一句话:所有 query head 共享同一份 K 与 V。 KV cache 立刻除以 $H_q$ —— Llama-3-70B 上从 64× 缩到 1×。 代价:质量略掉,长上下文 retrieval 退化。
KV 才是长上下文的真正瓶颈
Llama-3-70B 一个 token 的 KV 在 GQA-8 / FP16 下是 320 KB。 把它喂成 8k 上下文、批量 32,KV 就吃 80 GB——和权重等量。 所以 KV cache 压缩是过去两年最热的方向之一: head 共享(MQA/GQA/MLA)、量化、驱逐、前缀树共享,四条路同时压。
- §3.1MQA / GQA / MLA
- §3.2KV 量化
- §3.3KV 驱逐
- §3.4RadixAttention · 前缀树
- §3.5Chunked Prefill
§3.1MQA / GQA / MLA · 把 KV 头共享或潜变量化
MHA 与 MQA 的中间地带:把 $H_q$ 个 query head 分到 $H_{kv} \ll H_q$ 个组 (Llama-3 选 8)。质量与 MHA 持平、KV cache 缩 $H_q / H_{kv} = 8$×。 几乎所有现代开源模型(Llama 2/3、Qwen2、Mistral、Yi、DeepSeek 早期)默认 GQA。
# GQA: 一组 query head 共享一份 K,V
# 假设 H_q = 64, H_kv = 8 → group size = 8
def gqa(q, k, v):
# q: [B, H_q, S, D] k,v: [B, H_kv, S, D]
# 把 K,V 在 head 维复制 group_size 次(实际 kernel 是 broadcast,不真复制)
k = k.repeat_interleave(H_q // H_kv, dim=1)
v = v.repeat_interleave(H_q // H_kv, dim=1)
return flash_attn(q, k, v)
关键想法:与其共享 head,不如把每个 token 的 K 和 V 共同压缩到 一个低秩潜变量 $c_t \in \mathbb{R}^{r}$(DeepSeek-V3: $r = 512$)。 推理时只缓存 $c_t$;attention 时再用一个 up-proj 把 $c_t$ 投回完整 K, V。 再配合 rotary 的处理(保留一段非压缩部分),训得到与 MHA 接近的 quality。 KV 体积近似 7%–14% 的 MHA——这是 DeepSeek 能开 128k 上下文的根本。
MLA 解耦了"建模容量"与"缓存体积"——历史上头一回。 缺点:实现复杂,必须配套 RoPE 的特殊处理; serving 时需要专门的 kernel(vLLM、SGLang 都补了)。
MLA 的伪代码(含 RoPE 解耦)
# DeepSeek MLA: 训练时上投出 K/V, 推理时只缓存 c_kv
# rotary 部分分开走 (rope decoupling), 避免 c_kv 必须包含位置信息
def mla_forward(x, theta, kv_cache):
# 1. 下投: hidden -> latent
c_kv = x @ theta.W_DKV # [B, S, r] r=512
k_rope = x @ theta.W_KR # [B, S, d_rope] 独立的 rotary key
k_rope = apply_rope(k_rope)
# 2. 缓存 (c_kv, k_rope) 而不是 K, V
kv_cache.append(c_kv, k_rope)
# 3. attention 时再上投
K_nope = c_kv @ theta.W_UK # [B, S, H, d_h] 不带 rotary
V = c_kv @ theta.W_UV # [B, S, H, d_h]
K = concat([K_nope, k_rope.expand_to_heads()], dim=-1)
# 4. q 也分两路: nope + rope
q_lat = x @ theta.W_DQ
Q_nope = q_lat @ theta.W_UQ
Q_rope = apply_rope(q_lat @ theta.W_QR)
Q = concat([Q_nope, Q_rope], dim=-1)
return flash_attn(Q, K, V)
# 缓存 per token = r + d_rope = 512 + 64 ≈ 576 (DSv3, fp16 → ~1.1 KB)
# 对比 MHA 70B: 2 * 64 * 128 * 2 = 32 KB —— 缩到 ~3.5%
§3.2KV 量化 · KIVI / KVQuant / FP8 KV
关键观察:K 的分布在 channel 维上方差极大(有大 outlier), V 在 token 维上方差大。 所以——K 按 channel 量化(per-channel),V 按 token 量化(per-token)。 2-bit KV、量化 + 反量化在 attention 中即时做, 在 Llama / Mistral 上几乎不掉点。
# KIVI: K per-channel, V per-token
def kivi_quant_K(K):
# K: [B, H, S, D] → quantize along D-axis groups
scale = K.abs().max(dim=2, keepdim=True).values # per (B, H, channel)
qK = (K / scale * 127).round().clip(-127, 127).to(int8)
return qK, scale
def kivi_quant_V(V):
# V: [B, H, S, D] → quantize along D
scale = V.abs().max(dim=-1, keepdim=True).values # per (B, H, S)
qV = (V / scale * 127).round().clip(-127, 127).to(int8)
return qV, scale
几招组合:non-uniform 量化 grid(用 sensitivity 学)、 pre-RoPE quant、dense-and-sparse decomposition(异常值保留 fp16)。 打到 ~1 bit + 异常值时仍能跑 1M+ context。
工业界更常见的是更朴素的 FP8 KV(vLLM / TensorRT-LLM 默认选项):8-bit per scalar,per-tensor 或 per-channel scale, 搭配 e4m3 / e5m2,吞吐 +30%,质量几乎无损。
§3.3KV 驱逐 · H2O / StreamingLLM / SnapKV
经验观察:注意力的累积权重在长上下文里集中在少数 "重击者" token 上。 H2O 保留这些"高得分历史"以及一段"近期 window",其余的 KV 直接驱逐。 在 OPT / Llama 上把 KV 减 5× 而 longbench 不掉。
一个奇怪的现象:drop 掉最早几个 token 会让模型"崩"—— 因为 softmax 必须把概率质量倒在某处,最早的 token 充当"sink"。 StreamingLLM 永远保留开头若干个 token + 最近 window, 其余驱逐,从而开启无限流式生成而不退化。 这一观察后来在 SoftMax-attention sink 的工程里反复被复用(FA-3 也用了)。
# StreamingLLM: keep first-k sinks + last-w window
def streaming_kv_select(K, V, k_sink=4, w_window=2048):
S = K.shape[2]
if S <= k_sink + w_window:
return K, V
K_sink, V_sink = K[..., :k_sink, :], V[..., :k_sink, :]
K_win, V_win = K[..., -w_window:, :], V[..., -w_window:, :]
return torch.cat([K_sink, K_win], dim=2), torch.cat([V_sink, V_win], dim=2)
在 prefill 结束时,用 prompt 末尾几行的 attention 模式来推测"哪些早期 token 重要",对每层选 top-K KV 保留。后续 decode 都用这份压缩版。 在长 prompt(RAG / 多文档)场景把 KV 缩到 5%–15%,几乎不损质量。
派生分支地图
| 方向 | 方法 |
|---|---|
| 训练时学驱逐 | Scissorhands, KeyFormer, Quest |
| per-head / per-layer 异构 | PyramidKV, AdaKV, ThinK |
| 问题感知 | SnapKV (question-aware), QA-KV |
| 合并而非丢 | CaM, GEAR, MiniCache |
| 跨请求共享 | RadixAttention (§3.4), CacheBlend |
§3.4RadixAttention · 前缀树共享缓存
Agent / RAG / few-shot 工作流里, 多个请求经常共享一段 system prompt 或文档前缀。 RadixAttention 把所有当前活跃的 prefix 维护成一棵 基数树(每条边是一段 token 序列), 一份 KV cache 物理上挂一次, 新请求只需"在最长共享前缀上分叉"。 命中 prefix → 跳过整段 prefill。
多轮对话场景 5–10× 吞吐。SGLang 把 RadixAttention 与 structured decoding(regex / JSON schema 强约束)结合, 成为 agent 系统的事实标准之一。
# RadixAttention: 维护一棵活跃 prefix 的基数树
class RadixNode:
tokens: list[int] # 该 edge 上的 token 序列
kv_pages: list[int] # 对应的 KV page id
children: dict[int, RadixNode]
ref_count: int
def insert(root, prompt_tokens):
# 沿树往下走, 找到最长匹配 prefix
node, depth = root, 0
while depth < len(prompt_tokens):
nxt = node.children.get(prompt_tokens[depth])
if nxt is None: break
# 检查 edge 上 token 是否还匹配
m = match_len(nxt.tokens, prompt_tokens[depth:])
if m == len(nxt.tokens):
node, depth = nxt, depth + m # 完全吃下这条 edge, 继续走
else:
split_edge(nxt, m); break
# 剩余 token 新建 edge, alloc 新 KV pages
new_node = alloc_edge(prompt_tokens[depth:])
node.children[prompt_tokens[depth]] = new_node
return new_node
§3.5Chunked Prefill · 让 decode 别被 prefill 噎死
Continuous batching(vLLM 引入)下,一个 batch 同时有: 正在 prefill 的长 prompt + 正在 decode 的一群短请求。 一个 8k prompt 的 prefill 占满 SM 时,decode 请求被堵在后面 → 时延飙升。
Chunked Prefill(Sarathi-Serve, DistServe, vLLM 0.6+): 把长 prompt 切成 ~512–1024 token 的 chunk, 每个 schedule step 只 prefill 一个 chunk 并和 decode token 拼 batch。 长 prompt 排队时延 ↓ ~5×,吞吐近乎无损。 现在是所有主流 serving 框架的默认开关。
# Sarathi-Serve 风格 scheduler 一步
def schedule_step(running_seqs, prefill_queue, chunk_budget=1024):
batch = []
# 优先把所有 decode 请求 (S=1 forward) 拼进去
for s in running_seqs:
if s.state == 'decode':
batch.append((s, [s.last_token]))
# 剩余预算分给 prefill 中的请求, 一个 chunk 一个 chunk 上
budget = chunk_budget - len(batch)
for s in prefill_queue:
if budget <= 0: break
chunk = s.next_prefill_chunk(min(s.prefill_remaining, budget))
batch.append((s, chunk))
budget -= len(chunk)
run_one_forward(batch)
KV cache 压缩的四条路是正交的——你可以同时开: Llama-3-70B 在 SGLang 上 GQA-8 + FP8 KV + Radix 共享 + chunked prefill 通常一起开,端到端吞吐比"全无"提升 10–20×。 长上下文要再叠 H2O / SnapKV / NSA。