Hub Part V · 量化
Part V · 量化

把 FP16 weights 压成 4 bit

Decode 阶段权重读入是头号开销——FP16 → INT4 几乎 4× 加速、内存 4× 缩小。 但激活值里那"1% 的 outlier"会毁掉一切。SmoothQuant 把 outlier 推给权重; GPTQ 用 Hessian 做最优分配;AWQ 保护关键 1% 通道; QuaRot/SpinQuant 用 Hadamard 旋转把 outlier 摊平;BitNet 干脆把 1.58 bit 训进去。

本页 6 节 · 1 个 demo
  1. §5.1数据类型大杂烩
  2. §5.2GPTQ · OBQ
  3. §5.3AWQ · 关键通道
  4. §5.4SmoothQuant
  5. §5.5旋转家族
  6. §5.6BitNet · 1.58 bit

§5.1数据类型大杂烩 · 你需要的全部数字

近几年 GPU 加了一堆新数据类型。一张速查表:

类型bits能存的谁支持
FP3232训练 master / Adam state所有
BF1616训练主流通货A100/H100/MI300
FP1616历史训练 / 推理所有
FP8 (e4m3 / e5m2)8Hopper / Blackwell 推理H100/H200/MI300
FP6 (e3m2 / e2m3)6实验性,多用在 GEMM in / outBlackwell
MXFP4 / NVFP44 + group scale权重 / 激活极致量化Blackwell
INT88权重 + 激活的最早 baseline所有
INT4 (W4A16)4 (with group scale)GPTQ / AWQ 主流所有(kernel)
两个最常见的细节

(1) group quantization——大权重矩阵切成 group=128 的小块, 每块共享一个 scale (fp16) + zero-point (int)。 bits/weight = $n + 16/g$,g=128, n=4 → 4.125 bits。
(2) symmetric vs asymmetric——对称 $w \in [-A, A]$ 只需一个 scale, 适合权重。激活通常用非对称(保留 zero-point),因为 ReLU/SiLU 输出非零均值。

NVFP4 / MXFP4 的细节

Blackwell 新加的 FP4 格式:尾数 2 bit、指数 1-2 bit。 单纯 4-bit 浮点 dynamic range 不够 → 用 block scale: 每 16 或 32 个数共享一个 FP8 scale。"bits/weight = 4 + 8/16 = 4.5"。 和 INT4-group128 (4.125 bits) 在精度上接近,但能直接走 tensor core, 量化版 GEMM 比 INT4 dequantize-then-FP16 快 ~2×。

# 教学版: NVFP4 风格的 block-scaled 量化
def nvfp4_quant(w, block=16):
    n = w.shape[-1]
    w = w.reshape(*w.shape[:-1], n // block, block)
    amax = w.abs().max(dim=-1, keepdim=True).values
    # NVFP4 max = 6.0 (e2m1 with 1 sign, 2 exp, 1 mantissa)
    scale = (amax / 6.0).to(torch.float8_e4m3fn)
    q = (w / scale.float() * 6.0).round().clip(-6.0, 6.0)
    return q.to(torch.float4_e2m1fn), scale
Demo 5 · 量化 · 滑动 bit-width 看 MSE
橙线是原始权重(注意两个尖刺 outlier); 绿线是 INT$n$ 反量化后的结果。 把 bit-width 调到 4 仍能跟得很紧——但 outlier 处变扁,MSE 暴涨。 勾上 SmoothQuant 让权重把 outlier "分给" 激活, MSE 立刻回到无 outlier 时的水平—— 这就是 §5.4 要讲的核心招。

§5.2GPTQ · 用二阶信息一层一层量

GPTQ — Accurate Post-Training Quantization for Generative Pre-trained Transformers
ICLR 2023 Frantar, Ashkboos, Hoefler, Alistarh · arXiv:2210.17323

关键想法:来自 OBQ (Optimal Brain Quantizer) 的 Hessian 思想。 对每层 $W$ 与一小批校准数据 $X$, 逐列选最优量化值并把"丢的精度"通过 Hessian 更新分摊到其余还没量化的列。 复杂度从原 OBQ 的 $O(n^4)$ 降到 $O(n^2)$ 并 GPU 化。

一句口诀:"量一列、补一列、再量下一列。" Llama 7B/13B INT4 几乎无损,single GPU 几小时完工。 GPTQ 是所有 w4a16 LLM 部署的起点。

核心更新公式

设要量化的层是 $Y = WX$,$W \in \mathbb{R}^{d_{out} \times d_{in}}$。 GPTQ 选 $\hat{W}$ 最小化重建误差:

$$ \min_{\hat W} \| WX - \hat W X \|_F^2 $$

当我们已经量化前 $i-1$ 列、现要量化第 $i$ 列时, 最优策略是把第 $i$ 列误差通过 Hessian 反向投影到剩余列, 让剩余列吸收。Hessian $H = 2XX^\top$,更新为:

$$ \hat{w}_i = \text{quant}(w_i),\quad w_{j \gt i} \mathrel{+}= -\frac{w_i - \hat w_i}{[H^{-1}]_{ii}} [H^{-1}]_{ij} $$
# GPTQ 主循环 (教学版, 省略 group + cholesky)
def gptq(W, X, n_bits=4):
    # W: [d_out, d_in],  X: [d_in, n_calib]
    H = 2 * X @ X.T                       # [d_in, d_in]
    H += 1e-3 * torch.eye(H.shape[0])     # 数值稳定
    Hinv = torch.linalg.cholesky(H).inverse().T  # 实际用 cholesky 三角避免 inv
    Q = torch.zeros_like(W)
    for i in range(W.shape[1]):
        q_col = quantize_column(W[:, i], n_bits)
        Q[:, i] = q_col
        err = (W[:, i] - q_col) / Hinv[i, i]   # 标量列
        # 把误差按 Hinv 行 i 摊给后面所有列
        W[:, i+1:] -= err[:, None] * Hinv[i, i+1:][None, :]
    return Q

§5.3AWQ · 保护"关键 1% 通道"就够了

AWQ — Activation-aware Weight Quantization
MLSys 2024 Best Lin, Tang, Tang, Yang, Han · MIT · arXiv:2306.00978 · code

关键观察:只1% 的 weight channel对输出影响巨大—— 它们对应那些激活值常常很大的输入维度。 AWQ 给这些"重要列"乘 $s$,对应的 activation 维度反向除 $s$(数学上等价), 使重要列的 dynamic range 缩小 → 量化误差几乎消失。 无需反向传播,只用前向 stats,分钟级完工。

比 GPTQ 略好且训练成本更低,是 llama.cpp / MLC / vLLM AWQ-kernel 的常驻选项。

# AWQ scale 选择: grid search 找一个 s 让量化 + 反量化后 MSE 最小
def awq_search_scale(W, X, n_bits=4):
    # X: [n_calib, d_in],  per-channel activation magnitude
    act = X.abs().mean(dim=0)             # [d_in]
    best_loss, best_s = float('inf'), None
    for alpha in torch.linspace(0, 1, 20):
        s = act.pow(alpha)                # [d_in], 重要列得到大 s
        s = s / s.mean()                  # 归一化, 不改总规模
        W_scaled = W * s[None, :]
        Wq = fake_quant(W_scaled, n_bits)
        Y_q = (Wq / s[None, :]) @ X.T
        loss = (W @ X.T - Y_q).pow(2).mean()
        if loss < best_loss:
            best_loss, best_s = loss, s
    return best_s

§5.4SmoothQuant · 把激活的"刺"挪给权重

SmoothQuant — Accurate and Efficient Post-Training Quantization for LLMs
ICML 2023 Xiao, Lin, Seznec, Wu, Demouth, Han · MIT · arXiv:2211.10438

关键观察:W8A8 在 LLM 上崩,是因为 activation 有 "几个通道异常大"——softmax 之前的几个 channel 经常飙到 $\pm 100$。 Weight 是平的,所以两边都 INT8 ⇒ activation 端被"挤死"。

关键想法:定义一个 per-channel scale $s$: $Y = \mathrm{diag}(s)^{-1}\, \mathrm{diag}(s) \, X \, W$。 把 $X$ 的尖刺 "推" 到 $W$ 这一侧, 两边 dynamic range 互相平衡。 选 $s = (\max|X|)^\alpha / (\max|W|)^{1-\alpha}$。 $\alpha = 0.5$ 是甜蜜点。

使纯整数 W8A8 GEMM 在 LLM 上几乎无损 → 极大加速(int8 tensor core 比 fp16 快 2×)。 被 TensorRT-LLM、TurboMind 全数采用。后续 AWQ / QuaRot 都受其启发。

scale 公式推导

X 的第 $i$ 个 channel 的 max 是 $\sigma^X_i$,W 同位置是 $\sigma^W_i$。 我们想让缩放后两边的 max 接近。设缩放后 X 的 max 是 $\sigma^X_i / s_i$, W 的 max 是 $\sigma^W_i \cdot s_i$。让它们各自均衡到某个目标 $T$:

$$ \frac{\sigma^X_i}{s_i} = T,\quad \sigma^W_i \cdot s_i = T \;\;\Rightarrow\;\; s_i = \sqrt{\sigma^X_i / \sigma^W_i} $$

$\alpha$ 是个软化系数——$\alpha=1$ 全推给权重,$\alpha=0$ 全留给激活, 实测 $\alpha = 0.5$ 最优。

§5.5旋转家族 · QuaRot / SpinQuant · 让所有数变得"乖"

QuaRot — Outlier-Free 4-Bit Inference in Rotated LLMs
NeurIPS 2024 Ashkboos et al. · ETH / IST · arXiv:2404.00456

关键想法:在每个 layer 进出处插入一对正交 Hadamard 矩阵 $H, H^{-1}$,乘到权重里 — 数学上完全等价, 但 activations 经过 $H$ 之后不再有 outlier (Hadamard 把"一个大值"摊到 $\sqrt{d}$ 个普通值上)。 所以 W4A4 (KV 也 4-bit) 突然可行——Llama-2-70B 仅掉 0.4 PPL。

SpinQuant — LLM Quantization with Learned Rotations
ICLR 2025 Liu et al. · Meta · arXiv:2405.16406

QuaRot 的 Hadamard 是"硬编码"的。SpinQuant 让旋转矩阵可学, 用 Cayley 参数化保持正交,在校准集上端到端优化。 W4A4KV4 把质量推到 SoTA,Llama-3-70B 几乎无损。

直觉 · 为什么旋转能消 outlier

Hadamard 矩阵的每一行是 $\pm 1/\sqrt{d}$,所以 $H x$ 的每个元素是 $x$ 中 所有元素的"加权平均"——大值被均摊到所有维度。 输入信号原本"集中在一两个 channel" → 旋转后"均匀分布在所有 channel"。 然后量化时不再有大 outlier 把 dynamic range 撑爆。 这就是 QuaRot 的核心:把 outlier 从"channel 维"扩散到"全部维度"。

§5.6BitNet · 把 weights 压到 1.58 bit

BitNet b1.58 — The Era of 1-Bit LLMs
2024-02 / 2024-08 Ma, Wang, et al. · Microsoft · arXiv:2402.17764

关键想法:weights 只取 $\{-1, 0, +1\}$ 三值($\log_2 3 \approx 1.58$ bits)。 必须从头训练(PTQ 不行),用 STE 反传梯度。 所有 W·X 变 add/subtract—— Blackwell 上 BitNet kernel 比 fp16 GEMM 快 ~10×、内存 ~0.1×。 Llama-3-8B 同等 token 训出来与 fp16 baseline 持平。

路线和 PTQ 系列正交:weights 离散化嵌进训练目标里。 后续 BitNet b1.58 2B/3B(2024-12)证明可商用, BitNet a4.8(2024-10)把 activation 也压到 4-bit。

# BitNet 1.58: 把 weight 投到 {-1, 0, +1}, STE 反传
class BitLinear(nn.Module):
    def __init__(self, in_f, out_f):
        self.W = nn.Parameter(torch.randn(out_f, in_f))
    def forward(self, x):
        # 训练时:
        w_abs_mean = self.W.abs().mean()
        Wq = (self.W / w_abs_mean).round().clamp(-1, 1)
        # STE: forward 用 Wq, backward 走 self.W 的梯度
        Wq = self.W + (Wq * w_abs_mean - self.W).detach()
        # x 也量化到 INT8 per-token
        s = 127.0 / x.abs().max(dim=-1, keepdim=True).values
        xq = (x * s).round().clamp(-128, 127) / s
        return xq @ Wq.T
速查 · 怎么挑量化方案
  • 已有训好的 fp16 模型 + 1 GPU + 几小时 → GPTQ / AWQ
  • 想跑 W8A8 整数 GEMM → SmoothQuant
  • 想 W4A4 极致量化(连 KV 也 4-bit)→ QuaRot / SpinQuant
  • 有训练 budget,想 4× 推理加速 → BitNet b1.58(必须从头训)。
  • H100 / Blackwell 卡 → 优先考虑 FP8 / NVFP4(硬件原生)。