本节,将介绍 tensorflow 实现循环神经网络 RNN 的主要函数。
实现 RNN 的基本单元 RNNCell
RNNCell
是 tensorflow 中的循环神经网络的基本单元,它是一个抽象类,本身不能实例化。它的两个子类,一个 BasicRNNCell
,另一个BasicLSTMCell
,分别对应经典循环神经网络,和长短记忆循环神经网络。
学习 RNNCell
要重点关注三个地方:
- 类方法 call
- 类属性 state_size
- 类属性 output_size
简单的说,call
方法就是用来计算隐状态的。关于隐状态可以参考前面两节(RNN和LSTM)。而state_size
和output_size
则表示隐状态的大小和输出向量的大小。
output, next_state = call(input, state)
通常 input 的形状是 [batch_size, input_size],所以隐状态的形状 [batch_size, state_size],输出形状[batch_size, output_size]。
定义经典 RNN 单元的方法
rnnCell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print rnnCell.state_size
# 应 state_size = 128
定义 LSTM 单元的方法
lstmCell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
print lstmCell.state_size
# 应 state_size = LSTMStateTuple(c=128, h=128)
多层循环神经网络:MultiRNNCell
很多时候,单层 RNN 的能力有限,需要多层 RNN,在 tensorflow 中,可以使用 tf.nn.rnn_cell.MultiRNNCell
函数建立多层的 RNN,下面是一个示例小 demo
import tensorflow as tf
import numpy as np
# 创建单个cell并堆叠多层
def get_a_cell(lstm_size, keep_prob):
rnn = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
return rnn
# 建立 3 层
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])
这里 cell 的 state_size 为 (128,128,128),表示有 3 个隐状态,每个隐状态大小为 128。
MultiRNNCell 也是 RNNCell 的子类,所以它也有 call 方法,和 state_size, output_size 属性。
使用 dynamic_rnn 展开时间维度
对于单个 RNNCell,使用它的 call 方法进行运算时,只在序列时间是前进了一步。如使用 x1,h0 计算得到 h1,根据 x2,h1 计算得到 h2等。如果序列长度为 n,就需要调用 n 次 call 函数。tensorflow 提供了 tf.nn.dynamic_rnn
函数,等价于调用 n 次 call 函数。即通过 {h0, x1, x2, x3, ...} 直接得到 {h1, h2, ...}
outputs, state = tf.nn.dynamic_rnn(cell, inputs)
至此,建立循环神经网络的几个比较重要的 tensorflow 函数就介绍完了,下一节将尝试建立 RNN 网络,训练其作诗。
本节主要参考《21个项目玩转深度学习》。
[…] 上一节介绍了 tensorflow 关于循环神经网络的方法,接下来几节,将建立LSTM网络,并且训练之,让电脑拥有写诗的能力。在建立网络前,我们需要先确定数据的输入方式,只有确定了数据特点,才能建立高效好用的网络,本节,将分析需要用到的数据特点,并且给出相应得 python 代码。 […]