Hub Part II · Attention 算子
Part II · Attention 算子

把 softmax 装进 SRAM

Attention 是整个 transformer 里唯一跟序列长度 $S$ 平方相关的算子。 所有人想方设法把它的 $O(S^2)$ 显存和 $O(S^2)$ 内存读写降下来—— FlashAttention 用 tile + online softmax 把实际显存压成 $O(S)$, PagedAttention 把 KV 当虚拟内存管理, Ring Attention 把 sequence 切到多卡接力。这是过去三年最重要的算子工程。

FlashAttention v1 · NeurIPS 22 FlashAttention-2 2023-07 FlashAttention-3 NeurIPS 24 · Hopper FlexAttention PyTorch 2024 FlashDecoding 2023 · decode 专用 FlashDecoding++ 2024 · 平面化 KV FlashInfer vLLM 后端 SageAttention INT8 attn · 2024 PagedAttention SOSP 23 · vLLM Ring Attention ICLR 24 · 长上下文 Mamba-2 / NSA SSM · sparse · 2024-25

§2.1FlashAttention v1 · 把 softmax 装进 SRAM

FlashAttention — Fast and Memory-Efficient Exact Attention with IO-Awareness
NeurIPS 2022 Dao, Fu, Ermon, Rudra, Ré · Stanford · arXiv:2205.14135 · code

朴素 attention 必须把 $S \times S$ 的 score 矩阵物化在 HBM 上, 一来显存 $O(S^2)$,二来读 / 写两次 HBM 把内核打回 memory-bound—— 长上下文上谁都跑不动。

关键想法:把 $Q$、$K$、$V$ 切成 tile, 内层循环里一次只把一对 $(Q_i, K_j, V_j)$ tile 加载到 SM 的 shared memory, 用 online softmax(Milakov & Gimelshein 2018) 在 tile 之间维护两个标量 $(m_i, \ell_i)$—— running max 与 running sum——边算边归一化。 最后输出 $O$ 直接累加到 HBM,从不把 $S\times S$ 物化。

在线 softmax · 一行不变的恒等式

给定已扫的旧最大值 $m$、旧分母 $\ell$ 与新一项 $s$,新值:

$$ m' = \max(m, s),\quad \ell' = \ell \cdot e^{m - m'} + e^{s - m'} $$

数学上 $\ell$ 和 $\sum e^{s_j - m'}$ 相等,所以可以拆任意多个块来算。 把 $V$ 那一侧的"加权和"也照样 rescale 即可。这个恒等式是整篇论文 的核心;剩下的工程都是怎么把它写进一个 GPU kernel。

教学版前向

# 教学版 FlashAttention 前向(PyTorch + Numpy 风格)
def flash_attn(Q, K, V, BQ=64, BK=64):
    N, d = Q.shape
    O = torch.zeros_like(Q)
    m = torch.full((N,), -float('inf'))
    l = torch.zeros((N,))
    for i in range(0, N, BQ):
        Qi = Q[i:i+BQ]
        Oi = torch.zeros((BQ, d))
        mi = torch.full((BQ,), -float('inf'))
        li = torch.zeros((BQ,))
        for j in range(0, N, BK):
            Kj, Vj = K[j:j+BK], V[j:j+BK]
            Sij = (Qi @ Kj.T) / d**0.5
            mij = Sij.max(dim=-1).values
            Pij = torch.exp(Sij - mij[:, None])
            mi_new = torch.maximum(mi, mij)
            scale  = torch.exp(mi - mi_new)
            li     = li * scale + Pij.sum(dim=-1)
            Oi     = Oi * scale[:, None] + Pij @ Vj
            mi     = mi_new
        O[i:i+BQ] = Oi / li[:, None]
    return O

Triton 视角 · 真实 kernel 的骨架

# 简化的 Triton 内核骨架(FA-2 风格,省略 mask/causal)
@triton.jit
def fa_fwd(Q_ptr, K_ptr, V_ptr, O_ptr,
           N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, D: tl.constexpr):
    pid_m = tl.program_id(0)           # 一个 program 处理 BLOCK_M 个 Q-行
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, D)
    # 把整块 Qi 一次 load 到 SRAM(不再回写)
    Qi = tl.load(Q_ptr + offs_m[:, None]*D + offs_d[None, :])
    mi = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32)
    li = tl.zeros([BLOCK_M], dtype=tl.float32)
    Oi = tl.zeros([BLOCK_M, D], dtype=tl.float32)
    for start_n in range(0, N, BLOCK_N):
        offs_n = start_n + tl.arange(0, BLOCK_N)
        Kj = tl.load(K_ptr + offs_n[:, None]*D + offs_d[None, :])
        Vj = tl.load(V_ptr + offs_n[:, None]*D + offs_d[None, :])
        S = tl.dot(Qi, tl.trans(Kj)) * (1.0 / tl.sqrt(D))
        mij = tl.max(S, axis=1)
        mi_new = tl.maximum(mi, mij)
        Pij = tl.exp(S - mi_new[:, None])
        scale = tl.exp(mi - mi_new)
        li = li * scale + tl.sum(Pij, axis=1)
        Oi = Oi * scale[:, None] + tl.dot(Pij.to(Vj.dtype), Vj)
        mi = mi_new
    Oi = Oi / li[:, None]
    tl.store(O_ptr + offs_m[:, None]*D + offs_d[None, :], Oi)
Demo 2 · 在线 softmax + tile 扫描
左侧的 $S = QK^\top$ 矩阵被一行一行 tile 扫描; 右侧实时显示当前 Q-block 每行的 running max $m$ 与 sum $\ell$。 每扫完一整行,$O$ 自动归一化——结果与逐元素 softmax 精确相等。 把 tile 调大 = 更少 kernel launch + 更高 occupancy,但需要的 shared memory 更多—— 这就是 H100 上 FlashAttention 选 BQ=64/128 的取舍。

v1 在 A100 上把 GPT-2 1.5B 训练里 attention 的耗时从 50% 砍到 5%, 内存 $O(S^2) \to O(S)$。从此长上下文成为可能。

§2.2FlashAttention v2 · 把工作搬到正确的轴上

FlashAttention-2 — Faster Attention with Better Parallelism and Work Partitioning
2023-07 Dao · arXiv:2307.08691

v1 把外层循环放在 K-block 上、内层放 Q-block。 每个 K-block 内部所有 warp 都要写同一个 $O_{i}$, 引入大量 atomic add + 仅在 head 与 batch 维并行—— 长上下文 / 小 batch 时 SM 没吃饱。

关键想法

  1. 外层循环改在 Q-block,每个 thread block 独占一段 query → 无需 atomic。
  2. 在 seq 维加一层并行(thread block 维度 = $\lceil S / B_Q \rceil$), 长序列时 SM 终于吃饱。
  3. 减少非 matmul 操作:rescale 时把除法挪到 epilogue。

v2 在 A100 上把前向打到 ~225 TFLOPS(peak 312), 比 v1 快约 2×。所有现代框架(vLLM、SGLang、TRT-LLM、TGI、TF Transformer Engine)都默认用 v2 backward。

§2.3FlashAttention v3 · Hopper / FP8 / 异步

FlashAttention-3 — Fast and Accurate Attention with Asynchrony and Low-Precision
NeurIPS 2024 Shah, Bikshandi, Zhang, Thakkar, Ramani, Dao · arXiv:2407.08608

关键想法:充分利用 H100 上新东西——

  • WGMMA(warpgroup matmul):tensor core 异步发射,CPU 发完命令就走,不阻塞。
  • TMA(tensor memory accelerator):搬数据交给 DMA 引擎, SM 不用占线程做拷贝。
  • warp specialization:4 个 warp 专门 producer 搬数据、 其余做 consumer 算 matmul,pipeline 自然形成。
  • FP8:score / output 走 e4m3, soft-attention sink + block-wise scaling 防溢出, 把吞吐再 ×1.5。

H100 FP16 上从 v2 的 ~340 TFLOPS 升到 ~690 TFLOPS(peak 989), FP8 上 ~1.2 PFLOPS。Llama-3-70B prefill 一次推到 4× 实测加速。

派生分支 · 让 FlashAttention 适应 "怪" mask

变种解决的问题
FlashDecoding (Dao 2023)decode 阶段只有 1 个 query;切 KV 在 seq 维并行,比 v2 快 8×。
FlashDecoding++ (Hong 2024)统一 softmax 归一化常数 + 平面化 KV 减少 atomic。
FlashAttention-2 with sliding window支持 Mistral 的 4096 滑窗,把 mask 当成 tile-skipping 条件。
FlexAttention (PyTorch 2024)把 mask 当一阶函数喂给 compiler,编一个定制 kernel — 一行 Python = 一个 SoTA kernel。
FlashInfer专做推理:paged KV + 多种 mask,是 vLLM/TGI 的默认后端之一。
SageAttention / SageAttention-2INT8/INT4 attention:把 QK 走 INT8 GEMM,softmax 走 fp16,~30% 加速近无损。

§2.4PagedAttention · 把 KV 当虚拟内存

vLLM / PagedAttention — Efficient Memory Management for LLM Serving
SOSP 2023 Kwon, Li, Zhuang, Sheng, Zheng, Yu, Gonzalez, Zhang, Stoica · UC Berkeley · arXiv:2309.06180 · code

传统 serving 把每个请求的 KV cache 当连续大块预分配。 不同请求的输出长度差异巨大,"预定一个最坏情况" → 显存碎片 60%+, batch size 上不去。

关键想法:把 KV cache 按固定 16 token 大小的 page 分块管理,每个序列维护一份 page table(就像操作系统的 4 KB 页表)。 分配 = 找空 page;释放 = 还 page。 碎片几乎为 0,同一段 prompt 的多个 sample(beam / parallel sampling)可以共享同一份 KV pages, 仅 copy-on-write 那一格。

在 OPT-13B 上 throughput ×4 vs HuggingFace TGI 当时版本。 PagedAttention 已经是事实标准—— 几乎所有 serving 框架都重新实现了一份。

# Page table 长这样(最小化伪代码)
class Sequence:
    block_table: list[int]   # 物理 block 编号
    seq_len:     int

# 当 seq 推进一个 token:
if seq.seq_len % BLOCK == 0:
    seq.block_table.append(allocator.alloc())
write_kv(physical_blocks[seq.block_table[-1]], offset=seq.seq_len % BLOCK)

# Attention kernel 接受 block_table,做 indirect address lookup
attn_output = paged_attention_kernel(
    q, k_cache_phys, v_cache_phys, block_table, seq_len)
系统视角 · Page 是 LLM 服务的"4KB"

PagedAttention 借的就是操作系统教科书。Block 大小 16 经过了仔细权衡: 太小 → page table 长,attention kernel 间接寻址 overhead 大; 太大 → 内部碎片回来。SGLang / TRT-LLM 一度尝试过 32 / 64, 但 16 仍是大多数框架的默认。

§2.5Ring Attention · 跨卡接力,长上下文不再卡死

Ring Attention — Ring Attention with Blockwise Transformers for Near-Infinite Context
ICLR 2024 Liu, Zaharia, Abbeel · UC Berkeley · arXiv:2310.01889

关键想法:把 sequence 切到 $P$ 张卡上 → 每张卡只持有 $S/P$ 个 token 的 $Q_i, K_i, V_i$。让 $K, V$ tile 环形传一圈, 每张卡在收到来自邻居的 KV 时,立刻和自己的 $Q$ 做一段 FlashAttention 累加。 一次 forward 一共 $P$ 步通信, 每步可以与计算 overlap,等价于一张 巨型卡装下整段 attention。

Ring 让 1M-token 上下文真的能训。Gemini 1.5 的 1M、Llama-3.1 的 128k、 Qwen2 的 1M,几乎都依赖 Ring 或它的变种(Striped Attention, Context Parallel in Megatron, Flash-Mask Ring, DeepSpeed-Ulysses)。

# Ring Attention 一次 forward 的骨架(教学版)
def ring_attn_forward(Q_local, K_local, V_local, rank, world_size):
    """
    每张卡持有 [S/P, D] 的 Q/K/V。
    一共 world_size 步,每步:
      1. 用本地 K/V 做一次 FlashAttention 累加进 O;
      2. 把 K/V 异步发给右邻居、同时从左邻居收 K/V;
      3. 换上新收到的 K/V,进入下一步。
    通信和计算 overlap, 总通信量 O(S·D), 与 P 无关。
    """
    O = torch.zeros_like(Q_local)
    K, V = K_local, V_local
    m = torch.full((Q_local.shape[0],), -float('inf'))
    l = torch.zeros_like(m)
    for step in range(world_size):
        # 异步 P2P
        K_next, V_next = isend_irecv(K, V, dst=(rank+1) % world_size,
                                      src=(rank-1) % world_size)
        # 与此同时本地 FA 累加
        O, m, l = flash_attn_accumulate(Q_local, K, V, O, m, l)
        wait(K_next, V_next)
        K, V = K_next, V_next
    return O / l[:, None]
两种长上下文并行 · Ring vs Ulysses

Ring 把 sequence 切,KV 在卡间走。通信量 $= O(S \cdot d)$ 与卡数无关,长 $S$ 优势大。
DeepSpeed-Ulysses 把 head 切,做 all-to-all 把 head 维和 sequence 维交换。通信量 $= O(S \cdot d)$,但实现简单。 一般规则:seq 极长(128k+)用 Ring,head 多(64+)用 Ulysses, 或者把两者组合(USP)。

§2.6稀疏 · 线性 · Mamba · NSA · attention 之外

上面三节是把 dense $O(S^2)$ 算得更快。 另一条路是从复杂度上把它降下来:

家族关键想法代表
滑窗 / 局部每个 token 只看附近 $w$ 个Longformer, Mistral, Sliding Window
Routing / TopK每 query 只关注 top-K 个最相关 tokenSparse Transformer, BigBird, Reformer
线性 attention用 $\phi(Q)(\phi(K)^\top V)$ 把顺序换掉Performer, Linear Transformer, Hedgehog
State-Space用 RNN-like 递推,$O(S)$ 推理S4, Mamba, Mamba-2, RWKV
Native Sparse Attention (NSA)训练即学的 hierarchical 稀疏DeepSeek-V3.2 / NSA 2025
Mamba-2 — Transformers are SSMs: Generalized Models and Efficient Algorithms
ICML 2024 Dao, Gu · arXiv:2405.21060

Mamba 把 attention 换成选择性 SSM, decode 内存与 $S$ 无关、是常数;同时硬件友好(结构上是 GEMM-able)。 在 codegen / 长上下文 retrieval 上仍稍逊于 attention, 但混合(如 Jamba、Zamba、Granite-4 Hybrid)已是常见配方。

Native Sparse Attention — Hardware-Aligned, Natively Trainable Sparse Attention
DeepSeek 2025-02 · arXiv:2502.11089

把 sparsity 从训练开始就引入,按 block 选 top-K, 实现上贴着 SM 拓扑写 kernel。 DeepSeek-V3.2 直接用上了,64k 上下文吞吐对 dense 是几倍提升。 路线开始压过传统"训完再剪" — sparse pretraining 重新流行。

NSA 把每个 query 的 attention 拆三路:

  • Compressed:把过去的 KV 按 block 平均 / max-pool, 每个 query 看一份压缩版的"远处历史"——抓全局信号。
  • Selected:在 compressed pool 之后选 top-K block 做原分辨率 attention——抓重要细节。
  • Sliding:本地滑窗——抓近邻。

三路 attention 加权求和,全部可微,反向也走稀疏 kernel。