Example #1
0
 def create_loss(self, features, mode, logits, labels):
   """See `Head`."""
   del mode  # Unused for this head.
   logits = ops.convert_to_tensor(logits)
   processed_labels = self._process_labels(labels)
   processed_labels = head_lib._check_dense_labels_match_logits_and_reshape(  # pylint:disable=protected-access
       labels=processed_labels, logits=logits,
       expected_labels_dimension=self.logits_dimension)
   if self._loss_fn:
     unweighted_loss = head_lib._call_loss_fn(  # pylint:disable=protected-access
         loss_fn=self._loss_fn, labels=processed_labels, logits=logits,
         features=features, expected_loss_dim=1)
   else:
     unweighted_loss = losses.sigmoid_cross_entropy(
         multi_class_labels=processed_labels, logits=logits,
         reduction=losses.Reduction.NONE)
     # Averages loss over classes.
     unweighted_loss = math_ops.reduce_mean(
         unweighted_loss, axis=-1, keepdims=True)
   weights = head_lib._get_weights_and_check_match_logits(  # pylint:disable=protected-access,
       features=features, weight_column=self._weight_column, logits=logits)
   training_loss = losses.compute_weighted_loss(
       unweighted_loss, weights=weights, reduction=self._loss_reduction)
   return head_lib.LossSpec(
       training_loss=training_loss,
       unreduced_loss=unweighted_loss,
       weights=weights,
       processed_labels=processed_labels)
Example #2
0
 def create_loss(self, features, mode, logits, labels):
     """See `Head`."""
     del mode  # Unused for this head.
     logits = ops.convert_to_tensor(logits)
     processed_labels = self._process_labels(labels)
     processed_labels = head_lib._check_dense_labels_match_logits_and_reshape(  # pylint:disable=protected-access
         labels=processed_labels,
         logits=logits,
         expected_labels_dimension=self.logits_dimension)
     if self._loss_fn:
         unweighted_loss = head_lib._call_loss_fn(  # pylint:disable=protected-access
             loss_fn=self._loss_fn,
             labels=processed_labels,
             logits=logits,
             features=features,
             expected_loss_dim=1)
     else:
         unweighted_loss = losses.sigmoid_cross_entropy(
             multi_class_labels=processed_labels,
             logits=logits,
             reduction=losses.Reduction.NONE)
         # Averages loss over classes.
         unweighted_loss = math_ops.reduce_mean(unweighted_loss,
                                                axis=-1,
                                                keep_dims=True)
     weights = head_lib._get_weights_and_check_match_logits(  # pylint:disable=protected-access,
         features=features,
         weight_column=self._weight_column,
         logits=logits)
     training_loss = losses.compute_weighted_loss(
         unweighted_loss, weights=weights, reduction=self._loss_reduction)
     return head_lib.LossSpec(training_loss=training_loss,
                              unreduced_loss=unweighted_loss,
                              weights=weights,
                              processed_labels=processed_labels)