def _augment_image(image, augmentation_type=enums.AugmentationType.SIMCLR, warp_prob=0.5, blur_prob=0.5, augmentation_magnitude=0.5, use_pytorch_color_jitter=False): """Applies data augmentation to an image. Args: image: An image Tensor of shape [height, width, 3] and dtype tf.uint8. augmentation_type: An enums.AugmentationType. warp_prob: The probability of applying a warp augmentation at the end (only applies to augmentation_types SIMCLR and STACKED_RANDAUGMENT). blur_prob: A Python float between 0 and 1. The probability of applying a blur transformation. augmentation_magnitude: The magnitude of augmentation. Valid range differs depending on the augmentation_type. use_pytorch_color_jitter: A Python bool. Whether to use a color jittering algorithm that aims to replicate `torchvision.transforms.ColorJitter` rather than the standard TensorFlow color jittering. Only used for augmentation_types SIMCLR and STACKED_RANDAUGMENT. Returns: The augmented image. Raises: ValueError if `augmentation_type` is unknown. """ augmentation_type_map = { enums.AugmentationType.SIMCLR: functools.partial( _simclr_augment, warp_prob=warp_prob, blur_prob=blur_prob, strength=augmentation_magnitude, use_pytorch_color_jitter=use_pytorch_color_jitter), enums.AugmentationType.STACKED_RANDAUGMENT: functools.partial( _stacked_simclr_randaugment, warp_prob=warp_prob, blur_prob=blur_prob, strength=augmentation_magnitude, use_pytorch_color_jitter=use_pytorch_color_jitter), enums.AugmentationType.RANDAUGMENT: augment.RandAugment(magnitude=augmentation_magnitude).distort, enums.AugmentationType.AUTOAUGMENT: augment.AutoAugment().distort, enums.AugmentationType.IDENTITY: (lambda x: x), } if augmentation_type not in augmentation_type_map: raise ValueError(f'Invalid augmentation_type: {augmentation_type}.') if image.dtype != tf.uint8: raise TypeError(f'Image must have dtype tf.uint8. Was {image.dtype}.') if augmentation_magnitude <= 0.: return image with tf.name_scope('augment_image'): return augmentation_type_map[augmentation_type](image)
def test_randaug(self): """Smoke test to be sure there are no syntax errors.""" image = tf.zeros((224, 224, 3), dtype=tf.uint8) augmenter = augment.RandAugment() aug_image = augmenter.distort(image) self.assertEqual((224, 224, 3), aug_image.shape)
def _stacked_simclr_randaugment(image, warp_prob=0.5, blur_prob=0.5, strength=0.5, side_length=IMAGE_SIZE, use_pytorch_color_jitter=False): """A combination the data augmentation sequences from SimCLR and RandAugment. Citations: SimCLR: https://arxiv.org/abs/2002.05709 RandAugment: https://arxiv.org/abs/1909.13719 Args: image: An image Tensor of shape [height, width, 3] and dtype tf.uint8. warp_prob: A Python float between 0 and 1. The probability of applying a warp transformation. blur_prob: A Python float between 0 and 1. The probability of applying a blur transformation. strength: strength: A Python float in range [0,1] controlling the maximum strength of the augmentations. side_length: A Python integer. The length, in pixels, of the width and height of `image`. use_pytorch_color_jitter: A Python bool. Whether to use a color jittering algorithm that aims to replicate `torchvision.transforms.ColorJitter` rather than the standard TensorFlow color jittering. Returns: An image with the same shape and dtype as the input image. """ with tf.name_scope('stacked_simclr_randaugment'): image = _random_apply( functools.partial( _jitter_colors, strength=strength, use_pytorch_impl=use_pytorch_color_jitter), p=0.8, x=image) image = augment.RandAugment().distort(image) image = _random_apply(_rgb_to_gray, p=0.2, x=image) if warp_prob > 0.: image = _random_apply( functools.partial(_warp, side_length=side_length), p=warp_prob, x=image) if blur_prob > 0.: image = _convert_image_dtype(image, tf.float32) image = _random_apply( functools.partial(_gaussian_blur, side_length=side_length), p=blur_prob, x=image) return image