Esempio n. 1
0
  def _merge_eval(self, all_estimator_spec):
    """Merges list of `EstimatorSpec` for eval.

    Args:
      all_estimator_spec: list of `EstimatorSpec` for the individual heads.

    Returns:
      `EstimatorSpec` that merges all heads for EVAL.
    """
    predictions = {}
    metrics = {}
    losses = []
    with ops.name_scope('merge_eval'):
      for head, spec in zip(self._heads, all_estimator_spec):
        losses.append(spec.loss)
        head_name = head.name
        # Loss metric is not added by default.
        loss_name = head_lib._summary_key(  # pylint:disable=protected-access
            head_name, metric_keys.MetricKeys.LOSS)
        metrics[loss_name] = metrics_lib.mean(spec.loss, name=loss_name)
        # Metric keys already contain head.name.
        metrics.update(spec.eval_metric_ops or {})
        for k, v in six.iteritems(spec.predictions):
          predictions[(head_name, k)] = v
      loss = _merge_losses(losses, self._head_weights)

    return model_fn.EstimatorSpec(
        mode=model_fn.ModeKeys.EVAL,
        predictions=predictions,
        loss=loss,
        eval_metric_ops=metrics)
Esempio n. 2
0
 def create_loss(self, features, mode, logits=None, labels=None):
   """See `_Head`."""
   model_outputs = self.state_manager.define_loss(self.model, features, mode)
   tf.compat.v1.summary.scalar(
       head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS),
       model_outputs.loss)
   return model_outputs
Esempio n. 3
0
 def _eval_metric_ops(self, labels, probabilities, weights, unreduced_loss,
                      regularization_loss):
     """Returns a dict of metrics for eval_metric_ops."""
     with ops.name_scope(None, 'metrics', [
             labels, probabilities, weights, unreduced_loss,
             regularization_loss
     ]):
         keys = metric_keys.MetricKeys
         metric_ops = {
             # Estimator already adds a metric for loss.
             head_lib._summary_key(self._name, keys.LOSS_MEAN):  # pylint:disable=protected-access
                 metrics_lib.mean(
                     values=unreduced_loss,
                     weights=weights,
                     name=keys.LOSS_MEAN),
             head_lib._summary_key(self._name, keys.AUC):  # pylint:disable=protected-access
                 metrics_lib.auc(labels=labels, predictions=probabilities,
                                 weights=weights, name=keys.AUC),
             head_lib._summary_key(self._name, keys.AUC_PR):  # pylint:disable=protected-access
                 metrics_lib.auc(labels=labels, predictions=probabilities,
                                 weights=weights, curve='PR',
                                 name=keys.AUC_PR),
         }
         if regularization_loss is not None:
             loss_regularization_key = head_lib._summary_key(  # pylint:disable=protected-access
                 self._name, keys.LOSS_REGULARIZATION)
             metric_ops[loss_regularization_key] = (metrics_lib.mean(
                 values=regularization_loss, name=keys.LOSS_REGULARIZATION))
         for threshold in self._thresholds:
             accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
             metric_ops[head_lib._summary_key(self._name, accuracy_key)] = (  # pylint:disable=protected-access
                 head_lib._accuracy_at_threshold(  # pylint:disable=protected-access
                     labels=labels,
                     predictions=probabilities,
                     weights=weights,
                     threshold=threshold,
                     name=accuracy_key))
             # Precision for positive examples.
             precision_key = keys.PRECISION_AT_THRESHOLD % threshold
             metric_ops[head_lib._summary_key(
                 self._name, precision_key)] = (  # pylint:disable=protected-access
                     head_lib._precision_at_threshold(  # pylint:disable=protected-access
                         labels=labels,
                         predictions=probabilities,
                         weights=weights,
                         threshold=threshold,
                         name=precision_key))
             # Recall for positive examples.
             recall_key = keys.RECALL_AT_THRESHOLD % threshold
             metric_ops[head_lib._summary_key(self._name, recall_key)] = (  # pylint:disable=protected-access
                 head_lib._recall_at_threshold(  # pylint:disable=protected-access
                     labels=labels,
                     predictions=probabilities,
                     weights=weights,
                     threshold=threshold,
                     name=recall_key))
         for class_id in self._classes_for_class_based_metrics:
             batch_rank = array_ops.rank(probabilities) - 1
             begin = array_ops.concat([
                 array_ops.zeros([batch_rank], dtype=dtypes.int32),
                 [class_id]
             ],
                                      axis=0)
             size = array_ops.concat([
                 -1 * array_ops.ones([batch_rank], dtype=dtypes.int32), [1]
             ],
                                     axis=0)
             class_probabilities = array_ops.slice(probabilities,
                                                   begin=begin,
                                                   size=size)
             class_labels = array_ops.slice(labels, begin=begin, size=size)
             if self._label_vocabulary is None:
                 prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id
             else:
                 prob_key = (keys.PROBABILITY_MEAN_AT_NAME %
                             self._label_vocabulary[class_id])
             metric_ops[head_lib._summary_key(self._name, prob_key)] = (  # pylint:disable=protected-access
                 head_lib._predictions_mean(  # pylint:disable=protected-access
                     predictions=class_probabilities,
                     weights=weights,
                     name=prob_key))
             if self._label_vocabulary is None:
                 auc_key = keys.AUC_AT_CLASS % class_id
             else:
                 auc_key = keys.AUC_AT_NAME % self._label_vocabulary[
                     class_id]
             metric_ops[head_lib._summary_key(self._name, auc_key)] = (  # pylint:disable=protected-access
                 head_lib._auc(  # pylint:disable=protected-access
                     labels=class_labels,
                     predictions=class_probabilities,
                     weights=weights,
                     name=auc_key))
             if self._label_vocabulary is None:
                 auc_pr_key = keys.AUC_PR_AT_CLASS % class_id
             else:
                 auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[
                     class_id]
             metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = (  # pylint:disable=protected-access
                 head_lib._auc(  # pylint:disable=protected-access
                     labels=class_labels,
                     predictions=class_probabilities,
                     weights=weights,
                     curve='PR',
                     name=auc_pr_key))
     return metric_ops
Esempio n. 4
0
    def _create_tpu_estimator_spec(self,
                                   features,
                                   mode,
                                   logits,
                                   labels=None,
                                   optimizer=None,
                                   train_op_fn=None,
                                   regularization_losses=None):
        """Returns an `model_fn._TPUEstimatorSpec`.

    Args:
      features: Input `dict` of `Tensor` or `SparseTensor` objects.
      mode: Estimator's `ModeKeys`.
      logits: logits `Tensor` with shape `[D0, D1, ... DN, n_classes]`.
        For many applications, the shape is `[batch_size, n_classes]`.
      labels: Labels with shape matching `logits`. Can be multi-hot `Tensor`
        with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with
        `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when
        `mode` equals `TRAIN` or `EVAL`.
      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.
        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which
        updates variables and increments `global_step`.
      train_op_fn: Function that takes a scalar loss `Tensor` and returns
        `train_op`. Used if `optimizer` is `None`.
      regularization_losses: A list of additional scalar losses to be added to
        the training loss, such as regularization losses. These losses are
        usually expressed as a batch average, so for best results users need to
        set `loss_reduction=SUM_OVER_BATCH_SIZE` or
        `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
        avoid scaling errors.
    Returns:
      `model_fn._TPUEstimatorSpec`.
    Raises:
      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
        mode, or if both are set.
    """
        with ops.name_scope(self._name, 'head'):
            logits = head_lib._check_logits_final_dim(logits,
                                                      self.logits_dimension)  # pylint:disable=protected-access

            # Predict.
            pred_keys = prediction_keys.PredictionKeys
            with ops.name_scope(None, 'predictions', (logits, )):
                probabilities = math_ops.sigmoid(logits,
                                                 name=pred_keys.PROBABILITIES)
                predictions = {
                    pred_keys.LOGITS: logits,
                    pred_keys.PROBABILITIES: probabilities,
                }
            if mode == model_fn.ModeKeys.PREDICT:
                classifier_output = head_lib._classification_output(  # pylint:disable=protected-access
                    scores=probabilities,
                    n_classes=self._n_classes,
                    label_vocabulary=self._label_vocabulary)
                return model_fn._TPUEstimatorSpec(  # pylint:disable=protected-access
                    mode=model_fn.ModeKeys.PREDICT,
                    predictions=predictions,
                    export_outputs={
                        _DEFAULT_SERVING_KEY: classifier_output,
                        head_lib._CLASSIFY_SERVING_KEY: classifier_output,  # pylint:disable=protected-access
                        head_lib._PREDICT_SERVING_KEY: (  # pylint:disable=protected-access
                            export_output.PredictOutput(predictions))
                    })

            (training_loss, unreduced_loss, weights,
             processed_labels) = self.create_loss(features=features,
                                                  mode=mode,
                                                  logits=logits,
                                                  labels=labels)
            if regularization_losses:
                regularization_loss = math_ops.add_n(regularization_losses)
                regularized_training_loss = math_ops.add_n(
                    [training_loss, regularization_loss])
            else:
                regularization_loss = None
                regularized_training_loss = training_loss

            # Eval.
            if mode == model_fn.ModeKeys.EVAL:
                return model_fn._TPUEstimatorSpec(  # pylint:disable=protected-access
                    mode=model_fn.ModeKeys.EVAL,
                    predictions=predictions,
                    loss=regularized_training_loss,
                    eval_metrics=head_lib._create_eval_metrics_tuple(  # pylint:disable=protected-access
                        self._eval_metric_ops, {
                            'labels': processed_labels,
                            'probabilities': probabilities,
                            'weights': weights,
                            'unreduced_loss': unreduced_loss,
                            'regularization_loss': regularization_loss,
                        }))

            # Train.
            if optimizer is not None:
                if train_op_fn is not None:
                    raise ValueError(
                        'train_op_fn and optimizer cannot both be set.')
                train_op = optimizer.minimize(
                    regularized_training_loss,
                    global_step=training_util.get_global_step())
            elif train_op_fn is not None:
                train_op = train_op_fn(regularized_training_loss)
            else:
                raise ValueError(
                    'train_op_fn and optimizer cannot both be None.')
            train_op = head_lib._append_update_ops(train_op)  # pylint:disable=protected-access
            # Only summarize mean_loss for SUM reduction to preserve backwards
            # compatibility. Otherwise skip it to avoid unnecessary computation.
            if self._loss_reduction == losses.Reduction.SUM:
                example_weight_sum = math_ops.reduce_sum(
                    weights * array_ops.ones_like(unreduced_loss))
                mean_loss = training_loss / example_weight_sum
            else:
                mean_loss = None
        with ops.name_scope(''):
            keys = metric_keys.MetricKeys
            summary.scalar(
                head_lib._summary_key(self._name, keys.LOSS),  # pylint:disable=protected-access
                regularized_training_loss)
            if mean_loss is not None:
                summary.scalar(
                    head_lib._summary_key(self._name, keys.LOSS_MEAN),  # pylint:disable=protected-access
                    mean_loss)
            if regularization_loss is not None:
                summary.scalar(
                    head_lib._summary_key(self._name,
                                          keys.LOSS_REGULARIZATION),  # pylint:disable=protected-access
                    regularization_loss)
        return model_fn._TPUEstimatorSpec(  # pylint:disable=protected-access
            mode=model_fn.ModeKeys.TRAIN,
            predictions=predictions,
            loss=regularized_training_loss,
            train_op=train_op)