tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二)

前言


在学习各种编程语言时,最经典的入门例子就是打印出 "hello world" 了。对于 tensorflow 而言,与之对应地位的入门实战项目就是使用 MNIST数据集 实现手写数字识别了。

MNIST 数据集


1. MNIST 数据集的下载

MNIST 数据集的官方网站:http://yann.lecun.com/exdb/mnist/,非常知名的网站,看着却非常简陋,不过能提供好数据就是好网站。

手动下载,也不慢。也可使用 python 代码直接下载:

#encoding=utf8
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

上面的 input_data 的代码如下,记得文件名为 input_data.py:

"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=unused-import
import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# pylint: enable=unused-import

执行python脚本,也可以下载,下载在 py 文件同目录的 MNIST_data 文件夹里,如下图:

2. MNIST 数据集简介

MNIST 数据集分为训练(mnist.train)和测试(mnist.test)两部分。每一个数据单元分为图片数据标签两部分,图片数据即为手写数字图片,标签则对应着手写结果。

每一个手写数字是 28X28=784 的图片。所以,mnist.train.images 是一个 [60000, 784] 的张量,60000 表示图片的数量,784 表示每一张图片的数据点数。标签则为 0~9 的数字,用 1 个 10 维向量表示 10 个数字,例如,用 [1,0,0,0,0,0,0,0,0,0,0] 表示 0,用 [0,0,1,0,0,0,0,0,0,0,0] 表示 2,那么,mnist.train.labels 是一个 [60000, 10] 的张量。

手写数字识别


1. 搭建识别模型

这里手写数字识别使用 softmax回归 方法。

已经知道的事实是:手写数字图片 x 对应的标签(正确结果)为 y,那么,我们设置一个系数 w (权重系数),和一个偏置 b,肯定可以满足如下关系:

y = wx + b
使用 softmax 函数,则有
y = softmax(wx+b)

作为入门,暂且不提为啥使用 softmax 函数。如果提升到多维,则有:

那么,x w b y 的 tensorflow 的 python 描述代码可以如下写:

import tensorflow as tf
x = tf.placeholder("float", [None, 784])    # None 表示不关心该纬度
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)

上面的 y 是计算得到的值,下面引入 y_ 作为正确值:

y_ = tf.placeholder("float", [None,10])

引入 y_ 是为了估计计算值与实际值的接近程度,这个接近程度可以用 交叉熵 表示,估计值和实际值越接近,二者的 交叉熵 越小。交叉熵的公式如下:

cross_entropy = -tf.reduce_sum(y_*tf.log(y))    # 计算交叉熵

2. 训练模型

训练方式采取经典的反向传播法,tensorflow 使用一行代码就可以描述

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

即,使用 0.01 的学习因子,使 cross_entropy 尽可能小。

下面就可以初始化,训练了,要在 session 里训练,这点可以参考上一节:几个基本概念:图,会话,feed,fetch

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):           # 训练 1000 次
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # feed 数据

3. 测试模型

训练完成后,当然需要测试其准确性,直接上代码,主要利用了 tf.argmax 函数,它返回给出某个tensor对象在某一维上
的其数据最大值所在的索引值,因为本节使用的索引是特殊的 10 维向量,所以下面的代码应该非常好理解才对。

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

完整代码


有了上面的解释,下面的代码应该很好理解:

#encoding=utf8
import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import tensorflow as tf
x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

最终,程序输出的结果是 0.9137,不算高的正确率,下面几节将提高正确率。

阅读更多:   tensorflow
已有 10 条评论
  1. […] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]

  2. […] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]

  3. […] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]

  4. […] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]

  5. […] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]

  6. […] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]

  7. […] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]

  8. […] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]

  9. […] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]

  10. anonymous

    感谢作者对这一系列文章的分享 icon_mrgreen.gif

添加新评论

icon_redface.gificon_idea.gificon_cool.gif2016kuk.gificon_mrgreen.gif2016shuai.gif2016tp.gif2016db.gif2016ch.gificon_razz.gif2016zj.gificon_sad.gificon_cry.gif2016zhh.gificon_question.gif2016jk.gif2016bs.gificon_lol.gif2016qiao.gificon_surprised.gif2016fendou.gif2016ll.gif