self-rag

self-rag

通过自反学习使语言模型实现按需检索、生成和评估的框架

Self-RAG是一种创新框架,通过自反学习使语言模型实现按需检索、生成和评估。该方法预测反思标记,支持多次检索或跳过检索,并从多角度评估生成内容。这不仅提高了模型输出的事实性和质量,还保持了语言模型的通用性能。

Self-RAG语言模型检索增强生成自我反思关键词生成Github开源项目

SELF-RAG: 通过自我反思学习检索、生成和评判

这包括原始实现的SELF-RAG: 通过自我反思学习检索、生成和评判(ICLR 2024,口头报告前1%),作者为Akari Asai、Zeqiu Wu、Yizhong Wang、Avirup Sil和Hannaneh Hajishirzi。

网站 | 7B模型 | 13B模型 | 论文 | 训练数据 | Twitter摘要 | 更新

Self-RAG(右图)是一个新的框架,用于训练任意语言模型学习检索、生成和评判,以提高生成内容的事实性和质量,同时不影响大型语言模型的多功能性。

与广泛采用的检索增强生成(RAG;左图)方法不同,Self-RAG根据需求进行检索(例如,可以多次检索或完全跳过检索),针对不同的查询,并通过预测反思标记作为生成的组成部分,从多个细粒度方面对自身生成进行评判。我们进行分段束搜索,以选择能最大化多样化偏好效用的输出。

如果您发现我们的代码、数据、模型或论文有用,请引用以下论文:

@inproceedings{
asai2024selfrag,
author={Asai, Akari and Wu, Zeqiu and Wang, Yizhong and Sil, Avirup and Hajishirzi, Hannaneh},
title={Self-{RAG}: Learning to Retrieve, Generate, and Critique through Self-Reflection},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=hSyW5go0v8}
}

更新

  • 2023年10月:首次发布代码、模型和论文。

内容

  1. 安装
  2. 快速开始
  3. 检索器设置
  4. 训练
  5. 推理
  6. 基线
  7. 常见问题
  8. 联系方式

安装

通过运行以下命令安装依赖的Python库。

pip install -r requirements.txt

请使用最新版本的vllm,因为旧版本可能无法通过SamplingParam设置skip_special_tokens,这是由(这个PR)添加的。

您也可以通过运行以下命令创建conda环境。

conda env create -f environment.yml

快速开始

您可以从HuggingFace Hub下载Self-RAG。对于推理,我们建议使用vllm,因为它可以显著加快推理速度。

from vllm import LLM, SamplingParams model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half") sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False) def format_prompt(input, paragraph=None): prompt = "### 指令:\n{0}\n\n### 回复:\n".format(input) if paragraph is not None: prompt += "[检索]<段落>{0}</段落>".format(paragraph) return prompt query_1 = "找出不同项:twitter、instagram、whatsapp。" query_2 = "你能告诉我美洲驼和羊驼的区别吗?" queries = [query_1, query_2] # 对于不需要检索的查询 preds = model.generate([format_prompt(query) for query in queries], sampling_params) for pred in preds: print("模型预测: {0}".format(pred.outputs[0].text))

输出:

模型预测: Twitter、Instagram和WhatsApp都是社交媒体平台。[无检索]WhatsApp是不同项,因为它是一个消息应用,而Twitter和Instagram主要用于分享照片和视频。[效用:5]</s> 模型预测: 好的![检索]<段落><段落>

如您所见,在第一个查询中,当不需要检索时,Self-RAG开始生成回答而不进行检索。另一方面,对于第二个查询,Self-RAG输出了[检索]标记,因为这个问题需要更细粒度的事实依据。

对于需要事实依据的查询,您可以插入一个段落。Self-RAG可以在生成过程中随时检索和插入段落,只要它们被上下文标记特殊标记<段落></段落>包围,就能识别它们。

# 对于需要事实依据的查询
prompt = format_prompt("你能告诉我美洲驼和羊驼的区别吗?", "羊驼(Lama pacos)是南美洲骆驼科哺乳动物的一种。它与美洲驼相似,常常被混淆。羊驼比美洲驼小得多,与美洲驼不同的是,它们不是被培育为工作动物,而是专门为了它们的纤维而被培育。")
preds = model.generate([prompt], sampling_params)
print([pred.outputs[0].text for pred in preds])
# ['[相关]羊驼比美洲驼小得多,与美洲驼不同的是,它们不是被培育为工作动物,而是专门为了它们的纤维而被培育。[完全支持][效用:5]</s>']

Self-RAG找到相关的插入文档,并生成完全由证据支持的答案。

使用在线检索模型进行评估

您也可以按需进行检索并与Self-RAG一起使用。由于在完整的英文维基百科上运行检索需要大量RAM和多个GPU,我们为演示目的创建了一个只包含维基百科文章介绍段落的子集。

首先,请下载语料库和嵌入(共9GB)。

git clone git@github.com:AkariAsai/self-rag.git
cd retrieval_lm
bash download_demo_corpus.sh

如果脚本不起作用,您可以从Google DriveHF数据集下载数据。 然后,您可以在retrieval_lm下运行脚本。我们在1个RTX 6000(24GB)和100G RAM上测试了该脚本(但应该可以在更小的RAM上运行)。

from passage_retrieval import Retriever retriever = Retriever({}) retriever.setup_retriever_demo("facebook/contriever-msmarco", "enwiki_2020_intro_only/enwiki_2020_dec_intro_only.jsonl", "enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/*", n_docs=5, save_or_load_index=False) retrieved_documents = retriever.search_document_demo(query_3, 5) prompts = [format_prompt(query_3, doc["title"] +"\n"+ doc["text"]) for doc in retrieved_documents] preds = model.generate(prompts, sampling_params) top_doc = retriever.search_document_demo(query_3, 1)[0] print("参考: {0}\n模型预测: {1}".format(top_doc["title"] + "\n" + top_doc["text"], preds[0].outputs[0].text))

输出:

参考: 过拟合 在统计学中,过拟合是"产生一个与特定数据集过于紧密或完全对应的分析,因此可能无法可靠地适应额外数据或预测未来观察结果"。过拟合模型是一个包含比数据可以证明更多参数的统计模型。过拟合的本质是无意中将一些残差变异(即噪声)提取出来,就好像这种变异代表了潜在的模型结构。欠拟合发生在统计模型无法充分捕捉数据的潜在结构时。欠拟合模型是一个模型,其中一些在正确指定的模型中会出现的参数或项缺失 模型预测: [相关]过拟合发生在模型相对于其训练数据量而言具有太多参数时,导致它过度记忆训练数据,并在新的、未见过的数据上表现不佳。[完全支持][效用:5]</s>

检索系统正确检索了必要的文档并生成了完全有根据的输出。

请注意,此演示使用较小的语料库和具有完整推理算法的Self-RAG。对于完整评估,您需要设置检索器或下载我们的检索结果。请按照推理中的说明进行操作。

检索器设置

默认情况下,我们使用Contriever作为我们的检索组件。

下载数据

下载DPR中使用的预处理段落数据。

cd retrieval_lm
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz

然后,下载生成的段落。我们使用Contriever-MSMARCO

wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar

运行检索器

您可以通过运行以下命令来执行段落检索。

cd retrieval_lm
python passage_retrieval.py \
    --model_name_or_path facebook/contriever-msmarco --passages psgs_w100.tsv \
    --passages_embeddings "wikipedia_embeddings/*" \
    --data 你的输入文件  \
    --output_dir 你的输出文件 \
    --n_docs 20

你的输入文件应该是jsonjsonl格式。每个实例必须包含questioninstruction,这将在检索过程中用作查询。

为你自己的数据生成嵌入

你可以通过运行以下命令为你自己的数据生成嵌入。(该脚本改编自Contriever仓库。)请注意,为大规模语料库(>1000万文档)生成嵌入可能需要一些时间,我们建议在多个GPU上运行。

cd retrieval_lm
for i in {0..3}; do
  export CUDA_VISIBLE_DEVICES=${i}
  python generate_passage_embeddings.py  --model_name_or_path facebook/contriever-msmarco \
  --output_dir 你的输出目录 \
  --passages 你的段落数据 --shard_id ${i}  --num_shards 4 > ./log/nohup.my_embeddings.${i} 2>&1 &

训练

Self-RAG训练两个模型,CriticGenerator,它们都扩展了反思标记的词汇表,并使用标准的下一个标记预测目标进行训练。

或者,你可以下载我们由15万个实例组成的训练数据这里

收集反思标记

我们从GPT-4收集训练数据。用于为每种特殊标记类型调用GPT-4的脚本可在data_creation/critic获得。

或者,你可以在这里下载我们的训练数据。

Critic训练

一旦你创建或下载了训练数据,运行下面的命令对Llama2-7B进行critic训练的微调。

cd data_creation
torchrun --nproc_per_node=2 \
  --master_port=2568 train_special_tokens.py \
  --model_name_or_path meta-llama/Llama-2-7b-hf \
  --data_path 训练数据文件路径 \
  --bf16  True \
  --output_dir CRITIC模型路径 \
  --num_train_epochs 3  \
  --per_device_train_batch_size 1 --per_device_eval_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --evaluation_strategy "no" \
  --save_strategy "steps" \
  --save_steps 300 \
  --save_total_limit 1 \
  --learning_rate 2e-5 \
  --weight_decay 0. \
  --warmup_ratio 0.01 \
  --lr_scheduler_type "cosine" \
  --logging_steps 10 \
  --fsdp "full_shard auto_wrap"

Generator数据创建

创建Generator训练数据的代码在generator_data_creation下。请参阅README.md中的说明。

或者,你可以在HuggingFace数据集这里下载我们的训练数据

Generator训练

对于generator训练,我们使用DeepSpeed来提高训练效率。你可以通过运行下面的脚本来进行训练,设置好训练数据路径后。

cd retrieval_lm
bash script_finetune_7b.sh

对于13B模型训练,使用training_13b。我们使用8个40GB内存的A100进行7B模型训练,使用4个80GB内存的A100进行13B训练。7B模型应该可以在1-2个A100上运行,尽管训练可能会很慢。

推理

对于论文中进行的任务评估,请在这里下载数据。

每个文件都已经包含了检索到的文档,所以如果你不想在推理过程中运行检索器,你可以简单地在contexts加载检索到的文档。

下面,我们描述Self-RAG和基线。

短文本生成(PubHealth, ARC-Challenge, TriviaQA, PopQA)

由于我们通常只为短文本生成任务检索一次,我们提供了一个易于运行的评估脚本,利用Contriever离线预先检索的文档。请参见下面的各个命令。

问答

python run_short_form.py \
--model_name selfrag/selfrag_llama2_7b \
--input_file eval_data/popqa_longtail_w_gs.jsonl \
--mode 模式 --max_new_tokens 100 \
--threshold 0.2 \
--output_file 你的输出文件 \
--metric match --ndocs 10 --use_groundness --use_utility --use_seqscore \
--dtype half

mode指定推理时的模式,可选择['adaptive_retrieval', 'no_retrieval', 'always_retrieve']

  • adaptive_retrieval根据threshold或Self-RAG预测进行检索
  • no_retrieval在推理时禁用检索
  • always_retrieve始终进行检索。

对于13B,如果你在单个24GB内存的GPU上运行,可能会遇到内存不足的问题。你可以通过设置--world_size在多个GPU上运行推理。

ARC Challenge

python run_short_form.py \
  --model_name selfrag/selfrag_llama2_7b \
  --input_file eval_data/arc_challenge_processed.jsonl \
  --max_new_tokens 50 --threshold 0.2 \
  --output_file 输出文件名 \
  --metric match --ndocs 5 --use_groundness --use_utility --use_seqscore \
  --task arc_c

PubHealth

python run_short_form.py \
  --model_name selfrag/selfrag_llama2_7b \
  --input_file eval_data/health_claims_processed.jsonl \
  --max_new_tokens 50 \
  --threshold 0.2 --output_file 输出文件名 \
  --metric match --ndocs 5 \
  --use_groundness --use_utility --use_seqscore \
  --task fever

长文本生成(ASQA, FactScore)

对于长文本问答,你可以使用检索模型运行评估,也可以使用预先给定的段落运行评估。 目前,我们正在努力减少运行时内存需求(DPR / Contriever与整个英语维基百科嵌入需要100 GB RAM),加快长文本生成的速度,并首先发布使用一小组初始检索文档(~20)的推理代码。

注意:我们当前的实现专门为目标任务数据集的评估而设计。我们计划更新我们的代码库,使接口更简单,更易于使用。当我们发布另一个版本时,我们会宣布。

使用预先检索的段落运行推理

对于ASQA,请运行以下命令,

python run_long_form_static.py \
  --model_name selfrag/selfrag_llama2_7b \
  --ndocs 5 --max_new_tokens 300 --threshold 0.2 \
  --use_grounding --use_utility --use_seqscore \
  --task asqa --input_file eval_data/asqa_eval_gtr_top100.json \
  --output_file 你的输出文件名 --max_depth 7 --mode always_retrieve \

对于FactScore,

python run_long_form_static.py \
  --model_name selfrag/selfrag_llama2_7b \
  --ndocs 5 --max_new_tokens 300 --threshold 0.2 \
  --use_grounding --use_utility --use_seqscore \
  --task factscore --input_file eval_data/factscore_unlabeled_alpaca_13b_retrieval.jsonl \
  --output_file 你的输出文件名 --max_depth 7 \
长文本生成的关键参数

Self-RAG的推理有几个关键参数。

  • w_rel(默认1.0):w_rel控制在波束搜索过程中对isRel(评判检索到的段落是否相关的评价标记)标记概率的强调。
  • w_sup(默认1.0):w_sup控制在波束搜索过程中对isSup(评判生成是否被文档支持的评价标记)标记概率的强调。
  • w_use(默认0.5):w_use控制在波束搜索过程中对isUse(整体质量的评价标记)标记概率的强调。
  • threshold(默认0.2):此阈值控制自适应检索的频率。
  • max_depth(默认6):这对应于论文中的T,它定义了搜索的最大深度。
  • beam_width(默认2):这控制段级波束搜索中波束的大小。

更多详细信息,请参阅我们论文中的详细说明(第3.3节)和分析(第5节)。

运行评估

对于长文本评估,设置外部库或仓库以运行评估。

  • factscore==v0.1.5(生物) 请按照FactScore官方仓库的说明设置你的环境。
python -m factscore.factscorer --data_path 你的输出文件  --model_name retrieval+ChatGPT --cache_dir 你的缓存目录 --openai_key 你的OPEN_AI_密钥 --verbose

ALCE 为长篇问答提供了使用多种不同指标的全面评估。对于您的首次评估,请安装 ALCE 仓库并下载数据。

git clone https://github.com/princeton-nlp/ALCE.git
python3 -m alce_env
cd ALCE
bash download_data.sh

对于 ASQA,您可以按以下方式运行评估。请注意,ASQA 评估需要基于 T5-XXL (11B) 的 NLI 模块。

python eval.py --f YOUR_OUTPUT_FILE --citations --qa --mauve

基准测试

重新运行基准测试的代码可在 run_baseline_lm.py 找到。 要运行检索增强基准测试,请确保下载包含检索段落的任务输入文件。

普通语言模型基准测试

  • Huggingface 模型
python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
 --max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa --prompt_name "prompt_no_input"

例如,PubHealth

python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/health_claims_processed.jsonl \
--max_new_tokens 20 \
--metric accuracy \
--result_fp llama2_7b_pubhealth_results.json \
--task fever

注意:对于 PubHealth 和 ARC,请传入任务名称(ARC = arc_c 和 PubHealth = fever)以正确设置指令。

  • OpenAI API

对于 OpenAI API 模型,您还需要在这里设置组织密钥。您还需要有一个包含 OpenAI API 密钥的 txt 文件。

python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
 --task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--prompt_name "prompt_no_input"

检索增强基准测试

  • Huggingface 模型
python run_baseline_refactor.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
 --max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"
  • OpenAI API
python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
 --task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"

常见问题

问题1:如何使用 Self-RAG 方案训练新的预训练语言模型? -- 如果您使用的是 Hugging Face transformers,您可以简单地在我们的训练脚本 script_finetune_7b.sh 中更改 model_name_or_pathtokenizer_name。如果您想使用自己的微调脚本,请确保添加特殊标记并屏蔽段落上下文,如此问题中所讨论的。

问题2:你们计划发布基于 Mistral-7B 的 Self-RAG 吗? -- 目前我的时间有限,无法这样做,但社区已经训练了一个基于 Mistral-7B 的 Self-RAG 版本 SciPhi-Self-RAG-Mistral-7B-32k。如果我们能够在 Mistral-7B 上训练 Self-RAG 并发布检查点,我们会通知大家。

联系方式

如果您有问题,请提出一个问题并提及 @AkariAsai,或发送电子邮件至 akari[at]cs.washington.edu。

编辑推荐精选

讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

模型训练热门AI工具内容创作智能问答AI开发讯飞星火大模型多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

热门AI工具生产力协作转型TraeAI IDE
咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

AI助手热门AI工具AI创作AI辅助写作讯飞绘文内容运营个性化文章多平台分发
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

Hunyuan3D-2

Hunyuan3D-2

高分辨率纹理 3D 资产生成

Hunyuan3D-2 是腾讯开发的用于 3D 资产生成的强大工具,支持从文本描述、单张图片或多视角图片生成 3D 模型,具备快速形状生成能力,可生成带纹理的高质量 3D 模型,适用于多个领域,为 3D 创作提供了高效解决方案。

3FS

3FS

一个具备存储、管理和客户端操作等多种功能的分布式文件系统相关项目。

3FS 是一个功能强大的分布式文件系统项目,涵盖了存储引擎、元数据管理、客户端工具等多个模块。它支持多种文件操作,如创建文件和目录、设置布局等,同时具备高效的事件循环、节点选择和协程池管理等特性。适用于需要大规模数据存储和管理的场景,能够提高系统的性能和可靠性,是分布式存储领域的优质解决方案。

TRELLIS

TRELLIS

用于可扩展和多功能 3D 生成的结构化 3D 潜在表示

TRELLIS 是一个专注于 3D 生成的项目,它利用结构化 3D 潜在表示技术,实现了可扩展且多功能的 3D 生成。项目提供了多种 3D 生成的方法和工具,包括文本到 3D、图像到 3D 等,并且支持多种输出格式,如 3D 高斯、辐射场和网格等。通过 TRELLIS,用户可以根据文本描述或图像输入快速生成高质量的 3D 资产,适用于游戏开发、动画制作、虚拟现实等多个领域。

下拉加载更多