def _create_estimator_spec(self, current_iteration, mode, iteration_number_tensor, previous_iteration_vars): """See the `Estimator` base class for details.""" if not self._use_tpu: return super(TPUEstimator, self)._create_estimator_spec(current_iteration, mode, iteration_number_tensor, previous_iteration_vars) training = mode == tf.estimator.ModeKeys.TRAIN iteration_estimator_spec = current_iteration.estimator_spec return tf_compat.TPUEstimatorSpec( mode=mode, predictions=iteration_estimator_spec.predictions, loss=iteration_estimator_spec.loss, train_op=self._train_op(iteration_estimator_spec), host_call=self._create_host_call(current_iteration, training), eval_metrics=iteration_estimator_spec.eval_metrics, export_outputs=iteration_estimator_spec.export_outputs, # Return a constant summary_op, otherwise `Estimator` creates summary # ops that do not work on TPU. scaffold_fn=lambda: tf.train.Scaffold(summary_op=tf.constant("")), training_hooks=self._decorate_hooks( self._training_hooks(current_iteration, training, iteration_number_tensor, previous_iteration_vars)), evaluation_hooks=self._evaluation_hooks(current_iteration, training))
def setUp(self): super(MetricsTest, self).setUp() # We only test the multi head since this is the general case. self._features = {"x": tf.constant([[1.], [2.]])} heads = ("head_1", "head_2") labels = tf.constant([0, 1]) self._labels = {head: labels for head in heads} predictions = {(head, "predictions"): labels for head in heads} loss = tf.constant(2.) self._estimator_spec = tf_compat.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=loss, predictions=predictions, eval_metrics=(self._spec_metric_fn, { "features": self._features, "labels": self._labels, "predictions": predictions, "loss": loss }))