tensorflow学习,在训练时使用CIFAR-10测试集,实时测试搭建的物体识别网络(19)

上一节,我们参照官方例程,在训练网络时,实时的对比标签值和预测值,并将结果打印到终端。但是,不足的是,使用训练集测试训练结果,并没有什么说服力。本节将使用 CIFAR-10 的测试集,在训练时,实时测试训练网络。

总体思路


思路还是非常清晰的,就是将测试时用的 CIFAR-10 训练集数据,替换成测试集数据。但是训练不能中断,因此训练集数据不能中断,所以可以另外建立一个队列,专门读入测试集数据,这样,我们就可以使用训练集数据训练网络,而使用测试集数据测试网络。

实现方法


因为我们已经将读入数据的动作封装为函数(代码详情可参照 第 16 节):

def ReadImages(filenames):
    # 使用 tensorflow 将文件名 list 转成队列(queue)
    filename_queue = tf.train.string_input_producer(filenames)
    # 标签占一个字节
    label_bytes = 1
    ...
    # 创建固定长度的数据 reader
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
    key, value = reader.read(filename_queue)
    ...
    # 将图片从 [深度,高(长),宽] 转成 [高(长),宽, 深度] 形状
    uint8image = tf.transpose(depth_major, [1, 2, 0])

    return label, uint8image

def Distorted_inputs(filenames, batch_size):    
    # 读入数据
    label, uint8image = ReadImages(filenames)
...
    # 产生一个图片 batch
    num_preprocess_threads = 16
    images, label_batch = tf.train.batch(
        [float_image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size)

    # 返回的是加工过的图片和标签batch
    return images, tf.reshape(label_batch, [batch_size])

所以读入测试集数据变得非常容易,我们只需将测试集数据的文件名加入到文件名队列,再调用已写成的函数即可。

testFilenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'test_batch.bin')]
testimages, testLabels = Distorted_inputs(testFilenames, BATCH_SIZE)
testlogits = NetworkRes(testimages)
prediction = tf.nn.in_top_k(testlogits,testLabels,1)

修改部分的代码如下,还是只有 Train 函数修改了,其他的与第 16 节一致。注意saver = tf.train.Saver()的位置,是放在测试集数据相关部分的前面,因为我们不需要保存测试集部分的参数。

def Train():
    # 创建文件名 list
    for i in range(1,6):
        filenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'data_batch_%d.bin' % i)]

    images, label_batch = Distorted_inputs(filenames, BATCH_SIZE)
    logits = NetworkRes(images)
    loss = Loss(logits, label_batch)

    saver = tf.train.Saver()

    # 使用测试集测试网络
    testFilenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'test_batch.bin')]
    testimages, testLabels = Distorted_inputs(testFilenames, BATCH_SIZE)
    testlogits = NetworkRes(testimages)
    prediction = tf.nn.in_top_k(testlogits,testLabels,1)

    train_step = tf.train.AdamOptimizer(5e-5).minimize(loss)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        saver.restore(sess, 'net_w_b/save_w_b.ckpt')
        threads = tf.train.start_queue_runners(sess=sess)
        for i in range(500000):
            # print 'times: %d' % i
            sess.run(train_step)
            if i%10 == 0:
                rp = sess.run(prediction)
                print 'i=%d, loss:' % i, sess.run(loss), 'prediction %d/%d' % (np.sum(rp), BATCH_SIZE)
                print rp
            if i%10000 == 0:
                print "save ckpt----------------------"
                saver.save(sess, 'net_w_b/save_w_b.ckpt') 

运行结果


Train 函数中增加了一行代码

saver.restore(sess, 'net_w_b/save_w_b.ckpt')

这句是将我们之前保存的网络系数加载到网络,这样我们可以接着之前的训练结果继续训练。最终运行脚本,得到以下:

...
i=330, loss: 1.1338959 prediction 4/50
[False  True False False False False False  True False False False False
 False False False False False False False False False False False False
 False False  True False False False False False False False False False
 False  True False False False False False False False False False False
 False False]
i=340, loss: 0.8164284 prediction 4/50
[False  True False False False False False False False  True False False
 False False False False False False False False False False False False
 False  True False False False False False False False False False False
 False False False False False False False False  True False False False
 False False]
i=350, loss: 1.2925113 prediction 9/50
[False  True False False  True False False False False False False  True
  True False  True False False False False  True  True False False False
 False False False False False False False False False False False False
 False False False False False  True False False False  True False False
 False False]
i=360, loss: 0.96927905 prediction 1/50
[False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False  True False False False False
 False False False False False False False False False False False False
 False False]
i=370, loss: 1.0052602 prediction 7/50
[False False  True False False False False  True False  True False False
 False False False False False False False False False  True  True False
 False False False False False False False False  True False False False
 False False  True False False False False False False False False False
 False False]
i=380, loss: 1.0849243 prediction 6/50
[False False False False False False False False False False False False
 False False False False False False  True  True  True False False  True
 False False False False False False False False False False False False
 False False False  True False False False False False False False False
 False  True]
i=390, loss: 1.3486145 prediction 4/50
[False False False False False  True  True False False False False False
  True False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False  True False False False False False
 False False]
...

正确率明显比用训练集数据测试时的低,不过这样的测试结果才真正的有参考价值。
接下来就是等待了,让其继续训练,看看正确率是否会上升。

阅读更多:   tensorflow
仅有 1 条评论
  1. […] 在第16节,我们建立了最基本的深度学习网络,并且进行了训练。在第19节,参照官方例子,给出了实时评估训练结果的例子。训练到 30000 步时(batch=100),正确率达到了 70%。本节,将分析一下官方例子的代码。 […]

添加新评论

icon_redface.gificon_idea.gificon_cool.gif2016kuk.gificon_mrgreen.gif2016shuai.gif2016tp.gif2016db.gif2016ch.gificon_razz.gif2016zj.gificon_sad.gificon_cry.gif2016zhh.gificon_question.gif2016jk.gif2016bs.gificon_lol.gif2016qiao.gificon_surprised.gif2016fendou.gif2016ll.gif