예제 #1
0
def model_fn(features, labels, mode, params):
    """Contrastive model function."""

    model_mode = utils.estimator_mode_to_model_mode(mode)
    hparams = params['hparams']

    trainer = ContrastiveTrainer(
        model_inputs=features,
        labels=labels,
        train_global_batch_size=hparams.bs,
        hparams=hparams,
        mode=model_mode,
        num_classes=inputs.get_num_classes(hparams),
        training_set_size=inputs.get_num_train_images(hparams),
        is_tpu=params['use_tpu'])

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions_map = trainer.signature_def_map()
        exports = {
            k: tf.estimator.export.PredictOutput(v)
            for k, v in predictions_map.items()
        }
        # Export a default SignatureDef to keep the API happy.
        exports[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
            exports['contrastive_eval'])
        spec = tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            predictions=predictions_map['contrastive_eval'],
            export_outputs=exports)
        return spec

    # We directly write summaries for the relevant losses, so just hard-code a
    # dummy value to keep the Estimator API happy.
    loss = tf.constant(0.)

    if mode == tf.estimator.ModeKeys.EVAL:
        spec = tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode, loss=loss, eval_metrics=trainer.eval_metrics())
        return spec
    else:  # TRAIN
        spec = tf.estimator.tpu.TPUEstimatorSpec(
            mode=mode,
            train_op=trainer.train_op(),
            loss=loss,
            scaffold_fn=trainer.scaffold_fn(),
            host_call=trainer.host_call(FLAGS.model_dir))
        return spec
예제 #2
0
    def test_get_num_classes(self, input_fn, expected_classes):
        params = make_params(input_fn)

        self.assertEqual(inputs.get_num_classes(params['hparams']),
                         expected_classes)