此存储库包含 PyTorch 模型定义、预训练权重和训练/采样代码,用于我们的论文探索 带变压器 (DiT) 的扩散模型。你可以在我们的项目页面上找到更多可视化效果。
Scalable Diffusion Models with Transformers
William Peebles, Saining Xie
, 加州大学伯克利分校, 纽约大学
我们训练潜伏扩散模型,用一个运行在 潜伏补丁。我们通过前向通道的视角分析了扩散变压器 (DiT) 的可扩展性 以 Gflops 衡量的复杂度。我们发现具有更高 Gflops 的 DiT---通过增加变压器深度/宽度或 输入代币数量增加---始终具有较低的 FID。除了良好的可扩展性外,我们的 DiT-XL/2 模型在类条件 ImageNet 512×512 和 256×256 基准测试中优于所有先前的扩散模型, 后者的 FID 达到 2.27 的先进水平。
此存储库包含:
也可以在此处找到直接在 Hugging Face 中实现 DiT。
diffusers
首先,下载并设置存储库:
git clone https://github.com/facebookresearch/DiT.git
cd DiT
我们提供了一个environment.yml
文件,可用于创建 Conda 环境。如果你只想要
若要在 CPU 上本地运行预训练模型,可以从文件中删除 和 要求。
cudatoolkit
pytorch-cuda
conda env create -f environment.yml
conda activate DiT
预先训练的 DiT 检查点。你可以使用 sample.py
从我们预先训练的 DiT 模型中采样。我们预训练的 DiT 模型的权重为
根据你使用的型号自动下载。该脚本具有各种参数,可在 256x256 之间切换
和 512x512 模型,调整采样步骤,更改无分类器引导量表等。例如,要从中取样
我们的 512x512 DiT-XL/2 型号,你可以使用:
python sample.py --image-size 512 --seed 1
为方便起见,我们预先训练的 DiT 模型也可以直接在此处下载:
DiT模型 | 图像分辨率 | FID-50K型 | 初始分数 | Gflops的 |
---|---|---|---|---|
XL/2型 | 约256×256 | 2.27 | 278.24 | 119 |
XL/2型 | 约512×512 | 3.04 | 240.82 | 525 |
自定义 DiT 检查点。如果你已经使用 train.py
训练了新的 DiT 模型(见下文),则可以添加参数以改用自己的检查点。例如,从自定义的 EMA 权重中取样
256x256 DiT-L/4 型号,运行:
--ckpt
python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt
我们在 train.py
中为 DiT 提供了训练脚本。此脚本可用于训练类条件
DiT 型号,但可以很容易地修改以支持其他类型的调节。在 GPU 开启的情况下启动 DiT-XL/2 (256x256) 训练
一个节点:
N
torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train
我们使用 PyTorch 训练脚本从头开始训练 DiT-XL/2 和 DiT-B/4 模型 以验证它是否重现了多达数十万次训练迭代的原始 JAX 结果。在我们的实验中,PyTorch 训练的模型给出 与JAX训练的模型相比,结果相似(有时略好),直至合理的随机变化。一些数据点:
DiT模型 | 列车步骤 | FID-50K (JAX培训) |
FID-50K (PyTorch 训练) |
PyTorch 全球训练种子 |
---|---|---|---|---|
XL/2型 | 400 千米赛 | 19.5 | 18.1 | 42 |
B/4号楼 | 400 千米赛 | 68.4 | 68.9 | 42 |
B/4号楼 | 400 千米赛 | 68.4 | 68.3 | 100 |
这些模型以 256x256 分辨率进行训练;我们使用 8 架 A100 训练 XL/2,使用 4 架 A100 训练 B/4。请注意,FID 这里是用 250 个 DDPM 采样步骤计算的,使用 VAE 解码器且没有指导 ()。
mse
cfg-scale=1
TF32 Note(对 A100 用户很重要)。当我们运行上述测试时,根据 PyTorch 的默认值禁用了 TF32 matmuls。 我们在顶部启用了它们,因为它使训练和采样方式更快 A100s(也应该用于其他 Ampere GPU),但请注意,与 以上结果。
train.py
sample.py
培训(和抽样)可能会通过以下方式显著加快:
torch.compile
添加的基本功能会很好:
🔥 功能更新在 https://github.com/chuanyangjin/fast-DiT 上查看此存储库,预览一系列训练速度、加速和内存节省功能,包括梯度检查点、混合精度训练和预加的 VAE 功能。凭借这些进步,我们仅使用单个 A100 GPU 就实现了 DiT-XL/2 的 0.84 步/秒的训练速度。
我们包含一个sample_ddp.py
脚本,该脚本并行地从 DiT 模型中采样大量图像。此脚本
生成一个样本文件夹以及一个可直接与 ADM 的 TensorFlow 一起使用的文件
用于计算 FID、初始分数和
其他指标。例如,要通过 GPU 从预训练的 DiT-XL/2 模型中采样 50K 图像,请运行:
.npz
N
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000
还有其他几个选项;有关详细信息,请参阅sample_ddp.py
。
我们的模型最初是在 TPU 上的 JAX 中训练的。此存储库中的权重直接从 JAX 模型移植。 使用不同浮点精度进行采样的结果可能会有细微的差异。我们重新评估了 我们在 FP32 移植了 PyTorch 权重,它们实际上比 JAX 中的采样(2.21 FID 而论文中为 2.27)。
@article{Peebles2022DiT,
title={Scalable Diffusion Models with Transformers},
author={William Peebles and Saining Xie},
year={2022},
journal={arXiv preprint arXiv:2212.09748},
}
我们感谢何开明、胡荣航、亚历山大·伯格、舒比克·德布纳特、蒂姆·布鲁克斯、伊利亚·拉多萨沃维奇和泰特·肖的有益讨论。 威廉·皮布尔斯(William Peebles)得到了NSF研究生研究奖学金的支持。
这个代码库借鉴了 OpenAI 的扩散存储库,最著名的是 ADM。
代码和型号权重在 CC-BY-NC 下获得许可。有关详细信息,请参阅LICENSE.txt
。