我要努力工作,加油!

tensorflow入门学习,保存训练好的权值和偏置,并且再导入到空网络模型(13)

		发表于: 2018-07-02 22:58:45 | 已被阅读: 27 | 分类于: tensorflow
		
经过前面几节的学习,现在已经可以完成一些简单的回归例子了,例如第11节中的非线性回归。对网络的训练的目的最终是为了让其独立预测数据集本身具有的特性,那么保存下好不容易训练出的网络权值和偏置是非常重要的,可以说,训练后的权值和偏置保存着网络的学习记忆。本节将介绍如何保存,以及如何导入到空网络模型中。

保存训练后的权值和偏置


这里以保存第11节的非线性回归例子中的权值和偏置为例,其实非常简单,tensorflow 提供了

tf.train.Saver()
方法来保存这些数据。只需在 session 前构建该方法,训练后保存即可。

# 定义 saver
...
saver = tf.train.Saver()
...
saver.save(sess, 'net_w_b/save_w_b.ckpt') # 保存到 net_w_b文件夹里的 save_w_b.ckpt 文件里

导入到空网络模型里


要想将保存下来的权值和偏置导入到空的网络模型里使用,显然需要网络模型与训练时的保持一致,无论是数据类型和是形状,以及计算流程,都应该一致,否则即使导入成功,也是没有什么意义的。

#encoding=utf8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt  
from personalTools import AddLayer

# 构建网络
x = tf.placeholder(tf.float32)
    # 隐藏层
layer1 = AddLayer(x, 1, 10, tf.nn.relu)
    # 输出层
y = AddLayer(layer1, 10, 1) 

# 定义 saver
saver = tf.train.Saver()

# 初始化
init = tf.global_variables_initializer()

# 训练
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, 'net_w_b/save_w_b.ckpt')
    print sess.run(y, feed_dict={x:[[2.]]})         # 当 x=2 时,观察结果,用来判断是否导入成功

以上代码文件名:

empty.py
,可以看出,由于不需要训练了,所以我们不需要生成数据,不需要定义损失函数和训练方法,只需要定义权值和偏置的数据类型和形状,以及网络模型即可。我们在第11节非线性回归的是
y = x^2 - 0.5
,因此,当 x=2 时,结果应该在 1.5 附近。

我们执行

empty.py
,得到:

$ python empty_t.py 
[[1.5684248]]

成功了。