rcg

rcg

RCG框架实现突破性无条件图像生成性能

RCG是一种创新的自监督图像生成框架,在ImageNet 256x256数据集上达到了无条件图像生成的最佳性能。该框架缩小了无条件和有条件图像生成之间的性能差距。项目提供基于PyTorch的GPU实现,包含表示扩散模型(RDM)以及MAGE、DiT、ADM和LDM等多种像素生成器的训练和评估代码。同时提供预训练模型和可视化工具,便于研究人员复现和拓展相关工作。

RCGPyTorch图像生成自监督学习神经网络Github开源项目

RCG PyTorch 实现

<p align="center"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/3f46cfbd-4a30-4e55-83a7-5b3391b767f2.png" width="560"> </p>

这是论文无条件生成的回归:一种自监督表示生成方法的 PyTorch/GPU 实现:

@Article{RCG2023,
  author  = {Tianhong Li and Dina Katabi and Kaiming He},
  journal = {arXiv:2312.03701},
  title   = {Return of Unconditional Generation: A Self-supervised Representation Generation Method},
  year    = {2023},
}

RCG 是一个自条件图像生成框架,在 ImageNet 256x256 上实现了最先进的无条件图像生成性能,弥合了长期存在的无条件和类条件图像生成之间的性能差距。

<p align="center"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/b52af2ce-905b-4f76-98be-b4f061e61dfb.png" width="560"> </p>

更新

2024 年 3 月

  • 更新 FID 评估和结果,遵循 ADM suite,通过在 torch-fidelity 中硬编码 ADM 统计信息。 可以通过以下方式安装修改后的 torch-fidelity
pip install -e git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity
  • 更新训练 400 个周期的 ADM 检查点(与原论文相同)。
  • 包含使用 RCG 训练 DiT-XL 的脚本和预训练检查点(400 个周期)。
  • 更新 Arxiv。

准备工作

数据集

下载 ImageNet 数据集,并将其放在您的 IMAGENET_DIR 中。 准备 ImageNet 验证集以进行 FID 评估:

python prepare_imgnet_val.py --data_path ${IMAGENET_DIR} --output_dir imagenet-val

要对验证集进行 FID 评估,请执行 pip install torch-fidelity,这将安装原始的 torch-fidelity 包。

安装

下载代码

git clone https://github.com/LTH14/rcg.git
cd rcg

可以使用以下命令创建并激活名为 rcg 的合适 conda 环境:

conda env create -f environment.yaml
conda activate rcg

使用此链接 下载预训练的 VQGAN 分词器,命名为 vqgan_jax_strongaug.ckpt

使用此链接 下载预训练的 moco v3 ViT-B 编码器,并将其命名为 pretrained_enc_ckpts/mocov3/vitb.pth.tar

使用此链接 下载预训练的 moco v3 ViT-L 编码器,并将其命名为 pretrained_enc_ckpts/mocov3/vitl.pth.tar

使用方法

RDM

使用 4 个 V100 GPU 训练 Moco v3 ViT-B 表示扩散模型:

python -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 --node_rank=0 \
main_rdm.py \
--config config/rdm/mocov3vitb_simplemlp_l12_w1536.yaml \
--batch_size 128 --input_size 256 \
--epochs 200 \
--blr 1e-6 --weight_decay 0.01 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_url tcp://${MASTER_SERVER_ADDRESS}:2214

要继续之前中断的训练会话,请将 --resume 设置为存储 checkpoint-last.pthOUTPUT_DIR

下表提供了论文中使用的预训练 Moco v3 ViT-B/ViT-L RDM 权重:

<table><tbody> <!-- START TABLE --> <!-- TABLE HEADER --> <th valign="bottom"></th> <th valign="bottom">Moco v3 ViT-B</th> <th valign="bottom">Moco v3 ViT-L</th> <!-- TABLE BODY --> <tr><td align="left">类无条件 RDM</td> <td align="center"><a href="https://drive.google.com/file/d/1gdsvzKLmmBWuF4Ymy4rQ_T1t6dDHnTEA/view?usp=sharing">Google Drive</a> / <a href="config/rdm/mocov3vitb_simplemlp_l12_w1536.yaml">配置</a></td> <td align="center"><a href="https://drive.google.com/file/d/1E5E3i9LRpSy0tVF7NA0bGXEh4CrjHAXz/view?usp=sharing">Google Drive</a> / <a href="config/rdm/mocov3vitl_simplemlp_l12_w1536.yaml">配置</a></td> </tr> <tr><td align="left">类条件 RDM</td> <td align="center"><a href="https://drive.google.com/file/d/1roanmVfg-UaddVehstQErvByqi0OYs2R/view?usp=sharing">Google Drive</a> / <a href="config/rdm/mocov3vitb_simplemlp_l12_w1536_classcond.yaml">配置</a></td> <td align="center"><a href="https://drive.google.com/file/d/1lZmXOcdHE97Qmn2azNAo2tNVX7dtTAkY/view?usp=sharing">Google Drive</a> / <a href="config/rdm/mocov3vitl_simplemlp_l12_w1536_classcond.yaml">配置</a></td> </tr> </tbody></table>

像素生成器:MAGE

使用 64 个 V100 GPU 训练一个基于 Moco v3 ViT-B 表示的 MAGE-B,训练 200 个周期:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --node_rank=0 \
main_mage.py \
--pretrained_enc_arch mocov3_vit_base \
--pretrained_enc_path pretrained_enc_ckpts/mocov3/vitb.pth.tar --rep_drop_prob 0.1 \
--use_rep --rep_dim 256 --pretrained_enc_withproj --pretrained_enc_proj_dim 256 \
--pretrained_rdm_cfg ${RDM_CFG_PATH} --pretrained_rdm_ckpt ${RDM_CKPT_PATH} \
--rdm_steps 250 --eta 1.0 --temp 6.0 --num_iter 20 --num_images 50000 --cfg 0.0 \
--batch_size 64 --input_size 256 \
--model mage_vit_base_patch16 \
--mask_ratio_min 0.5 --mask_ratio_max 1.0 --mask_ratio_mu 0.75 --mask_ratio_std 0.25 \
--epochs 200 \
--warmup_epochs 10 \
--blr 1.5e-4 --weight_decay 0.05 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_url tcp://${MASTER_SERVER_ADDRESS}:2214

要训练基于 Moco v3 ViT-L 表示的 MAGE-L, 更改 Moco v3 ViT-L RDM 的 RDM_CFG_PATHRDM_CKPT_PATH,以及以下参数:

--pretrained_enc_arch mocov3_vit_large --pretrained_enc_path pretrained_enc_ckpts/mocov3/vitl.pth.tar --temp 11.0 --model mage_vit_large_patch16

恢复:将 --resume 设置为存储 checkpoint-last.pthOUTPUT_DIR

评估:将 --resume 设置为预训练的 MAGE 检查点, 并在上述脚本中包含 --evaluate 标志。

预训练模型

<table><tbody> <!-- START TABLE --> <!-- TABLE HEADER --> <th valign="bottom"></th> <th valign="bottom">表示条件 MAGE-B</th> <th valign="bottom">表示条件 MAGE-L</th> <!-- TABLE BODY --> <tr><td align="left">检查点</td> <td align="center"><a href="https://drive.google.com/file/d/1iZY0ujWp5GVochTLj0U6j4HgVTOyWPUI/view?usp=sharing">Google Drive</a></td> <td align="center"><a href="https://drive.google.com/file/d/1nQh9xCqjQCd78zKwn2L9eLfLyVosb1hp/view?usp=sharing">Google Drive</a></td> </tr> <tr><td align="left">类无条件生成(无 CFG)</td> <td align="center">FID=3.98,IS=177.8</td> <td align="center">FID=3.44,IS=186.9</td> </tr> <tr><td align="left">类无条件生成(有 CFG)</td> <td align="center">FID=3.19,IS=214.9(cfg=1.0)</td> <td align="center">FID=2.15,IS=253.4(cfg=6.0)</td> </tr> <tr><td align="left">类条件生成(无 CFG)</td> <td align="center">FID=3.50,IS=194.9</td> <td align="center">FID=2.99,IS=215.5</td> </tr> <tr><td align="left">类条件生成(有 CFG)</td> <td align="center">FID=3.18,IS=242.6(cfg=1.0)</td> <td align="center">FID=2.25,IS=300.7(cfg=6.0)</td> </tr> </tbody></table>

可视化:使用 viz_rcg.ipynb 可视化生成结果。

类无条件生成示例:

<p align="center"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/0a1704ff-068e-4b70-95a0-942e34ea6669.jpg" width="800"> </p>

类条件生成示例:

<p align="center"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/48cd7973-e792-4777-b818-9ba687df837c.jpg" width="800"> </p>

像素生成器:DiT

要训练一个基于Moco v3 ViT-B表示的DiT-L模型,使用128个V100 GPU进行400轮训练:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=16 --node_rank=0 \
main_dit.py \
--rep_cond --rep_dim 256 \
--pretrained_enc_arch mocov3_vit_base \
--pretrained_enc_path pretrained_enc_ckpts/mocov3/vitb.pth.tar \
--pretrained_rdm_cfg ${RDM_CFG_PATH} \
--pretrained_rdm_ckpt ${RDM_CKPT_PATH} \
--batch_size 16 --image_size 256 --dit_model DiT-L/2 --num-sampling-steps ddim25 \
--epochs 400  \
--lr 1e-4 --weight_decay 0.0 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_url tcp://${MASTER_SERVER_ADDRESS}:2214

注意:有时对于DiT-XL,batch_size=16会导致内存溢出。将其改为12或14对性能影响很小。

恢复训练:将--resume设置为存储checkpoint-last.pthOUTPUT_DIR

评估:将--resume设置为预训练的DiT检查点,并在上述脚本中包含--evaluate标志。设置--num-sampling-steps 250以获得更好的生成性能。

基于Moco v3 ViT-B表示的预训练DiT-XL/2(400轮)可以在这里下载(FID=4.89,IS=143.2)。

像素生成器:ADM

要训练一个基于Moco v3 ViT-B表示的ADM模型,使用128个V100 GPU进行100轮训练:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=16 --node_rank=0 \
main_adm.py \
--rep_cond --rep_dim 256 \
--pretrained_enc_arch mocov3_vit_base \
--pretrained_enc_path pretrained_enc_ckpts/mocov3/vitb.pth.tar \
--pretrained_rdm_cfg ${RDM_CFG_PATH} \
--pretrained_rdm_ckpt ${RDM_CKPT_PATH} \
--batch_size 2 --image_size 256 \
--epochs 100  \
--lr 1e-4 --weight_decay 0.0 \
--attention_resolutions 32,16,8 --diffusion_steps 1000 \
--learn_sigma --noise_schedule linear \
--num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown \
--use_scale_shift_norm \
--gen_timestep_respacing ddim25 --use_ddim \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_url tcp://${MASTER_SERVER_ADDRESS}:2214

恢复训练:将--resume设置为存储checkpoint-last.pthOUTPUT_DIR

评估:将--resume设置为预训练的ADM检查点,并在上述脚本中包含--evaluate标志。设置--gen_timestep_respacing 250并禁用--use_ddim以获得更好的生成性能。

基于Moco v3 ViT-B表示的预训练ADM(400轮)可以在这里下载(FID=6.24,IS=136.9)。

像素生成器:LDM

使用此链接下载分词器,并将其命名为vqgan-ckpts/ldm_vqgan_f8_16384/checkpoints/last.ckpt

要训练一个基于Moco v3 ViT-B表示的LDM-8模型,使用64个V100 GPU进行40轮训练:

python -m torch.distributed.launch --nproc_per_node=8 --nnodes=8 --node_rank=0 \
main_ldm.py \
--config config/ldm/cin-ldm-vq-f8-repcond.yaml \
--batch_size 4 \
--epochs 40 \
--blr 2.5e-7 --weight_decay 0.01 \
--output_dir ${OUTPUT_DIR} \
--data_path ${IMAGENET_DIR} \
--dist_url tcp://${MASTER_SERVER_ADDRESS}:2214

恢复训练:将--resume设置为存储checkpoint-last.pthOUTPUT_DIR

评估:将--resume设置为预训练的LDM检查点,并在上述脚本中包含--evaluate标志。

基于Moco v3 ViT-B表示的预训练LDM(40轮)可以在这里下载(FID=11.30,IS=101.9)。

联系方式

如果您有任何问题,请随时通过电子邮件(tianhong@mit.edu)与我联系。祝您使用愉快!

编辑推荐精选

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

Hunyuan3D-2

Hunyuan3D-2

高分辨率纹理 3D 资产生成

Hunyuan3D-2 是腾讯开发的用于 3D 资产生成的强大工具,支持从文本描述、单张图片或多视角图片生成 3D 模型,具备快速形状生成能力,可生成带纹理的高质量 3D 模型,适用于多个领域,为 3D 创作提供了高效解决方案。

3FS

3FS

一个具备存储、管理和客户端操作等多种功能的分布式文件系统相关项目。

3FS 是一个功能强大的分布式文件系统项目,涵盖了存储引擎、元数据管理、客户端工具等多个模块。它支持多种文件操作,如创建文件和目录、设置布局等,同时具备高效的事件循环、节点选择和协程池管理等特性。适用于需要大规模数据存储和管理的场景,能够提高系统的性能和可靠性,是分布式存储领域的优质解决方案。

下拉加载更多