Hub Part III · KV Cache
Part III · KV Cache 压缩

KV 才是长上下文的真正瓶颈

Llama-3-70B 一个 token 的 KV 在 GQA-8 / FP16 下是 320 KB。 把它喂成 8k 上下文、批量 32,KV 就吃 80 GB——和权重等量。 所以 KV cache 压缩是过去两年最热的方向之一: head 共享(MQA/GQA/MLA)、量化、驱逐、前缀树共享,四条路同时压。

本页 5 节 · 1 个 demo
  1. §3.1MQA / GQA / MLA
  2. §3.2KV 量化
  3. §3.3KV 驱逐
  4. §3.4RadixAttention · 前缀树
  5. §3.5Chunked Prefill

§3.1MQA / GQA / MLA · 把 KV 头共享或潜变量化

MQA — Multi-Query Attention
arXiv 2019 Shazeer · arXiv:1911.02150

一句话:所有 query head 共享同一份 K 与 V。 KV cache 立刻除以 $H_q$ —— Llama-3-70B 上从 64× 缩到 1×。 代价:质量略掉,长上下文 retrieval 退化。

GQA — Grouped-Query Attention
EMNLP 2023 Ainslie, Lee-Thorp, de Jong, Zemlyanskiy, Lebrón, Sanghai · arXiv:2305.13245

MHA 与 MQA 的中间地带:把 $H_q$ 个 query head 分到 $H_{kv} \ll H_q$ 个组 (Llama-3 选 8)。质量与 MHA 持平、KV cache 缩 $H_q / H_{kv} = 8$×。 几乎所有现代开源模型(Llama 2/3、Qwen2、Mistral、Yi、DeepSeek 早期)默认 GQA。

# GQA: 一组 query head 共享一份 K,V
# 假设 H_q = 64, H_kv = 8 → group size = 8
def gqa(q, k, v):
    # q: [B, H_q,  S, D]   k,v: [B, H_kv, S, D]
    # 把 K,V 在 head 维复制 group_size 次(实际 kernel 是 broadcast,不真复制)
    k = k.repeat_interleave(H_q // H_kv, dim=1)
    v = v.repeat_interleave(H_q // H_kv, dim=1)
    return flash_attn(q, k, v)
MLA — Multi-head Latent Attention (DeepSeek-V2/V3)
DeepSeek 2024-05 / 2024-12 · DeepSeek-V2 · DeepSeek-V3 Tech Report

关键想法:与其共享 head,不如把每个 token 的 K 和 V 共同压缩到 一个低秩潜变量 $c_t \in \mathbb{R}^{r}$(DeepSeek-V3: $r = 512$)。 推理时只缓存 $c_t$;attention 时再用一个 up-proj 把 $c_t$ 投回完整 K, V。 再配合 rotary 的处理(保留一段非压缩部分),训得到与 MHA 接近的 quality。 KV 体积近似 7%–14% 的 MHA——这是 DeepSeek 能开 128k 上下文的根本。

MLA 解耦了"建模容量"与"缓存体积"——历史上头一回。 缺点:实现复杂,必须配套 RoPE 的特殊处理; serving 时需要专门的 kernel(vLLM、SGLang 都补了)。

MLA 的伪代码(含 RoPE 解耦)

# DeepSeek MLA: 训练时上投出 K/V, 推理时只缓存 c_kv
# rotary 部分分开走 (rope decoupling), 避免 c_kv 必须包含位置信息

def mla_forward(x, theta, kv_cache):
    # 1. 下投: hidden -> latent
    c_kv = x @ theta.W_DKV          # [B, S, r]  r=512
    k_rope = x @ theta.W_KR         # [B, S, d_rope]  独立的 rotary key
    k_rope = apply_rope(k_rope)

    # 2. 缓存 (c_kv, k_rope) 而不是 K, V
    kv_cache.append(c_kv, k_rope)

    # 3. attention 时再上投
    K_nope = c_kv @ theta.W_UK      # [B, S, H, d_h]   不带 rotary
    V      = c_kv @ theta.W_UV      # [B, S, H, d_h]
    K = concat([K_nope, k_rope.expand_to_heads()], dim=-1)

    # 4. q 也分两路: nope + rope
    q_lat = x @ theta.W_DQ
    Q_nope = q_lat @ theta.W_UQ
    Q_rope = apply_rope(q_lat @ theta.W_QR)
    Q = concat([Q_nope, Q_rope], dim=-1)

    return flash_attn(Q, K, V)
# 缓存 per token = r + d_rope = 512 + 64 ≈ 576 (DSv3, fp16 → ~1.1 KB)
# 对比 MHA 70B: 2 * 64 * 128 * 2 = 32 KB —— 缩到 ~3.5%
Demo 3 · KV cache · 把 cache 拉到爆显存
切 variant 看 KV 怎么塌缩——MQA / MLA 比 MHA 缩到 ~3% 的体积。 再开 FP8 / INT4 KV 量化,进一步打 2× / 4×。 这就是为什么"长上下文 + 高并发"成为 2024 之后的可能。 注意 GQA-4 与 GQA-8 在质量上几乎没差,但 KV 减半—— 所以 Llama-3.1-70B 实际选了 8。

§3.2KV 量化 · KIVI / KVQuant / FP8 KV

KIVI — Tuning-Free Asymmetric 2-bit Quantization for KV Cache
ICML 2024 Liu et al. · arXiv:2402.02750

关键观察:K 的分布在 channel 维上方差极大(有大 outlier), V 在 token 维上方差大。 所以——K 按 channel 量化(per-channel),V 按 token 量化(per-token)。 2-bit KV、量化 + 反量化在 attention 中即时做, 在 Llama / Mistral 上几乎不掉点。

# KIVI: K per-channel, V per-token
def kivi_quant_K(K):
    # K: [B, H, S, D]  → quantize along D-axis groups
    scale = K.abs().max(dim=2, keepdim=True).values   # per (B, H, channel)
    qK = (K / scale * 127).round().clip(-127, 127).to(int8)
    return qK, scale

def kivi_quant_V(V):
    # V: [B, H, S, D]  → quantize along D
    scale = V.abs().max(dim=-1, keepdim=True).values   # per (B, H, S)
    qV = (V / scale * 127).round().clip(-127, 127).to(int8)
    return qV, scale
KVQuant — Towards 10M Context Length with Quantized KV Cache
NeurIPS 2024 · arXiv:2401.18079

几招组合:non-uniform 量化 grid(用 sensitivity 学)、 pre-RoPE quant、dense-and-sparse decomposition(异常值保留 fp16)。 打到 ~1 bit + 异常值时仍能跑 1M+ context。

工业界更常见的是更朴素的 FP8 KV(vLLM / TensorRT-LLM 默认选项):8-bit per scalar,per-tensor 或 per-channel scale, 搭配 e4m3 / e5m2,吞吐 +30%,质量几乎无损。

§3.3KV 驱逐 · H2O / StreamingLLM / SnapKV

H2O — Heavy-Hitter Oracle
NeurIPS 2023 Zhang et al. · UT Austin / Rice · arXiv:2306.14048

经验观察:注意力的累积权重在长上下文里集中在少数 "重击者" token 上。 H2O 保留这些"高得分历史"以及一段"近期 window",其余的 KV 直接驱逐。 在 OPT / Llama 上把 KV 减 5× 而 longbench 不掉。

StreamingLLM — Efficient Streaming Language Models with Attention Sinks
ICLR 2024 Xiao, Tian, Chen, Han, Lewis · MIT / Meta · arXiv:2309.17453

一个奇怪的现象:drop 掉最早几个 token 会让模型"崩"—— 因为 softmax 必须把概率质量倒在某处,最早的 token 充当"sink"。 StreamingLLM 永远保留开头若干个 token + 最近 window, 其余驱逐,从而开启无限流式生成而不退化。 这一观察后来在 SoftMax-attention sink 的工程里反复被复用(FA-3 也用了)。

# StreamingLLM: keep first-k sinks + last-w window
def streaming_kv_select(K, V, k_sink=4, w_window=2048):
    S = K.shape[2]
    if S <= k_sink + w_window:
        return K, V
    K_sink, V_sink = K[..., :k_sink, :], V[..., :k_sink, :]
    K_win,  V_win  = K[..., -w_window:, :], V[..., -w_window:, :]
    return torch.cat([K_sink, K_win], dim=2), torch.cat([V_sink, V_win], dim=2)
SnapKV — LLM Knows What You Look For: Compressing Prompt KV
NeurIPS 2024 · arXiv:2404.14469

在 prefill 结束时,用 prompt 末尾几行的 attention 模式来推测"哪些早期 token 重要",对每层选 top-K KV 保留。后续 decode 都用这份压缩版。 在长 prompt(RAG / 多文档)场景把 KV 缩到 5%–15%,几乎不损质量。

派生分支地图

方向方法
训练时学驱逐Scissorhands, KeyFormer, Quest
per-head / per-layer 异构PyramidKV, AdaKV, ThinK
问题感知SnapKV (question-aware), QA-KV
合并而非丢CaM, GEAR, MiniCache
跨请求共享RadixAttention (§3.4), CacheBlend

§3.4RadixAttention · 前缀树共享缓存

SGLang / RadixAttention — Efficient Execution of Structured Language Model Programs
NeurIPS 2024 Zheng, Yin, Xie, ... Sheng, Stoica, Zhang · UC Berkeley / Stanford · arXiv:2312.07104 · code

Agent / RAG / few-shot 工作流里, 多个请求经常共享一段 system prompt 或文档前缀。 RadixAttention 把所有当前活跃的 prefix 维护成一棵 基数树(每条边是一段 token 序列), 一份 KV cache 物理上挂一次, 新请求只需"在最长共享前缀上分叉"。 命中 prefix → 跳过整段 prefill。

多轮对话场景 5–10× 吞吐。SGLang 把 RadixAttention 与 structured decoding(regex / JSON schema 强约束)结合, 成为 agent 系统的事实标准之一。

# RadixAttention: 维护一棵活跃 prefix 的基数树
class RadixNode:
    tokens: list[int]      # 该 edge 上的 token 序列
    kv_pages: list[int]    # 对应的 KV page id
    children: dict[int, RadixNode]
    ref_count: int

def insert(root, prompt_tokens):
    # 沿树往下走, 找到最长匹配 prefix
    node, depth = root, 0
    while depth < len(prompt_tokens):
        nxt = node.children.get(prompt_tokens[depth])
        if nxt is None: break
        # 检查 edge 上 token 是否还匹配
        m = match_len(nxt.tokens, prompt_tokens[depth:])
        if m == len(nxt.tokens):
            node, depth = nxt, depth + m   # 完全吃下这条 edge, 继续走
        else:
            split_edge(nxt, m); break
    # 剩余 token 新建 edge, alloc 新 KV pages
    new_node = alloc_edge(prompt_tokens[depth:])
    node.children[prompt_tokens[depth]] = new_node
    return new_node

§3.5Chunked Prefill · 让 decode 别被 prefill 噎死

Continuous batching(vLLM 引入)下,一个 batch 同时有: 正在 prefill 的长 prompt + 正在 decode 的一群短请求。 一个 8k prompt 的 prefill 占满 SM 时,decode 请求被堵在后面 → 时延飙升。

Chunked Prefill(Sarathi-Serve, DistServe, vLLM 0.6+): 把长 prompt 切成 ~512–1024 token 的 chunk, 每个 schedule step 只 prefill 一个 chunk 并和 decode token 拼 batch。 长 prompt 排队时延 ↓ ~5×,吞吐近乎无损。 现在是所有主流 serving 框架的默认开关。

# Sarathi-Serve 风格 scheduler 一步
def schedule_step(running_seqs, prefill_queue, chunk_budget=1024):
    batch = []
    # 优先把所有 decode 请求 (S=1 forward) 拼进去
    for s in running_seqs:
        if s.state == 'decode':
            batch.append((s, [s.last_token]))
    # 剩余预算分给 prefill 中的请求, 一个 chunk 一个 chunk 上
    budget = chunk_budget - len(batch)
    for s in prefill_queue:
        if budget <= 0: break
        chunk = s.next_prefill_chunk(min(s.prefill_remaining, budget))
        batch.append((s, chunk))
        budget -= len(chunk)
    run_one_forward(batch)
关键认知

KV cache 压缩的四条路是正交的——你可以同时开: Llama-3-70B 在 SGLang 上 GQA-8 + FP8 KV + Radix 共享 + chunked prefill 通常一起开,端到端吞吐比"全无"提升 10–20×。 长上下文要再叠 H2O / SnapKV / NSA。