Esempio n. 1
0
 def create_loss(self, features, mode, logits, labels):
     """See `Head`."""
     del mode  # Unused for this head.
     logits = ops.convert_to_tensor(logits)
     processed_labels = self._process_labels(labels)
     processed_labels = head_lib._check_dense_labels_match_logits_and_reshape(  # pylint:disable=protected-access
         labels=processed_labels,
         logits=logits,
         expected_labels_dimension=self.logits_dimension)
     if self._loss_fn:
         unweighted_loss = head_lib._call_loss_fn(  # pylint:disable=protected-access
             loss_fn=self._loss_fn,
             labels=processed_labels,
             logits=logits,
             features=features,
             expected_loss_dim=1)
     else:
         unweighted_loss = losses.sigmoid_cross_entropy(
             multi_class_labels=processed_labels,
             logits=logits,
             reduction=losses.Reduction.NONE)
         # Averages loss over classes.
         unweighted_loss = math_ops.reduce_mean(unweighted_loss,
                                                axis=-1,
                                                keep_dims=True)
     weights = head_lib._get_weights_and_check_match_logits(  # pylint:disable=protected-access,
         features=features,
         weight_column=self._weight_column,
         logits=logits)
     training_loss = losses.compute_weighted_loss(
         unweighted_loss, weights=weights, reduction=self._loss_reduction)
     return head_lib.LossSpec(training_loss=training_loss,
                              unreduced_loss=unweighted_loss,
                              weights=weights,
                              processed_labels=processed_labels)
Esempio n. 2
0
  def create_loss(self, features, mode, logits, labels):
    """Returns a loss Tensor from provided logits.

    This function is designed to be used by framework developers.  Almost all
    users should use create_estimator_spec(), which calls this internally.
    `mode` and `features` are most likely not used, but some Head
    implementations may require them.

    Args:
      features: Input `dict` of `Tensor` objects.
      mode: Estimator's `ModeKeys`.
      logits: logits `Tensor` to be used for loss construction.
      labels: Labels `Tensor`, or `dict` of same.

    Returns:
      A LossSpec that contains
      * the scalar `Tensor` representing reduced weighted training loss
      * the `Tensor` representing the unreduced unweighted loss
      * the `Tensor` representing the example weights
      * possibly processed labels (e.g. vocabulary lookup, shape manipulation,
        etc.)
    """
    del mode  # Unused for this head.
    logits = ops.convert_to_tensor(logits)
    labels = math_ops.to_float(labels)

    training_loss = self._loss_fn(labels, logits, features)

    return head_lib.LossSpec(
        training_loss=training_loss,
        unreduced_loss=None,
        weights=None,
        processed_labels=labels)
Esempio n. 3
0
 def create_loss(self, features, mode, logits, labels):
     """See `Head`."""
     del mode  # Unused for this head.
     processed_labels = self._process_labels(labels)
     if self._loss_fn:
         unweighted_loss = _call_loss_fn(loss_fn=self._loss_fn,
                                         labels=processed_labels,
                                         logits=logits,
                                         features=features)
     else:
         unweighted_loss = losses.sigmoid_cross_entropy(
             multi_class_labels=processed_labels,
             logits=logits,
             reduction=losses.Reduction.NONE)
         # Averages loss over classes.
         unweighted_loss = math_ops.reduce_mean(unweighted_loss,
                                                axis=-1,
                                                keep_dims=True)
     weights = head_lib._weights(features, self._weight_column)  # pylint:disable=protected-access,
     weighted_sum_loss = losses.compute_weighted_loss(
         unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
     # _weights() can return 1.
     example_weight_sum = math_ops.reduce_sum(
         weights * array_ops.ones_like(unweighted_loss))
     return head_lib.LossSpec(weighted_sum_loss=weighted_sum_loss,
                              example_weight_sum=example_weight_sum,
                              processed_labels=processed_labels)
Esempio n. 4
0
    def create_loss(self, features, mode, logits, labels):
        """See `Head`."""
        if isinstance(logits, dict):
            logits_dict = logits
        else:
            logits_dict = self._split_logits(logits)
        training_losses = []
        labels_by_head = {}
        unreduced_losses_by_head = {}
        example_weights_by_head = {}
        for i, head in enumerate(self._heads):
            (training_loss, unreduced_loss, weights,
             processed_labels) = head.create_loss(features, mode,
                                                  logits_dict[head.name],
                                                  labels[head.name])
            training_losses.append(training_loss)
            labels_by_head[head.name] = processed_labels
            if self._head_weights:
                head_weight = self._head_weights[i]
                unreduced_losses_by_head[head.name] = math_ops.multiply(
                    unreduced_loss, head_weight)
                example_weights_by_head[head.name] = math_ops.multiply(
                    weights, head_weight)
            else:
                unreduced_losses_by_head[head.name] = unreduced_loss
                example_weights_by_head[head.name] = weights

        training_losses = tuple(training_losses)
        with ops.name_scope('merge_losses',
                            values=training_losses +
                            (self._head_weights or tuple())):
            if self._head_weights:
                head_weighted_training_losses = []
                for training_loss, head_weight in zip(training_losses,
                                                      self._head_weights):
                    head_weighted_training_losses.append(
                        math_ops.multiply(training_loss, head_weight))
                merged_training_loss = math_ops.add_n(
                    head_weighted_training_losses)
            else:
                merged_training_loss = math_ops.add_n(training_losses)

        return head_lib.LossSpec(training_loss=merged_training_loss,
                                 unreduced_loss=unreduced_losses_by_head,
                                 weights=example_weights_by_head,
                                 processed_labels=labels_by_head)
Esempio n. 5
0
  def create_loss(self, features, mode, logits, labels):
    """See `Head`."""
    # TODO(roumposg): Add support for logits as single Tensor (with
    # _split_logits utility).
    if not isinstance(logits, dict):
      raise ValueError('logits must be a dict.  Single Tensor support coming '
                       'soon.')
    weighted_sum_losses = []
    example_weight_sums = []
    labels_by_head = {}
    for head in self._heads:
      (weighted_sum_loss,
       example_weight_sum, processed_labels) = head.create_loss(
           features, mode, logits[head.name], labels[head.name])
      weighted_sum_losses.append(weighted_sum_loss)
      example_weight_sums.append(example_weight_sum)
      labels_by_head[head.name] = processed_labels

    weighted_sum_losses = tuple(weighted_sum_losses)
    with ops.name_scope('merge_losses',
                        values=weighted_sum_losses + (self._head_weights or
                                                      tuple())):
      if self._head_weights:
        head_weighted_losses = []
        head_weighted_example_weight_sums = []
        for loss, example_weight_sum, weight in zip(weighted_sum_losses,
                                                    example_weight_sums,
                                                    self._head_weights):
          head_weighted_losses.append(math_ops.multiply(loss, weight))
          head_weighted_example_weight_sums.append(math_ops.multiply(
              example_weight_sum, weight))
        merged_weighted_sum_loss = math_ops.add_n(head_weighted_losses)
        merged_example_weight_sum = math_ops.add_n(
            head_weighted_example_weight_sums)
      else:
        merged_weighted_sum_loss = math_ops.add_n(weighted_sum_losses)
        merged_example_weight_sum = math_ops.add_n(example_weight_sums)

    return head_lib.LossSpec(
        weighted_sum_loss=merged_weighted_sum_loss,
        example_weight_sum=merged_example_weight_sum,
        processed_labels=labels_by_head)
Esempio n. 6
0
    def create_loss(self, features, mode, logits, labels):
        """See `Head`."""
        if isinstance(logits, dict):
            logits_dict = logits
        else:
            logits_dict = self._split_logits(logits)
        weighted_sum_losses = []
        example_weight_sums = []
        labels_by_head = {}
        for head in self._heads:
            (weighted_sum_loss, example_weight_sum,
             processed_labels) = head.create_loss(features, mode,
                                                  logits_dict[head.name],
                                                  labels[head.name])
            weighted_sum_losses.append(weighted_sum_loss)
            example_weight_sums.append(example_weight_sum)
            labels_by_head[head.name] = processed_labels

        weighted_sum_losses = tuple(weighted_sum_losses)
        with ops.name_scope('merge_losses',
                            values=weighted_sum_losses +
                            (self._head_weights or tuple())):
            if self._head_weights:
                head_weighted_losses = []
                head_weighted_example_weight_sums = []
                for loss, example_weight_sum, weight in zip(
                        weighted_sum_losses, example_weight_sums,
                        self._head_weights):
                    head_weighted_losses.append(math_ops.multiply(
                        loss, weight))
                    head_weighted_example_weight_sums.append(
                        math_ops.multiply(example_weight_sum, weight))
                merged_weighted_sum_loss = math_ops.add_n(head_weighted_losses)
                merged_example_weight_sum = math_ops.add_n(
                    head_weighted_example_weight_sums)
            else:
                merged_weighted_sum_loss = math_ops.add_n(weighted_sum_losses)
                merged_example_weight_sum = math_ops.add_n(example_weight_sums)

        return head_lib.LossSpec(weighted_sum_loss=merged_weighted_sum_loss,
                                 example_weight_sum=merged_example_weight_sum,
                                 processed_labels=labels_by_head)