通过自反学习使语言模型实现按需检索、生成和评估的框架
Self-RAG是一种创新框架,通过自反学 习使语言模型实现按需检索、生成和评估。该方法预测反思标记,支持多次检索或跳过检索,并从多角度评估生成内容。这不仅提高了模型输出的事实性和质量,还保持了语言模型的通用性能。
这包括原始实现的SELF-RAG: 通过自我反思学习检索、生成和评判(ICLR 2024,口头报告前1%),作者为Akari Asai、Zeqiu Wu、Yizhong Wang、Avirup Sil和Hannaneh Hajishirzi。
网站 | 7B模型 | 13B模型 | 论文 | 训练数据 | Twitter摘要 | 更新
Self-RAG(右图)是一个新的框架,用于训练任意语言模型学习检索、生成和评判,以提高生成内容的事实性和质量,同时不影响大型语言模型的多功能性。
与广泛采用的检索增强生成(RAG;左图)方法不同,Self-RAG根据需求进行检索(例如,可以多次检索或完全跳过检索),针对不同的查询,并通过预测反思标记作为生成的组成部分,从多个细粒度方面对自身生成进行评判。我们进行分段束搜索,以选择能最大化多样化偏好效用的输出。
如果您发现我们的代码、数据、模型或论文有用,请引用以下论文:
@inproceedings{
asai2024selfrag,
author={Asai, Akari and Wu, Zeqiu and Wang, Yizhong and Sil, Avirup and Hajishirzi, Hannaneh},
title={Self-{RAG}: Learning to Retrieve, Generate, and Critique through Self-Reflection},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=hSyW5go0v8}
}
通过运行以下命令安装依赖的Python库。
pip install -r requirements.txt
请使用最新版本的vllm
,因为旧版本可能无法通过SamplingParam
设置skip_special_tokens
,这是由(这个PR)添加的。
您也可以通过运行以下命令创建conda环境。
conda env create -f environment.yml
您可以从HuggingFace Hub下载Self-RAG。对于推理,我们建议使用vllm,因为它可以显著加快推理速度。
from vllm import LLM, SamplingParams model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half") sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False) def format_prompt(input, paragraph=None): prompt = "### 指令:\n{0}\n\n### 回复:\n".format(input) if paragraph is not None: prompt += "[检索]<段落>{0}</段落>".format(paragraph) return prompt query_1 = "找出不同项:twitter、instagram、whatsapp。" query_2 = "你能告诉我美洲驼和羊驼的区别吗?" queries = [query_1, query_2] # 对于不需要检索的查询 preds = model.generate([format_prompt(query) for query in queries], sampling_params) for pred in preds: print("模型预测: {0}".format(pred.outputs[0].text))
输出:
模型预测: Twitter、Instagram和WhatsApp都是社交媒体平台。[无检索]WhatsApp是不同项,因为它是一个消息应用,而Twitter和Instagram主要用于分享照片和视频。[效用:5]</s> 模型预测: 好的![检索]<段落><段落>
如您所见,在第一个查询中,当不需要检索时,Self-RAG开始生成回答而不进行检索。另一方面,对于第二个查询,Self-RAG输出了[检索]
标记,因为这个问题需要更细粒度的事实依据。
对于需要事实依据的查询,您可以插入一个段落。Self-RAG可以在生成过程中随时检索和插入段落,只要它们被上下文标记特殊标记<段落>
、</段落>
包围,就能识别它们。
# 对于需要事实依据的查询
prompt = format_prompt("你能告诉我美洲驼和羊驼的区别吗?", "羊驼(Lama pacos)是南美洲骆驼科哺乳动物的一种。它与美洲驼相似,常常被混淆。羊驼比美洲驼小得多,与美洲驼不同的是,它们不是被培育为工作动物,而是专门为了它们的纤维而被培育。")
preds = model.generate([prompt], sampling_params)
print([pred.outputs[0].text for pred in preds])
# ['[相关]羊驼比美洲驼小得多,与美洲驼不同的是,它们不是被培育为工作动物,而是专门为了它们的纤维而被培育。[完全支持][效用:5]</s>']
Self-RAG找到相关的插入文档,并生成完全由证据支持的答案。
您也可以按需进行检索并与Self-RAG一起使用。由于在完整的英文维基百科上运行检索需要大量RAM和多个GPU,我们为演示目的创建了一个只包含维基百科文章介绍段落的子集。
首先 ,请下载语料库和嵌入(共9GB)。
git clone git@github.com:AkariAsai/self-rag.git
cd retrieval_lm
bash download_demo_corpus.sh
如果脚本不起作用,您可以从Google Drive或HF数据集下载数据。
然后,您可以在retrieval_lm
下运行脚本。我们在1个RTX 6000(24GB)和100G RAM上测试了该脚本(但应该可以在更小的RAM上运行)。
from passage_retrieval import Retriever retriever = Retriever({}) retriever.setup_retriever_demo("facebook/contriever-msmarco", "enwiki_2020_intro_only/enwiki_2020_dec_intro_only.jsonl", "enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/*", n_docs=5, save_or_load_index=False) retrieved_documents = retriever.search_document_demo(query_3, 5) prompts = [format_prompt(query_3, doc["title"] +"\n"+ doc["text"]) for doc in retrieved_documents] preds = model.generate(prompts, sampling_params) top_doc = retriever.search_document_demo(query_3, 1)[0] print("参考: {0}\n模型预测: {1}".format(top_doc["title"] + "\n" + top_doc["text"], preds[0].outputs[0].text))
输出:
参考: 过拟合 在统计学中,过拟合是"产生一个与特定数据集过于紧密或完全对应的分析,因此可能无法可靠地适应额外数据或预测未来观察结果"。过拟合模型是一个包含比数据可以证明更多参数的统计模型。过拟合的本质是无意中将一些残差变异(即噪声)提取出来,就好像这种变异代表了潜在的模型结构。欠拟合发生在统计模型无法充分捕捉数据的潜在结构时。欠拟合模型是一个模型,其中一些在正确指定的模型中会出现的参数或项缺失 模型预测: [相关]过拟合发生在模型相对于其训练数据量而言具有太多参数时,导致它过度记忆训练数据,并在新的、未见过的数据上表现不佳。[完全支持][效用:5]</s>
检索系统正确检索了必要的文档并生成了完全有根据的输出。
请注意,此演示使用较小的语料库和具有完整推理算法的Self-RAG。对于完整评估,您需要设置检索器或下载我们的检索结果。请按照推理中的说明进行操作。
默认情况下,我 们使用Contriever作为我们的检索组件。
下载DPR中使用的预处理段落数据。
cd retrieval_lm
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
然后,下载生成的段落。我们使用Contriever-MSMARCO
wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar
您可以通过运行以下命令来执行段落检索。
cd retrieval_lm
python passage_retrieval.py \
--model_name_or_path facebook/contriever-msmarco --passages psgs_w100.tsv \
--passages_embeddings "wikipedia_embeddings/*" \
--data 你的输入文件 \
--output_dir 你的输出文件 \
--n_docs 20
你的输入文件应该是json
或jsonl
格式。每个实例必须包含question
或instruction
,这将在检索过程中用作查询。
你可以通过运行以下命令为你自己的数据生成嵌入。(该脚本改编自Contriever仓库。)请注意,为大规模语料库(>1000万文档)生成嵌入可能需要一些时间,我们建议在多个GPU上运行。
cd retrieval_lm
for i in {0..3}; do
export CUDA_VISIBLE_DEVICES=${i}
python generate_passage_embeddings.py --model_name_or_path facebook/contriever-msmarco \
--output_dir 你的输出目录 \
--passages 你的段落数据 --shard_id ${i} --num_shards 4 > ./log/nohup.my_embeddings.${i} 2>&1 &
Self-RAG训练两个模型,Critic和Generator,它们都扩展了反思标记的词汇表,并使用标准的下一个标记预测目标进行训练。
或者,你可以下载我们由15万个实例组成的训练数据这里。
我们从GPT-4收集训练数据。用于为每种特殊标记类型调用GPT-4的脚本可在data_creation/critic获得。
或者,你可以在这里下载我们的训练数据。
一旦你创建或下载了训练数据,运行下面的命令对Llama2-7B进行critic训练的微调。
cd data_creation
torchrun --nproc_per_node=2 \
--master_port=2568 train_special_tokens.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--data_path 训练数据文件路径 \
--bf16 True \
--output_dir CRITIC模型路径 \
--num_train_epochs 3 \
--per_device_train_batch_size 1 --per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 300 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 10 \
--fsdp "full_shard auto_wrap"
创建Generator训练数据的代码在generator_data_creation下。请参阅README.md中的说明。
或者,你可以在HuggingFace数据集或这里下载我们的训练数据
对于generator训练,我们使用DeepSpeed来提高训练效率。你可以通过运行下面的脚本来进行训练,设置好训练数据路径后。
cd retrieval_lm
bash script_finetune_7b.sh
对于13B模型训练,使用training_13b
。我们使用8个40GB内存的A100进行7B模型训练,使用4个80GB内存的A100进行13B训练。7B模型应该可以在1-2个A100上运行,尽管训练可能会很慢。
对于论文中进行的任务评估,请在这里下载数据。
每个文件都已经包含了检索到的文档,所以如果你不想在推理过程中运行检索器,你可以简单地在contexts
加载检索到的文档。
下面,我们描述Self-RAG和基线。
由于我们通常只为短文本生成任务检索一次,我们提供了一个易于运行的评估脚本,利用Contriever离线预先检索的文档。请参见下面的各个命令。
python run_short_form.py \
--model_name selfrag/selfrag_llama2_7b \
--input_file eval_data/popqa_longtail_w_gs.jsonl \
--mode 模式 --max_new_tokens 100 \
--threshold 0.2 \
--output_file 你的输出文件 \
--metric match --ndocs 10 --use_groundness --use_utility --use_seqscore \
--dtype half
mode
指定推理时的模式,可选择['adaptive_retrieval', 'no_retrieval', 'always_retrieve']
。
adaptive_retrieval
根据threshold
或Self-RAG预测进行检索no_retrieval
在推理时禁用检索always_retrieve
始终进行检索。对于13B,如果你在单个24GB内存的GPU上运行,可能会遇到内存不足的问题。你可以通过设置--world_size
在多个GPU上运行推理。
python run_short_form.py \
--model_name selfrag/selfrag_llama2_7b \
--input_file eval_data/arc_challenge_processed.jsonl \
--max_new_tokens 50 --threshold 0.2 \
--output_file 输出文件名 \
--metric match --ndocs 5 --use_groundness --use_utility --use_seqscore \
--task arc_c
python run_short_form.py \
--model_name selfrag/selfrag_llama2_7b \
--input_file eval_data/health_claims_processed.jsonl \
--max_new_tokens 50 \
--threshold 0.2 --output_file 输出文件名 \
--metric match --ndocs 5 \
--use_groundness --use_utility --use_seqscore \
--task fever
对于长文本问答,你可以使用检索模型运行评估,也可以使用预先给定的段落运行评估。 目前,我们正在努力减少运行时内存需求(DPR / Contriever与整个英语维基百科嵌入需要100 GB RAM),加快长文本生成的速度,并首先发布使用一小组初始检索文档(~20)的推理代码。
注意:我们当前的实现专门为目标任务数据集的评估而设计。我们计划更新我们的代码库,使接口更简单,更易于使用。当我们发布另一个版本时,我们会宣布。
对于ASQA,请运行以下命令,
python run_long_form_static.py \
--model_name selfrag/selfrag_llama2_7b \
--ndocs 5 --max_new_tokens 300 --threshold 0.2 \
--use_grounding --use_utility --use_seqscore \
--task asqa --input_file eval_data/asqa_eval_gtr_top100.json \
--output_file 你的输出文件名 --max_depth 7 --mode always_retrieve \
对于FactScore,
python run_long_form_static.py \
--model_name selfrag/selfrag_llama2_7b \
--ndocs 5 --max_new_tokens 300 --threshold 0.2 \
--use_grounding --use_utility --use_seqscore \
--task factscore --input_file eval_data/factscore_unlabeled_alpaca_13b_retrieval.jsonl \
--output_file 你的输出文件名 --max_depth 7 \
Self-RAG的推理有几个关键参数。
w_rel
(默认1.0):w_rel
控制在波束搜索过程中对isRel
(评判检索到的段落是否相关的评价标记)标记概率的强调。w_sup
(默认1.0):w_sup
控制在波束搜索过程中对isSup
(评判生成是否被文档支持的评价标记)标记概率的强调。w_use
(默认0.5):w_use
控制在波束搜索过程中对isUse
(整体质量的评价标记)标记概率的强调。threshold
(默认0.2):此阈值控制自适应检索的频率。max_depth
(默认6):这对应于论文中的T
,它定义了搜索的最大深度。beam_width
(默认2):这控制段级波束搜索中波束的大小。更多详细信息,请参阅我们论文中的详细说明(第3.3节)和分析(第5节)。
对于长文本评估,设置外部库或仓库以运行评估。
factscore==v0.1.5
(生物)
请按照FactScore官方仓库的说明设置你的环境。python -m factscore.factscorer --data_path 你的输出文件 --model_name retrieval+ChatGPT --cache_dir 你的缓存目录 --openai_key 你的OPEN_AI_密钥 --verbose
ALCE 为长篇问答提供了使用多种不同指标的全面评估。对于您的首次评估,请安装 ALCE 仓库并下载数据。
git clone https://github.com/princeton-nlp/ALCE.git
python3 -m alce_env
cd ALCE
bash download_data.sh
对于 ASQA,您可以按以下方式运行评估。请注意,ASQA 评估需要基于 T5-XXL (11B) 的 NLI 模块。
python eval.py --f YOUR_OUTPUT_FILE --citations --qa --mauve
重新运行基准测试的代码可在 run_baseline_lm.py 找到。 要运行检索增强基准测试,请确保下载包含检索段落的任务输入文件。
python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa --prompt_name "prompt_no_input"
例如,PubHealth
python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/health_claims_processed.jsonl \
--max_new_tokens 20 \
--metric accuracy \
--result_fp llama2_7b_pubhealth_results.json \
--task fever
注意:对于 PubHealth 和 ARC,请传入任务名称(ARC = arc_c
和 PubHealth = fever
)以正确设置指令。
对于 OpenAI API 模型,您还需要在这里设置组织密钥。您还需要有一个包含 OpenAI API 密钥的 txt 文件。
python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
--task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--prompt_name "prompt_no_input"
python run_baseline_refactor.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"
python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
--task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"
问题1:如何使用 Self-RAG 方案训练新的预训练语言模型? -- 如果您使用的是 Hugging Face transformers,您可以简单地在我们的训练脚本 script_finetune_7b.sh 中更改 model_name_or_path
和 tokenizer_name
。如果您想使用自己的微调脚本,请确保添加特殊标记并屏蔽段落上下文,如此问题中所讨论的。
问题2:你们计划发布基于 Mistral-7B 的 Self-RAG 吗? -- 目前我的时间有限,无法这样做,但社区已经训练了一个基于 Mistral-7B 的 Self-RAG 版本 SciPhi-Self-RAG-Mistral-7B-32k。如果我们能够在 Mistral-7B 上训练 Self-RAG 并发布检查点,我们会通知大家。
如果您有问题,请提出一个问题并提及 @AkariAsai,或发送电子邮件至 akari[at]cs.washington.edu。
深度推理能力全新升级,全面对标OpenAI o1
科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。
一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型
Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。
字节跳动发布的AI编程神器IDE
Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。
AI助力,做PPT更简单!
咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的 一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。
选题、配图、成文,一站式创作,让内容运营更高效
讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。
专业的AI公文写作平台,公文写作神器
AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。
OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。
openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。
高分辨率纹理 3D 资产生成
Hunyuan3D-2 是腾讯开发的用于 3D 资产生成的强大工具,支持从文本描述、单张图片或多视角图片生成 3D 模型,具备快速形状生成能力,可生成带纹理的高质量 3D 模型,适用于多个领域,为 3D 创作提供了高效解决方案。
一个具备存储、管理和客户端操作等多种功能 的分布式文件系统相关项目。
3FS 是一个功能强大的分布式文件系统项目,涵盖了存储引擎、元数据管理、客户端工具等多个模块。它支持多种文件操作,如创建文件和目录、设置布局等,同时具备高效的事件循环、节点选择和协程池管理等特性。适用于需要大规模数据存储和管理的场景,能够提高系统的性能和可靠性,是分布式存储领域的优质解决方案。
用于可扩展和多功能 3D 生成的结构化 3D 潜在表示
TRELLIS 是一个专注于 3D 生成的项目,它利用结构化 3D 潜在表示技术,实现了可扩展且多功能的 3D 生成。项目提供了多种 3D 生成的方法和工具,包括文本到 3D、图像到 3D 等,并且支持多种输出格式,如 3D 高斯、辐射场和网格等。通过 TRELLIS,用户可 以根据文本描述或图像输入快速生成高质量的 3D 资产,适用于游戏开发、动画制作、虚拟现实等多个领域。
最新AI工具、AI资讯
独家AI资源、AI项目落地
微信扫一扫关注公众号