nuxt-typescript - 使用 Transformers (DiT)官方 PyTorch 实现的可扩展扩散模型

Created at: 2018-05-23 02:31:25
Language: JavaScript
License: MIT

带变压器的可扩展扩散模型 (DiT)
官方 PyTorch 实现

造纸 |项目页面 |运行 DiT-XL/2拥抱脸部空间 在 Colab 中打开

DiT 样品

此存储库包含 PyTorch 模型定义、预训练权重和训练/采样代码,用于我们的论文探索 带变压器 (DiT) 的扩散模型。你可以在我们的项目页面上找到更多可视化效果。

Scalable Diffusion Models with Transformers
William PeeblesSaining Xie
, 加州大学伯克利分校, 纽约大学

我们训练潜伏扩散模型,用一个运行在 潜伏补丁。我们通过前向通道的视角分析了扩散变压器 (DiT) 的可扩展性 以 Gflops 衡量的复杂度。我们发现具有更高 Gflops 的 DiT---通过增加变压器深度/宽度或 输入代币数量增加---始终具有较低的 FID。除了良好的可扩展性外,我们的 DiT-XL/2 模型在类条件 ImageNet 512×512 和 256×256 基准测试中优于所有先前的扩散模型, 后者的 FID 达到 2.27 的先进水平。

此存储库包含:

  • 🪐 DiT 的简单 PyTorch 实现
  • ⚡️ 在 ImageNet 上训练的预训练类条件 DiT 模型(512x512 和 256x256)
  • 💥 一个独立的 Hugging Face SpaceColab 笔记本,用于运行预训练的 DiT-XL/2 模型
  • 🛸 使用 PyTorch DDP 的 DiT 训练脚本

也可以在此处找到直接在 Hugging Face 中实现 DiT。

diffusers

设置

首先,下载并设置存储库:

git clone https://github.com/facebookresearch/DiT.git
cd DiT

我们提供了一个environment.yml文件,可用于创建 Conda 环境。如果你只想要 若要在 CPU 上本地运行预训练模型,可以从文件中删除 和 要求。

cudatoolkit
pytorch-cuda

conda env create -f environment.yml
conda activate DiT

采样拥抱脸部空间 在 Colab 中打开

更多 DiT 样品

预先训练的 DiT 检查点。你可以使用 sample.py 从我们预先训练的 DiT 模型中采样。我们预训练的 DiT 模型的权重为 根据你使用的型号自动下载。该脚本具有各种参数,可在 256x256 之间切换 和 512x512 模型,调整采样步骤,更改无分类器引导量表等。例如,要从中取样 我们的 512x512 DiT-XL/2 型号,你可以使用:

python sample.py --image-size 512 --seed 1

为方便起见,我们预先训练的 DiT 模型也可以直接在此处下载:

DiT模型 图像分辨率 FID-50K型 初始分数 Gflops的
XL/2型 约256×256 2.27 278.24 119
XL/2型 约512×512 3.04 240.82 525

自定义 DiT 检查点。如果你已经使用 train.py 训练了新的 DiT 模型(见下文),则可以添加参数以改用自己的检查点。例如,从自定义的 EMA 权重中取样 256x256 DiT-L/4 型号,运行:

--ckpt

python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt

训练 DiT

我们在 train.py 中为 DiT 提供了训练脚本。此脚本可用于训练类条件 DiT 型号,但可以很容易地修改以支持其他类型的调节。在 GPU 开启的情况下启动 DiT-XL/2 (256x256) 训练 一个节点:

N

torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train

PyTorch 训练结果

我们使用 PyTorch 训练脚本从头开始训练 DiT-XL/2 和 DiT-B/4 模型 以验证它是否重现了多达数十万次训练迭代的原始 JAX 结果。在我们的实验中,PyTorch 训练的模型给出 与JAX训练的模型相比,结果相似(有时略好),直至合理的随机变化。一些数据点:

DiT模型 列车步骤 FID-50K
(JAX培训)
FID-50K
(PyTorch 训练)
PyTorch 全球训练种子
XL/2型 400 千米赛 19.5 18.1 42
B/4号楼 400 千米赛 68.4 68.9 42
B/4号楼 400 千米赛 68.4 68.3 100

这些模型以 256x256 分辨率进行训练;我们使用 8 架 A100 训练 XL/2,使用 4 架 A100 训练 B/4。请注意,FID 这里是用 250 个 DDPM 采样步骤计算的,使用 VAE 解码器且没有指导 ()。

mse
cfg-scale=1

TF32 Note(对 A100 用户很重要)。当我们运行上述测试时,根据 PyTorch 的默认值禁用了 TF32 matmuls。 我们在顶部启用了它们,因为它使训练和采样方式更快 A100s(也应该用于其他 Ampere GPU),但请注意,与 以上结果。

train.py
sample.py

增强

培训(和抽样)可能会通过以下方式显著加快:

  • [ ] 在 DiT 模型中使用闪光注意力
  • [ ] 在 PyTorch 2.0 中使用
    torch.compile

添加的基本功能会很好:

  • [ ] 监控 FID 和其他指标
  • [ ] 定期从 EMA 模型生成和保存样本
  • [ ] 从检查点恢复训练
  • [ ] 支持 AMP/bfloat16

🔥 功能更新https://github.com/chuanyangjin/fast-DiT 上查看此存储库,预览一系列训练速度、加速和内存节省功能,包括梯度检查点、混合精度训练和预加的 VAE 功能。凭借这些进步,我们仅使用单个 A100 GPU 就实现了 DiT-XL/2 的 0.84 步/秒的训练速度。

评估(FID、初始分数等)

我们包含一个sample_ddp.py脚本,该脚本并行地从 DiT 模型中采样大量图像。此脚本 生成一个样本文件夹以及一个可直接与 ADM 的 TensorFlow 一起使用的文件 用于计算 FID、初始分数和 其他指标。例如,要通过 GPU 从预训练的 DiT-XL/2 模型中采样 50K 图像,请运行:

.npz
N

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-XL/2 --num-fid-samples 50000

还有其他几个选项;有关详细信息,请参阅sample_ddp.py

与 JAX 的区别

我们的模型最初是在 TPU 上的 JAX 中训练的。此存储库中的权重直接从 JAX 模型移植。 使用不同浮点精度进行采样的结果可能会有细微的差异。我们重新评估了 我们在 FP32 移植了 PyTorch 权重,它们实际上比 JAX 中的采样(2.21 FID 而论文中为 2.27)。

BibTeX的

@article{Peebles2022DiT,
  title={Scalable Diffusion Models with Transformers},
  author={William Peebles and Saining Xie},
  year={2022},
  journal={arXiv preprint arXiv:2212.09748},
}

确认

我们感谢何开明、胡荣航、亚历山大·伯格、舒比克·德布纳特、蒂姆·布鲁克斯、伊利亚·拉多萨沃维奇和泰特·肖的有益讨论。 威廉·皮布尔斯(William Peebles)得到了NSF研究生研究奖学金的支持。

这个代码库借鉴了 OpenAI 的扩散存储库,最著名的是 ADM

许可证

代码和型号权重在 CC-BY-NC 下获得许可。有关详细信息,请参阅LICENSE.txt