Stein Variational Gradient Descent
A new way to learn
Introduction
Gradient descent has become the basic algorithm for the training of almost all deep learning models. Stein variational gradient descent was proposed as a "natural counterpart of gradient descent for optimization."
In our previous blog post Differentiable Bayesian Structure Learning, we briefly mentioned the core engine of the Bayesian algorithm was the Stein variational gradient descent. In this article, we will expand this topic and articulate the motivation, fundamental logic, mathematical derivation of this novel optimization method. The original paper can be found in the Reference section while in this article we will derive the algorithm in a much more intuitive way.
Notations
The usual inner product of two functions f and g is defined as
Let f and g be functions in Reproducing kernel Hilbert space (RKHS) H, then the inner product is denoted ⟨f,g⟩H .
Let {fμ}μ=1d and {gν}ν=1d be functions in RKHS Hd, then the inner product
Stein identity and discrepancy
Observe that using Stokes' Theorem
if pf→0 at boundary ∂M and both $p$ and f are smooth functions, then we have
If p is a probability density over M, then we have the Stein identity
for any test function f that satisfies the requirements. Now we replace the sampling distribution by q
we get the Stein discrepancy which vanishes when q is p. Thus, we obtain a measure of "distance" between q and p with properly chosen test function f.
Variational inference
The goal of variational inference is to approximate a target distribution p with a tractable ansatz distribution q by minimizing the Kullback–Leibler divergence
which is in the form of free energy F=U−TS with temperature equals unity. Thus, minimization of KL-divergence is equivalent to striking a balance between minimizing q-average of energy (−logp) and maximizing the entropy of q.
KL-divergence is non-negative and attains minimum zero when q=p. We wish to minqDKL(q∣∣p) subject to the constraint that $q$ is a probability distribution ∫dxq(x)=⟨q,1⟩=1. Thus, the total objective is
where λ is a Lagrange multiplier. Take functional derivative
we get
Since p is a distribution, then we get λ=1. Thus, we showed q=p is the solution of the optimization problem. However, for a real-world distribution p, which may be arbitrarily complicated, it is impossible to obtain an exact equality but a best approximation given the functional form of ansatz q.
The ansatz distribution can be manually constructed based on the knowledge of the target distribution, e.g. mean-field approximation using exponential family distributions. In this article, we will discuss a non-parametric approach using particles of q.
Coordinate flow
Since both q and p are smooth functions and more importantly probability distributions, we can adiabatically deform q into p by shifting the coordinate
where for simplicity we assume x∈Rd.
The task of seeking such a transformation is equivalent to searching for a proper velocity field v(x).
The total mass of q is conserved and we have a conserved current of q-charge
and
We have the following equivalent optimization problems
which is the equivalence of active and passive perspectives of coordinate transformations. In the latter case, we have the velocity in reverse direction
where jp is the conserved current of p-charge.
We have the following expansion
The increment of the objective
in first order of δt
Replace p˙ using the continuity equation, we get
In other words, the gradient of DKL(qt∣∣pt−δt) is the negative Stein discrepancy of (q,p) using test function v, i.e. Sq,pv.
Method of steepest descent
Now we got the gradient descent of our variational inference objective, but we wish to get the steepest descent by searching for a proper velocity field.
We further assume the velocity field is an element of d-dimensional RKHS Hd. Then we have the reproducing property
where K(⋅,⋅) is the kernel function of Hd, e.g. Gaussian RBF kernel. Furthermore, the linear operator Sq,p can be shifted to the K,
which is the "dot product" of v and SqpK in RKHS.
The solution of the optimization
is simply v∗=SqpK/∥SqpK∥Hd. Also note that the velocity will vanish when p=q.
The Stein variational inference algorithm
Using method of steepest descent, we obtained the optimal flow field v∗. Next, we just need to go with flow
and incrementally update q(x).
The algorithm is as follows:
Sample m particles of initial q
Approximate SqpK=Eq[Kμ∂μlogp+∂μKμ] using sample mean of m particles
The inverse of the norm of v∗ is absorbed into learning rate δt′
Update the coordinates of the particles using the calculated flow
Repeat the process
Reference
Last updated