Esempio n. 1
0
    def _unl_parser(value):
        """Cifar10 parser."""
        image_size = params.image_size
        value = tf.io.decode_raw(value, tf.uint8)
        image = tf.reshape(value[1:], [3, 32, 32])  # uint8
        image = tf.transpose(image, [1, 2, 0])
        if image_size != 32:
            image = tf.image.resize_bicubic([image],
                                            [image_size, image_size])[0]
        image.set_shape([image_size, image_size, 3])
        image = _flip_and_jitter(image, replace_value=128)
        ori_image = image

        aug = augment.RandAugment(cutout_const=image_size // 8,
                                  translate_const=image_size // 8,
                                  magnitude=params.augment_magnitude)
        aug_image = aug.distort(image)
        aug_image = augment.cutout(aug_image,
                                   pad_size=image_size // 4,
                                   replace=128)
        aug_image = _flip_and_jitter(aug_image, replace_value=128)

        ori_image = convert_and_normalize(params, ori_image)
        aug_image = convert_and_normalize(params, aug_image)
        return {'ori_images': ori_image, 'aug_images': aug_image}
Esempio n. 2
0
def _cifar10_parser(params, value, training):
    """Cifar10 parser."""
    image_size = params.image_size
    value = tf.io.decode_raw(value, tf.uint8)
    label = tf.cast(value[0], tf.int32)
    label = tf.one_hot(label, depth=params.num_classes, dtype=tf.float32)
    image = tf.reshape(value[1:], [3, 32, 32])  # uint8
    image = tf.transpose(image, [1, 2, 0])
    if image_size != 32:
        image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
    image.set_shape([image_size, image_size, 3])

    if training:
        if params.use_augment:
            aug = augment.RandAugment(cutout_const=image_size // 8,
                                      translate_const=image_size // 8,
                                      magnitude=params.augment_magnitude)
            image = _flip_and_jitter(image, 128)
            image = aug.distort(image)
            image = augment.cutout(image,
                                   pad_size=image_size // 4,
                                   replace=128)
        else:
            image = _flip_and_jitter(image, 128)
    image = convert_and_normalize(params, image)
    return image, label