参考资料
- 廖星宇《深度学习入门之PyTorch》
- PyTorch官方文档
- 其他参考资料在文中以超链接的方式给出
目录
-
- 0. 写在前面
- 1. PyTorch基础
- 1.1 张量(Tensor)
- 1.2 变量(Variable)
- 1.3 数据集(Dataset)
- 1.4 模组(nn.Module)
- 1.5 优化(torch.optim)
- 1.6 模型保存与加载
- 2. 案例实践:多层全连接神经网络实现 MNIST 手写数字分类
- 2.1 定义简单三层全连接神经网络
- 2.2 改进网络——增加激活函数
- 2.3 再改进一下网络——添加批标准化
- 2.4 训练网络
- 2.5 三个神经网络模型的比较
0. 写在前面
多层全连接神经网络是深度学习各种复杂神经网络的基础,同时可以借用多层全连接神经网络,对PyTorch的一些基础概念进行一些了解。 关于神经网络的一些知识,廖星宇老师在《深度学习入门之PyTorc》的第二章和第三章都讲的很详细,廖老师也列出了很多参考资料:
学习内容 | 参考资料 |
---|---|
python | (1)《笨方法学Python》(Learn Python the Hard Way)。这本书面向零基础的python学者,通过一系列简单的例子快速入门Python的基本操作。(2)廖雪峰的Python入门,这个系列教程可以更全面地学习Python,对于机器学习方向的同学来说掌握前几张的Python基础即可。(3)Edx: Introduction to Computer Science and Programming Using Python,这是MIT的公开课,以Python作为入门语言,简洁、全面地讲述了计算机科学的内容,适合更进一步的学习。 |
线性代数 | (1)《线性代数应该这样学》(Linear Algebra done right)(2)MIT的线性代数公开课(B站上就有视频:【公开课】麻省理工学院公开课——线性代数)这门课程建议在学习之前先对线性代数的知识体系有个基础的了解,老师讲课的思路比较跳跃,如果一点基础都没有的同学可能会觉得比较难。(3)Coding the Matrix |
机器学习基础 | (1)吴恩达的机器学习入门课程([中英字幕]吴恩达机器学习系列课程)(2)林轩田的机器学习基石和机器学习技法(林轩田机器学习基石(国语),林轩田机器学习技法(Machine Learning Techniques))(3)Udacity 的机器学习纳米学位(4)周志华著的 《 机器学习 》(5)李航著的《统计学习方法 》 这本书真的特别好,对理解一些模型和理论有很大的帮助。(6) Pattern Recognition and Machine Learning |
深度学习 | (1)Udacity 的两个深度学习课程 (2)Coursera 的 Neural 入{etworks for Machine Learning (3)Stanford 的 cs231n (4)Stanford 的 cs224n |
1. PyTorch基础
1.1 张量(Tensor)
张量(Tensor)是PyTorch里面最基本的操作对象,可以和numpy的ndarray相互转换;它们的区别在于前者可以在GPU上运行,而后者只能在CPU上运行。可以通过下面这样的方式来定义一个三行两列给定元素的矩阵:
<code class="prism language-python has-numbering">a <span class="token operator">=</span> torch<span class="token punctuation">.</span>Tensor<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">4</span><span class="token punctuation">,</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">7</span><span class="token punctuation">,</span><span class="token number">9</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'a is {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'a size is {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>a<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token triple-quoted-string string">''' out: a is tensor([[2., 3.], [4., 8.], [7., 9.]]) a size is torch.Size([3, 2]) '''</span> </code>
可以像操作numpy一样用索引来改变张量的值:
<code class="prism language-python has-numbering">a<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span><span class="token number">1</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token number">100</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'a is changed to {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token triple-quoted-string string">''' out: a is changed to tensor([[ 2., 100.], [ 4., 8.], [ 7., 9.]]) '''</span> </code>
也可以实现Tensor与ndarray之间的转换:
<code class="prism language-python has-numbering">numpy_a <span class="token operator">=</span> a<span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'conver to numpy is \n {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>numpy_a<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token triple-quoted-string string">''' out: conver to numpy is [[ 2. 100.] [ 4. 8.] [ 7. 9.]] '''</span> <span class="token keyword">import</span> numpy <span class="token keyword">as</span> np b <span class="token operator">=</span> np<span class="token punctuation">.</span>array<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">4</span><span class="token punctuation">,</span><span class="token number">5</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span> torch_b <span class="token operator">=</span> torch<span class="token punctuation">.</span>from_numpy<span class="token punctuation">(</span>b<span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'from numpy to torch.Tensor is {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>torch_b<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token triple-quoted-string string">''' out: from numpy to torch.Tensor is tensor([[2, 3], [4, 5]], dtype=torch.int32) '''</span> </code>
torch.Tensor 默认的是 torch.FloatTensor 数据类型,也可以定义我们想要的数据类型:
<code class="prism language-python has-numbering">a <span class="token operator">=</span> torch<span class="token punctuation">.</span>LongTensor<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">4</span><span class="token punctuation">,</span><span class="token number">8</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">7</span><span class="token punctuation">,</span><span class="token number">9</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'a is {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token triple-quoted-string string">''' out: a is tensor([[2, 3], [4, 8], [7, 9]]) '''</span> </code>
同样可以创建全为0的张量或者随机创建张量:
<code class="prism language-python has-numbering">a <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'a is {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token triple-quoted-string string">''' out: a is tensor([[0., 0.], [0., 0.], [0., 0.]]) '''</span> a <span class="token operator">=</span> torch<span class="token punctuation">.</span>randn<span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token number">3</span><span class="token punctuation">,</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'a is {}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token triple-quoted-string string">''' out: a is tensor([[ 0.9284, 0.4900], [ 0.3578, -1.0652], [ 0.5255, -1.2100]]) '''</span> </code>
1.2 变量(Variable)
变量(Variable)是PyTorch里面一个比较特殊的概念,其与Tensor没有本质上的区别,想让一个Tensor变成Variable,只需要执行 torch.autograd.Variable(x) 就可以了。但不同的是,Variable 由三个重要的属性构成:data,grad,grad_fn。可以通过data取得Variable里面存储的Tensor值,grad表示的是这个Variable的反向传播梯度,grad_fn表示的是得到这个Variable的操作(加减乘除等)。
<code class="prism language-python has-numbering">X <span class="token operator">=</span> Variable<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>Tensor <span class="token punctuation">(</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token punctuation">)</span> <span class="token punctuation">,</span> requìres_grad<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> </code>
构建一个Variable,需要指明一个参数requìres_grad,这个参数默认为False,当被设置为True时,表示需要对这个Variable求梯度。 1.3 数据集(Dataset) 数据读取和预处理是深度学习问题的基础性的一步。PyTorch提供了很多工具可以帮助实现:
- 继承和重写torch.utils.data.Dataset,例如:
<code class="prism language-python has-numbering"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data<span class="token punctuation">.</span>dataset <span class="token keyword">as</span> dataset <span class="token keyword">import</span> pandas <span class="token keyword">as</span> pd <span class="token keyword">class</span> <span class="token class-name">myDataset</span><span class="token punctuation">(</span>dataset<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> csv_file<span class="token punctuation">,</span> txt_file<span class="token punctuation">,</span> root_dir<span class="token punctuation">,</span> other_file<span class="token punctuation">)</span><span class="token punctuation">:</span> self<span class="token punctuation">.</span>csv_data <span class="token operator">=</span> pd<span class="token punctuation">.</span>read_csv<span class="token punctuation">(</span>csv_file<span class="token punctuation">)</span> <span class="token keyword">with</span> <span class="token builtin">open</span><span class="token punctuation">(</span>txt_file<span class="token punctuation">,</span> <span class="token string">'r'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> f<span class="token punctuation">:</span> data_list <span class="token operator">=</span> f<span class="token punctuation">.</span>readlines<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>txt_data <span class="token operator">=</span> data_list self<span class="token punctuation">.</span>root_dir <span class="token operator">=</span> root_dir <span class="token keyword">def</span> <span class="token function">__len__</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">return</span> <span class="token builtin">len</span><span class="token punctuation">(</span>self<span class="token punctuation">.</span>csv_data<span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">__getitem__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> idx<span class="token punctuation">)</span><span class="token punctuation">:</span> data <span class="token operator">=</span> <span class="token punctuation">(</span>self<span class="token punctuation">.</span>csv_data<span class="token punctuation">(</span>idx<span class="token punctuation">)</span><span class="token punctuation">,</span> self<span class="token punctuation">.</span>txt_data<span class="token punctuation">[</span>idx<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">return</span> data </code>
- 通过 torch.utils.data.DataLoader 进行多线程读取数据
- 可以参考PyTorch源码解读(一)torch.utils.data.DataLoader,在案例中的应用可以参考上一篇文章《简单易懂的深度学习(一):利用PyTorch开始深度学习》中的3.3节。
- torchvision 这个包中还有有关于计算机视觉的数据读取类:ImageFolder ,主要功能是处理图片 具体的使用可以参考文章《简单易懂的深度学习(一):利用PyTorch开始深度学习》中的3.2节
1.4 模组(nn.Module)
nn.Module是利用PyTorch建立神经网络的核心工具之一,神经网络中的层、损失函数都在这个包里面。所有模型的构建都是从nn.Module这个类继承来的。
1.5 优化(torch.optim)
优化是调整模型中参数更新的一种策略。一般来说,优化算法分为两大类:
- 一阶优化算法 最常用的一阶优化算法就是梯度下降。
- 二阶优化算法 二阶优化算法使用的是二阶导数,但是计算成本太高。torch.optim包提供了各种优化算法的实现,如随机梯度下降,以及添加动量的随机棉度下降,自适应学习率等。例如:
<code class="prism language-python has-numbering">optimizer <span class="token operator">=</span> torch<span class="token punctuation">.</span>optim<span class="token punctuation">.</span>SGD<span class="token punctuation">(</span>model<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token punctuation">,</span> lr<span class="token operator">=</span><span class="token number">0.01</span> <span class="token punctuation">,</span> momentum<span class="token operator">=</span><span class="token number">0.9</span> <span class="token punctuation">)</span> <span class="token comment"># 将模型的参数作为需要更新的参数传入优化器,设定学习率是 0.01 ,动量是 0.9 的随机梯度下降</span> optimizer<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 在优化之前需要先将梯度归零</span> loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 反向传播,自动求导得到每个参数的梯度</span> optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 以通过梯度做一步参数更新</span></code>
1.6 模型保存与加载
PyTorch里面提供了两种模型的保存方式,对应的也有两种模型的加载方式。 第一种是保存整个模型的结构信息和参数信息,保存的对象是模型 model;在网络较大的时候加载的时间比较长,同时存储空间也比较大;
<code class="prism language-python has-numbering"><span class="token comment"># 保存</span> torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>model<span class="token punctuation">,</span> <span class="token string">'./model/pth'</span><span class="token punctuation">)</span> <span class="token comment"># 加载</span> load_model <span class="token operator">=</span> torch<span class="token punctuation">.</span>load<span class="token punctuation">(</span><span class="token string">'model.pth'</span><span class="token punctuation">)</span> </code>
第二种是保存模型的参数,保存的对象是模型的状态 model.state dict()
<code class="prism language-python has-numbering"><span class="token comment"># 保存</span> torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'./model_state.pth'</span><span class="token punctuation">)</span> <span class="token comment"># 加载</span> model<span class="token punctuation">.</span>load_state_dic<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>load<span class="token punctuation">(</span><span class="token string">'model_state.pth'</span><span class="token punctuation">)</span><span class="token punctuation">)</span> </code>
2. 案例实践:多层全连接神经网络实现 MNIST 手写数字分类
2.1 定义简单三层全连接神经网络
<code class="prism language-python has-numbering"><span class="token keyword">import</span> torch<span class="token punctuation">.</span>nn <span class="token keyword">as</span> nn <span class="token keyword">class</span> <span class="token class-name">simpleNet</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> in_dim<span class="token punctuation">,</span> n_hidden_1<span class="token punctuation">,</span> n_hidden_2<span class="token punctuation">,</span> out_dim<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token builtin">super</span><span class="token punctuation">(</span>simpleNet<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>layer1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>in_dim<span class="token punctuation">,</span> n_hidden_1<span class="token punctuation">)</span> self<span class="token punctuation">.</span>layer2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>n_hidden_1<span class="token punctuation">,</span> n_hidden_2<span class="token punctuation">)</span> self<span class="token punctuation">.</span>layer3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>n_hidden_2<span class="token punctuation">,</span> out_dim<span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer1<span class="token punctuation">(</span>x<span class="token punctuation">)</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer2<span class="token punctuation">(</span>x<span class="token punctuation">)</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer3<span class="token punctuation">(</span>x<span class="token punctuation">)</span> <span class="token keyword">return</span> x </code>
上面这个就是三层全连接神经网络架构的定义,输入参数包括:输入维度,输入的维度、第一层网络的神经元个数、第二层网络神经元的个数,以及第三层网络(输出层)神经元的个数。 全连接神经网络如下图所示:
2.2 改进网络——增加激活函数
<code class="prism language-python has-numbering"><span class="token keyword">class</span> <span class="token class-name">Activation_Net</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> in_dim<span class="token punctuation">,</span> n_hidden_1<span class="token punctuation">,</span> n_hidden_2<span class="token punctuation">,</span> out_dim<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token builtin">super</span><span class="token punctuation">(</span>Activation_Net<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>layer1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>in_dim<span class="token punctuation">,</span> n_hidden_1<span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>layer2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>n_hidden_1<span class="token punctuation">,</span> n_hidden_2<span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>layer3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>n_hidden_2<span class="token punctuation">,</span> out_dim<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer1<span class="token punctuation">(</span>x<span class="token punctuation">)</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer2<span class="token punctuation">(</span>x<span class="token punctuation">)</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer3<span class="token punctuation">(</span>x<span class="token punctuation">)</span> <span class="token keyword">return</span> x </code>
只需要在每层网络的输出部分添加激活函数就可以了,利用 nn.Sequential() 将网络的层组合到一起作为 self.layer。
2.3 再改进一下网络——添加批标准化
<code class="prism language-python has-numbering"><span class="token keyword">class</span> <span class="token class-name">Batch_Net</span><span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Module<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> in_dim<span class="token punctuation">,</span> n_hidden_1<span class="token punctuation">,</span> n_hidden_2<span class="token punctuation">,</span> out_dim<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token builtin">super</span><span class="token punctuation">(</span>Batch_Net<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>layer1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>in_dim<span class="token punctuation">,</span> n_hidden_1<span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>BatchNorm1d<span class="token punctuation">(</span>n_hidden_1<span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>layer2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>n_hidden_1<span class="token punctuation">,</span> n_hidden_2<span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>BatchNorm1d<span class="token punctuation">(</span>n_hidden_2<span class="token punctuation">)</span><span class="token punctuation">,</span> nn<span class="token punctuation">.</span>ReLU<span class="token punctuation">(</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>layer3 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span>nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>n_hidden_2<span class="token punctuation">,</span> out_dim<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> x<span class="token punctuation">)</span><span class="token punctuation">:</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer1<span class="token punctuation">(</span>x<span class="token punctuation">)</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer2<span class="token punctuation">(</span>x<span class="token punctuation">)</span> x <span class="token operator">=</span> self<span class="token punctuation">.</span>layer3<span class="token punctuation">(</span>x<span class="token punctuation">)</span> <span class="token keyword">return</span> x </code>
同样使用 nn.Sequential( )将 nn. BatchNorm1d ()组合到网络层中。注意批标准化一般放在全连接层的后面、非线性层(激活函数)的前面。BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。
2.4 训练网络
首先导入需要的包,net是上面三个网络模型文件
<code class="prism language-python has-numbering"><span class="token keyword">import</span> torch <span class="token keyword">from</span> torch <span class="token keyword">import</span> nn<span class="token punctuation">,</span> optim <span class="token keyword">from</span> torch<span class="token punctuation">.</span>autograd <span class="token keyword">import</span> Variable <span class="token keyword">from</span> torch<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>data <span class="token keyword">import</span> DataLoader <span class="token keyword">from</span> torchvision <span class="token keyword">import</span> datasets<span class="token punctuation">,</span> transforms <span class="token keyword">import</span> net </code>
接着,设置模型的一些超参数:
<code class="prism language-python has-numbering"><span class="token comment"># 设置超参数</span> batch_size <span class="token operator">=</span> <span class="token number">64</span> learning_rate <span class="token operator">=</span> <span class="token number">1e</span><span class="token operator">-</span><span class="token number">2</span> num_epoches <span class="token operator">=</span> <span class="token number">20</span> </code>
定义预处理方式:
<code class="prism language-python has-numbering"><span class="token comment">#数据预处理</span> data_tf <span class="token operator">=</span> transforms<span class="token punctuation">.</span>Compose<span class="token punctuation">(</span><span class="token punctuation">[</span>transforms<span class="token punctuation">.</span>ToTensor<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> transforms<span class="token punctuation">.</span>Normalize<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">0.5</span><span class="token punctuation">]</span><span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">0.5</span><span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span> </code>
torchvision.transforms提供了很多图片的预处理方法。这里的transforms.ToTensor()将图片转换成PyTorch中处理的Tensor对象,在转换的过程中PyTorch自动将图片标准化了;transforms.Normalize()需要传入两个参数,第一个参数市均值,第二个参数是方差,做的处理就是减均值,再除以方差。transforms.Compose()将各种预处理操作组合在一起。 注意这里由于是灰度图片,所以只有一个通道——transforms.Normalize([0.5], [0.5])。如果是彩色图片,则有三个通道,那么需要用transforms.Normalize([a,b,c], [d,e,f])来表示每个通道对应的均值和方差。 下面是下载数据集,读入数据。使用torch.utils.data.DataLoader 建立一个数据迭代器,传入数据集和 batch_size , 通过 shuffle=True 来表示每次迭代数据的时候是否将数据打乱。
<code class="prism language-python has-numbering"><span class="token comment">#下载训练集-MNIST手写数字训练集</span> train_dataset <span class="token operator">=</span> datasets<span class="token punctuation">.</span>MNIST<span class="token punctuation">(</span>root<span class="token operator">=</span><span class="token string">"./data"</span><span class="token punctuation">,</span> train<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> transform<span class="token operator">=</span>data_tf<span class="token punctuation">,</span> download<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> test_dataset <span class="token operator">=</span> datasets<span class="token punctuation">.</span>MNIST<span class="token punctuation">(</span>root<span class="token operator">=</span><span class="token string">"./data"</span><span class="token punctuation">,</span> train<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> transform<span class="token operator">=</span>data_tf<span class="token punctuation">)</span> train_loader <span class="token operator">=</span> DataLoader<span class="token punctuation">(</span>train_dataset<span class="token punctuation">,</span> batch_size<span class="token operator">=</span>batch_size<span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> test_loader <span class="token operator">=</span> DataLoader<span class="token punctuation">(</span>test_dataset<span class="token punctuation">,</span> batch_size<span class="token operator">=</span>batch_size<span class="token punctuation">,</span> shuffle<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span> </code>
接下来,导入网络,定义损失函数和优化方法:
<code class="prism language-python has-numbering">model <span class="token operator">=</span> net<span class="token punctuation">.</span>simpleNet<span class="token punctuation">(</span><span class="token number">28</span><span class="token operator">*</span><span class="token number">28</span><span class="token punctuation">,</span> <span class="token number">300</span><span class="token punctuation">,</span> <span class="token number">100</span><span class="token punctuation">,</span> <span class="token number">10</span><span class="token punctuation">)</span> <span class="token keyword">if</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span> model <span class="token operator">=</span> model<span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span> criterion <span class="token operator">=</span> nn<span class="token punctuation">.</span>CrossEntropyLoss<span class="token punctuation">(</span><span class="token punctuation">)</span> optimizer <span class="token operator">=</span> optim<span class="token punctuation">.</span>SGD<span class="token punctuation">(</span>model<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> lr <span class="token operator">=</span> learning_rate<span class="token punctuation">)</span> </code>
这里首先构建了一个简单网络,网络的输出层有10个神经元(因为识别手写数字是个多分类问题,共有0-9这10个数字)。 下面就可以开始训练网络了:
<code class="prism language-python has-numbering"><span class="token comment"># 训练模型</span> <span class="token keyword">def</span> <span class="token function">train_model</span><span class="token punctuation">(</span>model<span class="token punctuation">,</span> criterion<span class="token punctuation">,</span> optimizer<span class="token punctuation">,</span> num_epoches<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">for</span> epoch <span class="token keyword">in</span> <span class="token builtin">range</span><span class="token punctuation">(</span>num_epoches<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'epoch {}/{}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>epoch<span class="token punctuation">,</span> num_epoches<span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'-'</span><span class="token operator">*</span><span class="token number">10</span><span class="token punctuation">)</span> <span class="token comment">## training------------------</span> <span class="token comment"># model.train()</span> train_loss <span class="token operator">=</span> <span class="token number">0.0</span> train_acc <span class="token operator">=</span> <span class="token number">0.0</span> <span class="token comment"># 获取数据输入和标签,封装成变量</span> <span class="token keyword">for</span> data <span class="token keyword">in</span> train_loader<span class="token punctuation">:</span> <span class="token comment">#获得一个batch样本</span> img<span class="token punctuation">,</span> label <span class="token operator">=</span> data <span class="token comment"># 获得图片和标签</span> img <span class="token operator">=</span> img<span class="token punctuation">.</span>view<span class="token punctuation">(</span>img<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment">#将图片进行img的转换</span> <span class="token keyword">if</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span> img <span class="token operator">=</span> Variable<span class="token punctuation">(</span>img<span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span> label <span class="token operator">=</span> Variable<span class="token punctuation">(</span>label<span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">else</span><span class="token punctuation">:</span> img<span class="token punctuation">,</span> label <span class="token operator">=</span> Variable<span class="token punctuation">(</span>img<span class="token punctuation">)</span><span class="token punctuation">,</span> Variable<span class="token punctuation">(</span>label<span class="token punctuation">)</span> <span class="token comment"># 梯度参数清零</span> optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 前向</span> out <span class="token operator">=</span> model<span class="token punctuation">(</span>img<span class="token punctuation">)</span> <span class="token comment"># 等价于 out = model.forward(img)</span> loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>out<span class="token punctuation">,</span> label<span class="token punctuation">)</span> _<span class="token punctuation">,</span> preds <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>out<span class="token punctuation">.</span>data<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment"># 反向传播</span> loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span> optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 统计</span> train_loss <span class="token operator">+=</span> loss<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span> train_correct <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>preds <span class="token operator">==</span> label<span class="token punctuation">.</span>data<span class="token punctuation">)</span> train_acc <span class="token operator">+=</span> train_correct <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Train Loss: {:.6f}, Acc: {:.6f}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>train_loss<span class="token operator">/</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>train_loader<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> train_acc<span class="token operator">/</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>train_loader<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment">## evaluation-------------</span> model<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span> eval_loss <span class="token operator">=</span> <span class="token number">0.0</span> eval_acc <span class="token operator">=</span> <span class="token number">0.0</span> <span class="token keyword">for</span> data <span class="token keyword">in</span> test_loader<span class="token punctuation">:</span> img<span class="token punctuation">,</span> label <span class="token operator">=</span> data img <span class="token operator">=</span> img<span class="token punctuation">.</span>view<span class="token punctuation">(</span>img<span class="token punctuation">.</span>size<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token keyword">if</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token keyword">with</span> torch<span class="token punctuation">.</span>no_grad<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">:</span> img <span class="token operator">=</span> Variable<span class="token punctuation">(</span>img<span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span> label <span class="token operator">=</span> Variable<span class="token punctuation">(</span>label<span class="token punctuation">)</span><span class="token punctuation">.</span>cuda<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">else</span><span class="token punctuation">:</span> img <span class="token operator">=</span> Variable<span class="token punctuation">(</span>img<span class="token punctuation">,</span> volatile <span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">)</span> label <span class="token operator">=</span> Variable<span class="token punctuation">(</span>label<span class="token punctuation">,</span> volatile <span class="token operator">=</span> <span class="token boolean">True</span><span class="token punctuation">)</span> out <span class="token operator">=</span> model<span class="token punctuation">(</span>img<span class="token punctuation">)</span> loss <span class="token operator">=</span> criterion<span class="token punctuation">(</span>out<span class="token punctuation">,</span> label<span class="token punctuation">)</span> eval_loss <span class="token operator">+=</span> loss<span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span> _<span class="token punctuation">,</span> preds <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">max</span><span class="token punctuation">(</span>out<span class="token punctuation">.</span>data<span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">)</span> num_correct <span class="token operator">=</span> torch<span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span>preds <span class="token operator">==</span> label<span class="token punctuation">.</span>data<span class="token punctuation">)</span> eval_acc <span class="token operator">+=</span> num_correct <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Test Loss:{:.6f}, Acc: {:.6f}'</span><span class="token punctuation">.</span><span class="token builtin">format</span><span class="token punctuation">(</span>eval_loss<span class="token operator">/</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>test_loader<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">,</span> eval_acc<span class="token operator">/</span><span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>test_loader<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span> </code>
这里的view()函数的功能与reshape类似,用来转换size大小。view()函数作用是将一个多行的Tensor,拼接成一行。案例可以看这里:PyTorch中view()函数
<code class="prism language-python has-numbering"><span class="token keyword">import</span> torch a <span class="token operator">=</span> torch<span class="token punctuation">.</span>Tensor<span class="token punctuation">(</span><span class="token number">2</span><span class="token punctuation">,</span><span class="token number">3</span><span class="token punctuation">)</span> <span class="token keyword">print</span><span class="token punctuation">(</span>a<span class="token punctuation">)</span> <span class="token comment"># tensor([[0.0000, 0.0000, 0.0000],</span> <span class="token comment"># [0.0000, 0.0000, 0.0000]])</span> <span class="token keyword">print</span><span class="token punctuation">(</span>a<span class="token punctuation">.</span>view<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token comment"># tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])</span> </code>
最后运行训练模型的函数,就可以得到训练的网络啦:
<code class="prism language-python has-numbering">train_model<span class="token punctuation">(</span>model<span class="token punctuation">,</span> criterion<span class="token punctuation">,</span> optimizer<span class="token punctuation">,</span> num_epoches<span class="token punctuation">)</span> </code>
2.5 三个神经网络模型的比较
分别对前面创建的三个模型进行训练20轮,得到的准确率如下: