上一节我们又详细分析了16节建立的深度学习网络,并且尝试在训练中对比预测值和标签值。这一节将结合官方代码,对上一节的内容进行补充。
再说说建立的识别 CIFAR-10 数据集的深度学习网络
深度学习网络部分的代码如下:
def NetworkRes(images):
# 第一层卷积
kernel = tf.Variable(tf.zeros([5,5,3,64]))
biases = tf.Variable(tf.random_normal( [64]))
conv = tf.nn.conv2d(images, kernel, [1,1,1,1], padding='SAME')
...
# 第二层卷积
...
# 第一层全连接层
...
# 第二层全连接层
...
# 第三层全连接层
weights = tf.Variable(tf.zeros([192, NUM_CLASSES])) # NUM_CLASSES=10
biases = tf.Variable(tf.random_normal( [NUM_CLASSES]))
logits = tf.add(tf.matmul(res2, weights), biases)
return logits
其实网络已经很清晰了,就是利用两层卷积和池化提取图片特征,再经过几层全连接层降维,将每张图片的输出结果降低至 10 类(CIFAR-10数据集一共10类物体)。所以最终输出的 logits
形状为 [batch_size, 10]。
再来看看标签值:
def Distorted_inputs(filenames, batch_size):
# 读入数据
label, uint8image = ReadImages(filenames)
...
# 返回的是加工过的图片和标签的batch
return images, tf.reshape(label_batch, [batch_size])
最终的标签值的形状是 [batch_size] 的。网络输出形状为 [batch_size, 10] 的logits
是如何与形状为 [batch_size] 的label
对应起来的呢?如果 batch_size=1
,也就是说每张图片输出 10 个值,对应 1 个标签值,这是怎么对应的呢?
官方例程使用了 tensorflow 的 sparse_softmax_cross_entropy_with_logits
方法,直接将形状为 [batch_size, 10] 的logits
是如何与形状为 [batch_size] 的label
传输给该方法即可,该方法内部实现了以下三个步骤:
- sm=nn.softmax(logits)
- onehot=tf.sparse_to_dense(label,…)
- nn.sparse_cross_entropy(sm,onehot)
因为 logits 是从网络直接输出的,没有经过归一化,所以先用 softmax 对其归一化。然后将标签值转换为 one-hot
标签值,这样图片经过网络的 10 个输出值与标签的 10 个 bit,就对应起来了。
官方例程如何实现预测值与标签值对比的
这点可以从官方例程的测试用例中得到答案。我们看以下一段代码:
def evaluate():
"""Eval CIFAR-10 for a number of steps."""
with tf.Graph().as_default() as g:
# Get images and labels for CIFAR-10.
eval_data = FLAGS.eval_data == 'test'
images, labels = cifar10.inputs(eval_data=eval_data)
# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)
# Calculate predictions.
top_k_op = tf.nn.in_top_k(logits, labels, 1)
# Restore the moving average version of the learned variables for eval.
variable_averages = tf.train.ExponentialMovingAverage(
cifar10.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
# Build the summary operation based on the TF collection of Summaries.
summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
while True:
eval_once(saver, summary_writer, top_k_op, summary_op)
if FLAGS.run_once:
break
time.sleep(FLAGS.eval_interval_secs)
关键就是借助了 top_k_op = tf.nn.in_top_k(logits, labels, 1)
这句,它会自动将 logits 与 labels 对比,得到 batch_size 个 True 或者 False 的结果。再有以下一段代码:
...
while step < num_iter and not coord.should_stop():
predictions = sess.run([top_k_op])
true_count += np.sum(predictions)
step += 1
# Compute precision @ 1.
precision = true_count / total_sample_count
print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
...
官方例程将对比值加在一起了,表示有 true_count 个正确预测结果。
优化上一节的预测和标签对比
这里我直接将修改后的 Train 函数部分的代码放出来
def Train():
# 创建文件名 list
for i in range(1,6):
filenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'data_batch_%d.bin' % i)]
images, label_batch = Distorted_inputs(filenames, BATCH_SIZE)
logits = NetworkRes(images)
loss = Loss(logits, label_batch)
prediction = tf.nn.in_top_k(logits,label_batch,1)
train_step = tf.train.AdamOptimizer(5e-4).minimize(loss)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
threads = tf.train.start_queue_runners(sess=sess)
for i in range(500000):
# print 'times: %d' % i
sess.run(train_step)
if i%10 == 0:
rp = sess.run(prediction)
print 'i=%d, loss:' % i, sess.run(loss), 'prediction %d/%d' % (np.sum(rp), BATCH_SIZE)
print rp
if i%20000 == 0:
print "save ckpt----------------------"
saver.save(sess, 'net_w_b/save_w_b.ckpt')
我们以一个 batch_size 为单位,将正确预测的个数和打印出来。与此同时,我们每隔 20000 步将网络训练的结果保存一次,万一意外终止训练,可以方便我们下一次接着训练。最终打印结果如下:
$ python t2.py
save ckpt----------------------
i=10, loss: 2.3245995 prediction 4/50
[False False False True False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False True False
True False False False False True False False False False False False
False False]
i=20, loss: 2.341774 prediction 5/50
[False False False False True False False False False False False True
False False False False False False True False False False False False
False False False False False False True False False False False False
False False False False False False True False False False False False
False False]
i=30, loss: 2.3172011 prediction 10/50
[False True False True False True True False True False False False
True False False False False False False False False False False False
False False False False False False True True False False False False
False True False False False False False False False False True False
False False]
i=40, loss: 2.3874798 prediction 5/50
[False True False False False False True False False False False True
False False False False False False False True False False False False
False False False False False False False False False False False False
False False False False False True False False False False False False
False False]
i=50, loss: 2.2945552 prediction 1/50
[False False False False False False False False False False False False
False False False True False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False]
...
可以看出,我们打印出了 batch_size(50) 次预测,正确的次数。刚开始,正确次数比较低,基本都在 10/50 以下。随着训练的进行,发现 loss 减小了,正确率也变高了,这说明我们自己搭建的网络正常工作了!!!
接下来应该用测试集做测试了。
...
i=13960, loss: 1.3711777 prediction 35/50
[ True True False True True True False True True True True False
False True True True False True True False False True True True
True True True False False False True False True False True True
False True True True True True True True False True True True
True False]
i=13970, loss: 0.91767156 prediction 29/50
[ True False True True True True False False False True False False
True True True True True True True True True True True False
True False False True True False True False True False True False
True False True False True False False False False True True True
False False]
i=13980, loss: 1.0278193 prediction 30/50
[ True True False True False True False True True False True True
False True True False True True False False True True True False
False False False False True True True True True False True True
True False True True True True True False False False False False
True True]
i=13990, loss: 1.0174195 prediction 30/50
[False True False False True True True True False False True True
True True False True False False False True True False True True
False True False False True True True True True False True True
False False True True True False True True True False False True
True False]
i=14000, loss: 0.9888096 prediction 31/50
[False False False True True True True False True True True False
True True False True True True True True True True False True
False True True True False True False True True True False True
True False False True False True False False False False True False
True True]
i=14010, loss: 1.245312 prediction 32/50
[ True False True True True True False False True True False False
False True True False True True True True True False False False
True True True True False True True False False False True True
False False True True True True True True True True True True
False False]
...
[…] 上一节,我们参照官方例程,在训练网络时,实时的对比标签值和预测值,并将结果打印到终端。但是,不足的是,使用训练集测试训练结果,并没有什么说服力。本节将使用 CIFAR-10 的测试集,在训练时,实时测试训练网络。 […]