三、Transformer XL
Transformer
解决的核心问题是:如何将任意长度的文本序列编码成一个固定维度的隐向量。假设有无限的内存和算力,一个简单的方法是:取语料库中最长序列的长度,然后把所有较短的序列填充到最大长度作为输入。但是由于内存和算力的限制,这种方式不现实。
一个可行的方案是:将整个语料库分割成固定大小的、较短的片段
segment
,然后用每个片段来训练模型,而忽略之前所有片段的信息。这种模型称作
vanilla model
,这也是Transformer
的做法。
vanilla model
没有任何跨segment
的信息流,这会带来两个问题:模型能够捕获的最大依赖的长度不超过
segment
的大小。假设
segment
的长度为 ,则模型无法捕获长度超过 的依赖性。划分
segment
的时候未考虑句子边界或者短语的边界,这破坏了语义完整性。因此模型缺乏必要的上下文信息来预测segment
的开头的前几个token
和结束的尾部几个token
。这会导致低效的优化效率和更差的泛化能力。这个问题称作上下文碎片化
context fragmentation
问题。
vanilla model
不仅在训练期间遇到问题,在推断期间也会遇到问题。Transformer
在推断期间反复执行推断过程,每次推断都根据前几轮输入结果和输入来预测下一个输出结果。- 推断期间,模型也采取与训练期相同的
segment
大小,因此也会遇到上下文碎片化问题。 - 每个新的
segment
的计算需要从头开始重新计算,计算速度较慢。
- 推断期间,模型也采取与训练期相同的
Transformer XL
通过引入递归机制和相对位置编码来解决上述问题。- 能够学到的最长依赖性的长度:
Transformer XL
比RNN
长 80%,比Transformer
长 450%。 - 推断速度:
Transformer XL
比Transformer
快 1800 多倍。
- 能够学到的最长依赖性的长度:
3.1 Segment-level 递归
为解决固定长度上下文的局限性,
Transformer XL
引入了Segment-level
递归机制。训练期间:
Transformer XL
缓存前一个segment
计算得到的隐状态序列,然后在下一个segment
中重用。这种额外的输入使得网络能够利用历史中的有效信息,从而能够对长期依赖建模,并且避免了上下文碎片化。
在对语料库进行分割来生成样本之后,
Transformer XL
要求对样本不能进行混洗。因为一旦混洗就会破坏segment
的先后顺序。由于需要考虑
segment
之间的先后关系,因此训练期间要将连续的一组segment
分别放置在连续的一组batchment
中。这样可以在尽可能满足
segment
先后关系的条件下提高数据并行度。
推断期间:
Transformer XL
缓存前一个segment
计算得到的隐状态序列,使得后续需要用到这些隐状态时直接使用而不必重新计算,加快了推断速度。实验表明,
Transformer XL
的推断速度是传统Transformer
的 1800 倍。
令
segment
长度为 ,第 个segment
的输入为token
序列: ;第 个segment
的输入为: 。假设网络有 层,第 层网络每个位置的输出隐状态分别为 ():
其中 为 维列向量。令:
考虑
Segment-level
递归机制,则 的第 层各位置的隐向量为:拼接 的第 层的隐向量序列 和 的第 层的隐向量序列:
其中
SG
表示冻结参数(不用计算梯度),concate
表示沿位置拼接:计算
query,key,value
向量:其中 分别为
query,key,value
转换矩阵。计算 :
其中
Transformer-Layer
为常规的Transformer
层。
与标准
Transformer
不同,这里计算key
向量和value
向量时使用了扩展上下文,其中 缓存了前一个segment
的状态。这种在前后两个
segment
之间共享状态的方式构成了segment-level
的递归,因此上下文信息可以跨segment
流动。是在 中被使用的,这不仅跨了一个段,也垮了一层。这显著的区别于其它的
RNN
语言模型。正因为如此,
Transformer XL
最大依赖长度扩展到了 。下图中,每个位置的
context
都是它左下方的4个位置。Transformer XL
的训练方式类似于BPTT
,但是与BPTT
不同的是:Transformer XL
缓存了上一个segment
的多组隐向量。理论上不仅可以缓存前一个
segment
,也可以缓存前面 个segment
。
3.2 相对位置编码
采用
segment-level
递归机制之后存在一个问题:如何保持位置信息的一致性。令:
segment
长度为 。- 第 个
segment
的输入为token
序列: ,对应的token embedding
矩阵为 。 - 第 个
segment
的输入为: , 对应的token embedding
矩阵为 。 position embedding
矩阵为 。
则有:
可见 和 都采用了同样的位置编码,因此模型没有任何信息来区分一个
token
是位于segment
还是segment
。令:
- 为第 个位置的
token
的embedding
(对应于token embedding
矩阵 的第 行 ) - 为第 个位置的
position embedding
(对应于position embedding
矩阵 的第 行 ) - 为第 个位置的
token + position embedding
- 分别为
query, key, value
向量
则
Transformer
的attention score
(不考虑softmax
归一化,以及除以 )为:它表示
query i
与key j
的相关性。- 第一项刻画了位置 的
token
和位置 的token
的相关性。 - 第二项刻画了位置 的
token
和位置 的position
的相关性。 - 第三项刻画了位置 的
position
和位置 的token
的相关性。 - 第四项刻画了位置 的
position
和位置 的position
的相关性。
- 为第 个位置的
Transformer XL
引入相对位置编码。位置 相对于位置 的距离为 ,则位置 相对于位置 的relative position embedding
为:令 为
Transformer XL
中使用的最大相对距离,令相对位置编码矩阵为:Transformer XL
修改attention score
为:- 将第二、四项中的绝对位置编码 修改为相对位置编码 。其中的相对位置是:
key
相对于query
的位置。 - 通过参数 来代替 。这表示对于位置 的
key token
,同一个query
在不同位置上无影响。因为这种影响被剥离到第二项中。 - 通过参数 来代替 。这表示对
key
相对于value
的相对位置 ,同一个query
在不同位置上无影响。因为这种影响被剥离到第二项中。 - 通过 和 来生成不同的
key
向量。
修改后的
attention score
各项的意义为:- 第一项刻画了基于内容的
attention
- 第二项刻画了内容相对于每个相对位置的
bias
- 第三项刻画了内容的全局的
bias
- 第四项刻画了位置的全局
bias
- 将第二、四项中的绝对位置编码 修改为相对位置编码 。其中的相对位置是:
3.3 实验结果
Transformer XL
验证了word-level
语言模型(以困惑度PPL
为指标),以及char-level
语言模型(以bpc:Bit per Character
为指标) 。其中包含以下数据集:WikiText-103
数据集:最大的word-level
语言模型benchmark
,包含 2.8万篇文章总计103M 训练token
,平均每篇文章 3.6K token 。由于文章的平均
token
数量较大,因此这会考验模型的长期依赖建模能力。训练期间
attention length = 384
,推断期间attention length = 1600
。attention length
也等于segment
长度
enwik8
数据集:包含 100M 字节的未经处理的wiki
文本。结果表明:12层的
Transformer XL
的表现超过了 12层的Transformer
。text8
数据集:包含 100M 字节的、经过处理的wiki
文本。处理方式为:大写字母改小写、移除a~z
之外的所有字符。One Billion Word
数据集:数据集的句子已经被混洗过,因此该数据集无法验证序列的长期依赖。因而该数据集用于测试模型的短期依赖。结果表明:
Transformer XL
对短期依赖的建模效果也非常好。Penn Treebank
数据集:包含 1Mtoken
,用于验证模型在小数据集上的表现。结果表明:
transformer XL
在小数据集上效果也非常好。
在
WikiText-103
数据集上验证segment-level
递归、相对位置编码的作用,验证结果如下。其中:结果划分为三个部分,分别为
128M
参数的Transformer-XL
、128M
参数的Transformer
、151M
参数的Transformer-XL
。最后四列的意义:
PPL init
表示推断期间使用与训练时相同的attention length
PPL best
表示推断期间使用最佳长度的attention length
(通过超参数搜索得到)Attn Len
给出了推断期间使用的最佳attention length
Full Loss/ Half Loss
表示计算当前segment
中所有位置的损失,还是计算后一半位置的损失
结果表明:
相对位置编码非常重要对于
Transformer XL
非常重要,而绝对位置编码只有在Half Loss
中工作良好。因为
Half Loss
只考虑当前segment
后一半位置的损失,因此受到前一个segment
的绝对位置编码的影响比较小。随着
attention length
的增加,模型的困惑度下降。尽管每个
segment
训练期间的长度是 128,但推断期间可以推广到 640 。
此外,由于
Transformer XL
采取了缓存前一个segment
的状态的策略,因此会消耗更多的内存。下图给出在相同GPU
内存的约束条件下的比较结果。结果表明:即使
Transformer XL
由于内存约束而不得不使用一个较小的backprop len
,其最终效果仍然超越Transformer
。为了证明
Transformer XL
能够解决上下文碎片问题,作者在One Billion Word
数据集上做了实验。因为该数据集的句子经过混洗,导致不需要对长期依赖建模。因此效果的提升一定是由于模型能够解决上下文碎片,而不是由于模型捕捉到更长的上下文。