def normalize_weights(self, labels, weights): """See _RankingLoss.""" if weights is None: weights = 1. return torch.where(utils.is_label_valid(labels), torch.ones_like(labels) * weights, torch.zeros_like(labels))
def _prepare_and_validate_params(labels, predictions, weights=None, topn=None): """Prepares and validates the parameters. Args: labels: A `Tensor` of the same shape as `predictions`. A value >= 1 means a relevant example. predictions: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding example. weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The former case is per-example and the latter case is per-list. topn: A cutoff for how many examples to consider for this metric. Returns: (labels, predictions, weights, topn) ready to be used for metric calculation. """ weights = 1.0 if weights is None else weights example_weights = torch.ones_like(labels) * weights _assert_is_compatible(predictions, example_weights) _assert_is_compatible(predictions, labels) _assert_has_rank(predictions,2) if topn is None: topn = predictions.shape[1] # All labels should be >= 0. Invalid entries are reset. is_label_valid = utils.is_label_valid(labels) labels = torch.where(is_label_valid, labels, torch.zeros_like(labels)) pred_min, _ = torch.min(predictions, dim=1, keepdim=True) predictions = torch.where( is_label_valid, predictions, -1e-6 * torch.ones_like(predictions) + pred_min) return labels, predictions, example_weights, topn
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" is_valid = utils.is_label_valid(labels) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution. labels = torch.where(is_valid, labels.float(), torch.zeros_like(labels)) logits = torch.where(is_valid, logits, math.log(_EPSILON) * torch.ones_like(logits)) labels_min, _ = torch.min(labels, dim=1, keepdim=True) scores = torch.where(is_valid, labels, labels_min - 1e-6 * torch.ones_like(labels)) sorted_labels, sorted_logits = utils.sort_by_scores( scores, [labels, logits]) raw_max, _ = torch.max(sorted_logits, dim=1, keepdim=True) sorted_logits = sorted_logits - raw_max sums = self._cumsum_reverse(torch.exp(sorted_logits), 1) sums = torch.log(sums) - sorted_logits if self._lambda_weight is not None and isinstance( self._lambda_weight, ListMLELambdaWeight): batch_size, list_size = sorted_labels.shape sums *= self._lambda_weight.individual_weights( sorted_labels, torch.unsqueeze(torch.arange(list_size) + 1, 0).repeat(batch_size, 1)) negative_log_likelihood = torch.sum(sums, dim=1, keepdim=True) return negative_log_likelihood, 1.
def _get_valid_pairs_and_clean_labels(labels): """Returns a boolean Tensor for valid pairs and cleaned labels.""" assert (labels.dim() >= 2) is_valid = utils.is_label_valid(labels) valid_pairs = _apply_pairwise_op(utils.logical_and, is_valid) labels = torch.where(is_valid, labels, torch.zeros_like(labels)) return valid_pairs, labels
def _pairwise_comparison(labels, logits): r"""Returns pairwise comparison `Tensor`s. Given a list of n items, the labels of graded relevance l_i and the logits s_i, we form n^2 pairs. For each pair, we have the following: / | 1 if l_i > l_j for valid l_i and l_j. * `pairwise_labels` = | | 0 otherwise \ * `pairwise_logits` = s_i - s_j Args: labels: A `Tensor` with shape [batch_size, list_size]. logits: A `Tensor` with shape [batch_size, list_size]. Returns: A tuple of (pairwise_labels, pairwise_logits) with each having the shape [batch_size, list_size, list_size]. """ # Compute the difference for all pairs in a list. The output is a Tensor with # shape [batch_size, list_size, list_size] where the entry [-1, i, j] stores # the information for pair (i, j). pairwise_label_diff = _apply_pairwise_op(torch.sub, labels) pairwise_logits = _apply_pairwise_op(torch.sub, logits) # Only keep the case when l_i > l_j. pairwise_labels = torch.gt(pairwise_label_diff, 0).float() is_valid = utils.is_label_valid(labels) valid_pair = _apply_pairwise_op(utils.logical_and, is_valid) pairwise_labels *= valid_pair.float() return pairwise_labels, pairwise_logits
def _softmax_loss(labels, logits, weights=None, lambda_weight=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None): """Computes the softmax cross entropy for a list. Given the labels l_i and the logits s_i, we sort the examples and obtain ranks r_i. The standard softmax loss doesn't need r_i and is defined as -sum_i l_i * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))). The `lambda_weight` re-weight examples based on l_i and r_i. -sum_i w(l_i, r_i) * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))).abc See 'individual_weights' in 'DCGLambdaWeight' for how w(l_i, r_i) is computed. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `DCGLambdaWeight` instance. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. Returns: An op for the softmax cross entropy as a loss. """ with ops.name_scope(name, 'softmax_loss', (labels, logits, weights)): sorted_labels, sorted_logits, sorted_weights = _sort_and_normalize( labels, logits, weights) is_label_valid = utils.is_label_valid(sorted_labels) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution in softmax. sorted_labels = array_ops.where(is_label_valid, sorted_labels, array_ops.zeros_like(sorted_labels)) sorted_logits = array_ops.where( is_label_valid, sorted_logits, math_ops.log(_EPSILON) * array_ops.ones_like(sorted_logits)) if lambda_weight is not None and isinstance(lambda_weight, DCGLambdaWeight): sorted_labels = lambda_weight.individual_weights(sorted_labels) sorted_labels *= sorted_weights label_sum = math_ops.reduce_sum(sorted_labels, 1, keepdims=True) nonzero_mask = math_ops.greater(array_ops.reshape(label_sum, [-1]), 0.0) label_sum, sorted_labels, sorted_logits = [ array_ops.boolean_mask(x, nonzero_mask) for x in [label_sum, sorted_labels, sorted_logits] ] return core_losses.softmax_cross_entropy(sorted_labels / label_sum, sorted_logits, weights=array_ops.reshape( label_sum, [-1]), reduction=reduction)
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" is_valid = utils.is_label_valid(labels) labels = torch.where(is_valid, labels.float(), torch.zeros_like(labels)) logits = torch.where(is_valid, logits.float(), torch.zeros_like(logits)) losses = (labels - logits)**2 return losses, 1.
def _get_valid_pairs_and_clean_labels(self, sorted_labels): """Returns a boolean Tensor for valid pairs and cleaned labels.""" sorted_labels = ops.convert_to_tensor(sorted_labels) sorted_labels.get_shape().assert_has_rank(2) is_label_valid = utils.is_label_valid(sorted_labels) valid_pairs = math_ops.logical_and( array_ops.expand_dims(is_label_valid, 2), array_ops.expand_dims(is_label_valid, 1)) sorted_labels = array_ops.where(is_label_valid, sorted_labels, array_ops.zeros_like(sorted_labels)) return valid_pairs, sorted_labels
def normalize_weights(self, labels, weights): """See _RankingLoss.""" # The `weights` is item-wise and is applied non-symmetrically to update # pairwise_weights as # pairwise_weights(i, j) = w_i * pairwise_weights(i, j). # This effectively applies to all pairs with l_i > l_j. Note that it is # actually symmetric when `weights` are constant per list, i.e., listwise # weights. if weights is None: weights = 1. weights = torch.where(utils.is_label_valid(labels), torch.ones_like(labels) * weights, torch.zeros_like(labels)) return torch.unsqueeze(weights, dim=2)
def individual_weights(self, sorted_labels): """See `_LambdaWeight`.""" with ops.name_scope(None, 'dcg_lambda_weight', (sorted_labels, )): sorted_labels = ops.convert_to_tensor(sorted_labels) sorted_labels = array_ops.where( utils.is_label_valid(sorted_labels), sorted_labels, array_ops.zeros_like(sorted_labels)) gain = self._gain_fn(sorted_labels) if self._normalized: gain *= self._inverse_max_dcg(sorted_labels) rank_discount = self._rank_discount_fn( math_ops.to_float( math_ops.range(array_ops.shape(sorted_labels)[1]) + 1)) return gain * rank_discount
def individual_weights(self, labels, ranks): """See `_LambdaWeight`.""" _check_tensor_shapes([labels, ranks]) labels = torch.where(utils.is_label_valid(labels), labels, torch.zeros_like(labels)) gain = self._gain_fn(labels) if self._normalized: gain *= utils.inverse_max_dcg( labels, gain_fn=self._gain_fn, rank_discount_fn=self._rank_discount_fn, topn=self._topn) rank_discount = self._rank_discount_fn(ranks.float()) return gain * rank_discount
def precompute(self, labels, logits, weights): """Precomputes Tensors for softmax cross entropy inputs.""" is_valid = utils.is_label_valid(labels) ranks = _compute_ranks(logits, is_valid) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution in softmax. labels = torch.where(is_valid, labels, torch.zeros_like(labels)) logits = torch.where(is_valid, logits, math.log(_EPSILON) * torch.ones_like(logits)) if self._lambda_weight is not None and isinstance( self._lambda_weight, DCGLambdaWeight): labels = self._lambda_weight.individual_weights(labels, ranks) if weights is not None: labels *= weights return labels, logits
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" is_valid = utils.is_label_valid(labels) ranks = _compute_ranks(logits, is_valid) pairwise_labels, pairwise_logits = _pairwise_comparison(labels, logits) pairwise_weights = pairwise_labels if self._lambda_weight is not None: pairwise_weights *= self._lambda_weight.pair_weights(labels, ranks) # For LambdaLoss with relative rank difference, the scale of loss becomes # much smaller when applying LambdaWeight. This affects the training can # make the optimal learning rate become much larger. We use a heuristic to # scale it up to the same magnitude as standard pairwise loss. pairwise_weights *= float(labels.shape[1]) # pairwise_weights = tf.stop_gradient( # pairwise_weights, name='weights_stop_gradient') pairwise_weights = pairwise_weights.clone().detach() return self._pairwise_loss(pairwise_logits), pairwise_weights
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" alpha = self._params.get('alpha', 10.0) is_valid = utils.is_label_valid(labels) labels = torch.where(is_valid, labels, torch.zeros_like(labels)) logits_min, _ = torch.min(logits, dim=-1, keepdim=True) logits = torch.where(is_valid, logits, -1e3 * torch.ones_like(logits) + logits_min) label_sum = torch.sum(labels, dim=1, keepdim=True) nonzero_mask = torch.gt(label_sum, 0.0) labels = torch.where(nonzero_mask, labels, _EPSILON * torch.ones_like(labels)) rr = 1. / utils.approx_ranks(logits, alpha=alpha) rr = torch.sum(rr * labels, dim=-1, keepdim=True) mrr = rr / torch.sum(labels, dim=-1, keepdim=True) return -mrr, nonzero_mask.float()
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" alpha = self._params.get('alpha', 10.0) is_valid = utils.is_label_valid(labels) labels = torch.where(is_valid, labels, torch.zeros_like(labels)) logits_min, _ = torch.min(logits, dim=-1, keepdim=True) logits = torch.where(is_valid, logits, -1e3 * torch.ones_like(logits) + logits_min) label_sum = torch.sum(labels, dim=1, keepdim=True) nonzero_mask = torch.gt(label_sum, 0.0) labels = torch.where(nonzero_mask, labels, _EPSILON * torch.ones_like(labels)) gains = torch.pow(2., labels.float()) - 1. ranks = utils.approx_ranks(logits, alpha=alpha) discounts = 1. / torch.log1p(ranks) dcg = torch.sum(gains * discounts, dim=-1, keepdim=True) cost = -dcg * utils.inverse_max_dcg(labels) return cost, nonzero_mask.float()
def _sigmoid_cross_entropy_loss( labels, logits, weights=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None): """Computes the sigmoid_cross_entropy loss for a list. Given the labels of graded relevance l_i and the logits s_i, we calculate the sigmoid cross entropy for each ith position and aggregate the per position losses. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. Returns: An op for the sigmoid cross entropy as a loss. """ with ops.name_scope(name, 'sigmoid_cross_entropy_loss', (labels, logits, weights)): is_label_valid = array_ops.reshape(utils.is_label_valid(labels), [-1]) weights = 1.0 if weights is None else ops.convert_to_tensor(weights) weights = array_ops.ones_like(labels) * weights label_vector, logit_vector, weight_vector = [ array_ops.boolean_mask(array_ops.reshape(x, [-1]), is_label_valid) for x in [labels, logits, weights] ] return core_losses.sigmoid_cross_entropy(label_vector, logit_vector, weights=weight_vector, reduction=reduction)
def _pairwise_comparison(sorted_labels, sorted_logits, sorted_weights, lambda_weight=None): r"""Returns pairwise comparison `Tensor`s. Given a list of n items, the labels of graded relevance l_i and the logits s_i, we sort the items in a list based on s_i and obtain ranks r_i. We form n^2 pairs of items. For each pair, we have the following: / | 1 if l_i > l_j * `pairwise_labels` = | | 0 if l_i <= l_j \ * `pairwise_logits` = s_i - s_j / | 0 if l_i <= l_j, * `pairwise_weights` = | |l_i - l_j| if lambda_weight is None, | lambda_weight otherwise. \ The `sorted_weights` is item-wise and is applied non-symmetrically to update pairwise_weights as pairwise_weights(i, j) = w_i * pairwise_weights(i, j). This effectively applies to all pairs with l_i > l_j. Note that it is actually symmetric when `sorted_weights` are constant per list, i.e., listwise weights. Args: sorted_labels: A `Tensor` with shape [batch_size, list_size] of labels sorted. sorted_logits: A `Tensor` with shape [batch_size, list_size] of logits sorted. sorted_weights: A `Tensor` with shape [batch_size, list_size] of item-wise weights sorted. lambda_weight: A `_LambdaWeight` object. Returns: A tuple of (pairwise_labels, pairwise_logits, pairwise_weights) with each having the shape [batch_size, list_size, list_size]. """ # Compute the difference for all pairs in a list. The output is a Tensor with # shape [batch_size, list_size, list_size] where the entry [-1, i, j] stores # the information for pair (i, j). pairwise_label_diff = array_ops.expand_dims( sorted_labels, 2) - array_ops.expand_dims(sorted_labels, 1) pairwise_logits = array_ops.expand_dims( sorted_logits, 2) - array_ops.expand_dims(sorted_logits, 1) pairwise_labels = math_ops.to_float( math_ops.greater(pairwise_label_diff, 0)) is_label_valid = utils.is_label_valid(sorted_labels) valid_pair = math_ops.logical_and(array_ops.expand_dims(is_label_valid, 2), array_ops.expand_dims(is_label_valid, 1)) # Only keep the case when l_i > l_j. pairwise_weights = pairwise_labels * math_ops.to_float(valid_pair) # Apply the item-wise weights along l_i. pairwise_weights *= array_ops.expand_dims(sorted_weights, 2) if lambda_weight is not None: pairwise_weights *= lambda_weight.pair_weights(sorted_labels) else: pairwise_weights *= math_ops.abs(pairwise_label_diff) pairwise_weights = array_ops.stop_gradient(pairwise_weights, name='weights_stop_gradient') return pairwise_labels, pairwise_logits, pairwise_weights
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" labels = torch.where(utils.is_label_valid(labels), labels, torch.zeros_like(labels)) losses = self._sigmoid_cross_entropy_with_logits(labels, logits) return losses, 1.
def _list_mle_loss(labels, logits, weights=None, lambda_weight=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None, seed=None): """Computes the ListMLE loss [Xia et al. 2008] for a list. Given the labels of graded relevance l_i and the logits s_i, we calculate the ListMLE loss for the given list. The `lambda_weight` re-weights examples based on l_i and r_i. The recommended weighting scheme is the formulation presented in the "Position-Aware ListMLE" paper (Lan et. al) and available using create_p_list_mle_lambda_weight() factory function above. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `DCGLambdaWeight` instance. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. seed: A randomization seed used when shuffling ground truth permutations. Returns: An op for the ListMLE loss. """ with ops.name_scope(name, 'list_mle_loss', (labels, logits, weights)): is_label_valid = utils.is_label_valid(labels) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution. labels = array_ops.where(is_label_valid, labels, array_ops.zeros_like(labels)) logits = array_ops.where( is_label_valid, logits, math_ops.log(_EPSILON) * array_ops.ones_like(logits)) weights = 1.0 if weights is None else ops.convert_to_tensor(weights) weights = array_ops.squeeze(weights) # Shuffle labels and logits to add randomness to sort. shuffled_indices = utils.shuffle_valid_indices(is_label_valid, seed) shuffled_labels = array_ops.gather_nd(labels, shuffled_indices) shuffled_logits = array_ops.gather_nd(logits, shuffled_indices) sorted_labels, sorted_logits = utils.sort_by_scores( shuffled_labels, [shuffled_labels, shuffled_logits]) raw_max = math_ops.reduce_max(sorted_logits, axis=1, keepdims=True) sorted_logits = sorted_logits - raw_max sums = math_ops.cumsum(math_ops.exp(sorted_logits), axis=1, reverse=True) sums = math_ops.log(sums) - sorted_logits if lambda_weight is not None and isinstance(lambda_weight, ListMLELambdaWeight): sums *= lambda_weight.individual_weights(sorted_labels) negative_log_likelihood = math_ops.reduce_sum(sums, 1) return core_losses.compute_weighted_loss(negative_log_likelihood, weights=weights, reduction=reduction)