TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络)
源代码/数据集已上传到Github - tensorflow-tutorial-samples
大白话讲解卷积神经网络工作原理,推荐一个bilibili的讲卷积神经网络的视频,up主从youtube搬运过来,用中文讲了一遍。
这篇文章是TensorFlow 2.0 Tutorial入门教程的第五篇文章,介绍如何使用卷积神经网络(Convolutional Neural Network,CNN)来提高mnist手写数字识别的准确性。之前使用了最简单的784x10的神经网络,达到了0.91
的正确性,而这篇文章在使用了卷积神经网络后,正确性达到了0.99
卷积神经网络(Convolutional Neural Network,CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。
卷积神经网络由一个或多个卷积层和顶端的全连通层(对应经典的神经网络)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网络能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网络在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网络,卷积神经网络需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。
——维基百科
1. 安装TensorFlow 2.0
Google与3月发布了TensorFlow 2.0,TensorFlow 2.0 清理了废弃的API,通过减少重复来简化API,并且通过Keras能够轻松地构建模型,从这篇文章开始,教程示例采用TensorFlow 2.0
版本。
或者在这里下载whl包安装:https://pypi.tuna./simple/tensorflow/
2. 代码目录结构
3. CNN模型代码(train.py)
模型定义的前半部分主要使用Keras.layers提供的Conv2D
(卷积)与MaxPooling2D
(池化)函数。
CNN的输入是维度为 (image_height, image_width, color_channels)的张量,mnist数据集是黑白的,因此只有一个color_channel
(颜色通道),一般的彩色图片有3个(R,G,B),熟悉Web前端的同学可能知道,有些图片有4个通道(R,G,B,A),A代表透明度。对于mnist数据集,输入的张量维度就是(28,28,1),通过参数input_shape
传给网络的第一层。
model.summary()
用来打印我们定义的模型的结构。
我们可以看到,每一个Conv2D
和MaxPooling2D
层的输出都是一个三维的张量(height, width, channels)。height和width会逐渐地变小。输出的channel的个数,是由第一个参数(例如,32或64)控制的,随着height和width的变小,channel可以变大(从算力的角度)。
模型的后半部分,是定义输出张量的。layers.Flatten
会将三维的张量转为一维的向量。展开前张量的维度是(3, 3, 64) ,转为一维(576)的向量后,紧接着使用layers.Dense
层,构造了2层全连接层,逐步地将一维向量的位数从576变为64,再变为10。
后半部分相当于是构建了一个隐藏层为64,输入层为576,输出层为10的普通的神经网络。最后一层的激活函数是softmax
,10位恰好可以表达0-9十个数字。
最大值的下标即可代表对应的数字,使用numpy
很容易计算出来:
4. mnist数据集预处理(train.py)
因为mnist数据集国内下载不稳定,因此数据集也同步到了Github仓库。
对mnist数据集的介绍,大家可以参考这个系列的第一篇文章TensorFlow入门(一) - mnist手写数字识别(网络搭建)。
5. 开始训练并保存训练结果(train.py)
在执行python train.py
后,会得到以下的结果:
可以看到,在第一轮训练后,识别准确率达到了0.9536,5轮之后,使用测试集验证,准确率达到了0.9901
在第五轮时,模型参数成功保存在了./ckpt/cp-0005.ckpt
。接下来我们就可以加载保存的模型参数,恢复整个卷积神经网络,进行真实图片的预测了。
6. 图片预测(predict.py)
为了将模型的训练和加载分开,预测的代码写在了predict.py
中。
最终,执行predict.py,可以看到: