从 SFT 到 LoRA
给大模型"换上你的工作服"
预训练让大模型变成"博学的实习生"——它懂很多,却不知道在你这儿该怎么干活。
SFT(监督微调)教它怎么干,LoRA则让你用一张消费级显卡就能完成这件事。
大模型已经很聪明,
为什么还要再"调教"?
把一个大语言模型(LLM)的训练流程类比成培养一个员工,大概长这样:
- 预训练(Pretraining) — 像让 TA 读完整个互联网:百科、小说、代码、论文……得到一个什么都懂一点的通才。
- 监督微调(SFT) — 像入职培训:用"问题→标准答案"的对子告诉 TA:「我们公司客服就是这么说话的」。
- 偏好对齐(RLHF / DPO) — 像绩效反馈:在多个答案里告诉 TA「这条比那条好」,磨细风格与安全性。
这一节我们关心第 2 步。LoRA 不是另一种新的微调流派——它是让"第 2 步"变得便宜的实现技巧。
典型 Post-training 流水线
$\sim 10^{12}$ tokens
$\sim 10^{4}\text{–}10^{6}$ 对
$\sim 10^{4}\text{–}10^{5}$ 偏好
LoRA 主要应用在 SFT 阶段,但也能用于后两步。
SFT 的本质:给模型看「好答案」
SFT 用的损失函数和预训练一模一样——都是 next-token 交叉熵:
$$\mathcal{L}_{\text{SFT}} = - \sum_{t \in \text{response}} \log P_\theta(y_t \mid y_{\lt t},\; x)$$关键差别只有两点:
- 数据少而精:从万亿 token 的网络爬虫,缩到几千到几十万对人工写好的
(指令, 答案)。 - 只对答案算 loss:把 prompt 部分的 label 设成
-100(PyTorch 的 ignore_index),让模型学怎么答,而不是学怎么问。
下面这个交互演示让你看到这件事:
🎯 交互演示 · Loss Mask 可视化
鼠标悬停每个 token,看它是否计入损失。绿色 = 计入;灰色 = 屏蔽。
labels = input_ids.clone()
# 把 prompt 部分置为 -100,让 CE loss 忽略它
labels[: prompt_len] = -100
常见 SFT 数据格式(Alpaca 风)
{
"instruction": "用一句话解释什么是低秩矩阵。",
"input": "",
"output": "低秩矩阵的列空间维度远小于其形状所允许的最大维度。"
}
Stanford Alpaca 用 52K 条这样的样本就把 LLaMA 7B 调出了类 ChatGPT 风格。 [来源]
把全部参数都改一遍?
显存先崩。
给 1B 参数的模型做全参数 SFT,每个可训练参数大约需要:
- 2 字节存权重(fp16/bf16)
- 2 字节存梯度
- 8 字节存 Adam 的两个动量 m, v(fp32)
合计 ~16 字节 / 参数。再加上前向激活、KV cache、梯度累积,实际显存往往是参数量的 20 倍。
结果:7B 模型全参 SFT 已经吃掉一张 A100 80G,175B 直接劝退个人玩家。
📊 交互演示 · 显存估算器
关键直觉:
微调改的东西,本来就"很瘦"
线性代数里,一个 $d \times k$ 矩阵的秩 (rank) 等于它的列向量张成空间的维度。秩越低,这个矩阵"压缩"得越狠——把整个空间压成一个低维子空间。
LoRA 论文有一个大胆的假设:
预训练已经给了模型一个非常好的起点。SFT 让权重的变化量 $\Delta W$ "落在一个内在维度极低的子空间"—— 也就是说,$\Delta W$ 本身就接近一个低秩矩阵。
论文实验显示,在 175B 的 GPT-3 上,把 $\Delta W$ 约束到 秩 1 或 2 都几乎不掉点。 [来源]
右边的演示展示这件事的几何含义:用一张小图像逼近原图,只保留前 r 个奇异值,看 r 多小就够"看出原貌"。
🖼️ 交互演示 · 低秩重构
滑动 r,看一张 $64\times 64$ 的"权重图"被低秩近似还原得有多像。
LoRA 的核心:
把 $\Delta W$ 写成两个小矩阵的乘积
它在做什么?
原本你要训练一个 $d \times k$ 的"修改量" $\Delta W$,这是 $dk$ 个参数。LoRA 说:
既然 $\Delta W$ 反正是低秩的,那我干脆只让你训两个低秩因子 $B, A$,乘起来的 $BA$ 自动是秩 $\le r$ 的矩阵。
参数量从 $dk$ 变成 $r(d+k)$。当 $d=k=4096, r=8$ 时:
- 原本:$4096 \times 4096 = 16{,}777{,}216$
- LoRA:$8 \times (4096 + 4096) = 65{,}536$
- 缩减 256 倍(仅一层;模型有上百层,整体缩减更夸张)
🧮 交互演示 · 参数节省计算器
形状一眼看懂
$d\times k$
$d\times r$
$r\times k$
$B$ 又高又瘦,$A$ 又扁又长——它们的乘积是一个秩最多为 $r$ 的"瘦扁"矩阵 $\Delta W$。
训练时冻结,推理时合并:
LoRA 不增加任何延迟
① 训练阶段
冻结 🧊
训练 🔥
训练 🔥
- 梯度只流过 $A, B$
- 优化器只为 $A, B$ 维护动量
- 显存峰值 $\approx$ $3\text{-}5\times$ 小于全参 SFT
② 部署:保留 adapter
adapter
- 不同任务 → 不同 LoRA 文件(~10-100 MB)
- 一个 base 模型 + N 个 adapter
- 切任务零成本(换文件即可)
③ 部署:合并权重
- 把 $BA$ 加回 $W_0$,得到一个普通的 Transformer
- 推理延迟和原模型完全一样(不像 Adapter 方法)
- 代价:失去多任务切换的灵活性
动手:20 行 PEFT + TRL
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
dataset = load_dataset("tatsu-lab/alpaca", split="train")
peft_config = LoraConfig(
r=8, # 秩,常用 8/16/32
lora_alpha=16, # α = 2r 经验法则
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "v_proj"], # 或 "all-linear"
)
trainer = SFTTrainer(
model="meta-llama/Llama-3.2-1B",
train_dataset=dataset,
args=SFTConfig(
output_dir="lora-alpaca",
num_train_epochs=1,
per_device_train_batch_size=2,
learning_rate=2e-4, # LoRA 用 ~10× 于全参微调
),
peft_config=peft_config,
)
trainer.train()
来源: HF PEFT 文档 · TRL SFTTrainer · QLoRA 论文
调 LoRA 的"七寸"在哪?
r 秩
最重要的旋钮。8 / 16 / 32 是 90% 任务的最佳点。
r 不是越大越好——研究发现 r=256 偶尔最优,但 8 已经能拿到 $\geq 95\%$ 的性能。任务越"窄"(风格模仿、领域知识),r 越小就够。
$\alpha$ alpha
默认规则:$\alpha = 2r$。
$\alpha/r$ 控制 LoRA 输出对原模型的"影响强度"。把 $\alpha$ 翻倍 $\approx$ 把学习率翻倍。如果想要不依赖 r 的稳定尺度,可用 RSLoRA 的 $\alpha/\sqrt{r}$。
target 目标层
原论文只动 $W_q, W_v$。
实践派推荐 all-linear:q, k, v, o + MLP 全套。
覆盖全部线性层效果最佳,代价是 LoRA 参数 $\times 4\text{-}5$。Llama 系如想省,至少加上 MLP 的 down_proj。
lr 学习率
LoRA 通常 1e-4 ~ 3e-4(比全参 SFT 高 10 倍)。
因为只有 BA 在更新,原模型不会被冲坏,可以"放心大胆"上大 lr。RL 阶段(DPO/GRPO)需降到 5e-6。
dropout
典型 0.0 ~ 0.1,常见 0.05。小数据集多加一点防过拟合,大数据集设 0 也行。
🍳 推荐入门配方
r = 16
lora_alpha = 32
lora_dropout = 0.05
target = "all-linear"
lr = 2e-4
epochs = 1-3
batch = 越大越好(用 grad-accum)
LoRA 之后:一个迅速膨胀的家族
一句话总结
"SFT 是用问答对教大模型干活;
LoRA 让你只训两小块矩阵就完成这件事——
参数省 10000 倍,效果几乎一样。"
🧠 该带走的 5 个事实
- SFT 与预训练共享同一个损失,但只在 response token 上算。
- 全参数 SFT 显存 $\approx$ 参数量 $\times 16\text{-}20$ 字节,根本吃不消。
- LoRA 假设 $\Delta W$ 是低秩的,所以分解为 $BA$,只训这 $r(d+k)$ 个数。
- B 零初始化保证训练起点等价于原模型,安全无扰动。
- 推理时可以合并权重,零延迟代价。