예제 #1
0
def _prepare_and_validate_params(labels, predictions, weights=None, topn=None):
    """Prepares and validates the parameters.

  Args:
    labels: A `Tensor` of the same shape as `predictions`. A value >= 1 means a
      relevant example.
    predictions: A `Tensor` with shape [batch_size, list_size]. Each value is
      the ranking score of the corresponding example.
    weights: A `Tensor` of the same shape of predictions or [batch_size, 1]. The
      former case is per-example and the latter case is per-list.
    topn: A cutoff for how many examples to consider for this metric.

  Returns:
    (labels, predictions, weights, topn) ready to be used for metric
    calculation.
  """
    labels = ops.convert_to_tensor(labels)
    predictions = ops.convert_to_tensor(predictions)
    weights = 1.0 if weights is None else ops.convert_to_tensor(weights)
    example_weights = array_ops.ones_like(labels) * weights
    predictions.get_shape().assert_is_compatible_with(
        example_weights.get_shape())
    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
    predictions.get_shape().assert_has_rank(2)
    if topn is None:
        topn = array_ops.shape(predictions)[1]

    # All labels should be >= 0. Invalid entries are reset.
    is_label_valid = utils.is_label_valid(labels)
    labels = array_ops.where(is_label_valid, labels,
                             array_ops.zeros_like(labels))
    predictions = array_ops.where(
        is_label_valid, predictions, -1e-6 * array_ops.ones_like(predictions) +
        math_ops.reduce_min(predictions, axis=1, keepdims=True))
    return labels, predictions, example_weights, topn
예제 #2
0
def _softmax_loss(labels,
                  logits,
                  weights=None,
                  lambda_weight=None,
                  reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
                  name=None):
  """Computes the softmax cross entropy for a list.

  Given the labels l_i and the logits s_i, we sort the examples and obtain ranks
  r_i. The standard softmax loss doesn't need r_i and is defined as
      -sum_i l_i * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))).
  The `lambda_weight` re-weight examples based on l_i and r_i.
      -sum_i w(l_i, r_i) * log(exp(s_i) / (exp(s_1) + ... + exp(s_n))).abc
  See 'individual_weights' in 'DCGLambdaWeight' for how w(l_i, r_i) is computed.

  Args:
    labels: A `Tensor` of the same shape as `logits` representing graded
      relevance.
    logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
      ranking score of the corresponding item.
    weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
      weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
      weights.
    lambda_weight: A `DCGLambdaWeight` instance.
    reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch.
    name: A string used as the name for this loss.

  Returns:
    An op for the softmax cross entropy as a loss.
  """
  with ops.name_scope(name, 'softmax_loss', (labels, logits, weights)):
    sorted_labels, sorted_logits, sorted_weights = _sort_and_normalize(
        labels, logits, weights)
    is_label_valid = utils.is_label_valid(sorted_labels)
    # Reset the invalid labels to 0 and reset the invalid logits to a logit with
    # ~= 0 contribution in softmax.
    sorted_labels = array_ops.where(is_label_valid, sorted_labels,
                                    array_ops.zeros_like(sorted_labels))
    sorted_logits = array_ops.where(
        is_label_valid, sorted_logits,
        math_ops.log(_EPSILON) * array_ops.ones_like(sorted_logits))
    if lambda_weight is not None and isinstance(lambda_weight, DCGLambdaWeight):
      sorted_labels = lambda_weight.individual_weights(sorted_labels)
    sorted_labels *= sorted_weights
    label_sum = math_ops.reduce_sum(sorted_labels, 1, keepdims=True)
    nonzero_mask = math_ops.greater(array_ops.reshape(label_sum, [-1]), 0.0)
    label_sum, sorted_labels, sorted_logits = [
        array_ops.boolean_mask(x, nonzero_mask)
        for x in [label_sum, sorted_labels, sorted_logits]
    ]
    return core_losses.softmax_cross_entropy(
        sorted_labels / label_sum,
        sorted_logits,
        weights=array_ops.reshape(label_sum, [-1]),
        reduction=reduction)
예제 #3
0
 def _get_valid_pairs_and_clean_labels(self, sorted_labels):
   """Returns a boolean Tensor for valid pairs and cleaned labels."""
   sorted_labels = ops.convert_to_tensor(sorted_labels)
   sorted_labels.get_shape().assert_has_rank(2)
   is_label_valid = utils.is_label_valid(sorted_labels)
   valid_pairs = math_ops.logical_and(
       array_ops.expand_dims(is_label_valid, 2),
       array_ops.expand_dims(is_label_valid, 1))
   sorted_labels = array_ops.where(is_label_valid, sorted_labels,
                                   array_ops.zeros_like(sorted_labels))
   return valid_pairs, sorted_labels
예제 #4
0
 def individual_weights(self, sorted_labels):
   """See `_LambdaWeight`."""
   with ops.name_scope(None, 'dcg_lambda_weight', (sorted_labels,)):
     sorted_labels = ops.convert_to_tensor(sorted_labels)
     sorted_labels = array_ops.where(
         utils.is_label_valid(sorted_labels), sorted_labels,
         array_ops.zeros_like(sorted_labels))
     gain = self._gain_fn(sorted_labels)
     if self._normalized:
       gain *= self._inverse_max_dcg(sorted_labels)
     rank_discount = self._rank_discount_fn(
         math_ops.to_float(
             math_ops.range(array_ops.shape(sorted_labels)[1]) + 1))
     return gain * rank_discount
예제 #5
0
def _sigmoid_cross_entropy_loss(
    labels,
    logits,
    weights=None,
    reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
    name=None):
  """Computes the sigmoid_cross_entropy loss for a list.

  Given the labels of graded relevance l_i and the logits s_i, we calculate
  the sigmoid cross entropy for each ith position and aggregate the per position
  losses.

  Args:
    labels: A `Tensor` of the same shape as `logits` representing graded
      relevance.
    logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
      ranking score of the corresponding item.
    weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
      weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
      weights.
    reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch.
    name: A string used as the name for this loss.

  Returns:
    An op for the sigmoid cross entropy as a loss.
  """
  with ops.name_scope(name, 'sigmoid_cross_entropy_loss',
                      (labels, logits, weights)):
    is_label_valid = array_ops.reshape(utils.is_label_valid(labels), [-1])
    weights = 1.0 if weights is None else ops.convert_to_tensor(weights)
    weights = array_ops.ones_like(labels) * weights
    label_vector, logit_vector, weight_vector = [
        array_ops.boolean_mask(array_ops.reshape(x, [-1]), is_label_valid)
        for x in [labels, logits, weights]
    ]
    return core_losses.sigmoid_cross_entropy(
        label_vector, logit_vector, weights=weight_vector, reduction=reduction)
예제 #6
0
def _list_mle_loss(labels,
                   logits,
                   weights=None,
                   lambda_weight=None,
                   reduction=core_losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
                   name=None,
                   seed=None):
  """Computes the ListMLE loss [Xia et al.

  2008] for a list.

  Given the labels of graded relevance l_i and the logits s_i, we calculate
  the ListMLE loss for the given list.

  The `lambda_weight` re-weights examples based on l_i and r_i.
  The recommended weighting scheme is the formulation presented in the
  "Position-Aware ListMLE" paper (Lan et. al) and available using
  create_p_list_mle_lambda_weight() factory function above.

  Args:
    labels: A `Tensor` of the same shape as `logits` representing graded
      relevance.
    logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
      ranking score of the corresponding item.
    weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
      weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
      weights.
    lambda_weight: A `DCGLambdaWeight` instance.
    reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
      reduce training loss over batch.
    name: A string used as the name for this loss.
    seed: A randomization seed used when shuffling ground truth permutations.

  Returns:
    An op for the ListMLE loss.
  """
  with ops.name_scope(name, 'list_mle_loss', (labels, logits, weights)):
    is_label_valid = utils.is_label_valid(labels)
    # Reset the invalid labels to 0 and reset the invalid logits to a logit with
    # ~= 0 contribution.
    labels = array_ops.where(is_label_valid, labels,
                             array_ops.zeros_like(labels))
    logits = array_ops.where(
        is_label_valid, logits,
        math_ops.log(_EPSILON) * array_ops.ones_like(logits))
    weights = 1.0 if weights is None else ops.convert_to_tensor(weights)
    weights = array_ops.squeeze(weights)

    # Shuffle labels and logits to add randomness to sort.
    shuffled_indices = utils.shuffle_valid_indices(is_label_valid, seed)
    shuffled_labels = array_ops.gather_nd(labels, shuffled_indices)
    shuffled_logits = array_ops.gather_nd(logits, shuffled_indices)

    sorted_labels, sorted_logits = utils.sort_by_scores(
        shuffled_labels, [shuffled_labels, shuffled_logits])

    raw_max = math_ops.reduce_max(sorted_logits, axis=1, keepdims=True)
    sorted_logits = sorted_logits - raw_max
    sums = math_ops.cumsum(math_ops.exp(sorted_logits), axis=1, reverse=True)
    sums = math_ops.log(sums) - sorted_logits

    if lambda_weight is not None and isinstance(lambda_weight,
                                                ListMLELambdaWeight):
      sums *= lambda_weight.individual_weights(sorted_labels)

    negative_log_likelihood = math_ops.reduce_sum(sums, 1)

    return core_losses.compute_weighted_loss(
        negative_log_likelihood, weights=weights, reduction=reduction)
예제 #7
0
def _pairwise_comparison(sorted_labels,
                         sorted_logits,
                         sorted_weights,
                         lambda_weight=None):
  r"""Returns pairwise comparison `Tensor`s.

  Given a list of n items, the labels of graded relevance l_i and the logits
  s_i, we sort the items in a list based on s_i and obtain ranks r_i. We form
  n^2 pairs of items. For each pair, we have the following:

                        /
                        | 1   if l_i > l_j
  * `pairwise_labels` = |
                        | 0   if l_i <= l_j
                        \
  * `pairwise_logits` = s_i - s_j
                         /
                         | 0              if l_i <= l_j,
  * `pairwise_weights` = | |l_i - l_j|    if lambda_weight is None,
                         | lambda_weight  otherwise.
                         \

  The `sorted_weights` is item-wise and is applied non-symmetrically to update
  pairwise_weights as
    pairwise_weights(i, j) = w_i * pairwise_weights(i, j).
  This effectively applies to all pairs with l_i > l_j. Note that it is actually
  symmetric when `sorted_weights` are constant per list, i.e., listwise weights.

  Args:
    sorted_labels: A `Tensor` with shape [batch_size, list_size] of labels
      sorted.
    sorted_logits: A `Tensor` with shape [batch_size, list_size] of logits
      sorted.
    sorted_weights: A `Tensor` with shape [batch_size, list_size] of item-wise
      weights sorted.
    lambda_weight: A `_LambdaWeight` object.

  Returns:
    A tuple of (pairwise_labels, pairwise_logits, pairwise_weights) with each
    having the shape [batch_size, list_size, list_size].
  """
  # Compute the difference for all pairs in a list. The output is a Tensor with
  # shape [batch_size, list_size, list_size] where the entry [-1, i, j] stores
  # the information for pair (i, j).
  pairwise_label_diff = array_ops.expand_dims(
      sorted_labels, 2) - array_ops.expand_dims(sorted_labels, 1)
  pairwise_logits = array_ops.expand_dims(
      sorted_logits, 2) - array_ops.expand_dims(sorted_logits, 1)
  pairwise_labels = math_ops.to_float(math_ops.greater(pairwise_label_diff, 0))
  is_label_valid = utils.is_label_valid(sorted_labels)
  valid_pair = math_ops.logical_and(
      array_ops.expand_dims(is_label_valid, 2),
      array_ops.expand_dims(is_label_valid, 1))
  # Only keep the case when l_i > l_j.
  pairwise_weights = pairwise_labels * math_ops.to_float(valid_pair)
  # Apply the item-wise weights along l_i.
  pairwise_weights *= array_ops.expand_dims(sorted_weights, 2)
  if lambda_weight is not None:
    pairwise_weights *= lambda_weight.pair_weights(sorted_labels)
  else:
    pairwise_weights *= math_ops.abs(pairwise_label_diff)
  pairwise_weights = array_ops.stop_gradient(
      pairwise_weights, name='weights_stop_gradient')
  return pairwise_labels, pairwise_logits, pairwise_weights
예제 #8
0
    def _groupwise_dnn_v2(features, labels, mode, params, config):
        """Defines the dnn for groupwise scoring functions."""
        with ops.name_scope('transform'):
            context_features, per_example_features = _call_transform_fn(
                features, mode, params)

        def _score_fn(context_features, group_features, reuse):
            with variable_scope.variable_scope('group_score', reuse=reuse):
                return group_score_fn(context_features, group_features, mode,
                                      params, config)

        # Scatter/Gather per-example scores through groupwise comparison. Each
        # instance in a mini-batch will form a number of groups. Each groups of
        # examples are scored by 'score_fn' and socres for individual examples
        # accumulated over groups.
        with ops.name_scope('groupwise_dnn_v2'):
            with ops.name_scope('infer_sizes'):
                if labels is not None:
                    batch_size, list_size = array_ops.unstack(
                        array_ops.shape(labels))
                    is_valid = utils.is_label_valid(labels)
                else:
                    # Infer batch_size and list_size from a feature.
                    example_tensor_shape = array_ops.shape(
                        next(six.itervalues(per_example_features)))
                    batch_size = example_tensor_shape[0]
                    list_size = example_tensor_shape[1]
                    is_valid = utils.is_label_valid(
                        array_ops.ones([batch_size, list_size]))
            if batch_size is None or list_size is None:
                raise ValueError('Invalid batch_size=%s or list_size=%s' %
                                 (batch_size, list_size))

            # For each example feature, assume the shape is [batch_size, list_size,
            # feature_size], the groups are formed along the 2nd dim. Each group has a
            # 'group_size' number of indices in [0, list_size). Based on these
            # indices, we can gather the example feature into a sub-tensor for each
            # group. The total number of groups we have for a mini-batch is batch_size
            # * num_groups. Inside each group, we have a 'group_size' number of
            # examples.
            indices, mask = _form_group_indices_nd(
                is_valid,
                group_size,
                shuffle=(mode != model_fn.ModeKeys.PREDICT))
            num_groups = array_ops.shape(mask)[1]

            with ops.name_scope('group_features'):
                # For context features, We have shape [batch_size * num_groups, ...].
                large_batch_context_features = {}
                for name, value in six.iteritems(context_features):
                    # [batch_size, 1, ...].
                    value = array_ops.expand_dims(value, axis=1)
                    # [batch_size, num_groups, ...].  ,就是tile
                    value = array_ops.gather(value,
                                             array_ops.zeros([num_groups],
                                                             dtypes.int32),
                                             axis=1)
                    # [batch_size * num_groups, ...]
                    large_batch_context_features[
                        name] = utils.reshape_first_ndims(
                            value, 2, [batch_size * num_groups])

                # For example feature, we have shape [batch_size * num_groups,
                # group_size, ...].
                large_batch_group_features = {}
                for name, value in six.iteritems(per_example_features):
                    # [batch_size, num_groups, group_size, ...].
                    value = array_ops.gather_nd(value, indices)
                    # [batch_size * num_groups, group_size, ...].
                    large_batch_group_features[
                        name] = utils.reshape_first_ndims(
                            value, 3, [batch_size * num_groups, group_size])

            # Do the inference and get scores for the large batch.
            # [batch_size * num_groups, group_size].
            scores = _score_fn(large_batch_context_features,
                               large_batch_group_features,
                               reuse=False)

            with ops.name_scope('accumulate_scores'):
                scores = array_ops.reshape(
                    scores, [batch_size, num_groups, group_size])
                # Reset invalid scores to 0 based on mask.
                scores = array_ops.where(
                    array_ops.gather(array_ops.expand_dims(mask, 2),
                                     array_ops.zeros([group_size],
                                                     dtypes.int32),
                                     axis=2), scores,
                    array_ops.zeros_like(scores))
                # [batch_size, num_groups, group_size].
                list_scores = array_ops.scatter_nd(indices, scores,
                                                   [batch_size, list_size])
                # Use average.
                list_scores /= math_ops.to_float(group_size)

        if mode == model_fn.ModeKeys.PREDICT:
            return list_scores
        else:
            features.update(context_features)
            features.update(per_example_features)
            return list_scores