Stein Variational Gradient Descent
A new way to learn
Last updated
Was this helpful?
A new way to learn
Last updated
Was this helpful?
Gradient descent has become the basic algorithm for the training of almost all deep learning models. Stein variational gradient descent was proposed as a "natural counterpart of gradient descent for optimization."
In our previous blog post Differentiable Bayesian Structure Learning, we briefly mentioned the core engine of the Bayesian algorithm was the Stein variational gradient descent. In this article, we will expand this topic and articulate the motivation, fundamental logic, mathematical derivation of this novel optimization method. The original paper can be found in the Reference section while in this article we will derive the algorithm in a much more intuitive way.
The usual inner product of two functions and is defined as
Let and be functions in (RKHS) , then the inner product is denoted .
Let and be functions in RKHS , then the inner product
Observe that using Stokes' Theorem
we get
and
We have the following equivalent optimization problems
which is the equivalence of active and passive perspectives of coordinate transformations. In the latter case, we have the velocity in reverse direction
We have the following expansion
The increment of the objective
Now we got the gradient descent of our variational inference objective, but we wish to get the steepest descent by searching for a proper velocity field.
The solution of the optimization
The algorithm is as follows:
Update the coordinates of the particles using the calculated flow
Repeat the process
if at boundary and both $p$ and are smooth functions, then we have
If is a probability density over , then we have the Stein identity
for any test function that satisfies the requirements. Now we replace the sampling distribution by
we get the Stein discrepancy which vanishes when is . Thus, we obtain a measure of "distance" between and with properly chosen test function .
The goal of variational inference is to approximate a target distribution with a tractable ansatz distribution by minimizing the
which is in the form of free energy with temperature equals unity. Thus, minimization of KL-divergence is equivalent to striking a balance between minimizing -average of energy () and maximizing the entropy of .
KL-divergence is non-negative and attains minimum zero when . We wish to subject to the constraint that $q$ is a probability distribution . Thus, the total objective is
where is a Lagrange multiplier. Take functional derivative
Since is a distribution, then we get . Thus, we showed is the solution of the optimization problem. However, for a real-world distribution , which may be arbitrarily complicated, it is impossible to obtain an exact equality but a best approximation given the functional form of ansatz .
The ansatz distribution can be manually constructed based on the knowledge of the target distribution, e.g. mean-field approximation using exponential family distributions. In this article, we will discuss a non-parametric approach using particles of .
Since both and are smooth functions and more importantly probability distributions, we can adiabatically deform into by shifting the coordinate
where for simplicity we assume .
The task of seeking such a transformation is equivalent to searching for a proper velocity field .
The total mass of is conserved and we have a conserved current of -charge
where is the conserved current of -charge.
in first order of
Replace using the continuity equation, we get
In other words, the gradient of is the negative Stein discrepancy of using test function , i.e. .
We further assume the velocity field is an element of -dimensional RKHS . Then we have the reproducing property
where is the kernel function of , e.g. Gaussian RBF kernel. Furthermore, the linear operator can be shifted to the ,
which is the "dot product" of and in RKHS.
is simply . Also note that the velocity will vanish when .
Using method of steepest descent, we obtained the optimal flow field . Next, we just need to go with flow
and incrementally update .
Sample particles of initial
Approximate using sample mean of particles
The inverse of the norm of is absorbed into learning rate