Пример #1
0
    def test_list_mle_loss_tie(self):
        tf.random.set_seed(1)
        scores = [[0., ln(2), ln(3)]]
        labels = [[0., 0., 1.]]

        loss = losses.ListMLELoss()
        self.assertAlmostEqual(
            loss(labels, scores).numpy(),
            -((ln(3. / (3 + 2 + 1)) + ln(2. / (2 + 1)) + ln(1. / 1))),
            places=5)
Пример #2
0
 def test_listwise_losses_are_serializable(self):
     self.assertIsLossSerializable(
         losses.SoftmaxLoss(lambda_weight=self._lambda_weight))
     self.assertIsLossSerializable(
         losses.ListMLELoss(lambda_weight=self._lambda_weight))
     self.assertIsLossSerializable(
         losses.ApproxMRRLoss(lambda_weight=self._lambda_weight))
     self.assertIsLossSerializable(
         losses.ApproxNDCGLoss(lambda_weight=self._lambda_weight))
     # TODO: Debug assertIsLossSerializable for Gumbel loss. Right now,
     # the loss values got from obj and the deserialized don't match exactly.
     self.assertIsSerializable(losses.GumbelApproxNDCGLoss(seed=1))
Пример #3
0
  def test_list_mle_loss_lambda_weight(self):
    scores = [[0., ln(3), ln(2)], [0., ln(2), ln(3)]]
    labels = [[0., 2., 1.], [1., 0., 2.]]
    lw = losses.ListMLELambdaWeight(
        rank_discount_fn=lambda rank: tf.pow(2., 3 - rank) - 1.)

    loss = losses.ListMLELoss(lambda_weight=lw)
    self.assertAlmostEqual(
        loss(labels, scores).numpy(),
        -((3 * ln(3. / (3 + 2 + 1)) + 1 * ln(2. / (2 + 1)) + 0 * ln(1. / 1)) +
          (3 * ln(3. / (3 + 2 + 1)) + 1 * ln(1. / (1 + 2)) + 0 * ln(2. / 2))) /
        2,
        places=5)
Пример #4
0
    def test_list_mle_loss(self):
        scores = [[0., ln(3), ln(2)], [0., ln(2), ln(3)]]
        labels = [[0., 2., 1.], [1., 0., 2.]]
        weights = [[2.], [1.]]

        loss = losses.ListMLELoss()
        self.assertAlmostEqual(
            loss(labels, scores).numpy(),
            -((ln(3. / (3 + 2 + 1)) + ln(2. / (2 + 1)) + ln(1. / 1)) +
              (ln(3. / (3 + 2 + 1)) + ln(1. / (1 + 2)) + ln(2. / 2))) / 2,
            places=5)
        self.assertAlmostEqual(
            loss(labels, scores, weights).numpy(),
            -(2 * (ln(3. / (3 + 2 + 1)) + ln(2. / (2 + 1)) + ln(1. / 1)) + 1 *
              (ln(3. / (3 + 2 + 1)) + ln(1. / (1 + 2)) + ln(2. / 2))) / 2,
            places=5)