• 欢迎访问开心洋葱网站,在线教程,推荐使用最新版火狐浏览器和Chrome浏览器访问本网站,欢迎加入开心洋葱 QQ群
  • 为方便开心洋葱网用户,开心洋葱官网已经开启复制功能!
  • 欢迎访问开心洋葱网站,手机也能访问哦~欢迎加入开心洋葱多维思维学习平台 QQ群
  • 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏开心洋葱吧~~~~~~~~~~~~~!
  • 由于近期流量激增,小站的ECS没能经的起亲们的访问,本站依然没有盈利,如果各位看如果觉着文字不错,还请看官给小站打个赏~~~~~~~~~~~~~!

【tensorflow2.0】fashion mnist 数据集训练

人工智能 我是。 2763次浏览 0个评论

数据集介绍

  使用Fashion MNIST数据集,其中包含10个类别的70,000个灰度图像。图像显示了低分辨率(28 x 28像素)的单个衣​​物,如下所示(图片来自tensorflow官方文档):  
【tensorflow2.0】fashion mnist 数据集训练   图像是28×28 NumPy数组,像素值范围是0到255。标签是整数数组,范围是0到9。这些对应于图像表示的衣服类别:  
【tensorflow2.0】fashion mnist 数据集训练  

代码

 

import tensorflow as tf
import pandas as pd
import matplotlib as mlt
import matplotlib.pyplot as plt
print(tf.__version__)
print(tf.test.is_gpu_available())
# 加载mnist数据集
fashion_mnist = tf.keras.datasets.fashion_mnist
(X_train_all, Y_train_all),(X_test, Y_test) = fashion_mnist.load_data()
X_train_all = X_train_all/255
X_test = X_test/255
# 将训练集拆分出验证集,让模型每跑完一次数据就验证一次准确度
x_valid, x_train  = X_train_all[:5000], X_train_all[5000:]
y_valid, y_train  = Y_train_all[:5000], Y_train_all[5000:]
# 模型构建 使用的是tf.keras.Sequential
# relu:y=max(0,x) 即取0和x中的最大值
# softmax: 将输出向量变成概率分布,例如 x = [x1, x2, x3], 则
#                                     y = [e^x1/sum, e^x2/sum, e^x3/sum],
#                                     sum = e^x1+e^x2+e^x3
model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Flatten(input_shape=(28,28)), # Flatten函数的作用是将输入的二维数组进行展开,使其变成一维的数组
        tf.keras.layers.Dense(256,activation='relu'), # 创建权连接层,激活函数使用relu
        tf.keras.layers.Dropout(0.2),                 # 使用dropout缓解过拟合的发生
        tf.keras.layers.Dense(10, activation='softmax') # 输出层
    ]
)
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy', # 损失函数使用交叉熵
              metrics=['accuracy'])
model.summary() # 打印模型信息
# history记录模型训练过程中的一些值
history = model.fit(x_train, y_train, epochs=5,
                    validation_data=(x_valid,y_valid))
print('history:',history.history)
# 将history中的数据以图片表示出来
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.ylim(0,1)
plt.show()
model.evaluate(X_test,  Y_test, verbose=2)

  模型结构  

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
flatten (Flatten)            (None, 784)               0
_________________________________________________________________
dense (Dense)                (None, 256)               200960
_________________________________________________________________
dropout (Dropout)            (None, 256)               0
_________________________________________________________________
dense_1 (Dense)              (None, 10)                2570
=================================================================
Total params: 203,530
Trainable params: 203,530
Non-trainable params: 0
_________________________________________________________________

  训练过程  

Train on 55000 samples, validate on 5000 samples
Epoch 1/5
55000/55000 [==============================] - 6s 106us/sample - loss: 0.5183 - accuracy: 0.8162 - val_loss: 0.3885 - val_accuracy: 0.8598
Epoch 2/5
55000/55000 [==============================] - 5s 95us/sample - loss: 0.3908 - accuracy: 0.8570 - val_loss: 0.3656 - val_accuracy: 0.8696
Epoch 3/5
55000/55000 [==============================] - 5s 95us/sample - loss: 0.3585 - accuracy: 0.8697 - val_loss: 0.3203 - val_accuracy: 0.8836
Epoch 4/5
55000/55000 [==============================] - 5s 95us/sample - loss: 0.3358 - accuracy: 0.8767 - val_loss: 0.3326 - val_accuracy: 0.8796
Epoch 5/5
55000/55000 [==============================] - 5s 98us/sample - loss: 0.3237 - accuracy: 0.8808 - val_loss: 0.3297 - val_accuracy: 0.8824

 
【tensorflow2.0】fashion mnist 数据集训练    


开心洋葱 , 版权所有丨如未注明 , 均为原创丨未经授权请勿修改 , 转载请注明【tensorflow2.0】fashion mnist 数据集训练
喜欢 (0)

您必须 登录 才能发表评论!

加载中……