示例#1
0
  def test_sorted_ranks(self):
    with tf.Graph().as_default():
      scores = [[1., 3., 2.]]
      with tf.compat.v1.Session() as sess:
        ranks = sess.run(utils.sorted_ranks(scores, seed=1))
        self.assertAllEqual(ranks, [[3, 1, 2]])

      tf.compat.v1.set_random_seed(3)
      scores = [[1., 2., 1.]]
      with tf.compat.v1.Session() as sess:
        ranks = sess.run(utils.sorted_ranks(scores, shuffle_ties=False, seed=1))
        self.assertAllEqual(ranks, [[2, 1, 3]])
        ranks = sess.run(utils.sorted_ranks(scores, shuffle_ties=True, seed=1))
        self.assertAllEqual(ranks, [[3, 1, 2]])
示例#2
0
def _compute_ranks(logits, is_valid):
    """Computes ranks by sorting valid logits.

  Args:
    logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
      ranking score of the corresponding item.
    is_valid: A `Tensor` of the same shape as `logits` representing validity of
      each entry.

  Returns:
    The `ranks` Tensor.
  """
    _check_tensor_shapes([logits, is_valid])
    # Only sort entries with is_valid = True.
    scores = tf.compat.v1.where(
        is_valid, logits, -1e-6 * tf.ones_like(logits) +
        tf.reduce_min(input_tensor=logits, axis=1, keepdims=True))
    return utils.sorted_ranks(scores)