上一节,介绍了如何建立卷积深度学习网络(CNN),识别和破解captcha图形验证码,在最后,我们提到了训练耗时,事实的确如此啊,在我的云服务器上用 cpu 训练了近 24 小时,才勉强达到 80% 的正确率,本节将进行可视化的测试
,先用 python 的 captcha 库生成几张验证码图片,然后测试训练好的网络。
训练过程中的小插曲
以下几张图是训练过程中我截下来的。
发现了吧,训练到正确率超过 80% 的步数居然才 5000 多,反而变小了。
其实,并不是变小了。训练到大概 30000 步的时候,在我查看终端的输出时,有道词典
给我的终端发送了 ctrl+c
命令!!!直接结束了训练。。。
幸好我留了个心眼,在设计网络时留下了下面这句(全部代码参考上一节):
...
if step % 1000 == 0:
saver.save(sess,"./tmpModel/capcha_model_tmp.ckpt")
...
即每训练 1000 步,就保留一次参数。这样,即使训练杯意外终止了,我也可以接着上一次保存的参数继续训练,不然我就又得从头训练了。
保存验证码图片
我们先在工作目录新建文件夹testImages
,用于存放验证码图片。参考上一节的例子,生成验证码,并且保存到testImages
文件夹的代码可以如下写
# 文件名 genACaptcha.py
from captcha.image import ImageCaptcha
characters = string.digits + string.ascii_uppercase + string.ascii_lowercase
image = ImageCaptcha(width = 160,height = 40)
captcha_str = ''.join(random.sample(characters,4))
img = image.generate_image(captcha_str)
img.save('testImages' + captcha_str + '.jpg')
每次执行
python genACaptcha.py
就会在testImages
文件夹里生成一张验证码图片。
加载训练好的模型
其实就是重复训练时的模型,不再需要计算损失和训练的部分。文件名取为 test.py
#-*- coding:utf-8 -*-
from PIL import Image, ImageFilter
import tensorflow as tf
import numpy as np
import string
import sys
import generate_captcha
import captcha_model
if __name__ == '__main__':
captcha = generate_captcha.generateCaptcha()
width,height,char_num,characters,classes = captcha.get_parameter()
gray_image = Image.open(sys.argv[1]).convert('L') # 这一段是仿造上一节生成验证码 batch 的处理方法
img = np.array(gray_image.getdata())
test_x = np.reshape(img,[height,width,1])/255.0
x = tf.placeholder(tf.float32, [None, height,width,1])
keep_prob = tf.placeholder(tf.float32)
model = captcha_model.captchaModel(width,height,char_num,classes) # 建立和训练时一样的模型
y_conv = model.create_model(x,keep_prob)
predict = tf.argmax(tf.reshape(y_conv, [-1,char_num, classes]),2) # 计算概率
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
saver.restore(sess, './tmpModel/capcha_model_tmp.ckpt')
pre_list = sess.run(predict,feed_dict={x: [test_x], keep_prob: 1})
for i in pre_list:
s = ''
for j in i:
s += characters[j] # 拼接字符
print(s)
测试时,我们只需要执行
python test.py 【要识别的验证码图片名】
测试
首先,我们先生成一张验证码,文件名7f0N.jpg
,内容如下:
然后执行
$ python test.py testImages/7f0N.jpg
7f0N
成功了!当然由于只有 83% 的正确率,CNN 深度学习网络也有识别错的情况发生,例如下面这张验证码,
$ python test.py testImages/0jkZ.jpg
0jF2 # 识别错误
网络继续训练下去,应该还能提高分辨率,目标 99% !!!