flash-linear-attention

flash-linear-attention

Triton实现的高效线性注意力模型库

Flash Linear Attention是一个基于Triton实现的线性注意力模型库。该项目集成了RetNet、GLA和Based等多种先进模型,实现了高效的token混合和文本生成。兼容Hugging Face Transformers库,提供预训练模型、评估工具和基准测试,为线性注意力技术的研究和应用提供了便利。

Flash Linear Attention线性注意力模型Triton实现深度学习自然语言处理Github开源项目
<div align="center">

快速线性注意力

hf_model | Discord

</div>

本仓库旨在提供一系列基于Triton的高效实现,用于最先进的线性注意力模型。欢迎任何拉取请求!

<div align="center"> <img width="400" alt="image" src="https://github.com/sustcsonglin/flash-linear-attention/assets/18402347/02ff2e26-1495-4088-b701-e72cd65ac6cf"> </div>

模型

日期模型标题论文代码FLA实现
2023-07RetNet (@MSRA@THU)保留网络:大型语言模型的Transformer继任者[arxiv][官方] [RetNet]代码
2023-12GLA (@MIT@IBM)具有硬件高效训练的门控线性注意力Transformer[arxiv][官方]代码
2023-12Based (@Stanford@Hazyresearch)一个教育性且有效的序列混合器[博客][官方]代码
2024-01Rebased具有可学习核函数的线性Transformer是更好的上下文模型[arxiv][官方]代码
2021-02Delta Net线性Transformer实际上是快速权重编程器[arxiv][官方]代码
2023-09Hedgehog (@HazyResearch)刺猬和豪猪:具有Softmax模仿的表达性线性注意力openreview代码
2023-10PolySketchFormer (@CMU@Google)通过多项式核草图实现快速Transformerarxiv待完成
2023-07TransnormerLLM一种更快更好的大型语言模型,采用改进的TransNormer(@上海人工智能实验室)openreview arxiv[官方] [Lightning2]待完成
2023-05RWKV-v4 (@BlinkDL)为Transformer时代重新发明RNNarxiv[官方]待完成
2023-10GateLoop用于序列建模的完全数据控制的线性递归openreview arxiv[官方] [jax]待完成
2021-10ABC (@UW)具有有界内存控制的注意力arxiv代码
2023-09VQ-transformer通过向量量化实现线性时间Transformerarxiv[官方]待完成
2023-09HGRN用于序列建模的分层门控递归神经网络openreview[官方]代码
2024-04HGRN2HGRN2:具有状态扩展的门控线性RNNarxiv[官方]代码
2024-04RWKV6鹰和雀鹀:具有矩阵值状态和动态递归的RWKVarxiv[官方]代码
2024-06SambaSamba:用于高效无限上下文语言建模的简单混合状态空间模型arxiv[官方]代码
2024-05Mamba2Transformer是SSM:通过结构化状态空间对偶性实现广义模型和高效算法arxiv[官方]代码

安装

需满足以下要求:

  • PyTorch >= 2.0
  • Triton >=2.2
  • einops 由于fla目前正在积极开发中,暂时没有提供已发布的软件包。 如果您确实需要使用fla的操作/模块并考虑进一步探索,可以通过以下方式从源代码安装软件包
pip install -U git+https://github.com/sustcsonglin/flash-linear-attention

或者使用子模块管理fla

git submodule add https://github.com/sustcsonglin/flash-linear-attention.git 3rdparty/flash-linear-attention ln -s 3rdparty/flash-linear-attention/fla fla

[!注意] 如果您没有使用Triton v2.2或其每夜版本,请注意FusedChunk实现可能存在潜在问题,详见此问题。 您可以运行测试python tests/test_fused_chunk.py来检查您的版本是否受到类似编译器问题的影响。 虽然我们为Triton<=2.1提供了一些修复方案,但请注意这些可能会导致性能下降。

对于Triton 2.2和更早版本(最高2.1),您可以可靠地使用Chunk版本(隐藏状态具体化到HBM中)。 经过仔细优化,这个版本在大多数情况下通常能提供高性能。

使用方法

令牌混合

我们在fla.layers中提供了"令牌混合"线性注意力层供您使用。 您可以用其他线性注意力层替换模型中的标准多头注意力层。 使用示例如下:

>>> import torch >>> from fla.layers import MultiScaleRetention >>> batch_size, num_heads, seq_len, hidden_size, = 32, 4, 2048, 1024 >>> device, dtype = 'cuda:0', torch.bfloat16 >>> retnet = MultiScaleRetention(hidden_size=hidden_size, num_heads=num_heads).to(device=device, dtype=dtype) >>> x = torch.randn(batch_size, seq_len, hidden_size).to(device=device, dtype=dtype) >>> y, *_ = retnet(x) >>> y.shape torch.Size([32, 2048, 1024])

我们提供了与🤗 Transformers库兼容的模型实现。 以下是如何从fla中的默认配置初始化GLA模型的示例:

>>> from fla.models import GLAConfig >>> from transformers import AutoModel >>> config = GLAConfig() >>> config GLAConfig { "attn_mode": "fused_chunk", "bos_token_id": 1, "clamp_min": null, "conv_size": 4, "eos_token_id": 2, "expand_k": 0.5, "expand_v": 1, "fuse_cross_entropy": true, "fuse_norm": true, "hidden_act": "swish", "hidden_ratio": 4, "hidden_size": 2048, "initializer_range": 0.02, "intermediate_size": null, "max_position_embeddings": 2048, "model_type": "gla", "num_heads": 4, "num_hidden_layers": 24, "rms_norm_eps": 1e-06, "share_conv_kernel": true, "tie_word_embeddings": false, "transformers_version": "4.39.1", "use_cache": true, "use_gk": true, "use_gv": false, "use_short_conv": false, "vocab_size": 32000 } >>> AutoModel.from_config(config) GLAModel( (embed_tokens): Embedding(32000, 2048) (layers): ModuleList( (0-23): 24 x GLABlock( (attn_norm): RMSNorm() (attn): GatedLinearAttention( (gate_fn): SiLU() (q_proj): Linear(in_features=2048, out_features=1024, bias=False) (k_proj): Linear(in_features=2048, out_features=1024, bias=False) (v_proj): Linear(in_features=2048, out_features=2048, bias=False) (g_proj): Linear(in_features=2048, out_features=2048, bias=False) (gk_proj): Sequential( (0): Linear(in_features=2048, out_features=16, bias=False) (1): Linear(in_features=16, out_features=1024, bias=True) ) (o_proj): Linear(in_features=2048, out_features=2048, bias=False) (g_norm_swish_gate): FusedRMSNormSwishGate() ) (mlp_norm): RMSNorm() (mlp): GLAMLP( (gate_proj): Linear(in_features=2048, out_features=11264, bias=False) (down_proj): Linear(in_features=5632, out_features=2048, bias=False) (act_fn): SiLU() ) ) ) (norm): RMSNorm() )

生成

成功预训练模型后,就可以使用🤗文本生成API来生成文本。 以下是一个生成示例:

>>> import fla >>> from transformers import AutoModelForCausalLM, AutoTokenizer >>> name = 'fla-hub/gla-1.3B-100B' >>> tokenizer = AutoTokenizer.from_pretrained(name) >>> model = AutoModelForCausalLM.from_pretrained(name).cuda() >>> input_prompt = "Power goes with permanence. Impermanence is impotence. And rotation is castration." >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda() >>> outputs = model.generate(input_ids, max_length=64) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

我们还提供了一个简单的脚本这里用于基准测试生成速度。 只需运行:

$ python -m benchmarks.benchmark_generation \ --path 'fla-hub/gla-1.3B-100B' \ --repetition_penalty 2. \ --prompt="Hello everyone, I'm Songlin Yang" 提示: Hello everyone, I'm Songlin Yang 生成: Hello everyone, I'm Songlin Yang. I am a 20 year old girl from China who is currently studying in the United States of America for my Master degree and also working as an English teacher at school here on campus since last summer (1st semester). My main goal to be able do well with this course so that we can have 提示长度:10,生成长度:64 总提示处理 + 解码时间:4593ms

所有当前可用的预训练模型都可以在fla-hub中找到。

>>> from huggingface_hub import list_models >>> for model in list_models(author='fla-hub'): print(model.id)

评估

lm-evaluation-harness库允许您轻松执行(零样本)模型评估。 按照以下步骤使用此库:

  1. 按照他们的说明安装lm_eval

  2. 运行评估:

$ PATH='fla-hub/gla-1.3B-100B' $ python -m evals.harness --model hf \ --model_args pretrained=$PATH,dtype=bfloat16 \ --tasks wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,boolq,sciq,copa,openbookqa \ --batch_size 64 \ --num_fewshot 0 \ --device cuda \ --show_config

我们已经使fla与hf风格的评估兼容,您可以调用evals.harness来完成评估。 运行上述命令将提供GLA论文中报告的任务结果。

[!提示] 如果您将lm-evaluation-harness作为外部库使用,却发现(几乎)没有可用的任务,在调用lm_eval.evaluate()lm_eval.simple_evaluate()之前,只需运行以下命令来加载库的默认任务!

>>> from lm_eval.tasks import TaskManager; TaskManager().initialize_tasks()

基准测试

我们将基于Triton的RetNet实现与基于CUDA的FlashAttention2进行了比较,使用批量大小为8、32个头部和128的头部维度,在不同的序列长度下进行测试。这些测试在单个A100 80GB GPU上进行,如下图所示:

# 你可能需要先通过 `pip install -e .` 安装 `fla` 以启用其导入 $ python benchmark_retention.py 性能: seq_len fused_chunk_fwd chunk_fwd parallel_fwd fused_chunk_fwdbwd chunk_fwdbwd parallel_fwdbwd flash_fwd flash_fwdbwd 0 128.0 0.093184 0.185344 0.067584 1.009664 1.591296 1.044480 0.041984 0.282624 1 256.0 0.165888 0.219136 0.126976 1.024000 1.596928 1.073152 0.074752 0.413696 2 512.0 0.308224 0.397312 0.265216 1.550336 1.603584 1.301504 0.156672 0.883712 3 1024.0 0.603136 0.747520 0.706560 3.044864 3.089408 3.529728 0.467968 2.342912 4 2048.0 1.191424 1.403904 2.141184 6.010880 6.059008 11.009024 1.612800 7.135232 5 4096.0 2.377728 2.755072 7.392256 11.932672 11.938816 37.792770 5.997568 24.435200 6 8192.0 4.750336 5.491712 26.402817 23.759359 23.952385 141.014023 22.682114 90.619904 7 16384.0 9.591296 10.870784 101.262337 47.666176 48.745472 539.853821 91.346947 346.318848

性能

线性注意力的不同形式

关于线性注意力不同形式的硬件考虑,请参考GLA论文的第2.3节。

  • 并行:自注意力风格的计算,时间复杂度为O(L^2),具有序列并行性。
  • 融合递归:递归计算,时间复杂度为O(L)。隐藏状态在共享内存中即时计算,无需物化到全局内存(详见此论文的算法1)。这节省了大量I/O成本,应该是速度比较的强基准。
  • 融合分块:分块计算,时间复杂度为O(LC),其中C是块大小。隐藏状态同样即时计算,不物化到全局内存。这个版本通常比融合递归更好,因为可以使用张量核心进行序列级"归约",而融合递归完全无法使用张量核心。注意,此实现中没有序列级并行性,因此不适合非常小的批量大小设置。应比并行分块更节省内存。
  • 并行分块:具有序列并行性的分块计算。需要为每个块将隐藏状态物化到全局内存。需要适当设置C以获得良好性能,因为当C小时,需要加载/存储到全局内存的隐藏状态太多;当C太大时,浮点运算量高。推荐的C值为[64, 128, 256]。

引用

如果您觉得这个仓库有用,请考虑引用我们的工作:

@article{yang2024delta, title = {Parallelizing Linear Transformers with the Delta Rule over Sequence Length}, author = {Songlin Yang and Bailin Wang and Yu Zhang and Yikang Shen and Yoon Kim}, journal = {arXiv preprint arXiv:2406.06484}, year = {2024}, } @article{yang2023gated, title = {Gated Linear Attention Transformers with Hardware-Efficient Training}, author = {Yang, Songlin and Wang, Bailin and Shen, Yikang and Panda, Rameswar and Kim, Yoon}, journal = {arXiv preprint arXiv:2312.06635}, year = {2023} } @software{yang2024fla, title = {FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism}, author = {Yang, Songlin and Zhang, Yu}, url = {https://github.com/sustcsonglin/flash-linear-attention}, month = jan, year = {2024} }

编辑推荐精选

Vora

Vora

免费创建高清无水印Sora视频

Vora是一个免费创建高清无水印Sora视频的AI工具

Refly.AI

Refly.AI

最适合小白的AI自动化工作流平台

无需编码,轻松生成可复用、可变现的AI自动化工作流

酷表ChatExcel

酷表ChatExcel

大模型驱动的Excel数据处理工具

基于大模型交互的表格处理系统,允许用户通过对话方式完成数据整理和可视化分析。系统采用机器学习算法解析用户指令,自动执行排序、公式计算和数据透视等操作,支持多种文件格式导入导出。数据处理响应速度保持在0.8秒以内,支持超过100万行数据的即时分析。

AI工具酷表ChatExcelAI智能客服AI营销产品使用教程
TRAE编程

TRAE编程

AI辅助编程,代码自动修复

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
AIWritePaper论文写作

AIWritePaper论文写作

AI论文写作指导平台

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

AI辅助写作AI工具AI论文工具论文写作智能生成大纲数据安全AI助手热门
博思AIPPT

博思AIPPT

AI一键生成PPT,就用博思AIPPT!

博思AIPPT,新一代的AI生成PPT平台,支持智能生成PPT、AI美化PPT、文本&链接生成PPT、导入Word/PDF/Markdown文档生成PPT等,内置海量精美PPT模板,涵盖商务、教育、科技等不同风格,同时针对每个页面提供多种版式,一键自适应切换,完美适配各种办公场景。

AI办公办公工具AI工具博思AIPPTAI生成PPT智能排版海量精品模板AI创作热门
潮际好麦

潮际好麦

AI赋能电商视觉革命,一站式智能商拍平台

潮际好麦深耕服装行业,是国内AI试衣效果最好的软件。使用先进AIGC能力为电商卖家批量提供优质的、低成本的商拍图。合作品牌有Shein、Lazada、安踏、百丽等65个国内外头部品牌,以及国内10万+淘宝、天猫、京东等主流平台的品牌商家,为卖家节省将近85%的出图成本,提升约3倍出图效率,让品牌能够快速上架。

iTerms

iTerms

企业专属的AI法律顾问

iTerms是法大大集团旗下法律子品牌,基于最先进的大语言模型(LLM)、专业的法律知识库和强大的智能体架构,帮助企业扫清合规障碍,筑牢风控防线,成为您企业专属的AI法律顾问。

SimilarWeb流量提升

SimilarWeb流量提升

稳定高效的流量提升解决方案,助力品牌曝光

稳定高效的流量提升解决方案,助力品牌曝光

Sora2视频免费生成

Sora2视频免费生成

最新版Sora2模型免费使用,一键生成无水印视频

最新版Sora2模型免费使用,一键生成无水印视频

下拉加载更多