我要努力工作,加油!

tensorflow学习,循环神经网络(RNN)相关的函数简介(25)

		发表于: 2018-07-25 22:53:29 | 已被阅读: 28 | 分类于: tensorflow
		
本节,将介绍 tensorflow 实现循环神经网络 RNN 的主要函数。

实现 RNN 的基本单元 RNNCell


RNNCell
是 tensorflow 中的循环神经网络的基本单元,它是一个抽象类,本身不能实例化。它的两个子类,一个
BasicRNNCell
,另一个
BasicLSTMCell
,分别对应经典循环神经网络,和长短记忆循环神经网络。

学习

RNNCell
要重点关注三个地方:

  • 类方法 call
  • 类属性 state_size
  • 类属性 output_size

简单的说,

call
方法就是用来计算隐状态的。关于隐状态可以参考前面两节(
RNN
LSTM
)。而
state_size
output_size
则表示隐状态的大小和输出向量的大小。

output, next_state = call(input, state) 通常 input 的形状是 [batch_size, input_size],所以隐状态的形状 [batch_size, state_size],输出形状[batch_size, output_size]。

定义经典 RNN 单元的方法

rnnCell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print rnnCell.state_size
# 应 state_size = 128

定义 LSTM 单元的方法

lstmCell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
print lstmCell.state_size
# 应 state_size = LSTMStateTuple(c=128, h=128)

多层循环神经网络:MultiRNNCell


很多时候,单层 RNN 的能力有限,需要多层 RNN,在 tensorflow 中,可以使用

tf.nn.rnn_cell.MultiRNNCell
函数建立多层的 RNN,下面是一个示例小 demo

import tensorflow as tf
import numpy as np

# 创建单个cell并堆叠多层
def get_a_cell(lstm_size, keep_prob):
    rnn = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
    return rnn
# 建立 3 层
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])

这里 cell 的 state_size 为 (128,128,128),表示有 3 个隐状态,每个隐状态大小为 128。

MultiRNNCell 也是 RNNCell 的子类,所以它也有 call 方法,和 state_size, output_size 属性。

使用 dynamic_rnn 展开时间维度


对于单个 RNNCell,使用它的 call 方法进行运算时,只在序列时间是前进了一步。如使用 x1,h0 计算得到 h1,根据 x2,h1 计算得到 h2等。如果序列长度为 n,就需要调用 n 次 call 函数。tensorflow 提供了

tf.nn.dynamic_rnn
函数,等价于调用 n 次 call 函数。即通过 {h0, x1, x2, x3, ...} 直接得到 {h1, h2, ...}

outputs, state = tf.nn.dynamic_rnn(cell, inputs)

至此,建立循环神经网络的几个比较重要的 tensorflow 函数就介绍完了,下一节将尝试建立 RNN 网络,训练其作诗。

本节主要参考《21个项目玩转深度学习》。