gymnax

gymnax

JAX驱动的高效强化学习环境集合

gymnax是基于JAX构建的强化学习环境库,充分利用JAX的即时编译和向量化功能,显著提升了传统gym API的性能。该库涵盖经典控制、bsuite和MinAtar等多种环境,支持精确控制环境参数。通过在加速器上并行处理环境和策略,gymnax实现了高效的强化学习实验,尤其适合大规模并行和元强化学习研究。

gymnax强化学习JAX环境仿真加速计算Github开源项目
<h1 align="center"> <a href="https://yellow-cdn.veclightyear.com/835a84d5/2a34392f-ae3f-48b2-bd27-5970dbd15ff8.png"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/2a34392f-ae3f-48b2-bd27-5970dbd15ff8.png?raw=true" width="215" /></a><br> <b>JAX中的强化学习环境 🌍</b><br> </h1> <p align="center"> <a href="https://pypi.python.org/pypi/gymnax"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/955ebf74-3ad4-40af-9a6f-aaf547161ffc.svg?style=flat-square" /></a> <a href= "https://badge.fury.io/py/gymnax"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/73158e47-46e2-4ddd-902f-44d5162942c9.svg" /></a> <a href= "https://github.com/RobertTLange/gymnax/blob/master/LICENSE.md"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/53d70f62-418b-4b9b-bd7b-7357122defaa.svg" /></a> <a href= "https://codecov.io/gh/RobertTLange/gymnax"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/c27a074d-3eb9-47e2-aff8-ce050c30f8af.svg?token=OKKPDRIQJR" /></a> <a href= "https://github.com/psf/black"> <img src="https://yellow-cdn.veclightyear.com/835a84d5/2bf6dbd6-e45d-46c6-8332-3a3ae2996bf6.svg" /></a> </p>

你是否厌倦了基于CPU的缓慢RL环境进程?你是否想利用大规模矢量化来进行高吞吐量的RL实验?gymnaxjitvmap/pmap的强大功能引入经典的gym API。它支持多种不同的环境,包括经典控制bsuiteMinAtar以及一系列经典/元强化学习任务。gymnax允许显式控制环境设置(随机种子或超参数),从而实现不同配置的加速和并行化rollouts(例如用于元强化学习)。通过在加速器上同时执行环境和策略,它促进了Podracer论文(Hessel et al., 2021)中提出的Anakin子架构和高度分布式的进化优化(例如使用evosax)。我们在gymnax-blines中提供了PPO和ES的训练与检查点。从这里开始 👉 Colab

基本gymnax API使用 🍲

import jax import gymnax rng = jax.random.PRNGKey(0) rng, key_reset, key_act, key_step = jax.random.split(rng, 4) # 实例化环境及其设置 env, env_params = gymnax.make("Pendulum-v1") # 重置环境 obs, state = env.reset(key_reset, env_params) # 采样随机动作 action = env.action_space(env_params).sample(key_act) # 执行步骤转换 n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

已实现的加速环境 🏎️

环境名称参考文献源代码🤖 检查点(回报)每百万步秒数 🦶 <br /> A100 (2k 🌎)
Acrobot-v1Brockman等人 (2016)点击PPO, ES (回报: -80)0.07
Pendulum-v1Brockman等人 (2016)点击PPO, ES (回报: -130)0.07
CartPole-v1Brockman等人 (2016)点击PPO, ES (回报: 500)0.05
MountainCar-v0Brockman等人 (2016)点击PPO, ES (回报: -118)0.07
MountainCarContinuous-v0Brockman等人 (2016)点击PPO, ES (回报: 92)0.09
Asterix-MinAtarYoung和Tian (2019)点击PPO (回报: 15)0.92
Breakout-MinAtarYoung和Tian (2019)点击PPO (回报: 28)0.19
Freeway-MinAtarYoung和Tian (2019)点击PPO (回报: 58)0.87
SpaceInvaders-MinAtarYoung和Tian (2019)点击PPO (回报: 131)0.33
Catch-bsuiteOsband等人 (2019)点击PPO, ES (回报: 1)0.15
DeepSea-bsuiteOsband等人 (2019)点击PPO, ES (回报: 0)0.22
MemoryChain-bsuiteOsband等人 (2019)点击PPO, ES (回报: 0.1)0.13
UmbrellaChain-bsuiteOsband等人 (2019)点击PPO, ES (回报: 1)0.08
DiscountingChain-bsuiteOsband等人 (2019)点击PPO, ES (回报: 1.1)0.06
MNISTBandit-bsuiteOsband等人 (2019)点击--
SimpleBandit-bsuiteOsband等人 (2019)点击--
FourRooms-miscSutton等人 (1999)点击PPO, ES (回报: 1)0.07
MetaMaze-miscMicconi等人 (2020)点击ES (回报: 32)0.09
PointRobot-miscDorfman等人 (2021)点击ES (R: 10)0.08
BernoulliBandit-miscWang等人 (2017)点击ES (R: 90)0.08
GaussianBandit-miscLange & Sprekeler (2022)点击ES (R: 0)0.07
Reacher-miscLenton等人 (2021)点击
Swimmer-miscLenton等人 (2021)点击
Pong-miscKirsch (2018)点击
  • 所有显示的速度均为在NVIDIA A100 GPU上使用jit编译的情节展开和2000个环境工作线程进行100万步转换(随机策略)的估计值。有关不同加速器(CPU、RTX 2080Ti)和MLP策略的更详细速度比较,请参阅gymnax-blines文档。

安装 ⏳

可以直接从PyPI安装最新的gymnax版本:

pip install gymnax

如果你想获取最新的提交,请直接从仓库安装:

pip install git+https://github.com/RobertTLange/gymnax.git@main

要在加速器上使用JAX,可以在JAX文档中找到更多详细信息。

示例 📖

主要卖点 💵

  • 环境向量化和加速: 轻松组合JAX原语(如jitvmappmap):

    # Jit加速的步骤转换 jit_step = jax.jit(env.step) # 映射(vmap/pmap)随机键以进行批量展开 reset_rng = jax.vmap(env.reset, in_axes=(0, None)) step_rng = jax.vmap(env.step, in_axes=(0, 0, 0, None)) # 映射(vmap/pmap)环境参数(例如用于元学习) reset_params = jax.vmap(env.reset, in_axes=(None, 0)) step_params = jax.vmap(env.step, in_axes=(None, 0, 0, 0))

    有关与标准向量化NumPy环境的速度比较,请查看gymnax-blines

  • 扫描整个情节展开: 你还可以使用lax.scan扫描整个resetstep情节循环以实现快速编译:

    def rollout(rng_input, policy_params, env_params, steps_in_episode): """使用lax.scan展开一个jitted gymnax情节。""" # 重置环境 rng_reset, rng_episode = jax.random.split(rng_input) obs, state = env.reset(rng_reset, env_params) def policy_step(state_input, tmp): """与lax.scan兼容的jax环境中的步骤转换。""" obs, state, policy_params, rng = state_input rng, rng_step, rng_net = jax.random.split(rng, 3) action = model.apply(policy_params, obs) next_obs, next_state, reward, done, _ = env.step( rng_step, state, action, env_params ) carry = [next_obs, next_state, policy_params, rng] return carry, [obs, action, reward, next_obs, done] # 扫描情节步骤循环 _, scan_out = jax.lax.scan( policy_step, [obs, state, policy_params, rng_episode], (), steps_in_episode ) # 返回代理在情节中累积的奖励的掩码和 obs, action, reward, next_obs, done = scan_out return obs, action, reward, next_obs, done
  • 内置可视化工具: 你还可以使用Visualizer工具顺利生成GIF动画,该工具涵盖了所有classic_controlMinAtar和大多数misc环境:

    from gymnax.visualize import Visualizer state_seq, reward_seq = [], [] rng, rng_reset = jax.random.split(rng) obs, env_state = env.reset(rng_reset, env_params) while True: state_seq.append(env_state) rng, rng_act, rng_step = jax.random.split(rng, 3) action = env.action_space(env_params).sample(rng_act) next_obs, next_env_state, reward, done, info = env.step( rng_step, env_state, action, env_params ) reward_seq.append(reward) if done: break else: obs = next_obs env_state = next_env_state cum_rewards = jnp.cumsum(jnp.array(reward_seq)) vis = Visualizer(env, env_params, state_seq, cum_rewards) vis.animate(f"docs/anim.gif")
  • 训练流程和预训练代理: 查看gymnax-blines获取训练好的代理、专家展开可视化和PPO/ES流程。这些代理经过最小程度的调整,但可以帮助你快速上手。

  • 简单的批量代理评估: 正在进行中

    from gymnax.experimental import RolloutWrapper # 为摆环境定义展开管理器 manager = RolloutWrapper(model.apply, env_name="Pendulum-v1") # 策略的简单单情节展开 obs, action, reward, next_obs, done, cum_ret = manager.single_rollout(rng, policy_params) # 同一网络的多次展开(不同的rng,例如评估) rng_batch = jax.random.split(rng, 10) obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout( rng_batch, policy_params ) # 不同网络和rng的多次展开(例如用于ES) batch_params = jax.tree_map( # 堆叠参数或使用不同的参数 lambda x: jnp.tile(x, (5, 1)).reshape(5, *x.shape), policy_params ) obs, action, reward, next_obs, done, cum_ret = manager.population_rollout( rng_batch, batch_params )

资源和其他优秀工具 📝

  • 💻 Brax: Google Brain基于JAX的刚体物理库,提供JAX风格的MuJoCo替代品。
  • 💻 envpool: 向量化并行环境执行引擎。
  • 💻 Jumanji: 一套多样化且具有挑战性的JAX强化学习环境。
  • 💻 Pgx: 基于JAX的经典棋盘游戏环境。

致谢和引用gymnax ✏️

如果你在研究中使用gymnax,请按以下方式引用:

@software{gymnax2022github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.4},
  year = {2022},
}

我们感谢Google TRC和德国研究基金会(DFG,德国研究基金会)在德国卓越战略 - EXC 2002/1 "科学智能" - 项目编号390523135下的财政支持。

开发 👷

你可以通过python -m pytest -vv --all运行测试套件。如果你发现了bug或缺少你最喜欢的功能,欢迎创建问题并/或开始贡献 🤗。

编辑推荐精选

Vora

Vora

免费创建高清无水印Sora视频

Vora是一个免费创建高清无水印Sora视频的AI工具

Refly.AI

Refly.AI

最适合小白的AI自动化工作流平台

无需编码,轻松生成可复用、可变现的AI自动化工作流

酷表ChatExcel

酷表ChatExcel

大模型驱动的Excel数据处理工具

基于大模型交互的表格处理系统,允许用户通过对话方式完成数据整理和可视化分析。系统采用机器学习算法解析用户指令,自动执行排序、公式计算和数据透视等操作,支持多种文件格式导入导出。数据处理响应速度保持在0.8秒以内,支持超过100万行数据的即时分析。

AI工具酷表ChatExcelAI智能客服AI营销产品使用教程
TRAE编程

TRAE编程

AI辅助编程,代码自动修复

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
AIWritePaper论文写作

AIWritePaper论文写作

AI论文写作指导平台

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

AI辅助写作AI工具AI论文工具论文写作智能生成大纲数据安全AI助手热门
博思AIPPT

博思AIPPT

AI一键生成PPT,就用博思AIPPT!

博思AIPPT,新一代的AI生成PPT平台,支持智能生成PPT、AI美化PPT、文本&链接生成PPT、导入Word/PDF/Markdown文档生成PPT等,内置海量精美PPT模板,涵盖商务、教育、科技等不同风格,同时针对每个页面提供多种版式,一键自适应切换,完美适配各种办公场景。

AI办公办公工具AI工具博思AIPPTAI生成PPT智能排版海量精品模板AI创作热门
潮际好麦

潮际好麦

AI赋能电商视觉革命,一站式智能商拍平台

潮际好麦深耕服装行业,是国内AI试衣效果最好的软件。使用先进AIGC能力为电商卖家批量提供优质的、低成本的商拍图。合作品牌有Shein、Lazada、安踏、百丽等65个国内外头部品牌,以及国内10万+淘宝、天猫、京东等主流平台的品牌商家,为卖家节省将近85%的出图成本,提升约3倍出图效率,让品牌能够快速上架。

iTerms

iTerms

企业专属的AI法律顾问

iTerms是法大大集团旗下法律子品牌,基于最先进的大语言模型(LLM)、专业的法律知识库和强大的智能体架构,帮助企业扫清合规障碍,筑牢风控防线,成为您企业专属的AI法律顾问。

SimilarWeb流量提升

SimilarWeb流量提升

稳定高效的流量提升解决方案,助力品牌曝光

稳定高效的流量提升解决方案,助力品牌曝光

Sora2视频免费生成

Sora2视频免费生成

最新版Sora2模型免费使用,一键生成无水印视频

最新版Sora2模型免费使用,一键生成无水印视频

下拉加载更多