Exemplo n.º 1
0
def _get_loss(features, labels, logits, params):
    ctcvr_label = tf.to_float(labels['convert_label'])
    ctr_label = tf.to_float(labels['click_label'])
    ctr_logits, ctcvr_logits = logits['ctr_logits'], logits['ctcvr_logits']

    # unweighted loss
    unweighted_ctcvr_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=ctcvr_label, logits=ctcvr_logits, name='ctcvr_loss')
    unweighted_ctr_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=ctr_label, logits=ctr_logits, name='ctr_loss')
    unweighted_loss = tf.add(unweighted_ctr_loss,
                             unweighted_ctcvr_loss,
                             name='total_loss')

    # weighted loss
    ctr_weight_column = params.ctr_weight_column
    ctcvr_weight_column = params.ctcvr_weight_column
    ctr_weights = _get_weights_and_check_match_logits(
        features=features, weight_column=ctr_weight_column, logits=ctr_logits)
    ctr_weighted_loss = tf.losses.compute_weighted_loss(
        unweighted_ctr_loss,
        weights=ctr_weights,
        reduction=tf.losses.Reduction.MEAN)
    ctcvr_weights = _get_weights_and_check_match_logits(
        features=features,
        weight_column=ctcvr_weight_column,
        logits=ctcvr_logits)
    ctcvr_weighted_loss = tf.losses.compute_weighted_loss(
        unweighted_ctcvr_loss,
        weights=ctcvr_weights,
        reduction=tf.losses.Reduction.MEAN)
    weighted_loss = tf.add(ctr_weighted_loss, ctcvr_weighted_loss)
    labels = {'convert_label': ctcvr_label, 'click_label': ctr_label}
    weights = {'convert_weight': ctcvr_weights, 'click_weight': ctr_weights}
    return weighted_loss, unweighted_loss, weights, labels
Exemplo n.º 2
0
 def create_loss(self, features, mode, logits, labels):
   """See `Head`."""
   del mode  # Unused for this head.
   logits = ops.convert_to_tensor(logits)
   processed_labels = self._process_labels(labels)
   processed_labels = head_lib._check_dense_labels_match_logits_and_reshape(  # pylint:disable=protected-access
       labels=processed_labels, logits=logits,
       expected_labels_dimension=self.logits_dimension)
   if self._loss_fn:
     unweighted_loss = head_lib._call_loss_fn(  # pylint:disable=protected-access
         loss_fn=self._loss_fn, labels=processed_labels, logits=logits,
         features=features, expected_loss_dim=1)
   else:
     unweighted_loss = losses.sigmoid_cross_entropy(
         multi_class_labels=processed_labels, logits=logits,
         reduction=losses.Reduction.NONE)
     # Averages loss over classes.
     unweighted_loss = math_ops.reduce_mean(
         unweighted_loss, axis=-1, keepdims=True)
   weights = head_lib._get_weights_and_check_match_logits(  # pylint:disable=protected-access,
       features=features, weight_column=self._weight_column, logits=logits)
   training_loss = losses.compute_weighted_loss(
       unweighted_loss, weights=weights, reduction=self._loss_reduction)
   return head_lib.LossSpec(
       training_loss=training_loss,
       unreduced_loss=unweighted_loss,
       weights=weights,
       processed_labels=processed_labels)
Exemplo n.º 3
0
 def create_loss(self, features, mode, logits, labels):
     """See `Head`."""
     del mode  # Unused for this head.
     logits = ops.convert_to_tensor(logits)
     processed_labels = self._process_labels(labels)
     processed_labels = head_lib._check_dense_labels_match_logits_and_reshape(  # pylint:disable=protected-access
         labels=processed_labels,
         logits=logits,
         expected_labels_dimension=self.logits_dimension)
     if self._loss_fn:
         unweighted_loss = head_lib._call_loss_fn(  # pylint:disable=protected-access
             loss_fn=self._loss_fn,
             labels=processed_labels,
             logits=logits,
             features=features,
             expected_loss_dim=1)
     else:
         unweighted_loss = losses.sigmoid_cross_entropy(
             multi_class_labels=processed_labels,
             logits=logits,
             reduction=losses.Reduction.NONE)
         # Averages loss over classes.
         unweighted_loss = math_ops.reduce_mean(unweighted_loss,
                                                axis=-1,
                                                keep_dims=True)
     weights = head_lib._get_weights_and_check_match_logits(  # pylint:disable=protected-access,
         features=features,
         weight_column=self._weight_column,
         logits=logits)
     training_loss = losses.compute_weighted_loss(
         unweighted_loss, weights=weights, reduction=self._loss_reduction)
     return head_lib.LossSpec(training_loss=training_loss,
                              unreduced_loss=unweighted_loss,
                              weights=weights,
                              processed_labels=processed_labels)
def _create_loss(features, params, logits, labels):
    # follows the logic of tf.losses.sigmoid_cross_entropy
    labels = tf.to_float(labels)

    # unweighted loss
    unweighted_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
                                                              logits=logits)

    # calc weighted loss (assemble cost-sensitive learning)
    weight_column = params.weight_column
    if weight_column:
        weights = _get_weights_and_check_match_logits(
            features=features, weight_column=weight_column,
            logits=logits)  # get weight col tensor
    else:
        weights = 1.0
    weighted_loss = tf.losses.compute_weighted_loss(
        unweighted_loss, weights=weights, reduction=tf.losses.Reduction.MEAN)
    return weighted_loss, unweighted_loss, weights, labels
Exemplo n.º 5
0
    def create_estimator_spec(self,
                              features,
                              mode,
                              logits,
                              labels=None,
                              train_op_fn=None):
        """See `Head`."""
        with ops.name_scope(self._name, 'head'):
            logits = head_lib._check_logits_final_dim(logits,
                                                      self.logits_dimension)  # pylint:disable=protected-access

            # Predict.
            pred_keys = prediction_keys.PredictionKeys
            with ops.name_scope(None, 'predictions', (logits, )):
                probabilities = math_ops.sigmoid(logits,
                                                 name=pred_keys.PROBABILITIES)
                predictions = {
                    pred_keys.LOGITS: logits,
                    pred_keys.PROBABILITIES: probabilities,
                }
            if mode == model_fn.ModeKeys.PREDICT:
                classifier_output = head_lib._classification_output(  # pylint:disable=protected-access
                    scores=probabilities,
                    n_classes=self._n_classes,
                    label_vocabulary=self._label_vocabulary)
                return model_fn.EstimatorSpec(
                    mode=model_fn.ModeKeys.PREDICT,
                    predictions=predictions,
                    export_outputs={
                        _DEFAULT_SERVING_KEY:
                        classifier_output,
                        head_lib._CLASSIFY_SERVING_KEY:
                        classifier_output,  # pylint:disable=protected-access
                        head_lib._PREDICT_SERVING_KEY: (  # pylint:disable=protected-access
                            export_output.PredictOutput(predictions))
                    })

            (weighted_sum_loss, example_weight_sum,
             processed_labels) = self.create_loss(features=features,
                                                  mode=mode,
                                                  logits=logits,
                                                  labels=labels)

            # Eval.
            if mode == model_fn.ModeKeys.EVAL:
                weights = head_lib._get_weights_and_check_match_logits(  # pylint:disable=protected-access,
                    features=features,
                    weight_column=self._weight_column,
                    logits=logits)
                return model_fn.EstimatorSpec(
                    mode=model_fn.ModeKeys.EVAL,
                    predictions=predictions,
                    loss=weighted_sum_loss,
                    eval_metric_ops=self._eval_metric_ops(
                        labels=processed_labels,
                        probabilities=probabilities,
                        weights=weights,
                        weighted_sum_loss=weighted_sum_loss,
                        example_weight_sum=example_weight_sum))

            # Train.
            if train_op_fn is None:
                raise ValueError('train_op_fn can not be None.')
        with ops.name_scope(''):
            summary.scalar(
                head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS),  # pylint:disable=protected-access
                weighted_sum_loss)
            summary.scalar(
                head_lib._summary_key(  # pylint:disable=protected-access
                    self._name, metric_keys.MetricKeys.LOSS_MEAN),
                weighted_sum_loss / example_weight_sum)
        return model_fn.EstimatorSpec(mode=model_fn.ModeKeys.TRAIN,
                                      predictions=predictions,
                                      loss=weighted_sum_loss,
                                      train_op=train_op_fn(weighted_sum_loss))
Exemplo n.º 6
0
  def create_estimator_spec(
      self, features, mode, logits, labels=None, train_op_fn=None):
    """See `Head`."""
    with ops.name_scope(self._name, 'head'):
      logits = head_lib._check_logits_final_dim(logits, self.logits_dimension)  # pylint:disable=protected-access

      # Predict.
      pred_keys = prediction_keys.PredictionKeys
      with ops.name_scope(None, 'predictions', (logits,)):
        probabilities = math_ops.sigmoid(logits, name=pred_keys.PROBABILITIES)
        predictions = {
            pred_keys.LOGITS: logits,
            pred_keys.PROBABILITIES: probabilities,
        }
      if mode == model_fn.ModeKeys.PREDICT:
        classifier_output = head_lib._classification_output(  # pylint:disable=protected-access
            scores=probabilities, n_classes=self._n_classes,
            label_vocabulary=self._label_vocabulary)
        return model_fn.EstimatorSpec(
            mode=model_fn.ModeKeys.PREDICT,
            predictions=predictions,
            export_outputs={
                _DEFAULT_SERVING_KEY: classifier_output,
                head_lib._CLASSIFY_SERVING_KEY: classifier_output,  # pylint:disable=protected-access
                head_lib._PREDICT_SERVING_KEY: (  # pylint:disable=protected-access
                    export_output.PredictOutput(predictions))
            })

      (weighted_sum_loss, example_weight_sum,
       processed_labels) = self.create_loss(
           features=features, mode=mode, logits=logits, labels=labels)

      # Eval.
      if mode == model_fn.ModeKeys.EVAL:
        weights = head_lib._get_weights_and_check_match_logits(  # pylint:disable=protected-access,
            features=features, weight_column=self._weight_column, logits=logits)
        return model_fn.EstimatorSpec(
            mode=model_fn.ModeKeys.EVAL,
            predictions=predictions,
            loss=weighted_sum_loss,
            eval_metric_ops=self._eval_metric_ops(
                labels=processed_labels,
                probabilities=probabilities,
                weights=weights,
                weighted_sum_loss=weighted_sum_loss,
                example_weight_sum=example_weight_sum))

      # Train.
      if train_op_fn is None:
        raise ValueError('train_op_fn can not be None.')
    with ops.name_scope(''):
      summary.scalar(
          head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS),  # pylint:disable=protected-access
          weighted_sum_loss)
      summary.scalar(
          head_lib._summary_key(  # pylint:disable=protected-access
              self._name, metric_keys.MetricKeys.LOSS_MEAN),
          weighted_sum_loss / example_weight_sum)
    return model_fn.EstimatorSpec(
        mode=model_fn.ModeKeys.TRAIN,
        predictions=predictions,
        loss=weighted_sum_loss,
        train_op=train_op_fn(weighted_sum_loss))