# FreeU - FreeU：Diffusion U-Net 中的免费午餐

# FreeU：扩散U-Net中的免费午餐

## 自由U代码

```def Fourier_filter(x, threshold, scale):
# FFT
x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1))

B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W)).cuda()

crow, ccol = H // 2, W //2
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale

# IFFT
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real

return x_filtered

class Free_UNetModel(UNetModel):
"""
:param b1: backbone factor of the first stage block of decoder.
:param b2: backbone factor of the second stage block of decoder.
:param s1: skip factor of the first stage block of decoder.
:param s2: skip factor of the second stage block of decoder.
"""

def __init__(
self,
b1,
b2,
s1,
s2,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.b1 = b1
self.b2 = b2
self.s1 = s1
self.s2 = s2

def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)

if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
hs_ = hs.pop()

# --------------- FreeU code -----------------------
# Only operate on the first two stages
if h.shape[1] == 1280:
h[:,:640] = h[:,:640] * self.b1
hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
if h.shape[1] == 640:
h[:,:320] = h[:,:320] * self.b2
hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
# ---------------------------------------------------------

h = th.cat([h, hs_], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)```

## 参数

### SD1.4：

b1： 1.2， b2： 1.4， s1： 0.9， s2： 0.2

### SD2.1

b1： 1.1， b2： 1.2， s1： 0.9， s2： 0.2

### 更多参数的范围

• b1：1 ≤ b1 ≤ 1.2
• b2：1.2 ≤ b2 ≤ 1.6
• s1：s1 ≤ 1
• s2：s2 ≤ 1

