sequence_scatter

  • paddle.fluid.layers.sequence_scatter(input, index, updates, name=None)[源代码]

注解

该OP的输入index,updates必须是LoDTensor。

该OP根据index提供的位置将updates中的信息更新到输出中。

该OP先使用input初始化output,然后通过output[instance_index][index[pos]] += updates[pos]方式,将updates的信息更新到output中,其中instance_idx是pos对应的在batch中第k个样本。

output[i][j]的值取决于能否在index中第i+1个区间中找到对应的数据j,若能找到out[i][j] = input[i][j] + update[m][n],否则 out[i][j] = input[i][j]。

例如,在下面样例中,index的lod信息分为了3个区间。其中,out[0][0]能在index中第1个区间中找到对应数据0,所以,使用updates对应位置的值进行更新,out[0][0] = input[0][0]+updates[0][0]。out[2][1]不能在index中第3个区间找到对应数据1,所以,它等于输入对应位置的值,out[2][1] = input[2][1]。

样例:

  1. 输入:
  2.  
  3. input.data = [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
  4. [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
  5. [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
  6. input.dims = [3, 6]
  7.  
  8. index.data = [[0], [1], [2], [5], [4], [3], [2], [1], [3], [2], [5], [4]]
  9. index.lod = [[0, 3, 8, 12]]
  10.  
  11. updates.data = [[0.3], [0.3], [0.4], [0.1], [0.2], [0.3], [0.4], [0.0], [0.2], [0.3], [0.1], [0.4]]
  12. updates.lod = [[ 0, 3, 8, 12]]
  13.  
  14. 输出:
  15.  
  16. out.data = [[1.3, 1.3, 1.4, 1.0, 1.0, 1.0],
  17. [1.0, 1.0, 1.4, 1.3, 1.2, 1.1],
  18. [1.0, 1.0, 1.3, 1.2, 1.4, 1.1]]
  19. out.dims = X.dims = [3, 6]
  • 参数:
    • input (Variable) - 维度为 sequence_scatter - 图1 的Tensor, 支持的数据类型:float32,float64,int32,int64。
    • index (Variable) - 包含index信息的LoDTensor,lod level必须等于1,支持的数据类型:int64。
    • updates (Variable) - 包含updates信息的LoDTensor,lod level和index一致,数据类型与input的数据类型一致。支持的数据类型:float32,float64,int32,int64。
    • name (str,可选) – 具体用法请参见 Name ,一般无需设置,默认值为None。

返回:在input的基础上使用updates进行更新后得到的Tensor,它与input有相同的维度和数据类型。

返回类型:Variable

代码示例:

  1. import paddle.fluid as fluid
  2. import paddle.fluid.layers as layers
  3.  
  4. input = fluid.data( name="x", shape=[3, 6], dtype='float32' )
  5. index = fluid.data( name='index', shape=[12, 1], dtype='int64', lod_level=1)
  6. updates = fluid.data( name='updates', shape=[12, 1], dtype='float32', lod_level=1)
  7. output = fluid.layers.sequence_scatter(input, index, updates)