我要努力工作,加油!

tensorflow学习,在训练时使用CIFAR-10测试集,实时测试搭建的物体识别网络(19)

		发表于: 2018-07-12 22:44:04 | 已被阅读: 23 | 分类于: tensorflow
		
上一节,我们参照官方例程,在训练网络时,实时的对比标签值和预测值,并将结果打印到终端。但是,不足的是,使用训练集测试训练结果,并没有什么说服力。本节将使用 CIFAR-10 的测试集,在训练时,实时测试训练网络。

总体思路


思路还是非常清晰的,就是将测试时用的 CIFAR-10

训练集数据
,替换成
测试集数据
。但是训练不能中断,因此
训练集数据
不能中断,所以可以另外建立一个队列,专门读入
测试集数据
,这样,我们就可以使用
训练集数据
训练网络,而使用
测试集数据
测试网络。

实现方法


因为我们已经将读入数据的动作封装为函数(代码详情可参照

第 16 节
):

def ReadImages(filenames):
    # 使用 tensorflow 将文件名 list 转成队列(queue)
    filename_queue = tf.train.string_input_producer(filenames)
    # 标签占一个字节
    label_bytes = 1
    ...
    # 创建固定长度的数据 reader
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
    key, value = reader.read(filename_queue)
    ...
    # 将图片从 [深度,高(长),宽] 转成 [高(长),宽, 深度] 形状
    uint8image = tf.transpose(depth_major, [1, 2, 0])

    return label, uint8image

def Distorted_inputs(filenames, batch_size):    
    # 读入数据
    label, uint8image = ReadImages(filenames)
...
    # 产生一个图片 batch
    num_preprocess_threads = 16
    images, label_batch = tf.train.batch(
        [float_image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size)

    # 返回的是加工过的图片和标签batch
    return images, tf.reshape(label_batch, [batch_size])

所以读入

测试集数据
变得非常容易,我们只需将
测试集数据
的文件名加入到文件名队列,再调用已写成的函数即可。

testFilenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'test_batch.bin')]
testimages, testLabels = Distorted_inputs(testFilenames, BATCH_SIZE)
testlogits = NetworkRes(testimages)
prediction = tf.nn.in_top_k(testlogits,testLabels,1)

修改部分的代码如下,还是只有 Train 函数修改了,其他的与第 16 节一致。注意

saver = tf.train.Saver()
的位置,是放在测试集数据相关部分的前面,因为我们不需要保存测试集部分的参数。

def Train():
    # 创建文件名 list
    for i in range(1,6):
        filenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'data_batch_%d.bin' % i)]

    images, label_batch = Distorted_inputs(filenames, BATCH_SIZE)
    logits = NetworkRes(images)
    loss = Loss(logits, label_batch)

    saver = tf.train.Saver()

    # 使用测试集测试网络
    testFilenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'test_batch.bin')]
    testimages, testLabels = Distorted_inputs(testFilenames, BATCH_SIZE)
    testlogits = NetworkRes(testimages)
    prediction = tf.nn.in_top_k(testlogits,testLabels,1)

    train_step = tf.train.AdamOptimizer(5e-5).minimize(loss)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        saver.restore(sess, 'net_w_b/save_w_b.ckpt')
        threads = tf.train.start_queue_runners(sess=sess)
        for i in range(500000):
            # print 'times: %d' % i
            sess.run(train_step)
            if i%10 == 0:
                rp = sess.run(prediction)
                print 'i=%d, loss:' % i, sess.run(loss), 'prediction %d/%d' % (np.sum(rp), BATCH_SIZE)
                print rp
            if i%10000 == 0:
                print "save ckpt----------------------"
                saver.save(sess, 'net_w_b/save_w_b.ckpt') 

运行结果


Train 函数中增加了一行代码

saver.restore(sess, 'net_w_b/save_w_b.ckpt')

这句是将我们之前保存的网络系数加载到网络,这样我们可以接着之前的训练结果继续训练。最终运行脚本,得到以下:

...
i=330, loss: 1.1338959 prediction 4/50
[False  True False False False False False  True False False False False
 False False False False False False False False False False False False
 False False  True False False False False False False False False False
 False  True False False False False False False False False False False
 False False]
i=340, loss: 0.8164284 prediction 4/50
[False  True False False False False False False False  True False False
 False False False False False False False False False False False False
 False  True False False False False False False False False False False
 False False False False False False False False  True False False False
 False False]
i=350, loss: 1.2925113 prediction 9/50
[False  True False False  True False False False False False False  True
  True False  True False False False False  True  True False False False
 False False False False False False False False False False False False
 False False False False False  True False False False  True False False
 False False]
i=360, loss: 0.96927905 prediction 1/50
[False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False  True False False False False
 False False False False False False False False False False False False
 False False]
i=370, loss: 1.0052602 prediction 7/50
[False False  True False False False False  True False  True False False
 False False False False False False False False False  True  True False
 False False False False False False False False  True False False False
 False False  True False False False False False False False False False
 False False]
i=380, loss: 1.0849243 prediction 6/50
[False False False False False False False False False False False False
 False False False False False False  True  True  True False False  True
 False False False False False False False False False False False False
 False False False  True False False False False False False False False
 False  True]
i=390, loss: 1.3486145 prediction 4/50
[False False False False False  True  True False False False False False
  True False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False  True False False False False False
 False False]
...

正确率明显比用

训练集数据
测试时的低,不过这样的测试结果才真正的有参考价值。 接下来就是等待了,让其继续训练,看看正确率是否会上升。