我要努力工作,加油!

tensorflow学习,继续研究识别CIFAR-10数据集的深度学习网络,官方例程如何对比预测值和标签值(18)

		发表于: 2018-07-11 23:03:48 | 已被阅读: 27 | 分类于: tensorflow
		
上一节我们又详细分析了16节建立的深度学习网络,并且尝试在训练中对比预测值和标签值。这一节将结合官方代码,对上一节的内容进行补充。

再说说建立的识别 CIFAR-10 数据集的深度学习网络


深度学习网络部分的代码如下:

def NetworkRes(images):
    # 第一层卷积
    kernel = tf.Variable(tf.zeros([5,5,3,64]))
    biases = tf.Variable(tf.random_normal( [64])) 
    conv = tf.nn.conv2d(images, kernel, [1,1,1,1], padding='SAME')
    ...
    # 第二层卷积
    ...
    # 第一层全连接层
    ...
    # 第二层全连接层
    ...
    # 第三层全连接层
    weights = tf.Variable(tf.zeros([192, NUM_CLASSES]))  # NUM_CLASSES=10
    biases = tf.Variable(tf.random_normal( [NUM_CLASSES]))
    logits = tf.add(tf.matmul(res2, weights), biases)
    return logits

其实网络已经很清晰了,就是利用两层卷积和池化提取图片特征,再经过几层全连接层降维,将每张图片的输出结果降低至 10 类(CIFAR-10数据集一共10类物体)。所以最终输出的

logits
形状为 [batch_size, 10]。

再来看看标签值:

def Distorted_inputs(filenames, batch_size):    
    # 读入数据
    label, uint8image = ReadImages(filenames)
    ...
    # 返回的是加工过的图片和标签的batch
    return images, tf.reshape(label_batch, [batch_size])

最终的标签值的形状是 [batch_size] 的。网络输出形状为 [batch_size, 10] 的

logits
是如何与形状为 [batch_size] 的
label
对应起来的呢?如果
batch_size=1
,也就是说每张图片输出 10 个值,对应 1 个标签值,这是怎么对应的呢?

官方例程使用了 tensorflow 的

sparse_softmax_cross_entropy_with_logits
方法,直接将形状为 [batch_size, 10] 的
logits
是如何与形状为 [batch_size] 的
label
传输给该方法即可,该方法内部实现了以下三个步骤:

  • sm=nn.softmax(logits)
  • onehot=tf.sparse_to_dense(label,…)
  • nn.sparse_cross_entropy(sm,onehot)

因为 logits 是从网络直接输出的,没有经过归一化,所以先用 softmax 对其归一化。然后将标签值转换为

one-hot
标签值,这样图片经过网络的 10 个输出值与标签的 10 个 bit,就对应起来了。

官方例程如何实现预测值与标签值对比的


这点可以从官方例程的测试用例中得到答案。我们看以下一段代码:

def evaluate():
  """Eval CIFAR-10 for a number of steps."""
  with tf.Graph().as_default() as g:
    # Get images and labels for CIFAR-10.
    eval_data = FLAGS.eval_data == 'test'
    images, labels = cifar10.inputs(eval_data=eval_data)

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    # Calculate predictions.
    top_k_op = tf.nn.in_top_k(logits, labels, 1)

    # Restore the moving average version of the learned variables for eval.
    variable_averages = tf.train.ExponentialMovingAverage(
        cifar10.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.summary.merge_all()

    summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

    while True:
      eval_once(saver, summary_writer, top_k_op, summary_op)
      if FLAGS.run_once:
        break
      time.sleep(FLAGS.eval_interval_secs)

关键就是借助了

top_k_op = tf.nn.in_top_k(logits, labels, 1)
这句,它会自动将 logits 与 labels 对比,得到 batch_size 个 True 或者 False 的结果。再有以下一段代码:

...
      while step < num_iter and not coord.should_stop():
        predictions = sess.run([top_k_op])
        true_count += np.sum(predictions)
        step += 1

      # Compute precision @ 1.
      precision = true_count / total_sample_count
      print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
...

官方例程将对比值加在一起了,表示有 true_count 个正确预测结果。

优化上一节的预测和标签对比


这里我直接将修改后的 Train 函数部分的代码放出来

def Train():
    # 创建文件名 list
    for i in range(1,6):
        filenames = [os.path.join('cifar10_data/cifar-10-batches-bin', 'data_batch_%d.bin' % i)]

    images, label_batch = Distorted_inputs(filenames, BATCH_SIZE)
    logits = NetworkRes(images)
    loss = Loss(logits, label_batch)

    prediction = tf.nn.in_top_k(logits,label_batch,1)

    train_step = tf.train.AdamOptimizer(5e-4).minimize(loss)
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        threads = tf.train.start_queue_runners(sess=sess)
        for i in range(500000):
            # print 'times: %d' % i
            sess.run(train_step)
            if i%10 == 0:
                rp = sess.run(prediction)
                print 'i=%d, loss:' % i, sess.run(loss), 'prediction %d/%d' % (np.sum(rp), BATCH_SIZE)
                print rp
            if i%20000 == 0:
                print "save ckpt----------------------"
                saver.save(sess, 'net_w_b/save_w_b.ckpt') 

我们以一个 batch_size 为单位,将正确预测的个数和打印出来。与此同时,我们每隔 20000 步将网络训练的结果保存一次,万一意外终止训练,可以方便我们下一次接着训练。最终打印结果如下:

$ python t2.py 
save ckpt----------------------
i=10, loss: 2.3245995 prediction 4/50
[False False False  True False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False  True False
  True False False False False  True False False False False False False
 False False]
i=20, loss: 2.341774 prediction 5/50
[False False False False  True False False False False False False  True
 False False False False False False  True False False False False False
 False False False False False False  True False False False False False
 False False False False False False  True False False False False False
 False False]
i=30, loss: 2.3172011 prediction 10/50
[False  True False  True False  True  True False  True False False False
  True False False False False False False False False False False False
 False False False False False False  True  True False False False False
 False  True False False False False False False False False  True False
 False False]
i=40, loss: 2.3874798 prediction 5/50
[False  True False False False False  True False False False False  True
 False False False False False False False  True False False False False
 False False False False False False False False False False False False
 False False False False False  True False False False False False False
 False False]
i=50, loss: 2.2945552 prediction 1/50
[False False False False False False False False False False False False
 False False False  True False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False]
...

可以看出,我们打印出了 batch_size(50) 次预测,正确的次数。刚开始,正确次数比较低,基本都在 10/50 以下。随着训练的进行,发现 loss 减小了,正确率也变高了,这说明我们自己搭建的网络正常工作了!!!

接下来应该用测试集做测试了。

...
i=13960, loss: 1.3711777 prediction 35/50
[ True  True False  True  True  True False  True  True  True  True False
 False  True  True  True False  True  True False False  True  True  True
  True  True  True False False False  True False  True False  True  True
 False  True  True  True  True  True  True  True False  True  True  True
  True False]
i=13970, loss: 0.91767156 prediction 29/50
[ True False  True  True  True  True False False False  True False False
  True  True  True  True  True  True  True  True  True  True  True False
  True False False  True  True False  True False  True False  True False
  True False  True False  True False False False False  True  True  True
 False False]
i=13980, loss: 1.0278193 prediction 30/50
[ True  True False  True False  True False  True  True False  True  True
 False  True  True False  True  True False False  True  True  True False
 False False False False  True  True  True  True  True False  True  True
  True False  True  True  True  True  True False False False False False
  True  True]
i=13990, loss: 1.0174195 prediction 30/50
[False  True False False  True  True  True  True False False  True  True
  True  True False  True False False False  True  True False  True  True
 False  True False False  True  True  True  True  True False  True  True
 False False  True  True  True False  True  True  True False False  True
  True False]
i=14000, loss: 0.9888096 prediction 31/50
[False False False  True  True  True  True False  True  True  True False
  True  True False  True  True  True  True  True  True  True False  True
 False  True  True  True False  True False  True  True  True False  True
  True False False  True False  True False False False False  True False
  True  True]
i=14010, loss: 1.245312 prediction 32/50
[ True False  True  True  True  True False False  True  True False False
 False  True  True False  True  True  True  True  True False False False
  True  True  True  True False  True  True False False False  True  True
 False False  True  True  True  True  True  True  True  True  True  True
 False False]
...