gymnax

gymnax

JAX驱动的高效强化学习环境集合

gymnax是基于JAX构建的强化学习环境库,充分利用JAX的即时编译和向量化功能,显著提升了传统gym API的性能。该库涵盖经典控制、bsuite和MinAtar等多种环境,支持精确控制环境参数。通过在加速器上并行处理环境和策略,gymnax实现了高效的强化学习实验,尤其适合大规模并行和元强化学习研究。

gymnax强化学习JAX环境仿真加速计算Github开源项目
<h1 align="center"> <a href="https://yellow-cdn.veclightyear.com/835a84d5/2a34392f-ae3f-48b2-bd27-5970dbd15ff8.png"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/2a34392f-ae3f-48b2-bd27-5970dbd15ff8.png?raw=true" width="215" /></a><br> <b>JAX中的强化学习环境 🌍</b><br> </h1> <p align="center"> <a href="https://pypi.python.org/pypi/gymnax"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/955ebf74-3ad4-40af-9a6f-aaf547161ffc.svg?style=flat-square" /></a> <a href= "https://badge.fury.io/py/gymnax"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/73158e47-46e2-4ddd-902f-44d5162942c9.svg" /></a> <a href= "https://github.com/RobertTLange/gymnax/blob/master/LICENSE.md"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/53d70f62-418b-4b9b-bd7b-7357122defaa.svg" /></a> <a href= "https://codecov.io/gh/RobertTLange/gymnax"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/c27a074d-3eb9-47e2-aff8-ce050c30f8af.svg?token=OKKPDRIQJR" /></a> <a href= "https://github.com/psf/black"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/2bf6dbd6-e45d-46c6-8332-3a3ae2996bf6.svg" /></a> </p>

你是否厌倦了基于CPU的缓慢RL环境进程?你是否想利用大规模矢量化来进行高吞吐量的RL实验?gymnaxjitvmap/pmap的强大功能引入经典的gym API。它支持多种不同的环境,包括经典控制bsuiteMinAtar以及一系列经典/元强化学习任务。gymnax允许显式控制环境设置(随机种子或超参数),从而实现不同配置的加速和并行化rollouts(例如用于元强化学习)。通过在加速器上同时执行环境和策略,它促进了Podracer论文(Hessel et al., 2021)中提出的Anakin子架构和高度分布式的进化优化(例如使用evosax)。我们在gymnax-blines中提供了PPO和ES的训练与检查点。从这里开始 👉 Colab

基本gymnax API使用 🍲

import jax import gymnax rng = jax.random.PRNGKey(0) rng, key_reset, key_act, key_step = jax.random.split(rng, 4) # 实例化环境及其设置 env, env_params = gymnax.make("Pendulum-v1") # 重置环境 obs, state = env.reset(key_reset, env_params) # 采样随机动作 action = env.action_space(env_params).sample(key_act) # 执行步骤转换 n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

已实现的加速环境 🏎️

环境名称参考文献源代码🤖 检查点(回报)每百万步秒数 🦶 <br /> A100 (2k 🌎)
Acrobot-v1Brockman等人 (2016)点击PPO, ES (回报: -80)0.07
Pendulum-v1Brockman等人 (2016)点击PPO, ES (回报: -130)0.07
CartPole-v1Brockman等人 (2016)点击PPO, ES (回报: 500)0.05
MountainCar-v0Brockman等人 (2016)点击PPO, ES (回报: -118)0.07
MountainCarContinuous-v0Brockman等人 (2016)点击PPO, ES (回报: 92)0.09
Asterix-MinAtarYoung和Tian (2019)点击PPO (回报: 15)0.92
Breakout-MinAtarYoung和Tian (2019)点击PPO (回报: 28)0.19
Freeway-MinAtarYoung和Tian (2019)点击PPO (回报: 58)0.87
SpaceInvaders-MinAtarYoung和Tian (2019)点击PPO (回报: 131)0.33
Catch-bsuiteOsband等人 (2019)点击PPO, ES (回报: 1)0.15
DeepSea-bsuiteOsband等人 (2019)点击PPO, ES (回报: 0)0.22
MemoryChain-bsuiteOsband等人 (2019)点击PPO, ES (回报: 0.1)0.13
UmbrellaChain-bsuiteOsband等人 (2019)点击PPO, ES (回报: 1)0.08
DiscountingChain-bsuiteOsband等人 (2019)点击PPO, ES (回报: 1.1)0.06
MNISTBandit-bsuiteOsband等人 (2019)点击--
SimpleBandit-bsuiteOsband等人 (2019)点击--
FourRooms-miscSutton等人 (1999)点击PPO, ES (回报: 1)0.07
MetaMaze-miscMicconi等人 (2020)点击ES (回报: 32)0.09
PointRobot-miscDorfman等人 (2021)点击ES (R: 10)0.08
BernoulliBandit-miscWang等人 (2017)点击ES (R: 90)0.08
GaussianBandit-miscLange & Sprekeler (2022)点击ES (R: 0)0.07
Reacher-miscLenton等人 (2021)点击
Swimmer-miscLenton等人 (2021)点击
Pong-miscKirsch (2018)点击
  • 所有显示的速度均为在NVIDIA A100 GPU上使用jit编译的情节展开和2000个环境工作线程进行100万步转换(随机策略)的估计值。有关不同加速器(CPU、RTX 2080Ti)和MLP策略的更详细速度比较,请参阅gymnax-blines文档。

安装 ⏳

可以直接从PyPI安装最新的gymnax版本:

pip install gymnax

如果你想获取最新的提交,请直接从仓库安装:

pip install git+https://github.com/RobertTLange/gymnax.git@main

要在加速器上使用JAX,可以在JAX文档中找到更多详细信息。

示例 📖

主要卖点 💵

  • 环境向量化和加速: 轻松组合JAX原语(如jitvmappmap):

    # Jit加速的步骤转换 jit_step = jax.jit(env.step) # 映射(vmap/pmap)随机键以进行批量展开 reset_rng = jax.vmap(env.reset, in_axes=(0, None)) step_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None)) # 映射(vmap/pmap)环境参数(例如用于元学习) reset_params = jax.vmap(env.reset, in_axes=(None, 0)) step_params = jax.vmap(env.step, in_axes=(None, 0, 0, 0))

    有关与标准向量化NumPy环境的速度比较,请查看gymnax-blines

  • 扫描整个情节展开: 你还可以使用lax.scan扫描整个resetstep情节循环以实现快速编译:

    def rollout(rng_input, policy_params, env_params, steps_in_episode): """使用lax.scan展开一个jitted gymnax情节。""" # 重置环境 rng_reset, rng_episode = jax.random.split(rng_input) obs, state = env.reset(rng_reset, env_params) def policy_step(state_input, tmp): """与lax.scan兼容的jax环境中的步骤转换。""" obs, state, policy_params, rng = state_input rng, rng_step, rng_net = jax.random.split(rng, 3) action = model.apply(policy_params, obs) next_obs, next_state, reward, done, _ = env.step( rng_step, state, action, env_params ) carry = [next_obs, next_state, policy_params, rng] return carry, [obs, action, reward, next_obs, done] # 扫描情节步骤循环 _, scan_out = jax.lax.scan( policy_step, [obs, state, policy_params, rng_episode], (), steps_in_episode ) # 返回代理在情节中累积的奖励的掩码和 obs, action, reward, next_obs, done = scan_out return obs, action, reward, next_obs, done
  • 内置可视化工具: 你还可以使用Visualizer工具顺利生成GIF动画,该工具涵盖了所有classic_controlMinAtar和大多数misc环境:

    from gymnax.visualize import Visualizer state_seq, reward_seq = [], [] rng, rng_reset = jax.random.split(rng) obs, env_state = env.reset(rng_reset, env_params) while True: state_seq.append(env_state) rng, rng_act, rng_step = jax.random.split(rng, 3) action = env.action_space(env_params).sample(rng_act) next_obs, next_env_state, reward, done, info = env.step( rng_step, env_state, action, env_params ) reward_seq.append(reward) if done: break else: obs = next_obs env_state = next_env_state cum_rewards = jnp.cumsum(jnp.array(reward_seq)) vis = Visualizer(env, env_params, state_seq, cum_rewards) vis.animate(f"docs/anim.gif")
  • 训练流程和预训练代理: 查看gymnax-blines获取训练好的代理、专家展开可视化和PPO/ES流程。这些代理经过最小程度的调整,但可以帮助你快速上手。

  • 简单的批量代理评估: 正在进行中

    from gymnax.experimental import RolloutWrapper # 为摆环境定义展开管理器 manager = RolloutWrapper(model.apply, env_name="Pendulum-v1") # 策略的简单单情节展开 obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(rng, policy_params) # 同一网络的多次展开(不同的rng,例如评估) rng_batch = jax.random.split(rng, 10) obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout( rng_batch, policy_params ) # 不同网络和rng的多次展开(例如用于ES) batch_params = jax.tree_map( # 堆叠参数或使用不同的参数 lambda x: jnp.tile(x, (5, 1)).reshape(5, *x.shape), policy_params ) obs, action, reward, next_obs, done, cum_ret = manager.population_rollout( rng_batch, batch_params )

资源和其他优秀工具 📝

  • 💻 Brax: Google Brain基于JAX的刚体物理库,提供JAX风格的MuJoCo替代品。
  • 💻 envpool: 向量化并行环境执行引擎。
  • 💻 Jumanji: 一套多样化且具有挑战性的JAX强化学习环境。
  • 💻 Pgx: 基于JAX的经典棋盘游戏环境。

致谢和引用gymnax ✏️

如果你在研究中使用gymnax,请按以下方式引用:

@software{gymnax2022github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.4},
  year = {2022},
}

我们感谢Google TRC和德国研究基金会(DFG,德国研究基金会)在德国卓越战略 - EXC 2002/1 "科学智能" - 项目编号390523135下的财政支持。

开发 👷

你可以通过python -m pytest -vv --all运行测试套件。如果你发现了bug或缺少你最喜欢的功能,欢迎创建问题并/或开始贡献 🤗。

编辑推荐精选

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

Hunyuan3D-2

Hunyuan3D-2

高分辨率纹理 3D 资产生成

Hunyuan3D-2 是腾讯开发的用于 3D 资产生成的强大工具,支持从文本描述、单张图片或多视角图片生成 3D 模型,具备快速形状生成能力,可生成带纹理的高质量 3D 模型,适用于多个领域,为 3D 创作提供了高效解决方案。

3FS

3FS

一个具备存储、管理和客户端操作等多种功能的分布式文件系统相关项目。

3FS 是一个功能强大的分布式文件系统项目,涵盖了存储引擎、元数据管理、客户端工具等多个模块。它支持多种文件操作,如创建文件和目录、设置布局等,同时具备高效的事件循环、节点选择和协程池管理等特性。适用于需要大规模数据存储和管理的场景,能够提高系统的性能和可靠性,是分布式存储领域的优质解决方案。

下拉加载更多