First Attempt of CNN on MNIST
这是我模式识别课的一个实验。简而言之,就是识别手写在纸上的数字。
这个工作仔细分析还是比较简单的,我的实现方式是把整个过程分为许多 stage,每个 stage 只做一件单独的事,这样不仅方便调试(你可以看代码卡在了哪一步,或者单独显示哪一步的输出)而且写出来的代码也比较漂亮(不至于出现过了一段时间看不懂的情况)。
我把整个流程分为了以下的几个 stage:
- loader - 加载图片
- thresholding - 去噪声和二值化
- scissor - 裁剪图片
- normalizer - 处理裁剪后的单个数字
- recognizer - 识别裁剪后的单个数字
- marker - 把结果标注在原图片上
前三个步骤都可以直接调用 OpenCV 的函数完成。第四个稍稍费了些脑筋,为了提高准确率,我在裁剪出的每个数字周围添加了一圈补白,使之构成正方形图像(方便在下一个 stage 喂给神经网络);然后用到了一个神奇的 deskew
函数来矫正倾斜;接着缩放成 28x28 的大小(神经网络是用 MNIST 训练的);因为大的数字图片缩小后会出现笔画变得极细的问题,所以最后还需要做一个膨胀操作。
第五步就是最不必说的东西了……我偷了懒没用 TensorFlow,而是用了比较傻瓜的 Keras,不过好处就是构建起来快……可以把更多的时间花费在优化识别率上。
有一个值得一提的小插曲。我的最初版本在 9 的识别率上非常低,经常被误识别成 1 或 7。其原因是大家写 9 的时候会写的比较长,导致缩放成 28x28 后「9」上面的圈会变得极小无比,看起来就像是个 1 或者 7。解决这个问题的办法非常简单,就是在 MNIST 数据集上做了一个小手术:
def transform_nine(x_train, y_train):
x_superset = []
y_superset = []
for x, y in zip(x_train, y_train):
if y == 9 and random.random() < 0.5:
upper = x[0:14,:]
lower = x[14:,:]
stretched = numpy.zeros(shape=(35, 28))
stretched[0:14,:] = upper
stretched[14:,:] = cv2.resize(lower, (28, 21), interpolation=cv2.INTER_AREA)
stretched = cv2.erode(stretched, (7, 7))
stretched = cv2.resize(stretched, (28, 28), interpolation=cv2.INTER_AREA)
x_superset.append(stretched)
y_superset.append(9)
else:
x_superset.append(x)
y_superset.append(y)
return numpy.array(x_superset), numpy.array(y_superset)
这段代码做的事情就是把数据集中一半的 9 的下半部分拉长了,使得它们变得像我们写的 9 缩放成 28x28 的样子。这个做法虽然简单暴力但是效果拔群。听老师讲这种手段还有一个高大上的名字叫 domain adaption 哈哈哈哈哈。最终的效果如下,这是我手写的一堆数字以及程序的识别结果。
代码在 GitHub 上。