二三事

“みんなみんな大好き!”

0%

自注意力与 Transformer 的数学原理

若移动端访问时未显示侧栏,可点击左侧按钮强制按桌面端渲染。

* 本文的示意图使用了 AI 辅助创作。

RNN

背景引入

在前文 MLP 与 BP 算法的数学原理 中,我们从数学上推导了普通前馈神经网络的误差反向传播算法,并最终编码实现了一个能够训练并完成分类任务的 MLP。

MLP 是最经典、最基础的一种前馈神经网络,它只接受固定维度的张量 ,并在经过内部若干层全连接层的非线性变换 后,最终输出固定维度的张量 。对于能够确定输入与输出张量维度的任务,例如图像识别、缺陷检测等,输入输出张量维度固定是能够接受的。但是,对于序列数据,特别是那些具有强时间相关性的序列数据,譬如股票指数的时间序列、自然语言等,MLP 便显得力不从心了,主要体现在:

  • 输入维度固定,意味着无法处理可变长序列;
  • 固定维度与无记忆性的假设,使得模型更难捕捉到序列元素间的依赖关系。

一个简单的改进就是 RNN(Recurrent Neural Network,循环神经网络)。下面介绍 RNN,并做简要的数学推导。


参考文献:

  1. Y. Bengio, P. Simard and P. Frasconi, "Learning long-term dependencies with gradient descent is difficult," in IEEE Transactions on Neural Networks, vol. 5, no. 2, pp. 157-166, March 1994, doi: 10.1109/72.279181.

基本结构

RNN 的核心公式是如下的递归表达式, 其中 是时间步、 为隐状态, 是两个能够被训练与优化的权重参数张量, 是能够被训练与优化的偏置。

如果没有偏置项, 只能够做到线性变换,而不能做到仿射变换。

Recurrent Neural Network

我们认为, 包含了 的信息,因为 式本质上是一个递推公式,在完全展开后可见 由变量 唯一决定。RNN 通过显式地对时间步与历史状态进行建模,能够累积长期信息,这是 RNN 善于处理序列问题的一个重要因素。

这里要注意的关键点是 同时间步完全无关,即对于任何时间步,权重参数都是全局共享的,这是 RNN 擅长 NLP 之类的长建模序列依赖的另一个重要因素。

从权重规模看,RNN 的权重规模是固定的,不随输入序列长度 而变化,可视为常数级别复杂度的空间开销;而如果强行用 MLP 处理序列,常见的做法是要么让每个时间步输入 共享同一份权重参数,要么将所有时间步输入 拼接后再将其整体作为 MLP 的输入。然而,前者丢失了时间依赖性信息(所以 RNN 在其基础上设计了隐状态 ),后者则使得权重规模随序列长度 线性增长,显然这都不能称为很好的做法。

不严谨地,我们可以把 RNN 视为在时间维度上权重共享且具有隐状态的一种 MLP 改进。换言之,RNN 是在时间维度上对同一个 MLP 进行递归展开的结构。

RNN 的最终输出是可以灵活选取的,不存在统一的「硬规定」。可以只使用最后一个隐状态 或对其做进一步的变换 ,常见场景有情感分析、垃圾邮件识别与时间序列预测任务;也可以使用隐状态序列 ,场景场景有序列标注(譬如词性标注)、语音识别与早期机器翻译的 Decoder。

计算 时,需要我们特别提供 ,最常见的做法是取 为零值。也可以将 设计为一个可学习的参数,这样可以为 RNN 提供一个默认的初始上下文信息。

BPTT

在前文中我们推导了 MLP 的 BP 算法(Backpropagation Algorithm,反向传播算法)中梯度的形式。RNN 考虑了时间维度,我们称适用这类模型的最优化算法为 BPTT 算法(Backpropagation Through Time Algorithm,沿时间的反向传播算法)。下面我们推导 RNN 的 BPTT 算法中梯度的形式。

这里花了很大的篇幅推导 BPTT 中梯度的形式,这是有必要的,因为知晓了梯度后我们就可以很容易地使用梯度下降法等最优化算法训练模型了。

为方便表示,记

则根据 式,隐状态 可被表示为

通常假设 仅依赖 ,在假设成立的前提下总损失为

其中第 时间步的损失 不直接依赖于 ,而是由一系列隐状态 传递的。因此,要计算 的梯度,我们首先需要讨论 的梯度,进而才能讨论 的梯度。按经典的推导思路,我们记 接下来我们先推导出 的形式——这是 BPTT 数学推导中最重要的一环。


这里要明确根据式子 可知,总损失 的直接依赖路径是唯一的,即通过第 时间步的隐状态 ;而 又通过两种路径影响

  1. 直接影响路径: 直接影响当前时刻的损失 ,进而影响总损失

  2. 间接影响路径(通过递归关系): 作为下一时刻的输入,间接影响后续时刻的所有损失 ,进而影响总损失 。这类影响路径是 RNN 所特有的,是时间依赖的体现。

    然而,我们不能直接计算出 ,因为若将所有 直接展开到 ,我们将得到一个指数爆炸的间接路径数量;故我们应考虑藉由递归式 ,通过中间变量 表示 对总损失 造成的所有间接影响。

因此,我们根据链式法则,有 上式中方框内的部分利用了递归定义。


计算出 后,我们就可以方便地表示出总损失 的梯度了。


我们再分别导出回归任务场景均方误差与分类任务场景交叉熵损失下 的具体表达式,记真实值为 ,并记

对于均方误差损失函数 ,若基于最后一个时间步的隐状态 经线性层得到最终输出,即 对于交叉熵损失函数 ,考虑真实标签为 one-hot 编码,若基于最后一个时间步的隐状态 经线性层得到 logits 并将 logits 作为 Softmax 函数输入最终输出概率分布 ,即

由 Softmax 与交叉熵复合函数的导数性质,不难知道 因此有 可以看到,在上述前提下,BPTT 针对回归任务与分类任务的梯度计算在形式上是十分相似的,仅在误差项的来源上有所差异。

局限

RNN 的循环结构使得模型具备了跨时间维度共享状态的能力。我们可以列举出 RNN 的若干优点:

  • 能够轻松地处理可变长序列;
  • 参数量与序列长度无关,大幅降低了长序列数据的计算量;
  • 具备一定程度的 “记忆” 能力,因为任何时刻的隐状态在理论上都蕴含了过去时间步的历史信息;
  • ……

RNN 的设计的确十分优雅,然而从上述推导中,我们能够发现 RNN 依然存在着若干缺陷:

  • 没有解决长期依赖的问题,因为 RNN 内部存在梯度爆炸 / 梯度消失的现象,这使得早期信息难以长期保留;
  • 无法直接访问全部的历史信息:当前时间步 只利用了上一时间步输出的隐状态 ,尽管 理论上蕴含了过去的历史信息,但终归是将所有历史信息全部压缩编码到 这一个量中了,而无法直接利用过去完整的历史信息 。注意力机制解决了这个问题,使得模型在任意时间步都能够完整访问并利用过去某一时间步的信息;
  • 串行计算:在计算隐状态时,RNN 只能按时间步顺序串行化计算,因为 依赖于 。Transformer 在训练阶段彻底解决了这个问题,实现了所有 token 并行计算。

这里简要说明 RNN 为什么存在梯度爆炸 / 梯度消失的风险。我们讨论早期隐状态 对总损失 的影响程度,这在数值上等价于关注 。根据链式法则,有 因此, 可以被分解为若干矩阵连乘的和,并且跨越的时间步愈长,连乘的矩阵愈多。

Hochreiter 与 Schmidhuber 在 1997 年指出,保持时间依赖性的必要条件是长期保持梯度。然而, 由于激活函数的导数往往是有界的(导数无界的激活函数易导致更严重的梯度爆炸 / 梯度消失,极少考虑),因此可认为 。实际上,对于常见的激活函数(Sigmoid、Tanh、ReLU 等),通常有 。在 的假设下,有 不严谨地说,若 ,则 ,梯度 便呈指数级衰减,这使得早期隐状态 对总损失的影响变得微乎其微。如需更严谨的证明,请参考 Bengio、Simard 与 Frasconi 在 1994 年的研究。

Hochreiter 与 Schmidhuber 提出 LSTM 的动机,正是为了解决 RNN 的结构性问题所带来的梯度爆炸 / 梯度消失。LSTM 的核心创新点是通过添加精妙的输入门、遗忘门与输出门机制,使得隐状态链路径上的梯度不再是多个矩阵连乘,而只需要连乘多个遗忘门的值——而遗忘门是可学习的,这使得梯度消失的问题得到了很大的缓解。

Long Short-Term Memory

下面我们跳过 LSTM、GRU 等若干 RNN 的改良,让我们直接关注划时代的自注意力机制。


参考文献:

  1. S. Hochreiter and J. Schmidhuber, "Long Short-Term Memory," in Neural Computation, vol. 9, no. 8, pp. 1735-1780, 15 Nov. 1997, doi: 10.1162/neco.1997.9.8.1735.

自注意力

背景引入

LSTM 在 1997 年被首次提出,直到 2014 年注意力机制出现以前,机器翻译领域的最佳实践是 Encoder–Decoder 架构(without Attention)。

上文提到,RNN 在时间步 只能直接利用上一时间步的隐状态 ,即使是 LSTM 也主要针对梯度消失的问题进行改良,并未解决该问题。为了解决这个问题

Transformer

背景引入

Encoder–Decoder

待写