def preprocess(image, label): """Image preprocessing function.""" if split == tfds.Split.TRAIN: image = tf.image.resize_with_crop_or_pad(image, image_shape[0] + 4, image_shape[1] + 4) image = tf.image.random_crop(image, image_shape) image = tf.image.random_flip_left_right(image) # Only random augment for now. if random_augment: count = aug_params['aug_count'] augmenter = augment_utils.RandAugment() augmented = [augmenter.distort(image) for _ in range(count)] image = tf.stack(augmented) if split == tfds.Split.TRAIN and aug_params['augmix']: augmenter = augment_utils.RandAugment() image = _augmix(image, aug_params, augmenter, dtype) elif normalize: image = normalize_convert_image(image, dtype) if split == tfds.Split.TRAIN and onehot: label = tf.cast(label, tf.int32) label = tf.one_hot(label, num_classes) else: label = tf.cast(label, dtype) return image, label
def _example_parser(example: types.Features) -> types.Features: """A pre-process function to return images in [0, 1].""" image = example['image'] image_dtype = tf.bloat16 if self._use_bfloat16 else tf.float32 use_augmix = self._aug_params.get('augmix', False) if self._is_training: image_shape = tf.shape(image) # Expand the image by 2 pixels, then crop back down to 32x32. image = tf.image.resize_with_crop_or_pad( image, image_shape[0] + 4, image_shape[1] + 4) # Note that self._seed will already be shape (2,), as is required for # stateless random ops. per_example_step_seed = tf.random.experimental.stateless_fold_in( self._seed, example[self._enumerate_id_key]) per_example_step_seeds = tf.random.experimental.stateless_split( per_example_step_seed, num=2) image = tf.image.stateless_random_crop( image, (image_shape[0], image_shape[0], 3), seed=per_example_step_seeds[0]) image = tf.image.stateless_random_flip_left_right( image, seed=per_example_step_seeds[1]) # Only random augment for now. if self._aug_params.get('random_augment', False): count = self._aug_params['aug_count'] augmenter = augment_utils.RandAugment() augmented = [augmenter.distort(image) for _ in range(count)] image = tf.stack(augmented) if use_augmix: augmenter = augment_utils.RandAugment() image = _augmix(image, self._aug_params, augmenter, image_dtype) # The image has values in the range [0, 1]. # Optionally normalize by the dataset statistics. if not use_augmix: if self._normalize: image = normalize_convert_image(image, image_dtype) else: image = tf.image.convert_image_dtype(image, image_dtype) parsed_example = example.copy() parsed_example['features'] = image # Note that labels are always float32, even when images are bfloat16. mixup_alpha = self._aug_params.get('mixup_alpha', 0) label_smoothing = self._aug_params.get('label_smoothing', 0.) should_onehot = mixup_alpha > 0 or label_smoothing > 0 if should_onehot: parsed_example['labels'] = tf.one_hot( example['label'], 10, dtype=tf.float32) else: parsed_example['labels'] = tf.cast(example['label'], tf.float32) del parsed_example['image'] del parsed_example['label'] return parsed_example
def _example_parser(example: types.Features) -> types.Features: """A pre-process function to return images in [0, 1].""" image = example['image'] image_dtype = tf.bfloat16 if self._use_bfloat16 else tf.float32 use_augmix = self._aug_params.get('augmix', False) if self._is_training: image_shape = tf.shape(image) # Expand the image by 2 pixels, then crop back down to 32x32. image = tf.image.resize_with_crop_or_pad( image, image_shape[0] + 4, image_shape[1] + 4) # Note that self._seed will already be shape (2,), as is required for # stateless random ops, and so will per_example_step_seed. per_example_step_seed = tf.random.experimental.stateless_fold_in( self._seed, example[self._enumerate_id_key]) # per_example_step_seeds will be of size (num, 3). # First for random_crop, second for flip, third optionally for # RandAugment, and foruth optionally for Augmix. per_example_step_seeds = tf.random.experimental.stateless_split( per_example_step_seed, num=4) image = tf.image.stateless_random_crop( image, (image_shape[0], image_shape[0], 3), seed=per_example_step_seeds[0]) image = tf.image.stateless_random_flip_left_right( image, seed=per_example_step_seeds[1]) # Only random augment for now. if self._aug_params.get('random_augment', False): count = self._aug_params['aug_count'] augment_seeds = tf.random.experimental.stateless_split( per_example_step_seeds[2], num=count) augmenter = augment_utils.RandAugment() augmented = [ augmenter.distort(image, seed=augment_seeds[c]) for c in range(count) ] image = tf.stack(augmented) if use_augmix: augmenter = augment_utils.RandAugment() image = augmix.do_augmix(image, self._aug_params, augmenter, image_dtype, mean=CIFAR10_MEAN, std=CIFAR10_STD, seed=per_example_step_seeds[3]) # The image has values in the range [0, 1]. # Optionally normalize by the dataset statistics. if not use_augmix: if self._normalize: image = augmix.normalize_convert_image(image, image_dtype, mean=CIFAR10_MEAN, std=CIFAR10_STD) else: image = tf.image.convert_image_dtype(image, image_dtype) parsed_example = {'features': image} parsed_example[self._enumerate_id_key] = example[ self._enumerate_id_key] if self._add_fingerprint_key: parsed_example[self._fingerprint_key] = example[ self._fingerprint_key] # Note that labels are always float32, even when images are bfloat16. mixup_alpha = self._aug_params.get('mixup_alpha', 0) label_smoothing = self._aug_params.get('label_smoothing', 0.) should_onehot = mixup_alpha > 0 or label_smoothing > 0 labels = example['label'] if should_onehot: num_classes = 100 if self.name == 'cifar100' else 10 parsed_example['labels'] = tf.one_hot(labels, num_classes, dtype=tf.float32) else: parsed_example['labels'] = tf.cast(labels, tf.float32) return parsed_example