Пример #1
0
    def test_list_mle_loss(self):
        with tf.Graph().as_default():
            scores = [[0., ln(3), ln(2)], [0., ln(2), ln(3)],
                      [0., ln(2), ln(3)]]
            labels = [[0., 2., 1.], [1., 0., 2.], [0., 0., 0.]]
            weights = [[2.], [1.], [1.]]

            with self.cached_session():
                self.assertAlmostEqual(
                    ranking_losses._list_mle_loss(labels, scores,
                                                  seed=1).eval(),
                    -((ln(3. / (3 + 2 + 1)) + ln(2. / (2 + 1)) + ln(1. / 1)) +
                      (ln(3. / (3 + 2 + 1)) + ln(1. / (1 + 2)) + ln(2. / 2)) +
                      (ln(3. / (3 + 2 + 1)) + ln(2. /
                                                 (2 + 1)) + ln(1. / 1))) / 3,
                    places=5)
                self.assertAlmostEqual(
                    ranking_losses._list_mle_loss(labels,
                                                  scores,
                                                  weights,
                                                  seed=1).eval(),
                    -(2 * (ln(3. /
                              (3 + 2 + 1)) + ln(2. /
                                                (2 + 1)) + ln(1. / 1)) + 1 *
                      (ln(3. / (3 + 2 + 1)) + ln(1. /
                                                 (1 + 2)) + ln(2. / 2)) + 1 *
                      (ln(3. / (3 + 2 + 1)) + ln(2. /
                                                 (2 + 1)) + ln(1. / 1))) / 3,
                    places=5)
Пример #2
0
 def test_list_mle_loss_tie(self):
   with tf.Graph().as_default():
     tf.compat.v1.set_random_seed(3)
     scores = [[0., ln(2), ln(3)]]
     labels = [[0., 0., 0.]]
     with self.cached_session():
       self.assertAlmostEqual(
           ranking_losses._list_mle_loss(labels, scores).eval(),
           -((ln(3. / (3 + 2 + 1)) + ln(2. / (2 + 1)) + ln(1. / 1))),
           places=5)
Пример #3
0
  def test_list_mle_loss_lambda_weight(self):
    scores = [[0., ln(3), ln(2)], [0., ln(2), ln(3)], [0., ln(2), ln(3)]]
    labels = [[0., 2., 1.], [1., 0., 2.], [0., 0., 0.]]

    lw = ranking_losses.create_p_list_mle_lambda_weight(3)
    with self.cached_session():
      self.assertAlmostEqual(
          ranking_losses._list_mle_loss(
              labels, scores, lambda_weight=lw, seed=1).eval(),
          -((3 * ln(3. / (3 + 2 + 1)) + 1 * ln(2. / (2 + 1)) + 0 * ln(1. / 1)) +
            (3 * ln(3. / (3 + 2 + 1)) + 1 * ln(1. / (1 + 2)) + 0 * ln(2. / 2)) +
            (3 * ln(3. / (3 + 2 + 1)) + 1 * ln(2. /
                                               (2 + 1)) + 0 * ln(1. / 1))) / 3,
          places=5)