def model_fn(features, labels, mode, params): """Contrastive model function.""" model_mode = utils.estimator_mode_to_model_mode(mode) hparams = params['hparams'] trainer = ContrastiveTrainer( model_inputs=features, labels=labels, train_global_batch_size=hparams.bs, hparams=hparams, mode=model_mode, num_classes=inputs.get_num_classes(hparams), training_set_size=inputs.get_num_train_images(hparams), is_tpu=params['use_tpu']) if mode == tf.estimator.ModeKeys.PREDICT: predictions_map = trainer.signature_def_map() exports = { k: tf.estimator.export.PredictOutput(v) for k, v in predictions_map.items() } # Export a default SignatureDef to keep the API happy. exports[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( exports['contrastive_eval']) spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions_map['contrastive_eval'], export_outputs=exports) return spec # We directly write summaries for the relevant losses, so just hard-code a # dummy value to keep the Estimator API happy. loss = tf.constant(0.) if mode == tf.estimator.ModeKeys.EVAL: spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=trainer.eval_metrics()) return spec else: # TRAIN spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, train_op=trainer.train_op(), loss=loss, scaffold_fn=trainer.scaffold_fn(), host_call=trainer.host_call(FLAGS.model_dir)) return spec
def test_get_num_classes(self, input_fn, expected_classes): params = make_params(input_fn) self.assertEqual(inputs.get_num_classes(params['hparams']), expected_classes)