斯坦福让“GPU高速运转”的新工具火了,比FlashAttention2更快

西风 发自 凹非寺量子位 | 公众号 QbitAI
AI算力资源越发紧张的当下,斯坦福新研究将GPU运行效率再提升一波——
内核只有100行代码,让H100比使用FlashAttention-2,性能还要提升30%。
怎么做到的?
研究人员从“硬件实际需要什么?如何满足这些需求?”这两个问题出发,设计了 一个嵌入式CUDA DSL工具,名为ThunderKittens(暂且译为雷猫)。
雷猫可简化AI内核的编写,同时充分利用底层硬件能力。

具体来说,雷猫的主要抽象是寄存器和共享内存中的小型张量块(tile),和目前GPU中对小矩阵乘法的优化相匹配。
通过操作这些tile,开发者可相对简单地编写代码,充分利用张量核心、异步数据传输和共享内存等硬件特性。
使用雷猫实现的注意力机制内核,代码量少且能实现很高的硬件利用率,性能超过直接使用底层库(如Cutlass)。
详细讨论过程以及雷猫是怎么设计出的,研究人员以“GPUs Go Brrr”为题,发在了斯坦福Hazy Research的Blog网站上。

网友们对此讨论也十分热烈。
有网友表示读这篇Blog时,让他想起了初次了解超标量CPU架构时的惊讶感受:
GPU真的达到了新高度。

还有网友表示:
这篇文章重新点燃了我在CS 149并行编程课中所感受到的快乐。

H100里有什么?
斯坦福研究人员以H100为例,探讨了优化GPU的方法。
首先,回顾一下H100的硬件细节,这对于接下来的讨论非常重要。

一个H100 SXM GPU包含:
(1)80GB的HBM3内存,带宽为3TB/s(实际带宽略低)。
(2)50MB的L2缓存,带宽为12TB/s,在GPU上分为两个25MB的部分,通过交叉开关连接(这个交叉开关表现不佳)。
(3)132个流式多处理器(SM),每个包含:
高达227KB的共享内存位于256KB的L1缓存中(这些加起来的带宽大约33TB/s)。
一个张量内存加速器(TMA)——这是英伟达Hopper架构中的一种新硬件组件,可进行异步地址生成和内存获取,还能促进片上内存网络。
4个子单元,每个含:一个warp scheduler;512个向量寄存器(每个包含32个4字节的词);一个用于执行矩阵乘法的张量核心;一组内置指令,如求和、乘法等,这些指令能够并行操作这些向量寄存器。
除了这些,一个GPU还包括内存控制器、指令缓存……但对于这项研究而言不重要。
重要的是,所有的计算都发生在流式多处理器中,大部分计算是在寄存器中。
H100 GPU拥有989 TFLOPs的半精度矩阵乘法计算能力,以及约60 TFLOPs的“其他”计算能力。因此,每个周期内张量核心被使用时,至少能达到94%的硬件利用率。而张量核心不被使用时,硬件的利用率不会超过6%。
换句话说:
H100的利用率=张量核心活跃周期的百分比+/- 6%。

所以要充分发挥H100的能力,关键是保持张量核心持续运算。
榨干H100,要注意什么?
然鹅,要保持张量核心持续运行并不容易。
研究人员发现GPU硬件具有一些特性,对于保持矩阵乘法的运行非常重要:
WGMMA指令虽然是必要的,但使用起来颇为麻烦。
共享内存的速度并不如预期的快,使用时还需格外注意。
生成地址的成本较高。
保持高占用率对于提升性能是有益的,寄存器至关重要。
这些特性在非H100 GPU上也有所适用,在H100上更加典型,就拿RTX 4090来说,相比H100处理起来简单得多。

所以接下来还是以H100为例,展开探讨这几点特性。
WGMMA指令
H100引入了一套新的指令集,名为“warp group matrix multiply accumulate”(在PTX中为wgmma.mma_async,在SASS中为HGMMA/IGMMA/QGMMA/BGMMA)。
要理解这些指令的特点,需回顾以往张量核心的使用方式。
早期GPU中的张量核心指令如wmma.mma.sync和mma.sync,要求SM一个子单元内的32个线程的一个warp同步传输数据块至张量核心并等待结果。
wgmma.mma_async指令则不同。它允许128个连续线程跨SM所有子单元协作同步,并从共享内存及寄存器(可选)异步启动矩阵乘法。这使得这些warp在等待矩阵乘法结果时可以处理其他任务。
研究人员通过微观基准测试,发现这些指令是充分发挥H100计算能力所必需的。没有这些指令,GPU的峰值利用率大约只有63%。
他们推测,这是由于张量核心需要从本地资源维持一个深度硬件pipeline。
然而,这些指令的内存布局极其复杂。未重排的共享内存布局合并性差,需要额外的L2带宽。重排的内存布局记录不准确,研究人员花费了大量时间才弄明白。

最终发现,这些布局只适用于特定矩阵形状,并与wgmma.mma_async指令的其他部分不兼容,例如硬件仅在未重排的布局下转置子矩阵。
此外,未重排的wgmma布局内存合并性差且有bank conflicts。尽管TMA和L2缓存在如flash attention这类内核上能较好地掩盖这些问题,但要充分利用硬件,必须精心控制内存请求的合并和避免bank conflicts。
尽管有这些问题,但这些指令对于充分利用H100是必不可少的。没有它们,GPU的潜在性能就损失了37%。
共享内存
共享内存的单次访问延迟约为30个周期(这也与研究人员观察的相符),这看似不多,但在这段时间内,SM的张量核心几乎能完成两次完整的32x32方阵乘法。
以前的研究,如Flash Attention,研究人员更多关注的是HBM-SRAM的瓶颈。但随着HBM速度的提升和张量核心的快速发展,即使是共享内存的相对较小延迟也变得尤为关键。
由于共享内存被分为32个独立的存储单元,处理不当可能会引发bank conflicts,即同一个内存bank同时被多个请求访问,这种情况会导致请求被序列化。研究人员实验后认为,这会显著拖慢内核速度,且wgmma与mma指令需要的寄存器布局容易受到bank conflicts的影响。
解决方法是通过各种“重排”模式调整共享内存的配置,避免bank conflicts,但细节要处理得当。
此外研究人员发现,尽可能避免在寄存器和共享内存之间的移动数据非常重要。可能的话,可使用内置硬件(如wgmma和TMA指令)进行异步数据传输。实在没法子了,再使用warp进行同步数据传输。
地址生成
H100还有一个有趣的特性,其张量核心和内存都足够快,以至于仅生成用于获取数据的内存地址就占用了芯片的大量资源,特别是加入复杂的交错或重排模式时,这种情况更为明显。
研究人员表示,英伟达提供了张量内存加速器(TMA),似乎就是已经意识到了这个问题。
TMA允许用户在全局和共享内存中指定多维张量布局,命令其异步提取张量的一部分,并在完成后触发一个屏障。这大大节省了地址生成的开销,并简化了pipelines的构建。
研究人员认为,TMA对于充分发挥H100的潜力至关重要,可能比wgmma.mma_async更为关键。
它不仅节省了寄存器资源和指令派发,还提供了如异步在全局内存上执行归约等实用功能——这在处理复杂的反向内核时尤其有用。
虽然TMA的重排模式解读有一定难度,需要进行一些逆向工程,但研究人员表示,相比之下,他们在这上面遇到的问题要少得多。
占用率
占用率指的是在GPU的相同执行硬件上同时调度的线程数。每个周期,SM的某一子单元的warp scheduler会尝试向准备就绪的warp线程发出指令。
研究人员认为,英伟达采用这种模型可以更容易地保持硬件的满负荷运行。例如,当一个线程warp等待执行矩阵乘法时,另一个可以被指派执行使用快速指数运算的指令。
在某些方面,H100对占用率的依赖程度低于前几代硬件。
它的异步特性使得即使单一指令流也能使多个硬件部分同时持续运行,包括读取内存、执行矩阵乘法、进行共享内存的归约,同时还能在寄存器上进行计算。
但高占用率容易隐藏缺陷或同步问题,一个设计良好的pipeline即使在占用率不高的情况下也能运行得相当快。
据研究人员观察,英伟达在设计GPU时确实考虑到了占用率。且由于存在足够多的同步操作和足够多的错误可能性,根据他们的经验,提高占用率通常能显著增加硬件的实际利用率。
此外,相比H100,A100和RTX 4090更依赖同步指令调度,占用率更重要。
用雷猫优化GPU
鉴于以上情况,如何才能更轻松地编写所需的内核类型,同时充分发挥硬件的全部潜力?
雷猫(ThunderKittens)登场了。
这是一个嵌入在CUDA中的DSL,本是斯坦福研究人员设计出来给自己内部使用的,后来发现还真挺好使。
Ps:起这么个名,一是他们觉得小猫很可爱,二来他们觉得大伙儿在代码中输入kittens::会很有趣。
具体来说,雷猫包含四种模板类型:
寄存器tiles:在寄存器文件上表示二维张量。
寄存器向量:在寄存器文件上表示一维张量。
共享tiles:在共享内存中表示二维张量。
共享向量:在共享内存中表示一维张量。
tiles通过高度、宽度和布局进行参数化;寄存器向量通过长度和布局进行参数化;而共享向量仅通过长度进行参数化,通常不会遇到bank conflicts问题。
此外,研究人员提供了一系列操作来处理这些张量,既可在warp级别使用,也可用于多个warp协作,包含初始化器,如将共享向量清零;一元操作,如exp;二元操作,如mul;行/列操作,例如行求和。
雷猫作为一个嵌入到CUDA中的库,其提供的抽象层在遇到不支持的功能时能够很好地处理。如果雷猫缺少某些功能,可以直接扩展它来实现你想要的效果。
以Tri的flash attention算法为例,在实际应用中,即使是使用英伟达的Cutlass库,实现起来也是相当复杂。
以下是一个在RTX 4090上使用雷猫编写的简单flash attention内核的示例。
总共约60行CUDA代码,硬件利用率达到了75%。代码复杂性主要在于算法本身,而非交织模式或寄存器布局。
#define NUM_WORKERS 16 // This kernel uses 16 workers in parallel per block, to help issue instructions more quickly.using namespace kittens; // this kernel only handles headdim=64 for simplicity. Also n should be a multiple of 256 here.__global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {    auto warpid        = kittens::warpid();    auto block_start   = blockIdx.x*(n*64);    const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start;          bf16 *_o = __o__ + block_start;    extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory    shared_allocator al((int*)&__shm[0]);    // K and V live in shared memory -- this is about all that will fit.    st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();    st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();    // Initialize all of the register tiles.    rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg need to be swapped into col_l    rt_fl_1x1<> att_block;    rt_bf_1x1<> att_block_mma;    rt_fl_1x4<> o_reg;    rt_fl_1x1<>::col_vec max_vec_last, max_vec; // these are column vectors for the attention block    rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // these are column vectors for the attention block    int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS);    for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) {        // each warp loads its own Q tile of 16x64, and then multiplies by 1/sqrt(d)        load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);        mul(q_reg, q_reg, __float2bfloat16(0.125f)); // temperature adjustment        // zero flash attention L, M, and O registers.        neg_infty(max_vec); // zero registers for the Q chunk        zero(norm_vec);        zero(o_reg);        // iterate over k, v for these q's that have been loaded        for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) {            // each warp loads its own chunk of k, v into shared memory            load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);            load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);            __syncthreads(); // we need to make sure all memory is loaded before we can begin the compute phase            // now each warp goes through all of the subtiles, loads them, and then does the flash attention internal alg.            for(int subtile = 0; subtile < NUM_WORKERS; subtile++) {                load(k_reg, k_smem[subtile]); // load k from shared into registers                zero(att_block); // zero 16x16 attention tile                mma_ABt(att_block, q_reg, k_reg, att_block); // Q@K.T                copy(norm_vec_last, norm_vec);                copy(max_vec_last,  max_vec);                row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec                sub_row(att_block, att_block, max_vec); // subtract max from attention -- now all <=0                exp(att_block, att_block); // exponentiate the block in-place.                sub(max_vec_last, max_vec_last, max_vec); // subtract new max from old max to find the new normalization.                exp(max_vec_last, max_vec_last); // exponentiate this vector -- this is what we need to normalize by.                mul(norm_vec, norm_vec, max_vec_last); // and the norm vec is now normalized.                row_sum(norm_vec, att_block, norm_vec); // accumulate the new attention block onto the now-rescaled norm_vec                div_row(att_block, att_block, norm_vec); // now the attention block is correctly normalized                mul(norm_vec_last, norm_vec_last, max_vec_last); // normalize the previous norm vec according to the new max                div(norm_vec_last, norm_vec_last, norm_vec); // normalize the previous norm vec according to the new norm                copy(att_block_mma, att_block); // convert to bf16 for mma_AB                load(v_reg, v_smem[subtile]); // load v from shared into registers.                rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // this is a reference and the call has invalidated v_reg                mul_row(o_reg, o_reg, norm_vec_last); // normalize o_reg in advance of mma_AB'ing onto it                mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // mfma onto o_reg with the local attention@V matmul.            }            __syncthreads(); // we need to make sure all warps are done before we can start loading the next kv chunk        }        store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/    }}
关于TMA、WGMMA、交织模式和描述符的复杂性,这里展示了一个使用雷猫编写的,针对H100的FlashAttention-2算法的前向传递示例。
template<int D>__global__  __launch_bounds__((NUM_WORKERS)*kittens::WARP_THREADS, 2)void fwd_attend_ker_dim(int N, const CUtensorMap* tma_q, const CUtensorMap* tma_k, const CUtensorMap* tma_v, CUtensorMap* tma_o) {    extern __shared__ int __shm[]; // this is the CUDA shared memory    tma_swizzle_allocator al((int*)&__shm[0]);    constexpr int tile_width = fwd_attend_ker_tile_dims<D>::tile_width; // constants    constexpr int qo_height  = fwd_attend_ker_tile_dims<D>::qo_height;    constexpr int kv_height  = fwd_attend_ker_tile_dims<D>::kv_height;    st_bf<qo_height, tile_width, layout_q>          (&q_smem)   [NUM_WARPGROUPS] = al.allocate<st_bf<qo_height, tile_width, layout_q>,          NUM_WARPGROUPS>();    st_bf<kv_height, tile_width, layout_k>          (&k_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_k>, 2,       NUM_WORKERS_KV>();    st_bf<kv_height, tile_width, layout_v>          (&v_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_v>, 2,       NUM_WORKERS_KV>();    int tic = 0, toc = 1;    rt_fl<1, kv_height> att_block;    rt_bf<1, kv_height> att_block_mma;    rt_fl<1, qo_height> o_prev;    col_vec<rt_fl<1, kv_height>> max_vec_last, max_vec;    col_vec<rt_fl<1, kv_height>> norm_vec_last, norm_vec;    int warpid      = kittens::warpid();    int warpgroupid = warpid/kittens::WARPGROUP_WARPS;    int kv_blocks = N / (NUM_WORKERS_KV*k_smem[0][0].rows);    __shared__ uint64_t qsmem_barrier, kvsmem_barrier;//, vsmem_barrier;    int q_phasebit = 0;    int kv_phasebit = 0;    if (threadIdx.x == 0) {        tma::init_barrier<st_bf<qo_height, tile_width, layout_q>, NUM_WARPGROUPS>(qsmem_barrier, 1);        tma::init_barrier<st_bf<kv_height, tile_width, layout_k>, NUM_WORKERS_KV*2>(kvsmem_barrier, 1);     }    if (warpid == 0) {        for (int wg = 0; wg < NUM_WORKERS/kittens::WARPGROUP_WARPS; wg++) { // load q            int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + wg;            tma::load_async((q_smem[wg]), tma_q, qsmem_barrier, tile_idx);         }        for (int w = 0; w < NUM_WORKERS_KV; w++) { // load k, v                  int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + (0 * NUM_WORKERS_KV) + w;             tma::load_async((k_smem[tic][w]), tma_k, kvsmem_barrier, tile_idx);             tma::load_async((v_smem[tic][w]), tma_v, kvsmem_barrier, tile_idx);         }    }    neg_infty(max_vec); // zero registers for the Q chunk    zero(norm_vec);    zero(o_prev);    __syncthreads();    tma::arrive_and_wait(qsmem_barrier, q_phasebit);    q_phasebit ^= 1;    if constexpr (D == 64) { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.125f)); }     else { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.08838834764f)); }    for (auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++, tic ^= 1, toc ^= 1) {        tma::arrive_and_wait(kvsmem_barrier, kv_phasebit);        kv_phasebit ^= 1;        __syncthreads();        if (warpid == 0) {            tma::set_bytes(kvsmem_barrier, 2 * NUM_WORKERS_KV * k_smem[0][0].num_elements * sizeof(bf16));            if (kv_idx + 1 < kv_blocks) {                for (int w = 0; w < NUM_WORKERS_KV; w++) {                            int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + ((kv_idx + 1) * NUM_WORKERS_KV) + w;                     tma::load_async((k_smem[toc][w]), tma_k, kvsmem_barrier, tile_idx);                     tma::load_async((v_smem[toc][w]), tma_v, kvsmem_barrier, tile_idx);                }            }        }        warpgroup::mma_fence(att_block);        warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[tic][0]);        warpgroup::mma_commit_group();        copy(norm_vec_last, norm_vec);        copy(max_vec_last,  max_vec);        warpgroup::mma_async_wait();        row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec        sub_row(att_block, att_block, max_vec);        exp(att_block, att_block);        sub(max_vec_last, max_vec_last, max_vec);        exp(max_vec_last, max_vec_last);        mul(norm_vec, norm_vec, max_vec_last);        row_sum(norm_vec, att_block, norm_vec); // accumulate onto the norm_vec        div_row(att_block, att_block, norm_vec);        mul(norm_vec_last, norm_vec_last, max_vec_last);        div(norm_vec_last, norm_vec_last, norm_vec);        copy(att_block_mma, att_block); // convert to bf16 for mma        mul_row(o_prev, o_prev, norm_vec_last); // normalize o_prev in advance of mma'ing onto it        warpgroup::mma_fence(o_prev);        warpgroup::mma_AB(o_prev, att_block_mma, v_smem[tic][0]);        warpgroup::mma_commit_group();    }    auto (*o_smem) = reinterpret_cast<st_bf<qo_height, tile_width, layout_o>(*)>(q_smem); // reuse q memory    warpgroup::store(o_smem[warpgroupid], o_prev);     __syncthreads();    if (warpid % 4 == 0) { // store o        int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + warpgroupid;        tma::store_async(tma_o, (o_smem[warpgroupid]), tile_idx);         tma::store_commit_group();     }    tma::store_async_wait();}
那么,它的表现如何?
这个内核只有100行代码,实际上它在H100上的性能比FlashAttention-2高出约30%。雷猫负责包装布局和指令,提供了一个可以在GPU上使用的迷你pytorch环境。

△FA2(通过Pytorch实现)与TK在H100 SXM上的多种配置比较
此外,研究人员还发布了基于线性注意力和其他新架构的内核。其中基于线性注意力的内核的运行速度可达215 TFLOPs,如果考虑到算法中固有的重计算,速度可超过300 TFLOPs。
尽管线性注意力在理论上效率更高,但此前在实际硬件上表现并不佳。因此,研究人员认为这可能促进一系列高吞吐量应用的发展。

small tile符合AI和硬件发展趋势
最后,雷猫研究团队总结了开发雷猫的一些思考。在他们看来,雷猫之所以有效,是因为它的目标并不是试图做所有事:
CUDA的确比雷猫表达能力更广,雷猫小而简单,功能有限。但雷猫的small tiles抽象设计符合AI和硬件的发展趋势。
虽然雷猫不支持小于16的维度,但研究人员认为这并不重要,因为硬件也不倾向于支持过小的维度。
如果你的矩阵乘法小于16x16,你确定你正在做的是AI吗?
从理论出发,研究人员认为需要进行一种框架转变。
“寄存器当然不应该像旧CPU那样32位字。CUDA使用的1024位宽向量寄存器确实是朝着正确方向迈出的一步。但对我们来说,寄存器是16x16的数据tile。我们认为AI需要这样的设计,毕竟,它仍然只是矩阵乘法、归约和重塑。我们认为硬件也需要这样的设计,小型矩阵乘法迫切需要超出系统级MMA的硬件支持。”
研究人员认为,应该根据硬件特性来重新定义AI的设计理念。例如,循环状态应该有多大?应该足够大以适应一个SM。计算的密度应该有多高?不应低于硬件的需求。
我们未来工作的一个重要方向是利用我们对硬件的了解来帮助我们设计与之匹配的AI。
参考链接:[1]https://hazyresearch.stanford.edu/blog/2024-05-12-tk[2]https://github.com/HazyResearch/ThunderKittens[3]https://news.ycombinator.com/item?id=40337936
— 完 —
量子位年度AI主题策划正在征集中!
欢迎投稿专题 一千零一个AI应用,365行AI落地方案
或与我们分享你在寻找的AI产品,或发现的AI新动向

点这里👇关注我,记得标星哦~
一键三连「分享」、「点赞」和「在看」
科技前沿进展日日相见 ~ 

到顶部