当前位置:首页|资讯|AIGC|谷歌

AIGC: Progressive Distillation 笔记

作者:刹那-Ksana-发布时间:2023-08-18

Google 出品,必属精品?

DDIM 知识蒸馏(Knowledge Distillation)

我们先从 DDIM 的知识蒸馏开始(2101.02388),在这个知识蒸馏的设定里面,我们有一个老师 (teacher) 和一个学生 (student),学生的目标是让自己的输出 p_%7Bstudent%7D(%5Cmathbf%7Bx%7D_0%20%7C%20%5Cmathbf%7Bx%7D_T) 尽量地接近老师的输出 p_%7Bteacher%7D(%5Cmathbf%7Bx%7D_0%20%7C%20%5Cmathbf%7Bx%7D_T),用数学公式表达,就是最小化:

L_%7Bstudent%7D%3D%20%5Cmathbb%7BE%7D_%7Bx_T%7D%5B%20D_%7BKL%7D(p_%7Bteacher%7D(%5Cmathbf%7Bx%7D_0%20%7C%20%5Cmathbf%7Bx%7D_T)%20%7C%7C%20p_%7Bstudent%7D(%5Cmathbf%7Bx%7D_0%20%7C%20%5Cmathbf%7Bx%7D_T))%20%5D

另外,知识蒸馏有一个要求是,输出需要是确定的 (deterministic),所以这里采用的是 DDIM 的设定。

逐步蒸馏(Progressive Distillation)

示意图;x 代表了我们通常意义的 x0, z 代表了中间步骤 x_t 

N%2F2 步),当这个学生学习结束以后,这个学生就成了新的老师,然后重复如上的过程。

q(z_t%7C%5Cmathbf%7Bx_0%7D)%3D%5Cmathcal%7BN%7D(%5Calpha_t%20%5Cmathbf%7Bx%7D_0%2C%20%5Csigma_t%5E2%20%5Cmathbf%7BI%7D). 我们通常所见到的 Variance Preserving 扩散过程,是其在 %5Csigma_t%3D%5Csqrt%7B1-%5Calpha_t%5E2%7D 时的特例。z_t 是所谓的 latent, 其实就是 %5Cmathbf%7Bx%7D_0 加噪后的数据. t%5Cin%20%5B0%2C1%5D.

%5Calpha_t%20%3D%20%5Ccos(0.5%5Cpi%20t)%5Cmathbf%7Bz_1%7D 代表了纯高斯噪声 %5Cmathcal%7BN%7D(0%2CI)(注意下标 t 的范围是从0到1)。

SNR(t)%3D%5Calpha_t%5E2%20%2F%20%5Csigma_t%5E2.  在 z_1 的时候,很明显 %5Calpha_%7B1%7D%3D0%5Csigma_1%5E2%3D1, 故信噪比为 0.

Loss

针对 loss 函数我们有

L_%7B%5Ctheta%7D%20%3D%20%5ClVert%20%5Cepsilon%20-%20%5Chat%7B%5Cmathbf%7B%5Cepsilon%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t)%5CrVert_%7B2%7D%5E%7B2%7D%20%3D%20%5Cleft%5C%7C%20%5Cfrac%7B1%7D%7B%5Csigma_t%7D(%7B%5Cmathbf%7Bz%7D%7D_t%20-%20%5Calpha_t%7B%5Cmathbf%7Bx%7D_0%7D)%20-%20%5Cfrac%7B1%7D%7B%5Csigma_t%7D(%7B%5Cmathbf%7Bz%7D%7D_t%20-%20%5Calpha_t%5Chat%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t))%5Cright%5C%7C_%7B2%7D%5E%7B2%7D%20%3D%20%5Cfrac%7B%5Calpha%5E%7B2%7D_t%7D%7B%5Csigma%5E%7B2%7D_t%7D%20%5ClVert%20%7B%5Cmathbf%7Bx%7D_0%7D%20-%20%5Chat%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t)%20%5CrVert_%7B2%7D%5E%7B2%7D

%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(z_t%2C%20t)%3D(%5Cmathbf%7Bz%7D_t-%5Csigma_t%20%5Cepsilon_%5Ctheta%20(z_t%2Ct))%2F%20%5Calpha_t 代表了在 t 时间点所生成的图片。在公式做了如上的变形之后,我们可以把 loss 看成是在 %5Cmathbf%7Bx%7D 空间里面的函数(预测图像和原图像的距离),而信噪比则控制了 loss 的权重 (weight). 这里我们把这个权重称作权重函数 (weighting function). 当然,我们还可以设计各种不同的权重函数。

%5Calpha_t%20%5Cto%200 时(即扩散初期),因为 %5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(z_t%2C%20t)%3D(%5Cmathbf%7Bz%7D_t-%5Csigma_t%20%5Cepsilon_%5Ctheta%20(z_t%2Ct))%2F%20%5Calpha_t, 所以 %5Cepsilon_%5Ctheta%20 任何一点小的波动都会被超级放大。在蒸馏的初期,因为我们的步数很多,早期的一些的错误会在后期被修复;但是越往下蒸馏,步数越少的时候,这种情况就要出问题了。在极端的情况下,如果我们这个逐步蒸馏,进行到只剩下一步了(意味着直接从纯高斯噪声一步生成图片),那么这个时候,整个 loss 也变成 0 了,学生就学不到任何东西了。

对此,论文里面有三种解决方案:

  1. %5Calpha_t 在分母上的问题)

  2. %5Ctilde%7B%5Cepsilon%7D_%5Ctheta 的同时,也预测 %5Ctilde%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D,然后用公式 %5Chat%7B%5Cmathbf%7Bx%7D%7D%20%3D%20%5Csigma%5E%7B2%7D_t%5Ctilde%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t)%20%2B%20%5Calpha_%7Bt%7D(%7B%5Cmathbf%7Bz%7D%7D_t%20-%20%5Csigma_t%5Ctilde%7B%5Cmathbf%7B%5Cepsilon%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t))%3D(1-%5Calpha_t%5E2)%5C%20%5Ctilde%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%5Cmathbf%7Bz%7D_t)%20%2B%20%5Calpha_t%5E2%20%20%5Chat%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%5Cmathbf%7Bz%7D_t) 生成图片。(两种渠道预测的 %5Cmathbf%7Bx%7D 加权求和)

  3. %5Cmathbf%7Bv%7D%3A%3D%5Calpha_t%7B%5Cmathbf%7B%5Cepsilon%7D%7D%20-%20%5Csigma_%7Bt%7D%7B%5Cmathbf%7Bx%7D%7D, 然后 %5Chat%7B%5Cmathbf%7Bx%7D%7D%20%3D%20%5Calpha_t%7B%5Cmathbf%7Bz%7D%7D_t%20-%20%5Csigma_t%5Chat%7B%5Cmathbf%7Bv%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t)

三种解决方案+原方案的效果对比,第一个数是FID(越低越好),第二个数是IS(越高越好);论文中认为,这三种解决方案都能取得不错的效果;并且,三种方案都可以直接用来训练去噪扩散模型;N/A 意味着这个过程不稳定

另外,论文里面还提出了两种可行的 loss 的方案:

  1. L_%7B%5Ctheta%7D%20%3D%20%5Ctext%7Bmax%7D(%5ClVert%20%7B%5Cmathbf%7Bx%7D%7D%20-%20%5Chat%7B%7B%5Cmathbf%7Bx%7D%7D%7D_t%20%5CrVert_%7B2%7D%5E%7B2%7D%2C%20%5ClVert%20%7B%5Cmathbf%7B%5Cepsilon%7D%7D%20-%20%5Chat%7B%7B%5Cmathbf%7B%5Cepsilon%7D%7D%7D_t%20%5CrVert_%7B2%7D%5E%7B2%7D)%20%3D%20%5Ctext%7Bmax%7D(%5Cfrac%7B%5Calpha%5E%7B2%7D_t%7D%7B%5Csigma%5E%7B2%7D_t%7D%2C%201)%5ClVert%20%7B%5Cmathbf%7Bx%7D%7D%20-%20%5Chat%7B%7B%5Cmathbf%7Bx%7D%7D%7D_t%20%5CrVert_%7B2%7D%5E%7B2%7D

  2. L_%7B%5Ctheta%7D%20%3D%20%5ClVert%20%7B%5Cmathbf%7Bv%7D%7D_t%20-%20%5Chat%7B%7B%5Cmathbf%7Bv%7D%7D%7D_t%20%5CrVert_%7B2%7D%5E%7B2%7D%20%3D%20(1%2B%5Cfrac%7B%5Calpha%5E%7B2%7D_t%7D%7B%5Csigma%5E%7B2%7D_t%7D)%5ClVert%20%7B%5Cmathbf%7Bx%7D%7D%20-%20%5Chat%7B%7B%5Cmathbf%7Bx%7D%7D%7D_t%20%5CrVert_%7B2%7D%5E%7B2%7D

DDIM Angular Parameterization

%5Cphi_%7Bt%7D%20%3D%20%5Ctext%7Barctan%7D(%5Csigma_%7Bt%7D%2F%5Calpha_%7Bt%7D),所以 %5Calpha_%7B%5Cphi%7D%20%3D%20%5Ccos(%5Cphi)%2C%20%5Csigma_%7B%5Cphi%7D%3D%5Csin(%5Cphi). 显然,由 %5Cmathbf%7Bz%7D_t%20%3D%20%5Calpha_t%20%5Cmathbf%7Bx%7D_0%20%2B%20%5Csigma_t%20%5Cepsilon, 我们可得 %7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D%20%3D%20%5Ccos(%5Cphi)%7B%5Cmathbf%7Bx%7D_0%7D%20%2B%20%5Csin(%5Cphi)%7B%5Cmathbf%7B%5Cepsilon%7D%7D.

%5Cmathbf%7Bz%7D_%5Cphi 的速度 (velocity) 为:

%5Cmathbf%7Bv%7D_%5Cphi%20%3A%3D%20%5Cfrac%7Bd%20%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D%7D%7Bd%5Cphi%7D%20%3D%20%5Cfrac%7Bd%5Ccos(%5Cphi)%7D%7Bd%5Cphi%7D%7B%5Cmathbf%7Bx%7D%7D%20%2B%20%5Cfrac%7Bd%5Csin(%5Cphi)%7D%7Bd%5Cphi%7D%7B%5Cmathbf%7B%5Cepsilon%7D%7D%20%3D%5Ccos(%5Cphi)%7B%5Cmathbf%7B%5Cepsilon%7D%7D%20-%20%5Csin(%5Cphi)%7B%5Cmathbf%7Bx%7D%7D

利用三角函数的那些定理(初高中知识哦),对上面的公式变形后,我们可以得到:

%7B%5Cmathbf%7Bx%7D%7D%20%3D%20%5Ccos(%5Cphi)%7B%5Cmathbf%7Bz%7D%7D%20-%20%5Csin(%5Cphi)%7B%5Cmathbf%7Bv%7D%7D_%7B%5Cphi%7D

%5Cepsilon%20%3D%20%5Csin(%5Cphi)%20%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D%20%2B%20%5Ccos(%5Cphi)%20%7B%5Cmathbf%7Bv%7D%7D_%7B%5Cphi%7D

在这里,我们再定义一个预测速度 (predicted velocity): 

%5Chat%7B%5Cmathbf%7Bv%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D)%20%5Cequiv%20%5Ccos(%5Cphi)%5Chat%7B%5Cmathbf%7B%5Cepsilon%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D)%20-%20%5Csin(%5Cphi)%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D)%0A

%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(z_t%2C%20t)%3D(%5Cmathbf%7Bz%7D_t-%5Csigma_t%20%5Cepsilon_%5Ctheta%20(z_t%2Ct))%2F%20%5Calpha_t,我们有: 

%5Chat%7B%5Cmathbf%7B%5Cepsilon%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D)%20%3D%20(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D%20-%20%5Ccos(%5Cphi)%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D))%2F%5Csin(%5Cphi)

所以这里解释了上一节的解决方案3的公式由来。

接下来我们要做的只是一些公式变形了,最终我们会得到:

%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_%7Bs%7D%7D%20%3D%20%5Ccos(%5Cphi_s%20-%20%5Cphi_t)%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_%7Bt%7D%7D%20%2B%20%5Csin(%5Cphi_s%20-%20%5Cphi_t)%5Chat%7B%5Cmathbf%7Bv%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_t%7D)%20

%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_%7Bt%7D-%5Cdelta%7D%20%3D%20%5Ccos(%5Cdelta)%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_%7Bt%7D%7D%20-%20%5Csin(%5Cdelta)%5Chat%7B%5Cmathbf%7Bv%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_t%7D)

从纯高斯噪声e到原图像x,我们是朝着 -v 的方向沿着一个圆弧在前进

学习目标

对于每一步的更新,其方法是可以有很多种的。

这里,论文里面使用的更新公式为:

%7B%5Cmathbf%7Bz%7D%7D_s%20%3D%20%5Calpha_s%20%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_t)%20%2B%20%5Csigma_s%5Cfrac%7B%7B%5Cmathbf%7Bz%7D%7D_t-%5Calpha_t%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(%5Cmathbf%7Bz%7D_t)%7D%7B%5Csigma_t%7D, 对其求导的话就可以得到 d%7B%5Cmathbf%7Bz%7D%7D%20%3D%20%5Bf(%7B%5Cmathbf%7Bz%7D%7D%2C%20t)%20-%20%5Cfrac%7B1%7D%7B2%7Dg%5E%7B2%7D(t)%5Cnabla_%7Bz%7D%5Clog%20p_%7Bt%7D(%7B%5Cmathbf%7Bz%7D%7D)%20%5Ddt.(这里论文假定了 score function %5Cnabla_%7Bz%7D%5Clog%20p_%7Bt%7D(%5Cmathbf%7Bz%7D) 可以用 %5Cnabla_%7Bz%7D%5Clog%20p_%7Bt%7D(%7B%5Cmathbf%7Bz%7D%7D)%20%5Capprox%20%5Cfrac%7B%5Calpha_%7Bt%7D%5Chat%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t)%20-%20%7B%5Cmathbf%7Bz%7D%7D_t%7D%7B%5Csigma%5E%7B2%7D_t%7D 来近似;详细过程见论文附录)

%5Cmathbf%7Bz%7D_%7Bt'%7D, 根据 %5Cmathbf%7Bz%7D_%7Bt'%7D再计算前一步的 %5Cmathbf%7Bz%7D_%7Bt''%7D. 然后我们计算目标 %5Ctilde%7B%5Cmathbf%7Bx%7D%7D%20%3D%20%5Cfrac%7B%7B%5Cmathbf%7Bz%7D%7D_%7Bt''%7D%20-%20%5Cfrac%7B%5Csigma_%7Bt''%7D%7D%7B%5Csigma_%7Bt%7D%7D%7B%5Cmathbf%7Bz%7D%7D_t%7D%7B%5Calpha_%7Bt''%7D%20-%20%5Cfrac%7B%5Csigma_%7Bt''%7D%7D%7B%5Csigma_%7Bt%7D%7D%5Calpha_t%7D,最小化上述的 loss. 大功告成。

完。

注:B站的公式编辑器频繁抽风,如果遇到一些 tex parse error 之类的错误时,尝试刷新一下页面。


Copyright © 2024 aigcdaily.cn  北京智识时代科技有限公司  版权所有  京ICP备2023006237号-1