冻结大模型 $W$,给每个矩阵加一个低秩残差 $W + \alpha \cdot AB$, $A \in \mathbb{R}^{d \times r}$,$B \in \mathbb{R}^{r \times d}$,$r \ll d$。 可训参数从 $d^2$ 降到 $2dr$。 7B 模型微调一个 task:从 80GB+ 显存降到 ~14GB,几 MB 的 checkpoint。 推理时把 $AB$ merge 回 $W$,零额外延迟。
把训练显存挤进单卡
全参微调 70B 一张 H100 都装不下。LoRA 把可训参数压到几 MB; QLoRA 把 base 模型量化到 4-bit,单 24GB 4090 微调 65B; FSDP 把权重切到 N 卡;TP/PP/SP 切到不同维度; verl / OpenRLHF 让 RLHF rollout 走 vLLM、训练走 FSDP,端到端加速 10×。
- §8.1LoRA / QLoRA / DoRA
- §8.2FSDP / ZeRO
- §8.3TP / PP / SP
- §8.4Activation Checkpointing
- §8.5RLHF / GRPO 加速
- §8.6长上下文 SFT
§8.1LoRA / QLoRA / DoRA · 把全参微调压成几 MB
# LoRA forward: y = x W^T + α x (BA)^T, 其中 A 初始化为零
class LoRALinear(nn.Module):
def __init__(self, W, r=8, alpha=16):
super().__init__()
self.W = W # frozen
self.A = nn.Parameter(torch.zeros(r, W.shape[1]))
self.B = nn.Parameter(torch.randn(W.shape[0], r) * 0.01)
self.scale = alpha / r
def forward(self, x):
return x @ self.W.T + self.scale * (x @ self.A.T) @ self.B.T
# 推理时 merge: W_merged = W + scale * B @ A, 零额外算力开销
把 base $W$ 量化到 NF4(4-bit normal float,针对正态分布优化), 只反传 LoRA 增量、用 paged optimizer 把 Adam state 放主存。 单张 24GB 4090 微调 65B 不是梦。
把权重拆为方向 + 大小:方向用 LoRA 学,大小用一个 vector 学。 LoRA 性能上限被往上推一档,几乎赶上全参 SFT。
其他常见变种:VeRALoRA+PiSSAOLoRArsLoRAMoRAQA-LoRAGaLore。 GaLore 走"低秩 gradient projection"路线,全参微调也能省显存。
§8.2FSDP / ZeRO · 切谁,什么时候切
数据并行的"显存碎片"问题:每张卡都备份了一份权重 + 梯度 + Adam state,浪费。 ZeRO 把它们按 rank 切:
- ZeRO-1:切 optimizer state(Adam m,v)→ 显存 / 4。
- ZeRO-2:再切 gradients → 显存 / 8。
- ZeRO-3:再切 weights → 显存 / N(卡数)。
前向 / 反向时各 rank "all-gather 当前需要的 shard 到全卡",用完即释放。 FSDP(PyTorch 原生)= ZeRO-3 的开源实现。
FSDP 一步的通信图
# FSDP 一层 forward (简化)
def fsdp_layer_forward(layer, x):
# 1. 把当前层 weights 从全 rank all-gather 起来 (恢复 full shape)
full_W = all_gather(layer.local_W_shard)
# 2. 正常 forward
out = layer.module(x, full_W)
# 3. 立刻释放 (其他卡需要的话再 all-gather)
del full_W
return out
# Backward 时:
def fsdp_layer_backward(layer, grad_out):
full_W = all_gather(layer.local_W_shard)
grad_in, grad_W_full = backward_fn(grad_out, full_W)
# reduce-scatter: 每个 rank 拿回它自己 shard 的梯度部分
layer.local_W_shard.grad = reduce_scatter(grad_W_full)
del full_W
return grad_in
§8.3TP · PP · SP · 切模型的三条轴
| 并行 | 切什么 | 主要通信 | 适合 |
|---|---|---|---|
| Data parallel (DP) | batch | 梯度 all-reduce | 小模型 / 多机 |
| FSDP / ZeRO-3 | params + grads + state | weight all-gather + grad reduce-scatter | 大模型 + 中等节点数 |
| Tensor parallel (TP) | 每个 GEMM 拆 head / row / col | 每层 2 次 all-reduce | 单节点 (NVLink) 内 |
| Pipeline parallel (PP) | 层切到不同卡 | micro-batch 流水(少量) | 大模型 + 多节点 |
| Sequence parallel (SP) | seq 维切 | activation all-gather | 长上下文 |
| Expert parallel (EP) | experts 切到不同卡 | all-to-all | MoE |
| Context parallel (Ring §2.5) | seq 维 KV 走环 | ring 接力 | 极长上下文 |
Tensor Parallel 速通
把一个 $W \in \mathbb{R}^{d \times d'}$ 沿列切到 $P$ 张卡,每张拿 $d \times d'/P$。 Forward $Y = XW$ 时每张卡独立算自己那一列;最后不需要立刻 all-gather——后面紧跟的 $W'$ 再沿行切,自然消掉中间维。 Megatron-LM 的核心 trick 就这一招。
# Megatron-style column + row TP
# Layer 1 (column-parallel)
class ColumnParallelLinear(nn.Module):
def __init__(self, in_f, out_f, P):
# 每个 rank 拿 (out_f / P) 列
self.W = nn.Parameter(torch.randn(in_f, out_f // P))
def forward(self, x):
return x @ self.W # 输出 shape: [B, S, out_f/P], 各卡持有自己一段
# Layer 2 (row-parallel)
class RowParallelLinear(nn.Module):
def __init__(self, in_f, out_f, P):
self.W = nn.Parameter(torch.randn(in_f // P, out_f))
def forward(self, x): # x 已经是 [B, S, in_f/P]
y = x @ self.W # [B, S, out_f]
return all_reduce(y) # 各卡 sum, 得到 full output
大模型实际是这些维度组合使用: TP 在节点内 8 卡 NVLink,PP 跨节点,DP/FSDP 跨"replicas", 再叠 SP 处理长 seq,再叠 EP 处理 MoE。 Megatron-LM、DeepSpeed、Colossal-AI、MindSpore 都是这套配置的"DSL"。 DeepSeek-V3 671B 的训练是 EP=64 × DP × PP,没用 TP—— EP 通信和 TP 共用 NVLink 带宽时会打架。
§8.4Activation Checkpointing · 用算力换显存
训练时,前向产生的 activations 必须留到反向用 → 显存 $O(L \cdot B \cdot S \cdot D)$。 Activation checkpointing:只保留 layer 边界的 activation, 反向时再重算中间的。显存 $\downarrow \sqrt{L}$ 倍,FLOPs $\uparrow 33\%$。
# 一段教学版手写: 在 backward 时重算 forward
class CheckpointedBlock(torch.autograd.Function):
@staticmethod
def forward(ctx, x, *params):
ctx.save_for_backward(x, *params)
with torch.no_grad():
return block_forward(x, *params) # 不存中间 activation
@staticmethod
def backward(ctx, grad_out):
x, *params = ctx.saved_tensors
with torch.enable_grad():
x = x.detach().requires_grad_(True)
y = block_forward(x, *params) # 重算!
return torch.autograd.grad(y, [x, *params], grad_out)
现代 trick:
- selective recomputation(Megatron)—— 只重算 attention,不重算 GEMM。
- FlashAttention 已经内置:attention 不存 $S \times S$,本身就免重算。
- offload(DeepSpeed / FSDP)—— activation / Adam state 异步存到 CPU、NVMe。
§8.5RLHF / DPO / GRPO 加速 · verl / OpenRLHF
RLHF / GRPO 的训练 loop:
- rollout:当前 policy 生成 $N$ 条 trajectory(推理)。
- reward / value 评分(reward model 推理)。
- policy update(actor 反向)。
rollout 占 60%+ 总时间,但传统训练框架(Megatron / DeepSpeed)把它当训练 forward 跑—— 浪费。
把 rollout 走 vLLM / SGLang(带 paged + specdec),policy update 走 Megatron / FSDP, 两套 engine 之间通过权重同步桥接(async / sync weight reshard)。 端到端 GRPO 训练比 baseline 快 3–10×。 DeepSeek-R1、Qwen-2.5-Math、StarCoder-Math 都在 verl 类系统上训出。
GRPO 一步
# GRPO (DeepSeek-Math/R1 用的 RL 算法) 一步
def grpo_step(policy, ref_policy, prompts, vllm_engine):
# 1. 用 vLLM 大批量 rollout (G 个 sample / prompt)
rollouts = vllm_engine.generate(prompts, n=G, temperature=1.0)
# 2. reward = 任务级 reward (例如 math: 答案对错)
rewards = compute_reward(rollouts) # [B, G]
# 3. 用 group 内 baseline 算 advantage
A = (rewards - rewards.mean(dim=1, keepdim=True)) / rewards.std(dim=1, keepdim=True)
# 4. policy update on (prompt, response, advantage)
logp = policy.log_prob(rollouts)
logp_ref = ref_policy.log_prob(rollouts).detach()
ratio = (logp - logp_ref.detach()).exp()
loss = -(torch.min(ratio * A, ratio.clamp(0.8, 1.2) * A)).mean() + KL_PENALTY * kl
loss.backward()
# 5. 把更新后的 weight 同步回 vLLM rollout engine (重要!)
vllm_engine.update_weights(policy.state_dict())
§8.6长上下文 SFT · sequence packing + RingFlash
32k / 128k SFT 训练里,activations 暴涨是大头。 关键招:
- Sequence Packing:把多条短样本拼成一个长序列(mask 标好), 避免 padding 浪费 50%+ 算力。
- RingFlash / Context Parallel:见 §2.5,把 attention 跨卡接力。
- Sliding Window + Yarn / LongRoPE:让模型先在短上下文 SFT, 再 RoPE 缩放扩到长上下文,避免一开始就吃 32k 的全注意力开销。
- Llama-Factory / Axolotl / Open-Instruct:把上述全包成一行 yaml。
# Sequence packing: 多条样本拼成一段, document mask
def pack(samples, max_len=8192):
packed_ids, packed_mask, doc_id = [], [], []
cur_doc = 0
for s in samples:
if len(packed_ids) + len(s) > max_len: break
packed_ids.extend(s)
doc_id.extend([cur_doc] * len(s))
cur_doc += 1
# attention mask: 同一 doc_id 才能互相 attend
M = torch.tensor(doc_id)[:, None] == torch.tensor(doc_id)[None, :]
M = M & torch.tril(torch.ones_like(M)) # 加 causal
return packed_ids, M