高效的大模型训练工具
<p align="center"> <a href="#概述">概述</a> • <a href="#文档">文档</a> • <a href="#安装">安装</a> • <a href="#使用">使用</a> • <a href="#性能">性能</a> • <a href="./README-ZH.md" target="_blank">简体中文</a> <br> </p> <p align="center"> <a href='https://bmtrain.readthedocs.io/en/latest/?badge=latest'> <img src='https://readthedocs.org/projects/bmtrain/badge/?version=latest' alt='文档状态' /> </a> <a href="https://github.com/OpenBMB/BMTrain/releases"> <img alt="GitHub release (最新by日期包括预发布)" src="https://img.shields.io/github/v/release/OpenBMB/BMTrain?include_prereleases"> </a> <a href="https://github.com/OpenBMB/BMTrain/blob/main/LICENSE"> <img alt="GitHub" src="https://img.shields.io/github/license/OpenBMB/BMTrain"> </a> </p> </div>BMTrain是一个高效的大模型训练工具包,可用于训练具有数百亿参数的大模型。它可以以分布式方式训练模型,同时保持代码像单机训练一样简单。
<div id="文档"></div>我们的文档提供了关于该软件包的更多信息。
<div id="安装"></div>通过pip(推荐):pip install bmtrain
从源代码:下载软件包并运行pip install .
安装BMTrain可能需要几分钟到十几分钟,因为它需要在安装时编译c/cuda源代码。 我们建议直接在训练环境中编译BMTrain,以避免不同环境可能引起的潜在问题。
<div id="使用"></div>在使用BMTrain之前,您需要在代码开头对其进行初始化。就像使用PyTorch的分布式模块需要在代码开头使用init_process_group一样,使用BMTrain需要在代码开头使用init_distributed。
import bmtrain as bmt bmt.init_distributed( seed=0, # ... )
注意: 在使用BMTrain时,不要使用PyTorch的分布式模块及其相关通信函数。
要启用ZeRO优化,您需要对原始模型代码进行一些简单的替换。
torch.nn.Module
-> bmtrain.DistributedModule
torch.nn.Parameter
-> bmtrain.DistributedParameter
并用bmtrain.Block
包装transformer块。
这里有一个例子。
原始代码
import torch class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.empty(1024)) self.module_list = torch.nn.ModuleList([ SomeTransformerBlock(), SomeTransformerBlock(), SomeTransformerBlock() ]) def forward(self): x = self.param for module in self.module_list: x = module(x, 1, 2, 3) return x
替换后的代码
import torch import bmtrain as bmt class MyModule(bmt.DistributedModule): # 此处改变 def __init__(self): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) # 此处改变 self.module_list = torch.nn.ModuleList([ bmt.Block(SomeTransformerBlock(), zero_level=3), # 此处改变,现在支持2和3 bmt.Block(SomeTransformerBlock(), zero_level=3), # 此处改变,现在支持2和3 bmt.Block(SomeTransformerBlock(), zero_level=3) # 此处改变,现在支持2和3 ]) def forward(self): x = self.param for module in self.module_list: x = module(x, 1, 2, 3) return x
为进一步减少额外的通信开销并将通信与计算时间重叠,可以使用TransformerBlockList
进行优化。
您可以通过对代码进行以下替换来启用它们:
torch.nn.ModuleList
-> bmtrain.TransformerBlockList
for module in self.module_list: x = module(x, ...)
-> x = self.module_list(x, ...)
原始代码
import torch import bmtrain as bmt class MyModule(bmt.DistributedModule): def __init__(self): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) self.module_list = torch.nn.ModuleList([ bmt.Block(SomeTransformerBlock()), bmt.Block(SomeTransformerBlock()), bmt.Block(SomeTransformerBlock()) ]) def forward(self): x = self.param for module in self.module_list: x = module(x, 1, 2, 3) return x
替换后的代码
import torch import bmtrain as bmt class MyModule(bmt.DistributedModule): def __init__(self): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) self.module_list = bmt.TransformerBlockList([ # 此处改变 bmt.Block(SomeTransformerBlock()), bmt.Block(SomeTransformerBlock()), bmt.Block(SomeTransformerBlock()) ]) def forward(self): x = self.param for module in self.module_list: x = module(x, 1, 2, 3) return x
BMTrain使用与PyTorch分布式模块相同的启动命令。
您可以根据您的PyTorch版本选择其中一种。
${MASTER_ADDR}
表示主节点的IP地址。${MASTER_PORT}
表示主节点的端口。${NNODES}
表示节点总数。${GPU_PER_NODE}
表示每个节点的GPU数量。${NODE_RANK}
表示此节点的排名。$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py
$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py
有关更多信息,请参阅文档。
我们提供了一个基于BMTrain的GPT-2训练示例。 代码主要由以下部分组成。
├── layers
│ ├── attention.py
│ ├── embedding.py
│ ├── feedforward.py
│ ├── __init__.py
│ ├── layernorm.py
│ └── linear.py
└── models
├── gpt.py
└── __init__.py
上面是模型定义部分的代码目录结构。
我们定义了GPT-2所需的所有层,并使用BMTrain的DistributedModule
和DistributedParameter
来启用ZeRO优化。
bmtrain.init_distributed(seed=0) model = GPT( num_layers=8, vocab_size=10240, dim_model=2560, dim_head=80, num_heads=32, dim_ff=8192, max_distance=1024, bias=True, dtype=torch.half ) bmtrain.init_parameters(model) # 或使用`bmtrain.load`加载检查点 # ... 其他初始化(数据集)...
bmtrain.init_distributed(seed=0)
用于初始化分布式训练环境并设置随机种子以确保可重现性。
bmtrain.init_parameters(model)
用于初始化模型的分布式参数。
loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) optimizer = bmtrain.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmtrain.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0)
BMTrain支持所有PyTorch原生优化器和损失函数,您还可以使用BMTrain提供的融合优化器进行混合精度训练。
此外,BMTrain还在bmtrain.lr_scheduler
模块中提供了常用的学习率调度器。
# 创建优化器管理器的新实例 optim_manager = bmtrain.optim.OptimManager(loss_scale=1024) # 让优化器管理器处理所有的优化器和(可选的)对应的学习率调度器 optim_manager.add_optimizer(optimizer, lr_scheduler) # add_optimizer可以多次调用以添加其他优化器 for iteration in range(1000): # ... 为每个进程加载数据 ... # 前向传播和计算损失 pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) logits = model( enc_input, pos, pos < enc_length[:, None] ) batch, seq_len, vocab_out_size = logits.size() loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) global_loss = bmtrain.sum_loss(loss).item() # 汇总所有进程的损失。这仅用于训练日志 # 梯度清零 optim_manager.zero_grad() # 为每个优化器调用zero_grad # 损失缩放和反向传播 optim_manager.backward(loss) # 梯度裁剪 grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0) # 优化器步进 optim_manager.step() # ... 保存检查点或打印日志 ...
训练循环部分会稍长一些,但就像普通的训练循环一样,你不需要对分布式训练做太多适配。
你可以根据代码中的注释了解每个代码部分的功能。
唯一需要额外注意的是optimizer
。使用BMTrain后,优化器中的一些细节需要调整。我们已经在optim_manager
中实现了所有需要的细节。你只需要让optim_manager
通过add_optimizer
处理所有的优化器,并让optim_manager
代替执行zero_grad()
、backward()
、clip_grad_norm()
和step()
。
如果你不使用混合精度训练,可以在不使用loss_scale
的情况下进行训练。只需在OptimManager
的__init__
函数中将loss_scale
设置为None(OptimManager(loss_scale=None)
),这也是默认设置。
如果你使用混合精度训练,