☕
Machine Learning Demystified
  • Introduction
  • Blogs
    • Spherical counterpart of Gaussian kernel
    • Vanilla Kalman Filter
    • Support Vector Machine
    • One Class SVM
    • Differentiable Bayesian Structure Learning
    • Stein Variational Gradient Descent
    • Finite Step Unconditional Markov Diffusion Models
Powered by GitBook
On this page
  • Introduction
  • Notations
  • Stein identity and discrepancy
  • Variational inference
  • Coordinate flow
  • Method of steepest descent
  • The Stein variational inference algorithm
  • Reference

Was this helpful?

Edit on GitHub
  1. Blogs

Stein Variational Gradient Descent

A new way to learn

PreviousDifferentiable Bayesian Structure LearningNextFinite Step Unconditional Markov Diffusion Models

Last updated 3 years ago

Was this helpful?

Introduction

Gradient descent has become the basic algorithm for the training of almost all deep learning models. Stein variational gradient descent was proposed as a "natural counterpart of gradient descent for optimization."

In our previous blog post Differentiable Bayesian Structure Learning, we briefly mentioned the core engine of the Bayesian algorithm was the Stein variational gradient descent. In this article, we will expand this topic and articulate the motivation, fundamental logic, mathematical derivation of this novel optimization method. The original paper can be found in the Reference section while in this article we will derive the algorithm in a much more intuitive way.

Notations

The usual inner product of two functions fff and ggg is defined as

⟨f,g⟩=∫dx  f(x)g(x)\langle f, g \rangle = \int dx \; f(x)g(x)⟨f,g⟩=∫dxf(x)g(x)

Let fff and ggg be functions in (RKHS) H\mathcal HH, then the inner product is denoted ⟨f,g⟩H\langle f, g \rangle_\mathcal H⟨f,g⟩H​ .

Let {fμ}μ=1d\{f_\mu\}_{\mu=1}^d{fμ​}μ=1d​ and {gν}ν=1d\{g_\nu\}_{\nu=1}^d{gν​}ν=1d​ be functions in RKHS Hd\mathcal H^dHd, then the inner product

⟨f,g⟩Hd=δμν⟨fμ,gν⟩Hd.\langle \mathbf f, \mathbf g \rangle_{\mathcal H^d} = \delta_{\mu\nu}\langle f_\mu, g_\nu \rangle_{\mathcal H^d}.⟨f,g⟩Hd​=δμν​⟨fμ​,gν​⟩Hd​.

Stein identity and discrepancy

Observe that using Stokes' Theorem

Variational inference

we get

Coordinate flow

and

We have the following equivalent optimization problems

which is the equivalence of active and passive perspectives of coordinate transformations. In the latter case, we have the velocity in reverse direction

We have the following expansion

The increment of the objective

Method of steepest descent

Now we got the gradient descent of our variational inference objective, but we wish to get the steepest descent by searching for a proper velocity field.

The solution of the optimization

The Stein variational inference algorithm

The algorithm is as follows:

  • Update the coordinates of the particles using the calculated flow

  • Repeat the process

Reference

∫Mfdp+pdf=∫Md(pf)=∫∂Mpf\int_{M} f\mathrm dp + p\mathrm df = \int_{M} \mathrm d (pf) = \int_{\partial M} pf∫M​fdp+pdf=∫M​d(pf)=∫∂M​pf

if pf→0pf \rightarrow 0pf→0 at boundary ∂M\partial M∂M and both $p$ and fff are smooth functions, then we have

∫Mfdp+pdf=0.\int_M f\mathrm d p + p \mathrm df = 0.∫M​fdp+pdf=0.

If ppp is a probability density over MMM, then we have the Stein identity

∫Mfp  dlog⁡p+p df=Ep[f dlog⁡p+df]=0\int_M fp\; \mathrm d\log p + p\, \mathrm d f = \mathbb E_p \left[ f\, \mathrm d \log p + \mathrm d f \right] = 0∫M​fpdlogp+pdf=Ep​[fdlogp+df]=0

for any test function fff that satisfies the requirements. Now we replace the sampling distribution by qqq

Ep[fdlog⁡p+df]→Eq[fdlog⁡p+df]≡Sq,pf\mathbb E_p [f \mathrm d\log p + \mathrm d f] \rightarrow \mathbb E_q [f \mathrm d\log p + \mathrm d f]\equiv S_{q,p}fEp​[fdlogp+df]→Eq​[fdlogp+df]≡Sq,p​f

we get the Stein discrepancy which vanishes when qqq is ppp. Thus, we obtain a measure of "distance" between qqq and ppp with properly chosen test function fff.

The goal of variational inference is to approximate a target distribution ppp with a tractable ansatz distribution qqq by minimizing the

DKL(q∣∣p)=Eqlog⁡qp=(−Eqlog⁡p)−(−Eqlog⁡q)D_{\rm KL} (q||p) = \mathbb E_q \log \frac{q}{p} = (-\mathbb E_q\log p) - (-\mathbb E_q \log q)DKL​(q∣∣p)=Eq​logpq​=(−Eq​logp)−(−Eq​logq)

which is in the form of free energy F=U−TSF = U - TSF=U−TS with temperature equals unity. Thus, minimization of KL-divergence is equivalent to striking a balance between minimizing qqq-average of energy (−log⁡p-\log p−logp) and maximizing the entropy of qqq.

KL-divergence is non-negative and attains minimum zero when q=pq=pq=p. We wish to min⁡qDKL(q∣∣p)\min_q D_{\rm KL}(q||p)minq​DKL​(q∣∣p) subject to the constraint that $q$ is a probability distribution ∫dx q(x)=⟨q,1⟩=1\int dx\, q(x) = \langle q, 1 \rangle = 1∫dxq(x)=⟨q,1⟩=1. Thus, the total objective is

L[q]=⟨q,log⁡q−log⁡p⟩−λ⟨q,1⟩\mathcal L[q] = \langle q, \log q - \log p \rangle - \lambda \langle q, 1 \rangleL[q]=⟨q,logq−logp⟩−λ⟨q,1⟩

where λ\lambdaλ is a Lagrange multiplier. Take functional derivative

δLδq=log⁡q−log⁡p+1−λ=0\frac{\delta \mathcal L}{\delta q} = \log q - \log p + 1 - \lambda = 0δqδL​=logq−logp+1−λ=0
q=pexp⁡(λ−1).q = p \exp (\lambda -1).q=pexp(λ−1).

Since ppp is a distribution, then we get λ=1\lambda=1λ=1. Thus, we showed q=pq=pq=p is the solution of the optimization problem. However, for a real-world distribution ppp, which may be arbitrarily complicated, it is impossible to obtain an exact equality but a best approximation given the functional form of ansatz qqq.

The ansatz distribution can be manually constructed based on the knowledge of the target distribution, e.g. mean-field approximation using exponential family distributions. In this article, we will discuss a non-parametric approach using particles of qqq.

Since both qqq and ppp are smooth functions and more importantly probability distributions, we can adiabatically deform qqq into ppp by shifting the coordinate

xμ↦xμ+vμ(x)δtx^\mu \mapsto x^\mu + v^\mu(x) \delta txμ↦xμ+vμ(x)δt

where for simplicity we assume x∈Rdx \in \mathbb R^dx∈Rd.

The task of seeking such a transformation is equivalent to searching for a proper velocity field v(x)v(x)v(x).

The total mass of qqq is conserved and we have a conserved current of qqq-charge

jqμ(x)=q(x)vμ(x)j_q^\mu (x) = q(x) v^\mu(x)jqμ​(x)=q(x)vμ(x)
q˙(t)=−∂μjqμ=−vμ∂μq−q∂μvμ.\dot q (t) = - \partial_\mu j^\mu _q = - v^\mu \partial_\mu q - q \partial_\mu v^\mu .q˙​(t)=−∂μ​jqμ​=−vμ∂μ​q−q∂μ​vμ.
min⁡vDKL(qt+δt∣∣pt)⇔min⁡vDKL(qt∣∣pt−δt)\min_{v} D_{\rm KL} (q_{t+\delta t}||p_t)\Leftrightarrow \min_{v} D_{\rm KL}(q_t||p_{t - \delta t})vmin​DKL​(qt+δt​∣∣pt​)⇔vmin​DKL​(qt​∣∣pt−δt​)
p˙(t)=−∂μjpμ=vμ∂μp+p∂μvμ.\dot p (t) = - \partial_\mu j_p^\mu = v^\mu \partial_\mu p + p \partial_\mu v^\mu .p˙​(t)=−∂μ​jpμ​=vμ∂μ​p+p∂μ​vμ.

where jpj_pjp​ is the conserved current of ppp-charge.

p(x,t−δt)=p(x,t)−p˙(x,t)δt+O(δt2)p(x, t-\delta t) = p(x, t) - \dot p (x, t) \delta t + \mathcal O(\delta t^2)p(x,t−δt)=p(x,t)−p˙​(x,t)δt+O(δt2)
log⁡p(x,t−δt)=log⁡p(x,t)−ddtlog⁡p(x,t) δt+O(δt2).\log p(x, t-\delta t) = \log p(x, t) - \frac{d}{dt}\log p (x, t) \, \delta t + \mathcal O(\delta t^2) .logp(x,t−δt)=logp(x,t)−dtd​logp(x,t)δt+O(δt2).
L[v]=DKL(qt∣∣pt−δt)=∫dx qt(x)log⁡qt(x)pt−δt(x)\mathcal L[v] = D_{\rm KL}(q_t||p_{t - \delta t}) = \int dx\, q_t(x) \log \frac{q_t(x)}{p_{t-\delta t}(x)}L[v]=DKL​(qt​∣∣pt−δt​)=∫dxqt​(x)logpt−δt​(x)qt​(x)​

in first order of δt\delta tδt

δL[v]=−δt∫dx qt(x)p˙t(x)pt(x).\delta \mathcal L[v] = - \delta t \int dx\, q_t(x) \frac{\dot p_t(x)}{p_t(x)} .δL[v]=−δt∫dxqt​(x)pt​(x)p˙​t​(x)​.

Replace p˙\dot pp˙​ using the continuity equation, we get

δL[v]δt=−∫dx q(x)(vμ(x)∂μp(x)+∂μvμ)=−Eq(vμ∂μp+∂μvμ).\frac{ \delta \mathcal L[v] }{\delta t} = - \int dx\, q(x) (v^\mu(x) \partial_\mu p(x) + \partial_\mu v^\mu) = - \mathbb E_q(v^\mu \partial_\mu p + \partial_\mu v^\mu).δtδL[v]​=−∫dxq(x)(vμ(x)∂μ​p(x)+∂μ​vμ)=−Eq​(vμ∂μ​p+∂μ​vμ).

In other words, the gradient of DKL(qt∣∣pt−δt)D_{\rm KL}(q_t||p_{t - \delta t})DKL​(qt​∣∣pt−δt​) is the negative Stein discrepancy of (q,p)(q,p)(q,p) using test function vvv, i.e. Sq,pvS_{q,p} vSq,p​v.

We further assume the velocity field is an element of ddd-dimensional RKHS Hd\mathcal H^dHd. Then we have the reproducing property

v=⟨K,v⟩Hd=⟨v,K⟩Hdv = \langle K, v \rangle_{\mathcal H^d} = \langle v, K \rangle_{\mathcal H^d}v=⟨K,v⟩Hd​=⟨v,K⟩Hd​

where K(⋅,⋅)K(\cdot, \cdot)K(⋅,⋅) is the kernel function of Hd\mathcal H^dHd, e.g. Gaussian RBF kernel. Furthermore, the linear operator Sq,pS_{q,p}Sq,p​ can be shifted to the KKK,

Sq,pv=⟨v,SqpK⟩HdS_{q,p}v = \langle v, S_{qp}K \rangle_{\mathcal H^d}Sq,p​v=⟨v,Sqp​K⟩Hd​

which is the "dot product" of vvv and SqpKS_{qp} KSqp​K in RKHS.

max⁡v,∥v∥Hd≤1Sqpv=⟨v,SqpK⟩Hd\max_{v, \Vert v \Vert_{\mathcal H^d} \le 1} S_{qp} v = \langle v, S_{qp}K \rangle_{\mathcal H^d}v,∥v∥Hd​≤1max​Sqp​v=⟨v,Sqp​K⟩Hd​

is simply v∗=SqpK/∥SqpK∥Hdv^* = S_{qp}K/\Vert S_{qp}K \Vert_{\mathcal H^d}v∗=Sqp​K/∥Sqp​K∥Hd​. Also note that the velocity will vanish when p=qp=qp=q.

Using method of steepest descent, we obtained the optimal flow field v∗v^*v∗. Next, we just need to go with flow

x↦x+v∗δt=x+δt′ SqpKx \mapsto x + v^* \delta t = x + \delta t'\, S_{qp}Kx↦x+v∗δt=x+δt′Sqp​K

and incrementally update q(x)q(x)q(x).

Sample mmm particles of initial qqq

Approximate SqpK=Eq[Kμ∂μlog⁡p+∂μKμ]S_{qp}K = \mathbb E_q [K^\mu \partial_\mu \log p + \partial_\mu K^\mu]Sqp​K=Eq​[Kμ∂μ​logp+∂μ​Kμ] using sample mean of mmm particles

The inverse of the norm of v∗v^*v∗ is absorbed into learning rate δt′\delta t'δt′

[1608.04471] Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm
Reproducing kernel Hilbert space
Kullback–Leibler divergence