示例#1
0
  def test_build_all_signature_defs_without_receiver_alternatives(self):
    receiver_tensor = array_ops.placeholder(dtypes.string)
    output_1 = constant_op.constant([1.])
    output_2 = constant_op.constant(["2"])
    output_3 = constant_op.constant(["3"])
    export_outputs = {
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            export_output.RegressionOutput(value=output_1),
        "head-2": export_output.ClassificationOutput(classes=output_2),
        "head-3": export_output.PredictOutput(outputs={
            "some_output_3": output_3
        }),
    }

    signature_defs = export.build_all_signature_defs(
        receiver_tensor, export_outputs)

    expected_signature_defs = {
        "serving_default":
            signature_def_utils.regression_signature_def(receiver_tensor,
                                                         output_1),
        "head-2":
            signature_def_utils.classification_signature_def(receiver_tensor,
                                                             output_2, None),
        "head-3":
            signature_def_utils.predict_signature_def({
                "input": receiver_tensor
            }, {"some_output_3": output_3})
    }

    self.assertDictEqual(expected_signature_defs, signature_defs)
示例#2
0
    def test_build_standardized_signature_def_regression(self):
        input_tensors = {
            "input-1":
            array_ops.placeholder(dtypes.string, 1, name="input-tensor-1")
        }
        value = array_ops.placeholder(dtypes.float32,
                                      1,
                                      name="output-tensor-1")

        export_output = export_output_lib.RegressionOutput(value)
        actual_signature_def = export_output.as_signature_def(input_tensors)

        expected_signature_def = meta_graph_pb2.SignatureDef()
        shape = tensor_shape_pb2.TensorShapeProto(
            dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
        dtype_float = types_pb2.DataType.Value("DT_FLOAT")
        dtype_string = types_pb2.DataType.Value("DT_STRING")
        expected_signature_def.inputs[
            signature_constants.REGRESS_INPUTS].CopyFrom(
                meta_graph_pb2.TensorInfo(name="input-tensor-1:0",
                                          dtype=dtype_string,
                                          tensor_shape=shape))
        expected_signature_def.outputs[
            signature_constants.REGRESS_OUTPUTS].CopyFrom(
                meta_graph_pb2.TensorInfo(name="output-tensor-1:0",
                                          dtype=dtype_float,
                                          tensor_shape=shape))

        expected_signature_def.method_name = signature_constants.REGRESS_METHOD_NAME
        self.assertEqual(actual_signature_def, expected_signature_def)
示例#3
0
 def test_regress_value_must_be_float(self):
     value = constant_op.constant("1",
                                  dtype=dtypes.string,
                                  name="output-tensor-1")
     with self.assertRaisesRegexp(
             ValueError,
             "Regression output value must be a float32 Tensor;"):
         export_output_lib.RegressionOutput(value)
示例#4
0
 def test_regress_value_must_be_float(self):
     value = array_ops.placeholder(dtypes.string, 1, name="output-tensor-1")
     with self.assertRaises(ValueError) as e:
         export_output_lib.RegressionOutput(value)
     self.assertEqual(
         'Regression output value must be a float32 Tensor; got '
         'Tensor("output-tensor-1:0", shape=(1,), dtype=string)',
         str(e.exception))
示例#5
0
  def test_build_all_signature_defs_with_single_alternatives(self):
    receiver_tensor = array_ops.placeholder(dtypes.string)
    receiver_tensors_alternative_1 = array_ops.placeholder(dtypes.int64)
    receiver_tensors_alternative_2 = array_ops.sparse_placeholder(
        dtypes.float32)
    # Note we are passing single Tensors as values of
    # receiver_tensors_alternatives, where normally that is a dict.
    # In this case a dict will be created using the default receiver tensor
    # name "input".
    receiver_tensors_alternatives = {"other1": receiver_tensors_alternative_1,
                                     "other2": receiver_tensors_alternative_2}
    output_1 = constant_op.constant([1.])
    output_2 = constant_op.constant(["2"])
    output_3 = constant_op.constant(["3"])
    export_outputs = {
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            export_output.RegressionOutput(value=output_1),
        "head-2": export_output.ClassificationOutput(classes=output_2),
        "head-3": export_output.PredictOutput(outputs={
            "some_output_3": output_3
        }),
    }

    signature_defs = export.build_all_signature_defs(
        receiver_tensor, export_outputs, receiver_tensors_alternatives)

    expected_signature_defs = {
        "serving_default":
            signature_def_utils.regression_signature_def(
                receiver_tensor,
                output_1),
        "head-2":
            signature_def_utils.classification_signature_def(
                receiver_tensor,
                output_2, None),
        "head-3":
            signature_def_utils.predict_signature_def(
                {"input": receiver_tensor},
                {"some_output_3": output_3}),
        "other1:head-3":
            signature_def_utils.predict_signature_def(
                {"input": receiver_tensors_alternative_1},
                {"some_output_3": output_3}),
        "other2:head-3":
            signature_def_utils.predict_signature_def(
                {"input": receiver_tensors_alternative_2},
                {"some_output_3": output_3})

        # Note that the alternatives 'other:serving_default' and 'other:head-2'
        # are invalid, because regession and classification signatures must take
        # a single string input.  Here we verify that these invalid signatures
        # are not included in the export.
    }

    self.assertDictEqual(expected_signature_defs, signature_defs)
示例#6
0
def _predict_spec(tower_specs, aggregation_device):
    """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`."""
    estimator_spec = _asdict(tower_specs[0])
    estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT

    with ops_lib.device(aggregation_device):
        estimator_spec['predictions'] = _concat_tensor_dicts(
            *[tower_spec.predictions for tower_spec in tower_specs])

        export_outputs_dict = _dict_concat(
            *[tower_spec.export_outputs for tower_spec in tower_specs])

        export_outputs = {}
        for name, export_output_list in six.iteritems(export_outputs_dict):
            if isinstance(export_output_list[0],
                          export_output_lib.PredictOutput):
                export_outputs[name] = export_output_lib.PredictOutput(
                    outputs=_concat_tensor_dicts(*[
                        export_output.outputs
                        for export_output in export_output_list
                    ]))
            elif isinstance(export_output_list[0],
                            export_output_lib.RegressionOutput):
                export_outputs[name] = export_output_lib.RegressionOutput(
                    value=array_ops.concat([
                        export_output.value
                        for export_output in export_output_list
                    ],
                                           axis=0))
            elif isinstance(export_output_list[0],
                            export_output_lib.ClassificationOutput):
                scores = None
                if export_output_list[0].scores is not None:
                    scores = array_ops.concat([
                        export_output.scores
                        for export_output in export_output_list
                    ],
                                              axis=0)

                classes = None
                if export_output_list[0].classes is not None:
                    classes = array_ops.stack([
                        export_output.classes
                        for export_output in export_output_list
                    ],
                                              axis=0)

                export_outputs[name] = export_output_lib.ClassificationOutput(
                    scores=scores, classes=classes)

    estimator_spec['export_outputs'] = export_outputs
    return model_fn_lib.EstimatorSpec(**estimator_spec)
        def _model_fn(features, labels, mode, params):
            if not self._export_mode:
                # Always check batch size in params
                self.assertEqual(batch_size_dict[mode], params['batch_size'])
            else:
                self.assertNotIn('batch_size', params)

            # Check the input feeds correct shape for train and eval. When eval on CPU
            # or predict, it is allowed to have dynamic shape. So, here only validates
            # the fully known shape (which covers the TPU train).
            if features['x'].shape.is_fully_defined():
                self.assertEqual(batch_size_dict[mode], features['x'].shape[0])

            predictions = layers.dense(
                features['x'],
                1,
                kernel_initializer=init_ops.ones_initializer())
            export_outputs = {
                'predictions': export_output.RegressionOutput(predictions)
            }

            if mode == _PREDICT:
                return _create_estimator_spec(
                    mode=mode,
                    predictions={'predictions': predictions},
                    export_outputs=export_outputs)

            loss = losses.mean_squared_error(labels, predictions)

            optimizer = tf.tpu.CrossShardOptimizer(
                training.GradientDescentOptimizer(learning_rate=0.5))
            train_op = optimizer.minimize(
                loss, global_step=training.get_global_step())

            eval_metrics = (
                lambda labels, predictions: {  # pylint: disable=g-long-lambda
                    'absolute_error':
                    metrics_lib.mean_absolute_error(labels, predictions)
                },
                [labels, predictions])
            return _create_estimator_spec(
                mode=mode,
                loss=loss,
                predictions={'predictions': predictions},
                export_outputs=export_outputs,
                train_op=train_op,
                eval_metrics=eval_metrics)
示例#8
0
 def testExportOutputsSingleheadMissingDefault(self):
   with tf.Graph().as_default(), self.cached_session():
     predictions = {'loss': tf.constant(1.)}
     output_1 = tf.constant([1.])
     regression_output = export_output.RegressionOutput(value=output_1)
     export_outputs = {
         'head-1': regression_output,
     }
     estimator_spec = model_fn.EstimatorSpec(
         mode=ModeKeys.PREDICT,
         predictions=predictions,
         export_outputs=export_outputs)
     expected_export_outputs = {
         tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: regression_output,
         'head-1': regression_output,
     }
     self.assertEqual(expected_export_outputs, estimator_spec.export_outputs)
示例#9
0
 def testExportOutputsMultiheadWithDefault(self):
   with ops.Graph().as_default(), self.cached_session():
     predictions = {'loss': constant_op.constant(1.)}
     output_1 = constant_op.constant([1.])
     output_2 = constant_op.constant(['2'])
     output_3 = constant_op.constant(['3'])
     export_outputs = {
         signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
         export_output.RegressionOutput(value=output_1),
         'head-2': export_output.ClassificationOutput(classes=output_2),
         'head-3': export_output.PredictOutput(outputs={
             'some_output_3': output_3
         })}
     estimator_spec = model_fn.EstimatorSpec(
         mode=model_fn.ModeKeys.PREDICT,
         predictions=predictions,
         export_outputs=export_outputs)
     self.assertEqual(export_outputs, estimator_spec.export_outputs)
示例#10
0
 def testExportOutputsMultiheadMissingDefault(self):
   with ops.Graph().as_default(), self.cached_session():
     predictions = {'loss': constant_op.constant(1.)}
     output_1 = constant_op.constant([1.])
     output_2 = constant_op.constant(['2'])
     output_3 = constant_op.constant(['3'])
     export_outputs = {
         'head-1': export_output.RegressionOutput(value=output_1),
         'head-2': export_output.ClassificationOutput(classes=output_2),
         'head-3': export_output.PredictOutput(outputs={
             'some_output_3': output_3
         })}
     with self.assertRaisesRegexp(
         ValueError,
         'Multiple export_outputs were provided, but none of them is '
         'specified as the default.  Do this by naming one of them with '
         'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.'):
       model_fn.EstimatorSpec(
           mode=model_fn.ModeKeys.PREDICT,
           predictions=predictions,
           export_outputs=export_outputs)
示例#11
0
 def testExportOutputsMultiheadMissingDefault(self):
     with tf.Graph().as_default(), self.cached_session():
         predictions = {'loss': tf.constant(1.)}
         output_1 = tf.constant([1.])
         output_2 = tf.constant(['2'])
         output_3 = tf.constant(['3'])
         export_outputs = {
             'head-1':
             export_output.RegressionOutput(value=output_1),
             'head-2':
             export_output.ClassificationOutput(classes=output_2),
             'head-3':
             export_output.PredictOutput(
                 outputs={'some_output_3': output_3})
         }
         with self.assertRaisesRegexp(
                 ValueError,
                 'Multiple [`]*export_outputs[`]* were provided'):
             model_fn.EstimatorSpec(mode=ModeKeys.PREDICT,
                                    predictions=predictions,
                                    export_outputs=export_outputs)
    def _create_tpu_estimator_spec(self,
                                   features,
                                   mode,
                                   logits,
                                   labels=None,
                                   optimizer=None,
                                   trainable_variables=None,
                                   train_op_fn=None,
                                   update_ops=None,
                                   regularization_losses=None):
        """See superclass for description."""

        with tf.compat.v1.name_scope(self._name, 'head'):
            # Predict.
            pred_keys = prediction_keys.PredictionKeys
            predictions = self.predictions(logits)
            if mode == ModeKeys.PREDICT:
                probabilities = predictions[pred_keys.PROBABILITIES]
                logistic = predictions[pred_keys.LOGISTIC]
                classifier_output = base_head.classification_output(
                    scores=probabilities,
                    n_classes=2,
                    label_vocabulary=self._label_vocabulary)
                return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
                    mode=ModeKeys.PREDICT,
                    predictions=predictions,
                    export_outputs={
                        base_head.DEFAULT_SERVING_KEY: classifier_output,
                        base_head.CLASSIFY_SERVING_KEY: classifier_output,
                        base_head.REGRESS_SERVING_KEY:
                            export_output.RegressionOutput(value=logistic),
                        base_head.PREDICT_SERVING_KEY:
                            export_output.PredictOutput(predictions)
                    })
            regularized_training_loss = self.loss(
                logits=logits,
                labels=labels,
                features=features,
                mode=mode,
                regularization_losses=regularization_losses)
            scalar_loss = tf.reduce_mean(regularized_training_loss)
            # Eval.
            if mode == ModeKeys.EVAL:
                eval_metrics = self.metrics(
                    regularization_losses=regularization_losses)
                return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
                    mode=ModeKeys.EVAL,
                    predictions=predictions,
                    loss=scalar_loss,
                    eval_metrics=base_head.create_eval_metrics_tuple(
                        self.update_metrics, {
                            'eval_metrics': eval_metrics,
                            'features': features,
                            'logits': logits,
                            'labels': labels,
                            'regularization_losses': regularization_losses
                        }))
            # Train.
            train_op = base_head.create_estimator_spec_train_op(
                head_name=self._name,
                optimizer=optimizer,
                train_op_fn=train_op_fn,
                update_ops=update_ops,
                trainable_variables=trainable_variables,
                regularized_training_loss=regularized_training_loss,
                loss_reduction=self._loss_reduction)
        # Create summary.
        base_head.create_estimator_spec_summary(
            regularized_training_loss=scalar_loss,
            regularization_losses=regularization_losses,
            summary_key_fn=self._summary_key)
        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
            mode=ModeKeys.TRAIN,
            predictions=predictions,
            loss=scalar_loss,
            train_op=train_op)
示例#13
0
    def _create_tpu_estimator_spec(self,
                                   features,
                                   mode,
                                   logits,
                                   labels=None,
                                   optimizer=None,
                                   trainable_variables=None,
                                   train_op_fn=None,
                                   update_ops=None,
                                   regularization_losses=None):
        """Returns an `EstimatorSpec`.

    Args:
      features: Input `dict` mapping string feature names to `Tensor` or
        `SparseTensor` objects containing the values for that feature in a
        minibatch. Often to be used to fetch example-weight tensor.
      mode: Estimator's `ModeKeys`.
      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.
        For many applications, the shape is `[batch_size, logits_dimension]`.
      labels: Labels `Tensor` with shape matching `logits`, namely
        `[D0, D1, ... DN, logits_dimension]`. When `logits_dimension=1`, shape
        `[D0, D1, ... DN]` is also supported. `labels` is a required argument
        when `mode` equals `TRAIN` or `EVAL`.
      optimizer: An `tf.keras.optimizers.Optimizer` instance to optimize the
         loss in TRAIN mode. Namely, sets
        `train_op = optimizer.get_updates(loss, trainable_variables)`,
        which updates variables to minimize `loss`.
      trainable_variables: A list or tuple of `Variable` objects to update to
        minimize `loss`. In Tensorflow 1.x, by default these are the list of
        variables collected in the graph under the key
        `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have
        collections and GraphKeys, trainable_variables need to be passed
        explicitly here.
      train_op_fn: Function that takes a scalar loss `Tensor` and returns
        `train_op`. Used if `optimizer` is `None`.
      update_ops: A list or tuple of update ops to be run at training time. For
        example, layers such as BatchNormalization create mean and variance
        update ops that need to be run at training time. In Tensorflow 1.x,
        these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x
        doesn't have collections, update_ops need to be passed explicitly here.
      regularization_losses: A list of additional scalar losses to be added to
        the training loss, such as regularization losses. These losses are
        usually expressed as a batch average, so for best results users need to
        set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to
        avoid scaling errors.

    Returns:
      A `model_fn._TPUEstimatorSpec` instance.

    Raises:
      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
        mode, or if both are set.
    """
        with ops.name_scope(self._name, 'head'):
            # Predict.
            predictions = self.predictions(logits)
            if mode == ModeKeys.PREDICT:
                keys = prediction_keys.PredictionKeys
                regression_output = export_output.RegressionOutput(
                    value=predictions[keys.PREDICTIONS])
                return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
                    mode=ModeKeys.PREDICT,
                    predictions=predictions,
                    export_outputs={
                        base_head.DEFAULT_SERVING_KEY: regression_output,
                        base_head.REGRESS_SERVING_KEY: regression_output,
                        base_head.PREDICT_SERVING_KEY: export_output.PredictOutput(
                            predictions)
                    })
            regularized_training_loss = self.loss(
                logits=logits,
                labels=labels,
                features=features,
                mode=mode,
                regularization_losses=regularization_losses)
            # Eval.
            if mode == ModeKeys.EVAL:
                eval_metrics = self.metrics(
                    regularization_losses=regularization_losses)
                return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
                    mode=ModeKeys.EVAL,
                    predictions=predictions,
                    loss=regularized_training_loss,
                    eval_metrics=base_head.create_eval_metrics_tuple(
                        self.update_metrics, {
                            'eval_metrics': eval_metrics,
                            'features': features,
                            'logits': logits,
                            'labels': labels,
                            'regularization_losses': regularization_losses
                        }))
            # Train.
            train_op = base_head.create_estimator_spec_train_op(
                head_name=self._name,
                optimizer=optimizer,
                train_op_fn=train_op_fn,
                update_ops=update_ops,
                trainable_variables=trainable_variables,
                regularized_training_loss=regularized_training_loss,
                loss_reduction=self._loss_reduction)
        # Create summary.
        base_head.create_estimator_spec_summary(
            regularized_training_loss=regularized_training_loss,
            regularization_losses=regularization_losses,
            summary_key_fn=self._summary_key)
        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
            mode=ModeKeys.TRAIN,
            predictions=predictions,
            loss=regularized_training_loss,
            train_op=train_op)
示例#14
0
    def _create_tpu_estimator_spec(self,
                                   features,
                                   mode,
                                   logits,
                                   labels=None,
                                   optimizer=None,
                                   train_op_fn=None,
                                   regularization_losses=None):
        """Returns an `EstimatorSpec`.

    Args:
      features: Input `dict` mapping string feature names to `Tensor` or
        `SparseTensor` objects containing the values for that feature in a
        minibatch. Often to be used to fetch example-weight tensor.
      mode: Estimator's `ModeKeys`.
      logits: logits `Tensor` with shape `[D0, D1, ... DN, 1]`. For many
        applications, the shape is `[batch_size, 1]`.
      labels: Labels integer or string `Tensor` with shape matching `logits`,
        namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required
        argument when `mode` equals `TRAIN` or `EVAL`.
      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
        updates variables and increments `global_step`.
      train_op_fn: Function that takes a scalar loss `Tensor` and returns
        `train_op`. Used if `optimizer` is `None`.
      regularization_losses: A list of additional scalar losses to be added to
        the training loss, such as regularization losses. These losses are
        usually expressed as a batch average, so for best results users need to
        set `loss_reduction=SUM_OVER_BATCH_SIZE` or
        `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
        avoid scaling errors.

    Returns:
      `EstimatorSpec`.

    Raises:
      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
        mode, or if both are set.
    """
        with ops.name_scope(self._name, 'head'):
            # Predict.
            pred_keys = prediction_keys.PredictionKeys
            predictions = self.predictions(logits)
            if mode == model_fn.ModeKeys.PREDICT:
                probabilities = predictions[pred_keys.PROBABILITIES]
                logistic = predictions[pred_keys.LOGISTIC]
                classifier_output = base_head.classification_output(
                    scores=probabilities,
                    n_classes=2,
                    label_vocabulary=self._label_vocabulary)
                return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
                    mode=model_fn.ModeKeys.PREDICT,
                    predictions=predictions,
                    export_outputs={
                        base_head.DEFAULT_SERVING_KEY: classifier_output,
                        base_head.CLASSIFY_SERVING_KEY: classifier_output,
                        base_head.REGRESS_SERVING_KEY:
                            export_output.RegressionOutput(value=logistic),
                        base_head.PREDICT_SERVING_KEY:
                            export_output.PredictOutput(predictions)
                    })
            regularized_training_loss = self.loss(
                logits=logits,
                labels=labels,
                features=features,
                mode=mode,
                regularization_losses=regularization_losses)
            # Eval.
            if mode == model_fn.ModeKeys.EVAL:
                eval_metrics = self.metrics(
                    regularization_losses=regularization_losses)
                return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
                    mode=model_fn.ModeKeys.EVAL,
                    predictions=predictions,
                    loss=regularized_training_loss,
                    eval_metrics=base_head.create_eval_metrics_tuple(
                        self.update_metrics, {
                            'eval_metrics': eval_metrics,
                            'features': features,
                            'logits': logits,
                            'labels': labels,
                            'regularization_losses': regularization_losses
                        }))
            # Train.
            train_op = base_head.create_estimator_spec_train_op(
                self._name, optimizer, train_op_fn, regularized_training_loss)
        # Create summary.
        base_head.create_estimator_spec_summary(self._summary_key,
                                                regularized_training_loss,
                                                regularization_losses)
        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
            mode=model_fn.ModeKeys.TRAIN,
            predictions=predictions,
            loss=regularized_training_loss,
            train_op=train_op)
示例#15
0
    def _create_tpu_estimator_spec(self,
                                   features,
                                   mode,
                                   logits,
                                   labels=None,
                                   optimizer=None,
                                   train_op_fn=None,
                                   regularization_losses=None):
        """Returns an `EstimatorSpec`.

    Args:
      features: Input `dict` of `Tensor` or `SparseTensor` objects.
      mode: Estimator's `ModeKeys`.
      logits: logits `Tensor` with shape `[D0, D1, ... DN, 1]`. For many
        applications, the shape is `[batch_size, 1]`.
      labels: Labels integer or string `Tensor` with shape matching `logits`,
        namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required
        argument when `mode` equals `TRAIN` or `EVAL`.
      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
        updates variables and increments `global_step`.
      train_op_fn: Function that takes a scalar loss `Tensor` and returns
        `train_op`. Used if `optimizer` is `None`.
      regularization_losses: A list of additional scalar losses to be added to
        the training loss, such as regularization losses. These losses are
        usually expressed as a batch average, so for best results users need to
        set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid
        scaling errors.

    Returns:
      `EstimatorSpec`.
    Raises:
      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
        mode, or if both are set.
    """
        # Predict.
        with tf.compat.v1.name_scope(self._name, 'head'):
            with tf.compat.v1.name_scope(None, 'predictions', (logits, )):
                pred_keys = prediction_keys.PredictionKeys
                logits = _check_logits_final_dim(logits, self.logits_dimension)
                logistic = tf.math.sigmoid(logits, name=pred_keys.LOGISTIC)
                two_class_logits = tf.concat(
                    (tf.compat.v1.zeros_like(logits), logits),
                    axis=-1,
                    name='two_class_logits')
                probabilities = tf.compat.v1.nn.softmax(
                    two_class_logits, name=pred_keys.PROBABILITIES)
                class_ids = tf.compat.v1.math.argmax(two_class_logits,
                                                     axis=-1,
                                                     name=pred_keys.CLASS_IDS)
                class_ids = tf.compat.v1.expand_dims(class_ids, axis=-1)
                all_class_ids = _all_class_ids(logits, n_classes=2)
                all_classes = _all_classes(
                    logits,
                    n_classes=2,
                    label_vocabulary=self._label_vocabulary)

                if self._label_vocabulary:
                    table = lookup_ops.index_to_string_table_from_tensor(
                        vocabulary_list=self._label_vocabulary,
                        name='class_string_lookup')
                    classes = table.lookup(class_ids)
                else:
                    classes = tf.strings.as_string(class_ids,
                                                   name='str_classes')
                predictions = {
                    pred_keys.LOGITS: logits,
                    pred_keys.LOGISTIC: logistic,
                    pred_keys.PROBABILITIES: probabilities,
                    pred_keys.CLASS_IDS: class_ids,
                    pred_keys.CLASSES: classes,
                    pred_keys.ALL_CLASS_IDS: all_class_ids,
                    pred_keys.ALL_CLASSES: all_classes,
                }
            if mode == ModeKeys.PREDICT:
                classifier_output = _classification_output(
                    scores=probabilities,
                    n_classes=2,
                    label_vocabulary=self._label_vocabulary)
                return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
                    mode=ModeKeys.PREDICT,
                    predictions=predictions,
                    export_outputs={
                        _DEFAULT_SERVING_KEY: classifier_output,
                        _CLASSIFY_SERVING_KEY: classifier_output,
                        _REGRESS_SERVING_KEY: export_output.RegressionOutput(
                            value=logistic),
                        _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)
                    })

            (training_loss, unreduced_loss, weights,
             processed_labels) = (self.create_loss(features=features,
                                                   mode=mode,
                                                   logits=logits,
                                                   labels=labels))
            if regularization_losses:
                regularization_loss = tf.math.add_n(regularization_losses)
                regularized_training_loss = tf.math.add_n(
                    [training_loss, regularization_loss])
            else:
                regularization_loss = None
                regularized_training_loss = training_loss

            if self._loss_reduction == tf.compat.v1.losses.Reduction.NONE:
                scalar_loss = tf.reduce_mean(regularized_training_loss)
            else:
                scalar_loss = regularized_training_loss
            # Eval.
            if mode == ModeKeys.EVAL:
                return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
                    mode=ModeKeys.EVAL,
                    predictions=predictions,
                    loss=scalar_loss,
                    eval_metrics=_create_eval_metrics_tuple(
                        self._eval_metric_ops, {
                            'labels': processed_labels,
                            'logits': logits,
                            'logistic': logistic,
                            'class_ids': class_ids,
                            'weights': weights,
                            'unreduced_loss': unreduced_loss,
                            'regularization_loss': regularization_loss
                        }))

            # Train.
            if optimizer is not None:
                if train_op_fn is not None:
                    raise ValueError(
                        'train_op_fn and optimizer cannot both be set.')
                train_op = optimizer.minimize(
                    regularized_training_loss,
                    global_step=tf.compat.v1.train.get_global_step())
            elif train_op_fn is not None:
                train_op = train_op_fn(regularized_training_loss)
            else:
                raise ValueError(
                    'train_op_fn and optimizer cannot both be None.')
            train_op = _append_update_ops(train_op)
            # Only summarize mean_loss for SUM reduction to preserve backwards
            # compatibility. Otherwise skip it to avoid unnecessary computation.
            if self._loss_reduction == tf.compat.v1.losses.Reduction.SUM:
                example_weight_sum = tf.math.reduce_sum(
                    weights * tf.compat.v1.ones_like(unreduced_loss))
                mean_loss = training_loss / example_weight_sum
            else:
                mean_loss = None
        with tf.compat.v1.name_scope(''):
            keys = metric_keys.MetricKeys
            tf.compat.v1.summary.scalar(_summary_key(self._name, keys.LOSS),
                                        scalar_loss)
            if mean_loss is not None:
                tf.compat.v1.summary.scalar(
                    _summary_key(self._name, keys.LOSS_MEAN), mean_loss)
            if regularization_loss is not None:
                tf.compat.v1.summary.scalar(
                    _summary_key(self._name, keys.LOSS_REGULARIZATION),
                    regularization_loss)
        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access
            mode=ModeKeys.TRAIN,
            predictions=predictions,
            loss=scalar_loss,
            train_op=train_op)