def build_inputs( self, params: exp_cfg.DataConfig, input_context: Optional[tf.distribute.InputContext] = None ) -> tf.data.Dataset: """Builds classification input.""" num_classes = self.task_config.model.num_classes input_size = self.task_config.model.input_size image_field_key = self.task_config.train_data.image_field_key label_field_key = self.task_config.train_data.label_field_key is_multilabel = self.task_config.train_data.is_multilabel if params.tfds_name: decoder = tfds_factory.get_classification_decoder(params.tfds_name) else: decoder = classification_input.Decoder( image_field_key=image_field_key, label_field_key=label_field_key, is_multilabel=is_multilabel) parser = classification_input.Parser( output_size=input_size[:2], num_classes=num_classes, image_field_key=image_field_key, label_field_key=label_field_key, decode_jpeg_only=params.decode_jpeg_only, aug_rand_hflip=params.aug_rand_hflip, aug_type=params.aug_type, color_jitter=params.color_jitter, random_erasing=params.random_erasing, is_multilabel=is_multilabel, dtype=params.dtype) postprocess_fn = None if params.mixup_and_cutmix: postprocess_fn = augment.MixupAndCutmix( mixup_alpha=params.mixup_and_cutmix.mixup_alpha, cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha, prob=params.mixup_and_cutmix.prob, label_smoothing=params.mixup_and_cutmix.label_smoothing, num_classes=num_classes) reader = input_reader_factory.input_reader_generator( params, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), decoder_fn=decoder.decode, parser_fn=parser.parse_fn(params.is_training), postprocess_fn=postprocess_fn) dataset = reader.read(input_context=input_context) return dataset
def test_mixup_and_cutmix_smoothes_labels(self): batch_size = 12 num_classes = 1000 label_smoothing = 0.1 images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) labels = tf.range(batch_size) augmenter = augment.MixupAndCutmix(num_classes=num_classes, label_smoothing=label_smoothing) aug_images, aug_labels = augmenter.distort(images, labels) self.assertEqual(images.shape, aug_images.shape) self.assertEqual(images.dtype, aug_images.dtype) self.assertEqual([batch_size, num_classes], aug_labels.shape) self.assertAllLessEqual(aug_labels, 1. - label_smoothing + 2. / num_classes) # With tolerance self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes - 1e4) # With tolerance
def test_cutmix_changes_image(self): batch_size = 12 num_classes = 1000 label_smoothing = 0.1 images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32) labels = tf.range(batch_size) augmenter = augment.MixupAndCutmix(mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes) aug_images, aug_labels = augmenter.distort(images, labels) self.assertEqual(images.shape, aug_images.shape) self.assertEqual(images.dtype, aug_images.dtype) self.assertEqual([batch_size, num_classes], aug_labels.shape) self.assertAllLessEqual(aug_labels, 1. - label_smoothing + 2. / num_classes) # With tolerance self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes - 1e4) # With tolerance self.assertFalse(tf.math.reduce_all(images == aug_images))