def compute_unreduced_loss(self, labels, logits, weights): """See `_RankingLoss`.""" is_label_valid = utils.is_label_valid(labels) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution. labels = tf.where(is_label_valid, labels, tf.zeros_like(labels)) logits = tf.where(is_label_valid, logits, tf.math.log(_EPSILON) * tf.ones_like(logits)) weights = 1.0 if weights is None else tf.convert_to_tensor( value=weights) weights = tf.squeeze(weights) # Shuffle labels and logits to add randomness to sort. shuffled_indices = utils.shuffle_valid_indices(is_label_valid, self._seed) shuffled_labels = tf.gather_nd(labels, shuffled_indices) shuffled_logits = tf.gather_nd(logits, shuffled_indices) sorted_labels, sorted_logits = utils.sort_by_scores( shuffled_labels, [shuffled_labels, shuffled_logits]) raw_max = tf.reduce_max(input_tensor=sorted_logits, axis=1, keepdims=True) sorted_logits = sorted_logits - raw_max sums = tf.cumsum(tf.exp(sorted_logits), axis=1, reverse=True) sums = tf.math.log(sums) - sorted_logits if self._lambda_weight is not None and isinstance( self._lambda_weight, ListMLELambdaWeight): sums *= self._lambda_weight.individual_weights(sorted_labels) negative_log_likelihood = tf.reduce_sum(input_tensor=sums, axis=1) return negative_log_likelihood, weights
def test_shuffle_valid_indices(self): random_seed.set_random_seed(1) labels = [[1.0, 0.0, -1.0], [-1.0, 1.0, 2.0]] is_valid = utils.is_label_valid(labels) shuffled_indices = utils.shuffle_valid_indices(is_valid) with session.Session() as sess: shuffled_indices = sess.run(shuffled_indices) self.assertAllEqual(shuffled_indices, [[[0, 1], [0, 0], [0, 2]], [[1, 1], [1, 2], [1, 0]]])
def test_organize_valid_indices(self): tf.compat.v1.set_random_seed(1) labels = [[1.0, 0.0, -1.0], [-1.0, 1.0, 2.0]] is_valid = utils.is_label_valid(labels) shuffled_indices = utils.shuffle_valid_indices(is_valid) organized_indices = utils.organize_valid_indices(is_valid, shuffle=False) with tf.compat.v1.Session() as sess: shuffled_indices = sess.run(shuffled_indices) self.assertAllEqual(shuffled_indices, [[[0, 1], [0, 0], [0, 2]], [[1, 1], [1, 2], [1, 0]]]) organized_indices = sess.run(organized_indices) self.assertAllEqual(organized_indices, [[[0, 0], [0, 1], [0, 2]], [[1, 1], [1, 2], [1, 0]]])
def parse(self, serialized): """See `_RankingDataParser`.""" (serialized_context, serialized_list, sizes) = self._decode_as_serialized_example_list(serialized) # Use static batch size whenever possible. batch_size = serialized_context.get_shape().as_list()[0] or tf.shape( input=serialized_list)[0] cur_list_size = tf.shape(input=serialized_list)[1] list_size = self._list_size if self._shuffle_examples: is_valid = tf.sequence_mask(sizes, cur_list_size) indices = utils.shuffle_valid_indices(is_valid, seed=self._seed) serialized_list = tf.gather_nd(serialized_list, indices) # Apply truncation or padding to align tensor shape. if list_size: def truncate_fn(): return tf.slice(serialized_list, [0, 0], [batch_size, list_size]) def pad_fn(): return tf.pad(tensor=serialized_list, paddings=[[0, 0], [0, list_size - cur_list_size]], constant_values="") serialized_list = tf.cond(pred=cur_list_size > list_size, true_fn=truncate_fn, false_fn=pad_fn) cur_list_size = list_size features = {} example_features = tf.compat.v1.io.parse_example( tf.reshape(serialized_list, [-1]), self._example_feature_spec) for k, v in six.iteritems(example_features): features[k] = utils.reshape_first_ndims( v, 1, [batch_size, cur_list_size]) if self._context_feature_spec: features.update( tf.compat.v1.io.parse_example( tf.reshape(serialized_context, [batch_size]), self._context_feature_spec)) # Add example list sizes to features, if needed. if self._size_feature_name: features[self._size_feature_name] = sizes return features
def compute_unreduced_loss(self, labels, logits, weights): """See `_RankingLoss`.""" is_label_valid = utils.is_label_valid(labels) #检查label是否符合要求,即大于等于0 # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution. labels = tf.where( is_label_valid, labels, tf.zeros_like(labels)) #使用tf.where()函数将labels中不符合要求的label改变为0 logits = tf.where( is_label_valid, logits, tf.math.log(_EPSILON) * tf.ones_like(logits)) #对应的logits中非法的数据改成tf.math.log(_EPSILON) weights = 1.0 if weights is None else tf.convert_to_tensor( value=weights) #多loss融合时的系数 weights = tf.squeeze(weights) # Shuffle labels and logits to add randomness to sort. shuffled_indices = utils.shuffle_valid_indices(is_label_valid, self._seed) #对数据进行随机 shuffled_labels = tf.gather_nd(labels, shuffled_indices) #得到随机后的数据 shuffled_logits = tf.gather_nd(logits, shuffled_indices) sorted_labels, sorted_logits = utils.sort_by_scores( #根据已经有的labels进行排序,得到排序后的logits和对应的labels shuffled_labels, [shuffled_labels, shuffled_logits]) raw_max = tf.reduce_max( input_tensor=sorted_logits, axis=1, keepdims=True) #计算排序后的logits每行中的最大值组成一个tensor[batch_size,1] sorted_logits = sorted_logits - raw_max sums = tf.cumsum( tf.exp(sorted_logits), axis=1, reverse=True) #计算累计和,并且按照逆向累加方式,由于刚开始是1/(logits[0]+..logits[n]) sums = tf.math.log(sums) - sorted_logits #根据listMLE的损失函数进行变换,具体可看公式 if self._lambda_weight is not None and isinstance( self._lambda_weight, #目前不会用到 ListMLELambdaWeight): sums *= self._lambda_weight.individual_weights(sorted_labels) negative_log_likelihood = tf.reduce_sum(input_tensor=sums, axis=1) #在行上进行加和(batch_size,1) return negative_log_likelihood, weights
def _form_group_indices_nd(is_valid, group_size): """Forms the indices for groups for gather_nd or scatter_nd. Args: is_valid: A boolen `Tensor` for entry validity with shape [batch_size, list_size]. group_size: An scalar int `Tensor` for the number of examples in a group. Returns: A tuple of Tensors (indices, mask). The first has shape [batch_size, num_groups, group_size, 2] and it can be used in gather_nd or scatter_nd for group features. The second has the shape of [batch_size, num_groups] with value True for valid groups. """ with ops.name_scope(None, 'form_group_indices', (is_valid, group_size)): is_valid = ops.convert_to_tensor(is_valid) batch_size, list_size = array_ops.unstack(array_ops.shape(is_valid)) num_valid_entries = math_ops.reduce_sum(math_ops.to_int32(is_valid), axis=1) rw_indices, mask = _rolling_window_indices(list_size, group_size, num_valid_entries) # Valid indices of the tensor are shuffled and put on the top. # [batch_size, list_size, 2]. A determinstic op-level seed is set mainly for # unittest purpose. We can find a better way to avoid setting this seed # explicitly. shuffled_indices = utils.shuffle_valid_indices(is_valid, seed=87124) # Construct indices for gather_nd. # [batch_size, num_groups, group_size, 2] group_indices_nd = array_ops.expand_dims(rw_indices, axis=3) group_indices_nd = array_ops.concat([ array_ops.reshape(math_ops.range(batch_size), [-1, 1, 1, 1]) * array_ops.ones_like(group_indices_nd), group_indices_nd ], 3) indices = array_ops.gather_nd(shuffled_indices, group_indices_nd) return indices, mask
def _list_mle_loss(labels, logits, weights=None, lambda_weight=None, reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS, name=None, seed=None): """Computes the ListMLE loss [Xia et al. 2008] for a list. Given the labels of graded relevance l_i and the logits s_i, we calculate the ListMLE loss for the given list. The `lambda_weight` re-weights examples based on l_i and r_i. The recommended weighting scheme is the formulation presented in the "Position-Aware ListMLE" paper (Lan et. al) and available using create_p_list_mle_lambda_weight() factory function above. Args: labels: A `Tensor` of the same shape as `logits` representing graded relevance. logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise weights, or a `Tensor` with shape [batch_size, list_size] for item-wise weights. lambda_weight: A `DCGLambdaWeight` instance. reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. name: A string used as the name for this loss. seed: A randomization seed used when shuffling ground truth permutations. Returns: An op for the ListMLE loss. """ with ops.name_scope(name, 'list_mle_loss', (labels, logits, weights)): is_label_valid = utils.is_label_valid(labels) # Reset the invalid labels to 0 and reset the invalid logits to a logit with # ~= 0 contribution. labels = array_ops.where(is_label_valid, labels, array_ops.zeros_like(labels)) logits = array_ops.where( is_label_valid, logits, math_ops.log(_EPSILON) * array_ops.ones_like(logits)) weights = 1.0 if weights is None else ops.convert_to_tensor(weights) weights = array_ops.squeeze(weights) # Shuffle labels and logits to add randomness to sort. shuffled_indices = utils.shuffle_valid_indices(is_label_valid, seed) shuffled_labels = array_ops.gather_nd(labels, shuffled_indices) shuffled_logits = array_ops.gather_nd(logits, shuffled_indices) sorted_labels, sorted_logits = utils.sort_by_scores( shuffled_labels, [shuffled_labels, shuffled_logits]) raw_max = math_ops.reduce_max(sorted_logits, axis=1, keepdims=True) sorted_logits = sorted_logits - raw_max sums = math_ops.cumsum(math_ops.exp(sorted_logits), axis=1, reverse=True) sums = math_ops.log(sums) - sorted_logits if lambda_weight is not None and isinstance(lambda_weight, ListMLELambdaWeight): sums *= lambda_weight.individual_weights(sorted_labels) negative_log_likelihood = math_ops.reduce_sum(sums, 1) return core_losses.compute_weighted_loss(negative_log_likelihood, weights=weights, reduction=reduction)