scaled_dot_product_attention
paddle.fluid.nets.
scaled_dot_product_attention
(queries, keys, values, num_heads=1, dropout_rate=0.0)[源代码]
该接口实现了的基于点积(并进行了缩放)的多头注意力(Multi-Head Attention)机制。attention可以表述为将一个查询(query)和一组键值对(key-value pair)映射为一个输出;Multi-Head Attention则是使用多路进行attention,而且对attention的输入进行了线性变换。公式如下:
其中,
分别对应 queries
、 keys
和 values
,详细内容请参阅 Attention Is All You Need
要注意该接口实现支持的是batch形式,
中使用的矩阵乘是batch形式的矩阵乘法,参考 fluid.layers. matmul 。
- 参数:
- queries (Variable) - 形状为 的三维Tensor,其中 为batch_size, 为查询序列长度, 为查询的特征维度大小, 为head数。数据类型为float32或float64。
- keys (Variable) - 形状为
的三维Tensor,其中
为batch_size,
为键值序列长度,
为键的特征维度大小,
为head数。数据类型与
queries
相同。 - values (Variable) - 形状为
的三维Tensor,其中
为batch_size,
为键值序列长度,
为值的特征维度大小,
为head数。数据类型与
queries
相同。 - num_heads (int) - 指明所使用的head数。head数为1时不对输入进行线性变换。默认值为1。
- dropout_rate (float) - 以指定的概率对要attention到的内容进行dropout。默认值为0,即不使用dropout。
返回: 形状为
的三维Tensor,其中 为batch_size, 为查询序列长度, 为值的特征维度大小。与输入具有相同的数据类型。表示Multi-Head Attention的输出。
返回类型: Variable
- 抛出异常:
ValueError
:queries
、keys
和values
必须都是三维。ValueError
:queries
和keys
的最后一维(特征维度)大小必须相同。ValueError
:keys
和values
的第二维(长度维度)大小必须相同。ValueError
:keys
的最后一维(特征维度)大小必须是num_heads
的整数倍。ValueError
:values
的最后一维(特征维度)大小必须是num_heads
的整数倍。
代码示例
- import paddle.fluid as fluid
- queries = fluid.layers.data(name="queries",
- shape=[3, 5, 9],
- dtype="float32",
- append_batch_size=False)
- queries.stop_gradient = False
- keys = fluid.layers.data(name="keys",
- shape=[3, 6, 9],
- dtype="float32",
- append_batch_size=False)
- keys.stop_gradient = False
- values = fluid.layers.data(name="values",
- shape=[3, 6, 10],
- dtype="float32",
- append_batch_size=False)
- values.stop_gradient = False
- contexts = fluid.nets.scaled_dot_product_attention(queries, keys, values)
- contexts.shape # [3, 5, 10]