上一节介绍了 tensorflow 关于循环神经网络
的方法,接下来几节,将建立LSTM
网络,并且训练之,让电脑拥有写诗的能力。在建立网络前,我们需要先确定数据的输入方式,只有确定了数据特点,才能建立高效好用的网络,本节,将分析需要用到的数据特点,并且给出相应得 python 代码。
char RNN 基本功能
基础理论这里不说,我们直接以一个实例来了解char RNN
的基本功能。例如有一句很简单的话:
hello!
现在,如果输入序列是{h,e,l,l,o}
,则期望输出为{e,l,l,o,!}
,即,给定一个字符,网络可以生成下一个最有可能的字符,以此类推,可以生成任意长度的文字。如下图(此图的含义可参考第 23 节)
先不谈标点等符号,对于英文,由于它只有 26 个字母,所以建立 26 位的 one-hot 向量即可映射所有单词。但是汉字的种类比较多,可能导致模型过大,所以有以下两种优化方法:
- 取最常用的 N 个汉字。剩下的所有汉字变为单独一类,可以用特殊字符
XX
代替。 - 输入时,可以增加一层 embedding 层。(这个本节先不谈)
至于输出层,中英文都是一样的,都相当于 N 分类问题。
中文数据的处理
接下来计划用于训练网络的数据集是一本诗集,它的大致内容如下:
1. 取最常出现的前 N 个汉字
经过上面的分析,为了减小数据规模,我们只需要对最常用的汉字,而不是所有汉字建模,所以第一步就是对诗集里所有汉字做统计,只有最常出现的前 N 个汉字单独做一类,其他所有不常出现的汉字共同做 1 类。
假设诗集的内容为 text
,可以如下写 python 代码:
vocab = set(text) # 只要独一无二的,重复的只要一个
vocab_count = {}
for word in vocab:
vocab_count[word] = 0
for word in text:
vocab_count[word] += 1 # 某个字出现一次,就加一。这样,它出现的次数越多,对应的值越大
vocab_count_list = []
for word in vocab_count:
vocab_count_list.append((word, vocab_count[word]))
vocab_count_list.sort(key=lambda x: x[1], reverse=True) # lambda 是一个隐函数,x是随便写的,写成其他值也可以,reverse=True 表示降序
# 这个函数的意思是排序,按照 list 的第二个(0表示第一个)元素排序,并且是逆向排序(从大到小)
if len(vocab_count_list) > max_vocab:
vocab_count_list = vocab_count_list[:max_vocab] # 只要前 max_vocab 个(因为排序了,所以目的就是要最常用的 max_vocab 个汉字)
vocab = [x[0] for x in vocab_count_list] # 只要汉字,不要出现的次数了
这样,我们就得到了每一个字出现的次数从高到低的排序了,取前 N 个汉字就方便了。
2. 汉字和数之间的互换方法
建立模型时,数字比汉字更加适合做运算,因此我们需要建立汉字和数字的映射关系,即,每个汉字都有唯一一个数字与自己对应。经过步骤1
的处理,其实这种映射关系很简单。
# 汉字 -> 数字
word_to_int_table = {c: i for i, c in enumerate(vocab)} # {'字':index}
# 数字 -> 汉字
int_to_word_table = dict(enumerate(vocab)) # {index:'字'}
这样,我们就建立了数字与汉字的映射关系,一句话由若干个汉字组成,现在根据映射表,一句话就可以转换成一个数组了。
所以我们可以将这整个过程封装成一个类,以下代码文件名取为read_utils.py
。(也可取其他名字,不过以后几节做工作 import 时,名字也需要修改。否则会提示 import 失败)
#encoding=utf8
import numpy as np
import copy
import time
import tensorflow as tf
import pickle
# seqs 一个 batch 内的句子条数,steps 表示每个句子的长度
def batch_generator(arr, n_seqs, n_steps):
arr = copy.copy(arr)
batch_size = n_seqs * n_steps
n_batches = int(len(arr) / batch_size)
arr = arr[:batch_size * n_batches] # 这 3 句的目的就是确保 arr 能够正好分为 batch_size 组
arr = arr.reshape((n_seqs, -1))
while True:
np.random.shuffle(arr)
# print(arr.shape, arr.shape[1])
for n in range(0, arr.shape[1], n_steps):
# print(n)
x = arr[:, n:n + n_steps]
y = np.zeros_like(x)
y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0] # y 比 x 延长一个字(例如有 hello, x为hell,则y为ello)
yield x, y
class TextConverter(object):
def __init__(self, text=None, max_vocab=5000, filename=None):
if filename is not None:
with open(filename, 'rb') as f:
self.vocab = pickle.load(f)
else:
vocab = set(text) # 只要独一无二的,重复的只要一个
print(len(vocab))
# max_vocab_process
vocab_count = {}
for word in vocab:
vocab_count[word] = 0
for word in text:
vocab_count[word] += 1 # 某个字出现的次数越多,对应的值越大
vocab_count_list = []
for word in vocab_count:
vocab_count_list.append((word, vocab_count[word]))
vocab_count_list.sort(key=lambda x: x[1], reverse=True) # lambda 是一个隐函数,x是随便写的,写成其他值也可以,reverse=True 表示降序
# 这个函数的意思是排序,按照 list 的第二个(0表示第一个)元素排序,并且是逆向排序(从大到小)
if len(vocab_count_list) > max_vocab:
vocab_count_list = vocab_count_list[:max_vocab] # 只要前 max_vocab 个(因为排序了,所以目的就是要最常用的 max_vocab 个汉字)
vocab = [x[0] for x in vocab_count_list] # 只要汉字,不要出现的次数了
# print(vocab)
self.vocab = vocab
self.word_to_int_table = {c: i for i, c in enumerate(self.vocab)} # {'字':index}
self.int_to_word_table = dict(enumerate(self.vocab)) # {index:'字'}
@property
def vocab_size(self):
return len(self.vocab) + 1
def word_to_int(self, word):
if word in self.word_to_int_table:
return self.word_to_int_table[word]
else:
return len(self.vocab)
def int_to_word(self, index):
if index == len(self.vocab):
return 'XX'
elif index < len(self.vocab):
return self.int_to_word_table[index]
else:
raise Exception('Unknown index!')
def text_to_arr(self, text):
arr = []
for word in text:
arr.append(self.word_to_int(word))
return np.array(arr)
def arr_to_text(self, arr):
words = []
for index in arr:
words.append(self.int_to_word(index))
return "".join(words)
def save_to_file(self, filename):
with open(filename, 'wb') as f:
pickle.dump(self.vocab, f)
经过上面的分析,这段代码虽然略长,但是应该很好理解。
测试
现在,我们写点代码测试下该类。
import codecs
from read_utils import TextConverter, batch_generator
if __name__ == "__main__":
with codecs.open("data/poetry.txt", encoding='utf-8') as f:
text = f.read()
converter = TextConverter(text, 3500)
arr = converter.text_to_arr(text)
g = batch_generator(arr, 100, 100)
x,y=next(g)
for i in x:
print('--------')
print(converter.arr_to_text(i))
print('-------------------------------')
print(converter.arr_to_text(x[0]))
print(converter.arr_to_text(y[0]))
这段代码非常简单了,就是调用我们封装好的方法,最后输出结果如下,可以看出,y 比 x 滞后一个字。
下一节,可以建立网络训练了。