pfgmpp

pfgmpp

统一扩散和泊松流的生成模型框架

PFGM++是一个统一扩散模型和泊松流生成模型的框架,通过在高维空间嵌入路径来生成数据。它可以退化为PFGM或扩散模型,并允许通过选择额外维度D来平衡模型的鲁棒性和刚性。实验显示,特定D值的PFGM++模型在CIFAR-10和FFHQ数据集上的性能超越了现有的扩散模型,并对建模误差表现出更好的鲁棒性。

PFGM++生成模型图像生成深度学习人工智能Github开源项目

PFGM++:释放物理启发生成模型的潜力

PWC

论文《PFGM++:释放物理启发生成模型的潜力》的PyTorch实现

作者:徐逸伦刘子明田永龙、童尚源、Max TegmarkTommi S. Jaakkola

[幻灯片]

CIFAR-10FFHQ-64LSUN-Church-256
cifar_2ffhq_2lsun_2

😇 相比PFGM和扩散模型的改进:

  • 不再需要PFGM中的大批量训练目标,从而实现灵活的条件生成和更高效的训练!
  • 更一般的 $D \in \mathbb{R}^+$ 维增广变量。PFGM++包含了PFGM和扩散模型PFGM对应 $D=1$,扩散模型对应 $D\to \infty$。
  • 存在 $(1,\infty)$ 中间的最佳点 $D^*$!
  • 较小的 $D$ 比扩散模型($D\to \infty$)更稳健
  • 可以调整模型的稳健性和刚性
  • 可以直接迁移任何现有扩散模型($D\to \infty$)的精调超参数

*摘要:我们提出了一个名为PFGM++*的通用框架,统一了扩散模型和泊松流生成模型(PFGM)。这些模型通过在 $N{+}D$ 维空间中嵌入路径来实现 $N$ 维数据的生成轨迹,同时仍然使用 $D$ 个额外变量的简单标量范数来控制进程。新模型在 $D{=}1$ 时退化为PFGM,在 $D{\to}\infty$ 时退化为扩散模型。选择 $D$ 的灵活性使我们能够在稳健性和刚性之间权衡,因为增加 $D$ 会导致数据和额外变量范数之间的耦合更加集中。我们摒弃了PFGM中使用的有偏大批量场目标,而是提供了一个类似于扩散模型的无偏扰动目标。为了探索不同的 $D$ 选择,我们提供了一种直接对齐方法,用于将精调的扩散模型($D{\to} \infty$)超参数转移到任何有限 $D$ 值。我们的实验表明,有限 $D$ 的模型可以优于之前最先进的扩散模型,在CIFAR-10/FFHQ $64{\times}64$ 数据集上,当 $D{=}2048/128$ 时FID分数为 $1.91/2.43$。在类条件生成中,$D{=}2048$ 在CIFAR-10上产生了当前最先进的FID $1.74$。此外,我们证明较小 $D$ 的模型对建模错误表现出更好的稳健性

示意图


大纲

我们的实现基于EDM仓库。我们首先提供了一个指导,说明如何快速将精调扩散模型($D\to \infty$)的超参数转移到PFGM++家族($D\in \mathbb{R}^+$),如EDMDDPM,这种方式与任务/数据集无关(我们在论文的第4节(将超参数转移到有限 $D$)和附录C.2中提供了更多细节)。我们基于他们原始的命令行突出显示了我们对训练采样和评估的修改。我们在检查点部分提供了检查点。

我们还提供了来自EDM仓库的原始设置说明,如环境要求和数据集准备。

通过 $r=\sigma\sqrt{D}$ 公式进行转移指导

下面我们提供了如何快速将精调的扩散模型($D\to \infty$)超参数(如 $\sigma_{\textrm{max}}$ 和 $p(\sigma)$)转移到有限 $D$ 的指导。我们采用论文中的 $r=\sigma\sqrt{D}$ 公式进行对齐(参见第4节)。请将以下指导作为原型使用。

😀 请根据你的任务/数据集/模型调整增广维度 $D$。

训练超参数转移。我们提供的示例是本仓库中 loss.py 的简化版本。

示意图

def train(y, N, D, pfgmpp): ''' y: 小批量干净图像 N: 数据维度 D: 增广维度 pfgmpp: 使用PFGM++框架,否则为扩散模型(D\to\infty情况)。选项:0 | 1 ''' if not pfgmpp: ###################### === 扩散模型 === ###################### rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) sigma = (rnd_normal * self.P_std + self.P_mean).exp() # 从p(\sigma)采样sigma n = torch.randn_like(y) * sigma D_yn = net(y + n, sigma) loss = (D_yn - y) ** 2 ###################### === 扩散模型 === ###################### else: ###################### === PFGM++ === ###################### rnd_normal = torch.randn(images.shape[0], device=images.device) sigma = (rnd_normal * self.P_std + self.P_mean).exp() # 从p(\sigma)采样sigma r = sigma.double() * np.sqrt(self.D).astype(np.float64) # r=sigma\sqrt{D}公式

= 从扰动核 p_r 采样噪声 =

从逆贝塔分布采样

samples_norm = np.random.beta(a=self.N / 2., b=self.D / 2., size=images.shape[0]).astype(np.double) inverse_beta = samples_norm / (1 - samples_norm +1e-8) inverse_beta = torch.from_numpy(inverse_beta).to(images.device).double()

通过变量变换从 p_r(R) 采样 (参见附录 B)

samples_norm = (r * torch.sqrt(inverse_beta +1e-8)).view(len(samples_norm), -1)

均匀采样角度分量

gaussian = torch.randn(images.shape[0], self.N).to(samples_norm.device) unit_gaussian = gaussian / torch.norm(gaussian, p=2, dim=1, keepdim=True)

构造扰动

perturbation_x = (unit_gaussian * samples_norm).float()

= 从扰动核 p_r 采样噪声 =

sigma = sigma.reshape((len(sigma), 1, 1, 1)) n = perturbation_x.view_as(y) D_yn = net(y + n, sigma) loss = (D_yn - y) ** 2 ###################### === PFGM++ === ######################

采样超参数转换。我们提供的示例是这个仓库中 [generate.py] 的简化版本。如下图所示,唯一的修改是先验采样过程。因此,我们在代码片段中仅包含了扩散模型和 PFGM++ 的先验采样比较。

![示意图]

def generate(sigma_max, N, D, pfgmpp) ''' sigma_max: 扩散模型的起始条件 N: 数据维度 D: 增广维度 pfgmpp: 使用 PFGM++ 框架,否则为扩散模型(D\to\infty 情况)。选项:0 | 1 ''' if not pfgmpp: ###################### === 扩散模型 === ###################### x = torch.randn_like(data_size) * sigma_max ###################### === 扩散模型 === ###################### else: ###################### === PFGM++ === ###################### # 从逆贝塔分布采样 r = sigma_max * np.sqrt(self.D) # r=sigma\sqrt{D} 公式 samples_norm = np.random.beta(a=self.N / 2., b=self.D / 2., size=data_size).astype(np.double) inverse_beta = samples_norm / (1 - samples_norm +1e-8) inverse_beta = torch.from_numpy(inverse_beta).to(images.device).double() # 通过变量变换从 p_r(R) 采样 (参见附录 B) samples_norm = (r * torch.sqrt(inverse_beta +1e-8)).view(len(samples_norm), -1) # 均匀采样角度分量 gaussian = torch.randn(images.shape[0], self.N).to(samples_norm.device) unit_gaussian = gaussian / torch.norm(gaussian, p=2, dim=1, keepdim=True) # 构造扰动 x = (unit_gaussian * samples_norm).float().view(data_size) ###################### === PFGM++ === ####################### ######################################################## # Heun 二阶方法(又称改进的欧拉方法) # ########################################################

请参阅附录 C.2了解从 EDMDDPM 进行详细超参数转换的程序。

训练 PFGM++

您可以使用 train.py 训练新模型。例如:

torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs --name exp_name \ --data=datasets/cifar10-32x32.zip --cond=0 --arch=arch \ --pfgmpp=1 --batch 512 \ --aug_dim aug_dim (--resume resume_path) exp_name: 实验名称 aug_dim: D(额外维度) arch: 模型架构。选项:ncsnpp | ddpmpp pfgmpp: 使用 PFGM++ 框架,否则为扩散模型(D\to\infty 情况)。选项:0 | 1 resume_path: 恢复检查点的路径

上述示例使用默认的批量大小为 512 张图像(由 --batch 控制),这些图像在 8 个 GPU 之间均匀分配(由 --nproc_per_node 控制),每个 GPU 处理 64 张图像。训练大型模型可能会耗尽 GPU 内存;避免这种情况的最佳方法是限制每个 GPU 的批量大小,例如 --batch-gpu=32。这使用梯度累积来产生与使用完整的每 GPU 批量相同的结果。有关完整的选项列表,请参阅 [python train.py --help]。

每次训练运行的结果都保存在新创建的目录 training-runs/exp_name 中。训练循环会定期导出网络快照(training-state-*.pt)(由 --dump 控制)。网络快照可用于使用 generate.py 生成图像,训练状态可用于稍后恢复训练(--resume)。其他有用信息记录在 log.txtstats.jsonl 中。为了监控训练收敛情况,我们建议查看训练损失(stats.jsonl 中的 "Loss/loss"),并定期使用 generate.pyfid.py 评估 training-state-*.pt 的 FID。

对于 FFHQ 数据集,将 --data=datasets/cifar10-32x32.zip 替换为 --data=datasets/ffhq-64x64.zip

注意: 原始 EDM 仓库提供了更多数据集:FFHQ、AFHQv2、ImageNet-64。由于计算资源有限,我们没有测试 PFGM++ 在这些数据集上的性能。然而,我们相信某些有限的 D(最佳点)会优于扩散模型(D\to\infty 情况)。如果您有这些结果,请告诉我们 😀

生成和评估

  • 生成 50k 个样本:

    torchrun --standalone --nproc_per_node=8 generate.py \ --seeds=0-49999 --outdir=./training-runs/exp_name \ --pfgmpp=1 --aug_dim=aug_dim (--use_pickle=1)(--save_images) exp_name: 实验名称 aug_dim: D(额外维度) arch: 模型架构。选项:ncsnpp | ddpmpp pfgmpp: 使用 PFGM++ 框架,否则为扩散模型(D\to\infty 情况)。选项:0 | 1。(默认:0) use_pickle: 当检查点以 pickle 格式(.pkl)存储时。(默认:0)

请注意,FID 的数值在不同的随机种子间会有变化,并且对图像数量非常敏感。默认情况下,fid.py 总是使用 50,000 张生成的图像;提供更少的图像会导致错误,而提供更多则会使用随机子集。为了减少随机变化的影响,我们建议使用不同的种子重复计算多次,例如 --seeds=0-49999--seeds=50000-99999--seeds=100000-149999。在 EDM 论文中,他们计算了每个 FID 三次并报告了最小值。

对于 FID 与受控 $\alpha$/NFE/量化的对比,请使用 generate_alpha.py/generate_steps.py/generate_quant.py 进行生成。

  • FID 评估

    torchrun --standalone --nproc_per_node=8 fid.py calc --images=training-runs/exp_name --ref=fid-refs/cifar10-32x32.npz --num 50000 exp_name: 实验名称

检查点

所有检查点都提供在这个 [Google Drive 文件夹] 中。我们从 [EDM] 仓库借用了特定于数据集的超参数,例如批量大小、学习率等。如果您想尝试更多数据集(如 ImageNet 64),请参考该仓库的超参数。由于历史原因,一些检查点是 .pkl 格式,使用 generate.py 进行图像生成时请添加 --use_pickle=1 标志。在运行上述生成命令之前,请将检查点下载到指定的 ./training-runs/exp_name 文件夹中。

模型检查点路径$D$FID选项
cifar10-ncsnpp-D-128pfgmpp/cifar10_ncsnpp_D_128/1281.92--cond=0 --arch=ncsnpp --pfgmpp=1 --aug_dim=128
cifar10-ncsnpp-D-2048pfgmpp/cifar10_ncsnpp_D_2048/20481.91--cond=0 --arch=ncsnpp --pfgmpp=1 --aug_dim=2048
cifar10-ncsnpp-D-2048-conditionalpfgmpp/cifar10_ncsnpp_D_2048_conditional/20481.74--cond=1 --arch=ncsnpp --pfgmpp=1 --aug_dim=2048
cifar10-ncsnpp-D-inf (EDM)pfgmpp/cifar10_ncsnpp_D_inf/$\infty$1.98--cond=0 --arch=ncsnpp
ffhq-ddpm-D-128pfgmpp/ffhq_ddpm_D_128/1282.43--cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15 --pfgmpp=1 --aug_dim=128
ffhq-ddpm-D-inf (EDM)pfgmpp/ffhq_ddpm_D_inf/$\infty$2.53--cond=0 --arch=ddpmpp --batch=256 --cres=1,2,2,2 --lr=2e-4 --dropout=0.05 --augment=0.15

EDM仓库的设置说明

要求

  • Python库:具体的库依赖请参见environment.yml。您可以使用以下命令和Miniconda3创建并激活Python环境:
    • conda env create -f environment.yml -n edm
    • conda activate edm
  • Docker用户:

准备数据集

数据集的存储格式与StyleGAN相同:未压缩的ZIP存档,包含未压缩的PNG文件和用于标签的元数据文件dataset.json。可以从包含图像的文件夹创建自定义数据集;更多信息请参见python dataset_tool.py --help

CIFAR-10: 下载CIFAR-10 Python版本并转换为ZIP存档:

python dataset_tool.py --source=downloads/cifar10/cifar-10-python.tar.gz \
    --dest=datasets/cifar10-32x32.zip
python fid.py ref --data=datasets/cifar10-32x32.zip --dest=fid-refs/cifar10-32x32.npz

FFHQ: 下载Flickr-Faces-HQ数据集的1024x1024图像,并转换为64x64分辨率的ZIP存档:

python dataset_tool.py --source=downloads/ffhq/images1024x1024 \
    --dest=datasets/ffhq-64x64.zip --resolution=64x64
python fid.py ref --data=datasets/ffhq-64x64.zip --dest=fid-refs/ffhq-64x64.npz

AFHQv2: 下载更新的Animal Faces-HQ数据集afhq-v2-dataset),并转换为64x64分辨率的ZIP存档:

python dataset_tool.py --source=downloads/afhqv2 \
    --dest=datasets/afhqv2-64x64.zip --resolution=64x64
python fid.py ref --data=datasets/afhqv2-64x64.zip --dest=fid-refs/afhqv2-64x64.npz

ImageNet: 下载ImageNet对象定位挑战赛数据集,并转换为64x64分辨率的ZIP存档:

python dataset_tool.py --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \
    --dest=datasets/imagenet-64x64.zip --resolution=64x64 --transform=center-crop
python fid.py ref --data=datasets/imagenet-64x64.zip --dest=fid-refs/imagenet-64x64.npz

编辑推荐精选

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

Hunyuan3D-2

Hunyuan3D-2

高分辨率纹理 3D 资产生成

Hunyuan3D-2 是腾讯开发的用于 3D 资产生成的强大工具,支持从文本描述、单张图片或多视角图片生成 3D 模型,具备快速形状生成能力,可生成带纹理的高质量 3D 模型,适用于多个领域,为 3D 创作提供了高效解决方案。

3FS

3FS

一个具备存储、管理和客户端操作等多种功能的分布式文件系统相关项目。

3FS 是一个功能强大的分布式文件系统项目,涵盖了存储引擎、元数据管理、客户端工具等多个模块。它支持多种文件操作,如创建文件和目录、设置布局等,同时具备高效的事件循环、节点选择和协程池管理等特性。适用于需要大规模数据存储和管理的场景,能够提高系统的性能和可靠性,是分布式存储领域的优质解决方案。

下拉加载更多