• 欢迎访问开心洋葱网站,在线教程,推荐使用最新版火狐浏览器和Chrome浏览器访问本网站,欢迎加入开心洋葱 QQ群
  • 为方便开心洋葱网用户,开心洋葱官网已经开启复制功能!
  • 欢迎访问开心洋葱网站,手机也能访问哦~欢迎加入开心洋葱多维思维学习平台 QQ群
  • 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏开心洋葱吧~~~~~~~~~~~~~!
  • 由于近期流量激增,小站的ECS没能经的起亲们的访问,本站依然没有盈利,如果各位看如果觉着文字不错,还请看官给小站打个赏~~~~~~~~~~~~~!

深度学习模型调优方法详细解析(Deep Learning学习记录)

人工智能 Charzous 1516次浏览 0个评论

深度学习模型的调优,首先需要对各方面进行评估,主要包括定义函数、模型在训练集和测试集拟合效果、交叉验证、激活函数和优化算法的选择等。   那如何对我们自己的模型进行判断呢?——通过模型训练跑代码,我们可以分别从训练集和测试集上看到这个模型造成的损失大小(loss),还有它的精确率(accuracy)。

目录

  • 前言
  • 1、定义模型函数
  • 2、交叉验证(Cross-validation)
  • 3、优化算法
  • 4、激活函数(activation)
  • 5、dropout
  • 6、early stopping
  • 模型训练实战案例

前言

  1. 最初可从分析数据集的特征,选择适合的函数以及优化器。
  2. 其次,模型在训练集上的效果。至少模型在训练集上得到理想的效果,模型优化方向:激活函数(activation function)、学习率(learning rate)。
  3. 再者,模型在训练集上得到好的效果之后,对测试集上的拟合程度更为重要,因为我们训练之后目的就是去预测并达到好的结果。模型优化方向:正则化(regularization)、丢弃参数(dropout)、提前停止训练(early stopping)。

  以下介绍具体的方法和过程,图文结合,一目了然,这样方便学习记忆!

1、定义模型函数

  衡量方法通常用到真实f与预测f*的方差(variance)和偏差(bias),最后对数据集的拟合程度可分为4种情况。如下:  

深度学习模型调优方法详细解析(Deep Learning学习记录)

这好比数据模型预测对准靶心后的偏移和分散程度,期望达到的效果就是(low,low),模型的偏差相当于与目标的偏离程度,而方差就是数据之间的分散程度。  
深度学习模型调优方法详细解析(Deep Learning学习记录)   定义一个模型函数后结果会遇到(low,high)和(high,low),显然,如果是(high,high)说明函数完全与模型不匹配,可重新考虑其他函数模型。那出现其他情况如何评估这个函数呢?  

  1. 小偏差,大方差:所谓模型过拟合,在训练数据上得到好的效果,而在测试集上效果并不满意。优化方向:增加数据量;数据正则化处理。增加数据可以提高模型的鲁棒性,不被特殊数据影响整个模型的偏向;正则化是另一种方法,为了减小variance,但直接影响到bias,需要权衡。
  2. 大偏差,小方差:模型过于简单,在训练上没有得到好的效果。优化:增加模型参数(特征),更好去拟合数据。

2、交叉验证(Cross-validation)

  在训练一个模型时候,通常会将数据分为:训练集,测试集,开放集(小型训练集)。这里交叉验证是在训练集上进行的,选择最优模型。  
深度学习模型调优方法详细解析(Deep Learning学习记录)   假设一个划分为:  

深度学习模型调优方法详细解析(Deep Learning学习记录)

这里进行十折交叉验证,每一轮训练去9份数据作为训练集,1份作为测试集,并且每一轮的训练集与测试集对换,实现了所有数据都作为样本训练,所得到的模型避免了过拟合与低拟合的问题。  
交叉验证演示   K折交叉验证类似。 优点: (1)每一轮训练中几乎所有的样本数据用于训练模型,这样最接近原始样本的分布,评估所得的结果比较可靠。 (2)训练模型过程中较少随机特殊因素会影响实验数据,鲁棒性更好。  

3、优化算法

  模型优化算法选择根据要训练的模型以及数据,选择合适的算法,常用优化算法有:Gradient descent,Stochastic gradient descent,Adagrad,Adam,RMSprop。   前两种算法原理好理解,这里给出后面三种的算法原理如图,具体就不介绍了,写代码时候通常直接指定算法就行,算法原理理解就OK啦。   Adagrad  
Adagrad   RMSprop  
RMSprop   Adam  
Adam

4、激活函数(activation)

  选择合适的激活函数对模型训练也有很大的影响,需要适应训练任务,比如任务是二分类、多分类或许更新参数梯度问题等。常用的有:sigmoid、tanh、relu、softmax。   sigmoid  
sigmoid   功能特点:平滑函数,连续可导,适合二分类,存在梯度消失问题。   tanh
tanh     功能特点:与sigmoid相同的缺点,存在梯度消失,梯度下降的速度变慢。一般用在二分问题输出层,不在隐藏层中使用。   relu  
reluc   功能特点:ReLU在神经网络中使用最广泛的激活函数。根据图像x负半轴的特点,relu不会同时激活所有的神经元,如果输入值是负的,ReLU函数会转换为0,而神经元不被激活。这意味着,在一段时间内,只有少量的神经元被激活,神经网络的这种稀疏性使其变得高效且易于计算。   softmax  
softmax   功能特点:又称归一化指数函数。它是二分类函数sigmoid在多分类上的推广,目的是将多分类的结果以概率的形式展现出来。softmax通常在分类器的输出层使用。   在模型训练任务中激活函数通过指定选择,我们理解原理后帮助选择正确的函数,具体无需自己实现。  

5、dropout

  它解决深度神经网络的**过拟合(overfitting)梯度消失(gradient vanishing)**问题。 简单理解就是,当模型在训练集上得到较好的效果,而在测试集效果并不乐观,此时使用dropout对训练时候的参数进行优化调整(减少训练时候的参数),在学习过程中通过将隐含层的部分权重或输出随机归零,降低节点间的相互依赖性,使得模型在测试集上得到较好的结果的一种方法。 相当于运用新的神经网络结构训练模型,在模型训练时候可在每一层指定设计。  

6、early stopping

  在训练时候观察模型在training-set和testing-set上的损失(loss),我们想要得到的模型是在测试时候损失误差更小,训练时候在训练集上可能不是最好的效果,所以需要提前停止保证了模型预测得到较好的结果。理解如图:  
early stopping

模型训练实战案例

 

<code class="prism language-python has-numbering"><span class="token punctuation">(</span>X_train<span class="token punctuation">,</span> Y_train<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token punctuation">(</span>X_test<span class="token punctuation">,</span> Y_test<span class="token punctuation">)</span> <span class="token operator">=</span> load_data<span class="token punctuation">(</span><span class="token punctuation">)
</span>model <span class="token operator">=</span> keras<span class="token punctuation">.</span>Sequential<span class="token punctuation">(</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span>input_dim<span class="token operator">=</span><span class="token number">28</span> <span class="token operator">*</span> <span class="token number">28</span><span class="token punctuation">,</span> units<span class="token operator">=</span><span class="token number">690</span><span class="token punctuation">,</span>
                activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">)</span>  <span class="token comment"># tanh  activation:Sigmoid、tanh、ReLU、LeakyReLU、pReLU、ELU、maxout、softmax</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span>units<span class="token operator">=</span><span class="token number">690</span><span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span>units<span class="token operator">=</span><span class="token number">690</span><span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">'relu'</span><span class="token punctuation">)</span><span class="token punctuation">)</span>  <span class="token comment"># tanh</span>

model<span class="token punctuation">.</span>add<span class="token punctuation">(</span>Dense<span class="token punctuation">(</span>units<span class="token operator">=</span><span class="token number">10</span><span class="token punctuation">,</span> activation<span class="token operator">=</span><span class="token string">'softmax'</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span><span class="token builtin">compile</span><span class="token punctuation">(</span>loss<span class="token operator">=</span><span class="token string">'categorical_crossentropy'</span><span class="token punctuation">,</span> optimizer<span class="token operator">=</span><span class="token string">'adam'</span><span class="token punctuation">,</span>
              metrics<span class="token operator">=</span><span class="token punctuation">[</span><span class="token string">'accuracy'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>  <span class="token comment"># loss:mse,categorical_crossentropy,optimizer: rmsprop 或 adagrad、SGD(此处推荐)</span>
model<span class="token punctuation">.</span>fit<span class="token punctuation">(</span>X_train<span class="token punctuation">,</span> Y_train<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">100</span><span class="token punctuation">,</span> epochs<span class="token operator">=</span><span class="token number">20</span><span class="token punctuation">)</span>
result <span class="token operator">=</span> model<span class="token punctuation">.</span>evaluate<span class="token punctuation">(</span>X_train<span class="token punctuation">,</span> Y_train<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">10000</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Train ACC:'</span><span class="token punctuation">,</span> result<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
result <span class="token operator">=</span> model<span class="token punctuation">.</span>evaluate<span class="token punctuation">(</span>X_test<span class="token punctuation">,</span> Y_test<span class="token punctuation">,</span> batch_size<span class="token operator">=</span><span class="token number">10000</span><span class="token punctuation">)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Test ACC:'</span><span class="token punctuation">,</span> result<span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
</code>

 

案例中深度学习模型调优总结:  
深度学习模型调优方法详细解析(Deep Learning学习记录)  
深度学习模型调优方法详细解析(Deep Learning学习记录)  

禁止商业转载,非商业转载请注明出处! 本文链接:https://blog.csdn.net/Charzous/article/details/107822055 版权声明:本文为博主原创文章,转载请附上原文出处链接和本声明。


开心洋葱 , 版权所有丨如未注明 , 均为原创丨未经授权请勿修改 , 转载请注明深度学习模型调优方法详细解析(Deep Learning学习记录)
喜欢 (0)

您必须 登录 才能发表评论!

加载中……