Google 出品,必属精品?
我们先从 DDIM 的知识蒸馏开始(2101.02388),在这个知识蒸馏的设定里面,我们有一个老师 (teacher) 和一个学生 (student),学生的目标是让自己的输出 尽量地接近老师的输出 ,用数学公式表达,就是最小化:
另外,知识蒸馏有一个要求是,输出需要是确定的 (deterministic),所以这里采用的是 DDIM 的设定。
步),当这个学生学习结束以后,这个学生就成了新的老师,然后重复如上的过程。
. 我们通常所见到的 Variance Preserving 扩散过程,是其在 时的特例。 是所谓的 latent, 其实就是 加噪后的数据. .
, 代表了纯高斯噪声 (注意下标 t 的范围是从0到1)。
. 在 的时候,很明显 , , 故信噪比为 0.
针对 loss 函数我们有
代表了在 t 时间点所生成的图片。在公式做了如上的变形之后,我们可以把 loss 看成是在 空间里面的函数(预测图像和原图像的距离),而信噪比则控制了 loss 的权重 (weight). 这里我们把这个权重称作权重函数 (weighting function). 当然,我们还可以设计各种不同的权重函数。
时(即扩散初期),因为 , 所以 任何一点小的波动都会被超级放大。在蒸馏的初期,因为我们的步数很多,早期的一些的错误会在后期被修复;但是越往下蒸馏,步数越少的时候,这种情况就要出问题了。在极端的情况下,如果我们这个逐步蒸馏,进行到只剩下一步了(意味着直接从纯高斯噪声一步生成图片),那么这个时候,整个 loss 也变成 0 了,学生就学不到任何东西了。
对此,论文里面有三种解决方案:
在分母上的问题)
的同时,也预测 ,然后用公式 生成图片。(两种渠道预测的 加权求和)
, 然后
另外,论文里面还提出了两种可行的 loss 的方案:
,所以 . 显然,由 , 我们可得 .
的速度 (velocity) 为:
利用三角函数的那些定理(初高中知识哦),对上面的公式变形后,我们可以得到:
在这里,我们再定义一个预测速度 (predicted velocity):
,我们有:
所以这里解释了上一节的解决方案3的公式由来。
接下来我们要做的只是一些公式变形了,最终我们会得到:
对于每一步的更新,其方法是可以有很多种的。
这里,论文里面使用的更新公式为:
, 对其求导的话就可以得到 .(这里论文假定了 score function 可以用 来近似;详细过程见论文附录)
, 根据 再计算前一步的 . 然后我们计算目标 ,最小化上述的 loss. 大功告成。
完。
注:B站的公式编辑器频繁抽风,如果遇到一些 tex parse error 之类的错误时,尝试刷新一下页面。