Approximating KL Divergence
主要介绍使用MC方法来近似KL散度的技巧。
参考link
KL公式:
KL[q,p]=x∑q(x)logp(x)q(x)=Ex∼qlogp(x)q(x)
前置假设:
- 我们知道概率密度计算,但是没法做遍历x做求和或者积分。
- 已知x1,x2,...∼q,即从真实分布中采样的样本。
- 一般在机器里,我们的模型可以表示p的函数。
K1
一个straightforward的做法是直接使用k1=logp(x)q(x)=−logr,这里定义r=q(x)p(x),我们可以用MC方法抽样算k1来近似。
- 但是有high variance, 它甚至可能是负的。
- 无偏的
K2
21(logp(x)q(x))2=21(logr)2
- 都是正的
- 有低的variance
- 低bias
Estimator k2=21(logr)2 的期望实际上是一个 f-divergence。f-divergence 是一类衡量概率分布差异的函数族,定义为:
Df(p,q)=Ex∼q[f(q(x)p(x))]
其中 f 是一个凸函数。KL 散度、χ² 散度、Total Variation 等常见的概率距离都属于 f-divergence。在分布 p≈q 的情形下,所有可微的 f-divergence 在数学上都近似于一个统一的二次形式:
Df(p0,pθ)=2f′′(1)θTFθ+O(∥θ∥3)
其中 F 是 Fisher 信息矩阵。这意味着不同的 f-divergence 在局部行为上高度相似。以 k2 为例,它对应的 f(x)=21(logx)2,而 KL[q‖p] 对应的是 f(x)=−logx,它们在 x=1 处的二阶导数都是 1,因此在 p≈q 时提供几乎相同的距离估计。因此,k2 不仅具有明确的理论解释,而且在实际中是一个偏差小、方差低的 KL 散度近似器。
所以在两个分布p,q接近时可以用k2来近似KL散度,增加了稳定性。
K3
k3=(r−1)−logr
我们希望构造一个对 KL 散度(例如 KL[q∥p])的估计器,既无偏又具有较低方差。一个常见的技巧是使用 control variate:即在原始估计量上加一个期望为零、但与其负相关的项,以减少方差。
KL[q‖p] 的标准估计器是:
k1=−logr,其中 r=q(x)p(x)
为了降低其方差,可以加上 λ(r−1),因为:
Ex∼q[r−1]=Eq[q(x)p(x)−1]=1−1=0
于是构造出一类无偏估计器:
−logr+λ(r−1)
通过最小化方差可以解出最优的 λ,但这个解依赖于 p(x) 和 q(x),通常难以解析求得。
为了解决这个问题,可以采用一个更简单又合理的选择:λ=1。由于 logx≤x−1(因为对数函数是 concave 的),所以这个选择保证估计器始终为正:
k3=(r−1)−logr
这个形式正好就是前面提到的 KL[q‖p] 的 Bregman 形式,几何意义上表示 r=q(x)p(x) 下,logx 与其在 x=1 处切线之间的垂直距离。它不仅是一个无偏估计器,而且更稳定、易计算,是强化学习等场景中常用的 KL 近似方式之一。
从数学角度推广上述感性想法
我们可以推广 Bregman 散度的思想,构造出对任意 f-divergence 的始终为正的估计器。给定凸函数 f(x),f-divergence 定义为:
Df(p,q)=Ex∼q[f(q(x)p(x))]
由于凸函数始终位于其切线之上,我们可以用以下表达式作为 f-divergence 的估计器:(r-1在期望下是0所以是无偏的)
f(r)−f′(1)(r−1),其中 r=q(x)p(x)
这个估计器永远非负,其几何含义是:f(r) 与它在 r=1 处的切线之间的垂直距离。
具体到 KL 散度:
-
对于 KL[p∥q],对应的 f(x)=xlogx,有 f′(1)=1,估计器为:
rlogr−(r−1)
-
对于 KL[q∥p],对应的 f(x)=−logx,有 f′(1)=−1,估计器为:
(r−1)−logr
这两个表达式不仅无偏,且始终为正,是在机器学习和强化学习中非常实用的 KL 散度估计方法。