예제 #1
0
  def create_estimator_spec(
      self, features, mode, logits, labels=None, train_op_fn=None):
    """See `Head`."""
    with ops.name_scope('head'):
      logits = head_lib._check_logits(logits, self.logits_dimension)  # pylint:disable=protected-access

      # Predict.
      pred_keys = prediction_keys.PredictionKeys
      with ops.name_scope(None, 'predictions', (logits,)):
        probabilities = math_ops.sigmoid(logits, name=pred_keys.PROBABILITIES)
        predictions = {
            pred_keys.LOGITS: logits,
            pred_keys.PROBABILITIES: probabilities,
        }
      if mode == model_fn.ModeKeys.PREDICT:
        return model_fn.EstimatorSpec(
            mode=model_fn.ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={
                '': export_output.ClassificationOutput(scores=probabilities)
            })

      # Eval.
      unweighted_loss, processed_labels = self.create_loss(
          features=features, mode=mode, logits=logits, labels=labels)
      # Averages loss over classes.
      per_example_loss = math_ops.reduce_mean(
          unweighted_loss, axis=-1, keep_dims=True)
      weights = head_lib._weights(features, self._weight_column)  # pylint:disable=protected-access
      training_loss = losses.compute_weighted_loss(
          per_example_loss, weights=weights, reduction=losses.Reduction.SUM)
      if mode == model_fn.ModeKeys.EVAL:
        return model_fn.EstimatorSpec(
            mode=model_fn.ModeKeys.EVAL,
            predictions=predictions,
            loss=training_loss,
            eval_metric_ops=self._eval_metric_ops(
                labels=processed_labels,
                probabilities=probabilities,
                weights=weights,
                per_example_loss=per_example_loss))

      # Train.
      if train_op_fn is None:
        raise ValueError('train_op_fn can not be None.')
    with ops.name_scope(''):
      summary.scalar(
          head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS),  # pylint:disable=protected-access
          training_loss)
      summary.scalar(
          head_lib._summary_key(  # pylint:disable=protected-access
              self._name, metric_keys.MetricKeys.LOSS_MEAN),
          losses.compute_weighted_loss(
              unweighted_loss, weights=weights,
              reduction=losses.Reduction.MEAN))
    return model_fn.EstimatorSpec(
        mode=model_fn.ModeKeys.TRAIN,
        predictions=predictions,
        loss=training_loss,
        train_op=train_op_fn(training_loss))
예제 #2
0
 def create_loss(self, features, mode, logits, labels):
     """See `Head`."""
     del mode  # Unused for this head.
     processed_labels = self._process_labels(labels)
     if self._loss_fn:
         unweighted_loss = _call_loss_fn(loss_fn=self._loss_fn,
                                         labels=processed_labels,
                                         logits=logits,
                                         features=features)
     else:
         unweighted_loss = losses.sigmoid_cross_entropy(
             multi_class_labels=processed_labels,
             logits=logits,
             reduction=losses.Reduction.NONE)
         # Averages loss over classes.
         unweighted_loss = math_ops.reduce_mean(unweighted_loss,
                                                axis=-1,
                                                keep_dims=True)
     weights = head_lib._weights(features, self._weight_column)  # pylint:disable=protected-access,
     weighted_sum_loss = losses.compute_weighted_loss(
         unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
     # _weights() can return 1.
     example_weight_sum = math_ops.reduce_sum(
         weights * array_ops.ones_like(unweighted_loss))
     return head_lib.LossSpec(weighted_sum_loss=weighted_sum_loss,
                              example_weight_sum=example_weight_sum,
                              processed_labels=processed_labels)
예제 #3
0
 def create_loss(self, features, mode, logits, labels):
   """See `Head`."""
   del mode  # Unused for this head.
   processed_labels = self._process_labels(labels)
   if self._loss_fn:
     unweighted_loss = _call_loss_fn(
         loss_fn=self._loss_fn, labels=processed_labels, logits=logits,
         features=features)
   else:
     unweighted_loss = losses.sigmoid_cross_entropy(
         multi_class_labels=processed_labels, logits=logits,
         reduction=losses.Reduction.NONE)
     # Averages loss over classes.
     unweighted_loss = math_ops.reduce_mean(
         unweighted_loss, axis=-1, keep_dims=True)
   weights = head_lib._weights(features, self._weight_column)  # pylint:disable=protected-access,
   weighted_sum_loss = losses.compute_weighted_loss(
       unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
   # _weights() can return 1.
   example_weight_sum = math_ops.reduce_sum(
       weights * array_ops.ones_like(unweighted_loss))
   return head_lib.LossSpec(
       weighted_sum_loss=weighted_sum_loss,
       example_weight_sum=example_weight_sum,
       processed_labels=processed_labels)
예제 #4
0
    def create_estimator_spec(self,
                              features,
                              mode,
                              logits,
                              labels=None,
                              train_op_fn=None):
        """See `Head`."""
        with ops.name_scope(self._name, 'head'):
            logits = head_lib._check_logits(logits, self.logits_dimension)  # pylint:disable=protected-access

            # Predict.
            pred_keys = prediction_keys.PredictionKeys
            with ops.name_scope(None, 'predictions', (logits, )):
                probabilities = math_ops.sigmoid(logits,
                                                 name=pred_keys.PROBABILITIES)
                predictions = {
                    pred_keys.LOGITS: logits,
                    pred_keys.PROBABILITIES: probabilities,
                }
            if mode == model_fn.ModeKeys.PREDICT:
                classifier_output = head_lib._classification_output(  # pylint:disable=protected-access
                    scores=probabilities,
                    n_classes=self._n_classes,
                    label_vocabulary=self._label_vocabulary)
                return model_fn.EstimatorSpec(
                    mode=model_fn.ModeKeys.PREDICT,
                    predictions=predictions,
                    export_outputs={
                        _DEFAULT_SERVING_KEY:
                        classifier_output,
                        head_lib._CLASSIFY_SERVING_KEY:
                        classifier_output,  # pylint:disable=protected-access
                        head_lib._PREDICT_SERVING_KEY: (  # pylint:disable=protected-access
                            export_output.PredictOutput(predictions))
                    })

            (weighted_sum_loss, example_weight_sum,
             processed_labels) = self.create_loss(features=features,
                                                  mode=mode,
                                                  logits=logits,
                                                  labels=labels)

            # Eval.
            if mode == model_fn.ModeKeys.EVAL:
                return model_fn.EstimatorSpec(
                    mode=model_fn.ModeKeys.EVAL,
                    predictions=predictions,
                    loss=weighted_sum_loss,
                    eval_metric_ops=self._eval_metric_ops(
                        labels=processed_labels,
                        probabilities=probabilities,
                        weights=head_lib._weights(features,
                                                  self._weight_column),  # pylint:disable=protected-access,
                        weighted_sum_loss=weighted_sum_loss,
                        example_weight_sum=example_weight_sum))

            # Train.
            if train_op_fn is None:
                raise ValueError('train_op_fn can not be None.')
        with ops.name_scope(''):
            summary.scalar(
                head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS),  # pylint:disable=protected-access
                weighted_sum_loss)
            summary.scalar(
                head_lib._summary_key(  # pylint:disable=protected-access
                    self._name, metric_keys.MetricKeys.LOSS_MEAN),
                weighted_sum_loss / example_weight_sum)
        return model_fn.EstimatorSpec(mode=model_fn.ModeKeys.TRAIN,
                                      predictions=predictions,
                                      loss=weighted_sum_loss,
                                      train_op=train_op_fn(weighted_sum_loss))
예제 #5
0
  def create_estimator_spec(
      self, features, mode, logits, labels=None, train_op_fn=None):
    """See `Head`."""
    with ops.name_scope(self._name, 'head'):
      logits = head_lib._check_logits(logits, self.logits_dimension)  # pylint:disable=protected-access

      # Predict.
      pred_keys = prediction_keys.PredictionKeys
      with ops.name_scope(None, 'predictions', (logits,)):
        probabilities = math_ops.sigmoid(logits, name=pred_keys.PROBABILITIES)
        predictions = {
            pred_keys.LOGITS: logits,
            pred_keys.PROBABILITIES: probabilities,
        }
      if mode == model_fn.ModeKeys.PREDICT:
        classifier_output = head_lib._classification_output(  # pylint:disable=protected-access
            scores=probabilities, n_classes=self._n_classes,
            label_vocabulary=self._label_vocabulary)
        return model_fn.EstimatorSpec(
            mode=model_fn.ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={
                _DEFAULT_SERVING_KEY: classifier_output,
                head_lib._CLASSIFY_SERVING_KEY: classifier_output,  # pylint:disable=protected-access
                head_lib._PREDICT_SERVING_KEY: (  # pylint:disable=protected-access
                    export_output.PredictOutput(predictions))
            })

      (weighted_sum_loss, example_weight_sum,
       processed_labels) = self.create_loss(
           features=features, mode=mode, logits=logits, labels=labels)

      # Eval.
      if mode == model_fn.ModeKeys.EVAL:
        return model_fn.EstimatorSpec(
            mode=model_fn.ModeKeys.EVAL,
            predictions=predictions,
            loss=weighted_sum_loss,
            eval_metric_ops=self._eval_metric_ops(
                labels=processed_labels,
                probabilities=probabilities,
                weights=head_lib._weights(features, self._weight_column),  # pylint:disable=protected-access,
                weighted_sum_loss=weighted_sum_loss,
                example_weight_sum=example_weight_sum))

      # Train.
      if train_op_fn is None:
        raise ValueError('train_op_fn can not be None.')
    with ops.name_scope(''):
      summary.scalar(
          head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS),  # pylint:disable=protected-access
          weighted_sum_loss)
      summary.scalar(
          head_lib._summary_key(  # pylint:disable=protected-access
              self._name, metric_keys.MetricKeys.LOSS_MEAN),
          weighted_sum_loss / example_weight_sum)
    return model_fn.EstimatorSpec(
        mode=model_fn.ModeKeys.TRAIN,
        predictions=predictions,
        loss=weighted_sum_loss,
        train_op=train_op_fn(weighted_sum_loss))