예제 #1
0
 def _map_augment_batch(*inputs):
     image1, image2, image3, label = inputs
     stddev = tf.random.uniform((1,), 0, 0.02, dtype=tf.float32)
     flip_ud = tf.random.uniform([], 0, 1)
     flip_rl = tf.random.uniform([], 0, 1)
     blur = tf.random.uniform([], 0, 1)
     brightness = tf.random.uniform((1,), -0.2, 0.2, dtype=tf.float32)
     image1 = batch_augmentation(image1, stddev, flip_ud, flip_rl, brightness, blur)
     image2 = batch_augmentation(image2, stddev, flip_ud, flip_rl, brightness, blur)
     image3 = batch_augmentation(image3, stddev, flip_ud, flip_rl, brightness, blur)
     return image1, image2, image3, label
예제 #2
0
    def _map_augment_img(*inputs):
        """Data augmentation function to be applied"""
        image, label = inputs
        if augmentation_type == 'simple' or augmentation_type == 'batch':
            image = tf.image.random_flip_left_right(image)
            IMG_SIZE = image.shape[0]  # CIFAR10: 32
            # Add 4 pixels of padding
            image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 4, IMG_SIZE + 4)
            # Random crop back to the original size
            image = tf.image.random_crop(image, size=[IMG_SIZE, IMG_SIZE, 3])
        elif augmentation_type == 'single':
            # Special augmentation applied to each image individually to be comparable to metabyol
            # First, simple aug, then special one
            image = tf.image.random_flip_left_right(image)
            IMG_SIZE = image.shape[0]  # CIFAR10: 32
            # Add 4 pixels of padding
            image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 4, IMG_SIZE + 4)
            # Random crop back to the original size
            image = tf.image.random_crop(image, size=[IMG_SIZE, IMG_SIZE, 3])

            stddev = tf.random.uniform((1,), 0, 0.02, dtype=tf.float32)
            flip_ud = tf.random.uniform([], 0, 1)
            flip_rl = tf.random.uniform([], 0, 1)
            blur = tf.random.uniform([], 0, 1)
            brightness = tf.random.uniform((1,), -0.2, 0.2, dtype=tf.float32)
            image = batch_augmentation(image, stddev, flip_ud, flip_rl, brightness, blur)
        else:
            assert 'augmentation type not supported!'
        return image, label