当前位置:首页|资讯|AI绘画

90天学会GAN--Day1--从MNIST数据集开始

作者:弱弱的小汤汤发布时间:2023-05-31

1. 什么是GAN

GAN, 全称是Generative Adversarial Networks, 是一种对抗生成网络,用于生成图片:比如AI换脸,AI绘画风格转换。该模型由两个部分组成,分别是 生成器 generator 和 鉴别器 discriminator。 其中,生成器的作用就是生成图片,而鉴别器的作用就是鉴别该图片究竟是输入的图片还是生成器生成的图片 (若是输入的图片则返回1,否则返回0)。

生成器和鉴别器的关系就像是画家与鉴赏家的关系。生成器画一幅画让鉴别器鉴别这是由名家画的还是由生成器画的,然后生成器反馈两者的区别来提高生成器画图的能力。更加通俗易懂的解释还可以看 百度AI 写的《四天搞懂生成对抗网络(一)——通俗理解经典GAN》中的例子。

2. GAN的构建(以MNIST数据集为例)

2.1. 数据预处理

为了更加贴近实际使用,首先使用 gen_label.py 将下载的二进制文件转换为图片

首先使用 pytorch 内置的函数获取 MNIST 数据集:

此时数据已经下载到 /data/mnist 目录下,并且已经存储在了dataloader中 (格式为 (图片,标签)),下一步需要做的就是将图片从二进制文件转化为.png等可以可视化的方式,因此我们构造了以下函数 (需要 import CV2) :

之后我们就可以调用 save_img 函数来把图片写入该目录

另外为了方便之后读取,我们在 /data/mnist 目录下增加了一个 .txt 文件用于索引,格式为图片地址+标签

实现的方法很简单,只需要遍历一遍 dataloader 就好了



Copyright © 2024 aigcdaily.cn  北京智识时代科技有限公司  版权所有  京ICP备2023006237号-1