tensorflow学习,循环神经网络(RNN)相关的函数简介(25)
发表于: 2018-07-25 22:53:29 | 已被阅读: 28 | 分类于: tensorflow
本节,将介绍 tensorflow 实现循环神经网络 RNN 的主要函数。
实现 RNN 的基本单元 RNNCell
学习
- 类方法 call
- 类属性 state_size
- 类属性 output_size
简单的说,
output, next_state = call(input, state) 通常 input 的形状是 [batch_size, input_size],所以隐状态的形状 [batch_size, state_size],输出形状[batch_size, output_size]。
定义经典 RNN 单元的方法
rnnCell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print rnnCell.state_size
# 应 state_size = 128
定义 LSTM 单元的方法
lstmCell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
print lstmCell.state_size
# 应 state_size = LSTMStateTuple(c=128, h=128)
多层循环神经网络:MultiRNNCell
很多时候,单层 RNN 的能力有限,需要多层 RNN,在 tensorflow 中,可以使用
import tensorflow as tf
import numpy as np
# 创建单个cell并堆叠多层
def get_a_cell(lstm_size, keep_prob):
rnn = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
return rnn
# 建立 3 层
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])
这里 cell 的 state_size 为 (128,128,128),表示有 3 个隐状态,每个隐状态大小为 128。
MultiRNNCell 也是 RNNCell 的子类,所以它也有 call 方法,和 state_size, output_size 属性。
使用 dynamic_rnn 展开时间维度
对于单个 RNNCell,使用它的 call 方法进行运算时,只在序列时间是前进了一步。如使用 x1,h0 计算得到 h1,根据 x2,h1 计算得到 h2等。如果序列长度为 n,就需要调用 n 次 call 函数。tensorflow 提供了
outputs, state = tf.nn.dynamic_rnn(cell, inputs)
至此,建立循环神经网络的几个比较重要的 tensorflow 函数就介绍完了,下一节将尝试建立 RNN 网络,训练其作诗。
本节主要参考《21个项目玩转深度学习》。