def normalize_weights(self, labels, weights): """See _RankingLoss.""" # The `weights` is item-wise and is applied non-symmetrically to update # pairwise_weights as # pairwise_weights(i, j) = w_i * pairwise_weights(i, j). # This effectively applies to all pairs with l_i > l_j. Note that it is # actually symmetric when `weights` are constant per list, i.e., listwise # weights. if weights is None: weights = 1. weights = tf.compat.v1.where(utils.is_label_valid(labels), tf.ones_like(labels) * weights, tf.zeros_like(labels)) return tf.expand_dims(weights, axis=2)
def individual_weights(self, sorted_labels): """See `_LambdaWeight`.""" with ops.name_scope(None, 'dcg_lambda_weight', (sorted_labels, )): sorted_labels = ops.convert_to_tensor(sorted_labels) sorted_labels = array_ops.where( utils.is_label_valid(sorted_labels), sorted_labels, array_ops.zeros_like(sorted_labels)) gain = self._gain_fn(sorted_labels) if self._normalized: gain *= self._inverse_max_dcg(sorted_labels) rank_discount = self._rank_discount_fn( math_ops.to_float( math_ops.range(array_ops.shape(sorted_labels)[1]) + 1)) return gain * rank_discount
def _infer_sizes(example_features, labels): """Infers batch_size, list_size, and is_valid based on inputs.""" with tf.compat.v1.name_scope('infer_sizes'): if labels is not None: if isinstance(labels, dict): labels = next(six.itervalues(labels)) batch_size, list_size = tf.unstack(tf.shape(input=labels)) is_valid = utils.is_label_valid(labels) else: if not example_features: raise ValueError('`example_features` is empty.') # Infer batch_size and list_size from a feature. example_tensor_shape = tf.shape( input=next(six.itervalues(example_features))) batch_size = example_tensor_shape[0] list_size = example_tensor_shape[1] # Mark all entries as valid in case we don't have enough information. # TODO: Be more smart to infer is_valid. is_valid = utils.is_label_valid(tf.ones([batch_size, list_size])) if batch_size is None or list_size is None: raise ValueError('Invalid batch_size=%s or list_size=%s' % (batch_size, list_size)) return batch_size, list_size, is_valid
def precompute(self, labels, logits, weights): """Precomputes Tensors for softmax cross entropy inputs.""" is_valid = utils.is_label_valid(labels) ranks = _compute_ranks(logits, is_valid) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution in softmax. labels = tf.compat.v1.where(is_valid, labels, tf.zeros_like(labels)) logits = tf.compat.v1.where(is_valid, logits, tf.math.log(_EPSILON) * tf.ones_like(logits)) if self._lambda_weight is not None and isinstance(self._lambda_weight, DCGLambdaWeight): labels = self._lambda_weight.individual_weights(labels, ranks) if weights is not None: labels *= weights return labels, logits
def test_organize_valid_indices(self): tf.compat.v1.set_random_seed(1) labels = [[1.0, 0.0, -1.0], [-1.0, 1.0, 2.0]] is_valid = utils.is_label_valid(labels) shuffled_indices = utils.shuffle_valid_indices(is_valid) organized_indices = utils.organize_valid_indices(is_valid, shuffle=False) with tf.compat.v1.Session() as sess: shuffled_indices = sess.run(shuffled_indices) self.assertAllEqual( shuffled_indices, [[[0, 1], [0, 0], [0, 2]], [[1, 1], [1, 2], [1, 0]]]) organized_indices = sess.run(organized_indices) self.assertAllEqual( organized_indices, [[[0, 0], [0, 1], [0, 2]], [[1, 1], [1, 2], [1, 0]]])
def individual_weights(self, labels, ranks): """See `_LambdaWeight`.""" with tf.compat.v1.name_scope(name='dcg_lambda_weight'): _check_tensor_shapes([labels, ranks]) labels = tf.convert_to_tensor(value=labels) labels = tf.compat.v1.where( utils.is_label_valid(labels), labels, tf.zeros_like(labels)) gain = self._gain_fn(labels) if self._normalized: gain *= inverse_max_dcg( labels, gain_fn=self._gain_fn, rank_discount_fn=self._rank_discount_fn, topn=self._topn) rank_discount = self._rank_discount_fn(tf.cast(ranks, dtype=tf.float32)) return gain * rank_discount
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" is_valid = utils.is_label_valid(labels) labels = tf.compat.v1.where(is_valid, labels, tf.zeros_like(labels)) logits = tf.compat.v1.where( is_valid, logits, -1e3 * tf.ones_like(logits) + tf.reduce_min(input_tensor=logits, axis=-1, keepdims=True)) label_sum = tf.reduce_sum(input_tensor=labels, axis=1, keepdims=True) nonzero_mask = tf.greater(tf.reshape(label_sum, [-1]), 0.0) labels = tf.compat.v1.where(nonzero_mask, labels, _EPSILON * tf.ones_like(labels)) ranks = approx_ranks(logits, temperature=self._temperature) return -ndcg(labels, ranks), tf.reshape( tf.cast(nonzero_mask, dtype=tf.float32), [-1, 1])
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" is_valid = utils.is_label_valid(labels) ranks = _compute_ranks(logits, is_valid) pairwise_labels, pairwise_logits = _pairwise_comparison(labels, logits) pairwise_weights = pairwise_labels if self._lambda_weight is not None: pairwise_weights *= self._lambda_weight.pair_weights(labels, ranks) # For LambdaLoss with relative rank difference, the scale of loss becomes # much smaller when applying LambdaWeight. This affects the training can # make the optimal learning rate become much larger. We use a heuristic to # scale it up to the same magnitude as standard pairwise loss. pairwise_weights *= tf.cast(tf.shape(input=labels)[1], dtype=tf.float32) pairwise_weights = tf.stop_gradient( pairwise_weights, name='weights_stop_gradient') return self._pairwise_loss(pairwise_logits), pairwise_weights
def individual_weights(self, sorted_labels): """See `_LambdaWeight`.""" with tf.name_scope(name='dcg_lambda_weight'): sorted_labels = tf.convert_to_tensor(value=sorted_labels) sorted_labels = tf.where(utils.is_label_valid(sorted_labels), sorted_labels, tf.zeros_like(sorted_labels)) gain = self._gain_fn(sorted_labels) if self._normalized: gain *= utils.inverse_max_dcg( sorted_labels, gain_fn=self._gain_fn, rank_discount_fn=self._rank_discount_fn, topn=self._topn) rank_discount = self._rank_discount_fn( tf.cast(tf.range(tf.shape(input=sorted_labels)[1]) + 1, dtype=tf.float32)) return gain * rank_discount
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" alpha = self._params.get('alpha', 10.0) is_valid = utils.is_label_valid(labels) labels = tf.compat.v1.where(is_valid, labels, tf.zeros_like(labels)) logits = tf.compat.v1.where( is_valid, logits, -1e3 * tf.ones_like(logits) + tf.math.reduce_min(input_tensor=logits, axis=-1, keepdims=True)) label_sum = tf.math.reduce_sum(input_tensor=labels, axis=1, keepdims=True) nonzero_mask = tf.math.greater(tf.reshape(label_sum, [-1]), 0.0) labels = tf.compat.v1.where(nonzero_mask, labels, _EPSILON * tf.ones_like(labels)) rr = 1. / utils.approx_ranks(logits, alpha=alpha) rr = tf.math.reduce_sum(input_tensor=rr * labels, axis=-1, keepdims=True) mrr = rr / tf.math.reduce_sum(input_tensor=labels, axis=-1, keepdims=True) return -mrr, tf.reshape(tf.cast(nonzero_mask, dtype=tf.float32), [-1, 1])
def compute_unreduced_loss(labels, logits): """See `_RankingLoss`.""" alpha = 10.0 is_valid = utils.is_label_valid(labels) labels = tf.compat.v1.where(is_valid, labels, tf.zeros_like(labels)) logits = tf.compat.v1.where( is_valid, logits, -1e3 * tf.ones_like(logits) + tf.reduce_min(input_tensor=logits, axis=-1, keepdims=True)) label_sum = tf.reduce_sum(input_tensor=labels, axis=1, keepdims=True) nonzero_mask = tf.greater(tf.reshape(label_sum, [-1]), 0.0) labels = tf.compat.v1.where(nonzero_mask, labels, _EPSILON * tf.ones_like(labels)) gains = tf.pow(2., tf.cast(labels, dtype=tf.float32)) - 1. ranks = utils.approx_ranks(logits, alpha=alpha) discounts = 1. / tf.math.log1p(ranks) dcg = tf.reduce_sum(input_tensor=gains * discounts, axis=-1, keepdims=True) cost = -dcg * utils.inverse_max_dcg(labels) return cost, tf.reshape(tf.cast(nonzero_mask, dtype=tf.float32), [-1, 1])
def compute_unreduced_loss(self, labels, logits, weights): """See `_RankingLoss`.""" is_label_valid = utils.is_label_valid(labels) #检查label是否符合要求,即大于等于0 # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution. labels = tf.where( is_label_valid, labels, tf.zeros_like(labels)) #使用tf.where()函数将labels中不符合要求的label改变为0 logits = tf.where( is_label_valid, logits, tf.math.log(_EPSILON) * tf.ones_like(logits)) #对应的logits中非法的数据改成tf.math.log(_EPSILON) weights = 1.0 if weights is None else tf.convert_to_tensor( value=weights) #多loss融合时的系数 weights = tf.squeeze(weights) # Shuffle labels and logits to add randomness to sort. shuffled_indices = utils.shuffle_valid_indices(is_label_valid, self._seed) #对数据进行随机 shuffled_labels = tf.gather_nd(labels, shuffled_indices) #得到随机后的数据 shuffled_logits = tf.gather_nd(logits, shuffled_indices) sorted_labels, sorted_logits = utils.sort_by_scores( #根据已经有的labels进行排序,得到排序后的logits和对应的labels shuffled_labels, [shuffled_labels, shuffled_logits]) raw_max = tf.reduce_max( input_tensor=sorted_logits, axis=1, keepdims=True) #计算排序后的logits每行中的最大值组成一个tensor[batch_size,1] sorted_logits = sorted_logits - raw_max sums = tf.cumsum( tf.exp(sorted_logits), axis=1, reverse=True) #计算累计和,并且按照逆向累加方式,由于刚开始是1/(logits[0]+..logits[n]) sums = tf.math.log(sums) - sorted_logits #根据listMLE的损失函数进行变换,具体可看公式 if self._lambda_weight is not None and isinstance( self._lambda_weight, #目前不会用到 ListMLELambdaWeight): sums *= self._lambda_weight.individual_weights(sorted_labels) negative_log_likelihood = tf.reduce_sum(input_tensor=sums, axis=1) #在行上进行加和(batch_size,1) return negative_log_likelihood, weights
def _prepare_and_validate_params(self, labels, predictions, weights, mask): """Prepares and validates the parameters. Args: labels: A `Tensor` of the same shape as `predictions`. A value >= 1 means a relevant example. predictions: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding example. weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The former case is per-example and the latter case is per-list. mask: A `Tensor` of the same shape as predictions indicating which entries are valid for computing the metric. Returns: (labels, predictions, weights, mask) ready to be used for metric calculation. """ if any( isinstance(tensor, tf.RaggedTensor) for tensor in [labels, predictions, weights]): raise ValueError( 'labels, predictions and/or weights are ragged tensors, ' 'use ragged=True to enable ragged support for metrics.') labels = tf.convert_to_tensor(value=labels) predictions = tf.convert_to_tensor(value=predictions) weights = 1.0 if weights is None else tf.convert_to_tensor( value=weights) example_weights = tf.ones_like(labels) * weights predictions.get_shape().assert_is_compatible_with( example_weights.get_shape()) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions.get_shape().assert_has_rank(2) # All labels should be >= 0. Invalid entries are reset. if mask is None: mask = utils.is_label_valid(labels) labels = tf.compat.v1.where(mask, labels, tf.zeros_like(labels)) predictions = tf.compat.v1.where( mask, predictions, -1e-6 * tf.ones_like(predictions) + tf.reduce_min(input_tensor=predictions, axis=1, keepdims=True)) return labels, predictions, example_weights, mask
def _mean_squared_loss( labels, logits, weights=None, reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None): """Computes the mean squared loss for a list. Given the labels of graded relevance l_i and the logits s_i, we calculate the squared error for each ith position and aggregate the per position losses. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. Returns: An op for the mean squared error as a loss. """ with tf.compat.v1.name_scope(name, 'mean_squared_loss', (labels, logits, weights)): is_label_valid = tf.reshape(utils.is_label_valid(labels), [-1]) weights = 1.0 if weights is None else tf.convert_to_tensor( value=weights) weights = tf.ones_like(labels) * weights label_vector, logit_vector, weight_vector = [ tf.boolean_mask(tensor=tf.reshape(x, [-1]), mask=is_label_valid) for x in [labels, logits, weights] ] return tf.compat.v1.losses.mean_squared_error(label_vector, logit_vector, weights=weight_vector, reduction=reduction)
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" is_valid = utils.is_label_valid(labels) labels = tf.compat.v1.where(is_valid, labels, tf.zeros_like(labels)) logits = tf.compat.v1.where( is_valid, logits, -1e3 * tf.ones_like(logits) + tf.reduce_min(input_tensor=logits, axis=-1, keepdims=True)) label_sum = tf.reduce_sum(input_tensor=labels, axis=1, keepdims=True) nonzero_mask = tf.greater(tf.reshape(label_sum, [-1]), 0.0) labels = tf.compat.v1.where(is_valid, labels, -1e3 * tf.ones_like(labels)) # shape = [batch_size, list_size, list_size]. true_perm = neural_sort(labels) smooth_perm = neural_sort(logits) losses = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=true_perm, logits=tf.math.log(1e-20 + smooth_perm), axis=2) # shape = [batch_size, list_size]. losses = tf.reduce_mean(input_tensor=losses, axis=-1, keepdims=True) return losses, tf.reshape(tf.cast(nonzero_mask, dtype=tf.float32), [-1, 1])
def compute_unreduced_loss(self, labels, logits): """See `_RankingLoss`.""" is_valid = utils.is_label_valid(labels) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution. labels = tf.compat.v1.where(is_valid, labels, tf.zeros_like(labels)) logits = tf.compat.v1.where( is_valid, logits, tf.math.log(_EPSILON) * tf.ones_like(logits)) scores = tf.compat.v1.where( is_valid, labels, tf.reduce_min(input_tensor=labels, axis=1, keepdims=True) - 1e-6 * tf.ones_like(labels)) # Use a fixed ops-level seed and the randomness is controlled by the # graph-level seed. sorted_labels, sorted_logits = utils.sort_by_scores(scores, [labels, logits], shuffle_ties=True, seed=37) raw_max = tf.reduce_max(input_tensor=sorted_logits, axis=1, keepdims=True) sorted_logits = sorted_logits - raw_max sums = tf.cumsum(tf.exp(sorted_logits), axis=1, reverse=True) sums = tf.math.log(sums) - sorted_logits if self._lambda_weight is not None and isinstance( self._lambda_weight, ListMLELambdaWeight): batch_size, list_size = tf.unstack(tf.shape(input=sorted_labels)) sums *= self._lambda_weight.individual_weights( sorted_labels, tf.tile(tf.expand_dims(tf.range(list_size) + 1, 0), [batch_size, 1])) negative_log_likelihood = tf.reduce_sum(input_tensor=sums, axis=1, keepdims=True) return negative_log_likelihood, 1.
def _prepare_and_validate_params(labels, predictions, weights=None, topn=None): """Prepares and validates the parameters. Args: labels: A `Tensor` of the same shape as `predictions`. A value >= 1 means a relevant example. predictions: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding example. weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The former case is per-example and the latter case is per-list. topn: A cutoff for how many examples to consider for this metric. Returns: (labels, predictions, weights, topn) ready to be used for metric calculation. """ labels = ops.convert_to_tensor(labels) predictions = ops.convert_to_tensor(predictions) weights = 1.0 if weights is None else ops.convert_to_tensor(weights) example_weights = array_ops.ones_like(labels) * weights predictions.get_shape().assert_is_compatible_with(example_weights.get_shape()) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions.get_shape().assert_has_rank(2) if topn is None: topn = array_ops.shape(predictions)[1] # All labels should be >= 0. Invalid entries are reset. is_label_valid = utils.is_label_valid(labels) labels = array_ops.where( is_label_valid, labels, array_ops.zeros_like(labels)) predictions = array_ops.where( is_label_valid, predictions, -1e-6 * array_ops.ones_like(predictions) + math_ops.reduce_min( predictions, axis=1, keepdims=True)) return labels, predictions, example_weights, topn
def _prepare_and_validate_params(self, labels, predictions, weights, mask): """Prepares and validates the parameters. Args: labels: A `Tensor` with shape [batch_size, list_size, subtopic_size]. A nonzero value means that the example covers the corresponding subtopic. predictions: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding example. weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The former case is per-example and the latter case is per-list. mask: A `Tensor` of the same shape as predictions indicating which entries are valid for computing the metric. Returns: A 4-tuple of (labels, predictions, weights, mask) ready to be used for metric calculation. """ labels = tf.convert_to_tensor(value=labels) predictions = tf.convert_to_tensor(value=predictions) labels.get_shape().assert_has_rank(3) if mask is None: mask = utils.is_label_valid(labels) mask = tf.convert_to_tensor(value=mask) if mask.get_shape().rank == 3: mask = tf.reduce_any(mask, axis=2) predictions = tf.where( mask, predictions, -1e-6 * tf.ones_like(predictions) + tf.reduce_min(input_tensor=predictions, axis=1, keepdims=True)) # All labels should be >= 0. Invalid entries are reset. labels = tf.where(tf.expand_dims(mask, axis=2), labels, tf.zeros_like(labels)) weights = (tf.constant(1.0, dtype=tf.float32) if weights is None else tf.convert_to_tensor(value=weights)) example_weights = tf.ones_like(predictions) * weights return labels, predictions, example_weights, mask
def test_is_label_valid(self): labels = [[1.0, 0.0, -1.0]] labels_validity = [[True, True, False]] with tf.compat.v1.Session() as sess: is_valid = sess.run(utils.is_label_valid(labels)) self.assertAllEqual(is_valid, labels_validity)
def _pairwise_comparison(sorted_labels, sorted_logits, sorted_weights, lambda_weight=None): r"""Returns pairwise comparison `Tensor`s. Given a list of n items, the labels of graded relevance l_i and the logits s_i, we sort the items in a list based on s_i and obtain ranks r_i. We form n^2 pairs of items. For each pair, we have the following: / | 1 if l_i > l_j * `pairwise_labels` = | | 0 if l_i <= l_j \ * `pairwise_logits` = s_i - s_j / | 0 if l_i <= l_j, * `pairwise_weights` = | |l_i - l_j| if lambda_weight is None, | lambda_weight otherwise. \ The `sorted_weights` is item-wise and is applied non-symmetrically to update pairwise_weights as pairwise_weights(i, j) = w_i * pairwise_weights(i, j). This effectively applies to all pairs with l_i > l_j. Note that it is actually symmetric when `sorted_weights` are constant per list, i.e., listwise weights. Args: sorted_labels: A `Tensor` with shape [batch_size, list_size] of labels sorted. sorted_logits: A `Tensor` with shape [batch_size, list_size] of logits sorted. sorted_weights: A `Tensor` with shape [batch_size, list_size] of item-wise weights sorted. lambda_weight: A `_LambdaWeight` object. Returns: A tuple of (pairwise_labels, pairwise_logits, pairwise_weights) with each having the shape [batch_size, list_size, list_size]. """ # Compute the difference for all pairs in a list. The output is a Tensor with # shape [batch_size, list_size, list_size] where the entry [-1, i, j] stores # the information for pair (i, j). pairwise_label_diff = tf.expand_dims(sorted_labels, 2) - tf.expand_dims( sorted_labels, 1) pairwise_logits = tf.expand_dims(sorted_logits, 2) - tf.expand_dims( sorted_logits, 1) pairwise_labels = tf.cast(tf.greater(pairwise_label_diff, 0), dtype=tf.float32) is_label_valid = utils.is_label_valid(sorted_labels) valid_pair = tf.logical_and(tf.expand_dims(is_label_valid, 2), tf.expand_dims(is_label_valid, 1)) # Only keep the case when l_i > l_j. pairwise_weights = pairwise_labels * tf.cast(valid_pair, dtype=tf.float32) # Apply the item-wise weights along l_i. pairwise_weights *= tf.expand_dims(sorted_weights, 2) if lambda_weight is not None: pairwise_weights *= lambda_weight.pair_weights(sorted_labels) pairwise_weights = tf.stop_gradient(pairwise_weights, name='weights_stop_gradient') return pairwise_labels, pairwise_logits, pairwise_weights
def _softmax_loss( labels, logits, weights=None, lambda_weight=None, reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None): """Computes the softmax cross entropy for a list. Given the labels l_i and the logits s_i, we sort the examples and obtain ranks r_i. The standard softmax loss doesn't need r_i and is defined as -sum_i l_i * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))). The `lambda_weight` re-weight examples based on l_i and r_i. -sum_i w(l_i, r_i) * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))).abc See 'individual_weights' in 'DCGLambdaWeight' for how w(l_i, r_i) is computed. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `DCGLambdaWeight` instance. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. Returns: An op for the softmax cross entropy as a loss. """ with tf.compat.v1.name_scope(name, 'softmax_loss', (labels, logits, weights)): sorted_labels, sorted_logits, sorted_weights = _sort_and_normalize( labels, logits, weights) is_label_valid = utils.is_label_valid(sorted_labels) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution in softmax. sorted_labels = tf.where(is_label_valid, sorted_labels, tf.zeros_like(sorted_labels)) sorted_logits = tf.where( is_label_valid, sorted_logits, tf.math.log(_EPSILON) * tf.ones_like(sorted_logits)) if lambda_weight is not None and isinstance(lambda_weight, DCGLambdaWeight): sorted_labels = lambda_weight.individual_weights(sorted_labels) sorted_labels *= sorted_weights label_sum = tf.reduce_sum(input_tensor=sorted_labels, axis=1, keepdims=True) nonzero_mask = tf.greater(tf.reshape(label_sum, [-1]), 0.0) label_sum, sorted_labels, sorted_logits = [ tf.boolean_mask(tensor=x, mask=nonzero_mask) for x in [label_sum, sorted_labels, sorted_logits] ] return tf.compat.v1.losses.softmax_cross_entropy( sorted_labels / label_sum, sorted_logits, weights=tf.reshape(label_sum, [-1]), reduction=reduction)
def _list_mle_loss(labels, logits, weights=None, lambda_weight=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None, seed=None): """Computes the ListMLE loss [Xia et al. 2008] for a list. Given the labels of graded relevance l_i and the logits s_i, we calculate the ListMLE loss for the given list. The `lambda_weight` re-weights examples based on l_i and r_i. The recommended weighting scheme is the formulation presented in the "Position-Aware ListMLE" paper (Lan et. al) and available using create_p_list_mle_lambda_weight() factory function above. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `DCGLambdaWeight` instance. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. seed: A randomization seed used when shuffling ground truth permutations. Returns: An op for the ListMLE loss. """ with ops.name_scope(name, 'list_mle_loss', (labels, logits, weights)): is_label_valid = utils.is_label_valid(labels) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution. labels = array_ops.where(is_label_valid, labels, array_ops.zeros_like(labels)) logits = array_ops.where( is_label_valid, logits, math_ops.log(_EPSILON) * array_ops.ones_like(logits)) weights = 1.0 if weights is None else ops.convert_to_tensor(weights) weights = array_ops.squeeze(weights) # Shuffle labels and logits to add randomness to sort. shuffled_indices = utils.shuffle_valid_indices(is_label_valid, seed) shuffled_labels = array_ops.gather_nd(labels, shuffled_indices) shuffled_logits = array_ops.gather_nd(logits, shuffled_indices) sorted_labels, sorted_logits = utils.sort_by_scores( shuffled_labels, [shuffled_labels, shuffled_logits]) raw_max = math_ops.reduce_max(sorted_logits, axis=1, keepdims=True) sorted_logits = sorted_logits - raw_max sums = math_ops.cumsum(math_ops.exp(sorted_logits), axis=1, reverse=True) sums = math_ops.log(sums) - sorted_logits if lambda_weight is not None and isinstance(lambda_weight, ListMLELambdaWeight): sums *= lambda_weight.individual_weights(sorted_labels) negative_log_likelihood = math_ops.reduce_sum(sums, 1) return core_losses.compute_weighted_loss(negative_log_likelihood, weights=weights, reduction=reduction)
def sample(self, labels, logits, weights=None): """Samples scores from Concrete(logits). Args: labels: A `Tensor` with shape [batch_size, list_size] same as `logits`, representing graded relevance. Or in the diversity tasks, a `Tensor` with shape [batch_size, list_size, subtopic_size]. Each value represents relevance to a subtopic, 1 for relevent subtopic, 0 for irrelevant, and -1 for paddings. When the actual subtopic number of a query is smaller than the `subtopic_size`, `labels` will be padded to `subtopic_size` with -1. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. If None, the weight of a list in the mini-batch is set to the sum of the labels of the items in that list. Returns: A tuple of expanded labels, logits, and weights where the first dimension is now batch_size * sample_size. Logit Tensors are sampled from Concrete(logits) while labels and weights are simply tiled so the resulting Tensor has the updated dimensions. """ with tf.compat.v1.name_scope(self._name, 'gumbel_softmax_sample', (labels, logits, weights)): batch_size = tf.shape(input=labels)[0] list_size = tf.shape(input=labels)[1] # Expand labels. expanded_labels = tf.expand_dims(labels, 1) expanded_labels = tf.repeat(expanded_labels, [self._sample_size], axis=1) expanded_labels = utils.reshape_first_ndims( expanded_labels, 2, [batch_size * self._sample_size]) # Sample logits from Concrete(logits). sampled_logits = tf.expand_dims(logits, 1) sampled_logits = tf.tile(sampled_logits, [1, self._sample_size, 1]) sampled_logits += _sample_gumbel( [batch_size, self._sample_size, list_size], seed=self._seed) sampled_logits = tf.reshape( sampled_logits, [batch_size * self._sample_size, list_size]) is_label_valid = utils.is_label_valid(expanded_labels) if is_label_valid.shape.rank > 2: is_label_valid = tf.reduce_any(is_label_valid, axis=-1) sampled_logits = tf.compat.v1.where( is_label_valid, sampled_logits / self._temperature, tf.math.log(1e-20) * tf.ones_like(sampled_logits)) sampled_logits = tf.math.log(tf.nn.softmax(sampled_logits) + 1e-20) expanded_weights = weights if expanded_weights is not None: true_fn = lambda: tf.expand_dims( tf.expand_dims(expanded_weights, 1), 1) false_fn = lambda: tf.expand_dims(expanded_weights, 1) expanded_weights = tf.cond(pred=tf.math.equal( tf.rank(expanded_weights), 1), true_fn=true_fn, false_fn=false_fn) expanded_weights = tf.tile(expanded_weights, [1, self._sample_size, 1]) expanded_weights = tf.reshape( expanded_weights, [batch_size * self._sample_size, -1]) return expanded_labels, sampled_logits, expanded_weights
def _groupwise_dnn_v2(features, labels, mode, params, config): """Defines the dnn for groupwise scoring functions.""" with ops.name_scope('transform'): context_features, per_example_features = _call_transform_fn( features, mode) def _score_fn(context_features, group_features, reuse): with variable_scope.variable_scope('group_score', reuse=reuse): return group_score_fn(context_features, group_features, mode, params, config) # Scatter/Gather per-example scores through groupwise comparison. Each # instance in a mini-batch will form a number of groups. Each groups of # examples are scored by 'score_fn' and socres for individual examples # accumulated over groups. with ops.name_scope('groupwise_dnn_v2'): with ops.name_scope('infer_sizes'): if labels is not None: batch_size, list_size = array_ops.unstack( array_ops.shape(labels)) is_valid = utils.is_label_valid(labels) else: # Infer batch_size and list_size from a feature. example_tensor_shape = array_ops.shape( next(six.itervalues(per_example_features))) batch_size = example_tensor_shape[0] list_size = example_tensor_shape[1] is_valid = utils.is_label_valid( array_ops.ones([batch_size, list_size])) if batch_size is None or list_size is None: raise ValueError('Invalid batch_size=%s or list_size=%s' % (batch_size, list_size)) # For each example feature, assume the shape is [batch_size, list_size, # feature_size], the groups are formed along the 2nd dim. Each group has a # 'group_size' number of indices in [0, list_size). Based on these # indices, we can gather the example feature into a sub-tensor for each # group. The total number of groups we have for a mini-batch is batch_size # * num_groups. Inside each group, we have a 'group_size' number of # examples. indices, mask = _form_group_indices_nd( is_valid, group_size, shuffle=(mode != model_fn.ModeKeys.PREDICT)) num_groups = array_ops.shape(mask)[1] with ops.name_scope('group_features'): # For context features, We have shape [batch_size * num_groups, ...]. large_batch_context_features = {} for name, value in six.iteritems(context_features): # [batch_size, 1, ...]. value = array_ops.expand_dims(value, axis=1) # [batch_size, num_groups, ...]. value = array_ops.gather(value, array_ops.zeros([num_groups], dtypes.int32), axis=1) # [batch_size * num_groups, ...] large_batch_context_features[ name] = utils.reshape_first_ndims( value, 2, [batch_size * num_groups]) # For example feature, we have shape [batch_size * num_groups, # group_size, ...]. large_batch_group_features = {} for name, value in six.iteritems(per_example_features): # [batch_size, num_groups, group_size, ...]. value = array_ops.gather_nd(value, indices) # [batch_size * num_groups, group_size, ...]. large_batch_group_features[ name] = utils.reshape_first_ndims( value, 3, [batch_size * num_groups, group_size]) # Do the inference and get scores for the large batch. # [batch_size * num_groups, group_size]. scores = _score_fn(large_batch_context_features, large_batch_group_features, reuse=False) with ops.name_scope('accumulate_scores'): scores = array_ops.reshape( scores, [batch_size, num_groups, group_size]) # Reset invalid scores to 0 based on mask. scores = array_ops.where( array_ops.gather(array_ops.expand_dims(mask, 2), array_ops.zeros([group_size], dtypes.int32), axis=2), scores, array_ops.zeros_like(scores)) # [batch_size, num_groups, group_size]. list_scores = array_ops.scatter_nd(indices, scores, [batch_size, list_size]) # Use average. list_scores /= math_ops.to_float(group_size) if mode == model_fn.ModeKeys.PREDICT: return list_scores else: features.update(context_features) features.update(per_example_features) return list_scores
def _gumbel_softmax_sample( labels, logits, weights=None, name=None, sample_size=8, temperature=1.0, seed=None): """Samples scores from Concrete(logits). Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. If None, the weight of a list in the mini-batch is set to the sum of the labels of the items in that list. name: A string used as the name for this loss. sample_size: An integer representing the number of samples drawn from the Concrete distribution defined by scores. temperature: The Gumbel-Softmax temperature. seed: Seed for pseudo-random number generator. Returns: A tuple of expanded labels, logits, and weights where the first dimension is now batch_size * sample_size. Logit Tensors are sampled from Concrete(logits) while labels and weights are simply tiled so the resulting Tensor has the updated dimensions. """ with tf.compat.v1.name_scope(name, 'gumbel_softmax_sample', (labels, logits, weights)): batch_size = tf.shape(input=labels)[0] list_size = tf.shape(input=labels)[1] # Expand labels. expanded_labels = tf.expand_dims(labels, 1) expanded_labels = tf.tile(expanded_labels, [1, sample_size, 1]) expanded_labels = tf.reshape(expanded_labels, [batch_size * sample_size, list_size]) # Sample logits from Concrete(logits). sampled_logits = tf.expand_dims(logits, 1) sampled_logits = tf.tile(sampled_logits, [1, sample_size, 1]) sampled_logits += _sample_gumbel([batch_size, sample_size, list_size], seed=seed) sampled_logits = tf.reshape(sampled_logits, [batch_size * sample_size, list_size]) is_label_valid = utils.is_label_valid(expanded_labels) sampled_logits = tf.where( is_label_valid, sampled_logits / temperature, tf.math.log(1e-20) * tf.ones_like(sampled_logits)) sampled_logits = tf.math.log(tf.nn.softmax(sampled_logits) + 1e-20) expanded_weights = weights if expanded_weights is not None: true_fn = lambda: tf.expand_dims(tf.expand_dims(expanded_weights, 1), 1) false_fn = lambda: tf.expand_dims(expanded_weights, 1) expanded_weights = tf.cond( pred=tf.math.equal(tf.rank(expanded_weights), 1), true_fn=true_fn, false_fn=false_fn) expanded_weights = tf.tile(expanded_weights, [1, sample_size, 1]) expanded_weights = tf.reshape(expanded_weights, [batch_size * sample_size, -1]) return expanded_labels, sampled_logits, expanded_weights