def test_randaug(self):
    """Smoke test to be sure there are no syntax errors."""
    image = tf.zeros((224, 224, 3), dtype=tf.uint8)

    augmenter = augment.RandAugment()
    aug_image = augmenter.distort(image)

    self.assertEqual((224, 224, 3), aug_image.shape)
Beispiel #2
0
def _augment_image(image,
                   augmentation_type=enums.AugmentationType.SIMCLR,
                   warp_prob=0.5,
                   blur_prob=0.5,
                   augmentation_magnitude=0.5,
                   use_pytorch_color_jitter=False):
    """Applies data augmentation to an image.

  Args:
    image: An image Tensor of shape [height, width, 3] and dtype tf.uint8.
    augmentation_type: An enums.AugmentationType.
    warp_prob: The probability of applying a warp augmentation at the end (only
      applies to augmentation_types SIMCLR and STACKED_RANDAUGMENT).
    blur_prob: A Python float between 0 and 1. The probability of applying a
      blur transformation.
    augmentation_magnitude: The magnitude of augmentation. Valid range differs
      depending on the augmentation_type.
    use_pytorch_color_jitter: A Python bool. Whether to use a color jittering
      algorithm that aims to replicate `torchvision.transforms.ColorJitter`
      rather than the standard TensorFlow color jittering. Only used for
      augmentation_types SIMCLR and STACKED_RANDAUGMENT.

  Returns:
    The augmented image.

  Raises:
    ValueError if `augmentation_type` is unknown.
  """
    augmentation_type_map = {
        enums.AugmentationType.SIMCLR:
        functools.partial(_simclr_augment,
                          warp_prob=warp_prob,
                          blur_prob=blur_prob,
                          strength=augmentation_magnitude,
                          use_pytorch_color_jitter=use_pytorch_color_jitter),
        enums.AugmentationType.STACKED_RANDAUGMENT:
        functools.partial(_stacked_simclr_randaugment,
                          warp_prob=warp_prob,
                          blur_prob=blur_prob,
                          strength=augmentation_magnitude,
                          use_pytorch_color_jitter=use_pytorch_color_jitter),
        enums.AugmentationType.RANDAUGMENT:
        augment.RandAugment(magnitude=augmentation_magnitude).distort,
        enums.AugmentationType.AUTOAUGMENT:
        augment.AutoAugment().distort,
        enums.AugmentationType.IDENTITY: (lambda x: x),
    }
    if augmentation_type not in augmentation_type_map:
        raise ValueError(f'Invalid augmentation_type: {augmentation_type}.')

    if image.dtype != tf.uint8:
        raise TypeError(f'Image must have dtype tf.uint8. Was {image.dtype}.')

    if augmentation_magnitude <= 0.:
        return image

    with tf.name_scope('augment_image'):
        return augmentation_type_map[augmentation_type](image)
Beispiel #3
0
def _stacked_simclr_randaugment(image,
                                warp_prob=0.5,
                                blur_prob=0.5,
                                strength=0.5,
                                side_length=IMAGE_SIZE,
                                use_pytorch_color_jitter=False):
    """A combination the data augmentation sequences from SimCLR and RandAugment.

  Citations:
    SimCLR: https://arxiv.org/abs/2002.05709
    RandAugment: https://arxiv.org/abs/1909.13719

  Args:
    image: An image Tensor of shape [height, width, 3] and dtype tf.uint8.
    warp_prob: A Python float between 0 and 1. The probability of applying a
      warp transformation.
    blur_prob: A Python float between 0 and 1. The probability of applying a
      blur transformation.
    strength: strength: A Python float in range [0,1] controlling the maximum
      strength of the augmentations.
    side_length: A Python integer. The length, in pixels, of the width and
      height of `image`.
    use_pytorch_color_jitter: A Python bool. Whether to use a color jittering
      algorithm that aims to replicate `torchvision.transforms.ColorJitter`
      rather than the standard TensorFlow color jittering.

  Returns:
    An image with the same shape and dtype as the input image.
  """
    with tf.name_scope('stacked_simclr_randaugment'):
        image = _random_apply(functools.partial(
            _jitter_colors,
            strength=strength,
            use_pytorch_impl=use_pytorch_color_jitter),
                              p=0.8,
                              x=image)
        image = augment.RandAugment().distort(image)
        image = _random_apply(_rgb_to_gray, p=0.2, x=image)
        if warp_prob > 0.:
            image = _random_apply(functools.partial(_warp,
                                                    side_length=side_length),
                                  p=warp_prob,
                                  x=image)
        if blur_prob > 0.:
            image = _convert_image_dtype(image, tf.float32)
            image = _random_apply(functools.partial(_gaussian_blur,
                                                    side_length=side_length),
                                  p=blur_prob,
                                  x=image)
        return image