flashbax

flashbax

JAX强化学习高效体验回放缓冲库

Flashbax是一个为JAX设计的高效体验回放缓冲库,适用于强化学习算法。它提供平坦缓冲、轨迹缓冲及其优先级变体等多种缓冲类型,特点是高效内存使用、易于集成到编译函数中,并支持优先级采样。Flashbax还具有Vault功能,可将大型缓冲区保存到磁盘。这个简单灵活的框架适用于学术研究、工业应用和个人项目中的体验回放处理。

Flashbax经验回放缓冲区强化学习JAX深度学习Github开源项目
<p align="center"> <a href="./docs/img/logo.png#gh-light-mode-only"> <img src="https://raw.githubusercontent.com/instadeepai/flashbax/main/docs/imgs/logo.png#gh-light-mode-only" alt="Flashbax标志" width="70%"/> </a> <a href="./docs/img/logo_dm.png#gh-dark-mode-only"> <img src="https://raw.githubusercontent.com/instadeepai/flashbax/main/docs/imgs/logo_dm.png#gh-dark-mode-only" alt="Flashbax标志" width="70%"/> </a> </p>

Python版本 PyPI版本 测试 代码风格 MyPy 许可证

<div align="center"> <h3> <a href="#overview-">概述</a> | <a href="#features-%EF%B8%8F">特性</a> | <a href="#setup-">设置</a> | <a href="#quickstart-">快速开始</a> | <a href="#examples-">示例</a> | <a href="#important-considerations-%EF%B8%8F">重要注意事项</a> | <a href="#benchmarks-">基准测试</a> | <a href="#contributing-">贡献</a> | <a href="#see-also-">另请参阅</a> | <a href="#citing">引用</a> | <a href="#acknowledgements-">致谢</a> </h3> </div>

⚡ Jax中的高速缓冲区 ⚡

概述 🔍

Flashbax是一个旨在简化强化学习(RL)中经验回放缓冲区使用的库。它专为与JAX范式兼容而设计,允许这些缓冲区在完全编译的函数和训练循环中轻松使用。

Flashbax提供了各种不同类型缓冲区的实现,如平面缓冲区、轨迹缓冲区以及两者的优先级变体。无论是用于学术研究、工业应用还是个人项目,Flashbax都为RL经验回放处理提供了简单灵活的框架。

特性 🛠️

🚀 高效缓冲区变体:所有Flashbax缓冲区都是作为轨迹缓冲区的专门变体构建的,在各种类型的缓冲区中优化内存使用和功能。

🗄️ 平面缓冲区:平面缓冲区类似于DQN等算法中使用的转换缓冲区,是一个核心组件。它使用序列长度为2(即$s_t$, $s_{t+1}$),周期为1,以全面考虑转换对。

🧺 项目缓冲区:项目缓冲区是一个存储单个项目的简单缓冲区。它适用于存储相互独立的数据,如(观察、动作、奖励、折扣、下一个观察)元组或整个回合。

🛤️ 轨迹缓冲区:轨迹缓冲区便于采样多步轨迹,适用于使用循环网络的算法,如R2D2(Kapturowski等人,2018)。

🏅 优先级缓冲区:平面缓冲区和轨迹缓冲区都可以设置优先级,实现基于用户定义优先级的采样。优先机制与PER论文(Schaul等人,2016)中概述的原则一致。

🚶 轨迹/平面队列:提供了一种队列数据结构,用于按先进先出(FIFO)顺序采样数据。该队列可用于特定用例的在线策略算法。

设置 🎬

要将Flashbax集成到您的项目中,请按以下步骤操作:

  1. 安装:首先使用pip安装Flashbax:
pip install flashbax
  1. 选择缓冲区:从平面缓冲区、轨迹缓冲区和优先级变体等多种缓冲区选项中选择。
import flashbax as fbx buffer = fbx.make_trajectory_buffer(...) # 或 buffer = fbx.make_prioritised_trajectory_buffer(...) # 或 buffer = fbx.make_flat_buffer(...) # 或 buffer = fbx.make_prioritised_flat_buffer(...) # 或 buffer = fbx.make_item_buffer(...) # 或 buffer = fbx.make_trajectory_queue(...) # 初始化 state = buffer.init(example_timestep) # 添加数据 state = buffer.add(state, example_data) # 采样数据 batch = buffer.sample(state, rng_key)

快速开始 🏁

以下我们提供了使用平面缓冲区的最小代码示例。在此示例中,我们展示了如何使用定义平面缓冲区的每个纯函数。请注意,这些纯函数都与jax.pmapjax.jit兼容,但为简单起见,以下示例中未使用这些函数。

import jax import jax.numpy as jnp import flashbax as fbx # 使用简单配置通过`make_flat_buffer`实例化平面缓冲区NamedTuple。 # 返回的`buffer`只是使用平面缓冲区所需的纯函数的容器。 buffer = fbx.make_flat_buffer(max_length=32, min_length=2, sample_batch_size=1) # 初始化缓冲区的状态。 fake_timestep = {"obs": jnp.array([0, 0]), "reward": jnp.array(0.0)} state = buffer.init(fake_timestep) # 现在我们向缓冲区添加数据。 state = buffer.add(state, {"obs": jnp.array([1, 2]), "reward": jnp.array(3.0)}) print(buffer.can_sample(state)) # False,因为尚未达到min_length。 state = buffer.add(state, {"obs": jnp.array([4, 5]), "reward": jnp.array(6.0)}) print(buffer.can_sample(state)) # 仍为False,因为我们需要2个*转换*(即3个时间步)。 state = buffer.add(state, {"obs": jnp.array([7, 8]), "reward": jnp.array(9.0)}) print(buffer.can_sample(state)) # True!我们有2个转换(3个时间步)。 # 从缓冲区获取一个转换。 rng_key = jax.random.PRNGKey(0) # 随机源。 batch = buffer.sample(state, rng_key) # 采样 # 我们有一个转换!打印:obs = [[4 5]], obs' = [[7 8]] print( f"obs = {batch.experience.first['obs']}, obs' = {batch.experience.second['obs']}" )

示例 🧑‍💻

我们提供以下Colab示例,作为如何使用每个flashbax缓冲区的更高级教程以及使用示例:

Colab 笔记本描述
Colab平面缓冲区快速入门
Colab轨迹缓冲区快速入门
Colab优先级平面缓冲区快速入门
Colab使用 Matrax 的项目缓冲区示例
ColabAnakin DQN
ColabAnakin 优先级 DQN
ColabAnakin PPO
Colab使用向量化 Gym 环境的 DQN
  • 👾 Anakin - 基于 JAX 的架构,用于端到端地即时编译强化学习代理的训练。
  • 🎮 DQN - 实现改编自 CleanRL 的 DQN JAX 示例。
  • 🦎 Jumanji - 利用 Jumanji 基于 JAX 的环境(如贪吃蛇)进行完全即时编译的示例。
  • ✖️ Matrax - JAX 中的双人矩阵游戏。

保险库 💾

保险库是一种将 Flashbax 缓冲区保存到持久数据存储的高效机制,例如用于离线强化学习。考虑一个维度为 $(B, T, *E)$ 的 Flashbax 缓冲区,其中 $B$ 是批次维度(用于同步记录独立轨迹),$T$ 是时间/序列维度,$*E$ 表示经验数据本身的一个或多个维度。由于特定环境可能会生成大量数据,保险库通过沿时间轴读写缓冲区切片来将 $T$ 维度扩展到几乎不受限制的程度。这样,巨大的缓冲区存储可以驻留在磁盘上,从中可以将子缓冲区加载到 RAM/VRAM 中进行高效的离线训练。保险库已经在项目、平面和轨迹缓冲区上进行了测试。

更多信息,请参阅演示笔记本:Colab

重要考虑事项 ⚠️

在使用 Flashbax 缓冲区时,需要注意某些考虑事项以确保强化学习代理的正常功能。

顺序数据添加

Flashbax 使用轨迹缓冲区作为所有缓冲区类型的基础。这意味着数据必须按顺序添加。具体而言,对于平面缓冲区,每个添加的时间步必须紧跟其连续的时间步。在大多数情况下,这个要求自然得到满足,不需要过多考虑。然而,当添加完全独立的数据批次时,必须注意这个限制。未能维持时间步之间的序列关系可能导致算法问题。用户需要处理从最后一个时间步到第一个时间步的情况。这发生在同一批次中从第 n 个情节到第 n+1 个情节时。例如,我们使用自动重置包装器在终止时间步时自动重置环境。此外,我们使用折扣值(非终止状态为 1,终止状态为 0)来相应地掩蔽价值函数和奖励折扣。

有效缓冲区大小

添加数据批次时,缓冲区以块状结构创建。这意味着有效缓冲区大小取决于批次维度的大小。轨迹缓冲区允许用户指定添加批次维度和时间轴的最大长度。这将创建一个 (批次, 时间) 的块状结构,允许存储的最大时间步数为 批次*时间。为了便于使用,我们提供了 max_size 参数,允许用户设置所需的总时间步数,我们根据提供的添加批次维度计算时间轴的最大长度。因此,重要的是要注意,使用 max_size 参数时,时间轴的最大长度将等于 max_size // 添加批次大小,这将向下取整,从而减少有效缓冲区大小。这意味着人们可能认为他们增加了一定量的缓冲区大小,但实际上并没有增加。因此,为避免这种情况,我们建议采取以下两种方法之一:明确使用最大时间轴长度参数,或者以添加批次大小的倍数增加 max_size 参数。

处理情节截断

另一个关键方面是情节截断。当截断情节并将数据添加到缓冲区时,必须确保适当设置完成标志或"折扣"值。忽视这一点可能会给算法的实现和行为带来挑战。如前所述,预期算法会适当处理这些情况。使用平面缓冲区或轨迹缓冲区处理截断可能很困难,因为算法必须处理一个情节的最后时间步后面跟着下一个情节的第一个时间步的情况。为了牺牲内存效率来换取易用性,可以使用项目缓冲区来独立存储转换或整个轨迹。这意味着算法不需要处理一个情节的最后时间步后面跟着下一个情节的第一个时间步的情况,因为只有明确插入的数据才能被采样。

独立数据使用

对于打算使用缺乏顺序信息的数据的缓冲区的情况,你可以利用项目缓冲区,它是一个具有特定配置的包装轨迹缓冲区。通过将序列维度设置为 1 并将周期设置为 1,每个项目将被视为独立的。然而,当处理独立的转换项目(如观察、动作、奖励、折扣、下一个观察)时,请注意这种方法将导致缓冲区中的观察重复,从而导致不必要的内存消耗。值得注意的是,平面缓冲区的实现速度会比以这种方式使用项目缓冲区慢,这是由于硬件加速器上数据索引的固有速度问题;然而,这种权衡是为了提高内存效率。如果速度远比内存效率更重要,那么使用序列为 1 和周期为 1 的轨迹缓冲区存储完整的转换数据项。

缓冲区状态的原地更新

由于缓冲区通常很大并占用设备内存的大部分,因此执行原地更新是有益的。为此,重要的是要向顶级编译函数指定你希望执行这种原地更新操作。这表示如下:

def train(train_state, buffer_state): ... return train_state, buffer_state # 初始化缓冲区状态 buffer_fn = fbx.make_trajectory_buffer(...) buffer_state = buffer_fn.init(example_timestep) # 初始化一些训练状态 train_state = train_state.init(...) # 编译训练函数并指定缓冲区状态参数的捐赠 train_state, buffer_state = jax.jit(train, donate_argnums=(1,))( train_state, buffer_state )

在调用 jax.jit 时包含 donate_argnums 很重要,这可以使 JAX 对回放缓冲区状态进行原地更新。如果省略 donate_argnums,JAX 将被迫为回放缓冲区状态的任何修改创建副本,可能会抵消所有性能优势。有关 JAX 中缓冲区捐赠的更多信息,可以在文档中找到。

使用 Vault 存储数据

如上所述,Vault 通过扩展 Flashbax 缓冲区状态的时间轴将经验数据存储到磁盘。默认情况下,Vault 方便地处理此过程的簿记:消耗缓冲区状态并保存任何新的、以前未见过的数据。例如,假设我们向 Flashbax 缓冲区写入 10 个时间步,然后将此状态保存到 Vault;由于所有这些数据都是新的,所有数据都将写入磁盘。但是,如果我们再写入一个时间步并将状态保存到 Vault,则只会写入该新时间步,防止重复已保存的数据。重要的是,必须记住 Flashbax 状态是作为环形缓冲区实现的,这意味着必须足够频繁地更新 Vault,然后再覆盖 Flashbax 缓冲区状态中未见过的数据。即如果我们的缓冲区状态的时间轴长度为 $\tau$,那么我们必须每 $\tau - 1$ 步保存到 vault 一次,以免覆盖(并丢失)未保存的数据。

总之,理解并解决这些考虑因素将帮助您避开潜在的陷阱,并确保在使用 Flashbax 缓冲区时强化学习策略的有效性。

基准测试 📈

这里我们提供了一系列初步基准测试,概述了各种 Flashbax 缓冲区与常用开源缓冲区相比的性能。在这些基准测试中,我们(除非另有明确说明)使用以下配置:

参数
缓冲区大小500_000
采样批次大小256
观察大小(32, 32, 3)
添加序列长度1
添加序列批次大小1
采样序列长度1
采样序列周期1

我们使用采样序列长度和周期为 1 的原因是为了直接与其他缓冲区进行比较,这意味着轨迹缓冲区的速度与项目缓冲区的速度相当,因为项目缓冲区只是具有此配置的包装轨迹缓冲区。这实际上意味着轨迹缓冲区被用作内存效率低下的转换缓冲区。需要注意的是,Flat Buffer 实现使用采样序列长度为 2。此外,必须记住,并非所有其他缓冲区实现都能有效利用 GPU/TPU,因此它们只在 CPU 上运行并执行设备转换。最后,我们明确使用 Python 循环来公平比较其他缓冲区。使用扫描操作可以大大提高速度(取决于观察大小)。

CPU 速度

<p float="left"> <img alt="CPU_Add" src="https://yellow-cdn.veclightyear.com/835a84d5/9b7b6a0c-0005-4b43-96f5-4f20bb0a0843.png" width="49%"> <img alt="CPU_Sample" src="https://yellow-cdn.veclightyear.com/835a84d5/fd099347-e288-472e-b8d3-1a44003fe818.png" width="49%"> </p>

TPU 速度

<p float="left"> <img alt="TPU_Add" src="https://yellow-cdn.veclightyear.com/835a84d5/20a450a6-c96b-4059-9a61-884981c6aa15.png" width="49%"> <img alt="TPU_Sample" src="https://yellow-cdn.veclightyear.com/835a84d5/56bc1836-496b-4e30-817c-8d2fa29cbcfc.png" width="49%"> </p>

GPU 速度

我们注意到添加数据时 GPU 速度出现奇怪的行为。我们认为这是因为某些 JAX 操作尚未针对 GPU 使用进行充分优化,我们看到 Dejax 也有相同的性能问题。我们预计这些速度将来会有所改善。

<p float="left"> <img alt="GPU_Add" src="https://yellow-cdn.veclightyear.com/835a84d5/70093fe3-dcb8-48c2-8d73-e28f9ebbe5d0.png" width="49%"> <img alt="GPU_Sample" src="https://yellow-cdn.veclightyear.com/835a84d5/abd30a18-a9ba-400c-99f1-acea59bacde6.png" width="49%"> </p>

CPU、GPU 和 TPU 添加批次

之前的基准测试每次只添加一个时间步,现在我们评估每次添加 128 个时间步的批次 - 这是大多数人在高吞吐量 RL 中会使用的功能。我们只与具有此功能的缓冲区进行比较。

<p float="left"> <img alt="CPU_Add_Batch" src="https://yellow-cdn.veclightyear.com/835a84d5/3b96a9c0-ff60-4096-95ab-5dd170db10d1.png" width="49%"> <img alt="TPU_Add_Batch" src="https://yellow-cdn.veclightyear.com/835a84d5/4bb2c5e7-a692-47f7-abec-2c9184657186.png" width="49%"> </p> <p align="center"> <img alt="GPU_Add_Batch" src="https://yellow-cdn.veclightyear.com/835a84d5/47417c2e-b856-4a3a-9143-9fc232963343.png" width="49%"> </p>

最终,我们看到性能优于或可与基准测试的缓冲区相媲美,同时提供完全兼容 JAX 的缓冲区,此外还提供批量添加以及能够添加不同长度的序列等功能。我们确实注意到,由于 JAX 对 CPU、GPU 和 TPU 有不同的 XLA 后端,缓冲区的性能可能会因设备和所调用的特定操作而异。

贡献 🤝

欢迎贡献!请查看我们的问题跟踪器以了解适合新手的问题。请阅读我们的贡献指南,了解如何提交拉取请求、我们的贡献者许可协议和社区指南的详细信息。

另请参阅 📚

其他缓冲区

查看我们在基准测试中强调的社区中的其他缓冲区库。

  • 📀 Dejax: 第一个提供兼容 JAX 的回放缓冲区的库。
  • 🎶 Reverb: 用于本地和大规模分布式 RL 的高效回放缓冲区。
  • 🍰 Dopamine: 用于快速原型设计的研究框架,提供了几个核心回放缓冲区。
  • 🤖 StableBaselines3: 可靠的 RL 基线套件,具有自己易于使用的回放缓冲区。

使用示例

查看社区中使用 flashbax 的一些库:

  • 🦁 Mava: 利用 flashbax 的多智能体算法的端到端 JAX 实现。
  • 🏛️ Stoix: 利用 flashbax 的单智能体算法的端到端 JAX 实现。

引用 Flashbax ✏️

如果您在工作中使用了 Flashbax,请使用以下方式引用该库:

@misc{flashbax,
    title={Flashbax: Streamlining Experience Replay Buffers for Reinforcement Learning with JAX},
    author={Edan Toledo and Laurence Midgley and Donal Byrne and Callum Rhys Tilbury and
    Matthew Macfarlane and Cyprien Courtot and Alexandre Laterre},
    year={2023},
    url={https://github.com/instadeepai/flashbax/},
}

致谢 🙏

该库的开发得到了来自 Google 的 TPU Research Cloud (TRC) 🌤 的 Cloud TPU 支持。

编辑推荐精选

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

热门AI工具AI办公办公工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

模型训练热门AI工具内容创作智能问答AI开发讯飞星火大模型多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

热门AI工具生产力协作转型TraeAI IDE
咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

AI助手热门AI工具AI创作AI辅助写作讯飞绘文内容运营个性化文章多平台分发
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

Hunyuan3D-2

Hunyuan3D-2

高分辨率纹理 3D 资产生成

Hunyuan3D-2 是腾讯开发的用于 3D 资产生成的强大工具,支持从文本描述、单张图片或多视角图片生成 3D 模型,具备快速形状生成能力,可生成带纹理的高质量 3D 模型,适用于多个领域,为 3D 创作提供了高效解决方案。

3FS

3FS

一个具备存储、管理和客户端操作等多种功能的分布式文件系统相关项目。

3FS 是一个功能强大的分布式文件系统项目,涵盖了存储引擎、元数据管理、客户端工具等多个模块。它支持多种文件操作,如创建文件和目录、设置布局等,同时具备高效的事件循环、节点选择和协程池管理等特性。适用于需要大规模数据存储和管理的场景,能够提高系统的性能和可靠性,是分布式存储领域的优质解决方案。

下拉加载更多