tensorflow入门学习,保存训练好的权值和偏置,并且再导入到空网络模型(13)
发表于: 2018-07-02 22:58:45 | 已被阅读: 27 | 分类于: tensorflow
经过前面几节的学习,现在已经可以完成一些简单的回归例子了,例如第11节中的非线性回归。对网络的训练的目的最终是为了让其独立预测数据集本身具有的特性,那么保存下好不容易训练出的网络权值和偏置是非常重要的,可以说,训练后的权值和偏置保存着网络的学习记忆。本节将介绍如何保存,以及如何导入到空网络模型中。
保存训练后的权值和偏置
这里以保存第11节的非线性回归例子中的权值和偏置为例,其实非常简单,tensorflow 提供了
# 定义 saver
...
saver = tf.train.Saver()
...
saver.save(sess, 'net_w_b/save_w_b.ckpt') # 保存到 net_w_b文件夹里的 save_w_b.ckpt 文件里
导入到空网络模型里
要想将保存下来的权值和偏置导入到空的网络模型里使用,显然需要网络模型与训练时的保持一致,无论是数据类型和是形状,以及计算流程,都应该一致,否则即使导入成功,也是没有什么意义的。
#encoding=utf8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from personalTools import AddLayer
# 构建网络
x = tf.placeholder(tf.float32)
# 隐藏层
layer1 = AddLayer(x, 1, 10, tf.nn.relu)
# 输出层
y = AddLayer(layer1, 10, 1)
# 定义 saver
saver = tf.train.Saver()
# 初始化
init = tf.global_variables_initializer()
# 训练
with tf.Session() as sess:
sess.run(init)
saver.restore(sess, 'net_w_b/save_w_b.ckpt')
print sess.run(y, feed_dict={x:[[2.]]}) # 当 x=2 时,观察结果,用来判断是否导入成功
以上代码文件名:
我们执行
$ python empty_t.py
[[1.5684248]]
成功了。