前言
在学习各种编程语言时,最经典的入门例子就是打印出 "hello world" 了。对于 tensorflow 而言,与之对应地位的入门实战项目就是使用 MNIST数据集
实现手写数字识别了。
MNIST 数据集
1. MNIST 数据集的下载
MNIST 数据集的官方网站:http://yann.lecun.com/exdb/mnist/,非常知名的网站,看着却非常简陋,不过能提供好数据就是好网站。
手动下载,也不慢。也可使用 python 代码直接下载:
#encoding=utf8
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
上面的 input_data
的代码如下,记得文件名为 input_data.py
:
"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
import gzip
import os
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# pylint: enable=unused-import
执行python脚本,也可以下载,下载在 py 文件同目录的 MNIST_data
文件夹里,如下图:
2. MNIST 数据集简介
MNIST 数据集分为训练(mnist.train
)和测试(mnist.test
)两部分。每一个数据单元分为图片数据
和标签
两部分,图片数据即为手写数字图片,标签则对应着手写结果。
每一个手写数字是 28X28=784 的图片。所以,mnist.train.images
是一个 [60000, 784] 的张量,60000 表示图片的数量,784 表示每一张图片的数据点数。标签
则为 0~9 的数字,用 1 个 10 维向量表示 10 个数字,例如,用 [1,0,0,0,0,0,0,0,0,0,0] 表示 0,用 [0,0,1,0,0,0,0,0,0,0,0] 表示 2,那么,mnist.train.labels
是一个 [60000, 10] 的张量。
手写数字识别
1. 搭建识别模型
这里手写数字识别使用 softmax回归
方法。
已经知道的事实是:手写数字图片 x 对应的标签(正确结果)为 y,那么,我们设置一个系数 w (权重系数),和一个偏置 b,肯定可以满足如下关系:
y = wx + b
使用 softmax 函数,则有
y = softmax(wx+b)
作为入门,暂且不提为啥使用 softmax 函数。如果提升到多维,则有:
那么,x w b y 的 tensorflow 的 python 描述代码可以如下写:
import tensorflow as tf
x = tf.placeholder("float", [None, 784]) # None 表示不关心该纬度
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
上面的 y 是计算得到的值,下面引入 y_
作为正确值:
y_ = tf.placeholder("float", [None,10])
引入 y_
是为了估计计算值与实际值的接近程度,这个接近程度可以用 交叉熵
表示,估计值和实际值越接近,二者的 交叉熵
越小。交叉熵的公式如下:
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) # 计算交叉熵
2. 训练模型
训练方式采取经典的反向传播法,tensorflow 使用一行代码就可以描述
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
即,使用 0.01 的学习因子,使 cross_entropy 尽可能小。
下面就可以初始化,训练了,要在 session 里训练,这点可以参考上一节:几个基本概念:图,会话,feed,fetch
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000): # 训练 1000 次
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # feed 数据
3. 测试模型
训练完成后,当然需要测试其准确性,直接上代码,主要利用了 tf.argmax 函数,它返回给出某个tensor对象在某一维上
的其数据最大值所在的索引值,因为本节使用的索引是特殊的 10 维向量,所以下面的代码应该非常好理解才对。
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
完整代码
有了上面的解释,下面的代码应该很好理解:
#encoding=utf8
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import tensorflow as tf
x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
最终,程序输出的结果是 0.9137,不算高的正确率,下面几节将提高正确率。
[…] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]
[…] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]
[…] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]
[…] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]
[…] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]
[…] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]
[…] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]
[…] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]
[…] tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二) […]
感谢作者对这一系列文章的分享