Hub
›
Part II · Attention 算子
§2.1FlashAttention v1 · 把 softmax 装进 SRAM
FlashAttention
— Fast and Memory-Efficient Exact Attention with IO-Awareness
朴素 attention 必须把 $S \times S$ 的 score 矩阵物化在 HBM 上,
一来显存 $O(S^2)$,二来读 / 写两次 HBM 把内核打回 memory-bound——
长上下文上谁都跑不动。
关键想法:把 $Q$、$K$、$V$ 切成 tile,
内层循环里一次只把一对 $(Q_i, K_j, V_j)$ tile 加载到
SM 的 shared memory,
用 online softmax(Milakov & Gimelshein 2018)
在 tile 之间维护两个标量 $(m_i, \ell_i)$——
running max 与 running sum——边算边归一化。
最后输出 $O$ 直接累加到 HBM,从不把 $S\times S$ 物化。
在线 softmax · 一行不变的恒等式
给定已扫的旧最大值 $m$、旧分母 $\ell$ 与新一项 $s$,新值:
$$ m' = \max(m, s),\quad \ell' = \ell \cdot e^{m - m'} + e^{s - m'} $$
数学上 $\ell$ 和 $\sum e^{s_j - m'}$ 相等,所以可以拆任意多个块来算。
把 $V$ 那一侧的"加权和"也照样 rescale 即可。这个恒等式是整篇论文
的核心;剩下的工程都是怎么把它写进一个 GPU kernel。
教学版前向
# 教学版 FlashAttention 前向(PyTorch + Numpy 风格)
def flash_attn(Q, K, V, BQ=64, BK=64):
N, d = Q.shape
O = torch.zeros_like(Q)
m = torch.full((N,), -float('inf'))
l = torch.zeros((N,))
for i in range(0, N, BQ):
Qi = Q[i:i+BQ]
Oi = torch.zeros((BQ, d))
mi = torch.full((BQ,), -float('inf'))
li = torch.zeros((BQ,))
for j in range(0, N, BK):
Kj, Vj = K[j:j+BK], V[j:j+BK]
Sij = (Qi @ Kj.T) / d**0.5
mij = Sij.max(dim=-1).values
Pij = torch.exp(Sij - mij[:, None])
mi_new = torch.maximum(mi, mij)
scale = torch.exp(mi - mi_new)
li = li * scale + Pij.sum(dim=-1)
Oi = Oi * scale[:, None] + Pij @ Vj
mi = mi_new
O[i:i+BQ] = Oi / li[:, None]
return O
Triton 视角 · 真实 kernel 的骨架
# 简化的 Triton 内核骨架(FA-2 风格,省略 mask/causal)
@triton.jit
def fa_fwd(Q_ptr, K_ptr, V_ptr, O_ptr,
N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, D: tl.constexpr):
pid_m = tl.program_id(0) # 一个 program 处理 BLOCK_M 个 Q-行
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, D)
# 把整块 Qi 一次 load 到 SRAM(不再回写)
Qi = tl.load(Q_ptr + offs_m[:, None]*D + offs_d[None, :])
mi = tl.full([BLOCK_M], -float('inf'), dtype=tl.float32)
li = tl.zeros([BLOCK_M], dtype=tl.float32)
Oi = tl.zeros([BLOCK_M, D], dtype=tl.float32)
for start_n in range(0, N, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
Kj = tl.load(K_ptr + offs_n[:, None]*D + offs_d[None, :])
Vj = tl.load(V_ptr + offs_n[:, None]*D + offs_d[None, :])
S = tl.dot(Qi, tl.trans(Kj)) * (1.0 / tl.sqrt(D))
mij = tl.max(S, axis=1)
mi_new = tl.maximum(mi, mij)
Pij = tl.exp(S - mi_new[:, None])
scale = tl.exp(mi - mi_new)
li = li * scale + tl.sum(Pij, axis=1)
Oi = Oi * scale[:, None] + tl.dot(Pij.to(Vj.dtype), Vj)
mi = mi_new
Oi = Oi / li[:, None]
tl.store(O_ptr + offs_m[:, None]*D + offs_d[None, :], Oi)
Demo 2 · 在线 softmax + tile 扫描
左侧的 $S = QK^\top$ 矩阵被一行一行 tile 扫描;
右侧实时显示当前 Q-block 每行的 running max $m$ 与 sum $\ell$。
每扫完一整行,$O$ 自动归一化——结果与逐元素 softmax 精确相等。
把 tile 调大 = 更少 kernel launch + 更高 occupancy,但需要的 shared memory 更多——
这就是 H100 上 FlashAttention 选 BQ=64/128 的取舍。
v1 在 A100 上把 GPT-2 1.5B 训练里 attention 的耗时从 50% 砍到 5%,
内存 $O(S^2) \to O(S)$。从此长上下文成为可能。
§2.2FlashAttention v2 · 把工作搬到正确的轴上
FlashAttention-2
— Faster Attention with Better Parallelism and Work Partitioning
v1 把外层循环放在 K-block 上、内层放 Q-block。
每个 K-block 内部所有 warp 都要写同一个 $O_{i}$,
引入大量 atomic add + 仅在 head 与 batch 维并行——
长上下文 / 小 batch 时 SM 没吃饱。
关键想法:
- 外层循环改在 Q-block,每个 thread block 独占一段 query → 无需 atomic。
- 在 seq 维加一层并行(thread block 维度 = $\lceil S / B_Q \rceil$),
长序列时 SM 终于吃饱。
- 减少非 matmul 操作:rescale 时把除法挪到 epilogue。
v2 在 A100 上把前向打到 ~225 TFLOPS(peak 312),
比 v1 快约 2×。所有现代框架(vLLM、SGLang、TRT-LLM、TGI、TF Transformer
Engine)都默认用 v2 backward。
§2.3FlashAttention v3 · Hopper / FP8 / 异步
FlashAttention-3
— Fast and Accurate Attention with Asynchrony and Low-Precision
关键想法:充分利用 H100 上新东西——
- WGMMA(warpgroup matmul):tensor core 异步发射,CPU
发完命令就走,不阻塞。
- TMA(tensor memory accelerator):搬数据交给 DMA 引擎,
SM 不用占线程做拷贝。
- warp specialization:4 个 warp 专门 producer 搬数据、
其余做 consumer 算 matmul,pipeline 自然形成。
- FP8:score / output 走 e4m3,
soft-attention sink + block-wise scaling 防溢出,
把吞吐再 ×1.5。
H100 FP16 上从 v2 的 ~340 TFLOPS 升到 ~690 TFLOPS(peak 989),
FP8 上 ~1.2 PFLOPS。Llama-3-70B prefill 一次推到 4× 实测加速。
派生分支 · 让 FlashAttention 适应 "怪" mask
| 变种 | 解决的问题 |
| FlashDecoding (Dao 2023) | decode 阶段只有 1 个 query;切 KV 在 seq 维并行,比 v2 快 8×。 |
| FlashDecoding++ (Hong 2024) | 统一 softmax 归一化常数 + 平面化 KV 减少 atomic。 |
| FlashAttention-2 with sliding window | 支持 Mistral 的 4096 滑窗,把 mask 当成 tile-skipping 条件。 |
| FlexAttention (PyTorch 2024) | 把 mask 当一阶函数喂给 compiler,编一个定制 kernel — 一行 Python = 一个 SoTA kernel。 |
| FlashInfer | 专做推理:paged KV + 多种 mask,是 vLLM/TGI 的默认后端之一。 |
| SageAttention / SageAttention-2 | INT8/INT4 attention:把 QK 走 INT8 GEMM,softmax 走 fp16,~30% 加速近无损。 |
§2.4PagedAttention · 把 KV 当虚拟内存
vLLM / PagedAttention
— Efficient Memory Management for LLM Serving
传统 serving 把每个请求的 KV cache 当连续大块预分配。
不同请求的输出长度差异巨大,"预定一个最坏情况" → 显存碎片 60%+,
batch size 上不去。
关键想法:把 KV cache 按固定 16 token 大小的 page
分块管理,每个序列维护一份 page table(就像操作系统的 4 KB 页表)。
分配 = 找空 page;释放 = 还 page。
碎片几乎为 0,同一段 prompt 的多个 sample(beam / parallel
sampling)可以共享同一份 KV pages,
仅 copy-on-write 那一格。
在 OPT-13B 上 throughput ×4 vs HuggingFace TGI 当时版本。
PagedAttention 已经是事实标准——
几乎所有 serving 框架都重新实现了一份。
# Page table 长这样(最小化伪代码)
class Sequence:
block_table: list[int] # 物理 block 编号
seq_len: int
# 当 seq 推进一个 token:
if seq.seq_len % BLOCK == 0:
seq.block_table.append(allocator.alloc())
write_kv(physical_blocks[seq.block_table[-1]], offset=seq.seq_len % BLOCK)
# Attention kernel 接受 block_table,做 indirect address lookup
attn_output = paged_attention_kernel(
q, k_cache_phys, v_cache_phys, block_table, seq_len)
系统视角 · Page 是 LLM 服务的"4KB"
PagedAttention 借的就是操作系统教科书。Block 大小 16 经过了仔细权衡:
太小 → page table 长,attention kernel 间接寻址 overhead 大;
太大 → 内部碎片回来。SGLang / TRT-LLM 一度尝试过 32 / 64,
但 16 仍是大多数框架的默认。
§2.5Ring Attention · 跨卡接力,长上下文不再卡死
Ring Attention
— Ring Attention with Blockwise Transformers for Near-Infinite Context
关键想法:把 sequence 切到 $P$ 张卡上 → 每张卡只持有
$S/P$ 个 token 的 $Q_i, K_i, V_i$。让 $K, V$ tile 环形传一圈,
每张卡在收到来自邻居的 KV 时,立刻和自己的 $Q$ 做一段
FlashAttention 累加。
一次 forward 一共 $P$ 步通信,
每步可以与计算 overlap,等价于一张
巨型卡装下整段 attention。
Ring 让 1M-token 上下文真的能训。Gemini 1.5 的 1M、Llama-3.1 的 128k、
Qwen2 的 1M,几乎都依赖 Ring 或它的变种(Striped Attention,
Context Parallel in Megatron, Flash-Mask Ring, DeepSpeed-Ulysses)。
# Ring Attention 一次 forward 的骨架(教学版)
def ring_attn_forward(Q_local, K_local, V_local, rank, world_size):
"""
每张卡持有 [S/P, D] 的 Q/K/V。
一共 world_size 步,每步:
1. 用本地 K/V 做一次 FlashAttention 累加进 O;
2. 把 K/V 异步发给右邻居、同时从左邻居收 K/V;
3. 换上新收到的 K/V,进入下一步。
通信和计算 overlap, 总通信量 O(S·D), 与 P 无关。
"""
O = torch.zeros_like(Q_local)
K, V = K_local, V_local
m = torch.full((Q_local.shape[0],), -float('inf'))
l = torch.zeros_like(m)
for step in range(world_size):
# 异步 P2P
K_next, V_next = isend_irecv(K, V, dst=(rank+1) % world_size,
src=(rank-1) % world_size)
# 与此同时本地 FA 累加
O, m, l = flash_attn_accumulate(Q_local, K, V, O, m, l)
wait(K_next, V_next)
K, V = K_next, V_next
return O / l[:, None]
两种长上下文并行 · Ring vs Ulysses
Ring 把 sequence 切,KV 在卡间走。通信量 $= O(S \cdot d)$
与卡数无关,长 $S$ 优势大。
DeepSpeed-Ulysses 把 head 切,做 all-to-all 把 head 维和
sequence 维交换。通信量 $= O(S \cdot d)$,但实现简单。
一般规则:seq 极长(128k+)用 Ring,head 多(64+)用 Ulysses,
或者把两者组合(USP)。
§2.6稀疏 · 线性 · Mamba · NSA · attention 之外
上面三节是把 dense $O(S^2)$ 算得更快。
另一条路是从复杂度上把它降下来:
| 家族 | 关键想法 | 代表 |
| 滑窗 / 局部 | 每个 token 只看附近 $w$ 个 | Longformer, Mistral, Sliding Window |
| Routing / TopK | 每 query 只关注 top-K 个最相关 token | Sparse Transformer, BigBird, Reformer |
| 线性 attention | 用 $\phi(Q)(\phi(K)^\top V)$ 把顺序换掉 | Performer, Linear Transformer, Hedgehog |
| State-Space | 用 RNN-like 递推,$O(S)$ 推理 | S4, Mamba, Mamba-2, RWKV |
| Native Sparse Attention (NSA) | 训练即学的 hierarchical 稀疏 | DeepSeek-V3.2 / NSA 2025 |
Mamba-2
— Transformers are SSMs: Generalized Models and Efficient Algorithms
Mamba 把 attention 换成选择性 SSM,
decode 内存与 $S$ 无关、是常数;同时硬件友好(结构上是 GEMM-able)。
在 codegen / 长上下文 retrieval 上仍稍逊于 attention,
但混合(如 Jamba、Zamba、Granite-4 Hybrid)已是常见配方。
Native Sparse Attention
— Hardware-Aligned, Natively Trainable Sparse Attention
把 sparsity 从训练开始就引入,按 block 选 top-K,
实现上贴着 SM 拓扑写 kernel。
DeepSeek-V3.2 直接用上了,64k 上下文吞吐对 dense 是几倍提升。
路线开始压过传统"训完再剪" — sparse pretraining 重新流行。
NSA 把每个 query 的 attention 拆三路:
- Compressed:把过去的 KV 按 block 平均 / max-pool,
每个 query 看一份压缩版的"远处历史"——抓全局信号。
- Selected:在 compressed pool 之后选 top-K block
做原分辨率 attention——抓重要细节。
- Sliding:本地滑窗——抓近邻。
三路 attention 加权求和,全部可微,反向也走稀疏 kernel。