Пример #1
0
  def test_autoaugment(self):
    """Smoke test to be sure there are no syntax errors."""
    image = tf.zeros((224, 224, 3), dtype=tf.uint8)

    augmenter = augment.AutoAugment()
    aug_image = augmenter.distort(image)

    self.assertEqual((224, 224, 3), aug_image.shape)
Пример #2
0
def _augment_image(image,
                   augmentation_type=enums.AugmentationType.SIMCLR,
                   warp_prob=0.5,
                   blur_prob=0.5,
                   augmentation_magnitude=0.5,
                   use_pytorch_color_jitter=False):
    """Applies data augmentation to an image.

  Args:
    image: An image Tensor of shape [height, width, 3] and dtype tf.uint8.
    augmentation_type: An enums.AugmentationType.
    warp_prob: The probability of applying a warp augmentation at the end (only
      applies to augmentation_types SIMCLR and STACKED_RANDAUGMENT).
    blur_prob: A Python float between 0 and 1. The probability of applying a
      blur transformation.
    augmentation_magnitude: The magnitude of augmentation. Valid range differs
      depending on the augmentation_type.
    use_pytorch_color_jitter: A Python bool. Whether to use a color jittering
      algorithm that aims to replicate `torchvision.transforms.ColorJitter`
      rather than the standard TensorFlow color jittering. Only used for
      augmentation_types SIMCLR and STACKED_RANDAUGMENT.

  Returns:
    The augmented image.

  Raises:
    ValueError if `augmentation_type` is unknown.
  """
    augmentation_type_map = {
        enums.AugmentationType.SIMCLR:
        functools.partial(_simclr_augment,
                          warp_prob=warp_prob,
                          blur_prob=blur_prob,
                          strength=augmentation_magnitude,
                          use_pytorch_color_jitter=use_pytorch_color_jitter),
        enums.AugmentationType.STACKED_RANDAUGMENT:
        functools.partial(_stacked_simclr_randaugment,
                          warp_prob=warp_prob,
                          blur_prob=blur_prob,
                          strength=augmentation_magnitude,
                          use_pytorch_color_jitter=use_pytorch_color_jitter),
        enums.AugmentationType.RANDAUGMENT:
        augment.RandAugment(magnitude=augmentation_magnitude).distort,
        enums.AugmentationType.AUTOAUGMENT:
        augment.AutoAugment().distort,
        enums.AugmentationType.IDENTITY: (lambda x: x),
    }
    if augmentation_type not in augmentation_type_map:
        raise ValueError(f'Invalid augmentation_type: {augmentation_type}.')

    if image.dtype != tf.uint8:
        raise TypeError(f'Image must have dtype tf.uint8. Was {image.dtype}.')

    if augmentation_magnitude <= 0.:
        return image

    with tf.name_scope('augment_image'):
        return augmentation_type_map[augmentation_type](image)