def _augment_image(image,
                   augmentation_type=enums.AugmentationType.SIMCLR,
                   warp_prob=0.5,
                   blur_prob=0.5,
                   augmentation_magnitude=0.5,
                   use_pytorch_color_jitter=False):
  """Applies data augmentation to an image.

  Args:
    image: An image Tensor of shape [height, width, 3] and dtype tf.uint8.
    augmentation_type: An enums.AugmentationType.
    warp_prob: The probability of applying a warp augmentation at the end (only
      applies to augmentation_types SIMCLR and STACKED_RANDAUGMENT).
    blur_prob: A Python float between 0 and 1. The probability of applying a
      blur transformation.
    augmentation_magnitude: The magnitude of augmentation. Valid range differs
      depending on the augmentation_type.
    use_pytorch_color_jitter: A Python bool. Whether to use a color jittering
      algorithm that aims to replicate `torchvision.transforms.ColorJitter`
      rather than the standard TensorFlow color jittering. Only used for
      augmentation_types SIMCLR and STACKED_RANDAUGMENT.

  Returns:
    The augmented image.

  Raises:
    ValueError if `augmentation_type` is unknown.
  """
  augmentation_type_map = {
      enums.AugmentationType.SIMCLR:
          functools.partial(
              _simclr_augment,
              warp_prob=warp_prob,
              blur_prob=blur_prob,
              strength=augmentation_magnitude,
              use_pytorch_color_jitter=use_pytorch_color_jitter),
      enums.AugmentationType.STACKED_RANDAUGMENT:
          functools.partial(
              _stacked_simclr_randaugment,
              warp_prob=warp_prob,
              blur_prob=blur_prob,
              strength=augmentation_magnitude,
              use_pytorch_color_jitter=use_pytorch_color_jitter),
      enums.AugmentationType.RANDAUGMENT:
          augment.RandAugment(magnitude=augmentation_magnitude).distort,
      enums.AugmentationType.AUTOAUGMENT:
          augment.AutoAugment().distort,
      enums.AugmentationType.IDENTITY: (lambda x: x),
  }
  if augmentation_type not in augmentation_type_map:
    raise ValueError(f'Invalid augmentation_type: {augmentation_type}.')

  if image.dtype != tf.uint8:
    raise TypeError(f'Image must have dtype tf.uint8. Was {image.dtype}.')

  if augmentation_magnitude <= 0.:
    return image

  with tf.name_scope('augment_image'):
    return augmentation_type_map[augmentation_type](image)
Exemple #2
0
    def test_randaug(self):
        """Smoke test to be sure there are no syntax errors."""
        image = tf.zeros((224, 224, 3), dtype=tf.uint8)

        augmenter = augment.RandAugment()
        aug_image = augmenter.distort(image)

        self.assertEqual((224, 224, 3), aug_image.shape)
def _stacked_simclr_randaugment(image,
                                warp_prob=0.5,
                                blur_prob=0.5,
                                strength=0.5,
                                side_length=IMAGE_SIZE,
                                use_pytorch_color_jitter=False):
  """A combination the data augmentation sequences from SimCLR and RandAugment.

  Citations:
    SimCLR: https://arxiv.org/abs/2002.05709
    RandAugment: https://arxiv.org/abs/1909.13719

  Args:
    image: An image Tensor of shape [height, width, 3] and dtype tf.uint8.
    warp_prob: A Python float between 0 and 1. The probability of applying a
      warp transformation.
    blur_prob: A Python float between 0 and 1. The probability of applying a
      blur transformation.
    strength: strength: A Python float in range [0,1] controlling the maximum
      strength of the augmentations.
    side_length: A Python integer. The length, in pixels, of the width and
      height of `image`.
    use_pytorch_color_jitter: A Python bool. Whether to use a color jittering
      algorithm that aims to replicate `torchvision.transforms.ColorJitter`
      rather than the standard TensorFlow color jittering.

  Returns:
    An image with the same shape and dtype as the input image.
  """
  with tf.name_scope('stacked_simclr_randaugment'):
    image = _random_apply(
        functools.partial(
            _jitter_colors,
            strength=strength,
            use_pytorch_impl=use_pytorch_color_jitter),
        p=0.8,
        x=image)
    image = augment.RandAugment().distort(image)
    image = _random_apply(_rgb_to_gray, p=0.2, x=image)
    if warp_prob > 0.:
      image = _random_apply(
          functools.partial(_warp, side_length=side_length),
          p=warp_prob,
          x=image)
    if blur_prob > 0.:
      image = _convert_image_dtype(image, tf.float32)
      image = _random_apply(
          functools.partial(_gaussian_blur, side_length=side_length),
          p=blur_prob,
          x=image)
    return image