tensorflow学习,MNIST数据集的使用,识别手写数字实战项目(二)
发表于: 2018-06-20 23:18:00 | 已被阅读: 55 | 分类于: tensorflow
前言
在学习各种编程语言时,最经典的入门例子就是打印出 "hello world" 了。对于 tensorflow 而言,与之对应地位的入门实战项目就是使用 MNIST数据集
实现手写数字识别了。
MNIST 数据集
1. MNIST 数据集的下载
MNIST 数据集的官方网站:
http://yann.lecun.com/exdb/mnist/
手动下载,也不慢。也可使用 python 代码直接下载:
#encoding=utf8
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
上面的
"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
import gzip
import os
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# pylint: enable=unused-import
执行python脚本,也可以下载,下载在 py 文件同目录的
2. MNIST 数据集简介
MNIST 数据集分为训练(
每一个手写数字是 28X28=784 的图片。所以,
手写数字识别
1. 搭建识别模型
这里手写数字识别使用
已经知道的事实是:手写数字图片 x 对应的标签(正确结果)为 y,那么,我们设置一个系数 w (权重系数),和一个偏置 b,肯定可以满足如下关系:
y = wx + b
使用 softmax 函数,则有
y = softmax(wx+b)
作为入门,暂且不提为啥使用 softmax 函数。如果提升到多维,则有:
那么,x w b y 的 tensorflow 的 python 描述代码可以如下写:
import tensorflow as tf
x = tf.placeholder("float", [None, 784]) # None 表示不关心该纬度
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
上面的 y 是计算得到的值,下面引入
y_ = tf.placeholder("float", [None,10])
引入
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) # 计算交叉熵
2. 训练模型
训练方式采取经典的反向传播法,tensorflow 使用一行代码就可以描述
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
即,使用 0.01 的学习因子,使 cross_entropy 尽可能小。
下面就可以初始化,训练了,要在 session 里训练,这点可以参考上一节:
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000): # 训练 1000 次
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # feed 数据
3. 测试模型
训练完成后,当然需要测试其准确性,直接上代码,主要利用了 tf.argmax 函数,它返回给出某个tensor对象在某一维上 的其数据最大值所在的索引值,因为本节使用的索引是特殊的 10 维向量,所以下面的代码应该非常好理解才对。
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
完整代码
有了上面的解释,下面的代码应该很好理解:
#encoding=utf8
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import tensorflow as tf
x = tf.placeholder("float", [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
最终,程序输出的结果是 0.9137,不算高的正确率,下面几节将提高正确率。