tensorflow入门学习,继续巩固,一行一行写出线性回归代码(六)
发表于: 2018-06-25 23:35:17 | 已被阅读: 37 | 分类于: tensorflow
经过第四节和第五节的总结,对 tensorflow 的认识越来越深了。现在觉得它有点像一门特殊的编程语言,如果想使用它,就得先了解它的语法(规则)。虽然说第二节被称为 tensorflow 界的 “hello world”,我还是希望能够利用 tensorflow 做些自认为简单
的事情。所以,本节先通过 python 的 numpy 模块生成一个线性函数,并且用 tensorflow 逼近它。
生成数据
先生成 100 个随机数,随机数类型为 float32,然后,做一个斜率为 0.5,偏置为 0.8 的直线,python 代码如下
#encoding=utf8
import tensorflow as tf
import numpy as np
xdata = np.random.rand(100).astype(np.float32)
ydata = xdata*0.5 + 0.8
可以参照
#encoding=utf8
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
xdata = np.random.rand(100).astype(np.float32)
ydata = xdata*0.5 + 0.8
plt.figure()
plt.plot(xdata,xdata)
plt.show()
执行,直线图如下:
描述 tensorflow 模型
模型依然是简单的
y = w*x + b
权值矩阵 w 默认给了 -1.0 到 1.0 范围内的随机数,偏置 b 则默认全是 0。损失函数定义为计算值
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = w*xdata + b
loss = tf.reduce_mean(tf.square(y-ydata))
train = tf.train.GradientDescentOptimizer(0.3).minimize(loss)
init = tf.global_variables_initializer()
建立 session 会话,训练模型
之前几节提到,tensorflow 是以计算图为元素的,本节的代码都是建立在默认图
sess = tf.Session()
sess.run(init)
for i in range(301):
sess.run(train)
if i%30 == 0:
print i, sess.run(w), sess.run(b)
代码共训练 301 次,每 30 次输出一次结果。这里需要说明的是,
#encoding=utf8
import tensorflow as tf
import numpy as np
xdata = np.random.rand(100).astype(np.float32)
ydata = xdata*0.5 + 0.8
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
b = tf.Variable(tf.zeros([1]))
y = w*xdata + b
loss = tf.reduce_mean(tf.square(y-ydata))
train = tf.train.GradientDescentOptimizer(0.3).minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for i in range(301):
sess.run(train)
if i%30 == 0:
print i, sess.run(w), sess.run(b)
运行之,得到结果如下:
$ python test1.py
0 [0.05677211] [0.7462111]
30 [0.39830765] [0.8555181]
60 [0.467565] [0.8177076]
90 [0.48965475] [0.8056479]
120 [0.49670038] [0.80180144]
150 [0.4989476] [0.80057454]
180 [0.4996643] [0.8001833]
210 [0.49989298] [0.8000584]
240 [0.49996585] [0.80001867]
270 [0.49998906] [0.800006]
300 [0.4999965] [0.8000019]
可以看出,w 和 b 分别非常接近