最后的 30%
前 12 个 part 给的是算法层面的加速。这一部讲工程层面的—— 把 kernel launch 干掉、把 op 自动 fuse、把数据异步搬。 这些招通常各自只有 10-30% 提升, 但是 production 部署上"最后的 30%" 全靠它们。
- §13.1CUDA Graph
- §13.2torch.compile / Triton
- §13.3Kernel Fusion
- §13.4Mixed Precision
- §13.5异步预取与重计算
- §13.6还没解决的问题
§13.1CUDA Graph · 把 launch 开销消掉
Decode 时每步只跑一个 token,但要 launch 几百个 kernel—— 每个 launch 有 ~1–5 µs 的 host-side overhead。 在 7B 模型上 launch 开销能占 30–50%。
CUDA Graph:把一整段 kernel 序列录制成一张图, 之后整张图一次 launch。 vLLM / TRT-LLM 在 decode 时默认启用——decode 提速 ~2×。 限制:图固定,输入 shape 必须不变(不同 batch 要预录多张图)。
# CUDA Graph capture (PyTorch 2.x)
import torch.cuda
# 1. warm-up
for _ in range(3):
output = model(static_input)
# 2. capture
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
output = model(static_input)
# 3. replay
for _ in range(1000):
static_input.copy_(new_data) # 输入 buffer 是固定地址!
g.replay() # 一次 launch 整张图
use(output)
vLLM 对不同 batch size 预录一组图(1, 2, 4, 8, 16, 32, ..., 256); scheduler 根据当前 batch 选最近一档图来 replay,多余 pad 0。
§13.2torch.compile / Triton · 自动 fuse
Triton(Tillet 2019)让你用 Python 写 GPU kernel,编译器搞 tiling / shared-mem。 FlashAttention / Mamba / 大量 RLHF op 都是 Triton 写的。
torch.compile (PyTorch 2.x) = TorchDynamo + Inductor。 把 Python forward 翻成 IR,Inductor 后端为 GPU 生成 Triton kernel,把可融合 op 合并。 decode 时常见提速 ~30%。
# 一行: 把 model 编译成 fused kernel
compiled = torch.compile(model, mode='reduce-overhead', fullgraph=True)
# mode 选项:
# 'default' - 编译 + 一些 graph optimization
# 'reduce-overhead' - 自动用 CUDA Graph
# 'max-autotune' - 离线 autotune kernel 配置
Triton kernel 写一个 RMSNorm
@triton.jit
def rmsnorm_kernel(x_ptr, g_ptr, out_ptr, N: tl.constexpr, eps: tl.constexpr):
row = tl.program_id(0)
offs = tl.arange(0, N)
x = tl.load(x_ptr + row*N + offs)
g = tl.load(g_ptr + offs)
rms = tl.sqrt(tl.sum(x * x) / N + eps)
tl.store(out_ptr + row*N + offs, x / rms * g)
# 一行 Python = 一个 fused kernel, 比 PyTorch 慢的 eager 版本快 5-10x
§13.3Kernel Fusion · 一张速查表
| 融合点 | 例子 |
|---|---|
| RMSNorm + 残差 | FasterTransformer / TRT-LLM 标配 |
| QKV 投影 | 三个 $W_q, W_k, W_v$ 合成一个大 GEMM |
| SwiGLU FFN | $\text{SiLU}(xW_1) \odot xW_2$ 一次出 |
| Rotary + 写 KV cache | kernel 内做 rope 立刻写到 cache |
| Attention + soft-cap + sink + mask | FlashAttention 的常见 epilogue |
| Sampling (softmax + multinomial) | fused 采样 kernel |
# Fused QKV 投影: 三个矩阵 -> 一个大矩阵
# 原本: q = x @ Wq, k = x @ Wk, v = x @ Wv (3 次 launch + 3 次 HBM 读 x)
# 融合后: qkv = x @ Wqkv (1 次 launch, x 只读一次), 再 split
W_qkv = torch.cat([W_q, W_k, W_v], dim=-1) # [D, 3D]
qkv = x @ W_qkv
q, k, v = qkv.split(D, dim=-1)
§13.4Mixed Precision · BF16 + FP32 master
训练通货:BF16 forward/backward + FP32 master weights + FP32 accumulation。 loss scale 在 BF16 已不必要(指数位足够),FP16 时代必须。
推理:BF16 → FP16(H100 推理路径优化更好)→ FP8(H100/Blackwell)→ FP4 (Blackwell)。 把 GEMM 输入降 dtype,accumulate 仍走 FP32 以维持精度。
# PyTorch autocast: 全自动混合精度
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
loss = model(x).log_softmax(-1).nll_loss(y)
# matmul / conv: BF16
# softmax / loss / batchnorm reduce: FP32 (数值稳定)
# 反向: 梯度也是 BF16, 加到 FP32 master weights
# Hopper FP8 (Transformer Engine)
from transformer_engine.pytorch import fp8_autocast
with fp8_autocast():
out = model(x)
# attention QK GEMM 走 FP8 (e4m3), output proj 走 FP8 (e5m2), accumulate FP32
§13.5异步预取与重计算 · 让显存"虚胖"
关键招式整理:
- cudaMemcpyAsync + double buffer—— 把下一层 weights 异步从 CPU/NVMe prefetch,pipeline 隐藏 PCIe。
- Layer-wise CPU offload(DeepSpeed / FSDP CPUOffload)—— 把当前不用的 layer 推回 CPU。
- Selective Activation Recompute—— 只重算 attention,因为 attention activations 大但便宜重算。
- GS-Scale 路线—— 在生成式领域已有先例(3DGS 的 GS-Scale 把高斯放 host 内存), LLM 训练的 Pliny / FlashAttention-Offload 也走类似思路。
# Double-buffered weight prefetch (训练或推理都通用)
class OffloadedWeights:
def __init__(self, layers, host_mem):
self.host = host_mem # CPU/NVMe 存全量
self.gpu_buf = [None, None] # 两个 GPU buffer 轮换
self.stream = torch.cuda.Stream()
def get(self, i):
# 计算当前 layer 时, 异步预取下一层
with torch.cuda.stream(self.stream):
self.gpu_buf[(i+1) % 2] = self.host[i+1].cuda(non_blocking=True)
return self.gpu_buf[i % 2]
§13.6站在 2026-05 · 还没解决的问题
这份综述写于 2026 年 5 月,列出当下仍然空着的坑:
- 真正在 4-bit 推理下无损训出来的大模型。BitNet 路线证明可行, 但还没有 GPT-4 / Gemini 量级的 1.58-bit 公开模型。
- VLA 的 100 Hz 实时控制。 π0、CogACT 推动到 ~30 Hz,但触觉 / 双足等需要 200 Hz+。 在 Jetson Orin 量级的边缘卡上还做不到。
- VLM 视觉 token 的"原生"压缩。 FastV/VisionZip 都是 post-hoc 剪枝; 把"视觉 token 数量"作为模型可控参数的 SoTA 还没出现。
- 1-step diffusion 通用化。DMD2/SDXL Turbo 在 1024×1024 上 SoTA, 但 4k / 长视频 / 复杂 condition 下质量缩水严重。
- 更大上下文的高效 attention。NSA 把 dense O(S²) 推到 sparse O(S log S), 但 10M+ 上下文下与 Mamba/RWKV 等 RNN 派系的对决还没定胜负。
- 跨厂商 kernel 兼容性。FlashAttention-3 强绑 Hopper; AMD CDNA3 / 华为昇腾 / 国产卡都要写一套。 Triton 跨硬件后端是希望,但能不能跑得过手工 kernel 仍是问号。
- 统一的 P/D + EP + Caching 调度器。 vLLM / SGLang / TRT-LLM 各自有不完整的实现, 生产部署还是需要"再缝合"。
- RLHF 的 rollout-train 同卡复用。 目前 verl 类是分卡(infer & train 占不同 GPU 池子), 单卡 in-place 切换(参考 OneRL / RLite)才刚起步。
把这一切压缩成一句话:"无论你在做哪种 transformer/diffusion 模型, 优化目标永远是 ① 更少字节从 HBM 进来 ② 更多算力被复用 ③ 通信和算尽量重叠。" 所有 60+ 节都是这条主线上的某一颗螺丝。