Post-Training 面试速记: 从分词到对齐

整理自 Day3 复习笔记。面向”白纸能写出来 + 能讲明白”的面试目标,每节固定四块:Key Insight → 实现层次图 → 代码 → 高频追问


目录

  1. MinHash — 海量文档去重
  2. Tokenizer 三件套: BPE / WordPiece / Unigram
  3. 反向传播 — Softmax + Attention 手推
  4. Memory-Efficient Attention — Online Softmax → FlashAttention
  5. Reward Model — Bradley-Terry pairwise
  6. RL Fine-tuning — 策略梯度 + 组内归一化
  7. GAE — 广义优势估计
  8. PPO — Clipped Surrogate + KL
  9. GRPO — PPO 砍掉 Value 网络
  10. DPO — 把 RM + RL 合成一个分类损失
  11. 一图全景: 四种对齐算法对比

1. MinHash — 海量文档去重

Key Insight: 对一个随机 hash \(h\), \(\boxed{\,P\bigl(\min h(A) = \min h(B)\bigr) = J(A, B)\,}\)。 用 \(K\) 个独立 hash 取最小,签名向量逐位相等的比例就是 Jaccard 的无偏估计。

为什么需要

预训练前必须去重: 重复样本会让模型记忆而不是泛化,还会让评测集泄露到训练集。 两两算 Jaccard 是 O(N²),百万文档跑不动。MinHash 把”集合比较”压缩到”向量比较”。

实现层次

                                ┌──────────────────────────┐
原始文档(string)               │  create_shingles(text, k) │   预处理
   │                            │  滑动 k 字符窗口          │
   └──────────────────────────►│  → set of strings        │
                                └─────────────┬────────────┘
                                              ▼
                                       shingles set
                                              │
                                ┌─────────────▼────────────────────┐
                                │  compute_signature(shingles)     │   签名生成
                                │  for shingle in document:        │
                                │    for i, (a,b) in hash_funcs:   │
                                │      h = (a*x + b) mod prime     │
                                │      sig[i] = min(sig[i], h)     │
                                └─────────────┬────────────────────┘
                                              ▼
                                  signature [num_perm]
                                              │
                                ┌─────────────▼─────────────┐
                                │  similarity(sig1, sig2)    │   相似度
                                │  Σ(sig1[i]==sig2[i])/N    │
                                │  ≈ Jaccard(A, B)          │
                                └────────────────────────────┘

工程外壳 (在 MinHash 类外):
   LSH 分桶 → 桶内两两比签名 → Union-Find 聚簇 → 每簇留一篇

代码骨架

class MinHash:
    def __init__(self, num_perm=128):
        self.num_perm = num_perm
        self.prime = 2147483647
        self.hash_funcs = []
        for i in range(num_perm):
            seed = hashlib.sha256(f"minhash_{i}".encode()).digest()
            a = int.from_bytes(seed[:4], "big") % self.prime or 1
            b = int.from_bytes(seed[4:8], "big") % self.prime
            self.hash_funcs.append((a, b))           # h(x) = (a*x + b) mod p

    def compute_signature(self, document: set) -> list:
        sig = [float("inf")] * self.num_perm
        for shingle in document:
            for i, (a, b) in enumerate(self.hash_funcs):
                sig[i] = min(sig[i], self._compute_hash(shingle, a, b))
        return sig

    def similarity(self, sig1, sig2) -> float:
        return sum(x == y for x, y in zip(sig1, sig2)) / len(sig1)

关键定理的证明梗概

命题: 对一个随机排列 \(\pi\),

\[P\bigl(\min \pi(A) = \min \pi(B)\bigr) = J(A, B) = \frac{|A \cap B|}{|A \cup B|}\]

证明:

  • 令 \(U = A \cup B\),在 \(U\) 上看 \(\pi\) 的最小值。\(\pi\) 在 \(U\) 上是均匀随机的全排列。
  • \(\min \pi(A) = \min \pi(B) \;\Longleftrightarrow\;\) \(U\) 中最小值同时属于 \(A\) 和 \(B\) \(\;\Longleftrightarrow\;\) \(U\) 中最小值属于 \(A \cap B\)。
  • 由对称性,\(U\) 中任一元素被选为最小的概率均为 $$1/ U $$。
  • 所以概率 $$= A \cap B / U = A \cap B / A \cup B = J(A, B)\(。\)\blacksquare$$

工程化: 真实排列 \(\pi\) 太贵,用 universal hash family \(h(x) = (a x + b) \bmod p\)(\(p\) 是大素数,\(a, b\) 随机) 近似一个随机排列 —— 工程上够好。

估计精度 (\(K\) = num_perm 怎么选)

\(K\) 个独立 hash 各自取最小,签名等值率 \(\hat{J} = X / K\),其中 \(X \sim \mathrm{Binomial}(K, J)\)。

\[\mathbb{E}[\hat{J}] = J \quad\text{(无偏)},\qquad \mathrm{Var}[\hat{J}] = \frac{J(1 - J)}{K} \quad\text{(方差随 } K \text{ 线性下降)}\]
\(K\) 标准差 (\(J=0.5\)) 95% CI 宽度
64 ±6.3% ~12.5%
128 ±4.4% ~8.7%
256 ±3.1% ~6.2%
1024 ±1.6% ~3.1%

工业经验: 128-256 起步,严格去重(LSH 阈值 0.7+)可以加到 512。

LSH 分桶: 把 \(\mathcal{O}(N^2)\) 压成近线性

直接两两比签名仍是 \(\mathcal{O}(N^2 K)\)。LSH (Locality-Sensitive Hashing) banding 解法:

把 \(K\) 个 hash 切成 \(b\) 个 band,每 band 包含 \(r\) 行 (\(b \cdot r = K\))。 两文档在至少一个 band 内 \(r\) 行全等 → 进入同桶 → 候选对。

S 曲线: 两文档 Jaccard 为 \(J\),被选为候选对的概率:

\[P(\text{候选}) = 1 - (1 - J^r)^b\]
\((b, r)\) \(J=0.5\) \(J=0.7\) \(J=0.8\) \(J=0.9\)
(20, 5) 0.47 0.97 1.00 1.00
(50, 4) 0.95 1.00 1.00 1.00
(25, 10) 0.02 0.36 0.82 1.00

\((b, r)\) 控制 S 曲线的陡峭程度和阈值位置: 想滤掉 \(J < 0.7\),选大 \(r\) 小 \(b\);想宽松些,反之。 经验阈值公式: \(\tau \approx (1/b)^{1/r}\),在 \(P = 0.5\) 处的拐点。

完整去重 pipeline

def dedupe_corpus(documents):
    mh = MinHash(num_perm=256)
    signatures = [mh.compute_signature(shingles(d)) for d in documents]

    # LSH banding 分桶
    buckets = defaultdict(list)
    b, r = 50, 4   # 50 bands × 4 rows = 200 (用前 200 维)
    for idx, sig in enumerate(signatures):
        for band in range(b):
            key = (band, tuple(sig[band*r : (band+1)*r]))
            buckets[key].append(idx)

    # Union-Find 聚簇
    uf = UnionFind(len(documents))
    for bucket in buckets.values():
        if len(bucket) > 1:
            for j in bucket[1:]:
                if mh.similarity(signatures[bucket[0]], signatures[j]) > 0.8:
                    uf.union(bucket[0], j)

    # 每簇留一篇
    return [documents[i] for i in uf.representatives()]

时间复杂度: 签名 O(N·K·|shingles|) + 分桶 O(N·b) + 桶内验证 O(候选数 · K)近线性

数值例子

d1 = "The quick brown fox jumps over the lazy dog."
d2 = "The quick brown fox jumps over the lazy dog!"   # 只差一个标点
d3 = "Completely unrelated sentence about cats."

sim(d1, d2) ≈ 0.97   ← 近重复 (会被聚到同一簇)
sim(d1, d3) ≈ 0.00   ← 不相关

MinHash vs SimHash vs 嵌入

方法 估计什么 输入要求 典型用途
MinHash Jaccard (集合) shingle 集 (字符 / 词 n-gram) 文档级去重
SimHash Cosine (向量) 特征向量 (TF-IDF 等) 网页指纹、相似搜索
Embedding (BGE, OpenAI) Cosine in 语义空间 文本 → 向量 语义去重、检索
Bloom Filter 完全相等 任意 精确查重

LLM 预训练去重: 先 MinHash + LSH 粗筛,再 embedding 模型在候选对上精筛(可选)。 RedPajama, FineWeb, SlimPajama 全都用这套。

高频追问

Q A
为什么 min-hash 碰撞概率 = Jaccard? 见上面证明:U=A∪B 中的最小元素均匀分布,同时属于交集的概率 = |A∩B|/|A∪B|
num_perm 怎么选? 标准差 ≈ √(J(1-J)/K);128 ±4.4%,256 ±3.1%。代价是签名内存和比较时间
MinHash 用什么 hash? 工程上用 (a·x + b) mod p (universal hash);近似随机排列。p 通常 2^31 - 1
LSH banding 的 (b, r) 怎么调? 想 J>τ 的对都被召回 → 让 S 曲线在 τ 附近过 0.5: τ ≈ (1/b)^(1/r)
为什么 shingle 用 5-gram 字符? 平衡: 太短(字符)→ 无信息;太长(句子)→ 微小差异就 miss。5-7 字符是经验最优
MinHash vs SimHash? MinHash 估 Jaccard(集合相似),适合 shingle 集;SimHash 估 cosine(向量相似),适合特征向量
为什么去重对 LLM 重要? (1) 重复 → 记忆而非泛化 (2) 评测数据泄露 (3) 浪费算力。GPT-3, LLaMA, Chinchilla 论文都强调
工业级 dedup 流程? 文档级 MinHash+LSH → 句子级 SimHash/Bloom → 评测集污染检查 (n-gram 重叠);三层都过

2. Tokenizer 三件套: BPE / WordPiece / Unigram

Key Insight: 都是 subword,唯一区别在”用什么准则合并/删除 token”。 BPE = 频率;WordPiece = 互信息;Unigram = 删除后 likelihood 下降最小。

为什么需要 subword

方案 问题
Word-level 词表巨大(50万+),OOV 严重
Char-level 序列太长,模型学不到语义
Subword(BPE/WP/Uni) 词表 30k~100k,常用词整 token,稀有词降到子词

实现层次 (以 BPE 为例)

══════════ 训练阶段 (Training) ══════════

原始 corpus (list[str])
   │
   ▼  get_word_freqs(corpus)            按空格切词 + 计数
{"low": 5, "lower": 2, ...}
   │
   ▼  get_splits(word)                  拆字符 + </w>
{"low": [l,o,w,</w>], "lower": [l,o,w,e,r,</w>], ...}
   │
   ▼  ╔═════════════════════════════════════════════╗
      ║  loop merges_needed 次:                      ║
      ║    pairs = get_stats(splits)                ║
      ║    best  = argmax(pairs)        # 频率最高  ║
      ║    splits = merge_pair(splits, best)        ║
      ║    merges.append(best)                      ║
      ╚═════════════════════════════════════════════╝
   │
   ▼
final_vocab + merges (有序列表) + merges_lookup (pair→优先级)


══════════ 编码阶段 (Encoding) ══════════

word
   │
   ▼  get_splits(word)
[c1, c2, ..., </w>]
   │
   ▼  ╔═════════════════════════════════════════════╗
      ║  apply_bpe:                                  ║
      ║    while True:                              ║
      ║      在相邻 pair 里找 merges_lookup          ║
      ║      优先级最高 (索引最小) 的那个            ║
      ║      没有 → break;有 → 合并所有出现位置     ║
      ╚═════════════════════════════════════════════╝
   │
   ▼  vocab 查询
[id1, id2, ...]

2.1 BPE — 频率最高的 pair 合并

完整跑一遍 ["low", "lower", "lowest"]:

init splits:
  low:    [l, o, w, </w>]
  lower:  [l, o, w, e, r, </w>]
  lowest: [l, o, w, e, s, t, </w>]

第1轮 pair 频率: (l,o)=3, (o,w)=3, (w,e)=2, ...
  → 合并 (l, o)
  → splits 里 l,o 相邻处全变 lo

第2轮: (lo,w)=3 ← 最高 → 合并

第3轮: (low,e)=2 ← 最高 → 合并

... 直到攒够 vocab_size 个 token

编码代码:

def apply_bpe(word, merges_lookup):
    splits = list(word) + ["</w>"]
    while True:
        # 找当前相邻 pair 里"优先级最高"(最早学到)的
        best_pair, best_pri = None, float("inf")
        for i in range(len(splits) - 1):
            pri = merges_lookup.get((splits[i], splits[i+1]))
            if pri is not None and pri < best_pri:
                best_pair, best_pri = (splits[i], splits[i+1]), pri
        if best_pair is None:
            break
        splits = merge_in_place(splits, best_pair)
    return splits

关键: 编码必须严格按训练时的合并顺序重放,否则同一个词可能被切成不同 token,模型会崩。

2.2 WordPiece (BERT) — 互信息最大的 pair 合并

唯一的区别: 合并准则换成互信息

BPE:        argmax  count(A, B)
WordPiece:  argmax  count(A, B) / (count(A) × count(B))

直觉: WordPiece 选「A 和 B 真的经常一起出现」而非「A 自己就常见」。

词边界标记: 用 ## 表示”接续片段”:

"playing" → ["play", "##ing"]
"unhappiness" → ["un", "##happy", "##ness"]

OOV 处理: 拼不出的输出 [UNK](BPE 会用字符兜底,WP 不会 → 更脆)。

2.3 Unigram LM (T5, SentencePiece 默认) — 反向删词表

自上而下删:

1. 启发式构造一个超大候选词表(100万)
2. 训练 Unigram 语言模型: p(token)
3. 对每个词,Viterbi 找概率最高的切分
4. 计算每个 token 的"删除损失"(整个语料 likelihood 下降)
5. 删损失最小的 10%
6. 重复 2-5,直到词表 = vocab_size

关键: 一个词有多种切分方式,选概率最高的(或随机采样做数据增强)。

2.4 Byte-level BPE (GPT-2/3/4, LLaMA-3)

不在字符层,在 UTF-8 字节层做 BPE:

"你好" → UTF-8 字节 [228, 189, 160, 229, 165, 189] → 当作 6 个初始 token
  • 初始词表固定 256(所有字节值)
  • 任何 Unicode 都能编码 → 永远不会 OOV
  • GPT-2 用 Ġ 等可见字符替代不可见字节(仅显示,不影响算法)

三者对比

维度 BPE WordPiece Unigram
训练方向 自下而上加 自下而上加 自上而下删
合并/删除准则 频率最高 互信息最大 删除后 likelihood 下降最小
词边界标记 </w> 后缀 ## 续接前缀 空格前缀(SP)
编码 按 merges 贪心 最长前缀贪心 Viterbi 最优切分
OOV 处理 字符级兜底 [UNK] 字符级兜底
代表模型 GPT, LLaMA, RoBERTa BERT, DistilBERT T5, ALBERT, XLNet

坑题: SentencePiece 不是算法,是 Google 的库,默认 Unigram,也支持 BPE。LLaMA 用 SentencePiece-BPE 模式。

高频追问

Q A
为什么 BPE 需要 </w>? 区分”词尾 e”(est</w>)和”词中 e”,避免错误合并跨词边界
编码是确定性的吗? BPE/WP 是;Unigram 默认是(选概率最高),训练时可采样做正则
同一个词不同上下文切分会变吗? 不会(BPE/WP);Unigram 在 subword regularization 下会
为什么 GPT 用 byte-level? 多语言通用 + 代码符号 + 永远无 OOV
WordPiece 的 [UNK] 怎么来的? 输入有训练时没见过的字符,且无法拼回 vocab 子串

3. 反向传播 — Softmax + Attention 手推

Key Insight: Softmax 的 Jacobian \(J = \mathrm{diag}(s) - s s^\top\),向量化后 \(\nabla_z = s \odot (g - \langle g, s\rangle)\)。 Attention 反向 = 把前向的 4 个矩阵乘逆序求转置,中间夹一个 softmax_backward

实现层次

══════════ Forward ══════════
                                  ┌─ Q [N, D]
              S = Q·Kᵀ/√D         │
   ────────────────────────────── │  ↓ matmul + scale
              A = softmax(S)      ├─ K [M, D]
              out = A·V           │  ↓ softmax
   ────────────────────────────── │  ↓ matmul
                                  │
                                  └─ V [M, D]


══════════ Backward (反向链) ══════════

grad  ← upstream gradient on out
  │
  ├─►  dV = Aᵀ · grad                      [M, D]
  │           ▲
  │           └─ out = A·V → ∂out/∂V = Aᵀ
  │
  ├─►  dA = grad · Vᵀ                      [N, M]
  │           ▲
  │           └─ out = A·V → ∂out/∂A 走 V
  │
  ├─►  dS = softmax_backward(A, dA)        [N, M]
  │           │
  │           │  对每一行 s_i = A[i]:
  │           │  J = diag(s_i) - s_i·s_iᵀ
  │           │  向量化: s_i ⊙ (dA[i] - (dA[i]·s_i))
  │           ▼
  ├─►  dQ = dS · K / √D                    [N, D]
  │
  └─►  dK = dSᵀ · Q / √D                   [M, D]
              ▲
              └─ S 的列 ↔ K 的行,所以要转置

Softmax Jacobian 的完整推导(必背)

对一行 \(s = \mathrm{softmax}(z)\),即 \(s_i = \dfrac{e^{z_i}}{\sum_k e^{z_k}}\)。

Case 1: \(i = j\)

\[\frac{\partial s_i}{\partial z_i} = \frac{e^{z_i} \cdot \sum_k e^{z_k} - e^{z_i} \cdot e^{z_i}}{\left(\sum_k e^{z_k}\right)^2} = s_i - s_i^2 = s_i (1 - s_i)\]

Case 2: \(i \neq j\)

\[\frac{\partial s_i}{\partial z_j} = \frac{0 \cdot \sum_k e^{z_k} - e^{z_i} \cdot e^{z_j}}{\left(\sum_k e^{z_k}\right)^2} = -s_i s_j\]

合并(用 Kronecker \(\delta\)):

\[\frac{\partial s_i}{\partial z_j} = s_i (\delta_{ij} - s_j) \quad\Longleftrightarrow\quad J = \mathrm{diag}(s) - s s^\top \quad\text{(对称矩阵)}\]

Jacobian-Vector Product 化简(为什么不用显式构造 \(J\))

显式构造 \(J\) 是 \(\mathcal{O}(M^2)\),但 backward 只要 \(g \cdot J\)(VJP):

\[\begin{aligned} (g \cdot J)_j &= \sum_i g_i \cdot s_i (\delta_{ij} - s_j) \\ &= g_j s_j - s_j \sum_i g_i s_i \\ &= s_j \bigl(g_j - \langle g, s\rangle\bigr) \end{aligned}\]

→ 向量化: \(\boxed{\,s \odot \bigl(g - \langle g, s\rangle\bigr)\,}\),只要 \(\mathcal{O}(M)\) 乘加,不用建 \(J\)。

# 显式 (慢, O(M²)):              vs       向量化 (快, O(M)):
J = np.diag(s) - np.outer(s, s)          out = s * (g - (g * s).sum(-1, keepdims=True))
out = g @ J                              # 直接对 batch 维广播

Softmax + Cross-Entropy = \(p - y\) (经典化简)

CE loss: \(L = -\sum_i y_i \log s_i\),其中 \(y\) 是 one-hot。

\[\begin{aligned} \frac{\partial L}{\partial z_j} &= \sum_i \frac{\partial L}{\partial s_i} \cdot \frac{\partial s_i}{\partial z_j} = \sum_i \left(-\frac{y_i}{s_i}\right) s_i (\delta_{ij} - s_j) \\ &= -y_j + s_j \underbrace{\sum_i y_i}_{= 1\ (\text{one-hot})} \\ &= s_j - y_j \end{aligned}\]

所以 logits 上的梯度就是 \(\mathrm{softmax}(z) - \mathrm{onehot}(\text{target})\),完全不用建 Jacobian。 这就是为什么 deep learning 框架把 softmax + CE 融成一个 op (F.cross_entropy(logits, targets)),数值更稳、反向更快。

Attention 反向链(QKV 完整推导)

前向:

\[S = \frac{Q K^\top}{\sqrt{D}},\qquad A = \mathrm{softmax}(S,\,\text{dim}=-1),\qquad \mathrm{out} = A V\]

已知 \(\mathrm{dout}\)(上游梯度,shape \([N, D]\))。

反向 1: \(\mathrm{out} = A V\) (普通矩乘)

\[\mathrm{d}V = A^\top \cdot \mathrm{dout} \quad [M, D],\qquad \mathrm{d}A = \mathrm{dout} \cdot V^\top \quad [N, M]\]

反向 2: \(A = \mathrm{softmax}(S, \text{dim}=-1)\) (逐行 softmax)

\[\mathrm{d}S[i] = \mathrm{softmax\_backward}(A[i], \mathrm{d}A[i]) = A[i] \odot \bigl(\mathrm{d}A[i] - \langle \mathrm{d}A[i], A[i]\rangle\bigr)\]

反向 3: \(S = Q K^\top / \sqrt{D}\) (matmul + scale)

\[\mathrm{d}Q = \frac{\mathrm{d}S \cdot K}{\sqrt{D}} \quad [N, D],\qquad \mathrm{d}K = \frac{\mathrm{d}S^\top \cdot Q}{\sqrt{D}} \quad [M, D]\]

记忆法: 反向是前向4 个矩阵乘的逆序求转置,中间夹一个 softmax_backward。\(\sqrt{D}\) 的位置和前向一样(标量,直接除)。

代码骨架

def softmax_backward(s, grad_s):
    # 向量化版: s * (grad_s - (grad_s * s).sum(-1, keepdims=True))
    N, M = s.shape
    grad = np.zeros((N, M), dtype=np.float32)
    for i in range(N):
        J = np.diag(s[i]) - s[i][:, None] @ s[i][None, :]
        grad[i] = grad_s[i] @ J
    return grad

def sdpa_backward(grad, q, k, v):
    D = q.shape[1]
    A = softmax(q @ k.T / np.sqrt(D))
    dv = A.T @ grad
    dA = grad @ v.T
    dS = softmax_backward(A, dA)
    dq = dS @ k / np.sqrt(D)
    dk = dS.T @ q / np.sqrt(D)
    return dq, dk, dv

验证: 和 PyTorch autograd 对比,atol=1e-4 内 close。

为什么除 \(\sqrt{D}\)? (方差分析)

设 \(Q[i], K[j]\) 各元素 i.i.d. \(\sim \mathcal{N}(0, 1)\), 维度 \(D\)。

\[\begin{aligned} S[i,j] &= \sum_{k=1}^{D} Q[i,k] \cdot K[j,k] \\ \mathbb{E}[S] &= 0 \\ \mathrm{Var}[S] &= \sum_k \mathrm{Var}[Q \cdot K] = D \cdot 1 = D \quad\text{(方差线性增长)} \\ \mathrm{Std}[S] &= \sqrt{D} \end{aligned}\]

\(D = 64\) 时 \(\mathrm{Std}[S] \approx 8\),\(D = 128\) 时 \(\approx 11.3\)。 进 softmax 前 logits 太大 → softmax 退化为 one-hot → 梯度消失

除 \(\sqrt{D}\): \(\mathrm{Var}[S/\sqrt{D}] = D/D = 1\),无论 \(D\) 多大方差稳定。 所以这个 \(\sqrt{D}\) 是为了梯度健康,不是为了 forward 数值稳定

数值稳定的 softmax(减 max 技巧)

def softmax(z):
    z = z - z.max(-1, keepdims=True)   # ★ 减 max
    e = np.exp(z)
    return e / e.sum(-1, keepdims=True)

为什么不影响结果:

\[\mathrm{softmax}(z)_i = \frac{e^{z_i - c}}{\sum_k e^{z_k - c}} = \frac{e^{z_i} \cdot e^{-c}}{e^{-c} \sum_k e^{z_k}} = \frac{e^{z_i}}{\sum_k e^{z_k}}\]

任意常数 \(c\) 上下消去 → 结果不变。

为什么减最大值不溢出: 减完后所有元素 \(\le 0\),\(e^{\le 0} \in (0, 1]\),绝不会上溢。

训练时 attention 反向的显存代价

每一层 attention,正向产生:

  • S [B, H, N, M] 中间矩阵
  • A [B, H, N, M] softmax 输出

反向都要用,所以必须保存 → 显存 O(B·H·N²)。 对 LLaMA-7B (32 层, 32 头),序列 4k → ~16 GB 仅 attention 中间量。 这就是为什么 FlashAttention(§4)能让训练装下更大的模型

高频追问

Q A
为什么 softmax 数值稳定要减 max? exp(z) 大数会上溢;softmax(z) == softmax(z - max(z))(分子分母同乘常数)
为什么 attention 要除 √D? 点积方差 = D → softmax 饱和 → 梯度消失。除 √D 让 logit 方差稳定为 1
softmax + cross-entropy 为什么简化成 softmax - onehot? 见上面推导:-y/s · J 展开后 one-hot 把求和压到一个分量,剩 p - y
为什么不显式建 Jacobian? J 是 O(M²),但 VJP 化简后 s ⊙ (g − ⟨g,s⟩) 只要 O(M)
反向能 inplace 写 dq, dk, dv 吗? 不行,共用 A 的话会乱;PyTorch 是 save_for_backward 保留前向中间量
训练时哪几个张量必须存? 前向: Q, K, V, A(softmax 输出);反向重算 S 或直接复用 A。FlashAttention 只存 (O, m, l) 三个
为什么 PyTorch autograd 默认是 VJP 不是 JVP? 神经网络 loss 是标量,VJP (反向模式) 1 次反向得所有参数梯度;JVP 要 dim(params) 次正向
softmax(z/T) 中 T 的作用? T → 0: one-hot(贪心);T → ∞: 均匀。LLM 推理用来控制采样随机性
推导一下 dz = s · (g - g·s),为什么不是 s · g - s² · g? (g · s)内积标量 Σ g_i·s_i,不是逐元素 g⊙s。是同一个标量减回 g 的每个分量

4. Memory-Efficient Attention — Online Softmax → FlashAttention

Key Insight: 数值稳定的 safe softmax 本身最少要 2 pass(必须先扫一遍找 \(m_N\) 才能减 max)。 但 attention 的最终目标是 \(O = A V\),不是 \(A\) 本身 —— 对 \(O\) 再施一次 “surrogate” 技巧,可以把整个 attention 压到 1 pass。这就是 FlashAttention 的全部魔法。

参考: From Online Softmax to FlashAttention (Zihao Ye, UW CSE 599M)

实现层次: 3 pass → 2 pass → 1 pass

══════════ Pass 1: Safe Softmax (3 passes) ══════════

   pass 1:  m_N = max(x_1, ..., x_N)          ← 扫一遍取全局 max
   pass 2:  d_N = Σ exp(x_i - m_N)             ← 再扫一遍算分母
   pass 3:  a_i = exp(x_i - m_N) / d_N         ← 第三遍得 softmax

   问题: 长序列里 logits {x_i} 装不下 SRAM,3 次 pass 就要 3 次重算 Q·Kᵀ


══════════ Pass 2: Online Softmax (Milakov 2018) ══════════

   关键: 用 surrogate d'_i 代替依赖 m_N 的 d_i

      d'_i := Σ_{j=1}^i exp(x_j - m_i)         ← 只用当前 max m_i,不用 m_N

   递推:
      m_i  = max(m_{i-1}, x_i)
      d'_i = d'_{i-1} · exp(m_{i-1} - m_i)     ← 旧累加用 exp(Δm) 重新缩放
             + exp(x_i - m_i)

   性质: 当 i = N 时 d'_N = d_N,所以可以用 d'_N 替换 d_N

   pass 1:  循环里同时算 m_i 和 d'_i
   pass 2:  a_i = exp(x_i - m_N) / d'_N

   但 softmax 本身**没法压到 1 pass** —— a_i 必须等 m_N 算完


══════════ Pass 3: FlashAttention (1 pass) ══════════

   关键观察: 我们不需要 a_i,只需要 O = Σ a_i · V[i,:]

   再施一次 surrogate,对 O 做递推:

      o'_i := Σ_{j=1}^i (exp(x_j - m_i) / d'_i) · V[j,:]

      o'_i = o'_{i-1} · (d'_{i-1} · exp(m_{i-1} - m_i)) / d'_i
                       + (exp(x_i - m_i) / d'_i) · V[i,:]

   性质: o'_N = O[k, :]

   一个循环里同时维护 (m_i, d'_i, o'_i) 三个状态,扫一遍 K, V 就完事

单 pass 算法 (核心 5 行)

def flash_attention_row(q_k, K, V):
    """对 Q 的第 k 行计算 attention 输出 (single pass over K, V)."""
    D = V.shape[1]
    m, d, o = -np.inf, 0.0, np.zeros(D)
    for i in range(K.shape[0]):
        x_i   = q_k @ K[i] / np.sqrt(D)                        # logit
        m_new = max(m, x_i)
        d_new = d * np.exp(m - m_new) + np.exp(x_i - m_new)
        o     = o * (d * np.exp(m - m_new)) / d_new \
              + (np.exp(x_i - m_new) / d_new) * V[i]
        m, d  = m_new, d_new
    return o   # = O[k, :]

工程版: tiled (块大小 b 进 SRAM)

def flash_attention_tiled(Q, K, V, b):
    L, D = Q.shape
    O = np.zeros_like(Q)
    for k in range(L):                                         # 行间天然并行
        m, d, o = -np.inf, 0.0, np.zeros(D)
        for start in range(0, L, b):                           # 沿 K, V 滑动
            Kc, Vc  = K[start:start+b], V[start:start+b]       # 一块进 SRAM
            x       = Q[k] @ Kc.T / np.sqrt(D)                 # [b]
            m_new   = max(m, x.max())
            scale   = np.exp(m - m_new)
            e       = np.exp(x - m_new)                        # [b]
            d_new   = d * scale + e.sum()
            o       = o * (d * scale) / d_new + (e @ Vc) / d_new
            m, d    = m_new, d_new
        O[k] = o
    return O

硬件视角: 为什么省 HBM

朴素:    HBM ↔ S [N, M] ↔ HBM ↔ A = softmax(S) [N, M] ↔ HBM ↔ A·V ↔ HBM
         中间矩阵 S, A 反复读写 HBM  (这是真正的瓶颈)

Flash:   HBM → load Q[k], K[i:i+b], V[i:i+b] → SRAM tile 内算完 (m,d,o) → HBM 只写 O[k]
         中间结果 (S 的 tile, A 的 tile) 永远不出 SRAM

H100 单个 SM 的 SRAM ≈ 228 KB(比 HBM 快 ~30×)。整体 SRAM footprint 只和 b, D 有关,和序列长度 L 无关 → 这就是 FlashAttention 能撑 16k+ context 的原因。

反向传播: 重算 (recomputation) 而非缓存

朴素 attention 的反向(见 §3) 需要保留 [N, M]A 矩阵;长序列下显存爆炸。 FlashAttention 的解法: 正向只保存每行的标量 (m, l),反向时在每个 tile 内重新算一遍 softmax

正向保存:                反向重算:
  O   [L, D]              for each Kⱼ, Vⱼ tile:
  m   [L]   (running max)   重算  S_ij = Qᵢ Kⱼᵀ / √D
  l   [L]   (running sum)   重算  A_ij = exp(S_ij - mᵢ) / lᵢ   ← 用保存的 mᵢ, lᵢ
                            标准 chain rule:
                              dV_j  += A_ijᵀ · dO_i
                              dA_ij  = dO_i · V_jᵀ
                              dS_ij  = softmax_backward(A_ij, dA_ij)
                                     = A_ij ⊙ (dA_ij - D_i)
                                     其中  D_i = (dO_i · O_iᵀ).sum(-1)    ← 一个标量
                              dQ_i  += dS_ij · K_j / √D
                              dK_j  += dS_ijᵀ · Q_i / √D

关键 trick: 那个标量 \(D_i\) —— 朴素 softmax 反向要 \(\langle g, s\rangle\) 内积(见 §3), FlashAttention 里 \(g = \mathrm{d}A\), \(s = A\) 都不显式存。但代入 \(O = A V\) 可以证明:

\[\sum_j \mathrm{d}A_{ij} \cdot A_{ij} = \sum_j (\mathrm{d}O_i \cdot V_j^\top) \cdot A_{ij} = \mathrm{d}O_i \cdot (\sum_j A_{ij} V_j)^\top = \langle \mathrm{d}O_i,\ O_i\rangle\]

所以可以用已经存好的 \(O, \mathrm{d}O\) 直接算出这个标量,完全不需要 materialize \(A\)

正反向显存对比

方案 正向显存 反向额外显存 总计
朴素 attention \(\mathcal{O}(N^2)\) 存 \(S, A\) 0 (复用前向) \(\mathcal{O}(N^2)\)
FlashAttention \(\mathcal{O}(N D)\) 存 \(O\) + \(\mathcal{O}(N)\) 存 \((m, l)\) tile 缓冲 \(\mathcal{O}(b^2)\) \(\boxed{\mathcal{O}(N D)}\)
序列 16k, D 128 ~256 MB ~8 MB (\(32\times\))

IO 复杂度证明梗概(为什么是 \(\mathcal{O}(N^2 D^2 / M_{\mathrm{SRAM}})\))

FlashAttention 论文给出: 朴素 attention 的 HBM 访问量

\[\mathrm{IO}_{\mathrm{naive}} = \Theta(N D + N^2)\]

FlashAttention 的 HBM 访问量

\[\mathrm{IO}_{\mathrm{flash}} = \Theta\!\left(\frac{N^2 D^2}{M_{\mathrm{SRAM}}}\right)\]

当 \(D \ll \sqrt{M_{\mathrm{SRAM}}}\) 时(典型 \(D=128\), \(M_{\mathrm{SRAM}}=100\mathrm{KB}\) → \(\sqrt{M} \approx 300\)), Flash 的 HBM 访问 少 1-2 个数量级。这正是 FlashAttention 实际加速 \(2\)-\(4\times\) 的来源(不是 FLOPs,而是 IO)。

高频追问

Q A
为什么 softmax 最少 2 pass? 数值稳定要先有 m_N 才能减;d'_i 的 surrogate 让 m_i 替代 m_N,但 a_i 仍依赖 m_N
为什么 FlashAttention 能 1 pass? 目标是 O = A·V 而非 A,对 O 再施一次 surrogate trick (o'_i) 消去 m_N 依赖
exp(m_{i-1} - m_i) 在做什么? 把”用旧 max 累加的量”修正成”用新 max 累加的”,保持数值等价
FlashAttention 为什么快?(常见误解) 不是减 FLOPs,是减 HBM 读写。tiling 让中间矩阵留在 SRAM
反向怎么实现? 只存 (m, d'),反向重算 softmax + V 乘积;FLOPs 换显存
FlashAttention 2 / 3 改了什么? v2: 交换内外循环 + 减 non-matmul FLOPs;v3: Hopper warp-specialization + FP8 异步
为什么 Flash 对 long-context LLM 关键? SRAM 占用与 L 无关,只与 b, D 有关 → 64k context 也能稳跑

5. Reward Model — Bradley-Terry pairwise

Key Insight: 不要让人打绝对分(尺度不一致),只让人选 A/B 哪个好。 Bradley-Terry: \(P(y_w \succ y_l) = \sigma(r_w - r_l)\),损失 \(= -\log \sigma(r_w - r_l)\)。

实现层次

            ┌──────── chosen 分支 ────────┐    ┌──────── rejected 分支 ────────┐
            │                              │    │                                │
input:      prompt + y_w  [B, T]                prompt + y_l  [B, T]
            │                                   │
            ▼ backbone (causal LM)              ▼ backbone (共享权重,前向 2 次)
            h_w [B, T, hidden]                  h_l [B, T, hidden]
            │                                   │
            ▼ 取最后一个 token (左 padding 保证)  ▼ 取最后一个 token
            h_w[:, -1, :] [B, hidden]           h_l[:, -1, :] [B, hidden]
            │                                   │
            ▼ reward_head (Linear → 1)          ▼ reward_head (同一个)
            r_chosen   [B]                      r_rejected [B]
                │                                   │
                └─────────────┬─────────────────────┘
                              ▼
                Bradley-Terry pairwise loss
                L = −log σ(r_chosen − r_rejected)
                              ▼
                          backward
                  (训练 backbone + reward_head)

公式

每个回答 \(y\) 有隐式效用 \(r(x, y)\),偏好概率:

\[P(y_w \succ y_l \mid x) = \sigma\bigl(r(x, y_w) - r(x, y_l)\bigr)\]

最大化对数似然 → 损失:

\[\mathcal{L}_{\mathrm{RM}} = -\log \sigma\bigl(r(x, y_w) - r(x, y_l)\bigr)\]

代码骨架

class RewardModel(nn.Module):
    def __init__(self, backbone, hidden_dim):
        super().__init__()
        self.backbone = backbone
        self.reward_head = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        h = self.backbone(x)              # [B, T, hidden]
        last = h[:, -1, :]                # 最后一个 token(causal 模型看到整句)
        return self.reward_head(last).squeeze(-1)

def reward_model_loss(r_chosen, r_rejected):
    return -torch.log(torch.sigmoid(r_chosen - r_rejected)).sum()

工程细节:

  • 左 padding(encode_batch_left_pad),保证 h[:, -1, :] 是真实结尾
  • Backbone 通常 init from SFT checkpoint,reward_head 随机初始化
  • 训练时 chosen / rejected 共享 backbone 前向,只跑 2 次

Bradley-Terry 的等价 logistic 写法

\(-\log \sigma(r_w - r_l)\) 其实就是二分类 logistic loss,把”\(y_w \succ y_l\)“当作 label = 1。 等价于在 reward 差上做二分类:

\[\sigma(r_w - r_l) = P(\text{chosen wins}) \quad\text{(模型预测)}\] \[\text{target} = 1,\qquad \text{binary CE} = -\log \sigma(r_w - r_l) \quad\text{(标准 logistic loss)}\]

→ 任何稳定的 logistic 损失 trick(如 label smoothing)都能直接套上。

奖励的尺度问题(reward scaling)

RM 训完后,raw reward 的尺度任意(取决于 backbone 初始化、训练步数等)。 不归一化直接送 PPO 会出问题:reward 数量级波动 → KL 系数失效 → policy 训练不稳

常用归一化:

# 方法 1: 跑 RM 在大批 prompt 上算 mean/std,做 z-score (InstructGPT 做法)
rewards = (rm(x) - reward_mean) / reward_std

# 方法 2: per-batch 归一化(简单)
rewards = (r - r.mean()) / (r.std() + 1e-8)

# 方法 3: 取 baseline (减 SFT 模型在同 prompt 上的 reward)
shaped = rm(policy_out) - rm(sft_out)

Reward Hacking 是什么

policy 学会”骗”RM 拿高分但人类不喜欢的行为。常见模式:

模式 例子 修法
长度作弊 RM 训练时 chosen 倾向更长 → policy 全写超长答案 length penalty / SimPO
格式作弊 加 “Sure!” “Of course.” 这种 RM 喜欢的开头 reward shaping / 多样化训练数据
拒答作弊 模糊回答避免被 RM 惩罚 helpfulness vs harmlessness 多 RM
重复回应 复述 prompt 占 token repetition penalty in reward
数学幻觉 装作算了一长串 (RM 看格式不看正确性) 用 verifier 替代 RM (GRPO 数学路线)

→ 这就是为什么 PPO 要加 KL 惩罚 β·KL(π‖π_ref):约束 policy 不要为了高 reward 偏离 SFT 太远

多目标 Reward Model (Helpfulness vs Harmlessness)

Anthropic Constitutional AI / Claude 路线:两个 RM,加权合并:

\[R = \alpha \cdot R_{\mathrm{helpful}} + \beta \cdot R_{\mathrm{harmless}}\]

\(\alpha, \beta\) 是动态调的: 检测到不安全输入 → \(\beta\) 提高;一般问答 → \(\alpha\) 主导。 也可以训成单个 RM 输出多维标量,然后融合层学加权。

Pairwise vs Pointwise vs Listwise

范式 数据 Loss 典型
Pointwise (prompt, response, score 1-5) MSE 早期人工评分,人标不一致
Pairwise (BT) (prompt, chosen, rejected) -log σ(r_w - r_l) InstructGPT, Claude
Listwise (Plackett-Luce) (prompt, ranked list of k) k 选 1 的概率乘积 排序信息更密集

Pairwise 工业标配 —— 标注成本 / 一致性 / 算法成熟度三方面最优。

评估 RM 的指标

Accuracy on held-out pairs: 给定 (chosen, rejected),RM 算出 r_w > r_l 的比例。 经验值: 训得好的 RM 在新数据上 ~65-75%(注意:这个 ceiling 比想象低,因为偏好本身不一致)。

KL-controlled win rate: 让 PPO 训完的 policy 在固定 KL 距离 β 下, 对比 baseline (SFT 或上一版 policy) 的人类胜率。这才是 RM 真正的”线上指标”。

高频追问

Q A
为什么取最后一个 token? Causal mask 下,最后一个位置的 hidden 看到了完整 prompt + response
Reward hacking 是什么? Policy 学会钻 RM 漏洞(长度/格式/拒答作弊),拿到高 reward 但人类不喜欢。修: 加 KL,length penalty,多目标 RM
RM 和 DPO 的关系? DPO 在数学上等价于”训 RM 然后 PPO”,但消去 RM 这一步直接优化 policy(见 §10)
多个 chosen 怎么处理? (a) 列出所有 pairwise (chosen, rejected) 对逐对算损失 (b) 用 listwise Plackett-Luce
Reward 要不要归一化? 必须。raw reward 尺度任意 → PPO 不稳。z-score 或减 SFT baseline
RM 训多少数据? InstructGPT 用 ~33k pairs;开源 (Anthropic HH-RLHF) ~170k;实际工业线 100k-1M 量级
为什么 BT 假设可疑但还能用? BT 假设 reward 是固定标量函数,实际人类偏好可能 cyclic。但平均下来够好
RM 能直接做 inference rank 吗? 可以(best-of-N 采样,挑 r 最高的)。计算成本: N 倍 inference
Pairwise 数据怎么收集? 同 prompt 不同 temperature 采 2 个 → 人标 chosen/rejected。或 GPT-4 当 judge (RLAIF)
RM 会不会”偏向自己生成”? 会。如果 RM 训练数据是 SFT 模型采样的,RM 会偏好 SFT 风格 → 限制了 PPO 的探索

6. RL Fine-tuning — 策略梯度 + 组内归一化

Key Insight: 策略梯度的方差太大,必须减 baseline。 组内均值当 baseline(同 prompt 才可比),既降方差又不用单独训 value 网络。

实现层次

                  ┌─────────── 采样阶段 ───────────┐
                  │                                 │
   prompt [B] ────┤  policy.sample × G             │
                  │  (一个 prompt 采 G 个 rollout)  │
                  └────────────────┬───────────────┘
                                   ▼
              rollouts [B, G, T]   ←─── 每个回答的 token 序列
                                   │
                  ┌────────────────┼────────────────┐
                  ▼                                 ▼
       Reward Model 打分                  policy 重新前向 (要梯度)
              rewards [B, G]                  logits [B, G, T, V]
                  │                                 │
                  ▼ compute_group_advantage         ▼ gather + (-log_softmax)
       (r - mean(-1)) / (std(-1) + ε)      neg_logp [B, G, T]
       advantage [B, G]                            │
                  │                                 │
                  └──────────────┬──────────────────┘
                                 ▼
                policy_gradient_loss
                = (neg_logp * advantage[..., None]).sum(-1).mean(1).sum()
                  − entropy_weight · entropy
                                 │
                                 ▼
                            backward → 更新 policy

为什么要 advantage,不直接用 reward?

\[\nabla L = -\mathbb{E}\bigl[\nabla \log \pi(a) \cdot R\bigr] \quad\text{(方差大)}\] \[\nabla L = -\mathbb{E}\bigl[\nabla \log \pi(a) \cdot (R - b)\bigr] \quad\text{(减 baseline } b\text{ 不改变期望但降方差)}\]

\(b\) 选什么?

  • 经典: value network \(V(s)\)(actor-critic)
  • GRPO 路线: 同 prompt 的组内均值(省一个 value 网络)

代码骨架

def compute_group_advantage(rewards):     # [B, G]
    return (rewards - rewards.mean(-1, keepdim=True)) / \
           (rewards.std(-1, keepdim=True) + 1e-8)

def policy_gradient_loss(neg_logp, advantage, entropy, entropy_weight=1e-3):
    # neg_logp: [B, G, T]  ; advantage: [B, G]
    adv_loss = neg_logp * advantage.unsqueeze(-1)
    return adv_loss.sum(-1).mean(1).sum() - entropy_weight * entropy

熵正则: 鼓励探索,防止过早坍缩到一种回答模式。

策略梯度的”血脉” (REINFORCE → A2C → PPO → GRPO)

1992  Williams  REINFORCE        ∇L = -E[∇logπ · R]            ← 纯朴素,方差爆炸
                                                              
1999  Sutton    Policy Gradient  ∇L = -E[∇logπ · A]            ← 引入 advantage
      Theorem                       A = R - b(s)                   降方差
                                                              
2016  Mnih      A2C/A3C          A = Σ γ^k r_{t+k} - V(s_t)    ← 用 value 网络当 baseline
                                                              
2017  Schulman  PPO              加 ratio + clip + KL            ← 防策略跑飞
                                                              
2024  DeepSeek  GRPO             去掉 value,组内 z-score 当 A   ← 省一个模型

每代都在解决前一代的痛点。今天 RLHF 实际用的就是 PPOGRPO 两条线。

完整训练循环 (伪代码)

def rl_finetune(policy, ref, rm, prompts, epochs):
    for epoch in range(epochs):
        for batch_prompts in prompts:
            # ===== 采样阶段 (no grad) =====
            with torch.no_grad():
                rollouts = []
                for _ in range(group_size):                # G 个 rollout per prompt
                    out = policy.generate(batch_prompts)
                    rollouts.append(out)
                rewards = rm(rollouts)                     # [B, G]

            # ===== 优势归一化 =====
            advantage = (rewards - rewards.mean(-1, keepdim=True)) / \
                        (rewards.std(-1, keepdim=True) + 1e-8)

            # ===== 训练阶段 (需要 grad) =====
            for _ in range(ppo_epochs):                    # PPO 多轮利用同一批 rollout
                logits = policy(rollouts)                  # 重新前向
                neg_logp = F.cross_entropy(logits, rollouts, reduction='none')
                entropy  = compute_entropy(logits)

                loss = (neg_logp * advantage.unsqueeze(-1)).sum(-1).mean() \
                     - entropy_weight * entropy

                # 可选: KL 惩罚
                kl = compute_kl(policy, ref, rollouts)
                loss += kl_coef * kl

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

On-Policy vs Off-Policy 的微妙差别

On-policy: 用当前 π 采样的数据更新 π。每次梯度步后,旧数据就作废Off-policy: 用别的 π (历史 / SFT / 别人的) 采样的数据更新 π。需要 importance sampling 修正。

范式 数据来源 修正 复杂度 算法
纯 on-policy π_θ (当前) 慢 (一步一采) REINFORCE, vanilla PG
近 on-policy π_θ_old (几步前) importance ratio 中 (一批数据多步) PPO, GRPO
纯 off-policy replay buffer / SFT 强 IS / Q-learning 难 (分布外严重) DDPG, SAC, DPO

PPO 和 GRPO 的妙处: 用 importance ratio 复用同一批 rollout 训多步,省了 90%+ 的采样成本

为什么减 baseline 不改变期望(严格证明)

\[\begin{aligned} \mathbb{E}_\pi\bigl[\nabla \log \pi(a \mid s) \cdot b(s)\bigr] &= b(s) \int \pi(a \mid s)\,\nabla \log \pi(a \mid s)\,\mathrm{d}a \\ &= b(s) \int \nabla \pi(a \mid s)\,\mathrm{d}a \quad\text{(log 求导技巧)} \\ &= b(s) \cdot \nabla \int \pi(a \mid s)\,\mathrm{d}a \\ &= b(s) \cdot \nabla 1 = 0 \end{aligned}\]

\(b(s)\) 只要不依赖 action \(a\)(可以依赖 state \(s\))都不改期望;但会改方差

最优 baseline:

\[b^*(s) = \frac{\mathbb{E}\bigl[(\nabla \log \pi)^2 \cdot R\bigr]}{\mathbb{E}\bigl[(\nabla \log \pi)^2\bigr]}\]

(但实践用 \(V(s)\) 已经够好)。

熵正则的具体作用

\[\mathcal{L} = -\mathbb{E}\bigl[\log \pi \cdot A\bigr] - \beta_{\mathrm{ent}} \cdot H(\pi) \qquad (H \text{ 越大,}\pi \text{ 越分散})\]
\(\beta_{\mathrm{ent}}\) 效果
0 容易 mode collapse,最后只剩 1-2 个高 reward 答案
\(10^{-3}\) ~ \(10^{-2}\) 典型范围 (InstructGPT 用 \(10^{-2}\) 量级)
\(> 0.1\) \(\pi\) 太接近均匀分布,根本学不到偏好

对 LLM,熵指 token 级别 \(H = -\sum_v p_v \log p_v\) 在 vocab 维度上的均值。

高频追问

Q A
为什么减 baseline 不改变期望? E[∇logπ · b] = b · ∇∫π = 0,只要 b 不依赖 action a(可依赖 state s)
为什么用组内均值? 同 prompt 下 reward 才可比;跨 prompt 平均无意义,reward 尺度差太大
on-policy vs off-policy? on: 当前 π 采样,数据用 1 次;近 on: PPO/GRPO 用 IS ratio 重用;off: DPO 直接用离线偏好数据
熵正则太强 / 太弱? 太强 → 不收敛,policy 一直瞎采;太弱 → 早期 mode collapse
和 GRPO 关系? 这一节就是 GRPO 的”裸版”(没有 PPO 的 ratio + clip);加上后变成 §9
REINFORCE 为什么不实用? 没有 baseline → 方差极大;每个梯度步都要 fresh 采样 → 慢
为什么 LLM 不用 actor-critic 的 V(s)? (a) state 是整个 token 序列,V(s) 很难训 (b) GRPO 用组内 baseline 已经够好
group size G 的影响? G 太小 (≤2) 组内 std 不稳定;G 太大 (>64) 显存爆且边际收益递减;典型 8-16
奖励信号为什么稀疏(只在 EOS 给)? 用 RM 打分时,RM 是 sequence-level 输出。token 级 reward 需要 step-wise RM 或 process RM
token-level 还是 sequence-level loss? 序列级 advantage 广播到所有 token 是主流 (GRPO/PPO);DPO 直接序列级 logp 差

7. GAE — 广义优势估计

Key Insight: 用 \(\lambda\) 在「低偏差高方差(蒙特卡洛回报)」和「高偏差低方差(一步 TD)」之间插值。 关键计算:从后往前递推 \(A_t = \delta_t + \gamma\lambda \cdot A_{t+1}\)。

实现层次

   输入: rewards [B, T] , values [B, T] (来自 value head)
                                │
   ════════ 从后往前递推 ════════
                                │
   t = T-1 ──┐  next_V = 0      │
             │  δ = r_{T-1} - V_{T-1}
             │  A_{T-1} = δ
             ▼
   t = T-2 ──┐  next_V = V_{T-1}
             │  δ = r_{T-2} + γ·V_{T-1} - V_{T-2}
             │  A_{T-2} = δ + γλ·A_{T-1}
             ▼
   ...        ↓ 不断累乘 γλ
   t = 0   ──┐  next_V = V_1
             │  δ = r_0 + γ·V_1 - V_0
             │  A_0 = δ + γλ·A_1
             ▼
   advantages [B, T]
                                │
                                ▼
   returns = advantages + values   (用于 value head 的 MSE)
   
   极端情况:
     λ=0  ⇒  A_t = δ_t                   (一步 TD,低方差高偏差)
     λ=1  ⇒  A_t = Σ γ^k·δ_{t+k}         (≈MC return − V,高方差低偏差)

公式

TD 残差:

\[\delta_t = r_t + \gamma \cdot V(s_{t+1}) - V(s_t)\]

GAE 递推(从后往前):

\[A_t = \delta_t + \gamma \lambda \cdot A_{t+1}\]

代码骨架

def compute_gae(rewards, values, gamma=0.99, lam=0.95):
    advantages = torch.zeros_like(values)
    lastgae = 0.0
    T = rewards.shape[-1]
    for t in reversed(range(T)):
        next_value = values[:, t+1] if t < T-1 else torch.zeros_like(values[:, t])
        delta = rewards[:, t] + gamma * next_value - values[:, t]
        lastgae = delta + gamma * lam * lastgae
        advantages[:, t] = lastgae
    return advantages

数值验证: γ = λ = 1 时 GAE 应该等于 cumsum(rewards, from end) - values

n-step Returns 复习 (GAE 的前身)

回报的几种估计:

\[\begin{aligned} \text{1-step TD:}\quad &\hat{A}_t^{(1)} = r_t + \gamma V(s_{t+1}) - V(s_t) \quad (= \delta_t)\\ \text{2-step:}\quad &\hat{A}_t^{(2)} = r_t + \gamma r_{t+1} + \gamma^2 V(s_{t+2}) - V(s_t)\\ \text{k-step:}\quad &\hat{A}_t^{(k)} = \sum_{l=0}^{k-1} \gamma^l r_{t+l} + \gamma^k V(s_{t+k}) - V(s_t)\\ \infty\text{-step (MC):}\quad &\hat{A}_t^{(\infty)} = \sum_{l=0}^{\infty} \gamma^l r_{t+l} - V(s_t) \end{aligned}\]

偏差-方差权衡:

  • \(k\) 小: bootstrapping \(V\) 多,\(V\) 不准 → 偏差大;但 Var 小,因为只用了 \(k\) 步真实 reward
  • \(k\) 大: 用更多真实 reward → 偏差小;但 Var 大,因为长 trajectory 累加随机性

GAE = n-step Advantages 的指数加权平均

Schulman 2015 的 GAE 定义:

\[A_t^{\mathrm{GAE}(\gamma,\lambda)} = (1 - \lambda) \sum_{k=1}^{\infty} \lambda^{k-1} \hat{A}_t^{(k)} \quad\text{(各 k-step 的指数加权)}\]

代入 \(\hat{A}^{(k)}\) 展开,化简后等于 \(\delta_t\) 的指数加权:

\[A_t^{\mathrm{GAE}} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l} = \delta_t + \gamma\lambda \cdot \delta_{t+1} + (\gamma\lambda)^2 \cdot \delta_{t+2} + \cdots\]

这就是为什么 GAE 的递推这么简单: \(\boxed{A_t = \delta_t + \gamma\lambda \cdot A_{t+1}}\)(一行)。

两个极端的严格推导

\(\lambda = 0\): 只保留 \(k=1\) 项

\[A_t^{\mathrm{GAE}} = \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) \quad\text{(一步 TD)}\]

bias 高(完全信任 \(V\)),variance 低(只有一项随机)。

\(\lambda = 1\): 所有项加权

\[A_t^{\mathrm{GAE}} = \sum_{l=0}^{\infty} \gamma^l \delta_{t+l} = \sum_{l=0}^{\infty} \gamma^l r_{t+l} - V(s_t) \quad\text{(MC return} - V\text{)}\]

bias 低(用真实 reward),variance 高(累加无限项)。

→ 选 \(\lambda \in (0, 1)\) 在两者间插值,典型 \(\lambda = 0.95\)(略偏蒙特卡洛)。

γ 和 λ 的实际差别

超参 含义 典型 调谁
γ (折扣) 未来 reward 衰减率 0.99 任务相关:长 horizon 任务接近 1,短 horizon 可低些
λ (GAE) n-step 加权 0.95 bias-variance 调优,通常固定 0.9-0.97

LLM RLHF 里: γ ≈ 1(序列只到 EOS,无需折扣)、λ ≈ 0.95-1.0。

Value Head 训练

Actor-critic 同时优化 policy 和 value:

class PolicyValueModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone   = backbone
        self.policy_head = nn.Linear(d, vocab_size)
        self.value_head  = nn.Linear(d, 1)          # 标量 V(s)

    def forward(self, x):
        h = self.backbone(x)
        return self.policy_head(h), self.value_head(h).squeeze(-1)

# 训练
logits, V = model(rollouts)
A_gae     = compute_gae(rewards, V, gamma=0.99, lam=0.95)
returns   = A_gae + V                                 # 即 V_target

policy_loss = -mean(logp * A_gae.detach())            # advantage detach!
value_loss  = F.mse_loss(V, returns.detach())         # returns detach!
loss = policy_loss + 0.5 * value_loss

: A 和 returns 都要 .detach(),否则 value 梯度会污染 policy 梯度。

稀疏 reward 是怎么传回去的

LLM RLHF 中,RM 只在 EOS 给一个 scalar reward。其他 token reward \(= 0\)。 GAE 把这个唯一的非零 reward 折扣回传:

奖励序列:

\[r = [0, 0, 0, \ldots, 0, R_{\mathrm{final}}]\]

当 \(\gamma = 1, \lambda = 1\):

\[A_t = R_{\mathrm{final}} - V(s_t) \quad\text{(所有 t 都拿到 } R_{\mathrm{final}} \text{ 信号)}\]

当 \(\gamma < 1, \lambda < 1\):

\[A_t = \gamma^{T-t-1} \cdot R_{\mathrm{final}} + \text{bootstrap}(V)\]

→ 即使 reward 极稀疏,GAE + \(V\) 也能给每个 token 一个 advantage 信号。

高频追问

Q A
γ 和 λ 各管什么? γ 是回报折扣(管”未来 reward 看多远”);λ 管”信任 value 估计 vs 真实回报”的程度
为什么从后往前? 递推式 A_t = δ_t + γλ·A_{t+1} 依赖 t+1 的值,所以反向遍历
GAE 怎么从 n-step 推出来? n-step advantages 的指数加权平均,化简得 Σ (γλ)^l · δ_{t+l},等价于 A_t = δ_t + γλ·A_{t+1}
Actor 和 Critic 为什么共享 backbone? 省显存 + 表征共享;value head 只是个 Linear(dim, 1)
Value loss 怎么算? MSE(V(s_t), returns_t),returns = advantages + values
value 和 reward 的区别? reward 是即时奖励 r_t;value V(s_t) 是从 s_t 起的预期总折扣回报
advantage 为什么要 detach? 进 policy loss 时要,否则 value 反向梯度会污染 policy
稀疏 reward (只有最后一步)怎么办? GAE 会把最后一步的 reward 沿 γ^t 折扣回传到所有时间步
GAE 在 LLM 里是 token 级还是序列级? token 级:每个 token 有自己的 A_t。但 reward 一般只在 EOS 给
为什么 GRPO 不用 GAE? GRPO 砍掉 value 网络,直接用组内 z-score。代价: 没法 token 级精细化

8. PPO — Clipped Surrogate + KL

Key Insight: 用 importance sampling ratio 修正 off-policy + clip 防止策略跑飞 + KL 约束别偏离 SFT 太远。 4 个模型 (policy / ref / reward / value) 是 PPO 工程痛点。

实现层次

══════════ RLHF 三步走 ══════════
   Step 1: SFT             → π_SFT
   Step 2: 训 Reward Model → RM   (见 §5)
   Step 3: PPO 优化 π             ← 本节

══════════ PPO 一步更新的 4 个模型 ══════════

                        ┌──────────────────────────────┐
                        │  采样: π_old.sample(prompt)   │
                        │  得到 (a, logp_old) 缓存       │
                        └────────────┬──────────────────┘
                                     │ rollout
                                     ▼
   ┌───────────────┬─────────────────┴────────────────┬───────────────┐
   ▼               ▼                                  ▼               ▼
policy(πθ)      ref(π_SFT)                     reward(RM)         value(V)
要梯度           冻结                            冻结                要梯度
   │               │                                  │               │
logp_new         logp_ref                          rewards          V(s)
   │               │                                  │               │
   │     KL(π_θ ‖ π_ref) per token                    │       GAE(rewards, V)
   │               │                                  │               │
   │               └─── β·KL ───┐                     │               │
   │                            ▼                     │               ▼
   │                  shaped_reward = r − β·KL        │           advantages
   │                                                  │
   └────── ratio = exp(logp_new − logp_old) ──────────┘
                       │
            ┌──────────┴──────────┐
            ▼                     ▼
       ratio·A          clip(ratio, 1±ε)·A
            └───────── min() ─────┘
                       ▼
                   −mean = PPO loss
                       ▼
                   backward → 更新 policy + value

公式

重要性采样比 + clipped surrogate:

\[\mathrm{ratio}_t = \exp\bigl(\log \pi_\theta(a_t) - \log \pi_{\mathrm{old}}(a_t)\bigr)\] \[\mathcal{L}_{\mathrm{PPO}} = -\mathbb{E}\Bigl[ \min\bigl(\mathrm{ratio} \cdot A,\ \mathrm{clip}(\mathrm{ratio},\,1-\varepsilon,\,1+\varepsilon) \cdot A\bigr) \Bigr]\]

为什么 clip: 防止单步更新太大让策略跑飞。\(\varepsilon = 0.2\) 是经典值。

KL 惩罚

\[\mathrm{advantage} = \mathrm{reward} - \beta \cdot \mathrm{KL}\bigl(\pi_\theta \,\|\, \pi_{\mathrm{ref}}\bigr)\]

\(\pi_{\mathrm{ref}}\) = SFT 的冻结副本,全程不更新。这个 KL 防止 policy 偏离语言模型本能(避免 reward hacking)。

代码骨架

def ppo_loss(logp_new, logp_ref, advantages, clip=0.2):
    ratio = torch.exp(logp_new - logp_ref.detach())
    unclipped = ratio * advantages
    clipped = torch.clamp(ratio, 1.0 - clip, 1.0 + clip) * advantages
    return -torch.mean(torch.min(unclipped, clipped))

def apply_kl_penalty(rewards, logp_new_dist, logp_ref_dist, kl_coef=0.01):
    kl = (torch.exp(logp_new_dist) * (logp_new_dist - logp_ref_dist.detach())).sum(-1)
    return rewards - kl_coef * kl

Clip 的可视化(理解 min 的含义)

                Advantage > 0 (好动作,想拉高 logp)
                ┌──────────────────────────────┐
                │  ratio 增大方向                │
                │                              │
   ratio·A      │      ╱                       │ ← unclipped:无限抬升
                │     ╱                        │
                │    ╱                         │
                │   ╱                          │
   1+ε ─────────│──╱──────────── clipped       │ ← clipped:被天花板限制
                │ ╱            ────────────    │
                │╱                             │
   ratio=1 ─────┼──────────────────────────────│
                ↑                              ↓
                │  ratio 减小方向                │
                │                              │
              对 A > 0: 选 min → 取被 clip 的(保守,防过度拉高)


                Advantage < 0 (差动作,想压低 logp)
                ┌──────────────────────────────┐
   ratio=1 ─────┼──────────────────────────────│
                │\                             │
                │ \                            │
                │  \                           │
   1-ε ─────────│---\---────────── clipped     │
                │    \           ─────────     │
                │     \                        │
                │      \                       │
   ratio·A      │       \                      │ ← unclipped:无限下压
                │        \                     │
                └──────────────────────────────┘

              对 A < 0: 选 min → 取没被 clip 的(继续压低,不饶过坏动作)

核心直觉: 对正 A,限上界防止过度乐观;对负 A,不限下界保留惩罚力度。 这是 PPO 比 TRPO 简洁但等效的关键。

Multi-Epoch 训练: 为什么能复用同一批 rollout

朴素 on-policy: 一批数据用 1 个梯度步,扔掉,重新采样 — 极慢。 PPO: 一批数据用 k 个 epoch(典型 k=4)的梯度更新,通过 ratio 修正 off-policy 偏差:

for batch in batches:
    rollouts, logp_old = sample(policy)          # 采一次
    for ppo_epoch in range(4):                    # ← 复用 4 次
        logp_new = policy(rollouts)
        ratio = exp(logp_new - logp_old)         # ratio 起飞表示 policy 漂太远
        loss = clipped_surrogate(ratio, A)
        loss.backward(); step()

ratio 越接近 1 → 策略和 old 没差多远 → IS 估计准;ratio 离 1 远 → clip 生效避免崩。 多 epoch 是 PPO 比 REINFORCE 快 10× 的根本原因

Token-level vs Sequence-level PPO

LLM 里的 \(a_t\) 是 token,不是动作。两种实现:

Token-level (主流):

\[\mathrm{ratio}_t = \frac{\pi_\theta(a_t \mid \cdot)}{\pi_{\mathrm{old}}(a_t \mid \cdot)}, \qquad \mathcal{L} = -\frac{1}{T}\sum_t \min\!\bigl(\mathrm{ratio}_t A_t,\ \mathrm{clip}(\mathrm{ratio}_t, 1\pm\varepsilon)\,A_t\bigr)\]

Sequence-level (DPO 风格):

\[\log \pi(\mathrm{seq}) = \sum_t \log \pi(a_t),\quad \mathrm{ratio} = e^{\log \pi_\theta(\mathrm{seq}) - \log \pi_{\mathrm{old}}(\mathrm{seq})},\quad \mathcal{L} = -\min(\mathrm{ratio} \cdot A_{\mathrm{seq}},\ \mathrm{clip} \cdot A_{\mathrm{seq}})\]

Token-level 更细致,能给 GAE 提供 per-token \(A_t\);sequence-level 更稳定,梯度信号简洁。

4 个模型 (PPO 工程痛点)

模型 角色 是否更新 显存代价
policy 主角,被 RL 训练 1× + 优化器 ~5×
ref (= SFT 冻结) 算 KL 惩罚 1× (inference only)
reward RM 给分 1× (inference only)
value GAE baseline (actor-critic) ✅ (共享 backbone) 0.1× (只是个 head)

总显存 ≈ 3-4 倍模型大小。这就是为什么 RLHF PPO 难做 70B+ 模型,GRPO 把 value 砍了缓解。

PPO 完整 loss

def ppo_total_loss(logp_new, logp_old, A, V, returns, dist_new, dist_ref):
    # 1. Policy loss (clipped surrogate)
    ratio   = torch.exp(logp_new - logp_old)
    policy  = -torch.min(ratio * A,
                          torch.clamp(ratio, 1-ε, 1+ε) * A).mean()

    # 2. Value loss
    value   = F.mse_loss(V, returns)

    # 3. Entropy bonus
    entropy = -(dist_new.exp() * dist_new).sum(-1).mean()

    # 4. KL penalty (optional, 也可以放在 reward shaping 里)
    kl      = (dist_new.exp() * (dist_new - dist_ref)).sum(-1).mean()

    return policy + c_v * value - c_ent * entropy + β_kl * kl

InstructGPT 经验值: c_v = 0.5, c_ent = 0.01, β_kl = 0.1, ε = 0.2

高频追问

Q A
为什么要 clip? 单步更新太大会让 ratio 离 1 太远 → importance sampling 估计崩;clip 提供”信任域”
ε 太大/太小? 太大失去 clip 意义,等同朴素 PG;太小更新太保守,收敛慢。0.2 是经验最优
为什么需要 KL? RM 是不完美的代理,policy 全力对抗 RM 会 reward hacking;KL 限制偏离 SFT
Min(unclipped, clipped) 的直觉? 保守的那一边:正 A 限上界、负 A 不限下界(本来就要压)
Ratio 用 logp_new - logp_old 不是 logp_ref? old采样时的 policy 快照 (每个 minibatch 之前 detach);ref 是 SFT 副本,做 KL 用
PPO 为什么能多 epoch? 用 ratio 修正 off-policy,只要 ratio 没漂太远(被 clip 控住),数据可复用 4-8 次
token-level 还是 sequence-level? LLM 通常 token-level,提供 per-token advantage;sequence-level 适合稀疏 reward
value head 怎么初始化? (a) 从 policy 复用 backbone + 随机初始 head (b) 单独训几个 epoch 让 value 先 warm up
PPO vs TRPO? TRPO 用二阶 + KL 信任域 (严格但贵);PPO 用一阶 + clip 近似 (便宜 90% 效果)
4 个模型显存怎么压? (a) ref 用 LoRA delta (b) value 共享 backbone (c) reward 量化 (d) 用 GRPO 砍 value
ratio.mean() 应该等于多少? 应该 ≈ 1。> 1.5 或 < 0.5 表示 policy 漂太远,该早停或减小 lr

9. GRPO — PPO 砍掉 Value 网络

Key Insight: 同 prompt 的 G 个 rollout 组内 z-score 就是个很好的 baseline,不再需要 value 网络。 省一个模型 = 省 25% 显存,适合可验证 reward(数学/代码)。

实现层次

══════════ PPO vs GRPO 差在哪 ══════════

PPO:    policy + ref + reward + VALUE  (4 个模型, GAE 算 A)
GRPO:   policy + ref + reward          (3 个模型, 组内归一化算 A)
                              ↑
                       省掉的就是 value


══════════ GRPO 一步更新 ══════════

  prompt
     │
     ▼ policy.sample × G             一个 prompt 采 G 个 rollout
  rollouts [B, G]
     │
     ├──► RM 打分 ─────► rewards [B, G]
     │                       │
     │                       ▼  组内 z-score (不需要 value!)
     │             advantage = (r - r.mean(-1)) / (r.std(-1) + ε)
     │                       │
     │                       ▼
     │             advantage [B, G]
     │
     ├──► policy 重新前向 ─► logp_policy [B, G]
     │
     └──► ref (本轮 deepcopy) ─► logp_ref [B, G]
                                     │
                  ratio = exp(logp_policy - logp_ref)
                                     │
                  ┌──────────────────┴──────────────────┐
                  ▼                                     ▼
            ratio · A                       clip(ratio, 1±ε) · A
                  └──────────── min() ──────────────────┘
                                ▼
                          −mean = GRPO loss
                                ▼
                          backward → 更新 policy

公式

组内 z-score advantage + PPO 风格 clipped surrogate:

\[A_i = \frac{r_i - \mathrm{mean}(r)}{\mathrm{std}(r) + \epsilon}, \qquad \mathrm{ratio}_i = \exp\bigl(\log \pi_\theta(\mathrm{rollout}_i) - \log \pi_{\mathrm{ref}}(\mathrm{rollout}_i)\bigr)\] \[\mathcal{L}_{\mathrm{GRPO}} = -\mathbb{E}\Bigl[ \min\bigl(\mathrm{ratio}_i \cdot A_i,\ \mathrm{clip}(\mathrm{ratio}_i,\,1-\varepsilon,\,1+\varepsilon) \cdot A_i\bigr) \Bigr]\]

\(\pi_{\mathrm{ref}}\) 不再是 SFT 副本,而是这一轮 RL step 开始时 deepcopy 的快照(每步刷新)。

代码骨架

def grpo_loss(logp_policy, logp_ref, rewards, epsilon=0.2):
    advantage = (rewards - rewards.mean(-1, keepdim=True)) / \
                (rewards.std(-1, keepdim=True) + 1e-8)
    ratio = torch.exp(logp_policy - logp_ref.detach())
    unclipped = ratio * advantage
    clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantage
    return -torch.minimum(unclipped, clipped).mean()

GRPO vs PPO 对比表

维度 PPO GRPO
Value 网络 必须 砍掉
Advantage GAE(rewards, V) 组内归一化 reward
Ref policy SFT 副本(冻结) 每步 deepcopy 的快照
显存 4 模型 3 模型(policy + ref + RM/verifier)
适合任务 通用 RLHF 有可验证 reward(数学 / 代码)
代表 InstructGPT, ChatGPT 早期 DeepSeek-R1, DeepSeek-Math

GRPO vs RLOO vs Vanilla PG

算法 Baseline Off-policy 修正 代表
REINFORCE 教科书
Vanilla PG with mean 全 batch 均值 早期
RLOO (REINFORCE Leave-One-Out) 同 prompt 其他 G-1 个 rollout 的均值 无 (on-policy) Cohere 用
GRPO 组内 z-score (mean + std 归一化) ratio + clip DeepSeek
PPO with value V(s) ratio + clip InstructGPT

GRPO 比 RLOO 多了 std 归一化(让 advantage 尺度稳定)+ PPO 的 clip(允许 multi-epoch)。

DeepSeek-R1 的具体配置

DeepSeek-R1 论文 (2025) 用 GRPO 训了一个仅靠 RL 就能 reasoning 的模型:

配置:
  Group size G = 64        ← 比 InstructGPT 大很多
  ε_clip      = 0.2
  β_KL        = 0          ← 不加 KL!直接让 policy 探索
  Reward      = 0/1 数学正确性 (verifier, 不用 RM)
  Ref         = 每个 RL step 之前 deepcopy 的快照

为什么能 work?
  (1) 数学/代码任务有 verifiable reward,不会 reward hacking → 不需要 KL
  (2) 大 G 让组内 z-score 估计稳定
  (3) 长 CoT 自动涌现:奖励正确就行,policy 学会"先想后答"

观察: SFT 不需要长 CoT 数据,RL 自己会涌现”先思考再答案”的链。

长度归一化问题 (length normalization)

GRPO 序列级 logp 会偏向短回答(每多一个 token 多一个 logp 项,绝对值变大):

\[\log \pi_{\mathrm{policy}}(\mathrm{seq}_{\mathrm{short}}) = \log p(t_1) + \log p(t_2) \quad\text{(2 项)}\] \[\log \pi_{\mathrm{policy}}(\mathrm{seq}_{\mathrm{long}}) = \log p(t_1) + \cdots + \log p(t_{50}) \quad\text{(50 项,绝对值偏大)}\]

→ ratio 受序列长度影响。修法:

# Token-mean normalization (DeepSeek-Math 用)
logp_seq = neg_logp.sum(-1) / mask.sum(-1)        # 除以序列长度

# Or 算 per-token 的 ratio 和 advantage,然后逐 token 加权
ratio_t = exp(logp_new_t - logp_old_t)            # 每个 token 独立
loss_t  = min(ratio_t · A, clip(ratio_t, 1±ε) · A)
loss    = loss_t.sum(-1).mean()                   # 沿 token 加和

DeepSeek-V3/R1 实际用的是 token-level loss + length normalization 的组合。

Verifier vs Reward Model

GRPO 不一定要 RM,可以用可验证的 verifier:

任务 Verifier 优势
数学题 答案正则匹配 / SymPy 化简 无 reward hacking
代码题 跑测试用例,通过率 客观
自然语言 用 RM (LLM-as-judge 或学习的 RM) 灵活但有 hacking 风险

DeepSeek-R1 用了纯 verifier(0/1 准确率),所以才能砍掉 KL 还稳定。

完整 GRPO 训练循环

def grpo_train_step(policy, prompts, verifier, G=16, eps=0.2):
    # === 采样阶段 ===
    ref = copy.deepcopy(policy).eval()          # 本轮 ref 快照
    with torch.no_grad():
        rollouts = [policy.generate(p) for p in prompts for _ in range(G)]
        rewards  = verifier(rollouts)            # [B, G],可验证 reward
        logp_old = compute_logp(ref, rollouts)   # ref = policy snapshot

    # === 组内归一化 ===
    A = (rewards - rewards.mean(-1, keepdim=True)) / \
        (rewards.std(-1, keepdim=True) + 1e-8)

    # === 多 epoch PPO 风格更新 ===
    for _ in range(ppo_epochs):
        logp_new = compute_logp(policy, rollouts)
        ratio    = torch.exp(logp_new - logp_old)
        loss     = -torch.min(ratio * A,
                              torch.clamp(ratio, 1-eps, 1+eps) * A).mean()
        loss.backward(); step()

高频追问

Q A
为什么组内归一化能替代 value baseline? 同 prompt 下 reward 才可比,组内均值是无偏 baseline 估计
Group size G 怎么选? InstructGPT 8;DeepSeek-R1 用 64。大 G 估计准但显存爆;小 G 方差大
没 value 网络怎么处理 token-level? 序列级 A 广播到所有 token (typical);或 token-level loss + length norm
为什么 ref 每步 deepcopy 而不是固定 SFT? GRPO 的 ratio 主要做 off-policy 修正(单批数据多轮更新);不需要长期 KL 约束
GRPO vs RLOO 区别? RLOO 用 leave-one-out 均值;GRPO 多了 std 归一化 + PPO clip 支持多 epoch
DeepSeek-R1 为什么能不加 KL? 数学/代码用 verifier (0/1 正确性),不会 reward hacking
长 CoT 怎么自然涌现? 大 G 探索充足 + 正确答案被加权 → policy 学会先想后答(无监督)
长度偏置怎么处理? Token-mean normalization (logp / 序列长度) 或 per-token loss
GRPO 能 multi-epoch 吗? 能,和 PPO 一样靠 ratio + clip 控住 off-policy 漂移
GRPO 适合什么、不适合什么? 适合: 数学/代码/有 verifier 的任务。不适合: 主观对话 (需要 RM 的偏好)

10. DPO — 把 RM + RL 合成一个分类损失

Key Insight: 「最大化 reward + KL 约束」的最优解 → 反解出 \(r = \beta \log(\pi/\pi_{\mathrm{ref}})\) → 代入 Bradley-Terry → RM 这一步消失,变成一个 sigmoid 分类损失。

实现层次

══════════ 数据 ══════════
   (prompt, y_w (chosen tokens), y_l (rejected tokens))


══════════ 4 次前向 = 2 (policy/ref) × 2 (chosen/rejected) ══════════

                ┌──────────────────── policy(πθ) ────────────────────┐
                │                                                     │
   y_w ────────►│ logits_w ──► seq_logprob ──► π_w  (要梯度)         │
                │                                                     │
   y_l ────────►│ logits_l ──► seq_logprob ──► π_l  (要梯度)         │
                └─────────────────────────────────────────────────────┘

                ┌──────────────────── ref(π_SFT) ─────────────────────┐
                │  (冻结, with torch.no_grad())                       │
   y_w ────────►│ logits  ──► seq_logprob ──► r_w                    │
                │                                                     │
   y_l ────────►│ logits  ──► seq_logprob ──► r_l                    │
                └─────────────────────────────────────────────────────┘

══════════ 损失 ══════════

   margin = (π_w − r_w) − (π_l − r_l)        ← chosen 相对 ref 的 log-ratio
                                                减 rejected 相对 ref 的 log-ratio
   logits = β · margin
   loss   = −log σ(logits) = −F.logsigmoid(logits).sum()
                              │
                              ▼ backward → 只更新 policy

══════════ 模型数 ══════════
   policy (训练中) + ref (冻结) = 2 个    ← 比 PPO 的 4 个少一半

公式

把隐式 reward $$r(x,y) = \beta \log \dfrac{\pi_\theta(y x)}{\pi_{\mathrm{ref}}(y x)}$$ 代入 Bradley-Terry:
\[\mathcal{L}_{\mathrm{DPO}} = -\log \sigma\Bigl(\beta \Bigl[ \bigl(\log \pi_\theta(y_w) - \log \pi_{\mathrm{ref}}(y_w)\bigr) - \bigl(\log \pi_\theta(y_l) - \log \pi_{\mathrm{ref}}(y_l)\bigr) \Bigr]\Bigr)\]

直觉

  • 抬高 chosen 相对 ref 的 log-ratio,压低 rejected 的
  • β 控制偏离 ref 的强度(≈ PPO 里的 KL 系数)
  • 不需要 RM,不需要 RL,纯监督训练

代码骨架

def sequence_logprob(logits, tokens):
    logp = torch.log_softmax(logits, dim=-1)
    # t 时刻 logits 预测 t+1 的 token
    chosen = logp[:, :-1, :].gather(-1, tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return chosen.sum(-1)

def dpo_loss(policy_logits_w, policy_logits_l,
             ref_logits_w, ref_logits_l,
             tokens_w, tokens_l, beta=0.2):
    pi_w = sequence_logprob(policy_logits_w, tokens_w)
    pi_l = sequence_logprob(policy_logits_l, tokens_l)
    with torch.no_grad():
        ref_w = sequence_logprob(ref_logits_w, tokens_w)
        ref_l = sequence_logprob(ref_logits_l, tokens_l)
    logits = beta * ((pi_w - ref_w) - (pi_l - ref_l))
    return -F.logsigmoid(logits).sum()

DPO 损失从 RLHF 目标的完整推导(必背)

Step 1: RLHF 的优化目标 = “最大化 reward + KL 不偏离 ref”:

\[\pi^* = \arg\max_\pi\ \mathbb{E}_{x,\,y \sim \pi}\bigl[r(x, y)\bigr] - \beta \cdot \mathrm{KL}\bigl(\pi \,\|\, \pi_{\mathrm{ref}}\bigr)\]

Step 2: 这个约束优化有闭式最优解(Lagrangian 配方法):

\[\pi^*(y \mid x) = \frac{1}{Z(x)}\,\pi_{\mathrm{ref}}(y \mid x)\,\exp\!\left(\frac{r(x, y)}{\beta}\right), \quad Z(x) = \sum_y \pi_{\mathrm{ref}}(y \mid x)\,\exp\!\left(\frac{r(x, y)}{\beta}\right)\]

Step 3: 反解出 reward(从最优 policy 反推 reward):

\[r(x, y) = \beta \log\frac{\pi^*(y \mid x)}{\pi_{\mathrm{ref}}(y \mid x)} + \beta \log Z(x)\]

\(\beta \log Z(x)\) 是 prompt 相关的常数项,在 Bradley-Terry 的消掉了

Step 4: 代入 Bradley-Terry \(P(y_w \succ y_l) = \sigma(r(x, y_w) - r(x, y_l))\):

\[\begin{aligned} P(y_w \succ y_l) &= \sigma\!\left(\beta \log\frac{\pi^*(y_w)}{\pi_{\mathrm{ref}}(y_w)} - \beta \log\frac{\pi^*(y_l)}{\pi_{\mathrm{ref}}(y_l)}\right) \\ &= \sigma\!\Bigl(\beta\bigl[(\log\pi^*(y_w) - \log\pi_{\mathrm{ref}}(y_w)) - (\log\pi^*(y_l) - \log\pi_{\mathrm{ref}}(y_l))\bigr]\Bigr) \end{aligned}\]

Step 5: 最大化对数似然 → DPO loss(把 \(\pi^*\) 替换成可训练的 \(\pi_\theta\)):

\[\boxed{\;\mathcal{L}_{\mathrm{DPO}} = -\log\sigma\Bigl(\beta\bigl[ (\log\pi_\theta(y_w) - \log\pi_{\mathrm{ref}}(y_w)) - (\log\pi_\theta(y_l) - \log\pi_{\mathrm{ref}}(y_l)) \bigr]\Bigr)\;}\]

关键观察: 这个推导完全等价于 RM + PPO,但消去了 RM 这一步。 代价: PPO 是 on-policy 自己采样,DPO 是 off-policy 用固定偏好数据。

DPO 梯度的直觉(看公式知道在干啥)

对 DPO loss 求 \(\theta\) 梯度,令 \(h = \beta \cdot \mathrm{margin}\):

\[\nabla_\theta \mathcal{L}_{\mathrm{DPO}} = -\beta \cdot \underbrace{\sigma(-h)}_{\text{置信度调节子}} \cdot \nabla_\theta\bigl[\log\pi_\theta(y_w) - \log\pi_\theta(y_l)\bigr]\]

\(\sigma(-h)\) 是个置信度调节子: 模型已经把 chosen 拉高很多 → \(h\) 大 → \(\sigma(-h)\) 很小 → 梯度自动减弱。 自动 hard mining: 关注那些当前还分不清的 pair。

长度偏置 (DPO 的著名痛点)

\[\log \pi(\mathrm{seq}) = \sum_{t=1}^{|\mathrm{seq}|} \log \pi(t_t \mid t_{<t}) \quad\text{(项数 = 序列长度)}\]

如果 chosen 平均比 rejected 长 50 tokens,\(\log\pi(\mathrm{chosen}) - \log\pi(\mathrm{rejected})\) 会有 ~50 个额外 logp 项贡献,模型学到”输出长 = 高 reward”的虚假关联

实测: 普通 DPO 训完后,输出长度比 SFT 长 2-3 倍,但人类胜率提升有限。

修法:

  • SimPO: 用 logπ / |seq| 归一化,去掉 ref 模型
  • Length-normalized DPO: 加 length penalty 项
  • R-DPO: 在 reward 里减 length term

DPO 的几个进化版本

算法 改了什么 解决了什么
DPO (Rafailov 2023) RM + PPO → 直接分类 loss 简化 RLHF
IPO (Azar 2023) sigmoid → identity 防止过拟合 / 偏好概率饱和
KTO (Ethayarajh 2024) pairwise → 单样本 like/dislike 不需要成对数据,只要 thumbs up/down
SimPO (Meng 2024) 去掉 ref 模型 + length normalize 省显存 + 解长度偏置
RPO (Pal 2024) DPO + 显式 SFT loss 项 防 logp 整体下降的坍塌
R-DPO 加 length penalty 到 reward 解长度偏置
β-DPO 动态调 β 不同样本难度自适应

DPO 训练中的”logp 双降”现象 (常见 bug)

实际训完 DPO 后画图,常发现:

logπ_θ(chosen):    -25   (训前)  →   -40   (训后)    ← 居然降了!
logπ_θ(rejected):  -27   (训前)  →   -55   (训后)    ← 降得更多
margin:              2                  15            ← margin 增大了

margin 在涨,所以 DPO loss 在降,但 chosen 的概率也降了!原因: DPO 只关心 margin,不关心绝对概率。

修法 (RPO): 加一个 SFT loss 项强制 \(\log \pi_\theta(\mathrm{chosen})\) 不要下降:

\[\mathcal{L}_{\mathrm{RPO}} = \mathcal{L}_{\mathrm{DPO}} + \lambda \cdot \bigl(-\log \pi_\theta(y_w)\bigr) \qquad\text{(第二项就是普通 SFT loss)}\]

DPO vs PPO 实操对比

维度 PPO DPO
训练阶段 SFT → RM → PPO SFT → DPO(一步到位)
Reward model 显式训 隐式
RL 采样 on-policy 采 用离线偏好数据
显存 4 模型 (policy + ref + reward + value) 2 模型 (policy + ref)
实现复杂度 高 (4 模型 + GAE + clip) 低 (一个 sigmoid loss)
训练稳定性 难调 稳定(分类损失)
数据分布敏感性 低 (自己采样) 高 (OOD 偏好学不好)
工业线落地 Anthropic, OpenAI 早期 Llama 3, Mistral, Zephyr
性能上限 略高 接近但有差距

完整 DPO 训练代码

def dpo_train_step(policy, ref, batch, beta=0.1):
    # 4 次前向: 2 (policy/ref) × 2 (chosen/rejected)
    logp_pol_w = sequence_logprob(policy(batch['chosen']), batch['chosen'])
    logp_pol_l = sequence_logprob(policy(batch['rejected']), batch['rejected'])

    with torch.no_grad():
        logp_ref_w = sequence_logprob(ref(batch['chosen']), batch['chosen'])
        logp_ref_l = sequence_logprob(ref(batch['rejected']), batch['rejected'])

    # margin: chosen 的 log-ratio 减 rejected 的 log-ratio
    chosen_logratio   = logp_pol_w - logp_ref_w
    rejected_logratio = logp_pol_l - logp_ref_l
    margin            = chosen_logratio - rejected_logratio

    loss = -F.logsigmoid(beta * margin).mean()

    # 监控指标
    chosen_reward   = beta * chosen_logratio.detach()
    rejected_reward = beta * rejected_logratio.detach()
    reward_gap      = (chosen_reward - rejected_reward).mean()    # 应该上涨
    accuracy        = (chosen_reward > rejected_reward).float().mean()  # 训练 accuracy

    return loss, {'reward_gap': reward_gap, 'acc': accuracy}

高频追问

Q A
DPO 为什么不需要 RM? 反解 RLHF 最优解: r = β·log(π/π_ref) + const,代入 BT,const 抵消,直接得分类 loss
β 怎么选? 0.1-0.5 常见;太大 → 几乎不动 ref;太小 → 偏离 ref 太远易崩
为什么还要 ref? 限制 policy 不要离 SFT 太远;ref 同时给 chosen / rejected 提供归一化基准
DPO 的缺点? (a) Off-policy → OOD 偏好学不好 (b) 长度偏置 (c) chosen logp 可能整体下降
怎么改进 DPO? IPO (防过拟合), KTO (单点偏好), SimPO (去 ref + length norm), RPO (加 SFT)
DPO 训练中 chosen logp 应该上升吗? 不一定!DPO 只优化 margin,实测 chosen 和 rejected 经常一起下降 (用 RPO 修)
DPO 能 multi-epoch 吗? 能,但易过拟合;经验 1-3 epoch 最佳。IPO 提出就是为了缓解
为什么 SimPO 去掉 ref 还 work? logπ/|seq| 自带 length norm + ref 的归一化作用被 length scale 替代
DPO 数据从哪来? 同 prompt 不同 temperature 采样的两个回复 → 人标 chosen/rejected;或 GPT-4 当 judge
DPO 训练时 ref 要不要也 grad? 绝对不要!ref 必须 with torch.no_grad(),否则梯度全乱

11. 一图全景: 四种对齐算法对比

              数据形式                    需要 Reward Model?           是否需要在线采样?
──────────────────────────────────────────────────────────────────────────────────
SFT       (prompt, response)              ─                            ─
RM        (prompt, chosen, rejected)      训出来 ←                     ─
PPO       (prompt) + RM 打分              是 (独立训)                  是 (on-policy)
GRPO      (prompt) + RM 或 verifier       是 或 否 (可验证 reward)      是 (组内采样)
DPO       (prompt, chosen, rejected)      ─                            ─ (纯监督)
                            ┌─ value baseline ─► PPO  (4 模型)
            策略梯度 + RM ──┤
                            └─ 组内归一化 ────► GRPO (3 模型, 无 value)

            Bradley-Terry + 最优策略解析解 ───► DPO  (2 模型, 无 RM 无 RL)
显存:  PPO  ≫  GRPO  >  DPO  ≫  SFT
难度:  PPO  ≫  GRPO  ≈  DPO  >  SFT
质量:  PPO  >  GRPO  ≳  DPO    (DPO 在 OOD 数据上掉得快)

附: 复习自检清单

  • MinHash: 能口述”为什么 min h(A) = min h(B) 的概率 = Jaccard”
  • BPE: 能在白纸上跑完一个 3 词例子的前 3 轮合并
  • BPE vs WordPiece vs Unigram: 三个准则、三种方向、三种 OOV 处理
  • Byte-level BPE: GPT-2 的 Ġ 是什么,为什么用字节
  • Softmax 反向: 写出 s ⊙ (g - (g·s)),解释为什么
  • Attention 反向: dq/dk/dv 三个公式 + √D 的位置
  • Memory-Efficient Attention: 分块 num/den 累加;online softmax 的 running max 怎么维护;FlashAttention 为什么快(HBM,不是 FLOPs)
  • RM loss: -log σ(r_w - r_l),从 Bradley-Terry 推
  • GAE: 写出递推式 + γ/λ 极值含义
  • PPO: clipped surrogate + KL,4 个模型
  • GRPO: 比 PPO 少了什么(value),advantage 怎么算
  • DPO: 写出损失,解释为什么不需要 RM
  • 四者对比: 数据形式 / RM / on-policy / 显存 4 个维度