def test_sorted_ranks(self): with tf.Graph().as_default(): scores = [[1., 3., 2.]] with tf.compat.v1.Session() as sess: ranks = sess.run(utils.sorted_ranks(scores, seed=1)) self.assertAllEqual(ranks, [[3, 1, 2]]) tf.compat.v1.set_random_seed(3) scores = [[1., 2., 1.]] with tf.compat.v1.Session() as sess: ranks = sess.run(utils.sorted_ranks(scores, shuffle_ties=False, seed=1)) self.assertAllEqual(ranks, [[2, 1, 3]]) ranks = sess.run(utils.sorted_ranks(scores, shuffle_ties=True, seed=1)) self.assertAllEqual(ranks, [[3, 1, 2]])
def _compute_ranks(logits, is_valid): """Computes ranks by sorting valid logits. Args: logits: A `Tensor` with shape [batch_size, list_size]. Each value is the ranking score of the corresponding item. is_valid: A `Tensor` of the same shape as `logits` representing validity of each entry. Returns: The `ranks` Tensor. """ _check_tensor_shapes([logits, is_valid]) # Only sort entries with is_valid = True. scores = tf.compat.v1.where( is_valid, logits, -1e-6 * tf.ones_like(logits) + tf.reduce_min(input_tensor=logits, axis=1, keepdims=True)) return utils.sorted_ranks(scores)