CS224N学习笔记(七)—— RNN、LSTM和GRU

一、语言模型

语言模型计算特定序列中多个单词的出现概率。一个 m 个单词的序列 \{w_{1},…,w_{m}\} 的概率定义为 P(w_{1},…,w_{m})。

传统的语言模型为了简化问题,引入了马尔科夫假设,即句子的概率通常是通过待预测单词之前长度为n的窗口建立条件概率来预测:

\begin{equation} P(w_1,…,w_m) = \prod_{i=1}^{i=m} P(w_{i} | w_1, …, w_{i-1}) \approx \prod_{i=1}^{i=m} P(w_{i} | w_{i-n}, …, w_{i-1}) \label{eqn:nat_model} \end{equation}

简单来说就是每次考虑都当前词前面所有的词的信息显然对是很冗余的,因为离当前词越远的词通常和当前词没多少联系了,那么就干脆我们就取离当前词最近的几个词的信息来替代当前词前面所有词的信息。嗯,人之常情,很好理解。

例如,考虑一种情况,有一篇文章是讨论西班牙和法国的历史,然后在文章的某个地方,你读到一句话 “The two country went on a battle”;显然,这句话中提供的信息不足以确定文章所讨论的两个国家的名称(法国和西班牙)。


二、循环神经网络

传统的翻译模型只能以有限窗口大小的前 n 个单词作为条件进行语言模型建模,循环神经网络与其不同,RNN 有能力以语料库中所有前面的单词为条件进行语言模型建模。

下图展示的 RNN 的架构,其中矩形框是在一个时间步的一个隐藏层 :

如上图所示:

  • RNNs包含输入单元(Input units),输入集标记为{x_0, x_1, ..., x_t, x_{t+1},...},是含有 T 个单词的语料库对应的词向量

  • 输出单元(Output units)的输出集标记为{y_0,y_1,...,y_t,y_{t+1}.,..},是在每个时间步 t 全部单词的概率分布输出,用公式表示为\widehat{y} = softmax(W_{(hy)}h_{t})

  • RNNs还包含隐藏单元(Hidden units),{h_0,h_1,...,h_t,h_{t+1},...},表示每个时间步 t 的隐藏层的输出特征的计算关系,可以理解成RNN的记忆

准确来讲,RNN其实就是一个网络的多次复用:

其中:

如果将每一个复用都展开的话:

RNN的一个实际小例子:


三、BPTT

其实从理解上RNN其实很好理解,毕竟很直观,因为想囊获前面所有的信息,所以每次都计算一个h_{t-1}(这里使用网络得到,但是其实最简单的囊获前面所有信息的方式就是累加,只是效果不好而已)。

但RNN麻烦的地方就在于它的求导,由于循环操作,这就导致更新权值的时候会牵一发而动全身。

如下图的RNN:

首先写出前向公式:
\mathbf{h}_t = \phi(\mathbf{W}_{hx} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1})

\mathbf{o}_t = \mathbf{W}_{yh} \mathbf{h}_{t}

损失函数:
L = \frac{1}{T} \sum_{t=1}^T \ell (\mathbf{o}_t, y_t)

其中:T是序列长度,在这里就是3(为了好理解,此次的推导会先按T=3进行,然后再推广到一般)

3.1 对 \mathbf{W}_{yh} 的求导

对于 \mathbf{W}_{yh} 的求导显然不难,但是要注意的是,由于RNN是循环的网络,所有导致每一个时刻的o都对 \mathbf{W}_{yh} 有关,所以求导会变成加和的形式:

\frac{\partial L}{\partial \mathbf{W}_{yh}} = \sum_{t=1}^T \text{prod}(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{W}_{yh}}) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{o}_t} \mathbf{h}_t^\top

矩阵求导的具体矩阵应该怎么乘可以根据左右两边的维度来确定,如在上式中,由于 \frac{\partial L}{\partial \mathbf{W}_{yh}}肯定是(output_num, hidden_num)维度,\frac{\partial L}{\partial \mathbf{o}_t}是标量对向量的求导,不难得到其维度为(output_num, 1),那么剩下的h_t就只能是转置的形式。当然,由于这里不涉及行向量对行向量求导或者列向量对列向量求导,所以可以用简单的矩阵求导法则得到一样的结果,但如果出现行向量对行向量求导或者列向量对列向量求导的情况,还是只能使用这样的维度匹配。


3.2 对\mathbf{W}_{hx} 的求导

对于\mathbf{W}_{hx} 的求导就会变得更复杂一点,因为不仅是o,隐藏层h之间也存在联系。

先考虑最简单的情况,当我们对h_3求导时,由于没有其他ho的牵制,所以可以很简单地写出其表达式:

\frac{\partial L}{\partial \mathbf{h}_3} = \text{prod}(\frac{\partial L}{\partial \mathbf{o}_3}, \frac{\partial \mathbf{o}_3}{\partial \mathbf{h}_3} ) = \mathbf{W}_{yh}^\top \frac{\partial L}{\partial \mathbf{o}_3}

那么我们将情况拓展到 h_2,由于h_3也与其有关,所以最后可以写出表达式:

\frac{\partial L}{\partial \mathbf{h}_2} = \text{prod}(\frac{\partial L}{\partial \mathbf{h}_{3}}, \frac{\partial \mathbf{h}_{3}}{\partial \mathbf{h}_2} ) + \text{prod}(\frac{\partial L}{\partial \mathbf{o}_2}, \frac{\partial \mathbf{o}_2}{\partial \mathbf{h}_2} ) = \mathbf{W}_{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{3}} + \mathbf{W}_{yh}^\top \frac{\partial L}{\partial \mathbf{o}_2}

如果我们将对h_3求导时的情况带入的话:

\frac{\partial L}{\partial \mathbf{h}_2} = \mathbf{W}_{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{3}} + \mathbf{W}_{yh}^\top \frac{\partial L}{\partial \mathbf{o}_2} = \mathbf{W}_{hh}^\top \mathbf{W}_{yh}^\top \frac{\partial L}{\partial \mathbf{o}_3} + \mathbf{W}_{yh}^\top \frac{\partial L}{\partial \mathbf{o}_2}

以此类推,我们就可以得到当t < 3时(或者说T时)的通用表达式:
\frac{\partial L}{\partial \mathbf{h}_t} = \text{prod}(\frac{\partial L}{\partial \mathbf{h}_{t+1}}, \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t} ) + \text{prod}(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{h}_t} ) = \mathbf{W}_{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{t+1}} + \mathbf{W}_{yh}^\top \frac{\partial L}{\partial \mathbf{o}_t}

那么将其他的h带入通用形式变为:

\frac{\partial L}{\partial \mathbf{h}_t} = \sum_{i=t}^T {(\mathbf{W}_{hh}^\top)}^{T-i} \mathbf{W}_{yh}^\top \frac{\partial L}{\partial \mathbf{o}_{T+t-i}}

可以看到这里有一个阶乘的项,这也就是导致RNN更容易产生梯度爆炸和消失的原因。

求出\frac{\partial L}{\partial \mathbf{h}_t} 的通用形式之后,对\mathbf{W}_{hx} 的求导就变得简单了:

\frac{\partial L}{\partial \mathbf{W}_{hx}} = \sum_{t=1}^T \text{prod}(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hx}}) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{x}_t^\top

  • 防止梯度爆炸:
    一种暴力的方法是,当梯度的长度大于某个阈值的时候,将其缩放到某个阈值。虽然在数学上非常丑陋,但实践效果挺好。其直观解释是,在一个只有一个隐藏节点的网络中,损失函数和权值w偏置b构成error surface,其中有一堵墙如下图所示,每次迭代梯度本来是正常的,一次一小步,但遇到这堵墙之后突然梯度爆炸到非常大,可能指向一个莫名其妙的地方(实线长箭头)。但缩放之后,能够把这种误导控制在可接受的范围内(虚线短箭头)。但这种trick无法推广到梯度消失,因为你不想设置一个最低值硬性规定之前的单词都相同重要地影响当前单词。
伪代码
  • 减缓梯度消失
  1. 不去随机初始化 W_{hh},而是初始化为单位矩阵
  2. 使用 Rectified Linear(ReLU)单元代替 sigmoid 函数,ReLU 的导数是 0 或者 1。这样梯度传回神经元的导数是 1,而不会在反向传播了一定的时间步后梯度变小


3.3 对\mathbf{W}_{hh} 的求导

同理,我们可以轻松地写出\frac{\partial L}{\partial \mathbf{W}_{hh}}

\frac{\partial L}{\partial \mathbf{W}_{hh}} = \sum_{t=1}^T \text{prod}(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hh}}) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{h}_{t-1}^\top



到这里,我们基本上完成了RNN的求导,通过如下公式,便可以对RNN进行训练:

\mathbf{W}_{hx} = \mathbf{W}_{hx} - \eta \frac{\partial L}{\partial \mathbf{W}_{hx}}

\mathbf{W}_{hh} = \mathbf{W}_{hh} - \eta \frac{\partial L}{\partial \mathbf{W}_{hh}}

\mathbf{W}_{yh} = \mathbf{W}_{yh} - \eta \frac{\partial L}{\partial \mathbf{W}_{yh}}


四、双向RNN

这部分就不展开了,主要因为理解上其实不难,就是在原来从左到右进行训练的基础上,又加了一个从右到左的顺序,整体原理和推导变化不大:

通过总结过去和未来词表示来预测下一个词的类别关系:

\begin{eqnarray} \overrightarrow{\,h}_{t} &=& f(\overrightarrow{\,W}x_{t}+\overrightarrow{\,V}\overrightarrow{\,h}_{t-1}+\overrightarrow{\,b}) \tag{17} \\ \overleftarrow{h}_{t} &=& f(\overleftarrow{W}x_{t}+\overleftarrow{V}\overleftarrow{h}_{t+1}+\overleftarrow) \tag{18} \\ \widehat{y}_{t} &=& g(Uh_{t}+c) = g(U[\overrightarrow{\,h}_{t};\overleftarrow{h}_{t}]+c) \tag{19} \end{eqnarray}

虽然看起来复杂,不过因为大部分库都已经封装了双向的操作,所以日常使用还是很轻松的,而且推导上其实也和原来的一样,只是要求的东西变多了而已。


五、LSTM

LSTM的全称是Long-Short-Term-Memories,其在原先RNN的基础上引入了门机制:

\begin{eqnarray} i_{t} &=& \sigma(W^{(i)}x_{t}+U^{(i)}h_{t-1}) \tag{Input gate} \\ {f}_{t} &=& \sigma(W^{(f)}x_{t}+U^{(f)}h_{t-1}) \tag{Forget gate} \\ {o}_{t} &=& \sigma(W^{(o)}x_{t}+U^{(o)}h_{t-1}) \tag{Output/Exposure gate} \\ \widetilde{c}_{t} &=& tanh(W^{(c)}x_{t}+U^{(c)}h_{t-1}) \tag{New memory cell} \\ {c}_{t} &=& f_{t} \circ c_{t-1} +i_{t} \circ \widetilde{c}_{t} \tag{Final memory cell} \\ {h}_{t} &=& o_{t} \circ tanh(c_{t}) \nonumber \end{eqnarray}

可以看到一共有4个门:

  • Input Gate:我们看到在生成新的记忆之前,新的记忆的生成阶段不会检查新单词是否重要,这需要输入门函数来做这个判断。输入门使用输入单词和过去的隐藏状态来决定输入值是否值得保留,从而决定该输入值是否加入到新的记忆中

  • Forget Gate:这个门与输入门类似,但是它不能决定输入单词是否有用,而是评估过去的记忆单元是否对当前记忆单元的计算有用,因此,忘记门观测输入单词和过去的隐藏状态并生成f_t用来决定是否使用c_{t-1}

  • New memory generation:生成新的记忆的阶段。我们基本上是用输入单词 x_t 和过去的隐藏状态来生成一个包括新单词 x_t 的新的记忆 \widetilde{c}_{t}(这个算不算门也因人而异,我个人是将其计算成门,所以一共是4个门)

  • Output/Exposure Gate:用处是将最终记忆与隐状态分离开来。记忆 c_t 中的信息不是全部都需要存放到隐状态中,隐状态是个很重要的使用很频繁的东西,因此,该门是要评估关于记忆单元 c_t 的哪些部分需要显露在隐藏状态 h_t 中。用于评估的信号是 o_t,然后与 c_t 通过 o_{t} \circ tanh(c_{t})运算得到最终的 h_t(注意,这里的h_t和RNN里面其实不太一样,在LSTM里面,其实c_t才是真正的h_t

其他步骤:

  • Final memory generation:这个阶段首先根据忘记门 f_t 的判断,相应地忘记过去的记忆 c_{t?1}。类似地,根据输入门 i_t 的判断,相应地生成新的记忆 \widetilde{c}_{t}。然后将上面的两个结果相加生成最终的记忆 c_t

例子:

如上图所示,权值随机初始化,输入在右下角,接下来按顺序先输入[3,1,0]:

再输入[4,1,0]:

接下去以此类推:

整体上来看:

LSTM 如此复杂,那么它到底解决了什么问题呢?
观察LSTM的前向公式我们可以看到\begin{eqnarray} {c}_{t} &=& f_{t} \circ c_{t-1} +i_{t} \circ \widetilde{c}_{t} \end{eqnarray}其实就是我们之前的\mathbf{h}_t = \mathbf{W}_{hx} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1} 。那么如果对c_{t-1}求导也会得到类似\mathbf{W}_{hh}的阶乘的情况,但是此时\mathbf{W}_{hh}被换成了f_t(当然和之前求RNN的梯度一样,这里同样省略了其他东西,但大体上不影响结论),作为门机制的f_t,它的取值基本上不是0就是1,但大多时候是1,(注意,遗忘门并是不为了解决梯度消失而出现,在最早的LSTM中是不存在遗忘门,而是直接把这里设置为1,这很有残差网络的味道)所以通过这样的结果很好地解决了传统RNN中梯度消失的问题。

当然这里仅仅解决了c_{t-1}处的梯度问题,实际上LSTM在其他的梯度求导上还是容易出问题,这部分还是有待研究。至于遗忘门,查到的资料是说Gers 等人(2000)首先发现如果没有使记忆单元遗忘信息的机制,那么它们可能会无限增长,最终导致网络崩溃。为解决这个问题,他们为这个 LSTM 架构加上了另一个乘法门,即遗忘门。


六、GRU

GRU全称为Gated Recurrent Unit,它是LSTM的简化版变种,就目前的实验来看,GRU在性能上几乎与LSTM持平,但是在资源消耗方面会小一些。

GRU对LSTM的门进行了删减整合,将遗忘门、输入门和输出门换成了更新门和重置门,即下图中的z和r:

\begin{align*} z_{t} &= \sigma(W^{(z)}x_{t} + U^{(z)}h_{t-1})&~\text{(Update gate)}\\ r_{t} &= \sigma(W^{(r)}x_{t} + U^{(r)}h_{t-1})&~\text{(Reset gate)}\\ \tilde{h}_{t} &= \operatorname{tanh}(r_{t}\circ Uh_{t-1} + Wx_{t} )&~\text{(New memory)}\\ h_{t} &= (1 - z_{t}) \circ \tilde{h}_{t} + z_{t} \circ h_{t-1}&~\text{(Hidden state)} \end{align*}

更新门用于控制前一时刻的状态信息被带入到当前状态中的程度,更新门的值越大说明前一时刻的状态信息带入越多。重置门用于控制忽略前一时刻的状态信息的程度,重置门的值越小说明忽略得越多。


参考

  1. LSTM如何来避免梯度弥散和梯度爆炸?
  2. 三次简化一张图:一招理解LSTM/GRU门控机制
  3. LSTM如何解决梯度消失问题
  4. 详解 LSTM
  5. 学界|神奇!只有遗忘门的LSTM性能优于标准LSTM
  6. LSTM详解 反向传播公式推导
  7. CS224n笔记9 机器翻译和高级LSTM及GRU
  8. CS224n自然语言处理与深度学习 Lecture Notes Five
  9. Understanding LSTM Networks
  10. YJango的循环神经网络——实现LSTM
  11. GRU神经网络
  12. mxnet深度学习
  13. GRU与LSTM总结
  14. 李宏毅机器学习
最后编辑于
?著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,100评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,308评论 3 388
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,718评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,275评论 1 287
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,376评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,454评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,464评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,248评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,686评论 1 306
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,974评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,150评论 1 342
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,817评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,484评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,140评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,374评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,012评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,041评论 2 351

推荐阅读更多精彩内容