当前位置:首页|资讯|AIGC

AIGC: DDIM (Denoising Diffusion Implicit Models) 笔记

作者:刹那-Ksana-发布时间:2023-07-19

TL;DR: 去噪扩散隐式模型 (DDIM) 是利用非马尔可夫的思想,以牺牲一小部分图片质量为代价,对图像生成过程大幅度加速的采样方法。

这个话题太过复杂,如内容有错误,还请在评论里面指正。

本人数学不好,尽量绕开复杂的公式(?) 

大局观

首先,有一个问题必须要回答——为什么 DDPM 要基于马尔可夫链,马尔可夫链到底起一个什么样的作用。

在这里,以一个小白的视角来理一下DDPM的大致过程:

q(%5Cmathbf%7Bx%7D_t%7C%5Cmathbf%7Bx%7D_%7Bt-1%7D)%20%3A%3D%20%20%5Cmathcal%7BN%7D(%5Cmathbf%7Bx%7D_t%3B%5Csqrt%7B1-%5Cbeta_t%7D%5Cmathbf%7Bx%7D_%7Bt-1%7D%2C%5Cbeta_t%20%5Cmathbf%7BI%7D)

在多次加噪之后,数据最终将会变成高斯分布。

p_%5Ctheta(%5Ctextbf%7Bx%7D_%7Bt-1%7D%7C%5Ctextbf%7Bx%7D_%7Bt%7D)%20%3A%3D%20%5Cmathcal%7BN%7D(%5Cmathbf%7Bx%7D_%7Bt-1%7D%3B%20%5Cmu_%5Ctheta(%5Cmathbf%7Bx%7D_t%2C%20t)%2C%20%5CSigma_%5Ctheta(%5Cmathbf%7Bx%7D_t%2C%20t)) 去贴合去噪的分布。

p_%5Ctheta (事实上我们也没有去学习 p_%5Ctheta). 我们的最终目的,是去模拟出原始数据的分布 p(%5Ctextbf%7Bx%7D_0)

%5Cmathbb%7BE%7D%5Bp_%5Ctheta%20(%5Ctextbf%7Bx%7D_0)%5D (很明显,如果我们能够直接极大化似然的话,就不用在这里费这么大力气了😂), DDPM 的优化目标是最大化其变分下界(Evidence Lower Bound)%5Cmathbb%7BE%7D_q%5B%5Clog%5Cfrac%7Bp_%7B%5Ctheta%7D(x_0%2C%20x_%7B1%3AT%7D)%7D%7Bq(x_%7B1%3AT%7D%7Cx_%7B0%7D)%7D%5D. (这里 x_%7B1%3AT%7D 的意思是 x_1%2C%20x_2%2C...%2Cx_T

我们在 log 前面加一个负号,把最大化目标变成最小化目标,于是就得到了 DDPM 的优化目标:

L%3D%5Cmathbb%7BE%7D_q%5B-%20%5Clog%5Cfrac%7Bp_%7B%5Ctheta%7D(x_%7B0%3AT%7D)%7D%7Bq(x_%7B1%3AT%7D%7Cx_%7B0%7D)%7D%5D

q(x_%7B1%3AT%7D%7Cx_%7B0%7D). 很明显,这是一个依赖于 %5Ctextbf%7Bx%7D_0 的联合分布(Joint Distribution). 对于联合分布,我们初高中学过,有链式法则

P(x_%7B1%3A3%7D)%3DP(x_3%7Cx_2%2Cx_1)P(x_2%7Cx_1)P(x_1)

而在马尔可夫链的情况下,上面这个公式将变成

P(x_%7B1%3A3%7D)%3DP(x_3%7Cx_2)P(x_2%7Cx_1)P(x_1)

所以,在马尔可夫链的前提下,  q(x_%7B1%3AT%7D%7Cx_%7B0%7D) 这个联合分布可以被写成如下的乘积形式(如果外面带 log 的话就相当于加和形式)

q(x_%7B1%3AT%7D%7Cx_%7B0%7D)%3D%5Cprod%5Cnolimits_%7Bt%5Cgeq1%7D%20q(x_t%7Cx_%7Bt-1%7D)

p_%7B%5Ctheta%7D(x_%7B0%3AT%7D) 也是差不多的道理,这里就不去花力气解释了。总之,在马尔可夫链的前提下,上面的优化目标可以进一步地写下去:

%5Cbegin%7Balign*%7D%0A%26%20%5Cmathbb%7BE%7D_q%20%5B%20-%20%5Clog%20%5Cfrac%7Bp_%5Ctheta(%5Ctextbf%7Bx%7D_%7B0%3AT%7D)%7D%7Bq(%5Ctextbf%7Bx%7D_%7B1%3AT%7D%20%7C%20%5Ctextbf%7Bx%7D_0)%7D%20%5D%20%5C%5C%0A%26%3D%5Cmathbb%7BE%7D_%7Bq%7D%5B%20-%5Clog%20p(%5Cmathbf%7Bx%7D_T)%20-%20%5Csum_%7Bt%20%5Cgeq%201%7D%20%5Clog%20%5Cfrac%7Bp_%5Ctheta(%5Cmathbf%7Bx%7D_%7Bt-1%7D%20%7C%20%5Cmathbf%7Bx%7D_t)%7D%7Bq(%5Cmathbf%7Bx%7D_t%7C%5Cmathbf%7Bx%7D_%7Bt-1%7D)%7D%20%5D%0A%5Cend%7Balign*%7D

%5Clog%5Cfrac%7Bp_%5Ctheta(%5Cmathbf%7Bx%7D_0%7C%5Cmathbf%7Bx%7D_1)%7D%7Bq(%5Cmathbf%7Bx%7D_1%7C%5Cmathbf%7Bx%7D_0)%7D 这一项从 %5Csum_%7Bt%20%5Cgeq%201%7D%20%5Clog%20%5Cfrac%7Bp_%5Ctheta(%5Cmathbf%7Bx%7D_%7Bt-1%7D%20%7C%20%5Cmathbf%7Bx%7D_t)%7D%7Bq(%5Cmathbf%7Bx%7D_t%7C%5Cmathbf%7Bx%7D_%7Bt-1%7D)%7D 中拆分出来。然后我们再将 %5Csum_%7Bt%20%5Cgt%201%7D%20%5Clog%20%5Cfrac%7Bp_%5Ctheta(%5Cmathbf%7Bx%7D_%7Bt-1%7D%20%7C%20%5Cmathbf%7Bx%7D_t)%7D%7Bq(%5Cmathbf%7Bx%7D_t%7C%5Cmathbf%7Bx%7D_%7Bt-1%7D)%7D 改写一下,变成 %5Csum_%7Bt%20%3E%201%7D%20%5Clog%20%5Cfrac%7Bp_%5Ctheta(%5Cmathbf%7Bx%7D_%7Bt-1%7D%20%7C%20%5Cmathbf%7Bx%7D_t)%7D%7Bq(%5Cmathbf%7Bx%7D_%7Bt-1%7D%7C%5Cmathbf%7Bx%7D_t%2C%5Cmathbf%7Bx%7D_0)%7D%5Ccdot%5Cfrac%7Bq(%5Cmathbf%7Bx%7D_%7Bt-1%7D%7C%5Cmathbf%7Bx%7D_0)%7D%7Bq(%5Cmathbf%7Bx%7D_t%7C%5Cmathbf%7Bx%7D_0)%7D 的形式(后面一项可以被相互抵消掉),然后我们就得到了目标的一个新的形式:

%5Cmathbb%7BE%7D_%7Bq%7D%5B%20%5Cmathbf%7BD%7D_%7Bkl%7D(q(%5Cmathbf%7Bx%7D_T%7C%5Cmathbf%7Bx%7D_0)%7C%7Cp(%5Cmathbf%7Bx%7D_T))%20%2B%20%5Csum_%7Bt%20%3E%201%7D%20%5Cmathbf%7BD%7D_%7Bkl%7D(q(%5Cmathbf%7Bx%7D_%7Bt-1%7D%7C%5Cmathbf%7Bx%7D_t%2C%5Cmathbf%7Bx%7D_0)%7C%7C%7Bp_%5Ctheta(%5Cmathbf%7Bx%7D_%7Bt-1%7D%20%7C%20%5Cmathbf%7Bx%7D_t)%7D)%20-%20%5Clog%20p_%5Ctheta(%5Cmathbf%7Bx%7D_0%20%7C%20%5Cmathbf%7Bx%7D_1)%20%5D

至此,我们利用了马尔可夫链的性质对公式做了变形,如果我们把高斯分布这一个条件也加上去的话,我们可以对上面的目标做进一步的推导(这里就省略过程了,如果想知道过程,可以参考之前推荐的DDPM的文章,或者DDIM原论文附录C,链接在文章末尾),得到如下的最终形式:

L_%5Cgamma%3D%5Csum_%7Bt%3D1%7D%5E%7BT%7D%5Cgamma_t%20%5Cmathbb%7BE%7D_%7Bx%5Csim%20q(x_t%7Cx_0)%7D%5B%7C%7C%20%5Cepsilon_t-%5Cepsilon_%7B%5Ctheta%7D(x_t%2Ct)%20%7C%7C_2%5E2%5D%2C%20%5Cepsilon_t%20%5Csim%20%5Cmathcal%7BN%7D(0%2CI)

L,依赖于 q(x_t%7Cx_0) 而不是联合分布 q(x_%7B1%3AT%7D%7Cx_0). 那么什么是 q(x_t%7Cx_0) 呢?首先,这是个加噪/前向过程,并且根据定义,有

q(x_t%7Cx_0)%3A%3D%5Cint%20q(x_%7B1%3AT%7D%7Cx_0)dx_%7B1%3A(t-1)%7D

并且,根据马尔可夫链和高斯分布的两个大前提,我们还知道,

q(x_t%7Cx_0)%3A%3D%5Cint%20q(x_%7B1%3AT%7D%7Cx_0)dx_%7B1%3A(t-1)%7D%20%3D%20%5Cmathcal%7BN%7D(x_t%3B%5Csqrt%7B%5Calpha_t%7Dx_0%2C(1-%5Calpha_t)I)

呃,所以扯了大半天,这家伙在一本正经的胡八道什么?

q(x_t%7Cx_0)%3A%3D%5Cint%20q(x_%7B1%3AT%7D%7Cx_0)dx_%7B1%3A(t-1)%7D 到公式的“%5Cmathcal%7BN%7D(x_t%3B%5Csqrt%7B%5Calpha_t%7Dx_0%2C(1-%5Calpha_t)I),我们是通过了马尔可夫链推导出来的。但是实际上,我们未必一定要通过马尔可夫链去"求解"。有没有一种方法,基于非马尔可夫链的方式,也能求得这个"解"呢?

非马尔可夫过程

这里,直接照搬论文给出的标准答案了。如果我们的概率分布 q 满足以下的条件:

q_%5Csigma(%5Cmathbf%7Bx%7D_%7B1%3AT%7D%20%7C%20%5Cmathbf%7Bx%7D_0)%20%3A%3D%20q_%5Csigma(%5Cmathbf%7Bx%7D_T%20%7C%20%5Cmathbf%7Bx%7D_0)%20%5Cprod_%7Bt%3D2%7D%5E%7BT%7D%20q_%5Csigma(%5Cmathbf%7Bx%7D_%7Bt-1%7D%20%7C%20%5Cmathbf%7Bx%7D_%7Bt%7D%2C%20%5Cmathbf%7Bx%7D_0)%2C%20%5Csigma%20%5Cin%20%5Cmathbb%7BR%7D_%7B%5Cgeq%200%7D%5E%7BT%7D

q_%5Csigma(%5Cmathbf%7Bx%7D_%7BT%7D%20%7C%20%5Cmathbf%7Bx%7D_0)%20%3D%20%5Cmathcal%7BN%7D(%5Csqrt%7B%5Calpha_T%7D%20%5Cmathbf%7Bx%7D_0%2C%20(1%20-%20%5Calpha_T)%20%5Cmathbf%7BI%7D)

%5Ccolor%7Bpurple%7D%20%7Bq_%5Csigma(%5Cmathbf%7Bx%7D_%7Bt-1%7D%20%7C%20%5Cmathbf%7Bx%7D_t%2C%20%5Cmathbf%7Bx%7D_0)%20%3D%20%5Cmathcal%7BN%7D(%5Csqrt%7B%5Calpha_%7Bt-1%7D%7D%20%5Cmathbf%7Bx%7D_%7B0%7D%20%2B%20%5Csqrt%7B1%20-%20%5Calpha_%7Bt-1%7D%20-%20%5Csigma%5E2_t%7D%20%5Ccdot%20%7B%5Cfrac%7B%5Cmathbf%7Bx%7D_%7Bt%7D%20%20-%20%5Csqrt%7B%5Calpha_%7Bt%7D%7D%20%5Cmathbf%7Bx%7D_0%7D%7B%5Csqrt%7B1%20-%20%5Calpha_%7Bt%7D%7D%7D%7D%2C%20%5Csigma_t%5E2%20%5Cmathbf%7BI%7D)%7D

q_%5Csigma(%5Cmathbf%7Bx%7D_%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D_0)%20%3D%20%5Cmathcal%7BN%7D(%5Csqrt%7B%5Calpha_t%7D%20%5Cmathbf%7Bx%7D_0%2C%20(1%20-%20%5Calpha_t)%20%5Cmathbf%7BI%7D)

上面的第三个公式非常重要,至于怎么来的,论文没有给出过程和说明。网上有许多大神们针对这一步写了不少文章,感兴趣的可以去看(链接在文章末尾)。

然后我们利用贝叶斯理论(Bayes' Rule)获得非马尔科夫下的前向过程:

q_%5Csigma(%5Cmathbf%7Bx%7D_%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D_%7Bt-1%7D%2C%20%5Cmathbf%7Bx%7D_0)%20%3D%20%5Cfrac%7Bq_%5Csigma(%5Cmathbf%7Bx%7D_%7Bt-1%7D%20%7C%20%5Cmathbf%7Bx%7D_%7Bt%7D%2C%20%5Cmathbf%7Bx%7D_0)%20q_%5Csigma(%5Cmathbf%7Bx%7D_%7Bt%7D%20%7C%20%5Cmathbf%7Bx%7D_0)%7D%7Bq_%5Csigma(%5Cmathbf%7Bx%7D_%7Bt-1%7D%20%7C%20%5Cmathbf%7Bx%7D_0)%7D

q(x_t%7Cx_%7Bt-1%7D) 不再已知。

q_%5Csigma(%5Cmathbf%7Bx%7D_%7Bt-1%7D%20%7C%20%5Cmathbf%7Bx%7D_t%2C%20%5Cmathbf%7Bx%7D_0)%20,和前向一样,同时依赖于 x_tx_0. 当然遵循 DDPM 的思路,我们可以将这个概率分布改写一下

%5Cbegin%7Balign%7D%0A%26%20x_%7Bt-1%7D%3D%5Csqrt%7B%5Calpha_%7Bt-1%7D%7Dx_0%2B%5Csqrt%7B1-%5Calpha_%7Bt-1%7D-%5Csigma_t%5E2%7D%20%5Ccdot%20%5Cepsilon_%5Ctheta%5E%7B(t)%7D(x_t)%2B%5Csigma_t%20%5Cepsilon_t%20%5C%5C%0A%26%20x_0%20%3D%20(x_t-%5Csqrt%7B1-%5Calpha_t%7D%20%5Ccdot%20%5Cepsilon_%5Ctheta%5E%7B(t)%7D(x_t))%2F%5Csqrt%7B%5Calpha_t%7D%0A%5Cend%7Balign%7D

这里注意 α 和 σ 长得特别像,不要搞错了 (lll¬ω¬)。另外,在这里,我们依旧是老老实实地在一步步地进行采样,还没有涉及到任何加速采样的内容。

借用 DDPM 模型

BANG! 又是一个重磅炸弹——DDIM 不需要再训练一个单独的模型,直接利用已经训练好的 DDPM 模型就可以进行采样。

%5Csigma%3E0, 都存在 %5Cgamma%5Cin%5Cmathbb%7BR%7D%5ET_%7B%5Cgt%200%7D, 和一个常数 C%5Cin%20%5Cmathbb%7BR%7D, 使得等式 J_%7B%5Csigma%7D%3DL_%7B%5Cgamma%7D%2BC 成立。(这是什么天书)

这里对于论证过程不做叙述,但是有必要解释一下上面的这个结论是什么意思。

J_%5Csigma,或者完整地说,J_%5Csigma%20(%5Cepsilon_%5Ctheta), 是我们 DDIM 非马尔可夫过程下的目标,而 L_%7B%5Cgamma%7D 是我们利用 DDPM 推导出来的目标(上文中有完整的公式)。这里主要想说的一点,就是因为两者目标相等,所以 DDIM 可以借用其对应的参数的 DDPM 的模型。

%5Cepsilon_%5Ctheta, 在不同的时间 t, 权重是不共享的(比如,不同的时间点 t,我们都用一个不同的神经网络)。那么最小化目标 L_%5Cgamma 就变成了独立地最小化 %5Csum_%7Bt%3D1%7D%5E%7BT%7D%5Cgamma_t%20%5Cmathbb%7BE%7D_%7Bx%5Csim%20q(x_t%7Cx_0)%7D%5B%7C%7C%20%5Cepsilon_t-%5Cepsilon_%7B%5Ctheta%7D(x_t%2Ct)%20%7C%7C_2%5E2%5D  里面的每一项,于是 %5Cgamma_t 在优化目标的时候就不再起作用(我们最优解 %5Cepsilon_%5Ctheta 将不再依赖 %5Cgamma 的取值)。比如,我们可以取 %5Cgamma%3D1,那么这就变成了 DDPM 原论文中的情况。

加速推理

所以,讲了这么多废话。终于到了最关键的部分——如何利用DDIM来加速推理。

DDIM 加速推理的示意图

%5C%7B1%2C2%2C3%2C...%2C999%2C1000%5C%7D 的集合中,取一个子集出来 %5C%7B%5Ctau_1%2C%20%5Ctau_2%2C%5Ctau_3%2C...%5C%7D, 这个子集的长度将远小于1000。然后我们把 x_%5Ctau%2C%20x_%7B%5Ctau-1%7D 带入非马尔可夫过程这一节里面的迭代公式,就得到了我们加速采样的过程。

Why? 首先,论文里面把联合分布拆解成了以下的形式。

p_%5Ctheta(%5Cmathbf%7Bx%7D_%7B0%3AT%7D)%20%3A%3D%20p_%7B%5Ctheta%7D(x_T)%5Cprod_%7Bi%3D1%7D%5E%7BS%7Dp_%7B%5Ctheta%7D%5E%7B(%5Ctau_i)%7D(x_%7B%5Ctau_%7Bi-1%7D%7D%7Cx_%7B%5Ctau_%7Bi%7D%7D)%5Ctimes%20%5Cprod_%7Bt%5Cin%20%5Cbar%7B%5Ctau%7D%7D%20p_%7B%5Ctheta%7D%5E%7B(t)%7D(x_0%7Cx_t)

L_%5Cgamma 是等价的。

%5Csigma%3D0 的情况)。所以目前就把它当作一个定理来看吧,不要太深究了

ODE

上一篇文章最后留了一个坑没填。

%5Csigma%3D0 时,那么方差这一项就变成了 0. 于是原先的 stochastic 的过程就变成了 deterministic 的过程(即,已知 x_t 和 x_0 的情况下,x_%7Bt-1%7D是一个确定的值;意味着从相同的噪音出发,将会导出相同的图片)

%5Csigma%3D0 的离散情况连续化,就能得到对应的 ODE. 

x_i 出发

%5Cbegin%7Balign%7D%0A%26%20x_%7Bt-1%7D%3D%5Csqrt%7B%5Calpha_%7Bt-1%7D%7Dx_0%2B%5Csqrt%7B1-%5Calpha_%7Bt-1%7D%7D%20%5Ccdot%20%5Cepsilon_%5Ctheta%5E%7B(t)%7D(x_t)%20%5C%5C%0A%26%20x_0%20%3D%20(x_t-%5Csqrt%7B1-%5Calpha_t%7D%20%5Ccdot%20%5Cepsilon_%5Ctheta%5E%7B(t)%7D(x_t))%2F%5Csqrt%7B%5Calpha_t%7D%0A%5Cend%7Balign%7D

N%20%5Cto%20%5Cinfty, 然后我们再把它压缩到一个连续的区间 [0,1] 上面去,所以 %5CDelta%20t%3D1%2FN, 并且

%5Cfrac%7Bx_%7Bt-%5CDelta%20t%7D%7D%7B%5Csqrt%7B%5Calpha_%7Bt-%5CDelta%20t%7D%7D%7D%3D%5Cfrac%7Bx_t%7D%7B%5Csqrt%7B%5Calpha_%7Bt%7D%7D%7D%2B(%5Csqrt%7B%5Cfrac%7B1-%5Calpha_%7Bt-%5CDelta%7Bt%7D%7D%7D%7B%5Calpha_%7Bt-%5CDelta%20t%7D%7D%7D%20-%20%5Csqrt%7B%5Cfrac%7B1-%5Calpha_%7Bt%7D%7D%7B%5Calpha_%7Bt%7D%7D%7D)%5Cepsilon_%5Ctheta(x_t%2Ct)

%5Cmathrm%7BX%7D_t%3D%5Cfrac%7Bx_t%7D%7B%5Csqrt%7B%5Calpha_t%7D%7D%5Cmathrm%7BA%7D_t%3D%20%5Csqrt%7B%5Cfrac%7B1-%5Calpha_%7Bt%7D%7D%7B%5Calpha_%7Bt%7D%7D%7D 来使公式变得更加简洁。

%5Cmathrm%7BA%7D_t-%5Cmathrm%7BA%7D_%7Bt-%5CDelta%7Bt%7D%7D 就变成了 d%5Cmathrm%7BA%7D_t. %5Cmathrm%7BX%7D_t%20-%20%5Cmathrm%7BX%7D_%7Bt-%5CDelta%7Bt%7D%7D%20 就变成了 d%5Cmathrm%7BX%7D_t.最终形式就是

d%5Cmathrm%7BX%7D_t%3D%5Cepsilon_%5Ctheta(x_t%2Ct)d%5Cmathrm%7BA%7D_t.

一些推荐参考的资料

推荐过N遍了😂,这里不厌其烦地再推荐一遍,一篇系统性介绍DDPM的文章:https://zhuanlan.zhihu.com/p/638442430

讲 DDIM 的文章中,经典中的经典(里面有不少跳步,建议先看一眼上面那个讲DDPM的):https://www.zhangzhenhu.com/aigc/ddim.html#equation-eq-ddim-226

非马尔可夫下的公式由来:https://zhuanlan.zhihu.com/p/627616358

填坑结束(😵)


Copyright © 2024 aigcdaily.cn  北京智识时代科技有限公司  版权所有  京ICP备2023006237号-1