☕
Machine Learning Demystified
  • Introduction
  • Blogs
    • Spherical counterpart of Gaussian kernel
    • Vanilla Kalman Filter
    • Support Vector Machine
    • One Class SVM
    • Differentiable Bayesian Structure Learning
    • Stein Variational Gradient Descent
    • Finite Step Unconditional Markov Diffusion Models
Powered by GitBook
On this page
  • MLE and variational approach
  • Markov variational Ansatz
  • Gaussian diffusion

Was this helpful?

Edit on GitHub
  1. Blogs

Finite Step Unconditional Markov Diffusion Models

PreviousStein Variational Gradient Descent

Last updated 2 years ago

Was this helpful?

Let p(⋅∣θ)p( \cdot | \theta )p(⋅∣θ) be the parametric model that models data x0∼q0x_0 \sim q_0x0​∼q0​, then we can optimize θ\thetaθ by maximize the likelihood p(x0∣θ)p(x_0 | \theta)p(x0​∣θ) or equivalently

max⁡θEx0∼q0log⁡p(x0∣θ).\max_\theta \mathbb E_{x_0 \sim q_0} \log p(x_0|\theta).θmax​Ex0​∼q0​​logp(x0​∣θ).

From now on, we use qqq's to denote the forward physical distributions and ppp's the backward variational ansatz. The parameters are implied in p(⋅∣θ)≡p(⋅)p(\cdot|\theta) \equiv p(\cdot)p(⋅∣θ)≡p(⋅).

Let T>1T > 1T>1 be the diffusion steps, the joint density of forward process q(x0,1,⋯T)q(x_{0, 1, \cdots T})q(x0,1,⋯T​) can be expanded sequentially if the process is Markov

q(x0,x1,⋯ ,xT)=q0(x0)∏t=1Tqt∣t−1(xt∣xt−1).q(x_0, x_1, \cdots, x_T) = q_0(x_0) \prod_{t=1}^T q_{t|t-1}(x_t|x_{t-1}).q(x0​,x1​,⋯,xT​)=q0​(x0​)t=1∏T​qt∣t−1​(xt​∣xt−1​).

The reverse process variational ansatz can be similarly constructed

p(x0,x1,⋯ ,xT)=pT(xT)∏t=1Tpt−1∣t(xt−1∣xt)p(x_0,x_1, \cdots, x_T) = p_T(x_T) \prod_{t=1}^T p_{t-1|t}(x_{t-1}|x_{t})p(x0​,x1​,⋯,xT​)=pT​(xT​)t=1∏T​pt−1∣t​(xt−1​∣xt​)

which can be interpreted as a series of consecutive priors for the physical observation x0x_0x0​.

If we marginalize the latent variables x1,⋯ ,Tx_{1,\cdots,T}x1,⋯,T​, we get the objective function or observable likelihood

p(x0)=∫Dx1,⋯ ,T  p(x0,x1,⋯ ,xT)≡∫1,⋯ ,T p0,1,⋯ ,T.p(x_0) = \int \mathcal D x_{1,\cdots,T} \; p(x_0,x_1, \cdots, x_T)\equiv\int_{1,\cdots,T} \, p_{0,1,\cdots,T}.p(x0​)=∫Dx1,⋯,T​p(x0​,x1​,⋯,xT​)≡∫1,⋯,T​p0,1,⋯,T​.

MLE and variational approach

assume Markov property of backward process. However, without assumption of Markov property, we can still insert an identity and get

The log-likelihood average over all data distribution is

Use the concavity of logarithm function,

where

Regardless of Markov property of the processes, max likelihood lower bound is equivalent to min KL-divergence between variational ansatz and physical forward process.

Markov variational Ansatz

If the forward process is Markov, then we have

Similarly, if we assume the backward joint density can be expanded as

The posterior of physical forward process may be represented as

Thus, the forward process joint density has an posterior expansion

where the last factor telescopes

The ratio of forward and backward density can be expanded in the following fashion

whose logarithm reads

where

and the total loss becomes a typical variational inference loss: the sum of data negative log-likelihood and a series of prior KL-divergence.

So far we have not assumed any specific distribution yet. The objective is purely based on the assumption of Markov property.

Gaussian diffusion

For the particular case of Gaussian diffusion models, we assume

  • the terminal distribution is normal

  • forward transition process is Gaussian

It is useful to introduce additional notations:

Reparameterization

Iteratively apply the formula, we get

Posterior is Gaussian

We can compute the posterior of the physical process using

where the RHS is a product of Gaussians. One can show that the posterior is indeed Gaussian after doing an easy but lengthy calculation

where

and

Variational Ansatz

Since the target distribution is Gaussian, it is a good idea to choose Gaussian distribution as the variational Ansatz

where we used the KL-divergence between two Gaussian distributions.

Clearly, we have the exact solution

and therefore,

Finally, the objective becomes

It can also be shown that

Sampling the backward process

During training, the model learned the backward transition distribution

The backward iteration is essentially sampling and calculating

Training and inference algorithms

Training

  • Sample

Inference

The reconstruction formula is given in previous section "sampling the backward process."

We will derive the lower bound of maximum likelihood objective Ex0∼q0log⁡p(x0)\mathbb E_{x_0 \sim q_0} \log p(x_0)Ex0​∼q0​​logp(x0​) and then show that it is related to a KL-divergence up to a constant.

Proof. The original approach would expand the ppp joint density

p(x0)=∫1,⋯ ,T  pT(xT)∏t=1Tpt−1∣t(xt−1∣xt)p(x_0) = \int_{1,\cdots,T} \; p_T(x_T) \prod_{t=1}^T p_{t-1|t}(x_{t-1}|x_{t})p(x0​)=∫1,⋯,T​pT​(xT​)t=1∏T​pt−1∣t​(xt−1​∣xt​)
p(x0)=∫1,⋯ ,T  p0,1,⋯ ,Tq0,1,⋯ ,Tq0,1,⋯ ,T.p(x_0) = \int_{1,\cdots,T} \; \frac{p_{0,1,\cdots,T}}{q_{0,1,\cdots,T}}q_{0,1,\cdots,T}.p(x0​)=∫1,⋯,T​q0,1,⋯,T​p0,1,⋯,T​​q0,1,⋯,T​.

Note that q0,1,⋯ ,T=q1,⋯ ,T∣0q0q_{0,1,\cdots,T} = q_{1,\cdots,T | 0} q_0q0,1,⋯,T​=q1,⋯,T∣0​q0​, then we can reinterpret the integral

p(x0)=q0Eq1,⋯ ,T∣0p0,1,⋯ ,Tq0,1,⋯ ,T.p(x_0) = q_0\mathbb E_{q_{1,\cdots,T|0}} \frac{p_{0,1,\cdots,T}}{q_{0,1,\cdots,T}}.p(x0​)=q0​Eq1,⋯,T∣0​​q0,1,⋯,T​p0,1,⋯,T​​.
Eq0log⁡p(x0)=Eq0log⁡q0+Eq0log⁡Eq1,⋯ ,T∣0p0,1,⋯ ,Tq0,1,⋯ ,T.\mathbb E_{q_0} \log p(x_0) = \mathbb E_{q_0} \log q_0 +\mathbb E_{q_0} \log \mathbb E_{q_{1,\cdots,T|0}} \frac{p_{0,1,\cdots,T}}{q_{0,1,\cdots,T}}.Eq0​​logp(x0​)=Eq0​​logq0​+Eq0​​logEq1,⋯,T∣0​​q0,1,⋯,T​p0,1,⋯,T​​.
Eq0log⁡p(x0)≥Eq0log⁡q0+Eq0,1,⋯ ,Tlog⁡p0,1,⋯ ,Tq0,1,⋯ ,T=LELBO.\mathbb E_{q_0} \log p(x_0) \ge \mathbb E_{q_0} \log q_0 + \mathbb E_{q_{0,1,\cdots,T}} \log \frac{p_{0,1,\cdots,T}}{q_{0,1,\cdots,T}} = L_{\rm ELBO}.Eq0​​logp(x0​)≥Eq0​​logq0​+Eq0,1,⋯,T​​logq0,1,⋯,T​p0,1,⋯,T​​=LELBO​.
LELBO=−H[q0]−DKL(q0,⋯ ,T∣p0,⋯ ,T).L_{\rm ELBO} = - H[q_0] - D_{\rm KL} (q_{0,\cdots,T}|p_{0,\cdots,T}).LELBO​=−H[q0​]−DKL​(q0,⋯,T​∣p0,⋯,T​).

The entropy H[q0]H[q_0]H[q0​] depends on data distribution and does not contain model parameters. Thus, maximizing log-likelihood lower bound LELBOL_{\rm ELBO}LELBO​ is equivalent to minimizing KL-divergence between forward process joint density and backward process joint density.

q0,1,⋯ ,T=q0∏t=1Tqt∣t−1.q_{0,1,\cdots,T} = q_0 \prod_{t=1}^T q_{t|t-1}.q0,1,⋯,T​=q0​t=1∏T​qt∣t−1​.
p0,1,⋯ ,T=(∏t=1Tpt−1∣t)pT.p_{0,1,\cdots,T} = \left(\prod_{t=1}^T p_{t-1|t} \right) p_T.p0,1,⋯,T​=(t=1∏T​pt−1∣t​)pT​.
qt−1∣t∝qt∣t−1qt−1;q_{t-1|t} \propto q_{t|t-1} q_{t-1};qt−1∣t​∝qt∣t−1​qt−1​;

however, without the knowledge of the initial state x0x_0x0​, there could be infinity possibilities. Therefore, we fix the initial state, and get probabilies given the fixed x0∼q0x_0 \sim q_0x0​∼q0​

qt−1∣t,0∝qt∣t−1,0qt−1∣0,q_{t-1|t, 0} \propto q_{t|t-1, 0} q_{t-1 | 0},qt−1∣t,0​∝qt∣t−1,0​qt−1∣0​,

or the equality for t>1t>1t>1

qt−1∣t,0qt∣0=qt,t−1∣0=qt∣t−1,0qt−1∣0.q_{t-1|t, 0} q_{t|0} = q_{t, t-1|0} = q_{t|t-1, 0} q_{t-1 | 0}.qt−1∣t,0​qt∣0​=qt,t−1∣0​=qt∣t−1,0​qt−1∣0​.
q0,1,⋯ ,T=q0q1∣0∏t=2Tqt∣t−1,0=q0q1∣0∏t=2Tqt−1∣t,0qt∣0qt−1∣0q_{0,1,\cdots,T} = q_0 q_{1|0}\prod_{t=2}^T q_{t|t-1,0} = q_0 q_{1|0}\prod_{t=2}^T q_{t-1|t,0}\frac{q_{t|0}}{q_{t-1|0}}q0,1,⋯,T​=q0​q1∣0​t=2∏T​qt∣t−1,0​=q0​q1∣0​t=2∏T​qt−1∣t,0​qt−1∣0​qt∣0​​
q0,1,⋯ ,T=q0(∏t=2Tqt−1∣t,0)qT∣0.q_{0,1,\cdots,T} = q_0 \left( \prod_{t=2}^T q_{t-1|t,0} \right) q_{T|0}.q0,1,⋯,T​=q0​(t=2∏T​qt−1∣t,0​)qT∣0​.
q0,1,⋯ ,Tp0,1,⋯ ,T=qT∣0 q0pT p0∣1∏t=2Tqt−1∣t,0pt−1∣t,0\frac{q_{0,1,\cdots,T}}{p_{0,1,\cdots,T}} =\frac{q_{T|0}\,q_0}{p_T\,p_{0|1}}\prod_{t=2}^T\frac{q_{t-1|t,0}}{p_{t-1|t,0}}p0,1,⋯,T​q0,1,⋯,T​​=pT​p0∣1​qT∣0​q0​​t=2∏T​pt−1∣t,0​qt−1∣t,0​​
log⁡q0,1,⋯ ,Tp0,1,⋯ ,T=log⁡q0p0∣1+∑t=2Tlog⁡qt−1∣t,0pt−1∣t,0+log⁡qT∣0pT.\log\frac{q_{0,1,\cdots,T}}{p_{0,1,\cdots,T}} =\log\frac{q_0}{p_{0|1}} + \sum_{t=2}^T\log\frac{q_{t-1|t,0}}{p_{t-1|t,0}} +\log\frac{q_{T|0}}{p_T} .logp0,1,⋯,T​q0,1,⋯,T​​=logp0∣1​q0​​+t=2∑T​logpt−1∣t,0​qt−1∣t,0​​+logpT​qT∣0​​.

Using the posterior expansion of qqq, the total KL-divergence

DKL(q0,⋯ ,T∣p0,⋯ ,T)≡∑t=1TDt−1D_{\rm KL} (q_{0,\cdots,T}|p_{0,\cdots,T}) \equiv \sum_{t=1}^T D_{t-1}DKL​(q0,⋯,T​∣p0,⋯,T​)≡t=1∑T​Dt−1​

D0=DKL(q0∣p0∣1)D_0 = D_{\rm KL}(q_0|p_{0|1})D0​=DKL​(q0​∣p0∣1​),

Dt−1=DKL(qt−1∣t,0∣pt−1∣t,0)D_{t-1} = D_{\rm KL}(q_{t-1|t,0}|p_{t-1|t,0})Dt−1​=DKL​(qt−1∣t,0​∣pt−1∣t,0​) for 1<t<T1<t<T1<t<T, and

DT=DKL(qT∣0∣pT)D_T = D_{\rm KL} (q_{T|0}|p_T)DT​=DKL​(qT∣0​∣pT​).

The last term DTD_TDT​ is a constant with fixed distribution pTp_TpT​. If add back the entropy term H[q0]H[q_0]H[q0​], the first term becomes the usual likelihood

L0=D0+H[q0]=−log⁡p0∣1L_0 = D_0 + H[q_0] = -\log p_{0|1}L0​=D0​+H[q0​]=−logp0∣1​
qT∣0=qT=pT=N(xT;0,1),q_{T|0} = q_T=p_T = \mathcal N({\bf x}_T;{\bf 0}, {\bf 1}),qT∣0​=qT​=pT​=N(xT​;0,1),
qt∣t−1=N(xt;1−βtxt−1,βt1)q_{t|t-1} = \mathcal N({\bf x}_t; \sqrt{1-\beta_t}{\bf x}_{t-1}, \beta_t {\bf 1})qt∣t−1​=N(xt​;1−βt​​xt−1​,βt​1)

where 0<βt≤10<\beta_t \le 10<βt​≤1.

αt≡1−βt\alpha_t \equiv 1 - \beta_tαt​≡1−βt​, and

αˉt≡∏t=1Tαt\bar \alpha_t \equiv \prod_{t=1}^T \alpha_tαˉt​≡∏t=1T​αt​.

Let z∼N(0,1){\bf z}\sim \mathcal N({\bf 0}, {\bf 1})z∼N(0,1), the forward process can be written as

xt=1−βtxt−1+βtz=αtxt−1+βtz.{\bf x}_t = \sqrt{1-\beta_t} {\bf x}_{t-1} + \sqrt{\beta_t} {\bf z} =\sqrt{\alpha_t} {\bf x}_{t-1} + \sqrt{\beta_t} {\bf z} .xt​=1−βt​​xt−1​+βt​​z=αt​​xt−1​+βt​​z.
xt=αˉtx0+1−αˉtz.{\bf x}_t = \sqrt{\bar \alpha_t} {\bf x}_0 + \sqrt{1-\bar \alpha_t} {\bf z}.xt​=αˉt​​x0​+1−αˉt​​z.

Thus, we can generate xt{\bf x}_txt​ for any ttt without actually do the iterative calculations. There is a similar property for any Markov process, i.e. Feynman-Kac formula

qt−1∣t,0∝qt∣t−1,0qt−1∣0q_{t-1|t, 0} \propto q_{t|t-1, 0} q_{t-1 | 0}qt−1∣t,0​∝qt∣t−1,0​qt−1∣0​
qt−1∣t,0=N(xt−1;μ~t(xt,x0),β~t1)q_{t-1|t, 0} = \mathcal N({\bf x}_{t-1}; \tilde \mu_t({\bf x}_t, {\bf x}_0), \tilde \beta_t {\bf 1})qt−1∣t,0​=N(xt−1​;μ~​t​(xt​,x0​),β~​t​1)
μ~t=αt(1−αˉt−1)xt+βtαˉt−1x01−αˉt,\tilde \mu_t = \frac{\sqrt\alpha_t (1 - \bar \alpha_{t-1}){\bf x}_t + \beta_t \sqrt{\bar \alpha_{t-1}}{\bf x}_0}{1-\bar\alpha_t} ,μ~​t​=1−αˉt​α​t​(1−αˉt−1​)xt​+βt​αˉt−1​​x0​​,
β~t=βt1−αˉt−11−αt.\tilde \beta_t = \beta_t \frac{1-\bar\alpha_{t-1}}{1-\alpha_t} .β~​t​=βt​1−αt​1−αˉt−1​​.

Express x0{\bf x}_0x0​ in terms of xt{\bf x}_txt​ and noise, the mean simplies

μ~t(xt)=αt−12(xt−βt1−αˉtz).\tilde \mu_t ({\bf x}_t) = \alpha_t^{-\frac12 } \left( {\bf x}_t - \frac{\beta_t}{\sqrt{1-\bar \alpha_t}} {\bf z} \right).μ~​t​(xt​)=αt−21​​(xt​−1−αˉt​​βt​​z).
pt−1∣t=N(xt−1;μt,σt21)p_{t-1|t} = \mathcal N ({\bf x}_{t-1}; \mu_t, \sigma^2_t{\bf 1} )pt−1∣t​=N(xt−1​;μt​,σt2​1)

where the model parameters are μt\mu_tμt​ and σt\sigma_tσt​. The variance will eventually contribute to learning rate; we will treat σt\sigma_tσt​ as a hyperparameter instead of learning it from stochastic gradient descent. The only learnable parameter is then μt=μt(xt,t)\mu_t=\mu_t({\bf x}_t, t)μt​=μt​(xt​,t).

Recall the objective for each time step 1<t<T1<t<T1<t<T

Dt−1=DKL(qt−1∣t,0∣pt−1∣t,0)=12σt2∥μt−μ~t∥2+const.D_{t-1} = D_{\rm KL}(q_{t-1|t,0}|p_{t-1|t,0}) = \frac{1}{2\sigma^2_t} \Vert \mu_t - \tilde \mu_t \Vert^2 + {\rm const.}Dt−1​=DKL​(qt−1∣t,0​∣pt−1∣t,0​)=2σt2​1​∥μt​−μ~​t​∥2+const.
δDt−1[μt]δμt=0⇒μt=μ~t,\frac{\delta D_{t-1}[\mu_t]}{\delta \mu_t} = 0 \Rightarrow \mu_t = \tilde \mu_t,δμt​δDt−1​[μt​]​=0⇒μt​=μ~​t​,
μt(xt,t)=αt−12(xt−βt1−αˉtzt)\mu_t ({\bf x}_t, t) = \alpha_t^{-\frac12 } \left( {\bf x}_t - \frac{\beta_t}{\sqrt{1-\bar \alpha_t}} {\bf z}_t \right)μt​(xt​,t)=αt−21​​(xt​−1−αˉt​​βt​​zt​)

where zt∼N(0,1){\bf z}_t\sim\mathcal N(0, {\bf 1})zt​∼N(0,1) is the noise that generated xt{\bf x}_txt​ from x0{\bf x}_0x0​.

Next, we reparameterize μt\mu_tμt​ to separate the explicit dependency of xt{\bf x}_txt​ and ttt and let the model only focus on the implicit dependencies, i.e.

μt(xt,t)=αt−12(xt−βt1−αˉtz(xt,t)).\mu_t ({\bf x}_t, t) = \alpha_t^{-\frac12 } \left( {\bf x}_t - \frac{\beta_t}{\sqrt{1-\bar \alpha_t}} {\bf z}({\bf x}_t, t) \right).μt​(xt​,t)=αt−21​​(xt​−1−αˉt​​βt​​z(xt​,t)).
Dt−1=βt22αt(1−αˉt)σt2∥zt−z(xt,t)∥2D_{t-1} = \frac{\beta_t^2}{2\alpha_t(1-\bar\alpha_t)\sigma_t^2} \Vert {\bf z}_t -{\bf z}({\bf x}_t,t)\Vert^2Dt−1​=2αt​(1−αˉt​)σt2​βt2​​∥zt​−z(xt​,t)∥2

where z(xt,t){\bf z}({\bf x}_t, t)z(xt​,t) is the model output.

L0=β12α1σ12∥z1−z(x0,1)∥2L_0 = \frac{\beta_1}{2\alpha_1\sigma_1^2} \Vert {\bf z}_1 -{\bf z}({\bf x}_0,1)\Vert^2L0​=2α1​σ12​β1​​∥z1​−z(x0​,1)∥2

where αˉ1=α1=1−β1\bar \alpha_1 = \alpha_1 = 1 -\beta_1αˉ1​=α1​=1−β1​.

Thus, the generic loss term for 0<t<T0 < t < T0<t<T is

Lt−1=βt22αt(1−αˉt)σt2∥zt−z(xt,t)∥2.L_{t-1} = \frac{\beta_t^2}{2\alpha_t(1-\bar\alpha_t)\sigma_t^2} \Vert {\bf z}_t -{\bf z}({\bf x}_t,t)\Vert^2.Lt−1​=2αt​(1−αˉt​)σt2​βt2​​∥zt​−z(xt​,t)∥2.

Note that we still have the freedom to choose σt\sigma_tσt​ that controls the importance of each step. But in the literature, they usually take a heuristic approach by ignoring the weight factor keeping only the ℓ2\ell_2ℓ2​ loss.

pt−1∣t=N(xt−1;αt−12(xt−βt1−αˉtz(xt,t)),σt21).p_{t-1|t} = \mathcal N \left ({\bf x}_{t-1}; \alpha_t^{-\frac12 } \left( {\bf x}_t - \frac{\beta_t}{\sqrt{1-\bar \alpha_t}} {\bf z}({\bf x}_t, t) \right), \sigma^2_t{\bf 1} \right).pt−1∣t​=N(xt−1​;αt−21​​(xt​−1−αˉt​​βt​​z(xt​,t)),σt2​1).
xt−1=αt−12(xt−βt1−αˉtz(xt,t))+σtz{\bf x}_{t-1} = \alpha_t^{-\frac12 } \left( {\bf x}_t - \frac{\beta_t}{\sqrt{1-\bar \alpha_t}} {\bf z}({\bf x}_t, t) \right) + \sigma_t {\bf z}xt−1​=αt−21​​(xt​−1−αˉt​​βt​​z(xt​,t))+σt​z

where z∼N(0,1){\bf z} \sim \mathcal N({\bf 0}, {\bf 1})z∼N(0,1).

x0∼q0{\bf x}_0 \sim q_0x0​∼q0​

t∼Uniform(1,⋯ ,T)t \sim {\sf Uniform}(1,\cdots,T)t∼Uniform(1,⋯,T)

zt∼N(0,1){\bf z}_t \sim \mathcal N({\bf 0}, {\bf 1})zt​∼N(0,1)

Construct xt{\bf x}_txt​

Feed xt,t{\bf x}_t, txt​,t to model

Minimize Lt−1L_{t-1}Lt−1​

Sample xT∼N(0,1)x_T \sim \mathcal N({\bf 0}, {\bf 1})xT​∼N(0,1)

Loop t=T,⋯ ,1t = T,\cdots, 1t=T,⋯,1

Sample z∼N(0,1){\bf z} \sim \mathcal N({\bf 0}, {\bf 1})z∼N(0,1)

Compute xt−1{\bf x}_{t-1}xt−1​

Return x0{\bf x}_0x0​