ThunderKittens 是一个框架,旨在使用 CUDA 轻松编写快速的深度学习内核(不久后还将支持 ROCm 等其他平台)。
ThunderKittens 基于三个关键原则构建:
ThunderKittens 是从硬件层面构建的 —— 我们按照硅芯片的指示行事。现代 GPU 告诉我们,它们希望处理相当小的数据瓦片。GPU 并不真的是一个 1000x1000 矩阵乘法机器(即使它经常被这样使用);它是一个多核处理器,每个核心可以高效地执行约 16x16 的矩阵乘法。因此,ThunderKittens 围绕操作不小于 16x16 值的数据瓦片构建。
ThunderKittens 让一些棘手的事情变得简单,从而在现代硬件上实现高利用率。
示例:一个简单的注意力内核
以下是使用 ThunderKittens 为 RTX 4090 编写的简单 FlashAttention-2 内核示例。
#define NUM_WORKERS 16 // 此内核每个块并行使用16个工作线程,以帮助更快地发出指令。 using namespace kittens; // 为简单起见,此内核仅处理 headdim=64。此外,n 应该是 256 的倍数。 __global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) { auto warpid = kittens::warpid(); auto block_start = blockIdx.x*(n*64); const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start; bf16 *_o = __o__ + block_start; extern __shared__ alignment_dummy __shm[]; // 这是 CUDA 共享内存 shared_allocator al((int*)&__shm[0]); // K 和 V 存储在共享内存中 —— 这几乎是所能容纳的全部。 st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>(); st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>(); // 初始化所有寄存器瓦片。 rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg 需要交换到 col_l rt_fl_1x1<> att_block; rt_bf_1x1<> att_block_mma; rt_fl_1x4<> o_reg; rt_fl_1x1<>::col_vec max_vec_last, max_vec; // 这些是注意力块的列向量 rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // 这些是注意力块的列向量 int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS); for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) { // 每个线程束加载自己的 16x64 的 Q 瓦片,然后乘以 1/sqrt(d) load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); mul(q_reg, q_reg, __float2bfloat16(0.125f)); // 温度调整 // 将 flash 注意力 L、M 和 O 寄存器置零。 neg_infty(max_vec); // 为 Q 块清零寄存器 zero(norm_vec); zero(o_reg); // 对已加载的这些 q 迭代 k、v for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) { // 每个线程束将自己的 k、v 块加载到共享内存中 load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); __syncthreads(); // 我们需要确保在开始计算阶段之前所有内存都已加载 // 现在每个线程束遍历所有子瓦片,加载它们,然后执行 flash 注意力内部算法。 for(int subtile = 0; subtile < NUM_WORKERS; subtile++) { load(k_reg, k_smem[subtile]); // 从共享内存加载 k 到寄存器 zero(att_block); // 将 16x16 注意力瓦片置零 mma_ABt(att_block, q_reg, k_reg, att_block); // Q@K.T copy(norm_vec_last, norm_vec); copy(max_vec_last, max_vec); row_max(max_vec, att_block, max_vec); // 累积到 max_vec sub_row(att_block, att_block, max_vec); // 从注意力中减去最大值 —— 现在所有值 <=0 exp(att_block, att_block); // 原地对块进行指数运算。 sub(max_vec_last, max_vec_last, max_vec); // 从旧的最大值中减去新的最大值以找到新的归一化。 exp(max_vec_last, max_vec_last); // 对这个向量进行指数运算 —— 这是我们需要用来归一化的。 mul(norm_vec, norm_vec, max_vec_last); // norm_vec 现在已归一化。 row_sum(norm_vec, att_block, norm_vec); // 将新的注意力块累积到现在已重新缩放的 norm_vec 上 div_row(att_block, att_block, norm_vec); // 现在注意力块已正确归一化 mul(norm_vec_last, norm_vec_last, max_vec_last); // 根据新的最大值归一化先前的 norm vec div(norm_vec_last, norm_vec_last, norm_vec); // 根据新的范数归一化先前的 norm vec copy(att_block_mma, att_block); // 转 换为 bf16 以用于 mma_AB load(v_reg, v_smem[subtile]); // 从共享内存加载 v 到寄存器。 rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // 这是一个引用,该调用使 v_reg 失效 mul_row(o_reg, o_reg, norm_vec_last); // 在进行 mma_AB 之前预先归一化 o_reg mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // 使用局部注意力@V 矩阵乘法对 o_reg 进行 mfma。 } __syncthreads(); // 我们需要确保所有线程束都完成后才能开始加载下一个 kv 块 } store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // 写出 o。如果 d 设为 constexpr q_reg.rows,编译器在寄存器使用上会有问题 :/ } }
总的来说,这是 58 行代码(不包括空白行),在 RTX 4090 上可以达到约 122 TFLOPs。(理论最大值的 74%)我们将在下一节 ThunderKittens 手册中更仔细地介绍这些原语。
要使用 Thunderkittens,你不需要对 TK 本身做太多操作。它是一个仅包含头文件的库,所以只需克隆仓库,并包含 kittens.cuh。轻松搞定。
但 ThunderKittens 确实使用了许多现代功能,因此它有相当严格的要求。
sudo apt update
sudo apt install gcc-10 g++-10
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 --slave /usr/bin/g++ g++ /usr/bin/g++-10
sudo apt update
sudo apt install clang-10
如果你找不到 nvcc,或者遇到环境指向错误 CUDA 版本的问题:
export CUDA_HOME=/usr/local/cuda-12/
export PATH=${CUDA_HOME}/bin:${PATH}
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH
最后,感谢 Jordan Juravsky 整理了一份关于设置兼容 kittens 的 conda 环境的快速文档。
要试验我们现有的 TK 内核,请在 config.py
文件中指定你感兴趣的内核,然后运行 python setup.py install
。
欢迎贡献新的内核!
要验证你的安装,并运行 TK 相当全面的单元测试套件,只需在 tests 文件夹中运行 make -j
。注意:这可能会在编译数千个内核时占用你的电脑一两分钟。
要编译示例,请在根目录运行 source env.src
,然后进入 examples 目录。(许多示例使用 $THUNDERKITTENS_ROOT
环境变量来定位自己并找到 src 目录。)
ThunderKittens 实际上是一个相当小的库,就其提供的功能而言。
尽管它很简单,但如果你不了解底层的工作原理,仍可能会遇到一些棘手的问题。因此,我们建议你在开始编写内核之前好好阅读这份手册 —— 我们保证它不会太长!
为了理解ThunderKittens,首先回顾一下NVIDIA的编程模型如何工作会有所帮助,因为NVIDIA在编写并行代码时提供了几个不同的"作用域"供考虑。
"寄存器"对象存在于线程束级别 -- 它们的内容分布在线程束的各个线程中。寄存器对象包括:
src/register_tile/rt.cuh
中声明为kittens::rt
结构。Kittens提供了一些有用的包装器 -- 例如,可以将32x16行布局的bfloat16寄存器瓦片声明为kittens::rt_bf_2x1;
-- 默认情况下行布局是隐含的。kittens::rt_bf_2x1<>::col_vec;
相比之下,"共享"对象存在于块级别,仅位于共享内存中。所有ThunderKittens函数都遵循一个通用的签名。与汇编语言类似(ThunderKittens本质上是一个抽象的面向瓦片的RISC指令集),每个函数的目标是第一个操作数,源操作数按顺序传递。
例如,如果我们有三个32x64浮点寄存器瓦片:kittens::rt_fl_2x4 a, b, c;
,我们可以对a
和b
进行元素级乘法并将结果存储在c
中,调用如下:kittens::mul(c, a, b);
。
同样,如果我们想将结果存储到共享瓦片__shared__ kittens:st_bf_2x4 s;
中,我们可以类似地写函数:kittens::store(s, c);
。
ThunderKittens努力保护你免受自身错误的影响。特别是,ThunderKittens希望在编译时知道对象的布局,并确保它们在允许你进行操作之前是兼容的。这很重要,因为某些操作的允许布局有微妙之处,如果没有静态检查,很容易出现令人痛苦的静默失败。例如,普通的矩阵乘法要求B操作数采用列布局,而外积则要求B操作数采用行布局。
如果你被告知你认为存在的操作不存在,请仔细检查你的布局 -- 这是最常见的错误。只有在确认后才报告bug :)
默认情况下,ThunderKittens操作存在于线程束级别。换句话说,每个函数期望只由单个线程束调用,并且该单个线程束将完成函数的所有工作。如果将多个线程束分配给相同的工作,将导致未定义行为。(如果操作涉及内存移动,很可能会完全崩溃。)通常,你应该期望你的编程模式涉及在内核开始时使用kittens::warpid()
实例化一个warpid
,并基于该id将任务分配给数据。
然而,并非所有ThunderKittens函数都在线程束级别操作。许多重要操作,特别是WGMMA指令,需要线程束的协作组。这些操作存在于模板kittens::group<collaborative size>
中。例如,wgmma指令可通过kittens::group<4>::mma_AB
(或其别名kittens::warpgroup::mma_AB
)获得。线程束组还可以协作加载共享内存或在共享内存中进行归约。
ThunderKittens中的大多数操作都是纯函数式的。然而,一些操作确实有特殊限制;ThunderKittens试图通过给它们起显眼的名字来警告你。例如,寄存器瓦片转置需要可分离的参数:如果给它相同的底层寄存器作为源和目标,它会静默失败。因此,它被命名为transpose_sep
。
一站式AI创作平台
提供 AI 驱动的图片、视频生成及数字人等功能,助力创意创作
AI办公助手,复杂任务高效处理
AI办公助手,复杂任务高效处理。办公效率低?扣子空间AI助手支持播客生成、PPT制作、网页开发及报告写作,覆盖科研、商业、舆情等领域的专家Agent 7x24小时响应,生活工作无缝切换,提升50%效率!
AI数字人视频创作平台
Keevx 一款开箱即用的AI数字人视频创作平台,广泛适用于电商广告、企业培训与社媒宣传,让全球企业与个人创作者无需拍摄剪辑,就能快速生成多语言、高质量的专业视频。
AI辅助编程,代码自动修复
Trae是一种自适应的集成开发环 境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。
AI小说写 作助手,一站式润色、改写、扩写
蛙蛙写作—国内先进的AI写作平台,涵盖小说、学术、社交媒体等多场景。提供续写、改写、润色等功能,助力创作者高效优化写作流程。界面简洁,功能全面,适合各类写作者提升内容品质和工作效率。
全能AI智能助手,随时解答生活与工作的多样问题
问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。
实时语音翻译/同声传译工具
Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。
一键生成PPT和Word,让学习生活更轻松
讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或 是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。
深度推理能力全新升级,全面对标OpenAI o1
科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。
一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型
Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。