infini-transformer

infini-transformer

针对无限长度上下文设计的高效Transformer模型

Infini-Transformer是一种创新的Transformer模型,专门用于处理无限长度的上下文。该模型采用压缩性记忆机制和混合深度技术,能有效处理超长序列。Infini-Transformer支持文本分类、问答和语言生成等多种任务,并集成RoPE和YaRN等先进位置编码技术。这一模型为长文本处理和大规模语言任务提供了高效解决方案。

Infini-Transformer自然语言处理长序列处理注意力机制位置编码Github开源项目

Infini-Transformer

概述

Infini-Transformer(https://arxiv.org/abs/2404.07143)是一个强大而多功能的transformer模型,专为各种自然语言处理任务而设计。它利用最先进的技术和架构,实现了卓越的性能,并可扩展到无限的上下文长度。

特性

  • 可扩展的架构,能够处理长序列
  • 在多样化数据集上进行大规模预训练
  • 支持多种下游任务,包括文本分类、问答和语言生成
  • 高效的微调,适应特定任务
  • 包含一个融合了Infini-Attention的Mixture-of-Depths(https://arxiv.org/abs/2404.02258)transformer层
  • 实现了符合Infini-Attention和Mixture-of-Depth内存高效设计的RoPE(https://arxiv.org/abs/2104.09864
  • 实现了符合Infini-Attention和Mixture-of-Depth内存高效设计的YaRN(https://arxiv.org/abs/2309.00071

目录结构

infini-transformer/ ├── infini_transformer/ │ ├── __init__.py │ ├── transformer.py │ ├── compressive_memory.py │ ├── positional_embedder.py │ └── activations.py ├── examples/ │ ├── __init__.py │ └── modinfiniformer.py ├── tests/ │ ├── __init__.py │ └── test_transformer.py ├── LICENSE ├── README.md ├── requirements.txt ├── MANIFEST.in └── pyproject.toml

入门指南

要开始使用Infini-Transformer,你可以克隆仓库并从源代码安装:

git clone https://github.com/dingo-actual/infini-transformer.git cd infini-transformer pip install -e .

使用方法

CompressiveMemory

CompressiveMemory模块是Infini-Transformer架构的关键组件。它旨在通过压缩并将输入标记存储在内存矩阵和归一化向量中,有效地处理长序列。这使得模型能够在保持有限内存使用的同时维持大型上下文窗口。

它通过沿序列维度(假定为维度1)划分输入张量,执行多头自注意力的变体,并带有循环更新步骤。它首先对输入进行学习的线性投影,生成键、查询和值张量,然后从中提取每个循环步骤的片段。

在每个循环步骤中,它计算线性注意力(使用内存和归一化矩阵)和SDP注意力的学习线性组合。然后使用当前步骤的键和值矩阵,以及当前内存矩阵和归一化向量来更新内存矩阵和归一化向量。在输出之前,组合的注意力张量在所有头上堆叠,然后投影回输入维度。

每个循环步骤的输出沿序列维度(维度1)连接,生成最终的输出张量。

内存矩阵的更新有两种变体:线性和增量。

线性更新规则为: $$M_t = M_{t-1} + \bigl(\textrm{ELU}(K_{t-1}\bigr) + 1)^TV_{t-1}$$

增量更新规则为: $$M_t = M_{t-1} + \bigl(\textrm{ELU}(K_{t-1}) + 1\bigr)^T \biggl( V_{t-1} - \frac{(\textrm{ELU}(K_{t-1}) + 1)M_{t-1}}{(\textrm{ELU}(K_{t-1}) + 1)z_{t-1}}\biggr)$$

其中$M_i$是步骤$i$的内存矩阵,$z_i$是步骤$i$的归一化向量。$K$和$V$矩阵的下标表示它们对应的循环步骤。

计算尽可能沿嵌入维度堆叠,以高效利用多头注意力。

CompressiveMemory模块接受以下参数:

  • dim_input:张量的输入维度。
  • dim_key:键张量和查询张量的维度。
  • dim_value:值张量的维度。
  • num_heads:注意力头的数量。
  • segment_len:递归注意力计算中每个段的长度。
  • sampling_factor:如果使用混合深度(Mixture-of-Depths)则使用的采样因子(如果不使用混合深度则为None)。(默认为None。)
  • update:用于内存矩阵更新的类型。可以是"linear"或"delta"。(默认为"linear"。)
  • causal:是否在SDP计算中使用因果注意力(每个位置只能关注之前的位置)。(默认为False。)
  • positional_embedder:可选的PositionEmbeddings对象:RoPEEmbeddingsYaRNEmbeddings(默认为None。)
  • init_state_learnable:初始内存状态和归一化向量是否为可学习参数。(默认为False。)

CompressiveMemory模块的示例用法如下:

import torch from infini_transformer.compressive_memory import CompressiveMemory cm = CompressiveMemory( dim_input=768, dim_key=64, dim_value=64, num_heads=8, segment_len=2048, sampling_factor=None, update="linear", causal=True, positional_embedder="rope", init_state_learnable=False ) batch = torch.randn( 2, # 批量大小 65536, # 序列长度 768 # 输入维度 ) output = cm(batch)

在训练过程中,不需要对输出进行特殊处理。

InfiniTransformer

InfiniTransformer类实现了原始transformer的一个变体,它使用CompressiveMemory代替标准的自注意力机制。这使得模型能够通过压缩和存储输入标记到内存矩阵和归一化向量中来高效处理长序列。它利用CompressiveMemory模块执行多头自注意力的变体,并包含一个递归更新步骤。

InfiniTransformer与普通transformer的主要区别在于用CompressiveMemory替换了标准的多头自注意力机制。

InfiniTransformer模块接受以下参数:

  • dim_input:张量的输入维度。

  • dim_hidden:多头自注意力后应用的MLP的隐藏维度。

  • dim_key:键张量和查询张量的维度。

  • dim_value:值张量的维度。

  • num_heads:注意力头的数量。

  • activation:在MLP中应用的非线性激活函数。支持以下激活函数:

    • "relu":ReLU激活
    • "abs":绝对值激活
    • "gelu":高斯误差线性单元(GELU)激活
    • "swish":Swish激活
    • "swiglu":SwiGLU激活
    • "geglu":门控高斯误差线性单元(GeGELU)激活
    • "ffnglu":带门控线性单元的前馈网络(FFNGLU)激活
    • "ffngeglu":带门控高斯误差线性单元的前馈网络(FFNGeGLU)激活
    • "ffnswiglu":带Swish门控线性单元的前馈网络(FFNSwiGLU)激活
  • segment_len:递归注意力计算中每个段的长度。

  • update:用于内存矩阵更新的类型。可以是"linear"或"delta"。(默认为"linear"。)

  • causal:是否在SDP计算中使用因果注意力(每个位置只能关注之前的位置)。(默认为False。)

  • positional_embedder:可选的PositionEmbeddings对象:RoPEEmbeddingsYaRNEmbeddings(默认为None。)

  • init_state_learnable:初始内存状态和归一化向量是否为可学习参数。(默认为False。)

  • dropout:在MLP中应用的dropout率。(默认为0.0。)

InfiniTransformer模块的示例用法如下:

import torch from infini_transformer import InfiniTransformer tfm = InfiniTransformer( dim_input=768, dim_hidden=2048, dim_key=64, dim_value=64, num_heads=8, activation="ffngeglu", segment_len=2048, update="delta", causal=True, positional_embedder=None, init_state_learnable=False, dropout=0.1 ) batch = torch.randn( 2, # 批量大小 65536, # 序列长度 768 # 输入维度 ) output = tfm(batch)

在训练过程中,不需要对输出进行特殊处理。

MoDInfiniTransformer

MoDInfiniTransformer模块扩展了InfiniTransformer模块,引入了混合深度(Mixture-of-Depths)(Raposo等人;https://arxiv.org/abs/2404.02258)。MoDInfiniTransformer块将其输入进行学习的线性投影到单一维度,并使用具有最高前k个值的标记执行InfiniTransformer的操作,将所有剩余标记添加到残差连接中。这使得模型能够将其容量集中在输入序列中最重要的部分,进一步减少了整体计算和内存需求,比单独使用InfiniTransformer更加高效。

前k个选择通常会导致递归循环中的段具有不同的长度。我们通过在所有段中均匀分配选择来避免这种情况。 由于top-k选择的非因果性质,在推理时,投影到1维时产生的分数被视为独立二元分类器的logits。因此,我们在训练模型时为每个ModInfiniFormer层添加了一个额外的损失项,即logits与训练期间选择的top-k词元之间的二元交叉熵损失。

因此,ModInfiniTransformer的输出是由三个张量组成的元组:

  • 常规输出张量,其维度与输入张量相匹配
  • 形状为(batch_size * sequence_length, 1)的张量,表示训练期间选择的top-k词元的二元掩码。这将作为额外二元交叉熵损失的目标。
  • 形状为(batch_size * sequence_length, 1)的张量,包含与上述二元掩码对应的logits。这表示用于选择top-k词元的分数,被视为额外二元交叉熵损失的预测。

在推理时,可以安全地忽略元组的第二和第三个元素,因为所有词元选择逻辑都在MoDInfiniTransformer模块内部处理。

重要提示:基于二元分类器的词元选择机制在推理时无法保证为批次中的每个元素选择相同数量的词元。如果不加以控制,这将导致一个不规则数组,目前PyTorch不支持这种情况。当前的解决方案是将批次大小强制设为1,并在单个观察值上连接前向传播。我们意识到这并不是最优解,希望在不久的将来能够解决这个问题。

MoDInfiniTransformer模块接受以下参数:

  • dim_input:张量的输入维度。

  • dim_hidden:多头自注意力后应用的MLP的隐藏维度。

  • dim_key:键张量和查询张量的维度。

  • dim_value:值张量的维度。

  • num_heads:注意力头的数量。

  • activation:在MLP中应用的非线性激活函数。支持以下激活函数:

    • "relu":ReLU激活
    • "abs":绝对值激活
    • "gelu":高斯误差线性单元(GELU)激活
    • "swish":Swish激活
    • "swiglu":SwiGLU激活
    • "geglu":门控高斯误差线性单元(GeGELU)激活
    • "ffnglu":带门控线性单元的前馈网络(FFNGLU)激活
    • "ffngeglu":带门控高斯误差线性单元的前馈网络(FFNGeGLU)激活
    • "ffnswiglu":带Swish门控线性单元的前馈网络(FFNSwiGLU)激活
  • segment_len:循环注意力计算中每个段的长度。

  • sampling_factor:区间(1, segment_len)内的数值,决定在top-k选择期间从每个段中选择的词元数量。sampling_factor值越大,选择的词元越少。

  • update:用于更新记忆矩阵的方式。可以是"linear"或"delta"。(默认为"linear"。)

  • causal:在SDP计算中是否使用因果注意力(每个位置只能关注之前的位置)。(默认为False。)

  • positional_embedder:可选的PositionEmbeddings对象:RoPEEmbeddingsYaRNEmbeddings(默认为None。)

  • init_state_learnable:初始记忆状态和标准化向量是否为可学习参数。(默认为False。)

  • dropout:在MLP中应用的dropout率。(默认为0.0。)

InfiniTransformer模块的使用示例如下:

import torch from infini_transformer import MoDInfiniTransformer tfm = MoDInfiniTransformer( dim_input=768, dim_hidden=2048, dim_key=64, dim_value=64, num_heads=8, activation="ffngeglu", segment_len=2048, sampling_factor=8, update="delta", causal=True, init_state_learnable=False, positional_embedder=None, dropout=0.1 ) batch = torch.randn( 2, # 批次大小 65536, # 序列长度 768 # 输入维度 ) output, select_target, select_pred = tfm(batch)

在训练过程中,我们必须考虑MoDInfiniFormer的额外输出,以便将它们用于二元交叉熵损失。请参阅infini_transformer/example/modinfiniformer.py,了解如何将额外输出整合到整体模型输出和训练循环中的示例。

RoPEEmbeddings

RoPEEmbeddings模块应用了Su等人的论文"RoFormer: Enhanced Transformer with Rotary Position Embedding"(https://arxiv.org/abs/2104.09864)中的RoPE。一旦实例化,它可以作为positional_embedder参数传递给InfiniTransformerMoDInfiniTransformer模块,然后传递给CompressiveMemory,在那里将位置感知嵌入应用于键和查询张量。

RoPEEmbeddings模块接受以下参数:

  • dim: 键/值张量的维度。
  • seq_len: CompressiveMemory输入序列的最大长度(必须与CompressiveMemorysegment_len参数匹配)。
  • dim_embeddings_pct: 用于位置感知嵌入的键/值张量维度比例。例如,如果dim为64,dim_embeddings_pct为0.5,则将使用32个维度用于位置感知嵌入。(默认为0.5)
  • base: 用于位置嵌入角度的基值。(默认为10000)

RoPEEmbeddings模块的使用示例如下:

import torch from infini_transformer import InfiniTransformer from infini_transformer import RoPEEmbeddings embedder = RoPEEmbeddings( dim=64, # 必须与InfiniTransformer中的dim_key参数匹配 seq_len=2048, # 必须与InfiniTransformer中的segment_len参数匹配 dim_embeddings_pct=0.5, base=10000 ) tfm = InfiniTransformer( dim_input=768, dim_hidden=2048, dim_key=64, # 必须与RoPEEmbeddings中的dim参数匹配 dim_value=64, num_heads=8, activation="ffngeglu", segment_len=2048, # 必须与RoPEEmbeddings中的seq_len参数匹配 update="delta", causal=True, positional_embedder=embedder, init_state_learnable=False, dropout=0.1 ) batch = torch.randn( 2, # 批次大小 65536, # 序列长度 768 # 输入维度 ) output = tfm(batch)

YaRNEmbeddings

YaRNEmbeddings模块应用了Peng等人的论文"YaRN: Efficient Context Window Extension of Large Language Models"中的YaRN技术(https://arxiv.org/abs/2309.00071)。实例化后,它可以作为positional_embedder参数传递给InfiniTransformerMoDInfiniTransformer模块,然后传递给CompressiveMemory,在那里将位置感知嵌入应用于键和查询张量。

YaRNEmbeddings模块接受以下参数:

  • dim: 键/值张量的维度。
  • seq_len: CompressiveMemory输入序列的最大长度(必须与CompressiveMemorysegment_len参数匹配)。
  • context_len: 训练期间使用的上下文长度。
  • context_len_ext: 要扩展到的上下文长度。
  • dim_embeddings_pct: 用于位置感知嵌入的键/值张量维度比例。例如,如果dim为64,dim_embeddings_pct为0.5,则将使用32个维度用于位置感知嵌入。(默认为0.5)
  • base: 用于位置嵌入角度的基值。(默认为10000)
  • alpha: 动态缩放的插值最小值。(默认为1)
  • beta: 动态缩放的插值最小值。(默认为32)
  • len_scale: 注意力计算的长度缩放。默认为None(自动计算)。

YaRNEmbeddings模块的使用示例如下:

import torch from infini_transformer import InfiniTransformer from infini_transformer import YaRNEmbeddings embedder = YaRNEmbeddings( dim=64, # 必须与InfiniTransformer中的dim_key匹配 seq_len=2048, # 必须与InfiniTransformer中的segment_len参数匹配 context_len=32768, context_len_ext=65536, dim_embeddings_pct=0.5, base=10000, alpha=1, beta=32, len_scale=None ) tfm = InfiniTransformer( dim_input=768, dim_hidden=2048, dim_key=64, # 必须与YaRNEmbeddings中的dim匹配 dim_value=64, num_heads=8, activation="ffngeglu", segment_len=2048, # 必须与YaRNEmbeddings中的seq_len参数匹配 update="delta", causal=True, positional_embedder=embedder, init_state_learnable=False, dropout=0.1 ) batch = torch.randn( 2, # 批次大小 65536, # 序列长度 768 # 输入维度 ) output = tfm(batch)

使用示例

请参阅infini_transformer/example/modinfiniformer.py,了解使用MoDInfiniTransformer模块的模型和训练流程示例。

更多示例将陆续推出。

许可证

本项目采用MIT许可证

致谢

我们要感谢那些启发和促进Infini-Transformer和Mixture-of-Depths Transformer开发的研究人员和开发者。

同时,我们要特别感谢所有贡献者、合作者以及提供反馈的人。你们的努力使一个粗略的实现框架变成了真正可用的东西。

如果您有任何问题或需要进一步的帮助,请随时联系我,邮箱是ryan@beta-reduce.net

编辑推荐精选

蛙蛙写作

蛙蛙写作

AI小说写作助手,一站式润色、改写、扩写

蛙蛙写作—国内先进的AI写作平台,涵盖小说、学术、社交媒体等多场景。提供续写、改写、润色等功能,助力创作者高效优化写作流程。界面简洁,功能全面,适合各类写作者提升内容品质和工作效率。

AI辅助写作AI工具蛙蛙写作AI写作工具学术助手办公助手营销助手AI助手
Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
问小白

问小白

全能AI智能助手,随时解答生活与工作的多样问题

问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。

热门AI助手AI对话AI工具聊天机器人
Transly

Transly

实时语音翻译/同声传译工具

Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

下拉加载更多