tensorflow学习,几个基本概念:图,会话,feed,fetch等(一)
发表于: 2018-06-19 23:03:37 | 已被阅读: 46 | 分类于: tensorflow
基本概念
tensorflow 的安装可以参考:
- 在 tensorflow 中,一个
图(grapg)
表示一个计算任务 - 执行图(任务)的上下文,称之为
会话(session)
- tensorflow 的数据单元当然是
张量(tensor)
- tensorflow 通过
变量(variable)
维护状态 - 图中的节点被称之为
op (operation 的缩写)
. 一个 op 获得 0 个或多个 Tensor , 执行计算, 产生 0 个或多个 Tensor . 每个 Tensor 是一个类型化的多维数组 - tensorflow 使用
feed
和fetch
可以为任意的操作赋值,或者从中获取数据。
接下来本节几部分内容都是围绕这几句话阐述的。
tensorflow 的一次基础 python 代码流程
插入一点代码,结合着,再理解上面几句话。
#encoding=utf8
import tensorflow as tf
# 使用 tensorflow 的默认图 tf
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
# 创建一个矩阵乘法 matmul op , 把 'matrix1' 和 'matrix2' 作为输入.
product = tf.matmul(matrix1, matrix2)
# 现在我们在默认图 tf 上添加了 3 个 op(matrix1, matrix2, product)
# 按照上面说的,tensorflow 计算需要在 session 里进行
# 启动 tf 默认图
session = tf.Session()
# 执行计算
res = session.run(product)
print res
# 使用完要关闭 session
session.close()
结合注释,应该大概了解 tensorflow 的一次 python 代码基本流程了。
tensorflow 的变量
在实战项目中,变量(tensor)可以存储神经网络的权重,通过重复运行训练图,更新这个 tensor.
#encoding=utf8
import tensorflow as tf
# 使用 tensorflow 的默认图 tf
# 创建变量,初始化为标量 0
state = tf.Variable(0, name='counter')
# 常量 op
one = tf.constant(1)
# 加法 op
tmp = tf.add(state, one)
# assign() 操作是图所描绘的表达式的一部分, 正如 add() 操作一样.
# 在调用 run() 执行表达式之前, 它并不会真正执行赋值操作.
new = tf.assign(state, tmp)
# 初始化 op 到图中
initOp = tf.global_variables_initializer()
# 创建 session
session = tf.Session()
session.run(initOp)
print session.run(state) # 初始值
for _ in range(3):
session.run(new)
print session.run(state)
相信注释已经非常清楚了。
tensorflow 的 fetch
fetch,顾名思义,就是取数据而已,在上面的例子中,我们都是取出单个 tensor,tensorflow 支持同时取出多个 tensor,以下是一个实例,是本节第一个实例改写的:
#encoding=utf8
import tensorflow as tf
# 使用 tensorflow 的默认图 tf
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
# 添加了 sum op
product = tf.matmul(matrix1, matrix2)
sum = tf.add(matrix1, matrix2)
session = tf.Session()
# 执行计算 乘积 与 和
res = session.run([product, sum])
print res
# 使用完要关闭 session
session.close()
tensorflow 的 feed
feed,看其名字猜测就是给某些变量(tensor)传递数据,官方的解释是:
feed 使用一个 tensor 值临时替换一个操作的输出结果. 你可以提供 feed 数据作为 run() 调用的参数. feed只在调用它的方法内有效, 方法结束, feed 就会消失.
还是以实例说明:
#encoding=utf8
import tensorflow as tf
# 使用 tensorflow 的默认图 tf
input1 = tf.placeholder(tf.types.float32)
input2 = tf.placeholder(tf.types.float32)
# 此时,input1 和 input2 只说明了数据类型,还没有传递数据
output = tf.mul(input1, input2)
with tf.Session() as sess:
# feed 传递数据给 output op 的 input1 和 input2
print sess.run([output], feed_dict={input1:[7.], input2:[2.]})