def preprocess(image, label):
        """Image preprocessing function."""
        if split == tfds.Split.TRAIN:
            image = tf.image.resize_with_crop_or_pad(image, image_shape[0] + 4,
                                                     image_shape[1] + 4)
            image = tf.image.random_crop(image, image_shape)
            image = tf.image.random_flip_left_right(image)

            # Only random augment for now.
            if random_augment:
                count = aug_params['aug_count']
                augmenter = augment_utils.RandAugment()
                augmented = [augmenter.distort(image) for _ in range(count)]
                image = tf.stack(augmented)

        if split == tfds.Split.TRAIN and aug_params['augmix']:
            augmenter = augment_utils.RandAugment()
            image = _augmix(image, aug_params, augmenter, dtype)
        elif normalize:
            image = normalize_convert_image(image, dtype)

        if split == tfds.Split.TRAIN and onehot:
            label = tf.cast(label, tf.int32)
            label = tf.one_hot(label, num_classes)
        else:
            label = tf.cast(label, dtype)
        return image, label
Ejemplo n.º 2
0
    def _example_parser(example: types.Features) -> types.Features:
      """A pre-process function to return images in [0, 1]."""
      image = example['image']
      image_dtype = tf.bloat16 if self._use_bfloat16 else tf.float32
      use_augmix = self._aug_params.get('augmix', False)
      if self._is_training:
        image_shape = tf.shape(image)
        # Expand the image by 2 pixels, then crop back down to 32x32.
        image = tf.image.resize_with_crop_or_pad(
            image, image_shape[0] + 4, image_shape[1] + 4)
        # Note that self._seed will already be shape (2,), as is required for
        # stateless random ops.
        per_example_step_seed = tf.random.experimental.stateless_fold_in(
            self._seed, example[self._enumerate_id_key])
        per_example_step_seeds = tf.random.experimental.stateless_split(
            per_example_step_seed, num=2)
        image = tf.image.stateless_random_crop(
            image,
            (image_shape[0], image_shape[0], 3),
            seed=per_example_step_seeds[0])
        image = tf.image.stateless_random_flip_left_right(
            image,
            seed=per_example_step_seeds[1])

        # Only random augment for now.
        if self._aug_params.get('random_augment', False):
          count = self._aug_params['aug_count']
          augmenter = augment_utils.RandAugment()
          augmented = [augmenter.distort(image) for _ in range(count)]
          image = tf.stack(augmented)

        if use_augmix:
          augmenter = augment_utils.RandAugment()
          image = _augmix(image, self._aug_params, augmenter, image_dtype)

      # The image has values in the range [0, 1].
      # Optionally normalize by the dataset statistics.
      if not use_augmix:
        if self._normalize:
          image = normalize_convert_image(image, image_dtype)
        else:
          image = tf.image.convert_image_dtype(image, image_dtype)
      parsed_example = example.copy()
      parsed_example['features'] = image

      # Note that labels are always float32, even when images are bfloat16.
      mixup_alpha = self._aug_params.get('mixup_alpha', 0)
      label_smoothing = self._aug_params.get('label_smoothing', 0.)
      should_onehot = mixup_alpha > 0 or label_smoothing > 0
      if should_onehot:
        parsed_example['labels'] = tf.one_hot(
            example['label'], 10, dtype=tf.float32)
      else:
        parsed_example['labels'] = tf.cast(example['label'], tf.float32)

      del parsed_example['image']
      del parsed_example['label']
      return parsed_example
Ejemplo n.º 3
0
        def _example_parser(example: types.Features) -> types.Features:
            """A pre-process function to return images in [0, 1]."""
            image = example['image']
            image_dtype = tf.bfloat16 if self._use_bfloat16 else tf.float32
            use_augmix = self._aug_params.get('augmix', False)
            if self._is_training:
                image_shape = tf.shape(image)
                # Expand the image by 2 pixels, then crop back down to 32x32.
                image = tf.image.resize_with_crop_or_pad(
                    image, image_shape[0] + 4, image_shape[1] + 4)
                # Note that self._seed will already be shape (2,), as is required for
                # stateless random ops, and so will per_example_step_seed.
                per_example_step_seed = tf.random.experimental.stateless_fold_in(
                    self._seed, example[self._enumerate_id_key])
                # per_example_step_seeds will be of size (num, 3).
                # First for random_crop, second for flip, third optionally for
                # RandAugment, and foruth optionally for Augmix.
                per_example_step_seeds = tf.random.experimental.stateless_split(
                    per_example_step_seed, num=4)
                image = tf.image.stateless_random_crop(
                    image, (image_shape[0], image_shape[0], 3),
                    seed=per_example_step_seeds[0])
                image = tf.image.stateless_random_flip_left_right(
                    image, seed=per_example_step_seeds[1])

                # Only random augment for now.
                if self._aug_params.get('random_augment', False):
                    count = self._aug_params['aug_count']
                    augment_seeds = tf.random.experimental.stateless_split(
                        per_example_step_seeds[2], num=count)
                    augmenter = augment_utils.RandAugment()
                    augmented = [
                        augmenter.distort(image, seed=augment_seeds[c])
                        for c in range(count)
                    ]
                    image = tf.stack(augmented)

                if use_augmix:
                    augmenter = augment_utils.RandAugment()
                    image = augmix.do_augmix(image,
                                             self._aug_params,
                                             augmenter,
                                             image_dtype,
                                             mean=CIFAR10_MEAN,
                                             std=CIFAR10_STD,
                                             seed=per_example_step_seeds[3])

            # The image has values in the range [0, 1].
            # Optionally normalize by the dataset statistics.
            if not use_augmix:
                if self._normalize:
                    image = augmix.normalize_convert_image(image,
                                                           image_dtype,
                                                           mean=CIFAR10_MEAN,
                                                           std=CIFAR10_STD)
                else:
                    image = tf.image.convert_image_dtype(image, image_dtype)
            parsed_example = {'features': image}
            parsed_example[self._enumerate_id_key] = example[
                self._enumerate_id_key]
            if self._add_fingerprint_key:
                parsed_example[self._fingerprint_key] = example[
                    self._fingerprint_key]

            # Note that labels are always float32, even when images are bfloat16.
            mixup_alpha = self._aug_params.get('mixup_alpha', 0)
            label_smoothing = self._aug_params.get('label_smoothing', 0.)
            should_onehot = mixup_alpha > 0 or label_smoothing > 0

            labels = example['label']

            if should_onehot:
                num_classes = 100 if self.name == 'cifar100' else 10
                parsed_example['labels'] = tf.one_hot(labels,
                                                      num_classes,
                                                      dtype=tf.float32)
            else:
                parsed_example['labels'] = tf.cast(labels, tf.float32)

            return parsed_example