예제 #1
0
 def normalize_weights(self, labels, weights):
     """See _RankingLoss."""
     if weights is None:
         weights = 1.
     return torch.where(utils.is_label_valid(labels),
                        torch.ones_like(labels) * weights,
                        torch.zeros_like(labels))
예제 #2
0
def _prepare_and_validate_params(labels, predictions, weights=None, topn=None):
    """Prepares and validates the parameters.

    Args:
      labels: A `Tensor` of the same shape as `predictions`. A value >= 1 means a
        relevant example.
      predictions: A `Tensor` with shape [batch_size, list_size]. Each value is
        the ranking score of the corresponding example.
      weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The
        former case is per-example and the latter case is per-list.
      topn: A cutoff for how many examples to consider for this metric.

    Returns:
      (labels, predictions, weights, topn) ready to be used for metric
      calculation.
    """
    weights = 1.0 if weights is None else weights
    example_weights = torch.ones_like(labels) * weights
    _assert_is_compatible(predictions, example_weights)
    _assert_is_compatible(predictions, labels)
    _assert_has_rank(predictions,2)

    if topn is None:
        topn = predictions.shape[1]

    # All labels should be >= 0. Invalid entries are reset.
    is_label_valid = utils.is_label_valid(labels)
    labels = torch.where(is_label_valid, labels, torch.zeros_like(labels))
    pred_min, _ = torch.min(predictions, dim=1, keepdim=True)
    predictions = torch.where(
        is_label_valid, predictions, -1e-6 * torch.ones_like(predictions) + pred_min)
    return labels, predictions, example_weights, topn
예제 #3
0
 def compute_unreduced_loss(self, labels, logits):
     """See `_RankingLoss`."""
     is_valid = utils.is_label_valid(labels)
     # Reset the invalid labels to 0 and reset the invalid logits to a logit with
     # ~= 0 contribution.
     labels = torch.where(is_valid, labels.float(),
                          torch.zeros_like(labels))
     logits = torch.where(is_valid, logits,
                          math.log(_EPSILON) * torch.ones_like(logits))
     labels_min, _ = torch.min(labels, dim=1, keepdim=True)
     scores = torch.where(is_valid, labels,
                          labels_min - 1e-6 * torch.ones_like(labels))
     sorted_labels, sorted_logits = utils.sort_by_scores(
         scores, [labels, logits])
     raw_max, _ = torch.max(sorted_logits, dim=1, keepdim=True)
     sorted_logits = sorted_logits - raw_max
     sums = self._cumsum_reverse(torch.exp(sorted_logits), 1)
     sums = torch.log(sums) - sorted_logits
     if self._lambda_weight is not None and isinstance(
             self._lambda_weight, ListMLELambdaWeight):
         batch_size, list_size = sorted_labels.shape
         sums *= self._lambda_weight.individual_weights(
             sorted_labels,
             torch.unsqueeze(torch.arange(list_size) + 1,
                             0).repeat(batch_size, 1))
     negative_log_likelihood = torch.sum(sums, dim=1, keepdim=True)
     return negative_log_likelihood, 1.
예제 #4
0
def _get_valid_pairs_and_clean_labels(labels):
    """Returns a boolean Tensor for valid pairs and cleaned labels."""
    assert (labels.dim() >= 2)
    is_valid = utils.is_label_valid(labels)
    valid_pairs = _apply_pairwise_op(utils.logical_and, is_valid)
    labels = torch.where(is_valid, labels, torch.zeros_like(labels))
    return valid_pairs, labels
예제 #5
0
def _pairwise_comparison(labels, logits):
    r"""Returns pairwise comparison `Tensor`s.

    Given a list of n items, the labels of graded relevance l_i and the logits
    s_i, we form n^2 pairs. For each pair, we have the following:

                        /
                        | 1   if l_i > l_j for valid l_i and l_j.
    * `pairwise_labels` = |
                        | 0   otherwise
                        \
    * `pairwise_logits` = s_i - s_j

    Args:
    labels: A `Tensor` with shape [batch_size, list_size].
    logits: A `Tensor` with shape [batch_size, list_size].

    Returns:
    A tuple of (pairwise_labels, pairwise_logits) with each having the shape
    [batch_size, list_size, list_size].
    """
    # Compute the difference for all pairs in a list. The output is a Tensor with
    # shape [batch_size, list_size, list_size] where the entry [-1, i, j] stores
    # the information for pair (i, j).
    pairwise_label_diff = _apply_pairwise_op(torch.sub, labels)
    pairwise_logits = _apply_pairwise_op(torch.sub, logits)
    # Only keep the case when l_i > l_j.
    pairwise_labels = torch.gt(pairwise_label_diff, 0).float()
    is_valid = utils.is_label_valid(labels)
    valid_pair = _apply_pairwise_op(utils.logical_and, is_valid)
    pairwise_labels *= valid_pair.float()
    return pairwise_labels, pairwise_logits
예제 #6
0
def _softmax_loss(labels,
                  logits,
                  weights=None,
                  lambda_weight=None,
                  reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
                  name=None):
    """Computes the softmax cross entropy for a list.

  Given the labels l_i and the logits s_i, we sort the examples and obtain ranks
  r_i. The standard softmax loss doesn't need r_i and is defined as
      -sum_i l_i * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))).
  The `lambda_weight` re-weight examples based on l_i and r_i.
      -sum_i w(l_i, r_i) * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))).abc
  See 'individual_weights' in 'DCGLambdaWeight' for how w(l_i, r_i) is computed.

  Args:
    labels: A `Tensor` of the same shape as `logits` representing graded
      relevance.
    logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
      ranking score of the corresponding item.
    weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
      weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
      weights.
    lambda_weight: A `DCGLambdaWeight` instance.
    reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch.
    name: A string used as the name for this loss.

  Returns:
    An op for the softmax cross entropy as a loss.
  """
    with ops.name_scope(name, 'softmax_loss', (labels, logits, weights)):
        sorted_labels, sorted_logits, sorted_weights = _sort_and_normalize(
            labels, logits, weights)
        is_label_valid = utils.is_label_valid(sorted_labels)
        # Reset the invalid labels to 0 and reset the invalid logits to a logit with
        # ~= 0 contribution in softmax.
        sorted_labels = array_ops.where(is_label_valid, sorted_labels,
                                        array_ops.zeros_like(sorted_labels))
        sorted_logits = array_ops.where(
            is_label_valid, sorted_logits,
            math_ops.log(_EPSILON) * array_ops.ones_like(sorted_logits))
        if lambda_weight is not None and isinstance(lambda_weight,
                                                    DCGLambdaWeight):
            sorted_labels = lambda_weight.individual_weights(sorted_labels)
        sorted_labels *= sorted_weights
        label_sum = math_ops.reduce_sum(sorted_labels, 1, keepdims=True)
        nonzero_mask = math_ops.greater(array_ops.reshape(label_sum, [-1]),
                                        0.0)
        label_sum, sorted_labels, sorted_logits = [
            array_ops.boolean_mask(x, nonzero_mask)
            for x in [label_sum, sorted_labels, sorted_logits]
        ]
        return core_losses.softmax_cross_entropy(sorted_labels / label_sum,
                                                 sorted_logits,
                                                 weights=array_ops.reshape(
                                                     label_sum, [-1]),
                                                 reduction=reduction)
예제 #7
0
 def compute_unreduced_loss(self, labels, logits):
     """See `_RankingLoss`."""
     is_valid = utils.is_label_valid(labels)
     labels = torch.where(is_valid, labels.float(),
                          torch.zeros_like(labels))
     logits = torch.where(is_valid, logits.float(),
                          torch.zeros_like(logits))
     losses = (labels - logits)**2
     return losses, 1.
예제 #8
0
 def _get_valid_pairs_and_clean_labels(self, sorted_labels):
     """Returns a boolean Tensor for valid pairs and cleaned labels."""
     sorted_labels = ops.convert_to_tensor(sorted_labels)
     sorted_labels.get_shape().assert_has_rank(2)
     is_label_valid = utils.is_label_valid(sorted_labels)
     valid_pairs = math_ops.logical_and(
         array_ops.expand_dims(is_label_valid, 2),
         array_ops.expand_dims(is_label_valid, 1))
     sorted_labels = array_ops.where(is_label_valid, sorted_labels,
                                     array_ops.zeros_like(sorted_labels))
     return valid_pairs, sorted_labels
예제 #9
0
 def normalize_weights(self, labels, weights):
     """See _RankingLoss."""
     # The `weights` is item-wise and is applied non-symmetrically to update
     # pairwise_weights as
     #   pairwise_weights(i, j) = w_i * pairwise_weights(i, j).
     # This effectively applies to all pairs with l_i > l_j. Note that it is
     # actually symmetric when `weights` are constant per list, i.e., listwise
     # weights.
     if weights is None:
         weights = 1.
     weights = torch.where(utils.is_label_valid(labels),
                           torch.ones_like(labels) * weights,
                           torch.zeros_like(labels))
     return torch.unsqueeze(weights, dim=2)
예제 #10
0
 def individual_weights(self, sorted_labels):
     """See `_LambdaWeight`."""
     with ops.name_scope(None, 'dcg_lambda_weight', (sorted_labels, )):
         sorted_labels = ops.convert_to_tensor(sorted_labels)
         sorted_labels = array_ops.where(
             utils.is_label_valid(sorted_labels), sorted_labels,
             array_ops.zeros_like(sorted_labels))
         gain = self._gain_fn(sorted_labels)
         if self._normalized:
             gain *= self._inverse_max_dcg(sorted_labels)
         rank_discount = self._rank_discount_fn(
             math_ops.to_float(
                 math_ops.range(array_ops.shape(sorted_labels)[1]) + 1))
         return gain * rank_discount
예제 #11
0
 def individual_weights(self, labels, ranks):
     """See `_LambdaWeight`."""
     _check_tensor_shapes([labels, ranks])
     labels = torch.where(utils.is_label_valid(labels), labels,
                          torch.zeros_like(labels))
     gain = self._gain_fn(labels)
     if self._normalized:
         gain *= utils.inverse_max_dcg(
             labels,
             gain_fn=self._gain_fn,
             rank_discount_fn=self._rank_discount_fn,
             topn=self._topn)
     rank_discount = self._rank_discount_fn(ranks.float())
     return gain * rank_discount
예제 #12
0
 def precompute(self, labels, logits, weights):
     """Precomputes Tensors for softmax cross entropy inputs."""
     is_valid = utils.is_label_valid(labels)
     ranks = _compute_ranks(logits, is_valid)
     # Reset the invalid labels to 0 and reset the invalid logits to a logit with
     # ~= 0 contribution in softmax.
     labels = torch.where(is_valid, labels, torch.zeros_like(labels))
     logits = torch.where(is_valid, logits,
                          math.log(_EPSILON) * torch.ones_like(logits))
     if self._lambda_weight is not None and isinstance(
             self._lambda_weight, DCGLambdaWeight):
         labels = self._lambda_weight.individual_weights(labels, ranks)
     if weights is not None:
         labels *= weights
     return labels, logits
예제 #13
0
    def compute_unreduced_loss(self, labels, logits):
        """See `_RankingLoss`."""
        is_valid = utils.is_label_valid(labels)
        ranks = _compute_ranks(logits, is_valid)
        pairwise_labels, pairwise_logits = _pairwise_comparison(labels, logits)
        pairwise_weights = pairwise_labels
        if self._lambda_weight is not None:
            pairwise_weights *= self._lambda_weight.pair_weights(labels, ranks)
            # For LambdaLoss with relative rank difference, the scale of loss becomes
            # much smaller when applying LambdaWeight. This affects the training can
            # make the optimal learning rate become much larger. We use a heuristic to
            # scale it up to the same magnitude as standard pairwise loss.
            pairwise_weights *= float(labels.shape[1])

        #         pairwise_weights = tf.stop_gradient(
        #             pairwise_weights, name='weights_stop_gradient')
        pairwise_weights = pairwise_weights.clone().detach()
        return self._pairwise_loss(pairwise_logits), pairwise_weights
예제 #14
0
    def compute_unreduced_loss(self, labels, logits):
        """See `_RankingLoss`."""
        alpha = self._params.get('alpha', 10.0)
        is_valid = utils.is_label_valid(labels)
        labels = torch.where(is_valid, labels, torch.zeros_like(labels))
        logits_min, _ = torch.min(logits, dim=-1, keepdim=True)
        logits = torch.where(is_valid, logits,
                             -1e3 * torch.ones_like(logits) + logits_min)

        label_sum = torch.sum(labels, dim=1, keepdim=True)

        nonzero_mask = torch.gt(label_sum, 0.0)
        labels = torch.where(nonzero_mask, labels,
                             _EPSILON * torch.ones_like(labels))

        rr = 1. / utils.approx_ranks(logits, alpha=alpha)
        rr = torch.sum(rr * labels, dim=-1, keepdim=True)
        mrr = rr / torch.sum(labels, dim=-1, keepdim=True)
        return -mrr, nonzero_mask.float()
예제 #15
0
    def compute_unreduced_loss(self, labels, logits):
        """See `_RankingLoss`."""
        alpha = self._params.get('alpha', 10.0)
        is_valid = utils.is_label_valid(labels)
        labels = torch.where(is_valid, labels, torch.zeros_like(labels))
        logits_min, _ = torch.min(logits, dim=-1, keepdim=True)
        logits = torch.where(is_valid, logits,
                             -1e3 * torch.ones_like(logits) + logits_min)

        label_sum = torch.sum(labels, dim=1, keepdim=True)
        nonzero_mask = torch.gt(label_sum, 0.0)
        labels = torch.where(nonzero_mask, labels,
                             _EPSILON * torch.ones_like(labels))
        gains = torch.pow(2., labels.float()) - 1.
        ranks = utils.approx_ranks(logits, alpha=alpha)
        discounts = 1. / torch.log1p(ranks)
        dcg = torch.sum(gains * discounts, dim=-1, keepdim=True)
        cost = -dcg * utils.inverse_max_dcg(labels)
        return cost, nonzero_mask.float()
예제 #16
0
def _sigmoid_cross_entropy_loss(
        labels,
        logits,
        weights=None,
        reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
        name=None):
    """Computes the sigmoid_cross_entropy loss for a list.

  Given the labels of graded relevance l_i and the logits s_i, we calculate
  the sigmoid cross entropy for each ith position and aggregate the per position
  losses.

  Args:
    labels: A `Tensor` of the same shape as `logits` representing graded
      relevance.
    logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
      ranking score of the corresponding item.
    weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
      weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
      weights.
    reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch.
    name: A string used as the name for this loss.

  Returns:
    An op for the sigmoid cross entropy as a loss.
  """
    with ops.name_scope(name, 'sigmoid_cross_entropy_loss',
                        (labels, logits, weights)):
        is_label_valid = array_ops.reshape(utils.is_label_valid(labels), [-1])
        weights = 1.0 if weights is None else ops.convert_to_tensor(weights)
        weights = array_ops.ones_like(labels) * weights
        label_vector, logit_vector, weight_vector = [
            array_ops.boolean_mask(array_ops.reshape(x, [-1]), is_label_valid)
            for x in [labels, logits, weights]
        ]
        return core_losses.sigmoid_cross_entropy(label_vector,
                                                 logit_vector,
                                                 weights=weight_vector,
                                                 reduction=reduction)
예제 #17
0
def _pairwise_comparison(sorted_labels,
                         sorted_logits,
                         sorted_weights,
                         lambda_weight=None):
    r"""Returns pairwise comparison `Tensor`s.

  Given a list of n items, the labels of graded relevance l_i and the logits
  s_i, we sort the items in a list based on s_i and obtain ranks r_i. We form
  n^2 pairs of items. For each pair, we have the following:

                        /
                        | 1   if l_i > l_j
  * `pairwise_labels` = |
                        | 0   if l_i <= l_j
                        \
  * `pairwise_logits` = s_i - s_j
                         /
                         | 0              if l_i <= l_j,
  * `pairwise_weights` = | |l_i - l_j|    if lambda_weight is None,
                         | lambda_weight  otherwise.
                         \

  The `sorted_weights` is item-wise and is applied non-symmetrically to update
  pairwise_weights as
    pairwise_weights(i, j) = w_i * pairwise_weights(i, j).
  This effectively applies to all pairs with l_i > l_j. Note that it is actually
  symmetric when `sorted_weights` are constant per list, i.e., listwise weights.

  Args:
    sorted_labels: A `Tensor` with shape [batch_size, list_size] of labels
      sorted.
    sorted_logits: A `Tensor` with shape [batch_size, list_size] of logits
      sorted.
    sorted_weights: A `Tensor` with shape [batch_size, list_size] of item-wise
      weights sorted.
    lambda_weight: A `_LambdaWeight` object.

  Returns:
    A tuple of (pairwise_labels, pairwise_logits, pairwise_weights) with each
    having the shape [batch_size, list_size, list_size].
  """
    # Compute the difference for all pairs in a list. The output is a Tensor with
    # shape [batch_size, list_size, list_size] where the entry [-1, i, j] stores
    # the information for pair (i, j).
    pairwise_label_diff = array_ops.expand_dims(
        sorted_labels, 2) - array_ops.expand_dims(sorted_labels, 1)
    pairwise_logits = array_ops.expand_dims(
        sorted_logits, 2) - array_ops.expand_dims(sorted_logits, 1)
    pairwise_labels = math_ops.to_float(
        math_ops.greater(pairwise_label_diff, 0))
    is_label_valid = utils.is_label_valid(sorted_labels)
    valid_pair = math_ops.logical_and(array_ops.expand_dims(is_label_valid, 2),
                                      array_ops.expand_dims(is_label_valid, 1))
    # Only keep the case when l_i > l_j.
    pairwise_weights = pairwise_labels * math_ops.to_float(valid_pair)
    # Apply the item-wise weights along l_i.
    pairwise_weights *= array_ops.expand_dims(sorted_weights, 2)
    if lambda_weight is not None:
        pairwise_weights *= lambda_weight.pair_weights(sorted_labels)
    else:
        pairwise_weights *= math_ops.abs(pairwise_label_diff)
    pairwise_weights = array_ops.stop_gradient(pairwise_weights,
                                               name='weights_stop_gradient')
    return pairwise_labels, pairwise_logits, pairwise_weights
예제 #18
0
 def compute_unreduced_loss(self, labels, logits):
     """See `_RankingLoss`."""
     labels = torch.where(utils.is_label_valid(labels), labels,
                          torch.zeros_like(labels))
     losses = self._sigmoid_cross_entropy_with_logits(labels, logits)
     return losses, 1.
예제 #19
0
def _list_mle_loss(labels,
                   logits,
                   weights=None,
                   lambda_weight=None,
                   reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
                   name=None,
                   seed=None):
    """Computes the ListMLE loss [Xia et al.

  2008] for a list.

  Given the labels of graded relevance l_i and the logits s_i, we calculate
  the ListMLE loss for the given list.

  The `lambda_weight` re-weights examples based on l_i and r_i.
  The recommended weighting scheme is the formulation presented in the
  "Position-Aware ListMLE" paper (Lan et. al) and available using
  create_p_list_mle_lambda_weight() factory function above.

  Args:
    labels: A `Tensor` of the same shape as `logits` representing graded
      relevance.
    logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
      ranking score of the corresponding item.
    weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
      weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
      weights.
    lambda_weight: A `DCGLambdaWeight` instance.
    reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch.
    name: A string used as the name for this loss.
    seed: A randomization seed used when shuffling ground truth permutations.

  Returns:
    An op for the ListMLE loss.
  """
    with ops.name_scope(name, 'list_mle_loss', (labels, logits, weights)):
        is_label_valid = utils.is_label_valid(labels)
        # Reset the invalid labels to 0 and reset the invalid logits to a logit with
        # ~= 0 contribution.
        labels = array_ops.where(is_label_valid, labels,
                                 array_ops.zeros_like(labels))
        logits = array_ops.where(
            is_label_valid, logits,
            math_ops.log(_EPSILON) * array_ops.ones_like(logits))
        weights = 1.0 if weights is None else ops.convert_to_tensor(weights)
        weights = array_ops.squeeze(weights)

        # Shuffle labels and logits to add randomness to sort.
        shuffled_indices = utils.shuffle_valid_indices(is_label_valid, seed)
        shuffled_labels = array_ops.gather_nd(labels, shuffled_indices)
        shuffled_logits = array_ops.gather_nd(logits, shuffled_indices)

        sorted_labels, sorted_logits = utils.sort_by_scores(
            shuffled_labels, [shuffled_labels, shuffled_logits])

        raw_max = math_ops.reduce_max(sorted_logits, axis=1, keepdims=True)
        sorted_logits = sorted_logits - raw_max
        sums = math_ops.cumsum(math_ops.exp(sorted_logits),
                               axis=1,
                               reverse=True)
        sums = math_ops.log(sums) - sorted_logits

        if lambda_weight is not None and isinstance(lambda_weight,
                                                    ListMLELambdaWeight):
            sums *= lambda_weight.individual_weights(sorted_labels)

        negative_log_likelihood = math_ops.reduce_sum(sums, 1)

        return core_losses.compute_weighted_loss(negative_log_likelihood,
                                                 weights=weights,
                                                 reduction=reduction)