Finite Step Unconditional Markov Diffusion Models

Let p(θ)p( \cdot | \theta ) be the parametric model that models data x0q0x_0 \sim q_0, then we can optimize θ\theta by maximize the likelihood p(x0θ)p(x_0 | \theta) or equivalently

maxθEx0q0logp(x0θ).\max_\theta \mathbb E_{x_0 \sim q_0} \log p(x_0|\theta).

From now on, we use qq's to denote the forward physical distributions and pp's the backward variational ansatz. The parameters are implied in p(θ)p()p(\cdot|\theta) \equiv p(\cdot).

Let T>1T > 1 be the diffusion steps, the joint density of forward process q(x0,1,T)q(x_{0, 1, \cdots T}) can be expanded sequentially if the process is Markov

q(x0,x1,,xT)=q0(x0)t=1Tqtt1(xtxt1).q(x_0, x_1, \cdots, x_T) = q_0(x_0) \prod_{t=1}^T q_{t|t-1}(x_t|x_{t-1}).

The reverse process variational ansatz can be similarly constructed

p(x0,x1,,xT)=pT(xT)t=1Tpt1t(xt1xt)p(x_0,x_1, \cdots, x_T) = p_T(x_T) \prod_{t=1}^T p_{t-1|t}(x_{t-1}|x_{t})

which can be interpreted as a series of consecutive priors for the physical observation x0x_0.

If we marginalize the latent variables x1,,Tx_{1,\cdots,T}, we get the objective function or observable likelihood

p(x0)=Dx1,,T  p(x0,x1,,xT)1,,Tp0,1,,T.p(x_0) = \int \mathcal D x_{1,\cdots,T} \; p(x_0,x_1, \cdots, x_T)\equiv\int_{1,\cdots,T} \, p_{0,1,\cdots,T}.

MLE and variational approach

We will derive the lower bound of maximum likelihood objective Ex0q0logp(x0)\mathbb E_{x_0 \sim q_0} \log p(x_0) and then show that it is related to a KL-divergence up to a constant.

Proof. The original approach would expand the pp joint density

p(x0)=1,,T  pT(xT)t=1Tpt1t(xt1xt)p(x_0) = \int_{1,\cdots,T} \; p_T(x_T) \prod_{t=1}^T p_{t-1|t}(x_{t-1}|x_{t})

assume Markov property of backward process. However, without assumption of Markov property, we can still insert an identity and get

p(x0)=1,,T  p0,1,,Tq0,1,,Tq0,1,,T.p(x_0) = \int_{1,\cdots,T} \; \frac{p_{0,1,\cdots,T}}{q_{0,1,\cdots,T}}q_{0,1,\cdots,T}.

Note that q0,1,,T=q1,,T0q0q_{0,1,\cdots,T} = q_{1,\cdots,T | 0} q_0, then we can reinterpret the integral

p(x0)=q0Eq1,,T0p0,1,,Tq0,1,,T.p(x_0) = q_0\mathbb E_{q_{1,\cdots,T|0}} \frac{p_{0,1,\cdots,T}}{q_{0,1,\cdots,T}}.

The log-likelihood average over all data distribution is

Eq0logp(x0)=Eq0logq0+Eq0logEq1,,T0p0,1,,Tq0,1,,T.\mathbb E_{q_0} \log p(x_0) = \mathbb E_{q_0} \log q_0 +\mathbb E_{q_0} \log \mathbb E_{q_{1,\cdots,T|0}} \frac{p_{0,1,\cdots,T}}{q_{0,1,\cdots,T}}.

Use the concavity of logarithm function,

Eq0logp(x0)Eq0logq0+Eq0,1,,Tlogp0,1,,Tq0,1,,T=LELBO.\mathbb E_{q_0} \log p(x_0) \ge \mathbb E_{q_0} \log q_0 + \mathbb E_{q_{0,1,\cdots,T}} \log \frac{p_{0,1,\cdots,T}}{q_{0,1,\cdots,T}} = L_{\rm ELBO}.

where

LELBO=H[q0]DKL(q0,,Tp0,,T).L_{\rm ELBO} = - H[q_0] - D_{\rm KL} (q_{0,\cdots,T}|p_{0,\cdots,T}).

The entropy H[q0]H[q_0] depends on data distribution and does not contain model parameters. Thus, maximizing log-likelihood lower bound LELBOL_{\rm ELBO} is equivalent to minimizing KL-divergence between forward process joint density and backward process joint density.

Regardless of Markov property of the processes, max likelihood lower bound is equivalent to min KL-divergence between variational ansatz and physical forward process.

Markov variational Ansatz

If the forward process is Markov, then we have

q0,1,,T=q0t=1Tqtt1.q_{0,1,\cdots,T} = q_0 \prod_{t=1}^T q_{t|t-1}.

Similarly, if we assume the backward joint density can be expanded as

p0,1,,T=(t=1Tpt1t)pT.p_{0,1,\cdots,T} = \left(\prod_{t=1}^T p_{t-1|t} \right) p_T.

The posterior of physical forward process may be represented as

qt1tqtt1qt1;q_{t-1|t} \propto q_{t|t-1} q_{t-1};

however, without the knowledge of the initial state x0x_0, there could be infinity possibilities. Therefore, we fix the initial state, and get probabilies given the fixed x0q0x_0 \sim q_0

qt1t,0qtt1,0qt10,q_{t-1|t, 0} \propto q_{t|t-1, 0} q_{t-1 | 0},

or the equality for t>1t>1

qt1t,0qt0=qt,t10=qtt1,0qt10.q_{t-1|t, 0} q_{t|0} = q_{t, t-1|0} = q_{t|t-1, 0} q_{t-1 | 0}.

Thus, the forward process joint density has an posterior expansion

q0,1,,T=q0q10t=2Tqtt1,0=q0q10t=2Tqt1t,0qt0qt10q_{0,1,\cdots,T} = q_0 q_{1|0}\prod_{t=2}^T q_{t|t-1,0} = q_0 q_{1|0}\prod_{t=2}^T q_{t-1|t,0}\frac{q_{t|0}}{q_{t-1|0}}

where the last factor telescopes

q0,1,,T=q0(t=2Tqt1t,0)qT0.q_{0,1,\cdots,T} = q_0 \left( \prod_{t=2}^T q_{t-1|t,0} \right) q_{T|0}.

The ratio of forward and backward density can be expanded in the following fashion

q0,1,,Tp0,1,,T=qT0q0pTp01t=2Tqt1t,0pt1t,0\frac{q_{0,1,\cdots,T}}{p_{0,1,\cdots,T}} =\frac{q_{T|0}\,q_0}{p_T\,p_{0|1}}\prod_{t=2}^T\frac{q_{t-1|t,0}}{p_{t-1|t,0}}

whose logarithm reads

logq0,1,,Tp0,1,,T=logq0p01+t=2Tlogqt1t,0pt1t,0+logqT0pT.\log\frac{q_{0,1,\cdots,T}}{p_{0,1,\cdots,T}} =\log\frac{q_0}{p_{0|1}} + \sum_{t=2}^T\log\frac{q_{t-1|t,0}}{p_{t-1|t,0}} +\log\frac{q_{T|0}}{p_T} .

Using the posterior expansion of qq, the total KL-divergence

DKL(q0,,Tp0,,T)t=1TDt1D_{\rm KL} (q_{0,\cdots,T}|p_{0,\cdots,T}) \equiv \sum_{t=1}^T D_{t-1}

where

  • D0=DKL(q0p01)D_0 = D_{\rm KL}(q_0|p_{0|1}),

  • Dt1=DKL(qt1t,0pt1t,0)D_{t-1} = D_{\rm KL}(q_{t-1|t,0}|p_{t-1|t,0}) for 1<t<T1<t<T, and

  • DT=DKL(qT0pT)D_T = D_{\rm KL} (q_{T|0}|p_T).

The last term DTD_T is a constant with fixed distribution pTp_T. If add back the entropy term H[q0]H[q_0], the first term becomes the usual likelihood

L0=D0+H[q0]=logp01L_0 = D_0 + H[q_0] = -\log p_{0|1}

and the total loss becomes a typical variational inference loss: the sum of data negative log-likelihood and a series of prior KL-divergence.

So far we have not assumed any specific distribution yet. The objective is purely based on the assumption of Markov property.

Gaussian diffusion

For the particular case of Gaussian diffusion models, we assume

  • the terminal distribution is normal

    qT0=qT=pT=N(xT;0,1),q_{T|0} = q_T=p_T = \mathcal N({\bf x}_T;{\bf 0}, {\bf 1}),
  • forward transition process is Gaussian

    qtt1=N(xt;1βtxt1,βt1)q_{t|t-1} = \mathcal N({\bf x}_t; \sqrt{1-\beta_t}{\bf x}_{t-1}, \beta_t {\bf 1})

    where 0<βt10<\beta_t \le 1.

It is useful to introduce additional notations:

  • αt1βt\alpha_t \equiv 1 - \beta_t, and

  • αˉtt=1Tαt\bar \alpha_t \equiv \prod_{t=1}^T \alpha_t.

Reparameterization

Let zN(0,1){\bf z}\sim \mathcal N({\bf 0}, {\bf 1}), the forward process can be written as

xt=1βtxt1+βtz=αtxt1+βtz.{\bf x}_t = \sqrt{1-\beta_t} {\bf x}_{t-1} + \sqrt{\beta_t} {\bf z} =\sqrt{\alpha_t} {\bf x}_{t-1} + \sqrt{\beta_t} {\bf z} .

Iteratively apply the formula, we get

xt=αˉtx0+1αˉtz.{\bf x}_t = \sqrt{\bar \alpha_t} {\bf x}_0 + \sqrt{1-\bar \alpha_t} {\bf z}.

Thus, we can generate xt{\bf x}_t for any tt without actually do the iterative calculations. There is a similar property for any Markov process, i.e. Feynman-Kac formula

Posterior is Gaussian

We can compute the posterior of the physical process using

qt1t,0qtt1,0qt10q_{t-1|t, 0} \propto q_{t|t-1, 0} q_{t-1 | 0}

where the RHS is a product of Gaussians. One can show that the posterior is indeed Gaussian after doing an easy but lengthy calculation

qt1t,0=N(xt1;μ~t(xt,x0),β~t1)q_{t-1|t, 0} = \mathcal N({\bf x}_{t-1}; \tilde \mu_t({\bf x}_t, {\bf x}_0), \tilde \beta_t {\bf 1})

where

μ~t=αt(1αˉt1)xt+βtαˉt1x01αˉt,\tilde \mu_t = \frac{\sqrt\alpha_t (1 - \bar \alpha_{t-1}){\bf x}_t + \beta_t \sqrt{\bar \alpha_{t-1}}{\bf x}_0}{1-\bar\alpha_t} ,

and

β~t=βt1αˉt11αt.\tilde \beta_t = \beta_t \frac{1-\bar\alpha_{t-1}}{1-\alpha_t} .

Express x0{\bf x}_0 in terms of xt{\bf x}_t and noise, the mean simplies

μ~t(xt)=αt12(xtβt1αˉtz).\tilde \mu_t ({\bf x}_t) = \alpha_t^{-\frac12 } \left( {\bf x}_t - \frac{\beta_t}{\sqrt{1-\bar \alpha_t}} {\bf z} \right).

Variational Ansatz

Since the target distribution is Gaussian, it is a good idea to choose Gaussian distribution as the variational Ansatz

pt1t=N(xt1;μt,σt21)p_{t-1|t} = \mathcal N ({\bf x}_{t-1}; \mu_t, \sigma^2_t{\bf 1} )

where the model parameters are μt\mu_t and σt\sigma_t. The variance will eventually contribute to learning rate; we will treat σt\sigma_t as a hyperparameter instead of learning it from stochastic gradient descent. The only learnable parameter is then μt=μt(xt,t)\mu_t=\mu_t({\bf x}_t, t).

Recall the objective for each time step 1<t<T1<t<T

Dt1=DKL(qt1t,0pt1t,0)=12σt2μtμ~t2+const.D_{t-1} = D_{\rm KL}(q_{t-1|t,0}|p_{t-1|t,0}) = \frac{1}{2\sigma^2_t} \Vert \mu_t - \tilde \mu_t \Vert^2 + {\rm const.}

where we used the KL-divergence between two Gaussian distributions.

Clearly, we have the exact solution

δDt1[μt]δμt=0μt=μ~t,\frac{\delta D_{t-1}[\mu_t]}{\delta \mu_t} = 0 \Rightarrow \mu_t = \tilde \mu_t,

and therefore,

μt(xt,t)=αt12(xtβt1αˉtzt)\mu_t ({\bf x}_t, t) = \alpha_t^{-\frac12 } \left( {\bf x}_t - \frac{\beta_t}{\sqrt{1-\bar \alpha_t}} {\bf z}_t \right)

where ztN(0,1){\bf z}_t\sim\mathcal N(0, {\bf 1}) is the noise that generated xt{\bf x}_t from x0{\bf x}_0.

Next, we reparameterize μt\mu_t to separate the explicit dependency of xt{\bf x}_t and tt and let the model only focus on the implicit dependencies, i.e.

μt(xt,t)=αt12(xtβt1αˉtz(xt,t)).\mu_t ({\bf x}_t, t) = \alpha_t^{-\frac12 } \left( {\bf x}_t - \frac{\beta_t}{\sqrt{1-\bar \alpha_t}} {\bf z}({\bf x}_t, t) \right).

Finally, the objective becomes

Dt1=βt22αt(1αˉt)σt2ztz(xt,t)2D_{t-1} = \frac{\beta_t^2}{2\alpha_t(1-\bar\alpha_t)\sigma_t^2} \Vert {\bf z}_t -{\bf z}({\bf x}_t,t)\Vert^2

where z(xt,t){\bf z}({\bf x}_t, t) is the model output.

It can also be shown that

L0=β12α1σ12z1z(x0,1)2L_0 = \frac{\beta_1}{2\alpha_1\sigma_1^2} \Vert {\bf z}_1 -{\bf z}({\bf x}_0,1)\Vert^2

where αˉ1=α1=1β1\bar \alpha_1 = \alpha_1 = 1 -\beta_1.

Thus, the generic loss term for 0<t<T0 < t < T is

Lt1=βt22αt(1αˉt)σt2ztz(xt,t)2.L_{t-1} = \frac{\beta_t^2}{2\alpha_t(1-\bar\alpha_t)\sigma_t^2} \Vert {\bf z}_t -{\bf z}({\bf x}_t,t)\Vert^2.

Note that we still have the freedom to choose σt\sigma_t that controls the importance of each step. But in the literature, they usually take a heuristic approach by ignoring the weight factor keeping only the 2\ell_2 loss.

Sampling the backward process

During training, the model learned the backward transition distribution

pt1t=N(xt1;αt12(xtβt1αˉtz(xt,t)),σt21).p_{t-1|t} = \mathcal N \left ({\bf x}_{t-1}; \alpha_t^{-\frac12 } \left( {\bf x}_t - \frac{\beta_t}{\sqrt{1-\bar \alpha_t}} {\bf z}({\bf x}_t, t) \right), \sigma^2_t{\bf 1} \right).

The backward iteration is essentially sampling and calculating

xt1=αt12(xtβt1αˉtz(xt,t))+σtz{\bf x}_{t-1} = \alpha_t^{-\frac12 } \left( {\bf x}_t - \frac{\beta_t}{\sqrt{1-\bar \alpha_t}} {\bf z}({\bf x}_t, t) \right) + \sigma_t {\bf z}

where zN(0,1){\bf z} \sim \mathcal N({\bf 0}, {\bf 1}).

Training and inference algorithms

Training

  • Sample

    • x0q0{\bf x}_0 \sim q_0

    • tUniform(1,,T)t \sim {\sf Uniform}(1,\cdots,T)

    • ztN(0,1){\bf z}_t \sim \mathcal N({\bf 0}, {\bf 1})

  • Construct xt{\bf x}_t

  • Feed xt,t{\bf x}_t, t to model

  • Minimize Lt1L_{t-1}

Inference

  • Sample xTN(0,1)x_T \sim \mathcal N({\bf 0}, {\bf 1})

  • Loop t=T,,1t = T,\cdots, 1

    • Sample zN(0,1){\bf z} \sim \mathcal N({\bf 0}, {\bf 1})

    • Compute xt1{\bf x}_{t-1}

  • Return x0{\bf x}_0

The reconstruction formula is given in previous section "sampling the backward process."

Last updated

Was this helpful?