flash-linear-attention

flash-linear-attention

Triton实现的高效线性注意力模型库

Flash Linear Attention是一个基于Triton实现的线性注意力模型库。该项目集成了RetNet、GLA和Based等多种先进模型,实现了高效的token混合和文本生成。兼容Hugging Face Transformers库,提供预训练模型、评估工具和基准测试,为线性注意力技术的研究和应用提供了便利。

Flash Linear Attention线性注意力模型Triton实现深度学习自然语言处理Github开源项目
<div align="center">

快速线性注意力

hf_model | Discord

</div>

本仓库旨在提供一系列基于Triton的高效实现,用于最先进的线性注意力模型。欢迎任何拉取请求!

<div align="center"> <img width="400" alt="image" src="https://github.com/sustcsonglin/flash-linear-attention/assets/18402347/02ff2e26-1495-4088-b701-e72cd65ac6cf"> </div>

模型

日期模型标题论文代码FLA实现
2023-07RetNet (@MSRA@THU)保留网络:大型语言模型的Transformer继任者[arxiv][官方] [RetNet]代码
2023-12GLA (@MIT@IBM)具有硬件高效训练的门控线性注意力Transformer[arxiv][官方]代码
2023-12Based (@Stanford@Hazyresearch)一个教育性且有效的序列混合器[博客][官方]代码
2024-01Rebased具有可学习核函数的线性Transformer是更好的上下文模型[arxiv][官方]代码
2021-02Delta Net线性Transformer实际上是快速权重编程器[arxiv][官方]代码
2023-09Hedgehog (@HazyResearch)刺猬和豪猪:具有Softmax模仿的表达性线性注意力openreview代码
2023-10PolySketchFormer (@CMU@Google)通过多项式核草图实现快速Transformerarxiv待完成
2023-07TransnormerLLM一种更快更好的大型语言模型,采用改进的TransNormer(@上海人工智能实验室)openreview arxiv[官方] [Lightning2]待完成
2023-05RWKV-v4 (@BlinkDL)为Transformer时代重新发明RNNarxiv[官方]待完成
2023-10GateLoop用于序列建模的完全数据控制的线性递归openreview arxiv[官方] [jax]待完成
2021-10ABC (@UW)具有有界内存控制的注意力arxiv代码
2023-09VQ-transformer通过向量量化实现线性时间Transformerarxiv[官方]待完成
2023-09HGRN用于序列建模的分层门控递归神经网络openreview[官方]代码
2024-04HGRN2HGRN2:具有状态扩展的门控线性RNNarxiv[官方]代码
2024-04RWKV6鹰和雀鹀:具有矩阵值状态和动态递归的RWKVarxiv[官方]代码
2024-06SambaSamba:用于高效无限上下文语言建模的简单混合状态空间模型arxiv[官方]代码
2024-05Mamba2Transformer是SSM:通过结构化状态空间对偶性实现广义模型和高效算法arxiv[官方]代码

安装

需满足以下要求:

  • PyTorch >= 2.0
  • Triton >=2.2
  • einops 由于fla目前正在积极开发中,暂时没有提供已发布的软件包。 如果您确实需要使用fla的操作/模块并考虑进一步探索,可以通过以下方式从源代码安装软件包
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention

或者使用子模块管理fla

git submodule add https://github.com/sustcsonglin/flash-linear-attention.git 3rdparty/flash-linear-attention ln -s 3rdparty/flash-linear-attention/fla fla

[!注意] 如果您没有使用Triton v2.2或其每夜版本,请注意FusedChunk实现可能存在潜在问题,详见此问题。 您可以运行测试python tests/test_fused_chunk.py来检查您的版本是否受到类似编译器问题的影响。 虽然我们为Triton<=2.1提供了一些修复方案,但请注意这些可能会导致性能下降。

对于Triton 2.2和更早版本(最高2.1),您可以可靠地使用Chunk版本(隐藏状态具体化到HBM中)。 经过仔细优化,这个版本在大多数情况下通常能提供高性能。

使用方法

令牌混合

我们在fla.layers中提供了"令牌混合"线性注意力层供您使用。 您可以用其他线性注意力层替换模型中的标准多头注意力层。 使用示例如下:

>>> import torch >>> from fla.layers import MultiScaleRetention >>> batch_size, num_heads, seq_len, hidden_size, = 32, 4, 2048, 1024 >>> device, dtype = 'cuda:0', torch.bfloat16 >>> retnet = MultiScaleRetention(hidden_size=hidden_size, num_heads=num_heads).to(device=device, dtype=dtype) >>> x = torch.randn(batch_size, seq_len, hidden_size).to(device=device, dtype=dtype) >>> y, *_ = retnet(x) >>> y.shape torch.Size([32, 2048, 1024])

我们提供了与🤗 Transformers库兼容的模型实现。 以下是如何从fla中的默认配置初始化GLA模型的示例:

>>> from fla.models import GLAConfig >>> from transformers import AutoModel >>> config = GLAConfig() >>> config GLAConfig { "attn_mode": "fused_chunk", "bos_token_id": 1, "clamp_min": null, "conv_size": 4, "eos_token_id": 2, "expand_k": 0.5, "expand_v": 1, "fuse_cross_entropy": true, "fuse_norm": true, "hidden_act": "swish", "hidden_ratio": 4, "hidden_size": 2048, "initializer_range": 0.02, "intermediate_size": null, "max_position_embeddings": 2048, "model_type": "gla", "num_heads": 4, "num_hidden_layers": 24, "rms_norm_eps": 1e-06, "share_conv_kernel": true, "tie_word_embeddings": false, "transformers_version": "4.39.1", "use_cache": true, "use_gk": true, "use_gv": false, "use_short_conv": false, "vocab_size": 32000 } >>> AutoModel.from_config(config) GLAModel( (embed_tokens): Embedding(32000, 2048) (layers): ModuleList( (0-23): 24 x GLABlock( (attn_norm): RMSNorm() (attn): GatedLinearAttention( (gate_fn): SiLU() (q_proj): Linear(in_features=2048, out_features=1024, bias=False) (k_proj): Linear(in_features=2048, out_features=1024, bias=False) (v_proj): Linear(in_features=2048, out_features=2048, bias=False) (g_proj): Linear(in_features=2048, out_features=2048, bias=False) (gk_proj): Sequential( (0): Linear(in_features=2048, out_features=16, bias=False) (1): Linear(in_features=16, out_features=1024, bias=True) ) (o_proj): Linear(in_features=2048, out_features=2048, bias=False) (g_norm_swish_gate): FusedRMSNormSwishGate() ) (mlp_norm): RMSNorm() (mlp): GLAMLP( (gate_proj): Linear(in_features=2048, out_features=11264, bias=False) (down_proj): Linear(in_features=5632, out_features=2048, bias=False) (act_fn): SiLU() ) ) ) (norm): RMSNorm() )

生成

成功预训练模型后,就可以使用🤗文本生成API来生成文本。 以下是一个生成示例:

>>> import fla >>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> name = 'fla-hub/gla-1.3B-100B' >>> tokenizer = AutoTokenizer.from_pretrained(name) >>> model = AutoModelForCausalLM.from_pretrained(name).cuda() >>> input_prompt = "Power goes with permanence. Impermanence is impotence. And rotation is castration." >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda() >>> outputs = model.generate(input_ids, max_length=64) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

我们还提供了一个简单的脚本这里用于基准测试生成速度。 只需运行:

$ python -m benchmarks.benchmark_generation \ --path 'fla-hub/gla-1.3B-100B' \ --repetition_penalty 2. \ --prompt="Hello everyone, I'm Songlin Yang" 提示: Hello everyone, I'm Songlin Yang 生成: Hello everyone, I'm Songlin Yang. I am a 20 year old girl from China who is currently studying in the United States of America for my Master degree and also working as an English teacher at school here on campus since last summer (1st semester). My main goal to be able do well with this course so that we can have 提示长度:10,生成长度:64 总提示处理 + 解码时间:4593ms

所有当前可用的预训练模型都可以在fla-hub中找到。

>>> from huggingface_hub import list_models >>> for model in list_models(author='fla-hub'): print(model.id)

评估

lm-evaluation-harness库允许您轻松执行(零样本)模型评估。 按照以下步骤使用此库:

  1. 按照他们的说明安装lm_eval

  2. 运行评估:

$ PATH='fla-hub/gla-1.3B-100B' $ python -m evals.harness --model hf \ --model_args pretrained=$PATH,dtype=bfloat16 \ --tasks wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,boolq,sciq,copa,openbookqa \ --batch_size 64 \ --num_fewshot 0 \ --device cuda \ --show_config

我们已经使fla与hf风格的评估兼容,您可以调用evals.harness来完成评估。 运行上述命令将提供GLA论文中报告的任务结果。

[!提示] 如果您将lm-evaluation-harness作为外部库使用,却发现(几乎)没有可用的任务,在调用lm_eval.evaluate()lm_eval.simple_evaluate()之前,只需运行以下命令来加载库的默认任务!

>>> from lm_eval.tasks import TaskManager; TaskManager().initialize_tasks()

基准测试

我们将基于Triton的RetNet实现与基于CUDA的FlashAttention2进行了比较,使用批量大小为8、32个头部和128的头部维度,在不同的序列长度下进行测试。这些测试在单个A100 80GB GPU上进行,如下图所示:

# 你可能需要先通过 `pip install -e .` 安装 `fla` 以启用其导入 $ python benchmark_retention.py 性能: seq_len fused_chunk_fwd chunk_fwd parallel_fwd fused_chunk_fwdbwd chunk_fwdbwd parallel_fwdbwd flash_fwd flash_fwdbwd 0 128.0 0.093184 0.185344 0.067584 1.009664 1.591296 1.044480 0.041984 0.282624 1 256.0 0.165888 0.219136 0.126976 1.024000 1.596928 1.073152 0.074752 0.413696 2 512.0 0.308224 0.397312 0.265216 1.550336 1.603584 1.301504 0.156672 0.883712 3 1024.0 0.603136 0.747520 0.706560 3.044864 3.089408 3.529728 0.467968 2.342912 4 2048.0 1.191424 1.403904 2.141184 6.010880 6.059008 11.009024 1.612800 7.135232 5 4096.0 2.377728 2.755072 7.392256 11.932672 11.938816 37.792770 5.997568 24.435200 6 8192.0 4.750336 5.491712 26.402817 23.759359 23.952385 141.014023 22.682114 90.619904 7 16384.0 9.591296 10.870784 101.262337 47.666176 48.745472 539.853821 91.346947 346.318848

性能

线性注意力的不同形式

关于线性注意力不同形式的硬件考虑,请参考GLA论文的第2.3节。

  • 并行:自注意力风格的计算,时间复杂度为O(L^2),具有序列并行性。
  • 融合递归:递归计算,时间复杂度为O(L)。隐藏状态在共享内存中即时计算,无需物化到全局内存(详见此论文的算法1)。这节省了大量I/O成本,应该是速度比较的强基准。
  • 融合分块:分块计算,时间复杂度为O(LC),其中C是块大小。隐藏状态同样即时计算,不物化到全局内存。这个版本通常比融合递归更好,因为可以使用张量核心进行序列级"归约",而融合递归完全无法使用张量核心。注意,此实现中没有序列级并行性,因此不适合非常小的批量大小设置。应比并行分块更节省内存。
  • 并行分块:具有序列并行性的分块计算。需要为每个块将隐藏状态物化到全局内存。需要适当设置C以获得良好性能,因为当C小时,需要加载/存储到全局内存的隐藏状态太多;当C太大时,浮点运算量高。推荐的C值为[64, 128, 256]。

引用

如果您觉得这个仓库有用,请考虑引用我们的工作:

@article{yang2024delta, title = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length}, author = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim}, journal = {arXiv preprint arXiv:2406.06484}, year = {2024}, } @article{yang2023gated, title = {Gated Linear Attention Transformers with Hardware-Efficient Training}, author = {Yang, Songlin and Wang, Bailin and Shen, Yikang and Panda, Rameswar and Kim, Yoon}, journal = {arXiv preprint arXiv:2312.06635}, year = {2023} } @software{yang2024fla, title = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism}, author = {Yang, Songlin and Zhang, Yu}, url = {https://github.com/sustcsonglin/flash-linear-attention}, month = jan, year = {2024} }

编辑推荐精选

蛙蛙写作

蛙蛙写作

AI小说写作助手,一站式润色、改写、扩写

蛙蛙写作—国内先进的AI写作平台,涵盖小说、学术、社交媒体等多场景。提供续写、改写、润色等功能,助力创作者高效优化写作流程。界面简洁,功能全面,适合各类写作者提升内容品质和工作效率。

AI辅助写作AI工具蛙蛙写作AI写作工具学术助手办公助手营销助手AI助手
Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
问小白

问小白

全能AI智能助手,随时解答生活与工作的多样问题

问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。

热门AI助手AI对话AI工具聊天机器人
Transly

Transly

实时语音翻译/同声传译工具

Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

下拉加载更多