tensorflow学习,继续研究识别CIFAR-10数据集的深度学习网络,官方例程如何对比预测值和标签值(18)
发表于: 2018-07-11 23:03:48 | 已被阅读: 27 | 分类于: tensorflow
再说说建立的识别 CIFAR-10 数据集的深度学习网络
深度学习网络部分的代码如下:
def NetworkRes(images):
# 第一层卷积
kernel = tf.Variable(tf.zeros([5,5,3,64]))
biases = tf.Variable(tf.random_normal( [64]))
conv = tf.nn.conv2d(images, kernel, [1,1,1,1], padding='SAME')
...
# 第二层卷积
...
# 第一层全连接层
...
# 第二层全连接层
...
# 第三层全连接层
weights = tf.Variable(tf.zeros([192, NUM_CLASSES])) # NUM_CLASSES=10
biases = tf.Variable(tf.random_normal( [NUM_CLASSES]))
logits = tf.add(tf.matmul(res2, weights), biases)
return logits
其实网络已经很清晰了,就是利用两层卷积和池化提取图片特征,再经过几层全连接层降维,将每张图片的输出结果降低至 10 类(CIFAR-10数据集一共10类物体)。所以最终输出的
再来看看标签值:
def Distorted_inputs(filenames, batch_size):
# 读入数据
label, uint8image = ReadImages(filenames)
...
# 返回的是加工过的图片和标签的batch
return images, tf.reshape(label_batch, [batch_size])
最终的标签值的形状是 [batch_size] 的。网络输出形状为 [batch_size, 10] 的
官方例程使用了 tensorflow 的
- sm=nn.softmax(logits)
- onehot=tf.sparse_to_dense(label,…)
- nn.sparse_cross_entropy(sm,onehot)
因为 logits 是从网络直接输出的,没有经过归一化,所以先用 softmax 对其归一化。然后将标签值转换为
官方例程如何实现预测值与标签值对比的
这点可以从官方例程的测试用例中得到答案。我们看以下一段代码:
def evaluate():
"""Eval CIFAR-10 for a number of steps."""
with tf.Graph().as_default() as g:
# Get images and labels for CIFAR-10.
eval_data = FLAGS.eval_data == 'test'
images, labels = cifar10.inputs(eval_data=eval_data)
# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)
# Calculate predictions.
top_k_op = tf.nn.in_top_k(logits, labels, 1)
# Restore the moving average version of the learned variables for eval.
variable_averages = tf.train.ExponentialMovingAverage(
cifar10.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
# Build the summary operation based on the TF collection of Summaries.
summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
while True:
eval_once(saver, summary_writer, top_k_op, summary_op)
if FLAGS.run_once:
break
time.sleep(FLAGS.eval_interval_secs)
关键就是借助了
...
while step < num_iter and not coord.should_stop():
predictions = sess.run([top_k_op])
true_count += np.sum(predictions)
step += 1
# Compute precision @ 1.
precision = true_count / total_sample_count
print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
...
官方例程将对比值加在一起了,表示有 true_count 个正确预测结果。
优化上一节的预测和标签对比
这里我直接将修改后的 Train 函数部分的代码放出来
def Train():
# 创建文件名 list
for i in range(1,6):
filenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'data_batch_%d.bin' % i)]
images, label_batch = Distorted_inputs(filenames, BATCH_SIZE)
logits = NetworkRes(images)
loss = Loss(logits, label_batch)
prediction = tf.nn.in_top_k(logits,label_batch,1)
train_step = tf.train.AdamOptimizer(5e-4).minimize(loss)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
threads = tf.train.start_queue_runners(sess=sess)
for i in range(500000):
# print 'times: %d' % i
sess.run(train_step)
if i%10 == 0:
rp = sess.run(prediction)
print 'i=%d, loss:' % i, sess.run(loss), 'prediction %d/%d' % (np.sum(rp), BATCH_SIZE)
print rp
if i%20000 == 0:
print "save ckpt----------------------"
saver.save(sess, 'net_w_b/save_w_b.ckpt')
我们以一个 batch_size 为单位,将正确预测的个数和打印出来。与此同时,我们每隔 20000 步将网络训练的结果保存一次,万一意外终止训练,可以方便我们下一次接着训练。最终打印结果如下:
$ python t2.py
save ckpt----------------------
i=10, loss: 2.3245995 prediction 4/50
[False False False True False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False True False
True False False False False True False False False False False False
False False]
i=20, loss: 2.341774 prediction 5/50
[False False False False True False False False False False False True
False False False False False False True False False False False False
False False False False False False True False False False False False
False False False False False False True False False False False False
False False]
i=30, loss: 2.3172011 prediction 10/50
[False True False True False True True False True False False False
True False False False False False False False False False False False
False False False False False False True True False False False False
False True False False False False False False False False True False
False False]
i=40, loss: 2.3874798 prediction 5/50
[False True False False False False True False False False False True
False False False False False False False True False False False False
False False False False False False False False False False False False
False False False False False True False False False False False False
False False]
i=50, loss: 2.2945552 prediction 1/50
[False False False False False False False False False False False False
False False False True False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False False False False
False False]
...
可以看出,我们打印出了 batch_size(50) 次预测,正确的次数。刚开始,正确次数比较低,基本都在 10/50 以下。随着训练的进行,发现 loss 减小了,正确率也变高了,这说明我们自己搭建的网络正常工作了!!!
接下来应该用测试集做测试了。
...
i=13960, loss: 1.3711777 prediction 35/50
[ True True False True True True False True True True True False
False True True True False True True False False True True True
True True True False False False True False True False True True
False True True True True True True True False True True True
True False]
i=13970, loss: 0.91767156 prediction 29/50
[ True False True True True True False False False True False False
True True True True True True True True True True True False
True False False True True False True False True False True False
True False True False True False False False False True True True
False False]
i=13980, loss: 1.0278193 prediction 30/50
[ True True False True False True False True True False True True
False True True False True True False False True True True False
False False False False True True True True True False True True
True False True True True True True False False False False False
True True]
i=13990, loss: 1.0174195 prediction 30/50
[False True False False True True True True False False True True
True True False True False False False True True False True True
False True False False True True True True True False True True
False False True True True False True True True False False True
True False]
i=14000, loss: 0.9888096 prediction 31/50
[False False False True True True True False True True True False
True True False True True True True True True True False True
False True True True False True False True True True False True
True False False True False True False False False False True False
True True]
i=14010, loss: 1.245312 prediction 32/50
[ True False True True True True False False True True False False
False True True False True True True True True False False False
True True True True False True True False False False True True
False False True True True True True True True True True True
False False]
...