高性能强化学习框架助力大规模语言模型优化
OpenRLHF是一款基于Ray、DeepSpeed和Hugging Face Transformers构建的高性能强化学习框架。该框架简单易用,兼容Hugging Face模型和数据集,性能优于优化后的DeepSpeedChat。它支持分布式RLHF,能够在多GPU环境下进行70B+参数模型的全规模微调。OpenRLHF集成了多项PPO实现技巧以提升训练稳定性,同时支持vLLM生成加速和多奖励模型等先进特性,为大规模语言模型优化提供了强大支持。
<span>[ 英文 | <a href="README_zh.md">中文</a> ]</span>
OpenRLHF是一个基于Ray、DeepSpeed和HF Transformers构建的高性能RLHF框架:
tokenizer.apply_chat_template
(--apply_chat_template和--input_key)。特性 | OpenRLHF | DSChat | CAIChat | TRL |
---|---|---|---|---|
使用16个A100-80GB进行70B+完整微调 | ✅ | ❌ | ❌ | ❌ |
使用4个RTX4090进行7B完整微调 | ✅ | ❌ | ❌ | ❌ |
使用8个A100-80GB进行34B DPO完整微调 | ✅ | ❌ | ❌ | ❌ |
PPO中的推理引擎 | ✅ | ✅ | ❌ | ❌ |
PPO实现技巧 | ✅ | ❌ | ❌ | ✅ |
支持QLoRA | ✅ | ❌ | ❌ | ✅ |
支持Mixtral 8*7b | ✅ | ❌ | ❌ | ❌ |
支持未合并的Actor-Critic | ✅ | ✅ | ✅ | ❌ |
支持多个奖励模型 | ✅ | ❌ | ❌ | ❌ |
支持Huggingface模型 | ✅ | ✅ | ✅ | ✅ |
易用性 | ✅ | ❌ (HybridEngine问题) | ✅ | ✅ |
要使用OpenRLHF,首先启动docker容器(推荐),然后在docker容器内使用pip install
安装openrlhf:
# 启动docker容器 docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN -v $PWD:/openrlhf nvcr.io/nvidia/pytorch:24.02-py3 bash sudo pip uninstall xgboost transformer_engine flash_attn -y # pip安装 pip install openrlhf # 如果你想使用vLLM加速(安装vLLM 0.4.2) pip install openrlhf[vllm] # 也支持最新版本的vLLM(使用Gloo) pip install openrlhf[vllm_latest] # 安装最新版本 pip install git+https://github.com/OpenRLHF/OpenRLHF.git # 或者git克隆 git clone https://github.com/OpenRLHF/OpenRLHF.git cd OpenRLHF pip install -e .
[!注意] 我们建议使用vLLM 0.4.2,因为0.4.3+版本目前仅支持通过Gloo进行权重同步(DeepSpeed到vLLM)(
--vllm_sync_backend gloo
)。 我们还提供了vLLM的Dockerfile和Nvidia-Docker一键安装脚本。
OpenRLHF在我们的数据集类中提供了多种数据处理方法。 例如在Prompt Dataset中:
def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str: if apply_chat_template: prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True) else: prompt = data[input_key] if input_template: prompt = input_template.format(prompt) return prompt
--input_key
来指定输入数据集 --prompt_data {名称或路径}
(PPO) 或 --dataset {名称或路径}
的 JSON 键名
,并使用 --apply_chat_template
来利用 Huggingface Tokenizer 中的 chat_template
。--apply_chat_template
,你可以使用 --input_template
代替,或者提前离线预处理数据集。--prompt_data_probs 0.1,0.4,0.5
(PPO) 或 --dataset_probs 0.1,0.4,0.5
混合多个数据集。聊天模板的工作原理:
dataset = [{"input_key": [ {"role": "user", "content": "你好,你好吗?"}, {"role": "assistant", "content": "我很好。今天我能为你做些什么?"}, {"role": "user", "content": "我想展示一下聊天模板是如何工作的!"}, ]}] tokenizer.apply_chat_template(dataset[0]["input_key"], tokenize=False) "<s>[INST] 你好,你好吗? [/INST]我很好。今天我能为你做些什么?</s> [INST] 我想展示一下聊天模板是如何工作的! [/INST]"
如何指定训练和测试数据集?
你可以使用 data_type@data_dir
格式来指定。例如,数据集可以设置为 --dataset json@./data
。
data
├── test.jsonl
└── train.jsonl
[!注意] 默认情况下,我们使用
train
和test
作为分割来区分 Huggingface 的训练和测试数据集。JSON 键
选项取决于具体的数据集。参见 奖励数据集 和 SFT 数据集
OpenRLHF 的模型检查点与 HuggingFace 模型完全兼容。你可以使用 --pretrain {名称或路径}
、--reward_pretrain {名称或路径}
和 --critic_pretrain {名称或路径}
来指定模型名称或路径。我们在 HuggingFace OpenRLHF 上提供了一些预训练的检查点和数据集。
然后你可以使用我们在 examples/scripts 目录中提供的启动脚本,或使用以下命令开始训练。
deepspeed --module openrlhf.cli.train_sft \ --max_len 4096 \ --dataset Open-Orca/OpenOrca \ --input_key question \ --output_key response \ --input_template '用户: {}\n助手: ' \ --train_batch_size 256 \ --micro_train_batch_size 2 \ --max_samples 500000 \ --pretrain meta-llama/Meta-Llama-3-8B \ --save_path ./checkpoint/llama3-8b-sft \ --save_steps -1 \ --logging_steps 1 \ --eval_steps -1 \ --zero_stage 2 \ --max_epochs 1 \ --bf16 \ --flash_attn \ --learning_rate 5e-6 \ --gradient_checkpointing \ --use_wandb {wandb_token} # 支持 HF tokenizer.apply_chat_template # --apply_chat_template # --input_key {JSON 键} # --tokenizer_chat_template {HF 聊天模板} # 支持样本打包 # --packing_samples # 也可用于继续预训练 # --pretrain_mode
[!注意] OpenRLHF SFT/DPO/RewardModel 训练器支持
--packing_samples
基于--flash_attn
deepspeed --module openrlhf.cli.train_rm \ --save_path ./checkpoint/llama3-8b-rm \ --save_steps -1 \ --logging_steps 1 \ --eval_steps -1 \ --train_batch_size 256 \ --micro_train_batch_size 1 \ --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ --bf16 \ --max_epochs 1 \ --max_len 8192 \ --zero_stage 3 \ --learning_rate 9e-6 \ --dataset OpenRLHF/preference_dataset_mixture2_and_safe_pku \ --apply_chat_template \ --chosen_key chosen \ --rejected_key rejected \ --flash_attn \ --gradient_checkpointing \ --use_wandb {wandb_token} # 支持样本打包 # --packing_samples
deepspeed --module openrlhf.cli.train_ppo \ --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ --save_path ./checkpoint/llama-3-8b-rlhf \ --save_steps -1 \ --logging_steps 1 \ --eval_steps -1 \ --micro_train_batch_size 2 \ --train_batch_size 128 \ --micro_rollout_batch_size 4 \ --rollout_batch_size 1024 \ --max_epochs 1 \ --prompt_max_len 1024 \ --generate_max_len 1024 \ --zero_stage 2 \ --bf16 \ --actor_learning_rate 5e-7 \ --critic_learning_rate 9e-6 \ --init_kl_coef 0.01 \ --prompt_data OpenRLHF/prompt-collection-v0.1 \ --input_key context_messages \ --apply_chat_template \ --max_samples 100000 \ --normalize_reward \ --adam_offload \ --flash_attn \ --gradient_checkpointing \ --use_wandb {wandb_token} # 支持远程奖励模型(HTTP) # --remote_rm_url http://localhost:5000/get_reward
为了提高 RLHF 训练速度或支持 70B 模型,我们可以使用 Ray 和 vLLM 加速的 PPO
# 在容器中启动 ray 的主节点 ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 # 如果你想在更多节点上启动 ray,使用 ray start --address {主节点地址}:6379 --num-gpus 8 ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json='{"working_dir": "/openrlhf"}' \ -- python3 -m openrlhf.cli.train_ppo_ray \ --ref_num_nodes 1 \ --ref_num_gpus_per_node 2 \ --reward_num_nodes 1 \ --reward_num_gpus_per_node 2 \ --critic_num_nodes 1 \ --critic_num_gpus_per_node 2 \ --actor_num_nodes 1 \ --actor_num_gpus_per_node 2 \ --vllm_num_engines 2 \ --vllm_tensor_parallel_size 2 \ --colocate_critic_reward \ --colocate_actor_ref \ --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ --micro_train_batch_size 8 \ --train_batch_size 128 \ --micro_rollout_batch_size 16 \ --rollout_batch_size 1024 \ --max_samples 100000 \ --max_epochs 1 \ --prompt_max_len 1024 \ --generate_max_len 1024 \ --zero_stage 3 \ --bf16 \ --actor_learning_rate 5e-7 \ --critic_learning_rate 9e-6 \ --init_kl_coef 0.01 \ --prompt_data OpenRLHF/prompt-collection-v0.1 \ --input_key context_messages \ --apply_chat_template \ --normalize_reward \ --adam_offload \ --flash_attn \ --gradient_checkpointing \ --use_wandb {wandb_token} # 支持远程奖励模型(HTTP) # --remote_rm_url http://localhost:5000/get_reward
[!注意] 不设置
--vllm_num_engines
意味着不使用 vLLM 引擎。 你也可以使用setup_commands
让 Ray 自动部署环境,比如--runtime-env-json='{"setup_commands": ["pip install openrlhf[vllm]"]}'
。
支持的算法的启动脚本和文档在 example/scripts 和 文档 - 使用
我们通过采用诸如启用Adam卸载、奖励模型(RM)和参考模型(Ref)卸载等技术,最大程度地优化了DSChat的性能,以在推理阶段增加微批次大小并避免内存不足问题。我们甚至修复了DSChat中的一些错误,以便为LLaMA2启用混合引擎(HE)。使用优化后的DSChat和OpenRLHF训练1024个提示词,1个PPO周期所需的平均时间(秒)如下:
大小 | NVIDIA A800-80GB GPU数量 | 优化后的DSChat (使用混合引擎) | OpenRLHF | 加速比 |
---|---|---|---|---|
7B | 16 | 855.09 | 471.11 | 1.82x |
13B | 32 | 1528.93 | 608.93 | 2.5x |
34B | 32 | 3634.98 | 1526.4 | 2.4x |
70B | 32 | 10407.0 | 4488.53 | 2.3x |
为了获得最佳性能,我们建议为vLLM引擎分配更多节点。例如,对于使用32个A100 GPU的70B模型,建议为vLLM引擎分配超过16个A100 GPU,为Actor模型分配8个GPU,剩余的8个GPU分配给Critic模型。此外,启用--colocate_critic_reward
、--colocate_actor_ref
和--ref_reward_offload
选项以合并节点。最后,应尽可能增加rollout_micro_batch_size
(并最小化vLLM引擎的TP大小),避免Reward/Reference
模型前向传播出现OOM(内存不足)问题。在训练阶段,更大的--micro_train_batch_size
效果更好。当n_samples_per_prompt > 1
时,在vLLM生成中启用enable_prefix_caching
。
如何加入?
你可以做什么?
你的赞助可以帮助我们维护和改进OpenRLHF。如果你觉得这个项目有用,请考虑赞助我们。你可以在Open Collective ↗上赞助我们。
衷心感谢所有贡献者!如果你想贡献,欢迎提交拉取请求或创建issue。
<a href="https://github.com/OpenRLHF/OpenRLHF/graphs/contributors"> <img src="https://contrib.rocks/image?repo=OpenRLHF/OpenRLHF" /> </a>我们要感谢以下项目和组织对AI和NLP领域的贡献:
我们的项目还要感谢ColossalChat和DeepSpeedChat。在项目早期阶段,我们参考了他们的代码设计。
(2024/7) 我们的GitHub组织已从OpenLLMAI更名为OpenRLHF。
@article{hu2024openrlhf,
title={OpenRLHF: An Easy-to-use, Scalable and High-performance RLHF Framework},
author={Jian Hu and Xibin Wu and Weixun Wang and Xianyu and Dehao Zhang and Yu Cao},
journal={arXiv preprint arXiv:2405.11143},
year={2024}
}
OpenRLHF © 2024 OpenRLHF. 版权所有。
AI数字人视频创作平台
Keevx 一款开箱即用的AI数字人视频创作平台,广泛适用于电商广告、企业培训与社媒宣传,让全球企业与个人创作者无需拍摄剪辑,就能快速生成多语言、高质量的专业视频。
一站式AI创作平台
提供 AI 驱动的图片、视频生成及数字人等功能,助力创意创作
AI办公助手,复杂任务高效处理
AI办公助手,复杂任务高效处理。办公效率低?扣子空间AI助手支持播客生成、PPT制作、网页开发及报告写作,覆盖科研、商业、舆情等领域的专家Agent 7x24小时响应,生活工作无缝切换,提升50%效率!
AI辅助编程,代码自动修复
Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地 编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。
AI小说写作助手,一站式润色、改写、扩写