tensorflow学习,循环神经网络(RNN)相关的函数简介(25)

本节,将介绍 tensorflow 实现循环神经网络 RNN 的主要函数。

实现 RNN 的基本单元 RNNCell


RNNCell 是 tensorflow 中的循环神经网络的基本单元,它是一个抽象类,本身不能实例化。它的两个子类,一个 BasicRNNCell,另一个BasicLSTMCell,分别对应经典循环神经网络,和长短记忆循环神经网络。

学习 RNNCell 要重点关注三个地方:

  • 类方法 call
  • 类属性 state_size
  • 类属性 output_size

简单的说,call方法就是用来计算隐状态的。关于隐状态可以参考前面两节(RNNLSTM)。而state_sizeoutput_size则表示隐状态的大小和输出向量的大小。

output, next_state = call(input, state)
通常 input 的形状是 [batch_size, input_size],所以隐状态的形状 [batch_size, state_size],输出形状[batch_size, output_size]。

定义经典 RNN 单元的方法

rnnCell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print rnnCell.state_size
# 应 state_size = 128

定义 LSTM 单元的方法

lstmCell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
print lstmCell.state_size
# 应 state_size = LSTMStateTuple(c=128, h=128)

多层循环神经网络:MultiRNNCell


很多时候,单层 RNN 的能力有限,需要多层 RNN,在 tensorflow 中,可以使用 tf.nn.rnn_cell.MultiRNNCell 函数建立多层的 RNN,下面是一个示例小 demo

import tensorflow as tf
import numpy as np

# 创建单个cell并堆叠多层
def get_a_cell(lstm_size, keep_prob):
    rnn = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
    return rnn
# 建立 3 层
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])

这里 cell 的 state_size 为 (128,128,128),表示有 3 个隐状态,每个隐状态大小为 128。

MultiRNNCell 也是 RNNCell 的子类,所以它也有 call 方法,和 state_size, output_size 属性。

使用 dynamic_rnn 展开时间维度


对于单个 RNNCell,使用它的 call 方法进行运算时,只在序列时间是前进了一步。如使用 x1,h0 计算得到 h1,根据 x2,h1 计算得到 h2等。如果序列长度为 n,就需要调用 n 次 call 函数。tensorflow 提供了 tf.nn.dynamic_rnn 函数,等价于调用 n 次 call 函数。即通过 {h0, x1, x2, x3, ...} 直接得到 {h1, h2, ...}

outputs, state = tf.nn.dynamic_rnn(cell, inputs)

至此,建立循环神经网络的几个比较重要的 tensorflow 函数就介绍完了,下一节将尝试建立 RNN 网络,训练其作诗。

本节主要参考《21个项目玩转深度学习》。
阅读更多:   tensorflow
仅有 1 条评论
  1. […] 上一节介绍了 tensorflow 关于循环神经网络的方法,接下来几节,将建立LSTM网络,并且训练之,让电脑拥有写诗的能力。在建立网络前,我们需要先确定数据的输入方式,只有确定了数据特点,才能建立高效好用的网络,本节,将分析需要用到的数据特点,并且给出相应得 python 代码。 […]

添加新评论

icon_redface.gificon_idea.gificon_cool.gif2016kuk.gificon_mrgreen.gif2016shuai.gif2016tp.gif2016db.gif2016ch.gificon_razz.gif2016zj.gificon_sad.gificon_cry.gif2016zhh.gificon_question.gif2016jk.gif2016bs.gificon_lol.gif2016qiao.gificon_surprised.gif2016fendou.gif2016ll.gif