예제 #1
0
def imagenet(mode, params):
    """An input_fn for ImageNet (ILSVRC 2012) data."""
    model_mode = utils.estimator_mode_to_model_mode(mode)
    hparams = params['hparams']
    is_training = model_mode == enums.ModelMode.TRAIN
    preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor(
        is_training=is_training,
        preprocessing_options=hparams.input_data.preprocessing,
        dataset_options=preprocessing.DatasetOptions(decode_input=False),
        bfloat16_supported=params['use_tpu'])
    imagenet_input = TfdsInput(
        dataset_name='imagenet2012:5.*.*',
        split='train' if is_training else 'validation',
        mode=model_mode,
        preprocessor=preprocessor,
        shuffle_buffer=1024,
        shard_per_host=hparams.input_data.shard_per_host,
        cache=is_training,
        num_parallel_calls=64,
        max_samples=hparams.input_data.max_samples,
        label_noise_prob=hparams.input_data.label_noise_prob,
        num_classes=get_num_classes(hparams),
        data_dir=params['data_dir'],
    )

    return imagenet_input.input_fn(params)
예제 #2
0
def cifar10(mode, params):
    """CIFAR10 dataset creator."""
    # Images are naturally 32x32.
    model_mode = utils.estimator_mode_to_model_mode(mode)
    hparams = params['hparams']
    is_training = model_mode == enums.ModelMode.TRAIN
    preprocessor = preprocessing.ImageToMultiViewedImagePreprocessor(
        is_training=is_training,
        preprocessing_options=hparams.input_data.preprocessing,
        dataset_options=preprocessing.DatasetOptions(
            decode_input=False,
            image_mean_std=(np.array([[[-0.0172, -0.0356, -0.107]]]),
                            np.array([[[0.4046, 0.3988, 0.402]]]))),
        bfloat16_supported=params['use_tpu'])
    cifar_input = TfdsInput(
        dataset_name='cifar10:3.*.*',
        split='train' if is_training else 'test',
        mode=model_mode,
        preprocessor=preprocessor,
        shard_per_host=hparams.input_data.shard_per_host,
        cache=is_training,
        shuffle_buffer=50000,
        num_parallel_calls=64,
        max_samples=hparams.input_data.max_samples,
        label_noise_prob=hparams.input_data.label_noise_prob,
        num_classes=get_num_classes(hparams),
        data_dir=params['data_dir'],
    )

    return cifar_input.input_fn(params)
예제 #3
0
def model_fn(features, labels, mode, params):
    """Contrastive model function."""

    model_mode = utils.estimator_mode_to_model_mode(mode)
    hparams = params['hparams']

    trainer = ContrastiveTrainer(
        model_inputs=features,
        labels=labels,
        train_global_batch_size=hparams.bs,
        hparams=hparams,
        mode=model_mode,
        num_classes=inputs.get_num_classes(hparams),
        training_set_size=inputs.get_num_train_images(hparams),
        is_tpu=params['use_tpu'])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions_map = trainer.signature_def_map()
        exports = {
            k: tf.estimator.export.PredictOutput(v)
            for k, v in predictions_map.items()
        }
        # Export a default SignatureDef to keep the API happy.
        exports[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
            exports['contrastive_eval'])
        spec = tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions_map['contrastive_eval'],
            export_outputs=exports)
        return spec

    # We directly write summaries for the relevant losses, so just hard-code a
    # dummy value to keep the Estimator API happy.
    loss = tf.constant(0.)

    if mode == tf.estimator.ModeKeys.EVAL:
        spec = tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, eval_metrics=trainer.eval_metrics())
        return spec
    else:  # TRAIN
        spec = tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            train_op=trainer.train_op(),
            loss=loss,
            scaffold_fn=trainer.scaffold_fn(),
            host_call=trainer.host_call(FLAGS.model_dir))
        return spec