Пример #1
0
    def _create_estimator_spec(self, current_iteration, mode,
                               iteration_number_tensor,
                               previous_iteration_vars):
        """See the `Estimator` base class for details."""

        if not self._use_tpu:
            return super(TPUEstimator,
                         self)._create_estimator_spec(current_iteration, mode,
                                                      iteration_number_tensor,
                                                      previous_iteration_vars)

        training = mode == tf.estimator.ModeKeys.TRAIN
        iteration_estimator_spec = current_iteration.estimator_spec
        return tf_compat.TPUEstimatorSpec(
            mode=mode,
            predictions=iteration_estimator_spec.predictions,
            loss=iteration_estimator_spec.loss,
            train_op=self._train_op(iteration_estimator_spec),
            host_call=self._create_host_call(current_iteration, training),
            eval_metrics=iteration_estimator_spec.eval_metrics,
            export_outputs=iteration_estimator_spec.export_outputs,
            # Return a constant summary_op, otherwise `Estimator` creates summary
            # ops that do not work on TPU.
            scaffold_fn=lambda: tf.train.Scaffold(summary_op=tf.constant("")),
            training_hooks=self._decorate_hooks(
                self._training_hooks(current_iteration, training,
                                     iteration_number_tensor,
                                     previous_iteration_vars)),
            evaluation_hooks=self._evaluation_hooks(current_iteration,
                                                    training))
Пример #2
0
  def setUp(self):
    super(MetricsTest, self).setUp()

    # We only test the multi head since this is the general case.
    self._features = {"x": tf.constant([[1.], [2.]])}
    heads = ("head_1", "head_2")
    labels = tf.constant([0, 1])
    self._labels = {head: labels for head in heads}
    predictions = {(head, "predictions"): labels for head in heads}
    loss = tf.constant(2.)
    self._estimator_spec = tf_compat.TPUEstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        predictions=predictions,
        eval_metrics=(self._spec_metric_fn, {
            "features": self._features,
            "labels": self._labels,
            "predictions": predictions,
            "loss": loss
        }))