Let p(⋅∣θ) be the parametric model that models data x0∼q0, then we can optimize θ by maximize the likelihood p(x0∣θ) or equivalently
θmaxEx0∼q0logp(x0∣θ).
From now on, we use q's to denote the forward physical distributions and p's the backward variational ansatz. The parameters are implied in p(⋅∣θ)≡p(⋅).
Let T>1 be the diffusion steps, the joint density of forward process q(x0,1,⋯T) can be expanded sequentially if the process is Markov
We will derive the lower bound of maximum likelihood objective Ex0∼q0logp(x0) and then show that it is related to a KL-divergence up to a constant.
Proof. The original approach would expand the p joint density
p(x0)=∫1,⋯,TpT(xT)t=1∏Tpt−1∣t(xt−1∣xt)
assume Markov property of backward process. However, without assumption of Markov property, we can still insert an identity and get
p(x0)=∫1,⋯,Tq0,1,⋯,Tp0,1,⋯,Tq0,1,⋯,T.
Note that q0,1,⋯,T=q1,⋯,T∣0q0, then we can reinterpret the integral
p(x0)=q0Eq1,⋯,T∣0q0,1,⋯,Tp0,1,⋯,T.
The log-likelihood average over all data distribution is
The entropy H[q0] depends on data distribution and does not contain model parameters. Thus, maximizing log-likelihood lower bound LELBO is equivalent to minimizing KL-divergence between forward process joint density and backward process joint density.
Regardless of Markov property of the processes, max likelihood lower bound is equivalent to min KL-divergence between variational ansatz and physical forward process.
Markov variational Ansatz
If the forward process is Markov, then we have
q0,1,⋯,T=q0t=1∏Tqt∣t−1.
Similarly, if we assume the backward joint density can be expanded as
p0,1,⋯,T=(t=1∏Tpt−1∣t)pT.
The posterior of physical forward process may be represented as
qt−1∣t∝qt∣t−1qt−1;
however, without the knowledge of the initial state x0, there could be infinity possibilities. Therefore, we fix the initial state, and get probabilies given the fixed x0∼q0
qt−1∣t,0∝qt∣t−1,0qt−1∣0,
or the equality for t>1
qt−1∣t,0qt∣0=qt,t−1∣0=qt∣t−1,0qt−1∣0.
Thus, the forward process joint density has an posterior expansion
Using the posterior expansion of q, the total KL-divergence
DKL(q0,⋯,T∣p0,⋯,T)≡t=1∑TDt−1
where
D0=DKL(q0∣p0∣1),
Dt−1=DKL(qt−1∣t,0∣pt−1∣t,0) for 1<t<T, and
DT=DKL(qT∣0∣pT).
The last term DT is a constant with fixed distribution pT. If add back the entropy term H[q0], the first term becomes the usual likelihood
L0=D0+H[q0]=−logp0∣1
and the total loss becomes a typical variational inference loss: the sum of data negative log-likelihood and a series of prior KL-divergence.
So far we have not assumed any specific distribution yet. The objective is purely based on the assumption of Markov property.
Gaussian diffusion
For the particular case of Gaussian diffusion models, we assume
the terminal distribution is normal
qT∣0=qT=pT=N(xT;0,1),
forward transition process is Gaussian
qt∣t−1=N(xt;1−βtxt−1,βt1)
where 0<βt≤1.
It is useful to introduce additional notations:
αt≡1−βt, and
αˉt≡∏t=1Tαt.
Reparameterization
Let z∼N(0,1), the forward process can be written as
xt=1−βtxt−1+βtz=αtxt−1+βtz.
Iteratively apply the formula, we get
xt=αˉtx0+1−αˉtz.
Thus, we can generate xt for anyt without actually do the iterative calculations. There is a similar property for any Markov process, i.e. Feynman-Kac formula
Posterior is Gaussian
We can compute the posterior of the physical process using
qt−1∣t,0∝qt∣t−1,0qt−1∣0
where the RHS is a product of Gaussians. One can show that the posterior is indeed Gaussian after doing an easy but lengthy calculation
qt−1∣t,0=N(xt−1;μ~t(xt,x0),β~t1)
where
μ~t=1−αˉtαt(1−αˉt−1)xt+βtαˉt−1x0,
and
β~t=βt1−αt1−αˉt−1.
Express x0 in terms of xt and noise, the mean simplies
μ~t(xt)=αt−21(xt−1−αˉtβtz).
Variational Ansatz
Since the target distribution is Gaussian, it is a good idea to choose Gaussian distribution as the variational Ansatz
pt−1∣t=N(xt−1;μt,σt21)
where the model parameters are μt and σt. The variance will eventually contribute to learning rate; we will treat σt as a hyperparameter instead of learning it from stochastic gradient descent. The only learnable parameter is then μt=μt(xt,t).
where we used the KL-divergence between two Gaussian distributions.
Clearly, we have the exact solution
δμtδDt−1[μt]=0⇒μt=μ~t,
and therefore,
μt(xt,t)=αt−21(xt−1−αˉtβtzt)
where zt∼N(0,1) is the noise that generated xt from x0.
Next, we reparameterize μt to separate the explicit dependency of xt and t and let the model only focus on the implicit dependencies, i.e.
μt(xt,t)=αt−21(xt−1−αˉtβtz(xt,t)).
Finally, the objective becomes
Dt−1=2αt(1−αˉt)σt2βt2∥zt−z(xt,t)∥2
where z(xt,t) is the model output.
It can also be shown that
L0=2α1σ12β1∥z1−z(x0,1)∥2
where αˉ1=α1=1−β1.
Thus, the generic loss term for 0<t<T is
Lt−1=2αt(1−αˉt)σt2βt2∥zt−z(xt,t)∥2.
Note that we still have the freedom to choose σt that controls the importance of each step. But in the literature, they usually take a heuristic approach by ignoring the weight factor keeping only the ℓ2 loss.
Sampling the backward process
During training, the model learned the backward transition distribution