Stein Variational Gradient Descent

A new way to learn

Introduction

Gradient descent has become the basic algorithm for the training of almost all deep learning models. Stein variational gradient descent was proposed as a "natural counterpart of gradient descent for optimization."

In our previous blog post Differentiable Bayesian Structure Learning, we briefly mentioned the core engine of the Bayesian algorithm was the Stein variational gradient descent. In this article, we will expand this topic and articulate the motivation, fundamental logic, mathematical derivation of this novel optimization method. The original paper can be found in the Reference section while in this article we will derive the algorithm in a much more intuitive way.

Notations

The usual inner product of two functions ff and gg is defined as

f,g=dx  f(x)g(x)\langle f, g \rangle = \int dx \; f(x)g(x)

Let ff and gg be functions in Reproducing kernel Hilbert space (RKHS) H\mathcal H, then the inner product is denoted f,gH\langle f, g \rangle_\mathcal H .

Let {fμ}μ=1d\{f_\mu\}_{\mu=1}^d and {gν}ν=1d\{g_\nu\}_{\nu=1}^d be functions in RKHS Hd\mathcal H^d, then the inner product

f,gHd=δμνfμ,gνHd.\langle \mathbf f, \mathbf g \rangle_{\mathcal H^d} = \delta_{\mu\nu}\langle f_\mu, g_\nu \rangle_{\mathcal H^d}.

Stein identity and discrepancy

Observe that using Stokes' Theorem

Mfdp+pdf=Md(pf)=Mpf\int_{M} f\mathrm dp + p\mathrm df = \int_{M} \mathrm d (pf) = \int_{\partial M} pf

if pf0pf \rightarrow 0 at boundary M\partial M and both $p$ and ff are smooth functions, then we have

Mfdp+pdf=0.\int_M f\mathrm d p + p \mathrm df = 0.

If pp is a probability density over MM, then we have the Stein identity

Mfp  dlogp+pdf=Ep[fdlogp+df]=0\int_M fp\; \mathrm d\log p + p\, \mathrm d f = \mathbb E_p \left[ f\, \mathrm d \log p + \mathrm d f \right] = 0

for any test function ff that satisfies the requirements. Now we replace the sampling distribution by qq

Ep[fdlogp+df]Eq[fdlogp+df]Sq,pf\mathbb E_p [f \mathrm d\log p + \mathrm d f] \rightarrow \mathbb E_q [f \mathrm d\log p + \mathrm d f]\equiv S_{q,p}f

we get the Stein discrepancy which vanishes when qq is pp. Thus, we obtain a measure of "distance" between qq and pp with properly chosen test function ff.

Variational inference

The goal of variational inference is to approximate a target distribution pp with a tractable ansatz distribution qq by minimizing the Kullback–Leibler divergence

DKL(qp)=Eqlogqp=(Eqlogp)(Eqlogq)D_{\rm KL} (q||p) = \mathbb E_q \log \frac{q}{p} = (-\mathbb E_q\log p) - (-\mathbb E_q \log q)

which is in the form of free energy F=UTSF = U - TS with temperature equals unity. Thus, minimization of KL-divergence is equivalent to striking a balance between minimizing qq-average of energy (logp-\log p) and maximizing the entropy of qq.

KL-divergence is non-negative and attains minimum zero when q=pq=p. We wish to minqDKL(qp)\min_q D_{\rm KL}(q||p) subject to the constraint that $q$ is a probability distribution dxq(x)=q,1=1\int dx\, q(x) = \langle q, 1 \rangle = 1. Thus, the total objective is

L[q]=q,logqlogpλq,1\mathcal L[q] = \langle q, \log q - \log p \rangle - \lambda \langle q, 1 \rangle

where λ\lambda is a Lagrange multiplier. Take functional derivative

δLδq=logqlogp+1λ=0\frac{\delta \mathcal L}{\delta q} = \log q - \log p + 1 - \lambda = 0

we get

q=pexp(λ1).q = p \exp (\lambda -1).

Since pp is a distribution, then we get λ=1\lambda=1. Thus, we showed q=pq=p is the solution of the optimization problem. However, for a real-world distribution pp, which may be arbitrarily complicated, it is impossible to obtain an exact equality but a best approximation given the functional form of ansatz qq.

The ansatz distribution can be manually constructed based on the knowledge of the target distribution, e.g. mean-field approximation using exponential family distributions. In this article, we will discuss a non-parametric approach using particles of qq.

Coordinate flow

Since both qq and pp are smooth functions and more importantly probability distributions, we can adiabatically deform qq into pp by shifting the coordinate

xμxμ+vμ(x)δtx^\mu \mapsto x^\mu + v^\mu(x) \delta t

where for simplicity we assume xRdx \in \mathbb R^d.

The task of seeking such a transformation is equivalent to searching for a proper velocity field v(x)v(x).

The total mass of qq is conserved and we have a conserved current of qq-charge

jqμ(x)=q(x)vμ(x)j_q^\mu (x) = q(x) v^\mu(x)

and

q˙(t)=μjqμ=vμμqqμvμ.\dot q (t) = - \partial_\mu j^\mu _q = - v^\mu \partial_\mu q - q \partial_\mu v^\mu .

We have the following equivalent optimization problems

minvDKL(qt+δtpt)minvDKL(qtptδt)\min_{v} D_{\rm KL} (q_{t+\delta t}||p_t)\Leftrightarrow \min_{v} D_{\rm KL}(q_t||p_{t - \delta t})

which is the equivalence of active and passive perspectives of coordinate transformations. In the latter case, we have the velocity in reverse direction

p˙(t)=μjpμ=vμμp+pμvμ.\dot p (t) = - \partial_\mu j_p^\mu = v^\mu \partial_\mu p + p \partial_\mu v^\mu .

where jpj_p is the conserved current of pp-charge.

We have the following expansion

p(x,tδt)=p(x,t)p˙(x,t)δt+O(δt2)p(x, t-\delta t) = p(x, t) - \dot p (x, t) \delta t + \mathcal O(\delta t^2)
logp(x,tδt)=logp(x,t)ddtlogp(x,t)δt+O(δt2).\log p(x, t-\delta t) = \log p(x, t) - \frac{d}{dt}\log p (x, t) \, \delta t + \mathcal O(\delta t^2) .

The increment of the objective

L[v]=DKL(qtptδt)=dxqt(x)logqt(x)ptδt(x)\mathcal L[v] = D_{\rm KL}(q_t||p_{t - \delta t}) = \int dx\, q_t(x) \log \frac{q_t(x)}{p_{t-\delta t}(x)}

in first order of δt\delta t

δL[v]=δtdxqt(x)p˙t(x)pt(x).\delta \mathcal L[v] = - \delta t \int dx\, q_t(x) \frac{\dot p_t(x)}{p_t(x)} .

Replace p˙\dot p using the continuity equation, we get

δL[v]δt=dxq(x)(vμ(x)μp(x)+μvμ)=Eq(vμμp+μvμ).\frac{ \delta \mathcal L[v] }{\delta t} = - \int dx\, q(x) (v^\mu(x) \partial_\mu p(x) + \partial_\mu v^\mu) = - \mathbb E_q(v^\mu \partial_\mu p + \partial_\mu v^\mu).

In other words, the gradient of DKL(qtptδt)D_{\rm KL}(q_t||p_{t - \delta t}) is the negative Stein discrepancy of (q,p)(q,p) using test function vv, i.e. Sq,pvS_{q,p} v.

Method of steepest descent

Now we got the gradient descent of our variational inference objective, but we wish to get the steepest descent by searching for a proper velocity field.

We further assume the velocity field is an element of dd-dimensional RKHS Hd\mathcal H^d. Then we have the reproducing property

v=K,vHd=v,KHdv = \langle K, v \rangle_{\mathcal H^d} = \langle v, K \rangle_{\mathcal H^d}

where K(,)K(\cdot, \cdot) is the kernel function of Hd\mathcal H^d, e.g. Gaussian RBF kernel. Furthermore, the linear operator Sq,pS_{q,p} can be shifted to the KK,

Sq,pv=v,SqpKHdS_{q,p}v = \langle v, S_{qp}K \rangle_{\mathcal H^d}

which is the "dot product" of vv and SqpKS_{qp} K in RKHS.

The solution of the optimization

maxv,vHd1Sqpv=v,SqpKHd\max_{v, \Vert v \Vert_{\mathcal H^d} \le 1} S_{qp} v = \langle v, S_{qp}K \rangle_{\mathcal H^d}

is simply v=SqpK/SqpKHdv^* = S_{qp}K/\Vert S_{qp}K \Vert_{\mathcal H^d}. Also note that the velocity will vanish when p=qp=q.

The Stein variational inference algorithm

Using method of steepest descent, we obtained the optimal flow field vv^*. Next, we just need to go with flow

xx+vδt=x+δtSqpKx \mapsto x + v^* \delta t = x + \delta t'\, S_{qp}K

and incrementally update q(x)q(x).

The algorithm is as follows:

  • Sample mm particles of initial qq

  • Approximate SqpK=Eq[Kμμlogp+μKμ]S_{qp}K = \mathbb E_q [K^\mu \partial_\mu \log p + \partial_\mu K^\mu] using sample mean of mm particles

  • The inverse of the norm of vv^* is absorbed into learning rate δt\delta t'

  • Update the coordinates of the particles using the calculated flow

  • Repeat the process

Reference

Last updated

Was this helpful?