• 欢迎访问开心洋葱网站,在线教程,推荐使用最新版火狐浏览器和Chrome浏览器访问本网站,欢迎加入开心洋葱 QQ群
  • 为方便开心洋葱网用户,开心洋葱官网已经开启复制功能!
  • 欢迎访问开心洋葱网站,手机也能访问哦~欢迎加入开心洋葱多维思维学习平台 QQ群
  • 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏开心洋葱吧~~~~~~~~~~~~~!
  • 由于近期流量激增,小站的ECS没能经的起亲们的访问,本站依然没有盈利,如果各位看如果觉着文字不错,还请看官给小站打个赏~~~~~~~~~~~~~!

【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题

人工智能 喵木木 1901次浏览 0个评论

 

【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题

目录

    • 1. 生成模型(Generative Model)
      • 1.1 自编码器(Autoencoder)
      • 1.2 变分自编码器(Variational AutoEncoder,VAE)
    • 2. 生成对抗网络(Generative Adversarial Networks,GAN)
      • 2.1 生成对抗网络模型概述
      • 2.2 生成对抗网络的数学原理
        • 2.2.1 预备知识
        • 2.2.2 生成对抗网络的数学原理
    • 3 【案例一】利用PyTorch实现GAN【生成新的图片】
      • 3.1 模型构建
      • 3.2 损失函数和优化器
      • 3.3 训练模型
      • 3.4 采用不同的loss函数
      • 3.5 使用更复杂的卷积神经网络
    • 4. 【案例二】使用语言模型生成新的文本

1. 生成模型(Generative Model)

  生成模型是指一些列用于随机生成可观测数据的模型,并使“生成”的样本和“真实”样本尽可能地相似。这个概念属于概率统计和机器学习。生成模型的两个主要功能就是学习概率分布P_model(X)和生成数据。   前面介绍的CNN和RNN等网络,大多应用于监督学习的任务,生成模型的目的之一就是希望能让无监督学习取得比较大的进展。  

概率生成模型,简称生成模型(Generative Model),是概率统计和机器学习中的一类重要模型,指一系列用于随机生成可观测数据的模型 。生成模型的应用十分广泛,可以用来不同的数据进行建模,比如图像、文本、声音等。比如图像生成,我们将图像表示为一个随机向量X,其中每一维都表示一个像素值。假设自然场景的图像都服从一个未知的分布pr(x),希望通过一些观测样本来估计其分布。高维随机向量一般比较难以直接建模,需要通过一些条件独立性来简化模型。深度生成模型就是利用深层神经网络可以近似任意函数的能力来建模一个复杂的分布。方法:从统计的角度表示数据的分布情况,能够反映同类数据本身的相似度;生成方法还原出联合概率分布,而判别方法不能;生成方法的学习收敛速度更快、即当样本容量增加的时候,学到的模型可以更快地收敛于真实模型;当存在隐变量时,仍可以用生成方法学习,此时判别方法不能用。 ——百度百科  

这里我们介绍两个最简单的生成模型—— 自编码器(Autoencoder) 和 变分自编码器(Variational AutoEncoder)。  

1.1 自编码器(Autoencoder)

 
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   自编码器包括两个部分,第一部分是编码器(encoder),第二部分是解码器(decoder),编码器和解码器都可以是任意的模型,通常我们可以使用神经网络作为我们的编码器和解码器,输入的数据经过神经网络(NN Encoder)降维到一个编码(Code),然后又通过另外一个神经网络解码(NN Decoder)得到一个与原始数据一模一样的生成数据,通过比较原始数据和生成数据,希望他们尽可能接近,所以最小化他们之间的差异来训练网络中编码器和解码器的参数。   我们接下来可以使用PyTorch实现一个简单的自编码器  

数据准备

  我们还是使用MNIST手写数据集  

<code class="prism language-python has-numbering">im_tfs <span class="token operator">=</span> tfs<span class="token punctuation">.</span>Compose<span class="token punctuation">(</span><span class="token punctuation">[</span>
    tfs<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
    tfs<span class="token punctuation">.</span>Normalize<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">0.1307</span><span class="token punctuation">,</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span><span class="token number">0.3081</span><span class="token punctuation">,</span><span class="token punctuation">)</span><span class="token punctuation">)</span> 
<span class="token punctuation">]</span><span class="token punctuation">)</span>

train_set <span class="token operator">=</span> MNIST<span class="token punctuation">(</span><span class="token string">'./mnist'</span><span class="token punctuation">,</span> transform<span class="token operator">=</span>im_tfs<span class="token punctuation">,</span> download<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
train_data <span class="token operator">=</span> DataLoader<span class="token punctuation">(</span>train_set<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">128</span><span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
</code>

构建模型

  我们定义一个简单的4层网络作为编码器,中间使用ReLU激活函数;解码器同样使用简单的4层网络,前面三层采用ReLU激活函数,最后一层采用Tanh函数。  

<code class="prism language-python has-numbering"><span class="token keyword">class</span> <span class="token class-name">autoencoder</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token builtin">super</span><span class="token punctuation">(</span>autoencoder<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
        
        self<span class="token punctuation">.</span>encoder <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">28</span><span class="token operator">*</span><span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">12</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">12</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span> <span class="token comment"># 输出的 code 是 3 维,便于可视化</span>
        <span class="token punctuation">)</span>
        
        self<span class="token punctuation">.</span>decoder <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">12</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">12</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token operator">*</span><span class="token number">28</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Tanh<span class="token punctuation">(</span><span class="token punctuation">)</span>
        <span class="token punctuation">)</span>

    <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
        encode <span class="token operator">=</span> self<span class="token punctuation">.</span>encoder<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
        decode <span class="token operator">=</span> self<span class="token punctuation">.</span>decoder<span class="token punctuation">(</span>encode<span class="token punctuation">)</span>
        <span class="token keyword">return</span> encode<span class="token punctuation">,</span> decode
</code>

  开始训练  

<code class="prism language-python has-numbering">net <span class="token operator">=</span> autoencoder<span class="token punctuation">(</span><span class="token punctuation">)</span>
    
criterion <span class="token operator">=</span> nn<span class="token punctuation">.</span>MSELoss<span class="token punctuation">(</span>size_average<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
optimizer <span class="token operator">=</span> torch<span class="token punctuation">.</span>optim<span class="token punctuation">.</span>Adam<span class="token punctuation">(</span>net<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> lr<span class="token operator">=</span><span class="token number">1e</span><span class="token operator">-</span><span class="token number">3</span><span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">to_img</span><span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token triple-quoted-string string">'''
    定义一个函数将最后的结果转换回图片
    '''</span>
    x <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>x <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">.</span><span class="token punctuation">)</span>
    x <span class="token operator">=</span> x<span class="token punctuation">.</span>clamp<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
    x <span class="token operator">=</span> x<span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">28</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> x

<span class="token comment"># 开始训练自动编码器</span>
<span class="token keyword">for</span> e <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span><span class="token number">100</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">for</span> im<span class="token punctuation">,</span> _ <span class="token keyword">in</span> train_data<span class="token punctuation">:</span>
        im <span class="token operator">=</span> im<span class="token punctuation">.</span>view<span class="token punctuation">(</span>im<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
        im <span class="token operator">=</span> Variable<span class="token punctuation">(</span>im<span class="token punctuation">)</span>
        <span class="token comment"># 前向传播</span>
        _<span class="token punctuation">,</span> output <span class="token operator">=</span> net<span class="token punctuation">(</span>im<span class="token punctuation">)</span>
        loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>output<span class="token punctuation">,</span> im<span class="token punctuation">)</span> <span class="token operator">/</span> im<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token comment"># 平均</span>
        <span class="token comment"># 反向传播</span>
        optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
        loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
        optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span>
    
    <span class="token keyword">if</span> <span class="token punctuation">(</span>e<span class="token operator">+</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">%</span> <span class="token number">20</span> <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span> <span class="token comment"># 每 20 次,将生成的图片保存一下</span>
        <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'epoch: {}, Loss: {:.4f}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>e <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">,</span> loss<span class="token punctuation">.</span>data<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        pic <span class="token operator">=</span> to_img<span class="token punctuation">(</span>output<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>data<span class="token punctuation">)</span>
        <span class="token keyword">if</span> <span class="token operator">not</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>exists<span class="token punctuation">(</span><span class="token string">'./simple_autoencoder'</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
            os<span class="token punctuation">.</span>mkdir<span class="token punctuation">(</span><span class="token string">'./simple_autoencoder'</span><span class="token punctuation">)</span>
        save_image<span class="token punctuation">(</span>pic<span class="token punctuation">,</span> <span class="token string">'./simple_autoencoder/image_{}.png'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>e <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
</code>

  我们看一下生成的图片,效果还是很不错的:  
在这里插入图片描述   使用更复杂一点的网络  

<code class="prism language-python has-numbering"><span class="token keyword">class</span> <span class="token class-name">conv_autoencoder</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token builtin">super</span><span class="token punctuation">(</span>conv_autoencoder<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
        
        self<span class="token punctuation">.</span>encoder <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
            nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">16</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>  <span class="token comment"># (b, 16, 10, 10)</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>MaxPool2d<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span>  <span class="token comment"># (b, 16, 5, 5)</span>
            nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">16</span><span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>  <span class="token comment"># (b, 8, 3, 3)</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>MaxPool2d<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span>  <span class="token comment"># (b, 8, 2, 2)</span>
        <span class="token punctuation">)</span>
        
        self<span class="token punctuation">.</span>decoder <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
            nn<span class="token punctuation">.</span>ConvTranspose2d<span class="token punctuation">(</span><span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">16</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span>  <span class="token comment"># (b, 16, 5, 5)</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ConvTranspose2d<span class="token punctuation">(</span><span class="token number">16</span><span class="token punctuation">,</span> <span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">3</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>  <span class="token comment"># (b, 8, 15, 15)</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ConvTranspose2d<span class="token punctuation">(</span><span class="token number">8</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>  <span class="token comment"># (b, 1, 28, 28)</span>
            nn<span class="token punctuation">.</span>Tanh<span class="token punctuation">(</span><span class="token punctuation">)</span>
        <span class="token punctuation">)</span>

    <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
        encode <span class="token operator">=</span> self<span class="token punctuation">.</span>encoder<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
        decode <span class="token operator">=</span> self<span class="token punctuation">.</span>decoder<span class="token punctuation">(</span>encode<span class="token punctuation">)</span>
        <span class="token keyword">return</span> encode<span class="token punctuation">,</span> decode

conv_net <span class="token operator">=</span> conv_autoencoder<span class="token punctuation">(</span><span class="token punctuation">)</span>

criterion <span class="token operator">=</span> nn<span class="token punctuation">.</span>MSELoss<span class="token punctuation">(</span>size_average<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    conv_net <span class="token operator">=</span> conv_net<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
optimizer <span class="token operator">=</span> torch<span class="token punctuation">.</span>optim<span class="token punctuation">.</span>Adam<span class="token punctuation">(</span>conv_net<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> lr<span class="token operator">=</span><span class="token number">1e</span><span class="token operator">-</span><span class="token number">3</span><span class="token punctuation">,</span> weight_decay<span class="token operator">=</span><span class="token number">1e</span><span class="token operator">-</span><span class="token number">5</span><span class="token punctuation">)</span>
</code>

  我们使用一个更加复杂一点的网络——多层卷积神经网络构建编码器;解码器使用了nn.ConvTranspose2d(),相当于卷积的反操作。   我们可以对比下两个模型的训练效果:  
在这里插入图片描述   可以看到,同样是训练40轮,卷积模型生成的图片更清晰一些。  

1.2 变分自编码器(Variational AutoEncoder,VAE)

  变分自编码器是自动编码器的升级版本,它的结构和自动编码器非常相似。与自动编码器不同的是,在编码过程中增加一些限制,使其生成的隐含向量能够粗略地遵循一个标准正态分布;这样只要给它一个标准正态分布的随机隐含向量,通过解码器就能够生成想要的图片,而不需要先给它一张原始图片编码。   因此,我们的编码次每次生成的不是一个隐含向量,而是一个均值和一个标准差。   同样可以利用PyTorch实现变分自编码器的结构:   完整代码  

<code class="prism language-python has-numbering">reconstruction_function <span class="token operator">=</span> nn<span class="token punctuation">.</span>MSELoss<span class="token punctuation">(</span>size_average<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">loss_function</span><span class="token punctuation">(</span>recon_x<span class="token punctuation">,</span> x<span class="token punctuation">,</span> mu<span class="token punctuation">,</span> logvar<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token triple-quoted-string string">"""
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """</span>
    MSE <span class="token operator">=</span> reconstruction_function<span class="token punctuation">(</span>recon_x<span class="token punctuation">,</span> x<span class="token punctuation">)</span>
    <span class="token comment"># loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)</span>
    KLD_element <span class="token operator">=</span> mu<span class="token punctuation">.</span><span class="token builtin">pow</span><span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>add_<span class="token punctuation">(</span>logvar<span class="token punctuation">.</span>exp<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>mul_<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>add_<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>add_<span class="token punctuation">(</span>logvar<span class="token punctuation">)</span>
    KLD <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>KLD_element<span class="token punctuation">)</span><span class="token punctuation">.</span>mul_<span class="token punctuation">(</span><span class="token operator">-</span><span class="token number">0.5</span><span class="token punctuation">)</span>
    <span class="token comment"># KL divergence</span>
    <span class="token keyword">return</span> MSE <span class="token operator">+</span> KLD

<span class="token keyword">class</span> <span class="token class-name">VAE</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token builtin">super</span><span class="token punctuation">(</span>VAE<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>

        self<span class="token punctuation">.</span>fc1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">784</span><span class="token punctuation">,</span> <span class="token number">400</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>fc21 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">400</span><span class="token punctuation">,</span> <span class="token number">20</span><span class="token punctuation">)</span> <span class="token comment"># mean</span>
        self<span class="token punctuation">.</span>fc22 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">400</span><span class="token punctuation">,</span> <span class="token number">20</span><span class="token punctuation">)</span> <span class="token comment"># var</span>
        self<span class="token punctuation">.</span>fc3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">20</span><span class="token punctuation">,</span> <span class="token number">400</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>fc4 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">400</span><span class="token punctuation">,</span> <span class="token number">784</span><span class="token punctuation">)</span>

    <span class="token keyword">def</span> <span class="token function">encode</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
        h1 <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>fc1<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">return</span> self<span class="token punctuation">.</span>fc21<span class="token punctuation">(</span>h1<span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>fc22<span class="token punctuation">(</span>h1<span class="token punctuation">)</span>

    <span class="token keyword">def</span> <span class="token function">reparametrize</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> mu<span class="token punctuation">,</span> logvar<span class="token punctuation">)</span><span class="token punctuation">:</span>
        std <span class="token operator">=</span> logvar<span class="token punctuation">.</span>mul<span class="token punctuation">(</span><span class="token number">0.5</span><span class="token punctuation">)</span><span class="token punctuation">.</span>exp_<span class="token punctuation">(</span><span class="token punctuation">)</span>
        eps <span class="token operator">=</span> torch<span class="token punctuation">.</span>FloatTensor<span class="token punctuation">(</span>std<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span>normal_<span class="token punctuation">(</span><span class="token punctuation">)</span>
        <span class="token keyword">if</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
            eps <span class="token operator">=</span> Variable<span class="token punctuation">(</span>eps<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">else</span><span class="token punctuation">:</span>
            eps <span class="token operator">=</span> Variable<span class="token punctuation">(</span>eps<span class="token punctuation">)</span>
        <span class="token keyword">return</span> eps<span class="token punctuation">.</span>mul<span class="token punctuation">(</span>std<span class="token punctuation">)</span><span class="token punctuation">.</span>add_<span class="token punctuation">(</span>mu<span class="token punctuation">)</span>

    <span class="token keyword">def</span> <span class="token function">decode</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> z<span class="token punctuation">)</span><span class="token punctuation">:</span>
        h3 <span class="token operator">=</span> F<span class="token punctuation">.</span>relu<span class="token punctuation">(</span>self<span class="token punctuation">.</span>fc3<span class="token punctuation">(</span>z<span class="token punctuation">)</span><span class="token punctuation">)</span>
        <span class="token keyword">return</span> torch<span class="token punctuation">.</span>tanh<span class="token punctuation">(</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>fc4<span class="token punctuation">(</span>h3<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

    <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
        mu<span class="token punctuation">,</span> logvar <span class="token operator">=</span> self<span class="token punctuation">.</span>encode<span class="token punctuation">(</span>x<span class="token punctuation">)</span> <span class="token comment"># 编码</span>
        z <span class="token operator">=</span> self<span class="token punctuation">.</span>reparametrize<span class="token punctuation">(</span>mu<span class="token punctuation">,</span> logvar<span class="token punctuation">)</span> <span class="token comment"># 重新参数化成正态分布</span>
        <span class="token keyword">return</span> self<span class="token punctuation">.</span>decode<span class="token punctuation">(</span>z<span class="token punctuation">)</span><span class="token punctuation">,</span> mu<span class="token punctuation">,</span> logvar <span class="token comment"># 解码,同时输出均值方差</span>
</code>

  可以把VAE的结果和自动编码器的结果进行对比:  
在这里插入图片描述  

2. 生成对抗网络(Generative Adversarial Networks,GAN)

  生成对抗网络由两部分组成——生成+对抗。第一部分是生成模型,类似自动编码器的解码部分;第二部分是对抗模型,可以看成是一个判断真假图片的判别器。生成对抗网络就是让两个网络相互竞争,通过生成网络来生成假的数据,对抗网络通过判别器判别真伪,最后希望生成网络生成数据能够以假乱真骗过判别器。   下面这张图表示生成对抗网络生成数据过程。
在这里插入图片描述  

2.1 生成对抗网络模型概述

 

生成模型

  在生成对抗网络中,不再是将图片输入编码器得到隐含向量然后生成图片,而是随机初始化一个隐含向量,根据变分自编码器的特点,初始化一个正态分布的隐含向量,通过类似解码的过程,将它映射到一个更高的维度,最后生成一个与输入数据相似的数据,这就是假的图片。生成对抗网络会通过对抗过程来计算出这个损失函数。  

对抗模型

  对抗模型简单来说就是一个判断真假的判别器,相当于一个二分类问题,真的图片输出1,假的图片输出0。对抗模型不关心图片的分类(例如,是猫还是狗),只关心图片的真假(不管是猫的图片还是狗的图片,对抗模型都会认为它是真的图片)。  

生成对抗网络的训练

  生成对抗网络在训练的时候,先训练判别器,将假的数据和真的数据都输入给判别模型,并优化这个判别器模型,希望它能够准确地判断出真的数据和假的数据。具体而言,是希望假数据尽可能得到标签0。   然后,开始训练生成器,希望它生成的假数据能够骗过已经经过优化的判别器。在这个阶段,将判别器的参数固定,通过反向传播优化生成器的参数,希望生成器得到的数据经过判别之后的标签尽可能是1(即为真数据)。  

2.2 生成对抗网络的数学原理

  这一节涉及到比较多的数学推理,但是对理解生成对抗网络很有帮助   推荐阅读: 生成对抗网络(GAN) 背后的数学理论  

2.2.1 预备知识

  KL散度  

相对熵(relative entropy),又被称为Kullback-Leibler散度(Kullback-Leibler divergence)或信息散度(information divergence),是两个概率分布(probability distribution)间差异的非对称性度量。  

KL散度是统计学中的一个概念,用于衡量两种概率分布的相似程度,数值越小,表示两种概率分布越接近。   假设P ( x ) Q ( x ) )是两个概率分布。若P ( x ) P(x)P(x)Q ( x )是离散的概率分布,则它们的KL散度定义为  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   若P ( x ) Q ( x ) 是连续概率分布,则他们的KL散度定义为
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   举个例子,假设有两个分布A B,它们出现0和1的概率分别为
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   那么,这两个概率分布的KL散度为
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   极大似然估计   推荐阅读/参考文献: 一文搞懂极大似然估计 极大似然估计详解  

极大似然估计,通俗理解来说,就是利用已知的样本结果信息,反推最具有可能(最大概率)导致这些样本结果出现的模型参数值! 换句话说,极大似然估计提供了一种给定观察数据来评估模型参数的方法,即:“模型已定,参数未知”。  

我们先来看一下贝叶斯决策:  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   其中P ( ω ) 为先验概率,即每种类别分布的概率;P ( x ) 为某件事情发生的概率;P ( x ∣ ω ) 为类条件概率,即某种类别前提下,某事发生的概率;P ( ω ∣ x )为后验概率,即某事发生了,并且它属于某一类别的概率。后验概率越大,说明某事物属于这个类别的可能性越大,我们越有理由把它归到这个类别下。 . 实际问题中,我们可能只有样本数据,先验概率P ( ω )和类条件概率P ( x ∣ ω ) 都是未知的。那么我们怎么根据仅有的样本数据进行分类呢?一种可行的办法就是对先验概率P ( ω ) 和类条件概率P ( x ∣ ω ) 进行估计,然后再套用上述公式进行分类。   对先验概率P ( ω ) P(\omega)P(ω)的估计是比较简单的,我们可以利用经验或者依据数据样本中各类别出现的频率进行估计。   而对类条件概率P ( x ∣ ω ) P(x|\omega)P(xω)的估计就难得多。原因包括:概率密度函数包含了一个随机变量的全部信息;样本数据可能不多;特征向量x的维度可能很大等等。一个解决的方法就是把概率密度p ( x ∣ ω i ) p(x|\omega_i)p(xωi)的估计转化为参数估计问题(参数估计方法)。比如可以假设为正态分布,那么需要估计的参数就是σ \sigmaσμ \muμ。当然,概率密度分布p ( x ∣ ω i ) p(x|\omega_i)p(xωi)非常重要。同时,采用这种参数估计的方法,要求数据满足“独立同分布”的假设,并且要保证样本量充足。   那么,假设我们已经确定了概率密度分布的模型,我们怎么确定哪些参数组合时最好的呢?这就是“极大似然估计”的目的。  

最大似然估计的目的就是:利用已知的样本结果,反推最有可能(最大概率)导致这样结果的参数值。 原理:极大似然估计是建立在极大似然原理的基础上的一个统计方法,是概率论在统计学中的应用。极大似然估计提供了一种给定观察数据来评估模型参数的方法,即:“模型已定,参数未知”。通过若干次试验,观察其结果,利用试验结果得到某个参数值能够使样本出现的概率为最大,则称为极大似然估计。  

假设已知的样本集为D = { x 1 , x 2 , . . . , x n },需要估计的参数向量为θ \thetaθ,则相对于{ x 1 , x 2 , . . . , x n } θ 的似然函数(likehood function)为其联合概率密度函数P ( D ∣ θ ),可以理解为在参数为θ 的条件下D = { x 1 , x 2 , . . . , x n } 发生的概率(数据样本服从独立同分布):  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   这样一来,要使似然函数最大(即在参数为θ 的条件下D = { x 1 , x 2 , . . . , x n } 发生的概率最大),就是求似然函数l ( θ ) 的极值θ ^θ ^ 称作极大似然函数估计值  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   为了便于分析,定义似然函数的对数  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   求极值自然就是求导了。  

  1. 如果θ为标量,则在似然函数满足连续、可微的正则条件下,极大似然估计量是下面微分方程的解 【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题
  2. 如果未知参数有很多个,即θ \thetaθ为向量,则可以对每个参数求偏导数,得到梯度算子 【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题 若似然函数满足连续可导的条件,则最大似然估计量就是如下方程的解。 【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题

 

2.2.2 生成对抗网络的数学原理

  我们再回到我们的生成对抗网络。   我们想要将一个随机高斯噪声z zz通过一个生成网络G 得到一个和真实数据分布P d a t a ( x ) 差不多的生成分布P G ( x ; θ ) ,其中参数θ是 网络的参数,并且我们希望θ 可以使P G ( x ; θ )尽可能和P d a t a ( x ) 接近。2   我们从真实数据分布P d a t a ( x ) 中取样m 个点,{ x 1 , x 2 , . . . x m } 。根据给定的参数θ可以计算概率P G ( x i ; θ ) ,那么生成m 个样本数据的似然就是   【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题     则极大似然函数估计值θ ^为  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   也就是说,极大似然的另一个角度就是使KL散度极小。其中   【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   里面的I II表示指示性函数,即   也就是说,我们希望得到一组模型参数θ ^,使得其随机生成的向量的概率分布P G ( x ; θ )与真实向量的概率分布P d a t a ( x ; θ ) 之间的KL散度最小(KL散度用于衡量两种概率分布的相似程度,数值越小,表示两种概率分布越接近)。   那么,我们怎么得到P G ( x ; θ 呢?在生成对抗网络中,生成器(生成模型,Generator,G)的作用就是将服从正态分布的样本得到输出集合,通过计算输出集合中各类样本的频次,会得到一个概率分布,这个概率分布就是P G ( x ; θ )。这里的θ 就是生成器的神经网络模型的参数。换句话说,机器学习的任务就是不断地训练、学习参数θ ,使得P G ( x ; θ ) P d a t a ( x ; θ )之间的KL散度最小。而判别器(对抗模型,Discriminator,D)的作用就是衡量P G ( x ; θ ) P d a t a ( x ; θ ) 之间的差距。   我们知道,判别器的作用其实是一个二分类问题,那么我们可以定义判别器的目标函数V ( G , D ) 为  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   这个函数和逻辑回归的目标函数的形式是一样的。   我们的目标是找到一个合适的判别器D DD,使得max ⁡ D V ( G , D ) \max_{D} V(G,D)maxDV(G,D)。为什么呢?因为我们希望判别器能够判断哪个是真,哪个是假,自然是他们之间的差距越大越好。由  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   那么,
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   在给定生成器G GG的前提下,P d a t a ( x ) P_{data}(x)Pdata(x)P G ( x ) P_G(x)PG(x)都可以看成是常数,分别用a aab bb来表示,那么,等价于求解f ( D ) f(D)f(D)的极值,f ( D ) f(D)f(D)为  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   这样就求得了在给定G GG的前提下,使得V ( D ) V(D)V(D)取得最大值的D DD。将D DD带回到V ( D ) V(D)V(D),得到  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   其中J S D ( ∗ )为JS散度,定义为  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题   这也就是为什么我们再训练生成对抗网络的时候,先训练判别器。因为我们首先希望得到判别器的参数θ D,使其能够将真假图片最大化,以得到较优的判别器。   然后我们在这个较优判别器的基础上,再训练生成器,我们希望得到生成器的参数θ G ,能够和真实数据之间的概率分布的KL散度最小,即整体的优化目标是  
【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题  

3 【案例一】利用PyTorch实现GAN【生成新的图片】

  完整代码   在这个案例里面,我们同样是希望能够生成手写数字的图片。数据集是MNIST。  

3.1 模型构建

  构建判别器——对抗模型   判别器的结构非常简单,由三层全连接神经网络构成,中间使用了斜率为0.2的LeakyReLU激活函数。 leakyrelu 是指 f(x) = max(α \alphaα x, x)。  

<code class="prism language-python has-numbering"><span class="token keyword">class</span> <span class="token class-name">discriminator</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token builtin">super</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>net <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">784</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                 nn<span class="token punctuation">.</span>LeakyReLU<span class="token punctuation">(</span><span class="token number">0.2</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                 nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">256</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                 nn<span class="token punctuation">.</span>LeakyReLU<span class="token punctuation">(</span><span class="token number">0.2</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                 nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">256</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
                                 <span class="token punctuation">)</span>
    
    <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
        x <span class="token operator">=</span> self<span class="token punctuation">.</span>net<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
        <span class="token keyword">return</span> x
</code>

  构建生成器——生成模型   生成网络的结构也很简单,最后一层的激活函数选择tanh()函数的目的是将图片规范化到-1~1之间。  

<code class="prism language-python has-numbering"><span class="token keyword">class</span> <span class="token class-name">generator</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> input_size<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token builtin">super</span><span class="token punctuation">(</span>generator<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>gen <span class="token operator">=</span>nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
                nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>input_size<span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">,</span> <span class="token number">784</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                nn<span class="token punctuation">.</span>Tanh<span class="token punctuation">(</span><span class="token punctuation">)</span>
                <span class="token punctuation">)</span>
    
    <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
        x <span class="token operator">=</span> self<span class="token punctuation">.</span>gen<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
        <span class="token keyword">return</span> x
</code>

 

3.2 损失函数和优化器

 

<code class="prism language-python has-numbering">bce_loss <span class="token operator">=</span> nn<span class="token punctuation">.</span>BCEWithLogitsLoss<span class="token punctuation">(</span><span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">discriminator_loss</span><span class="token punctuation">(</span>logits_real<span class="token punctuation">,</span> logits_fake<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># 判别器的 loss</span>
    size <span class="token operator">=</span> logits_real<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
    true_labels <span class="token operator">=</span> Variable<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>ones<span class="token punctuation">(</span>size<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
    false_labels <span class="token operator">=</span> Variable<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span>size<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
    loss <span class="token operator">=</span> bce_loss<span class="token punctuation">(</span>logits_real<span class="token punctuation">,</span> true_labels<span class="token punctuation">)</span> <span class="token operator">+</span> bce_loss<span class="token punctuation">(</span>logits_fake<span class="token punctuation">,</span> false_labels<span class="token punctuation">)</span>
    <span class="token keyword">return</span> loss

<span class="token keyword">def</span> <span class="token function">generator_loss</span><span class="token punctuation">(</span>logits_fake<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># 生成器的 loss  </span>
    size <span class="token operator">=</span> logits_fake<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
    true_labels <span class="token operator">=</span> Variable<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>ones<span class="token punctuation">(</span>size<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
    loss <span class="token operator">=</span> bce_loss<span class="token punctuation">(</span>logits_fake<span class="token punctuation">,</span> true_labels<span class="token punctuation">)</span>
    <span class="token keyword">return</span> loss

<span class="token comment"># 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999</span>
<span class="token keyword">def</span> <span class="token function">get_optimizer</span><span class="token punctuation">(</span>net<span class="token punctuation">)</span><span class="token punctuation">:</span>
    optimizer <span class="token operator">=</span> torch<span class="token punctuation">.</span>optim<span class="token punctuation">.</span>Adam<span class="token punctuation">(</span>net<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> lr<span class="token operator">=</span><span class="token number">3e</span><span class="token operator">-</span><span class="token number">4</span><span class="token punctuation">,</span> betas<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">0.5</span><span class="token punctuation">,</span> <span class="token number">0.999</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> optimizer
</code>

 

3.3 训练模型

 

<code class="prism language-python has-numbering"><span class="token keyword">def</span> <span class="token function">train_a_gan</span><span class="token punctuation">(</span>D_net<span class="token punctuation">,</span> G_net<span class="token punctuation">,</span> D_optimizer<span class="token punctuation">,</span> G_optimizer<span class="token punctuation">,</span> discriminator_loss<span class="token punctuation">,</span> generator_loss<span class="token punctuation">,</span> show_every<span class="token operator">=</span><span class="token number">250</span><span class="token punctuation">,</span> 
                noise_size<span class="token operator">=</span>NOISE_DIM<span class="token punctuation">,</span> num_epochs<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    iter_count <span class="token operator">=</span> <span class="token number">0</span>
    <span class="token keyword">for</span> epoch <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>num_epochs<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token keyword">for</span> x<span class="token punctuation">,</span> _ <span class="token keyword">in</span> train_data<span class="token punctuation">:</span>
            bs <span class="token operator">=</span> x<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
            <span class="token comment"># 判别网络</span>
            real_data <span class="token operator">=</span> Variable<span class="token punctuation">(</span>x<span class="token punctuation">)</span><span class="token punctuation">.</span>view<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 真实数据</span>
            logits_real <span class="token operator">=</span> D_net<span class="token punctuation">(</span>real_data<span class="token punctuation">)</span> <span class="token comment"># 判别网络得分</span>
            
            sample_noise <span class="token operator">=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>rand<span class="token punctuation">(</span>bs<span class="token punctuation">,</span> noise_size<span class="token punctuation">)</span> <span class="token operator">-</span> <span class="token number">0.5</span><span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token number">0.5</span> <span class="token comment"># -1 ~ 1 的均匀分布</span>
            g_fake_seed <span class="token operator">=</span> Variable<span class="token punctuation">(</span>sample_noise<span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
            fake_images <span class="token operator">=</span> G_net<span class="token punctuation">(</span>g_fake_seed<span class="token punctuation">)</span> <span class="token comment"># 生成的假的数据</span>
            logits_fake <span class="token operator">=</span> D_net<span class="token punctuation">(</span>fake_images<span class="token punctuation">)</span> <span class="token comment"># 判别网络得分</span>

            d_total_error <span class="token operator">=</span> discriminator_loss<span class="token punctuation">(</span>logits_real<span class="token punctuation">,</span> logits_fake<span class="token punctuation">)</span> <span class="token comment"># 判别器的 loss</span>
            D_optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
            d_total_error<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
            D_optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 优化判别网络</span>
            
            <span class="token comment"># 生成网络</span>
            g_fake_seed <span class="token operator">=</span> Variable<span class="token punctuation">(</span>sample_noise<span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span>
            fake_images <span class="token operator">=</span> G_net<span class="token punctuation">(</span>g_fake_seed<span class="token punctuation">)</span> <span class="token comment"># 生成的假的数据</span>

            gen_logits_fake <span class="token operator">=</span> D_net<span class="token punctuation">(</span>fake_images<span class="token punctuation">)</span>
            g_error <span class="token operator">=</span> generator_loss<span class="token punctuation">(</span>gen_logits_fake<span class="token punctuation">)</span> <span class="token comment"># 生成网络的 loss</span>
            G_optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
            g_error<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
            G_optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 优化生成网络</span>

            <span class="token keyword">if</span> <span class="token punctuation">(</span>iter_count <span class="token operator">%</span> show_every <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Iter: {}, D: {:.4}, G:{:.4}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>iter_count<span class="token punctuation">,</span> d_total_error<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> g_error<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
                imgs_numpy <span class="token operator">=</span> deprocess_img<span class="token punctuation">(</span>fake_images<span class="token punctuation">.</span>data<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
                show_images<span class="token punctuation">(</span>imgs_numpy<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">:</span><span class="token number">16</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
                plt<span class="token punctuation">.</span>show<span class="token punctuation">(</span><span class="token punctuation">)</span>
                <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
            iter_count <span class="token operator">+=</span> <span class="token number">1</span>
</code>

  训练结果如下:  
在这里插入图片描述 可以看到,在iteration=4000的时候,隐约有了手写数字的样子。

3.4 采用不同的loss函数

  Least Squares GAN 比最原始的 GANs 的 loss 更加稳定。我们可以定义一下loss函数:  

<code class="prism language-python has-numbering"><span class="token keyword">def</span> <span class="token function">ls_discriminator_loss</span><span class="token punctuation">(</span>scores_real<span class="token punctuation">,</span> scores_fake<span class="token punctuation">)</span><span class="token punctuation">:</span>
    loss <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token punctuation">(</span>scores_real <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span>scores_fake <span class="token operator">**</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> loss

<span class="token keyword">def</span> <span class="token function">ls_generator_loss</span><span class="token punctuation">(</span>scores_fake<span class="token punctuation">)</span><span class="token punctuation">:</span>
    loss <span class="token operator">=</span> <span class="token number">0.5</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token punctuation">(</span>scores_fake <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>mean<span class="token punctuation">(</span><span class="token punctuation">)</span>
    <span class="token keyword">return</span> loss
</code>

  再次训练,可以比较一下采用不同的loss函数的结果:  
在这里插入图片描述  

3.5 使用更复杂的卷积神经网络

  前面只是使用了简单的全连接网络来构建生成器和判别器,我们同样可以使用更复杂的卷积神经网络来构建生成器和判别器。   判别器  

<code class="prism language-python has-numbering"><span class="token keyword">class</span> <span class="token class-name">build_dc_classifier</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token builtin">super</span><span class="token punctuation">(</span>build_dc_classifier<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>conv <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
            nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">32</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>LeakyReLU<span class="token punctuation">(</span><span class="token number">0.01</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>MaxPool2d<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Conv2d<span class="token punctuation">(</span><span class="token number">32</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>LeakyReLU<span class="token punctuation">(</span><span class="token number">0.01</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>MaxPool2d<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span>
        <span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>fc <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>LeakyReLU<span class="token punctuation">(</span><span class="token number">0.01</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span>
        <span class="token punctuation">)</span>
        
    <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
        x <span class="token operator">=</span> self<span class="token punctuation">.</span>conv<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
        x <span class="token operator">=</span> x<span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>
        x <span class="token operator">=</span> self<span class="token punctuation">.</span>fc<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
        <span class="token keyword">return</span> x
</code>

  生成器  

<code class="prism language-python has-numbering"><span class="token keyword">class</span> <span class="token class-name">build_dc_generator</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span> 
    <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> noise_dim<span class="token operator">=</span>NOISE_DIM<span class="token punctuation">)</span><span class="token punctuation">:</span>
        <span class="token builtin">super</span><span class="token punctuation">(</span>build_dc_generator<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
        self<span class="token punctuation">.</span>fc <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>noise_dim<span class="token punctuation">,</span> <span class="token number">1024</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>BatchNorm1d<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">1024</span><span class="token punctuation">,</span> <span class="token number">7</span> <span class="token operator">*</span> <span class="token number">7</span> <span class="token operator">*</span> <span class="token number">128</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>BatchNorm1d<span class="token punctuation">(</span><span class="token number">7</span> <span class="token operator">*</span> <span class="token number">7</span> <span class="token operator">*</span> <span class="token number">128</span><span class="token punctuation">)</span>
        <span class="token punctuation">)</span>
        
        self<span class="token punctuation">.</span>conv <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>
            nn<span class="token punctuation">.</span>ConvTranspose2d<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>BatchNorm2d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>ConvTranspose2d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">4</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">,</span> padding<span class="token operator">=</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
            nn<span class="token punctuation">.</span>Tanh<span class="token punctuation">(</span><span class="token punctuation">)</span>
        <span class="token punctuation">)</span>
        
    <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span>
        x <span class="token operator">=</span> self<span class="token punctuation">.</span>fc<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
        x <span class="token operator">=</span> x<span class="token punctuation">.</span>view<span class="token punctuation">(</span>x<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">,</span> <span class="token number">7</span><span class="token punctuation">,</span> <span class="token number">7</span><span class="token punctuation">)</span> <span class="token comment"># reshape 通道是 128,大小是 7x7</span>
        x <span class="token operator">=</span> self<span class="token punctuation">.</span>conv<span class="token punctuation">(</span>x<span class="token punctuation">)</span>
        <span class="token keyword">return</span> x
</code>

  开始训练   训练结果与简单全连接神经网络的对比如下:  
在这里插入图片描述   可以看到,采用卷积神经网络训练的效果要好很多很多。  

4. 【案例二】使用语言模型生成新的文本

  上面一个案例主要是关于新图片的生成,而也有很多深度学习任务的数据集是和文本有关的。我们同样可以利用LSTM来创建新的文本序列。简单来说,我们希望模型能够根据给定的上下文预测下一个词。生成序列化数据这一能力在很多领域都有应用,比如图像标注、语音识别、语言翻译、自动回复邮件等等。   github:Word-level language modeling RNN


开心洋葱 , 版权所有丨如未注明 , 均为原创丨未经授权请勿修改 , 转载请注明【深度学习实战】从零开始深度学习(五):生成对抗网络——深度学习中的非监督学习问题
喜欢 (0)

您必须 登录 才能发表评论!

加载中……