tensorflow学习,在训练时使用CIFAR-10测试集,实时测试搭建的物体识别网络(19)
发表于: 2018-07-12 22:44:04 | 已被阅读: 23 | 分类于: tensorflow
上一节,我们参照官方例程,在训练网络时,实时的对比标签值和预测值,并将结果打印到终端。但是,不足的是,使用训练集测试训练结果,并没有什么说服力。本节将使用 CIFAR-10 的测试集,在训练时,实时测试训练网络。
总体思路
思路还是非常清晰的,就是将测试时用的 CIFAR-10
实现方法
因为我们已经将读入数据的动作封装为函数(代码详情可参照
def ReadImages(filenames):
# 使用 tensorflow 将文件名 list 转成队列(queue)
filename_queue = tf.train.string_input_producer(filenames)
# 标签占一个字节
label_bytes = 1
...
# 创建固定长度的数据 reader
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
key, value = reader.read(filename_queue)
...
# 将图片从 [深度,高(长),宽] 转成 [高(长),宽, 深度] 形状
uint8image = tf.transpose(depth_major, [1, 2, 0])
return label, uint8image
def Distorted_inputs(filenames, batch_size):
# 读入数据
label, uint8image = ReadImages(filenames)
...
# 产生一个图片 batch
num_preprocess_threads = 16
images, label_batch = tf.train.batch(
[float_image, label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size)
# 返回的是加工过的图片和标签batch
return images, tf.reshape(label_batch, [batch_size])
所以读入
testFilenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'test_batch.bin')]
testimages, testLabels = Distorted_inputs(testFilenames, BATCH_SIZE)
testlogits = NetworkRes(testimages)
prediction = tf.nn.in_top_k(testlogits,testLabels,1)
修改部分的代码如下,还是只有 Train 函数修改了,其他的与第 16 节一致。注意
def Train():
# 创建文件名 list
for i in range(1,6):
filenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'data_batch_%d.bin' % i)]
images, label_batch = Distorted_inputs(filenames, BATCH_SIZE)
logits = NetworkRes(images)
loss = Loss(logits, label_batch)
saver = tf.train.Saver()
# 使用测试集测试网络
testFilenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'test_batch.bin')]
testimages, testLabels = Distorted_inputs(testFilenames, BATCH_SIZE)
testlogits = NetworkRes(testimages)
prediction = tf.nn.in_top_k(testlogits,testLabels,1)
train_step = tf.train.AdamOptimizer(5e-5).minimize(loss)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
saver.restore(sess, 'net_w_b/save_w_b.ckpt')
threads = tf.train.start_queue_runners(sess=sess)
for i in range(500000):
# print 'times: %d' % i
sess.run(train_step)
if i%10 == 0:
rp = sess.run(prediction)
print 'i=%d, loss:' % i, sess.run(loss), 'prediction %d/%d' % (np.sum(rp), BATCH_SIZE)
print rp
if i%10000 == 0:
print "save ckpt----------------------"
saver.save(sess, 'net_w_b/save_w_b.ckpt')
运行结果
Train 函数中增加了一行代码
saver.restore(sess, 'net_w_b/save_w_b.ckpt')
这句是将我们之前保存的网络系数加载到网络,这样我们可以接着之前的训练结果继续训练。最终运行脚本,得到以下:
...
i=330, loss: 1.1338959 prediction 4/50
[False True False False False False False True False False False False
False False False False False False False False False False False False
False False True False False False False False False False False False
False True False False False False False False False False False False
False False]
i=340, loss: 0.8164284 prediction 4/50
[False True False False False False False False False True False False
False False False False False False False False False False False False
False True False False False False False False False False False False
False False False False False False False False True False False False
False False]
i=350, loss: 1.2925113 prediction 9/50
[False True False False True False False False False False False True
True False True False False False False True True False False False
False False False False False False False False False False False False
False False False False False True False False False True False False
False False]
i=360, loss: 0.96927905 prediction 1/50
[False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False True False False False False
False False False False False False False False False False False False
False False]
i=370, loss: 1.0052602 prediction 7/50
[False False True False False False False True False True False False
False False False False False False False False False True True False
False False False False False False False False True False False False
False False True False False False False False False False False False
False False]
i=380, loss: 1.0849243 prediction 6/50
[False False False False False False False False False False False False
False False False False False False True True True False False True
False False False False False False False False False False False False
False False False True False False False False False False False False
False True]
i=390, loss: 1.3486145 prediction 4/50
[False False False False False True True False False False False False
True False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False True False False False False False
False False]
...
正确率明显比用