BaseTransform

class paddle.vision.transforms. BaseTransform ( keys=None ) [源代码]

视觉中图像变化的基类。

调用逻辑:

  1. if keys is None:
  2. _get_params -> _apply_image()
  3. else:
  4. _get_params -> _apply_*() for * in keys

如果你想要定义自己的图像变化方法, 需要重写子类中的 _apply_* 方法。

参数

  • keys (list[str]|tuple[str], optional) - 输入的类型. 你的输入可以是单一的图像,也可以是包含不同数据结构的元组, keys 可以用来指定输入类型. 举个例子, 如果你的输入就是一个单一的图像,那么 keys 可以为 None 或者 (“image”)。如果你的输入是两个图像:(image, image) ,那么 keys 应该设置为 ("image", "image") 。如果你的输入是 (image, boxes), 那么 keys 应该为 ("image", "boxes")

    目前支持的数据类型如下所示:

    • “image”: 输入的图像, 它的维度为 (H, W, C)

    • “coords”: 输入的左边, 它的维度为 (N, 2)

    • “boxes”: 输入的矩形框, 他的维度为 (N, 4), 形式为 “xyxy”, 第一个 “xy” 表示矩形框左上方的坐标, 第二个 “xy” 表示矩形框右下方的坐标.

    • “mask”: 分割的掩码,它的维度为 (H, W, 1)

    你也可以通过自定义 apply*的方法来处理特殊的数据结构。

返回

PIL.Image 或 numpy ndarray,变换后的图像。

代码示例

  1. import numpy as np
  2. from PIL import Image
  3. import paddle.vision.transforms.functional as F
  4. from paddle.vision.transforms import BaseTransform
  5. def _get_image_size(img):
  6. if F._is_pil_image(img):
  7. return img.size
  8. elif F._is_numpy_image(img):
  9. return img.shape[:2][::-1]
  10. else:
  11. raise TypeError("Unexpected type {}".format(type(img)))
  12. class CustomRandomFlip(BaseTransform):
  13. def __init__(self, prob=0.5, keys=None):
  14. super(CustomRandomFlip, self).__init__(keys)
  15. self.prob = prob
  16. def _get_params(self, inputs):
  17. image = inputs[self.keys.index('image')]
  18. params = {}
  19. params['flip'] = np.random.random() < self.prob
  20. params['size'] = _get_image_size(image)
  21. return params
  22. def _apply_image(self, image):
  23. if self.params['flip']:
  24. return F.hflip(image)
  25. return image
  26. # if you only want to transform image, do not need to rewrite this function
  27. def _apply_coords(self, coords):
  28. if self.params['flip']:
  29. w = self.params['size'][0]
  30. coords[:, 0] = w - coords[:, 0]
  31. return coords
  32. # if you only want to transform image, do not need to rewrite this function
  33. def _apply_boxes(self, boxes):
  34. idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
  35. coords = np.asarray(boxes).reshape(-1, 4)[:, idxs].reshape(-1, 2)
  36. coords = self._apply_coords(coords).reshape((-1, 4, 2))
  37. minxy = coords.min(axis=1)
  38. maxxy = coords.max(axis=1)
  39. trans_boxes = np.concatenate((minxy, maxxy), axis=1)
  40. return trans_boxes
  41. # if you only want to transform image, do not need to rewrite this function
  42. def _apply_mask(self, mask):
  43. if self.params['flip']:
  44. return F.hflip(mask)
  45. return mask
  46. # create fake inputs
  47. fake_img = Image.fromarray((np.random.rand(400, 500, 3) * 255.).astype('uint8'))
  48. fake_boxes = np.array([[2, 3, 200, 300], [50, 60, 80, 100]])
  49. fake_mask = fake_img.convert('L')
  50. # only transform for image:
  51. flip_transform = CustomRandomFlip(1.0)
  52. converted_img = flip_transform(fake_img)
  53. # transform for image, boxes and mask
  54. flip_transform = CustomRandomFlip(1.0, keys=('image', 'boxes', 'mask'))
  55. (converted_img, converted_boxes, converted_mask) = flip_transform((fake_img, fake_boxes, fake_mask))
  56. print('converted boxes', converted_boxes)