关键想法:来自 OBQ (Optimal Brain Quantizer) 的 Hessian 思想。 对每层 $W$ 与一小批校准数据 $X$, 逐列选最优量化值并把"丢的精度"通过 Hessian 更新分摊到其余还没量化的列。 复杂度从原 OBQ 的 $O(n^4)$ 降到 $O(n^2)$ 并 GPU 化。
一句口诀:"量一列、补一列、再量下一列。" Llama 7B/13B INT4 几乎无损,single GPU 几小时完工。 GPTQ 是所有 w4a16 LLM 部署的起点。
Decode 阶段权重读入是头号开销——FP16 → INT4 几乎 4× 加速、内存 4× 缩小。 但激活值里那"1% 的 outlier"会毁掉一切。SmoothQuant 把 outlier 推给权重; GPTQ 用 Hessian 做最优分配;AWQ 保护关键 1% 通道; QuaRot/SpinQuant 用 Hadamard 旋转把 outlier 摊平;BitNet 干脆把 1.58 bit 训进去。
近几年 GPU 加了一堆新数据类型。一张速查表:
| 类型 | bits | 能存的 | 谁支持 |
|---|---|---|---|
| FP32 | 32 | 训练 master / Adam state | 所有 |
| BF16 | 16 | 训练主流通货 | A100/H100/MI300 |
| FP16 | 16 | 历史训练 / 推理 | 所有 |
| FP8 (e4m3 / e5m2) | 8 | Hopper / Blackwell 推理 | H100/H200/MI300 |
| FP6 (e3m2 / e2m3) | 6 | 实验性,多用在 GEMM in / out | Blackwell |
| MXFP4 / NVFP4 | 4 + group scale | 权重 / 激活极致量化 | Blackwell |
| INT8 | 8 | 权重 + 激活的最早 baseline | 所有 |
| INT4 (W4A16) | 4 (with group scale) | GPTQ / AWQ 主流 | 所有(kernel) |
(1) group quantization——大权重矩阵切成 group=128 的小块,
每块共享一个 scale (fp16) + zero-point (int)。
bits/weight = $n + 16/g$,g=128, n=4 → 4.125 bits。
(2) symmetric vs asymmetric——对称 $w \in [-A, A]$ 只需一个 scale,
适合权重。激活通常用非对称(保留 zero-point),因为 ReLU/SiLU 输出非零均值。
Blackwell 新加的 FP4 格式:尾数 2 bit、指数 1-2 bit。 单纯 4-bit 浮点 dynamic range 不够 → 用 block scale: 每 16 或 32 个数共享一个 FP8 scale。"bits/weight = 4 + 8/16 = 4.5"。 和 INT4-group128 (4.125 bits) 在精度上接近,但能直接走 tensor core, 量化版 GEMM 比 INT4 dequantize-then-FP16 快 ~2×。
# 教学版: NVFP4 风格的 block-scaled 量化
def nvfp4_quant(w, block=16):
n = w.shape[-1]
w = w.reshape(*w.shape[:-1], n // block, block)
amax = w.abs().max(dim=-1, keepdim=True).values
# NVFP4 max = 6.0 (e2m1 with 1 sign, 2 exp, 1 mantissa)
scale = (amax / 6.0).to(torch.float8_e4m3fn)
q = (w / scale.float() * 6.0).round().clip(-6.0, 6.0)
return q.to(torch.float4_e2m1fn), scale
关键想法:来自 OBQ (Optimal Brain Quantizer) 的 Hessian 思想。 对每层 $W$ 与一小批校准数据 $X$, 逐列选最优量化值并把"丢的精度"通过 Hessian 更新分摊到其余还没量化的列。 复杂度从原 OBQ 的 $O(n^4)$ 降到 $O(n^2)$ 并 GPU 化。
一句口诀:"量一列、补一列、再量下一列。" Llama 7B/13B INT4 几乎无损,single GPU 几小时完工。 GPTQ 是所有 w4a16 LLM 部署的起点。
设要量化的层是 $Y = WX$,$W \in \mathbb{R}^{d_{out} \times d_{in}}$。 GPTQ 选 $\hat{W}$ 最小化重建误差:
$$ \min_{\hat W} \| WX - \hat W X \|_F^2 $$当我们已经量化前 $i-1$ 列、现要量化第 $i$ 列时, 最优策略是把第 $i$ 列误差通过 Hessian 反向投影到剩余列, 让剩余列吸收。Hessian $H = 2XX^\top$,更新为:
$$ \hat{w}_i = \text{quant}(w_i),\quad w_{j \gt i} \mathrel{+}= -\frac{w_i - \hat w_i}{[H^{-1}]_{ii}} [H^{-1}]_{ij} $$# GPTQ 主循环 (教学版, 省略 group + cholesky)
def gptq(W, X, n_bits=4):
# W: [d_out, d_in], X: [d_in, n_calib]
H = 2 * X @ X.T # [d_in, d_in]
H += 1e-3 * torch.eye(H.shape[0]) # 数值稳定
Hinv = torch.linalg.cholesky(H).inverse().T # 实际用 cholesky 三角避免 inv
Q = torch.zeros_like(W)
for i in range(W.shape[1]):
q_col = quantize_column(W[:, i], n_bits)
Q[:, i] = q_col
err = (W[:, i] - q_col) / Hinv[i, i] # 标量列
# 把误差按 Hinv 行 i 摊给后面所有列
W[:, i+1:] -= err[:, None] * Hinv[i, i+1:][None, :]
return Q
关键观察:只1% 的 weight channel对输出影响巨大—— 它们对应那些激活值常常很大的输入维度。 AWQ 给这些"重要列"乘 $s$,对应的 activation 维度反向除 $s$(数学上等价), 使重要列的 dynamic range 缩小 → 量化误差几乎消失。 无需反向传播,只用前向 stats,分钟级完工。
比 GPTQ 略好且训练成本更低,是 llama.cpp / MLC / vLLM AWQ-kernel 的常驻选项。
# AWQ scale 选择: grid search 找一个 s 让量化 + 反量化后 MSE 最小
def awq_search_scale(W, X, n_bits=4):
# X: [n_calib, d_in], per-channel activation magnitude
act = X.abs().mean(dim=0) # [d_in]
best_loss, best_s = float('inf'), None
for alpha in torch.linspace(0, 1, 20):
s = act.pow(alpha) # [d_in], 重要列得到大 s
s = s / s.mean() # 归一化, 不改总规模
W_scaled = W * s[None, :]
Wq = fake_quant(W_scaled, n_bits)
Y_q = (Wq / s[None, :]) @ X.T
loss = (W @ X.T - Y_q).pow(2).mean()
if loss < best_loss:
best_loss, best_s = loss, s
return best_s
关键观察:W8A8 在 LLM 上崩,是因为 activation 有 "几个通道异常大"——softmax 之前的几个 channel 经常飙到 $\pm 100$。 Weight 是平的,所以两边都 INT8 ⇒ activation 端被"挤死"。
关键想法:定义一个 per-channel scale $s$: $Y = \mathrm{diag}(s)^{-1}\, \mathrm{diag}(s) \, X \, W$。 把 $X$ 的尖刺 "推" 到 $W$ 这一侧, 两边 dynamic range 互相平衡。 选 $s = (\max|X|)^\alpha / (\max|W|)^{1-\alpha}$。 $\alpha = 0.5$ 是甜蜜点。
使纯整数 W8A8 GEMM 在 LLM 上几乎无损 → 极大加速(int8 tensor core 比 fp16 快 2×)。 被 TensorRT-LLM、TurboMind 全数采用。后续 AWQ / QuaRot 都受其启发。
X 的第 $i$ 个 channel 的 max 是 $\sigma^X_i$,W 同位置是 $\sigma^W_i$。 我们想让缩放后两边的 max 接近。设缩放后 X 的 max 是 $\sigma^X_i / s_i$, W 的 max 是 $\sigma^W_i \cdot s_i$。让它们各自均衡到某个目标 $T$:
$$ \frac{\sigma^X_i}{s_i} = T,\quad \sigma^W_i \cdot s_i = T \;\;\Rightarrow\;\; s_i = \sqrt{\sigma^X_i / \sigma^W_i} $$$\alpha$ 是个软化系数——$\alpha=1$ 全推给权重,$\alpha=0$ 全留给激活, 实测 $\alpha = 0.5$ 最优。
关键想法:在每个 layer 进出处插入一对正交 Hadamard 矩阵 $H, H^{-1}$,乘到权重里 — 数学上完全等价, 但 activations 经过 $H$ 之后不再有 outlier (Hadamard 把"一个大值"摊到 $\sqrt{d}$ 个普通值上)。 所以 W4A4 (KV 也 4-bit) 突然可行——Llama-2-70B 仅掉 0.4 PPL。
QuaRot 的 Hadamard 是"硬编码"的。SpinQuant 让旋转矩阵可学, 用 Cayley 参数化保持正交,在校准集上端到端优化。 W4A4KV4 把质量推到 SoTA,Llama-3-70B 几乎无损。
Hadamard 矩阵的每一行是 $\pm 1/\sqrt{d}$,所以 $H x$ 的每个元素是 $x$ 中 所有元素的"加权平均"——大值被均摊到所有维度。 输入信号原本"集中在一两个 channel" → 旋转后"均匀分布在所有 channel"。 然后量化时不再有大 outlier 把 dynamic range 撑爆。 这就是 QuaRot 的核心:把 outlier 从"channel 维"扩散到"全部维度"。
关键想法:weights 只取 $\{-1, 0, +1\}$ 三值($\log_2 3 \approx 1.58$ bits)。 必须从头训练(PTQ 不行),用 STE 反传梯度。 所有 W·X 变 add/subtract—— Blackwell 上 BitNet kernel 比 fp16 GEMM 快 ~10×、内存 ~0.1×。 Llama-3-8B 同等 token 训出来与 fp16 baseline 持平。
路线和 PTQ 系列正交:weights 离散化嵌进训练目标里。 后续 BitNet b1.58 2B/3B(2024-12)证明可商用, BitNet a4.8(2024-10)把 activation 也压到 4-bit。
# BitNet 1.58: 把 weight 投到 {-1, 0, +1}, STE 反传
class BitLinear(nn.Module):
def __init__(self, in_f, out_f):
self.W = nn.Parameter(torch.randn(out_f, in_f))
def forward(self, x):
# 训练时:
w_abs_mean = self.W.abs().mean()
Wq = (self.W / w_abs_mean).round().clamp(-1, 1)
# STE: forward 用 Wq, backward 走 self.W 的梯度
Wq = self.W + (Wq * w_abs_mean - self.W).detach()
# x 也量化到 INT8 per-token
s = 127.0 / x.abs().max(dim=-1, keepdim=True).values
xq = (x * s).round().clamp(-128, 127) / s
return xq @ Wq.T