我要努力工作,加油!

tensorflow入门学习,继续巩固,一行一行写出线性回归代码(六)

		发表于: 2018-06-25 23:35:17 | 已被阅读: 37 | 分类于: tensorflow
		
经过第四节第五节的总结,对 tensorflow 的认识越来越深了。现在觉得它有点像一门特殊的编程语言,如果想使用它,就得先了解它的语法(规则)。虽然说第二节被称为 tensorflow 界的 “hello world”,我还是希望能够利用 tensorflow 做些自认为简单的事情。所以,本节先通过 python 的 numpy 模块生成一个线性函数,并且用 tensorflow 逼近它。

生成数据

先生成 100 个随机数,随机数类型为 float32,然后,做一个斜率为 0.5,偏置为 0.8 的直线,python 代码如下

#encoding=utf8
import tensorflow as tf
import numpy as np

xdata = np.random.rand(100).astype(np.float32)
ydata = xdata*0.5 + 0.8

可以参照

python基础,安装并使用matplotlib库画图
小节,把该直线画出来,有个直观感受。python 代码如下:

#encoding=utf8
import matplotlib.pyplot as plt  
import tensorflow as tf
import numpy as np

xdata = np.random.rand(100).astype(np.float32)
ydata = xdata*0.5 + 0.8

plt.figure()  
plt.plot(xdata,xdata)  
plt.show()  

执行,直线图如下:

描述 tensorflow 模型

模型依然是简单的

y = w*x + b

权值矩阵 w 默认给了 -1.0 到 1.0 范围内的随机数,偏置 b 则默认全是 0。损失函数定义为计算值

y
和实际值
ydata
的平方和的平均值,训练 train 则采用学习因子为 0.3 的梯度下降法,最后初始化所有变量。以下是 python 代码:

w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = w*xdata + b
loss = tf.reduce_mean(tf.square(y-ydata))
train = tf.train.GradientDescentOptimizer(0.3).minimize(loss)

init = tf.global_variables_initializer()

建立 session 会话,训练模型

之前几节提到,tensorflow 是以计算图为元素的,本节的代码都是建立在默认图

tf
上的,
loss
train
init
都是图上的节点(op),计算图都是在
session
里进行的,所以咱们要先建立 session,然后就可以 run(init) 了。

sess = tf.Session()
sess.run(init)
for i in range(301):
    sess.run(train)
    if i%30 == 0:
        print i, sess.run(w), sess.run(b)

代码共训练 301 次,每 30 次输出一次结果。这里需要说明的是,

w
b
是图
tf
的节点,想要观察其表示的值,需要通过 session 的 run 方法。全部 python 代码如下:

#encoding=utf8
import tensorflow as tf
import numpy as np

xdata = np.random.rand(100).astype(np.float32)
ydata = xdata*0.5 + 0.8

w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = w*xdata + b
loss = tf.reduce_mean(tf.square(y-ydata))
train = tf.train.GradientDescentOptimizer(0.3).minimize(loss)

init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)
for i in range(301):
    sess.run(train)
    if i%30 == 0:
        print i, sess.run(w), sess.run(b)

运行之,得到结果如下:

$ python test1.py 

0 [0.05677211] [0.7462111]
30 [0.39830765] [0.8555181]
60 [0.467565] [0.8177076]
90 [0.48965475] [0.8056479]
120 [0.49670038] [0.80180144]
150 [0.4989476] [0.80057454]
180 [0.4996643] [0.8001833]
210 [0.49989298] [0.8000584]
240 [0.49996585] [0.80001867]
270 [0.49998906] [0.800006]
300 [0.4999965] [0.8000019]

可以看出,w 和 b 分别非常接近

0.5
0.8
,成功了。