def _model_fn(features, labels, params, mode, config): """model_fn implementation.""" if mode == tf.estimator.ModeKeys.TRAIN: self._register_word_counters(features, labels) features_shards = dispatcher.shard(features) labels_shards = dispatcher.shard(labels) with tf.variable_scope(self.name, initializer=self._initializer(params)): losses_shards = dispatcher(_loss_op, features_shards, labels_shards, params, mode, config) loss = _extract_loss(losses_shards) train_op = optimize(loss, params) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) elif mode == tf.estimator.ModeKeys.EVAL: with tf.variable_scope(self.name): logits, predictions = self._build(features, labels, params, mode, config=config) loss = self._compute_loss(features, labels, logits, params, mode) loss = _extract_loss(loss) eval_metric_ops = self._compute_metrics( features, labels, predictions) if predictions is not None: # Register predictions in a collection so that hooks can easily fetch them. add_dict_to_collection("predictions", predictions) return tf.estimator.EstimatorSpec( mode, loss=loss, eval_metric_ops=eval_metric_ops) elif mode == tf.estimator.ModeKeys.PREDICT: with tf.variable_scope(self.name): _, predictions = self._build(features, labels, params, mode, config=config) export_outputs = {} export_outputs[tf.saved_model.signature_constants. DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( tf.estimator.export.PredictOutput( predictions)) return tf.estimator.EstimatorSpec( mode, predictions=predictions, export_outputs=export_outputs) else: raise RuntimeError("Invalid mode")
def _model_fn(features, labels, params, mode, config): """model_fn implementation.""" if mode == tf.estimator.ModeKeys.TRAIN: counters = self._register_word_counters(features, labels) counters_hook = hooks.CountersHook( every_n_steps=config.save_summary_steps, output_dir=config.model_dir, counters=counters) features_shards = dispatcher.shard(features) labels_shards = dispatcher.shard(labels) with tf.variable_scope(self.name, initializer=self._initializer(params)): losses_shards = dispatcher( _loss_op, features_shards, labels_shards, params, mode, config) loss = _extract_loss(losses_shards) train_op = optimize(loss, params, mixed_precision=(self.dtype == tf.float16)) return tf.estimator.EstimatorSpec( mode, loss=loss, train_op=train_op, training_hooks=[counters_hook]) elif mode == tf.estimator.ModeKeys.EVAL: with tf.variable_scope(self.name): logits, predictions = self._build(features, labels, params, mode, config=config) loss = self._compute_loss(features, labels, logits, params, mode) loss = _extract_loss(loss) eval_metric_ops = self._compute_metrics(features, labels, predictions) evaluation_hooks = [] if predictions is not None and eval_prediction_hooks_fn is not None: evaluation_hooks.extend(eval_prediction_hooks_fn(predictions)) return tf.estimator.EstimatorSpec( mode, loss=loss, eval_metric_ops=eval_metric_ops, evaluation_hooks=evaluation_hooks) elif mode == tf.estimator.ModeKeys.PREDICT: with tf.variable_scope(self.name): _, predictions = self._build(features, labels, params, mode, config=config) export_outputs = {} export_outputs[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( tf.estimator.export.PredictOutput(predictions)) return tf.estimator.EstimatorSpec( mode, predictions=predictions, export_outputs=export_outputs) else: raise RuntimeError("Invalid mode")
def __call__(self, features, labels, params, mode, config): """Creates the model. See Also: ``tf.estimator.Estimator`` 's ``model_fn`` argument for more details about arguments and the returned value. """ if mode == tf.estimator.ModeKeys.TRAIN: self._register_word_counters(features, labels) with tf.variable_scope( self.name, initializer=self._initializer(params)) as model_scope: outputs, predictions = self._build(features, labels, params, mode, config) if predictions is not None: # Register predictions in a collection so that hooks can easily fetch them. add_dict_to_collection("predictions", predictions) if mode != tf.estimator.ModeKeys.PREDICT: with tf.variable_scope(model_scope): loss = self._compute_loss(features, labels, outputs, params, mode) if isinstance(loss, tuple): loss, display_loss = loss else: display_loss = loss tf.summary.scalar("loss", display_loss) if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimize(loss, params) return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) else: eval_metric_ops = self._compute_metrics( features, labels, predictions) return tf.estimator.EstimatorSpec( mode, loss=loss, eval_metric_ops=eval_metric_ops) else: export_outputs = {} export_outputs[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \ tf.estimator.export.PredictOutput(predictions) return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs)