Example #1
0
def imagenet(mode, params):
    """An input_fn for ImageNet (ILSVRC 2012) data."""
    model_mode = utils.estimator_mode_to_model_mode(mode)
    hparams = params['hparams']
    is_training = model_mode == enums.ModelMode.TRAIN
    preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor(
        is_training=is_training,
        preprocessing_options=hparams.input_data.preprocessing,
        dataset_options=preprocessing.DatasetOptions(decode_input=False),
        bfloat16_supported=params['use_tpu'])
    imagenet_input = TfdsInput(
        dataset_name='imagenet2012:5.*.*',
        split='train' if is_training else 'validation',
        mode=model_mode,
        preprocessor=preprocessor,
        shuffle_buffer=1024,
        shard_per_host=hparams.input_data.shard_per_host,
        cache=is_training,
        num_parallel_calls=64,
        max_samples=hparams.input_data.max_samples,
        label_noise_prob=hparams.input_data.label_noise_prob,
        num_classes=get_num_classes(hparams),
        data_dir=params['data_dir'],
    )

    return imagenet_input.input_fn(params)
Example #2
0
def cifar10(mode, params):
    """CIFAR10 dataset creator."""
    # Images are naturally 32x32.
    model_mode = utils.estimator_mode_to_model_mode(mode)
    hparams = params['hparams']
    is_training = model_mode == enums.ModelMode.TRAIN
    preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor(
        is_training=is_training,
        preprocessing_options=hparams.input_data.preprocessing,
        dataset_options=preprocessing.DatasetOptions(
            decode_input=False,
            image_mean_std=(np.array([[[-0.0172, -0.0356, -0.107]]]),
                            np.array([[[0.4046, 0.3988, 0.402]]]))),
        bfloat16_supported=params['use_tpu'])
    cifar_input = TfdsInput(
        dataset_name='cifar10:3.*.*',
        split='train' if is_training else 'test',
        mode=model_mode,
        preprocessor=preprocessor,
        shard_per_host=hparams.input_data.shard_per_host,
        cache=is_training,
        shuffle_buffer=50000,
        num_parallel_calls=64,
        max_samples=hparams.input_data.max_samples,
        label_noise_prob=hparams.input_data.label_noise_prob,
        num_classes=get_num_classes(hparams),
        data_dir=params['data_dir'],
    )

    return cifar_input.input_fn(params)
  def test_image_to_multi_viewed_image_preprocessor(self, is_training,
                                                    decode_image, num_views):
    # tf.random.uniform() doesn't allow generating random values that are uint8.
    image = tf.cast(
        tf.random.uniform(
            shape=(300, 400, 3), minval=0, maxval=255, dtype=tf.float32),
        dtype=tf.uint8)
    if decode_image:
      image = tf.image.encode_jpeg(image)

    image_size = 32
    preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor(
        is_training=is_training,
        preprocessing_options=hparams.ImagePreprocessing(
            image_size=image_size, num_views=num_views),
        dataset_options=preprocessing.DatasetOptions(decode_input=decode_image))
    output = preprocessor.preprocess(image)
    self.assertEqual(output.dtype, tf.float32)
    self.assertEqual([image_size, image_size, 3 * num_views],
                     output.shape.as_list())