手写数字识别问题作为机器学习领域中的一个经典问题,本文介绍如何使用 keras 构建卷积神经网络模型实现 MNIST 手写数字识别。文本代码只需更换训练集目录,修改图片输入尺寸和类别数量等少量参数,即可直接应用到其他图像分类的问题中。
关于如何解析 MNIST 数据集,可以参看另一片文章:python 读取 MNIST 数据集,并解析为图片文件。解析出的数据集具有如下目录结构:
mnist_data/
|----train/
|----0/
|----0.png
|----1/
|----1.png
...
|----test/
...
卷积神经网络的概念最早出自 19 世纪 60 年代提出的感受野 ( Receptive Field [1]),到了 20 世纪 80 年代,日本科学家提出神经认知机 ( Neocognitron [2] ) 的概念,是卷积神经网络的实现原型。卷积神经网络可以利用空间结构关系减少需要学习的参数量,提高反向传播算法的训练效率。一般的 CNN 有多个卷积层构成,每个卷积层会进行如下操作:
使用 keras API 可以很容易搭建一个卷积神经网络,在此我们搭建的网络结构具有 重复三重的 3 * 3 卷积核加 2 * 2 最大池化层,最后加两个全连接层。
第一层卷积输入图像尺寸,设置64个卷积核,经过 2 * 2 最大池化层后,第二层卷积层需要 128 个卷积核,同理第三层卷积是 256 个卷积核。然后将矩阵平整为一维,接一层 256 输出的全链接层,最后一层输出为类别标签数的全连接层,并且该层使用 ‘softmax’ 作为激活函数。
def build_model(self):
"""构建网络模型"""
setting = self.setting
model = Sequential()
model.add(Conv2D(64, (3, 3), input_shape=setting.input_shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(256, (3, 3)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(setting.lebel_nums))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
return model
使用 keras 的图像预处理类可以很容易加载图像数据并进行数据预处理,详情可以参考 keras 的官方文档:ImageDataGenerator。将训练集划分 20% 作为验证集,将像素点缩放为 0~1.0 的尺度。
def data_generator(self, setting: Setting):
"""图像数据生成器,对生成的字体图像划分训练集和验证集,同时进行数据增强,提高模型泛化能力。"""
datagen = ImageDataGenerator(
rescale=1.0 / 255,
validation_split=0.2
)
train_generator = datagen.flow_from_directory(
setting.train_dir,
target_size=(setting.img_width, setting.img_height),
batch_size=setting.batch_size,
class_mode='categorical',
subset='training'
)
validation_generator = datagen.flow_from_directory(
setting.train_dir,
target_size=(setting.img_width, setting.img_height),
batch_size=setting.batch_size,
class_mode='categorical',
subset='validation'
)
return train_generator, validation_generator
在此,我将程序用到的一些参数的设置放置于一个设置类中:
class Setting:
"""配置类"""
def __init__(self):
self.data_dir = 'mnist_data/'
self.train_dir = self.data_dir + 'train'
self.test_dir = self.data_dir + 'test'
self.img_width = 28
self.img_height = 28
# 在不同系统中,通道数的位置可能不同,因此图像的输入维度略有不同
if backend.image_data_format() == 'channels_first':
self.input_shape = (3, self.img_width, self.img_height)
else:
self.input_shape = (self.img_width, self.img_height, 3)
# 模型训练过程中,根据验证准确率的提升保存模型权重参数
self.checkpoint_path = self.data_dir + 'model_weight/weights-improvement-{epoch:02d}-{val_acc:.5f}.hdf5'
# 这两个参数可以设置加载现有模型继续重新训练
self.restore_model = None
self.is_restore = False
# 模型应用预测可以将所需识别的图像放置于 tmp 文件夹中
self.imp_dir = self.data_dir + 'tmp'
# 模型训练的配置
self.lebel_nums = 10
self.epoches = 20 # 训练的总轮次
self.save_epoches = 1 # 模型保存的检验轮次
self.batch_size = 32
self.initial_epoch = 0
全部代码见 mnist_keras.py 文件。
训练模型只需以如下命令运行 mnist_keras.py 文件:
python mnist_keras.py --mode train
在 GPU 中训练可以很快就能训练结束,第 8 轮中验证集正确率达到 0.98825 后便不再提升。
用测试集评估模型:
python mnist_keras.py --mode evaluate
得到 0.9895 的正确率。
在应用模型预测时,将图片放置于 tmp 文件夹中
运行模型预测
python mnist_keras.py --mode predict
运行后得到预测结果:
mnist_dat/tmp/8.png
预测结果为:8
mnist_dat/tmp/9.png
预测结果为:9
mnist_dat/tmp/4.png
预测结果为:4
mnist_dat/tmp/5.png
预测结果为:5
mnist_dat/tmp/7.png
预测结果为:7
mnist_dat/tmp/6.png
预测结果为:6
mnist_dat/tmp/2.png
预测结果为:2
mnist_dat/tmp/3.png
预测结果为:3
mnist_dat/tmp/1.png
预测结果为:1
mnist_dat/tmp/0.png
预测结果为:0
[1] Sherrington C S. Observations on the scratch-reflex in the spinal dog[J]. Journal of Physiology, 1906, 34(1-2):1.
[2] Fukushima K, Miyake S, Ito T. Neocognitron: a neural network model for a mechanism of visual pattern recognition[M]// Competition and Cooperation in Neural Nets. Springer Berlin Heidelberg, 1982:826-834.
[3] Krizhevsky A, Sutskever I, Hinton G E. ImageNet classification with deep convolutional neural networks[J]. Communications of the Acm, 2012, 60(2):2012.
[4] Lecun Y, Bottou L, Bengio Y, et al. Gradient-based learning applied to document recognition[J]. Proceedings of the IEEE, 1998, 86(11):2278-2324.