def test_autoaugment(self): """Smoke test to be sure there are no syntax errors.""" image = tf.zeros((224, 224, 3), dtype=tf.uint8) augmenter = augment.AutoAugment() aug_image = augmenter.distort(image) self.assertEqual((224, 224, 3), aug_image.shape)
def _augment_image(image, augmentation_type=enums.AugmentationType.SIMCLR, warp_prob=0.5, blur_prob=0.5, augmentation_magnitude=0.5, use_pytorch_color_jitter=False): """Applies data augmentation to an image. Args: image: An image Tensor of shape [height, width, 3] and dtype tf.uint8. augmentation_type: An enums.AugmentationType. warp_prob: The probability of applying a warp augmentation at the end (only applies to augmentation_types SIMCLR and STACKED_RANDAUGMENT). blur_prob: A Python float between 0 and 1. The probability of applying a blur transformation. augmentation_magnitude: The magnitude of augmentation. Valid range differs depending on the augmentation_type. use_pytorch_color_jitter: A Python bool. Whether to use a color jittering algorithm that aims to replicate `torchvision.transforms.ColorJitter` rather than the standard TensorFlow color jittering. Only used for augmentation_types SIMCLR and STACKED_RANDAUGMENT. Returns: The augmented image. Raises: ValueError if `augmentation_type` is unknown. """ augmentation_type_map = { enums.AugmentationType.SIMCLR: functools.partial(_simclr_augment, warp_prob=warp_prob, blur_prob=blur_prob, strength=augmentation_magnitude, use_pytorch_color_jitter=use_pytorch_color_jitter), enums.AugmentationType.STACKED_RANDAUGMENT: functools.partial(_stacked_simclr_randaugment, warp_prob=warp_prob, blur_prob=blur_prob, strength=augmentation_magnitude, use_pytorch_color_jitter=use_pytorch_color_jitter), enums.AugmentationType.RANDAUGMENT: augment.RandAugment(magnitude=augmentation_magnitude).distort, enums.AugmentationType.AUTOAUGMENT: augment.AutoAugment().distort, enums.AugmentationType.IDENTITY: (lambda x: x), } if augmentation_type not in augmentation_type_map: raise ValueError(f'Invalid augmentation_type: {augmentation_type}.') if image.dtype != tf.uint8: raise TypeError(f'Image must have dtype tf.uint8. Was {image.dtype}.') if augmentation_magnitude <= 0.: return image with tf.name_scope('augment_image'): return augmentation_type_map[augmentation_type](image)