diffusion-forcing

diffusion-forcing

创新机器学习方法结合下一步预测和全序列扩散技术

Diffusion Forcing是一种结合下一步预测和全序列扩散技术的机器学习方法。该项目为视频预测、迷宫规划和时间序列分析等任务提供了框架。通过时间注意力机制,Diffusion Forcing可生成长序列预测并在复杂环境中进行规划。该方法在Minecraft和DMLab视频数据集以及迷宫规划任务中表现优异。项目包含使用说明和预训练模型,便于研究者快速上手和复现结果。

Diffusion Forcing深度学习视频预测迷宫规划模型训练Github开源项目

扩散强制:下一个词预测遇上全序列扩散

[项目网站] [论文]

Boyuan Chen<sup>1</sup>, Diego Martí Monsó<sup>2</sup>, Yilun Du<sup>1</sup>, Max Simchowitz<sup>1</sup>, Russ Tedrake<sup>1</sup>, Vincent Sitzmann<sup>1</sup> <br/> <sup>1</sup>麻省理工学院 <sup>2</sup>慕尼黑工业大学 </br>

这是我们论文扩散强制:下一个词预测遇上全序列扩散的v1.5代码库。main分支包含我们最新的带有时间注意力的重新实现(推荐使用),而paper分支包含原始论文用于复现目的的RNN代码。

图片

@misc{chen2024diffusionforcingnexttokenprediction,
      title={扩散强制:下一个词预测遇上全序列扩散},
      author={Boyuan Chen and Diego Marti Monso and Yilun Du and Max Simchowitz and Russ Tedrake and Vincent Sitzmann},
      year={2024},
      eprint={2407.01392},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2407.01392},
}

项目说明

设置

如果你想使用我们最新改进的视频和规划实现(使用时间注意力而非RNN),请保持在当前分支。如果你对复现原始论文的声明感兴趣,请通过git checkout paper切换到原始论文使用的分支。

运行conda create python=3.10 -n diffusion-forcing创建环境。 运行conda activate diffusion-forcing激活此环境。

安装时间序列、视频和机器人学所需的依赖:

pip install -r requirements.txt

注册一个wandb账户以进行云端日志记录和检查点保存。在命令行中,运行wandb login登录。

然后修改configurations/config.yaml中的wandb实体为你的wandb账户。

可选地,如果你想进行迷宫规划,由于d4rl的过时依赖,需要安装以下复杂的依赖。这涉及首先安装mujoco 210,然后运行

pip install -r extra_requirements.txt

使用预训练检查点快速开始

由于数据集很大,我们提供了一个迷你子集和预训练检查点,供你快速测试我们的模型!要使用它们,请从这里下载迷你数据集和检查点到项目根目录,并用tar -xzvf quickstart_atten.tar.gz解压。文件将出现在dataoutputs/xxx.ckpt中。如果你在发布检查点之前fork了项目,请确保也git pull上游以使用最新版本的代码!

然后运行以下命令,并前往wandb面板查看结果。

视频预测:

我们的可视化是并排的,左侧是预测,右侧是真实数据。然而,由于序列高度随机,真实数据预计不会与预测对齐。提供真实数据仅用于提供质量参考。

自回归生成与训练长度相同的Minecraft视频: python -m main +name=sample_minecraft_pretrained load=outputs/minecraft.ckpt experiment.tasks=[validation]

要让模型滚动超出训练长度,只需在上述命令后附加dataset.validation_multiplier=8,它将滚动8倍于训练的最大序列长度。

上述检查点使用少量帧训练了100K步。我们已经验证了扩散强制在潜在扩散设置中有效,并且可以扩展到更多标记而不牺牲组合性(使用本仓库之外的一些额外技术)!敬请期待我们的下一个项目!

迷宫规划:

随着我们获得更多见解,迷宫规划设置有所改变,详情请参见训练部分相应段落。我们尚未重新实现MCTG,但你已经可以在wandb日志上看到不错的可视化效果。

中等迷宫

python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.action_mean=[] dataset.action_std=[] dataset.observation_mean=[3.5092521,3.4765592] dataset.observation_std=[1.3371079,1.52102] load=outputs/maze2d_medium_x.ckpt experiment.tasks=[validation] algorithm.guidance_scale=3 +name=maze2d_medium_x_sampling

大型迷宫

python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_mean=[3.7296331,5.3047247] dataset.observation_std=[1.8070312,2.5687592] dataset.action_mean=[] dataset.action_std=[] load=outputs/maze2d_large_x.ckpt experiment.tasks=[validation] algorithm.guidance_scale=2 +name=maze2d_large_x_sampling

我们还探索了几个更多的设置,但尚未重新实现原始论文中的所有内容。如果你对那些检查点感兴趣,请查看本README文件的源代码中被注释掉的检查点加载说明。 python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_std=[3.6140624,5.1375184,9.747382,10.5974788] dataset.action_mean=[] dataset.action_std=[] load=outputs/maze2d_large_xv.ckpt experiment.tasks=[validation] algorithm.guidance_scale=4 +name=maze2d_large_xv_sampling

这里还有一个检查点,我们在其中采用了扩散动作,这是一个具有挑战性的设置,之前的论文中没有涉及。虽然我们还没有让它像原始的RNN版本的扩散强制那样工作得很好,但它确实有不错的数据。你可以稍微调高guidance scale。

python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.observation_std=[2.67,3.04,8,8] dataset.action_std=[6,6] load=outputs/maze2d_medium_xva.ckpt experiment.tasks=[validation] algorithm.guidance_scale=2 algorithm.open_loop_horizon=10 +name=maze2d_medium_xva_sampling

python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_std=[3.62,5.14,9.76,10.6] dataset.action_std=[3,3] load=outputs/maze2d_large_xva.ckpt experiment.tasks=[validation] algorithm.guidance_scale=2 algorithm.open_loop_horizon=10 +name=maze2d_large_xva_sampling

训练

视频

视频预测需要下载大型数据集。首先,如果你按照"使用预训练检查点快速开始"部分下载了mini子集,请删除mini子集文件夹data/minecraftdata/dmlab,因为这次我们必须下载完整的数据集。我们已经在Python中编写了代码,如果数据集不存在,它会为你下载。由于源代码的下载速度较慢,这可能需要几天时间。如果你更喜欢自己通过bash脚本来完成,请参考原始TECO数据集中的bash脚本,并使用他们README的Dataset部分中的dmlab.shminecraft.sh,也许可以将bash脚本分割成并行脚本。

然后只需运行相应的命令:

Minecraft

python -m main +name=your_experiment_name algorithm=df_video dataset=video_minecraft

DMLab

python -m main +name=your_experiment_name algorithm=df_video dataset=video_dmlab algorithm.weight_decay=1e-3 algorithm.diffusion.architecture.network_size=48 algorithm.diffusion.architecture.attn_dim_head=32 algorithm.diffusion.architecture.attn_resolutions=[8,16,32,64] algorithm.diffusion.beta_schedule=cosine

无因果掩码

只需在命令后添加algorithm.causal=False即可。

尝试采样

请查看"加载检查点以进行评估"段落,了解如何使用load=加载检查点。然后,运行完全相同的训练命令,添加experiment.tasks=[validation] load={wandb_run_id}来加载检查点并尝试采样。

要了解如何生成比训练序列更长的序列,你可以在"使用预训练检查点快速开始"部分找到说明。请记住,无限滚动而不使用滑动窗口是paper分支上原始RNN实现的特性,而这个版本必须使用滑动窗口,因为它是时间注意力。

默认情况下,我们运行带稳定化的自回归采样。要联合采样下两个标记,你可以在上述命令后添加以下内容:algorithm.scheduling_matrix=full_sequence algorithm.chunk_size=2

迷宫规划

对于那些只想重现原始论文而不是transformer架构的人,请查看代码的paper分支。

中等迷宫

python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_medium dataset.action_mean=[] dataset.action_std=[] dataset.observation_mean=[3.5092521,3.4765592] dataset.observation_std=[1.3371079,1.52102] +name=maze2d_medium_x

大型迷宫

python -m main experiment=exp_planning algorithm=df_planning dataset=maze2d_large dataset.observation_mean=[3.7296331,5.3047247] dataset.observation_std=[1.8070312,2.5687592] dataset.action_mean=[] dataset.action_std=[] +name=maze2d_large_x

模型训练后运行规划

请查看"加载检查点以进行评估"段落,了解如何使用load=加载检查点。要进行采样,只需在上述训练完成的命令后添加load={wandb_id_of_above_runs} experiment.tasks=[validation] algorithm.guidance_scale=2 +name=maze2d_sampling。你可以自由调整guidance_scale,范围从1到5。

这个版本的迷宫规划使用了与原始论文不同的扩散强制版本 - 在对扩散强制进行后续研究时,我们意识到使用独立噪声进行训练实际上也构建了因果和非因果模型之间的平滑插值,因为我们可以通过完全噪声(完全因果)或部分噪声(插值)来掩盖未来。最好的是,在这种设置下,你仍然可以通过金字塔采样考虑因果不确定性,方法是在不同噪声级别掩盖标记,并且你仍然可以拥有灵活的视野,因为你可以告诉模型填充的条目是纯噪声,这是扩散强制的独特能力。

我们还反思了一下环境,得出结论认为原始指标不一定是一个好指标,因为迷宫规划应该奖励那些能够最快规划到达目标路线的人,而不是最终到达那里的缓慢行走代理。数据集从未包含停留在目标处的数据,所以代理应该在达到目标后离开。我认为Diffuser有一个不公平的优势,只是生成缓慢的计划,恰好让代理在目标附近停留更长时间并获得很高的奖励,利用了环境设计的缺陷(一个好的设计应该包括对到达目标所需时间较长的惩罚)。因此,在这个版本的代码中,我们只是优化了灵活视野的规划,试图尽快到达目标,如果离开目标,规划器会自动返回目标,因为停留从未出现在数据集中。你可以在wandb日志界面中看到我们设计的新指标。

时间序列和机器人

请查看paper分支获取原始论文使用的代码。如果我以后有时间,我也会用transformer重新实现这两个领域,以完成这个分支。

更新日志

日期备注
2024/7/30将RNN升级为时间注意力,将原始代码移至'paper'分支
2024/7/3代码初始发布。如果你有问题或发现此版本中的任何错误,请给我发邮件。

基础设施说明

本仓库源自Boyuan Chen的研究模板仓库。根据MIT许可证,你必须在README.md中保留上述句子并保留LICENSE文件以对作者表示致谢。

所有实验可以通过python -m main +name=xxxx {选项}来启动,你可以在本文后面找到更多细节。

代码库会在可用时自动使用CUDA或Macbook M1 GPU。

对于Slurm集群(如MIT超级云),你可以在登录节点上运行python -m main cluster=mit_supercloud {选项}。它会自动生成Slurm脚本并在计算节点上运行。即使计算节点离线,脚本仍会自动将wandb日志同步到云端,延迟不到1分钟。按照"添加Slurm集群"部分,添加你自己的Slurm也很容易。

为你的项目修改

首先,使用此模板创建一个新仓库。确保新仓库的名称是你想用于wandb日志记录的名称。

按照algorithms/README.mdalgorithms/diffusion_forcing/df_video.py中的示例代码,在algorithms中添加你的方法和基准。对于PyTorch实验,将你的算法编写为pytorch lightning pl.LightningModule,它有详尽的文档。快速入门可以阅读此链接中的"Define a LightningModule"部分。最后,为你添加的每个算法在configurations/algorithm中添加一个yaml配置文件,仿照configurations/algorithm/df_video.yaml

按照datasets/README.mddatasets/video中的示例代码,在datasets中添加你的数据集。最后,为你添加的每个数据集在configurations/dataset中添加一个yaml配置文件,仿照configurations/dataset/video_dmlab.yaml

按照experiments/README.mdexperiments/exp_video.py中的示例代码,在experiments中添加你的实验。然后在experiments/__init__.py中注册你的实验。最后,为你添加的每个实验在configurations/experiment中添加一个yaml配置文件,仿照configurations/experiment/exp_video.yaml

修改configurations/config.yaml,将algorithm设置为你想在configurations/algorithm中使用的yaml文件;将experiment设置为你想在configurations/experiment中使用的yaml文件;将dataset设置为你想在configurations/dataset中使用的yaml文件,如果不需要数据集则设为null。注意字段不应包含.yaml后缀。

设置完成!

进入你的项目根目录。现在你可以通过python main.py +name=<为你的实验命名>来启动新实验。你可以通过添加algorithm=xxxdataset=xxx等参数来运行基准或不同的数据集。你也可以按照下一节的说明覆盖任何yaml配置。

特别注意,如果你想为你的实验定义一个新任务(例如除了trainingtest之外的任务),你可以在实验类中将其定义为一个方法,并使用experiment.tasks=[task_name]来运行它。假设你在training任务之前有一个generate_dataset任务,并且你在实验类中实现了它,那么你可以运行python -m main +name xxxx experiment.tasks=[generate_dataset,training]来在训练前执行它。

传递参数

我们使用hydra而不是argparse来配置每个代码层级的参数。你既可以在configuration文件夹中编写静态配置,也可以在运行时覆盖部分静态配置,使用命令行参数。

例如,参数algorithm=example_classifier experiment.lr=1e-3将覆盖configurations/experiment/example_classifier.yaml中的lr变量。参数wandb.mode将覆盖configurations/config.yaml文件中wandb命名空间下的mode

所有静态配置和运行时覆盖将自动记录到云端。

恢复检查点和日志

对于机器学习实验,所有检查点和日志都会自动记录到云端,因此你可以在另一台服务器上恢复它们。只需在命令行参数中添加resume={wandb_run_id}即可恢复。run_id可以在wandb仪表板中的wandb运行URL中找到。默认情况下,一次运行中的最新检查点会无限期存储,而该运行中较早的检查点会在5天后删除以节省存储空间。

另一方面,有时你可能想要使用不同的run_id启动新运行,但仍然加载先前的检查点。这可以通过设置load={wandb_run_id / ckpt path}标志来完成。

加载检查点进行评估

参数experiment.tasks=[task_name1,task_name2](注意这里需要[]括号)允许选择要执行的一系列任务,如trainingvalidationtest。因此,要测试机器学习检查点,你可以运行python -m main load={your_wandb_run_id} experiment.tasks=[test]

更一般地,任务名称是你的实验类中相应的方法名称。对于BaseLightningExperiment,我们已经为你定义了三个方法:trainingvalidationtest,但你也可以通过在预期任务名称下为你的实验类创建方法来定义自己的任务。

调试

我们提供了一个有用的调试标志,你可以通过python main.py debug=True启用。这将启用数值错误跟踪,并为你的实验、算法和数据集类设置cfg.debugTrue。但是,这个调试标志会使机器学习代码变得非常慢,因为它会自动跟踪所有参数和梯度!

添加Slurm集群

通过在configurations/cluster中添加yaml文件,可以很容易地添加你自己的Slurm集群。你可以参考configurations/cluster/mit_supercloud.yaml作为示例。

编辑推荐精选

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
问小白

问小白

全能AI智能助手,随时解答生活与工作的多样问题

问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。

热门AI助手AI对话AI工具聊天机器人
Transly

Transly

实时语音翻译/同声传译工具

Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

下拉加载更多