ThunderKittens

ThunderKittens

高效瓦片原语框架助力深度学习内核开发

ThunderKittens是一个用于开发高性能CUDA深度学习内核的框架。它基于现代GPU架构设计,通过操作16x16及以上的数据瓦片实现高效计算。框架支持张量核心、共享内存优化和异步数据传输等特性,充分利用GPU性能。ThunderKittens以简洁、可扩展和高速为设计原则,适用于各类深度学习算法的高效实现。

ThunderKittensCUDAGPU编程深度学习矩阵运算Github开源项目

ThunderKittens

快速内核的瓦片原语

<div align="center" > <img src="https://yellow-cdn.veclightyear.com/835a84d5/48f011e6-aa1e-4f32-a2d8-72a9bbe59f22.png" height=350 alt="ThunderKittens 标志" style="margin-bottom:px"/> </div> <br> <br>

ThunderKittens 是一个框架,旨在使用 CUDA 轻松编写快速的深度学习内核(不久后还将支持 ROCm 等其他平台)。

ThunderKittens 基于三个关键原则构建:

  1. 简单性。ThunderKittens 编写起来异常简单。
  2. 可扩展性。ThunderKittens 原生嵌入,如果你需要的功能超出了 ThunderKittens 的能力范围,它不会妨碍你自行构建。
  3. 速度。使用 ThunderKittens 编写的内核应该至少与从头编写的内核一样快 —— 特别是因为 ThunderKittens 可以在底层以"正确"的方式处理事情。我们认为我们的 Flash Attention 2 实现证明了这一点。
<div align="center" > <img src="https://yellow-cdn.veclightyear.com/835a84d5/ffc1a947-ec1f-4a72-8537-ea0abab07949.png" height=600 alt="Flash Attention 2,但是带有小猫!" style="margin-bottom:px"/> </div>

ThunderKittens 是从硬件层面构建的 —— 我们按照硅芯片的指示行事。现代 GPU 告诉我们,它们希望处理相当小的数据瓦片。GPU 并不真的是一个 1000x1000 矩阵乘法机器(即使它经常被这样使用);它是一个多核处理器,每个核心可以高效地执行约 16x16 的矩阵乘法。因此,ThunderKittens 围绕操作不小于 16x16 值的数据瓦片构建。

ThunderKittens 让一些棘手的事情变得简单,从而在现代硬件上实现高利用率。

  1. 张量核心。ThunderKittens 可以调用快速的张量核心函数,包括在 H100 GPU 上的异步 WGMMA 调用。
  2. 共享内存。我有九十九个问题,但银行冲突不是其中之一。
  3. 加载和存储。通过异步复制隐藏延迟,通过 TMA 进行地址生成。
  4. 分布式共享内存。L2 已经是过去式了。

示例:一个简单的注意力内核

以下是使用 ThunderKittens 为 RTX 4090 编写的简单 FlashAttention-2 内核示例。

#define NUM_WORKERS 16 // 此内核每个块并行使用16个工作线程,以帮助更快地发出指令。 using namespace kittens; // 为简单起见,此内核仅处理 headdim=64。此外,n 应该是 256 的倍数。 __global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) { auto warpid = kittens::warpid(); auto block_start = blockIdx.x*(n*64); const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start; bf16 *_o = __o__ + block_start; extern __shared__ alignment_dummy __shm[]; // 这是 CUDA 共享内存 shared_allocator al((int*)&__shm[0]); // K 和 V 存储在共享内存中 —— 这几乎是所能容纳的全部。 st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>(); st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>(); // 初始化所有寄存器瓦片。 rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg 需要交换到 col_l rt_fl_1x1<> att_block; rt_bf_1x1<> att_block_mma; rt_fl_1x4<> o_reg; rt_fl_1x1<>::col_vec max_vec_last, max_vec; // 这些是注意力块的列向量 rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // 这些是注意力块的列向量 int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS); for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) { // 每个线程束加载自己的 16x64 的 Q 瓦片,然后乘以 1/sqrt(d) load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); mul(q_reg, q_reg, __float2bfloat16(0.125f)); // 温度调整 // 将 flash 注意力 L、M 和 O 寄存器置零。 neg_infty(max_vec); // 为 Q 块清零寄存器 zero(norm_vec); zero(o_reg); // 对已加载的这些 q 迭代 k、v for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) { // 每个线程束将自己的 k、v 块加载到共享内存中 load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); __syncthreads(); // 我们需要确保在开始计算阶段之前所有内存都已加载 // 现在每个线程束遍历所有子瓦片,加载它们,然后执行 flash 注意力内部算法。 for(int subtile = 0; subtile < NUM_WORKERS; subtile++) { load(k_reg, k_smem[subtile]); // 从共享内存加载 k 到寄存器 zero(att_block); // 将 16x16 注意力瓦片置零 mma_ABt(att_block, q_reg, k_reg, att_block); // Q@K.T copy(norm_vec_last, norm_vec); copy(max_vec_last, max_vec); row_max(max_vec, att_block, max_vec); // 累积到 max_vec sub_row(att_block, att_block, max_vec); // 从注意力中减去最大值 —— 现在所有值 <=0 exp(att_block, att_block); // 原地对块进行指数运算。 sub(max_vec_last, max_vec_last, max_vec); // 从旧的最大值中减去新的最大值以找到新的归一化。 exp(max_vec_last, max_vec_last); // 对这个向量进行指数运算 —— 这是我们需要用来归一化的。 mul(norm_vec, norm_vec, max_vec_last); // norm_vec 现在已归一化。 row_sum(norm_vec, att_block, norm_vec); // 将新的注意力块累积到现在已重新缩放的 norm_vec 上 div_row(att_block, att_block, norm_vec); // 现在注意力块已正确归一化 mul(norm_vec_last, norm_vec_last, max_vec_last); // 根据新的最大值归一化先前的 norm vec div(norm_vec_last, norm_vec_last, norm_vec); // 根据新的范数归一化先前的 norm vec copy(att_block_mma, att_block); // 转换为 bf16 以用于 mma_AB load(v_reg, v_smem[subtile]); // 从共享内存加载 v 到寄存器。 rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // 这是一个引用,该调用使 v_reg 失效 mul_row(o_reg, o_reg, norm_vec_last); // 在进行 mma_AB 之前预先归一化 o_reg mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // 使用局部注意力@V 矩阵乘法对 o_reg 进行 mfma。 } __syncthreads(); // 我们需要确保所有线程束都完成后才能开始加载下一个 kv 块 } store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // 写出 o。如果 d 设为 constexpr q_reg.rows,编译器在寄存器使用上会有问题 :/ } }

总的来说,这是 58 行代码(不包括空白行),在 RTX 4090 上可以达到约 122 TFLOPs。(理论最大值的 74%)我们将在下一节 ThunderKittens 手册中更仔细地介绍这些原语。

库安装

要使用 Thunderkittens,你不需要对 TK 本身做太多操作。它是一个仅包含头文件的库,所以只需克隆仓库,并包含 kittens.cuh。轻松搞定。

但 ThunderKittens 确实使用了许多现代功能,因此它有相当严格的要求。

  • CUDA 12.3+。CUDA 12.1 之后的任何版本可能都能工作,但由于这些早期 CUDA 版本中的一个 bug,你可能会遇到串行化的 wgmma 管道。
  • 广泛使用 C++20 —— TK 基于概念运行。
sudo apt update
sudo apt install gcc-10 g++-10

sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 --slave /usr/bin/g++ g++ /usr/bin/g++-10

sudo apt update
sudo apt install clang-10

如果你找不到 nvcc,或者遇到环境指向错误 CUDA 版本的问题:

export CUDA_HOME=/usr/local/cuda-12/
export PATH=${CUDA_HOME}/bin:${PATH} 
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH

最后,感谢 Jordan Juravsky 整理了一份关于设置兼容 kittens 的 conda 环境的快速文档。

内核安装

要试验我们现有的 TK 内核,请在 config.py 文件中指定你感兴趣的内核,然后运行 python setup.py install

欢迎贡献新的内核!

测试

要验证你的安装,并运行 TK 相当全面的单元测试套件,只需在 tests 文件夹中运行 make -j。注意:这可能会在编译数千个内核时占用你的电脑一两分钟。

示例

要编译示例,请在根目录运行 source env.src,然后进入 examples 目录。(许多示例使用 $THUNDERKITTENS_ROOT 环境变量来定位自己并找到 src 目录。)

ThunderKittens 手册

ThunderKittens 实际上是一个相当小的库,就其提供的功能而言。

  • 数据类型:(寄存器 + 共享)*(瓦片 + 向量),所有这些都由布局、类型和大小参数化。
  • 用于操作这些对象的操作。

尽管它很简单,但如果你不了解底层的工作原理,仍可能会遇到一些棘手的问题。因此,我们建议你在开始编写内核之前好好阅读这份手册 —— 我们保证它不会太长!

NVIDIA 的编程模型

为了理解ThunderKittens,首先回顾一下NVIDIA的编程模型如何工作会有所帮助,因为NVIDIA在编写并行代码时提供了几个不同的"作用域"供考虑。

  1. 线程 -- 这是在单个数据位上执行工作的级别,如浮点乘法。一个线程每个周期可以访问最多256个32位寄存器。
  2. 线程束 -- 32个线程组成一个线程束。这是硬件发出指令的级别。它也是ThunderKittens操作的基本(和默认)作用域;大多数ThunderKittens编程都发生在这个级别。
  3. 线程束组 -- 4个线程束组成一个线程束组。这是发出异步线程束组矩阵乘累加指令的级别。(我们真希望能忽略这个级别,但不幸的是H100需要它。)相应地,许多矩阵乘法和内存操作都在线程束组级别得到支持。
  4. 块 -- N个线程束组成一个块,这是在CUDA编程模型中共享"共享内存"的级别。在ThunderKittens中,N通常是8。
  5. 网格 -- M个块组成一个网格,其中M应该等于(或略小于)GPU上SM数量的倍数,以避免尾部效应。ThunderKittens不直接操作网格作用域,除了通过帮助初始化TMA描述符。

"寄存器"对象存在于线程束级别 -- 它们的内容分布在线程束的各个线程中。寄存器对象包括:

  • 寄存器瓦片,在src/register_tile/rt.cuh中声明为kittens::rt结构。Kittens提供了一些有用的包装器 -- 例如,可以将32x16行布局的bfloat16寄存器瓦片声明为kittens::rt_bf_2x1; -- 默认情况下行布局是隐含的。
  • 寄存器向量,与寄存器瓦片相关联。它们有两种形式:列向量和行向量。列向量用于在瓦片行上进行归约或映射,而行向量在瓦片列上进行归约和映射。例如,要保存上面声明的瓦片行的和,我们可以创建一个kittens::rt_bf_2x1<>::col_vec; 相比之下,"共享"对象存在于块级别,仅位于共享内存中。

所有ThunderKittens函数都遵循一个通用的签名。与汇编语言类似(ThunderKittens本质上是一个抽象的面向瓦片的RISC指令集),每个函数的目标是第一个操作数,源操作数按顺序传递。

例如,如果我们有三个32x64浮点寄存器瓦片:kittens::rt_fl_2x4 a, b, c;,我们可以对ab进行元素级乘法并将结果存储在c中,调用如下:kittens::mul(c, a, b);

同样,如果我们想将结果存储到共享瓦片__shared__ kittens:st_bf_2x4 s;中,我们可以类似地写函数:kittens::store(s, c);

类型系统

ThunderKittens努力保护你免受自身错误的影响。特别是,ThunderKittens希望在编译时知道对象的布局,并确保它们在允许你进行操作之前是兼容的。这很重要,因为某些操作的允许布局有微妙之处,如果没有静态检查,很容易出现令人痛苦的静默失败。例如,普通的矩阵乘法要求B操作数采用列布局,而外积则要求B操作数采用行布局。

如果你被告知你认为存在的操作不存在,请仔细检查你的布局 -- 这是最常见的错误。只有在确认后才报告bug :)

作用域

默认情况下,ThunderKittens操作存在于线程束级别。换句话说,每个函数期望只由单个线程束调用,并且该单个线程束将完成函数的所有工作。如果将多个线程束分配给相同的工作,将导致未定义行为。(如果操作涉及内存移动,很可能会完全崩溃。)通常,你应该期望你的编程模式涉及在内核开始时使用kittens::warpid()实例化一个warpid,并基于该id将任务分配给数据。

然而,并非所有ThunderKittens函数都在线程束级别操作。许多重要操作,特别是WGMMA指令,需要线程束的协作组。这些操作存在于模板kittens::group<collaborative size>中。例如,wgmma指令可通过kittens::group<4>::mma_AB(或其别名kittens::warpgroup::mma_AB)获得。线程束组还可以协作加载共享内存或在共享内存中进行归约。

其他限制

ThunderKittens中的大多数操作都是纯函数式的。然而,一些操作确实有特殊限制;ThunderKittens试图通过给它们起显眼的名字来警告你。例如,寄存器瓦片转置需要可分离的参数:如果给它相同的底层寄存器作为源和目标,它会静默失败。因此,它被命名为transpose_sep

编辑推荐精选

即梦AI

即梦AI

一站式AI创作平台

提供 AI 驱动的图片、视频生成及数字人等功能,助力创意创作

扣子-AI办公

扣子-AI办公

AI办公助手,复杂任务高效处理

AI办公助手,复杂任务高效处理。办公效率低?扣子空间AI助手支持播客生成、PPT制作、网页开发及报告写作,覆盖科研、商业、舆情等领域的专家Agent 7x24小时响应,生活工作无缝切换,提升50%效率!

Keevx

Keevx

AI数字人视频创作平台

Keevx 一款开箱即用的AI数字人视频创作平台,广泛适用于电商广告、企业培训与社媒宣传,让全球企业与个人创作者无需拍摄剪辑,就能快速生成多语言、高质量的专业视频。

TRAE编程

TRAE编程

AI辅助编程,代码自动修复

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
蛙蛙写作

蛙蛙写作

AI小说写作助手,一站式润色、改写、扩写

蛙蛙写作—国内先进的AI写作平台,涵盖小说、学术、社交媒体等多场景。提供续写、改写、润色等功能,助力创作者高效优化写作流程。界面简洁,功能全面,适合各类写作者提升内容品质和工作效率。

AI辅助写作AI工具蛙蛙写作AI写作工具学术助手办公助手营销助手AI助手
问小白

问小白

全能AI智能助手,随时解答生活与工作的多样问题

问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。

热门AI助手AI对话AI工具聊天机器人
Transly

Transly

实时语音翻译/同声传译工具

Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

下拉加载更多