grid_sampler

  • paddle.fluid.layers.grid_sampler(x, grid, name=None)[源代码]

该OP基于flow field网格的对输入X进行双线性插值采样。网格通常由affine_grid生成, shape为[N, H, W, 2],是shape为[N, H, W]的采样点张量的(x, y)坐标。 其中,x坐标是对输入数据X的第四个维度(宽度维度)的索引,y坐标是第三维度(高维度)的索引,最终输出采样值为采样点的4个最接近的角点的双线性插值结果,输出张量的shape为[N, C, H, W]。

step 1:

得到(x, y)网格坐标,缩放到[0,h -1/W-1]
  1. grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1) grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)

step 2:

在每个[H, W]区域用网格(X, y)作为输入数据X的索引,并将双线性插值点值由4个最近的点表示。
  1. wn ------- y_n ------- en
  2. | | |
  3. | d_n |
  4. | | |
  5. x_w --d_w-- grid--d_e-- x_e
  6. | | |
  7. | d_s |
  8. | | |
  9. ws ------- y_s ------- wn
  10.  
  11. x_w = floor(x) // west side x coord
  12. x_e = x_w + 1 // east side x coord
  13. y_n = floor(y) // north side y coord
  14. y_s = y_s + 1 // south side y coord
  15. d_w = grid_x - x_w // distance to west side
  16. d_e = x_e - grid_x // distance to east side
  17. d_n = grid_y - y_n // distance to north side
  18. d_s = y_s - grid_y // distance to south side
  19. wn = X[:, :, y_n, x_w] // north-west point value
  20. en = X[:, :, y_n, x_e] // north-east point value
  21. ws = X[:, :, y_s, x_w] // south-east point value
  22. es = X[:, :, y_s, x_w] // north-east point value
  23.  
  24. output = wn * d_e * d_s + en * d_w * d_s
  25. + ws * d_e * d_n + es * d_w * d_n
  • 参数:
    • x (Variable): 输入张量,维度为 grid_sampler - 图1 的4-D Tensor,N为批尺寸,C是通道数,H是特征高度,W是特征宽度, 数据类型为float32或float64。
    • grid (Variable): 输入网格数据张量,维度为 grid_sampler - 图2 的4-D Tensor,N为批尺寸,C是通道数,H是特征高度,W是特征宽度, 数据类型为float32或float64。
    • name (None|str) – 该参数供开发人员打印调试信息时使用,具体用法请参见 Name ,默认值为None。

返回: Variable(Tensor): 输入X基于输入网格的双线性插值计算结果,维度为

grid_sampler - 图3 的4-D Tensor

返回类型:变量(Variable),数据类型与 x 一致

代码示例:

  1. # 一般与 affine_grid 组合使用
  2. x = fluid.layers.data(name='x', shape=[10, 32, 32], dtype='float32')
  3. theta = fluid.layers.data(name='theta', shape=[2, 3], dtype='float32')
  4. grid = fluid.layers.affine_grid(theta=theta, out_shape=[3, 10, 32, 32])
  5. out = fluid.layers.grid_sampler(x=x, grid=grid)