matmul

  • paddle.fluid.layers.matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None)[源代码]

输入 x 和输入 y 矩阵相乘。

两个输入的形状可为任意维度,但当任一输入维度大于3时,两个输入的维度必须相等。 实际的操作取决于 xy 的维度和 transpose_xtranspose_y 的布尔值。具体如下:

  • 如果 transpose 为真,则对应 Tensor 的后两维会转置。假定 x 是一个 shape=[D] 的一维 Tensor,则 x 非转置形状为 [1, D],转置形状为 [D, 1]。转置之后的输入形状需满足矩阵乘法要求,即 x_width 与 y_height 相等。
    • 转置后,输入的两个 Tensor 维度将为 2-D 或 n-D,将根据下列规则矩阵相乘:
      • 如果两个矩阵都是 2-D,则同普通矩阵一样进行矩阵相乘。
      • 如果任意一个矩阵是 n-D,则将其视为带 batch 的二维矩阵乘法。
  • 如果原始 Tensor x 或 y 的秩为 1 且未转置,则矩阵相乘后的前置或附加维度 1 将移除。
  • 参数:
    • x (Variable) : 输入变量,类型为 Tensor 或 LoDTensor。
    • y (Variable) : 输入变量,类型为 Tensor 或 LoDTensor。
    • transpose_x (bool) : 相乘前是否转置 x。
    • transpose_y (bool) : 相乘前是否转置 y。
    • alpha (float) : 输出比例,默认为 1.0。
    • name (str|None) : 该层名称(可选),如果设置为空,则自动为该层命名。
  • 返回:
    • Variable (Tensor / LoDTensor),矩阵相乘后的结果。
  • 返回类型:
    • Variable(变量)。
  1. * 1:
  2.  
  3. x: [B, ..., M, K], y: [B, ..., K, N]
  4. out: [B, ..., M, N]
  5.  
  6. * 2:
  7.  
  8. x: [B, M, K], y: [B, K, N]
  9. out: [B, M, N]
  10.  
  11. * 3:
  12.  
  13. x: [B, M, K], y: [K, N]
  14. out: [B, M, N]
  15.  
  16. * 4:
  17.  
  18. x: [M, K], y: [K, N]
  19. out: [M, N]
  20.  
  21. * 5:
  22.  
  23. x: [B, M, K], y: [K]
  24. out: [B, M]
  25.  
  26. * 6:
  27.  
  28. x: [K], y: [K]
  29. out: [1]
  30.  
  31. * 7:
  32.  
  33. x: [M], y: [N]
  34. out: [M, N]

代码示例

  1. import paddle.fluid as fluid
  2. import numpy
  3.  
  4. # Graph Organizing
  5. x = fluid.layers.data(name='x', shape=[2, 3], dtype='float32')
  6. y = fluid.layers.data(name='y', shape=[3, 2], dtype='float32')
  7. output = fluid.layers.matmul(x, y, True, True)
  8.  
  9. # Create an executor using CPU as an example
  10. exe = fluid.Executor(fluid.CPUPlace())
  11. exe.run(fluid.default_startup_program())
  12.  
  13. # Execute
  14. input_x = numpy.ones([2, 3]).astype(numpy.float32)
  15. input_y = numpy.ones([3, 2]).astype(numpy.float32)
  16. res, = exe.run(fluid.default_main_program(),
  17. feed={'x':input_x, 'y':input_y},
  18. fetch_list=[output])
  19. print(res)
  20. '''
  21. Output Value:
  22. [[2. 2. 2.]
  23. [2. 2. 2.]
  24. [2. 2. 2.]]
  25. '''