def cifar10(mode, params):
    """CIFAR10 dataset creator."""
    # Images are naturally 32x32.
    model_mode = utils.estimator_mode_to_model_mode(mode)
    hparams = params['hparams']
    is_training = model_mode == enums.ModelMode.TRAIN
    preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor(
        is_training=is_training,
        preprocessing_options=hparams.input_data.preprocessing,
        dataset_options=preprocessing.DatasetOptions(
            decode_input=False,
            image_mean_std=(np.array([[[-0.0172, -0.0356, -0.107]]]),
                            np.array([[[0.4046, 0.3988, 0.402]]]))),
        bfloat16_supported=params['use_tpu'])
    cifar_input = TfdsInput(
        dataset_name='cifar10:3.*.*',
        split='train' if is_training else 'test',
        mode=model_mode,
        preprocessor=preprocessor,
        shard_per_host=hparams.input_data.shard_per_host,
        cache=is_training,
        shuffle_buffer=50000,
        num_parallel_calls=64,
        max_samples=hparams.input_data.max_samples,
        label_noise_prob=hparams.input_data.label_noise_prob,
        num_classes=get_num_classes(hparams),
        data_dir=params['data_dir'],
    )

    return cifar_input.input_fn(params)
def imagenet(mode, params):
    """An input_fn for ImageNet (ILSVRC 2012) data."""
    model_mode = utils.estimator_mode_to_model_mode(mode)
    hparams = params['hparams']
    is_training = model_mode == enums.ModelMode.TRAIN
    preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor(
        is_training=is_training,
        preprocessing_options=hparams.input_data.preprocessing,
        dataset_options=preprocessing.DatasetOptions(decode_input=False),
        bfloat16_supported=params['use_tpu'])
    imagenet_input = TfdsInput(
        dataset_name='imagenet2012:5.*.*',
        split='train' if is_training else 'validation',
        mode=model_mode,
        preprocessor=preprocessor,
        shuffle_buffer=1024,
        shard_per_host=hparams.input_data.shard_per_host,
        cache=is_training,
        num_parallel_calls=64,
        max_samples=hparams.input_data.max_samples,
        label_noise_prob=hparams.input_data.label_noise_prob,
        num_classes=get_num_classes(hparams),
        data_dir=params['data_dir'],
    )

    return imagenet_input.input_fn(params)
  def test_preprocess_image(self, decode_image, image_dtype, image_size,
                            is_training, augmentation_type, warp_prob,
                            augmentation_magnitude, eval_crop_method,
                            image_mean_std):
    # tf.random.uniform() doesn't allow generating random values that are uint8.
    image = tf.cast(
        tf.random.uniform(
            shape=(300, 400, 3), minval=0, maxval=255, dtype=tf.float32),
        dtype=tf.uint8)
    if decode_image:
      image = tf.image.encode_jpeg(image)
    else:
      image = tf.cast(image, image_dtype)

    expect_error = (not decode_image and image_dtype != tf.uint8)
    if expect_error:
      context_manager = self.assertRaises(AssertionError)
    else:
      context_manager = nullcontext()
    with context_manager:
      output = preprocessing.preprocess_image(
          image,
          is_training=is_training,
          bfloat16_supported=False,
          preprocessing_options=hparams.ImagePreprocessing(
              image_size=image_size,
              augmentation_type=augmentation_type,
              warp_probability=warp_prob,
              augmentation_magnitude=augmentation_magnitude,
              eval_crop_method=eval_crop_method,
          ),
          dataset_options=preprocessing.DatasetOptions(
              image_mean_std=image_mean_std, decode_input=decode_image))
      self.assertEqual(output.dtype, tf.float32)
      self.assertEqual([image_size, image_size, 3], output.shape.as_list())
Exemple #4
0
    def test_input_class(self, input_class, model_mode, image_size,
                         max_samples):
        split = 'train' if model_mode == enums.ModelMode.TRAIN else 'test'
        batch_size = 2
        dataset_size = 10
        expected_num_batches = dataset_size // batch_size
        if max_samples is not None and model_mode == enums.ModelMode.TRAIN:
            expected_num_batches = max_samples // batch_size

        params = {'batch_size': batch_size}
        if input_class == 'TfdsInput':
            with tfds.testing.mock_data(num_examples=dataset_size):
                data = inputs.TfdsInput(
                    'cifar10',
                    split,
                    mode=model_mode,
                    preprocessor=preprocessing.
                    ImageToMultiViewedImagePreprocessor(
                        is_training=model_mode == enums.ModelMode.TRAIN,
                        preprocessing_options=hparams.ImagePreprocessing(
                            image_size=image_size, num_views=2),
                        dataset_options=preprocessing.DatasetOptions(
                            decode_input=False)),
                    max_samples=max_samples,
                    num_classes=10).input_fn(params)
        else:
            raise ValueError(f'Unknown input class {input_class}')

        expected_num_channels = 3 if model_mode == enums.ModelMode.INFERENCE else 6
        expected_batch_size = (None if model_mode == enums.ModelMode.INFERENCE
                               else batch_size)

        if model_mode == enums.ModelMode.INFERENCE:
            self.assertIsInstance(
                data, tf.estimator.export.TensorServingInputReceiver)
            image_shape = data.features.shape.as_list()
        else:
            self.assertIsInstance(data, tf.data.Dataset)
            shapes = tf.data.get_output_shapes(data)
            image_shape = shapes[0].as_list()
            label_shape = shapes[1].as_list()
            self.assertEqual([expected_batch_size], label_shape)
        self.assertEqual([
            expected_batch_size, image_size, image_size, expected_num_channels
        ], image_shape)

        if model_mode == enums.ModelMode.INFERENCE:
            return

        # Now extract the Tensors
        data = tf.data.make_one_shot_iterator(data).get_next()[0]

        with self.cached_session() as sess:
            for i in range(expected_num_batches + 1):
                if i == expected_num_batches and model_mode == enums.ModelMode.EVAL:
                    with self.assertRaises(tf.errors.OutOfRangeError):
                        sess.run(data)
                    break
                else:
                    sess.run(data)
  def test_image_to_multi_viewed_image_preprocessor(self, is_training,
                                                    decode_image, num_views):
    # tf.random.uniform() doesn't allow generating random values that are uint8.
    image = tf.cast(
        tf.random.uniform(
            shape=(300, 400, 3), minval=0, maxval=255, dtype=tf.float32),
        dtype=tf.uint8)
    if decode_image:
      image = tf.image.encode_jpeg(image)

    image_size = 32
    preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor(
        is_training=is_training,
        preprocessing_options=hparams.ImagePreprocessing(
            image_size=image_size, num_views=num_views),
        dataset_options=preprocessing.DatasetOptions(decode_input=decode_image))
    output = preprocessor.preprocess(image)
    self.assertEqual(output.dtype, tf.float32)
    self.assertEqual([image_size, image_size, 3 * num_views],
                     output.shape.as_list())