def mixup_and_cutmix(features, labels):
     augmenter = augment.MixupAndCutmix(
         mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
         cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
         prob=params.mixup_and_cutmix.prob,
         label_smoothing=params.mixup_and_cutmix.label_smoothing,
         num_classes=self._get_num_classes())
     features['image'], labels = augmenter(features['image'],
                                           labels)
     return features, labels
  def build_inputs(
      self,
      params: exp_cfg.DataConfig,
      input_context: Optional[tf.distribute.InputContext] = None
  ) -> tf.data.Dataset:
    """Builds classification input."""

    num_classes = self.task_config.model.num_classes
    input_size = self.task_config.model.input_size
    image_field_key = self.task_config.train_data.image_field_key
    label_field_key = self.task_config.train_data.label_field_key
    is_multilabel = self.task_config.train_data.is_multilabel

    if params.tfds_name:
      decoder = tfds_factory.get_classification_decoder(params.tfds_name)
    else:
      decoder = classification_input.Decoder(
          image_field_key=image_field_key, label_field_key=label_field_key,
          is_multilabel=is_multilabel)

    parser = classification_input.Parser(
        output_size=input_size[:2],
        num_classes=num_classes,
        image_field_key=image_field_key,
        label_field_key=label_field_key,
        decode_jpeg_only=params.decode_jpeg_only,
        aug_rand_hflip=params.aug_rand_hflip,
        aug_crop=params.aug_crop,
        aug_type=params.aug_type,
        color_jitter=params.color_jitter,
        random_erasing=params.random_erasing,
        is_multilabel=is_multilabel,
        dtype=params.dtype)

    postprocess_fn = None
    if params.mixup_and_cutmix:
      postprocess_fn = augment.MixupAndCutmix(
          mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
          cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
          prob=params.mixup_and_cutmix.prob,
          label_smoothing=params.mixup_and_cutmix.label_smoothing,
          num_classes=num_classes)

    reader = input_reader_factory.input_reader_generator(
        params,
        dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
        decoder_fn=decoder.decode,
        parser_fn=parser.parse_fn(params.is_training),
        postprocess_fn=postprocess_fn)

    dataset = reader.read(input_context=input_context)

    return dataset
Example #3
0
    def test_mixup_and_cutmix_smoothes_labels(self):
        batch_size = 12
        num_classes = 1000
        label_smoothing = 0.1

        images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
        labels = tf.range(batch_size)
        augmenter = augment.MixupAndCutmix(num_classes=num_classes,
                                           label_smoothing=label_smoothing)

        aug_images, aug_labels = augmenter.distort(images, labels)

        self.assertEqual(images.shape, aug_images.shape)
        self.assertEqual(images.dtype, aug_images.dtype)
        self.assertEqual([batch_size, num_classes], aug_labels.shape)
        self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
                                2. / num_classes)  # With tolerance
        self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
                                   1e4)  # With tolerance
Example #4
0
    def test_cutmix_changes_image(self):
        batch_size = 12
        num_classes = 1000
        label_smoothing = 0.1

        images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
        labels = tf.range(batch_size)
        augmenter = augment.MixupAndCutmix(mixup_alpha=0.,
                                           cutmix_alpha=1.,
                                           num_classes=num_classes)

        aug_images, aug_labels = augmenter.distort(images, labels)

        self.assertEqual(images.shape, aug_images.shape)
        self.assertEqual(images.dtype, aug_images.dtype)
        self.assertEqual([batch_size, num_classes], aug_labels.shape)
        self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
                                2. / num_classes)  # With tolerance
        self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
                                   1e4)  # With tolerance
        self.assertFalse(tf.math.reduce_all(images == aug_images))