Hub Part VIII · 后训练加速
Part VIII · 后训练加速

把训练显存挤进单卡

全参微调 70B 一张 H100 都装不下。LoRA 把可训参数压到几 MB; QLoRA 把 base 模型量化到 4-bit,单 24GB 4090 微调 65B; FSDP 把权重切到 N 卡;TP/PP/SP 切到不同维度; verl / OpenRLHF 让 RLHF rollout 走 vLLM、训练走 FSDP,端到端加速 10×。

§8.1LoRA / QLoRA / DoRA · 把全参微调压成几 MB

LoRA — Low-Rank Adaptation of Large Language Models
ICLR 2022 Hu et al. · Microsoft · arXiv:2106.09685

冻结大模型 $W$,给每个矩阵加一个低秩残差 $W + \alpha \cdot AB$, $A \in \mathbb{R}^{d \times r}$,$B \in \mathbb{R}^{r \times d}$,$r \ll d$。 可训参数从 $d^2$ 降到 $2dr$。 7B 模型微调一个 task:从 80GB+ 显存降到 ~14GB,几 MB 的 checkpoint。 推理时把 $AB$ merge 回 $W$,零额外延迟。

# LoRA forward: y = x W^T + α x (BA)^T,  其中 A 初始化为零
class LoRALinear(nn.Module):
    def __init__(self, W, r=8, alpha=16):
        super().__init__()
        self.W = W                                  # frozen
        self.A = nn.Parameter(torch.zeros(r, W.shape[1]))
        self.B = nn.Parameter(torch.randn(W.shape[0], r) * 0.01)
        self.scale = alpha / r
    def forward(self, x):
        return x @ self.W.T + self.scale * (x @ self.A.T) @ self.B.T
# 推理时 merge: W_merged = W + scale * B @ A, 零额外算力开销
QLoRA — Efficient Finetuning of Quantized LLMs
NeurIPS 2023 Dettmers et al. · arXiv:2305.14314

把 base $W$ 量化到 NF4(4-bit normal float,针对正态分布优化), 只反传 LoRA 增量、用 paged optimizer 把 Adam state 放主存。 单张 24GB 4090 微调 65B 不是梦。

DoRA — Weight-Decomposed Low-Rank Adaptation
ICML 2024 Oral · arXiv:2402.09353

把权重拆为方向 + 大小:方向用 LoRA 学,大小用一个 vector 学。 LoRA 性能上限被往上推一档,几乎赶上全参 SFT。

其他常见变种:VeRALoRA+PiSSAOLoRArsLoRAMoRAQA-LoRAGaLore。 GaLore 走"低秩 gradient projection"路线,全参微调也能省显存。

§8.2FSDP / ZeRO · 切谁,什么时候切

ZeRO — Memory Optimization Toward Training Trillion Parameter Models
SC 2020 Rajbhandari et al. · Microsoft · arXiv:1910.02054

数据并行的"显存碎片"问题:每张卡都备份了一份权重 + 梯度 + Adam state,浪费。 ZeRO 把它们按 rank 切:

  • ZeRO-1:切 optimizer state(Adam m,v)→ 显存 / 4。
  • ZeRO-2:再切 gradients → 显存 / 8。
  • ZeRO-3:再切 weights → 显存 / N(卡数)。

前向 / 反向时各 rank "all-gather 当前需要的 shard 到全卡",用完即释放。 FSDP(PyTorch 原生)= ZeRO-3 的开源实现。

FSDP 一步的通信图

# FSDP 一层 forward (简化)
def fsdp_layer_forward(layer, x):
    # 1. 把当前层 weights 从全 rank all-gather 起来 (恢复 full shape)
    full_W = all_gather(layer.local_W_shard)
    # 2. 正常 forward
    out = layer.module(x, full_W)
    # 3. 立刻释放 (其他卡需要的话再 all-gather)
    del full_W
    return out

# Backward 时:
def fsdp_layer_backward(layer, grad_out):
    full_W = all_gather(layer.local_W_shard)
    grad_in, grad_W_full = backward_fn(grad_out, full_W)
    # reduce-scatter: 每个 rank 拿回它自己 shard 的梯度部分
    layer.local_W_shard.grad = reduce_scatter(grad_W_full)
    del full_W
    return grad_in

§8.3TP · PP · SP · 切模型的三条轴

并行切什么主要通信适合
Data parallel (DP)batch梯度 all-reduce小模型 / 多机
FSDP / ZeRO-3params + grads + stateweight all-gather + grad reduce-scatter大模型 + 中等节点数
Tensor parallel (TP)每个 GEMM 拆 head / row / col每层 2 次 all-reduce单节点 (NVLink) 内
Pipeline parallel (PP)层切到不同卡micro-batch 流水(少量)大模型 + 多节点
Sequence parallel (SP)seq 维切activation all-gather长上下文
Expert parallel (EP)experts 切到不同卡all-to-allMoE
Context parallel (Ring §2.5)seq 维 KV 走环ring 接力极长上下文

Tensor Parallel 速通

把一个 $W \in \mathbb{R}^{d \times d'}$ 沿列切到 $P$ 张卡,每张拿 $d \times d'/P$。 Forward $Y = XW$ 时每张卡独立算自己那一列;最后不需要立刻 all-gather——后面紧跟的 $W'$ 再沿行切,自然消掉中间维。 Megatron-LM 的核心 trick 就这一招。

# Megatron-style column + row TP
# Layer 1 (column-parallel)
class ColumnParallelLinear(nn.Module):
    def __init__(self, in_f, out_f, P):
        # 每个 rank 拿 (out_f / P) 列
        self.W = nn.Parameter(torch.randn(in_f, out_f // P))
    def forward(self, x):
        return x @ self.W              # 输出 shape: [B, S, out_f/P], 各卡持有自己一段

# Layer 2 (row-parallel)
class RowParallelLinear(nn.Module):
    def __init__(self, in_f, out_f, P):
        self.W = nn.Parameter(torch.randn(in_f // P, out_f))
    def forward(self, x):              # x 已经是 [B, S, in_f/P]
        y = x @ self.W                 # [B, S, out_f]
        return all_reduce(y)           # 各卡 sum, 得到 full output
3D / 4D / 5D 并行

大模型实际是这些维度组合使用: TP 在节点内 8 卡 NVLink,PP 跨节点,DP/FSDP 跨"replicas", 再叠 SP 处理长 seq,再叠 EP 处理 MoE。 Megatron-LM、DeepSpeed、Colossal-AI、MindSpore 都是这套配置的"DSL"。 DeepSeek-V3 671B 的训练是 EP=64 × DP × PP,没用 TP—— EP 通信和 TP 共用 NVLink 带宽时会打架。

§8.4Activation Checkpointing · 用算力换显存

训练时,前向产生的 activations 必须留到反向用 → 显存 $O(L \cdot B \cdot S \cdot D)$。 Activation checkpointing:只保留 layer 边界的 activation, 反向时再重算中间的。显存 $\downarrow \sqrt{L}$ 倍,FLOPs $\uparrow 33\%$

# 一段教学版手写: 在 backward 时重算 forward
class CheckpointedBlock(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, *params):
        ctx.save_for_backward(x, *params)
        with torch.no_grad():
            return block_forward(x, *params)   # 不存中间 activation
    @staticmethod
    def backward(ctx, grad_out):
        x, *params = ctx.saved_tensors
        with torch.enable_grad():
            x = x.detach().requires_grad_(True)
            y = block_forward(x, *params)      # 重算!
        return torch.autograd.grad(y, [x, *params], grad_out)

现代 trick:

  • selective recomputation(Megatron)—— 只重算 attention,不重算 GEMM。
  • FlashAttention 已经内置:attention 不存 $S \times S$,本身就免重算。
  • offload(DeepSpeed / FSDP)—— activation / Adam state 异步存到 CPU、NVMe。

§8.5RLHF / DPO / GRPO 加速 · verl / OpenRLHF

RLHF / GRPO 的训练 loop:

  1. rollout:当前 policy 生成 $N$ 条 trajectory(推理)。
  2. reward / value 评分(reward model 推理)。
  3. policy update(actor 反向)。

rollout 占 60%+ 总时间,但传统训练框架(Megatron / DeepSpeed)把它当训练 forward 跑—— 浪费。

verl / OpenRLHF / NeMo-Aligner
2024–2025 · verl · OpenRLHF

把 rollout 走 vLLM / SGLang(带 paged + specdec),policy update 走 Megatron / FSDP, 两套 engine 之间通过权重同步桥接(async / sync weight reshard)。 端到端 GRPO 训练比 baseline 快 3–10×。 DeepSeek-R1、Qwen-2.5-Math、StarCoder-Math 都在 verl 类系统上训出。

GRPO 一步

# GRPO (DeepSeek-Math/R1 用的 RL 算法) 一步
def grpo_step(policy, ref_policy, prompts, vllm_engine):
    # 1. 用 vLLM 大批量 rollout (G 个 sample / prompt)
    rollouts = vllm_engine.generate(prompts, n=G, temperature=1.0)
    # 2. reward = 任务级 reward (例如 math: 答案对错)
    rewards = compute_reward(rollouts)         # [B, G]
    # 3. 用 group 内 baseline 算 advantage
    A = (rewards - rewards.mean(dim=1, keepdim=True)) / rewards.std(dim=1, keepdim=True)
    # 4. policy update on (prompt, response, advantage)
    logp = policy.log_prob(rollouts)
    logp_ref = ref_policy.log_prob(rollouts).detach()
    ratio = (logp - logp_ref.detach()).exp()
    loss = -(torch.min(ratio * A, ratio.clamp(0.8, 1.2) * A)).mean() + KL_PENALTY * kl
    loss.backward()
    # 5. 把更新后的 weight 同步回 vLLM rollout engine (重要!)
    vllm_engine.update_weights(policy.state_dict())

§8.6长上下文 SFT · sequence packing + RingFlash

32k / 128k SFT 训练里,activations 暴涨是大头。 关键招:

  • Sequence Packing:把多条短样本拼成一个长序列(mask 标好), 避免 padding 浪费 50%+ 算力。
  • RingFlash / Context Parallel:见 §2.5,把 attention 跨卡接力。
  • Sliding Window + Yarn / LongRoPE:让模型先在短上下文 SFT, 再 RoPE 缩放扩到长上下文,避免一开始就吃 32k 的全注意力开销。
  • Llama-Factory / Axolotl / Open-Instruct:把上述全包成一行 yaml。
# Sequence packing: 多条样本拼成一段, document mask
def pack(samples, max_len=8192):
    packed_ids, packed_mask, doc_id = [], [], []
    cur_doc = 0
    for s in samples:
        if len(packed_ids) + len(s) > max_len: break
        packed_ids.extend(s)
        doc_id.extend([cur_doc] * len(s))
        cur_doc += 1
    # attention mask: 同一 doc_id 才能互相 attend
    M = torch.tensor(doc_id)[:, None] == torch.tensor(doc_id)[None, :]
    M = M & torch.tril(torch.ones_like(M))  # 加 causal
    return packed_ids, M