chatglm-maths

chatglm-maths

ChatGLM-6B数学运算能力优化项目

该项目旨在优化ChatGLM-6B模型的整数和小数四则运算能力。项目采用LORA、PPO等多种训练方法,支持GPU和CPU环境。内容包括自动生成的训练样本、微调数据集、LORA权重,以及环境配置和使用说明。这一工具主要面向开发者和研究人员,用于提升大语言模型的数学计算表现。

ChatGLM-6B微调LORAPPO数学计算Github开源项目

chatglm-maths

chatglm-6b微调/LORA/PPO/推理, 样本为自动生成的整数/小数加减乘除运算, 可gpu/cpu

踩坑

1. eps=1e-5(不要改小), 半精度float16, 以及LN采用的是Post-LN(泛化性更好) + DeepNorm, 【害, Attention前也有LN】目的是大模型为了防止梯度溢出等; 2. 模型输入输出, 默认的tokenization_chatglm.py/modeling_chatglm.py不能用, 因为那是完全为生成generate设置的, 需要自己写好所有缩入参数, 或者机子改成适配的; 2.1 ChatGLMModel中, get_masks()正常, get_position_ids()函数中‘context_length = seq.index(150004) + 1’ 改为 ‘context_length = len(seq); 2.2 训练输入input_ids格式暂定为(训练后post-padding, 推理前pre-padding[tokenization_chatglm.py默认pre-padding]) x: prompt_1 + "_" + text_1 + "\n" + prompt_2 + [gMASK] + [BOS] + "_" + text_2 + [PAD]*N 2.3 训练输入label_ids格式暂定为(CrossEntropyLoss默认忽略-100不参与计算loss) y = [-100]*len(text_1) + [BOS] + text_2 + [EOS] + [-100]*N 2.4 注意position/mask(自带的只是推理用的batch_size=1, 所以训练输入还得自己写), 可参考GLM-130的README.md, huozhe 查看GLM-1源码https://github.com/THUDM/GLM/blob/main/tasks/seq2seq/dataset.py 3. 注意chatglm-6b权重是float16的, 不过计算loss时候会转成float32计算, 最后loss再转回float16更新梯度; 4. ChatGLMTokenizer有时候会报奇奇怪怪的错误, 建议生成时候设置max_new_tokens, 最大{"max_new_tokens": 2048}; decode有时候会出现不存在id; 5. 低秩自适应LORA, RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! 尝试 transformers升级到最新, get_peft_model后再.cuda(), device_map={'':torch.cuda.current_device()},

微调数据

  1. 原始数据来自https://github.com/LYH-YF/MWPToolkit

    处理后的微调数据(算式/解方程)-MWP: https://huggingface.co/datasets/Macropodus/MWP-Instruct

  2. 大数加减乘除来自: https://github.com/liutiedong/goat.git

LoRA权重

Baichuan-7B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct Bloomz-7B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct ChatGLM-6B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct LlaMA-7B-GPT4ForALL: https://huggingface.co/Macropodus/MWP-Instruct ChatGLM-6B-MWP: https://huggingface.co/Macropodus/MWP-Instruct

数据集-中文

环境配置

transformers>=4.26.1 cpm_kernels==1.0.11 icetk==0.0.4 torch>=1.10.1 rouge==1.0.1 nltk==3.6.6 peft>=0.2.0 numpy tqdm lion_pytorch macropodus trl>=0.4.1

微调-计算题

lora 微调: python c00_toy_lora_train_6b.py 推理: python p00_toy_lora_predict_6b.py ppo 训练: python t10_toy_trl_train_ppo.py 测试: python t10_toy_trl_predict_ppo.py 6b 微调: python c00_toy_cpu_train_6b.py 推理: python p00_toy_cpu_predit_6b.py small-layer 微调: python c01_toy_cpu_train_small.py 推理: python p01_toy_cpu_predict_small.py

参考/感谢

推理日志toy

generator_calculate_line: ('13+75=', '13+75=88') tokenizer.vocab_size: 150344 eval: 0%| | 0/1 [00:00<?, ?it/s]batch_query: ['简便运算: 98+83= 剖析: 98+83=181'] batch_qtext_0: 简便运算: 98+83= 剖析: batch_qans_0: 98+83=181 response_0: 98+83=171 {'rouge-1': 0.0, 'rouge-2': 0.0, 'rouge-l': 0.0, 'bleu': 0.0} 请输入: 25.31+86.35= 请稍等... 25.31+86.35=101.66

微调日志toy

generator_calculate_line: ('13+75=', '13+75=88') tokenizer.vocab_size: 150344 Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:10<00:00, 1.31s/it] transformer.word_embeddings.weight False ...... transformer.layers.26.mlp.dense_4h_to_h.bias False transformer.layers.27.input_layernorm.weight True transformer.layers.27.input_layernorm.bias True transformer.layers.27.attention.query_key_value.weight True transformer.layers.27.attention.query_key_value.bias True transformer.layers.27.attention.dense.weight True transformer.layers.27.attention.dense.bias True transformer.layers.27.post_attention_layernorm.weight True transformer.layers.27.post_attention_layernorm.bias True transformer.layers.27.mlp.dense_h_to_4h.weight True transformer.layers.27.mlp.dense_h_to_4h.bias True transformer.layers.27.mlp.dense_4h_to_h.weight True transformer.layers.27.mlp.dense_4h_to_h.bias True transformer.final_layernorm.weight True transformer.final_layernorm.bias True model.chat start 13+75=88, but that's not the correct answer. The correct answer is 13+75=88, which is 90. /anaconda3/envs/py371/lib/python3.7/site-packages/transformers/optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning FutureWarning, epoch: 0%| | 0/21 [00:00<?, ?it/s]epochs: batch_query: ['简便运算: 98+83= 剖析: 98+83=181'] | 0/8 [00:00<?, ?it/s] epoch_global: 0, step_global: 1, step: 0, loss: 4.0625 batch_query: ['口算: 57.84+13.64 解: 57.84+13.64=71.48'] epoch_global: 0, step_global: 2, step: 1, loss: 2.5625███▌ | 2/8 [00:17<00:51, 8.54s/it] batch_query: ['计算题: 48+1 解答: 48+1=49'] epoch_global: 0, step_global: 3, step: 2, loss: 4.15625█████████████████████▎ | 3/8 [00:38<01:09, 13.94s/it] batch_query: ['计算题: 61.65+33.05 解答: 61.65+33.05=94.7'] epoch_global: 0, step_global: 4, step: 3, loss: 2.40625████████████████████████████████████████ | 4/8 [01:01<01:09, 17.43s/it] batch_query: ['计算: 81+75 回答: 81+75=156'] epoch_global: 0, step_global: 5, step: 4, loss: 3.546875█████████████████████████████████████████████████████████▊ | 5/8 [01:27<01:01, 20.41s/it] epoch: 5%|███████▎ | 1/21 [03:07<1:02:30, 187.52s/it]epochs: step: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [02:41<00:00, 23.15s/it] epoch_0_step: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [03:07<00:00, 23.44s/it] batch_query: ['问题: 99+37 答案: 99+37=136'] epoch_global: 1, step_global: 9, step: 0, loss: 3.640625 | 0/8 [00:00<?, ?it/s] batch_query: ['问题: 26.81+55.91 答案: 26.81+55.91=82.72'] | 0/1 [00:00<?, ?it/s] batch_qtext_0: 问题: 26.81+55.91 答案: batch_qans_0: 26.81+55.91=82.72 response_0: 26.81+55.91=83.72 {'rouge-1': 0.749999995, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.749999995, 'bleu': 0.0} epoch_global: 1, step_global: 9, step: 0 best_score_avg: 0.45833 current_mertics: {'rouge-1': 0.749999995, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.749999995, 'bleu': 0.0} batch_query: ['数学题: 23.34+68.45 点拨: 23.34+68.45=91.79'] epoch_global: 1, step_global: 10, step: 1, loss: 2.09375 batch_query: ['计算: 77+14 回答: 77+14=91']█████████████▌ | 2/8 [00:33<01:39, 16.58s/it] epoch_global: 1, step_global: 11, step: 2, loss: 3.265625 batch_query: ['口算: 79.69+17.43= 解: 79.69+17.43=97.12']██████████████████▎ | 3/8 [00:35<00:53, 10.75s/it] epoch_global: 1, step_global: 12, step: 3, loss: 2.171875 batch_query: ['简便运算: 59.67+86.73 剖析: 59.67+86.73=146.4']████████████████████████████████ | 4/8 [00:37<00:29, 7.43s/it] epoch_global: 1, step_global: 13, step: 4, loss: 2.328125 epoch: 10%|██████████████▊ | 2/21 [03:56<33:33, 105.97s/it]epochs: epoch_1_step: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:48<00:00, 6.11s/it] batch_query: ['初等数学: 24.29+76.26 解析: 24.29+76.26=100.55'] epoch_global: 2, step_global: 17, step: 0, loss: 2.046875 epoch_2_step: 0%| | 0/8 [00:00<?, ?it/sbatch_query: ['计算题: 69.85+28.46= 解答: 69.85+28.46=98.31'] batch_qtext_0: 计算题: 69.85+28.46= 解答: | 0/1 [00:00<?, ?it/s] batch_qans_0: 69.85+28.46=98.31 response_0: 69.85+28.46=97.21 {'rouge-1': 0.4999999950000001, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.4999999950000001, 'bleu': 0.0} eval: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00, 7.83s/it] epoch_global: 2, step_global: 17, step: 0 best_score_avg: 0.33333 current_mertics: {'rouge-1': 0.4999999950000001, 'rouge-2': 0.3333333283333334, 'rouge-l': 0.4999999950000001, 'bleu': 0.0} batch_query: ['问题: 113.79+81.78= 答案: 113.79+81.78=195.57'] epoch_global: 2, step_global: 18, step: 1, loss: 1.8515625 batch_query: ['计算: 10.74+17.87= 回答:

编辑推荐精选

TRAE编程

TRAE编程

AI辅助编程,代码自动修复

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

热门AI工具生产力协作转型TraeAI IDE
蛙蛙写作

蛙蛙写作

AI小说写作助手,一站式润色、改写、扩写

蛙蛙写作—国内先进的AI写作平台,涵盖小说、学术、社交媒体等多场景。提供续写、改写、润色等功能,助力创作者高效优化写作流程。界面简洁,功能全面,适合各类写作者提升内容品质和工作效率。

AI助手AI工具AI写作工具AI辅助写作蛙蛙写作学术助手办公助手营销助手
问小白

问小白

全能AI智能助手,随时解答生活与工作的多样问题

问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。

聊天机器人AI助手热门AI工具AI对话
Transly

Transly

实时语音翻译/同声传译工具

Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

热门AI工具AI办公办公工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

模型训练热门AI工具内容创作智能问答AI开发讯飞星火大模型多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

AI助手热门AI工具AI创作AI辅助写作讯飞绘文内容运营个性化文章多平台分发
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

下拉加载更多