二、训练算法
2.1 BPTT 算法
以
多输出&隐-隐RNN
为例,设:输入到隐状态的权重为 ,隐状态到输出的权重为 ,隐状态到隐状态的权重为 , 为输入偏置向量和输出偏置向量;激活函数为双曲正切激活函数 。设网络从特定的初始状态 开始前向传播,从 到 的每个时间步,则有更新方程:
多输出&隐-隐RNN
的单个样本损失函数为: 。该损失函数的梯度计算代价较高:- 因为每个时间步只能一前一后的计算无法并行化,因此时间复杂度为 。
- 前向传播中各个状态必须保存直到它们反向传播中被再次使用,因此空间复杂度也是 。
- 采用
tanh
激活函数而不是ReLU
激活函数的原因是为了缓解长期依赖。
back-propagation through time:BPTT
:通过时间反向传播算法,其算法复杂度为 。由
BPTT
计算得到梯度,再结合任何通用的、基于梯度的技术就可以训练RNN
。计算图的节点包括参数 ,以及以 为索引的节点序列 以及 。
根据 ,则有: 。
令节点 ,则有: 。则有:
其中 表示 的第 个分量。
则有:
表示梯度 的第 个分量, 为示性函数。写成向量形式为:
其中 为真实标签 扩充得到的概率分布,其真实的类别 位置上的分量为 1,而其它位置上的分量为 0。
根据定义 ,得到:
根据导数: ,则有:
设隐向量长度为 ,定义:
则有: 。
根据定义 ,即 ,则有: ,记作:
因此得到隐单元的梯度:
当 时, 只有一个后续结点 (从而只有一个后继节点 ) ,因此有:
当 时, 同时具有 两个后续节点,因此有:
由于 依赖于 ,因此求解隐单元的梯度时,从末尾开始反向计算。
一旦获得了隐单元及输出单元的梯度,则可以获取参数节点的梯度。
注意:由于参数在多个时间步共享,因此在参数节点的微分操作时必须谨慎对待。
微分中的算子 在计算 对于 的贡献时,将计算图的所有边都考虑进去了。但是事实上:有一条边是 时间步的 ,还有一条边是 时间步的 ,…. 。
为了消除歧义,使用虚拟变量 作为 的副本。用 表示参数 在时间步 对于梯度的贡献。将所有时间步上的梯度相加,即可得到 。
根据定义 ,即 。则有:
考虑到 对于每个输出 都有贡献,因此有:
记:
考虑到 对于每个输出 都有贡献,因此有:
其中 表示 的第 个分量。
根据定义 ,即:
则有:
考虑到 对于每个隐向量 都有贡献,因此有:
记:
考虑到每个 都对 有贡献,则:
其中 表示 的第 个分量。
记:
考虑到每个 都对 有贡献,则:
其中 表示 的第 个分量。
因为任何参数都不是训练数据 的父节点,因此不需要计算 。
2.2 Teacher forcing 算法
多输出&输出-隐连接RNN
模型可以使用teacher forcing
算法进行训练。- 模型的数学表示:
- 单个样本的损失:
- 训练时:在时刻 接受真实类别分布 作为输入,而不必等待 时刻的模型输出分布 。
- 推断时:真实的标记通常是未知的,因此必须用模型的输出分布 。
teacher forcing
训练的本质原因是:当前隐状态与早期隐状态没有直接连接。虽然有间接连接,但是由于 已知,因此这种连接被切断。- 如果模型的隐状态依赖于早期时间步的隐状态,则需要采用
BPTT
算法。 - 某些模型训练时,需要同时使用
teacher forcing
和BPTT
算法。
- 如果模型的隐状态依赖于早期时间步的隐状态,则需要采用