FreeU - FreeU:Diffusion U-Net 中的免费午餐

Created at: 2023-09-14 21:12:40
Language:
License: MIT

FreeU:扩散U-Net中的免费午餐

晨阳, 子琪, 玉明, 刘紫薇
南洋理工大学S实验室

|项目页面 |视频

我们提出了FreeU,这是一种免费大幅提高扩散模型样品质量的方法:无需训练,无需引入其他参数,也不会增加内存或采样时间。

📖 如需更多视觉效果,请查看我们的项目页面

自由U代码

def Fourier_filter(x, threshold, scale):
    # FFT
    x_freq = fft.fftn(x, dim=(-2, -1))
    x_freq = fft.fftshift(x_freq, dim=(-2, -1))
    
    B, C, H, W = x_freq.shape
    mask = torch.ones((B, C, H, W)).cuda() 

    crow, ccol = H // 2, W //2
    mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
    x_freq = x_freq * mask

    # IFFT
    x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
    x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
    
    return x_filtered

class Free_UNetModel(UNetModel):
    """
    :param b1: backbone factor of the first stage block of decoder.
    :param b2: backbone factor of the second stage block of decoder.
    :param s1: skip factor of the first stage block of decoder.
    :param s2: skip factor of the second stage block of decoder.
    """

    def __init__(
        self,
        b1,
        b2,
        s1,
        s2,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.b1 = b1 
        self.b2 = b2
        self.s1 = s1
        self.s2 = s2

    def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs = []
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)

        if self.num_classes is not None:
            assert y.shape[0] == x.shape[0]
            emb = emb + self.label_emb(y)

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)
        h = self.middle_block(h, emb, context)
        for module in self.output_blocks:
            hs_ = hs.pop()

            # --------------- FreeU code -----------------------
            # Only operate on the first two stages
            if h.shape[1] == 1280:
                h[:,:640] = h[:,:640] * self.b1
                hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
            if h.shape[1] == 640:
                h[:,:320] = h[:,:320] * self.b2
                hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
            # ---------------------------------------------------------

            h = th.cat([h, hs_], dim=1)
            h = module(h, emb, context)
        h = h.type(x.dtype)
        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

参数

你可以根据你的模型、图像/视频样式或任务随意调整这些参数。以下参数仅供参考。

SD1.4:

b1: 1.2, b2: 1.4, s1: 0.9, s2: 0.2

SD2.1

b1: 1.1, b2: 1.2, s1: 0.9, s2: 0.2

更多参数的范围

尝试其他参数时,请考虑以下范围:

  • b1:1 ≤ b1 ≤ 1.2
  • b2:1.2 ≤ b2 ≤ 1.6
  • s1:s1 ≤ 1
  • s2:s2 ≤ 1

来自社区的结果

如果你尝试过FreeU并想分享你的结果,请告诉我,我们可以在此处提供链接。

比特克斯

@article{si2023freeu,
  title={FreeU: Free Lunch in Diffusion U-Net},
  author={Si, Chenyang and Huang, Ziqi and Jiang, Yuming and Liu, Ziwei},
  journal={arXiv preprint arXiv:2309.11497},
  year={2023}
}

🗞️ 许可证

在 MIT 许可证下分发。有关详细信息,请参阅。

LICENSE