def create_loss(self, features, mode, logits, labels): """See `Head`.""" del mode # Unused for this head. logits = ops.convert_to_tensor(logits) processed_labels = self._process_labels(labels) processed_labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint:disable=protected-access labels=processed_labels, logits=logits, expected_labels_dimension=self.logits_dimension) if self._loss_fn: unweighted_loss = head_lib._call_loss_fn( # pylint:disable=protected-access loss_fn=self._loss_fn, labels=processed_labels, logits=logits, features=features, expected_loss_dim=1) else: unweighted_loss = losses.sigmoid_cross_entropy( multi_class_labels=processed_labels, logits=logits, reduction=losses.Reduction.NONE) # Averages loss over classes. unweighted_loss = math_ops.reduce_mean( unweighted_loss, axis=-1, keepdims=True) weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, features=features, weight_column=self._weight_column, logits=logits) training_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=self._loss_reduction) return head_lib.LossSpec( training_loss=training_loss, unreduced_loss=unweighted_loss, weights=weights, processed_labels=processed_labels)
def create_loss(self, features, mode, logits, labels): """See `Head`.""" del mode # Unused for this head. logits = ops.convert_to_tensor(logits) processed_labels = self._process_labels(labels) processed_labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint:disable=protected-access labels=processed_labels, logits=logits, expected_labels_dimension=self.logits_dimension) if self._loss_fn: unweighted_loss = head_lib._call_loss_fn( # pylint:disable=protected-access loss_fn=self._loss_fn, labels=processed_labels, logits=logits, features=features, expected_loss_dim=1) else: unweighted_loss = losses.sigmoid_cross_entropy( multi_class_labels=processed_labels, logits=logits, reduction=losses.Reduction.NONE) # Averages loss over classes. unweighted_loss = math_ops.reduce_mean(unweighted_loss, axis=-1, keep_dims=True) weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, features=features, weight_column=self._weight_column, logits=logits) training_loss = losses.compute_weighted_loss( unweighted_loss, weights=weights, reduction=self._loss_reduction) return head_lib.LossSpec(training_loss=training_loss, unreduced_loss=unweighted_loss, weights=weights, processed_labels=processed_labels)