def test_list_mle_loss_tie(self): tf.random.set_seed(1) scores = [[0., ln(2), ln(3)]] labels = [[0., 0., 1.]] loss = losses.ListMLELoss() self.assertAlmostEqual( loss(labels, scores).numpy(), -((ln(3. / (3 + 2 + 1)) + ln(2. / (2 + 1)) + ln(1. / 1))), places=5)
def test_listwise_losses_are_serializable(self): self.assertIsLossSerializable( losses.SoftmaxLoss(lambda_weight=self._lambda_weight)) self.assertIsLossSerializable( losses.ListMLELoss(lambda_weight=self._lambda_weight)) self.assertIsLossSerializable( losses.ApproxMRRLoss(lambda_weight=self._lambda_weight)) self.assertIsLossSerializable( losses.ApproxNDCGLoss(lambda_weight=self._lambda_weight)) # TODO: Debug assertIsLossSerializable for Gumbel loss. Right now, # the loss values got from obj and the deserialized don't match exactly. self.assertIsSerializable(losses.GumbelApproxNDCGLoss(seed=1))
def test_list_mle_loss_lambda_weight(self): scores = [[0., ln(3), ln(2)], [0., ln(2), ln(3)]] labels = [[0., 2., 1.], [1., 0., 2.]] lw = losses.ListMLELambdaWeight( rank_discount_fn=lambda rank: tf.pow(2., 3 - rank) - 1.) loss = losses.ListMLELoss(lambda_weight=lw) self.assertAlmostEqual( loss(labels, scores).numpy(), -((3 * ln(3. / (3 + 2 + 1)) + 1 * ln(2. / (2 + 1)) + 0 * ln(1. / 1)) + (3 * ln(3. / (3 + 2 + 1)) + 1 * ln(1. / (1 + 2)) + 0 * ln(2. / 2))) / 2, places=5)
def test_list_mle_loss(self): scores = [[0., ln(3), ln(2)], [0., ln(2), ln(3)]] labels = [[0., 2., 1.], [1., 0., 2.]] weights = [[2.], [1.]] loss = losses.ListMLELoss() self.assertAlmostEqual( loss(labels, scores).numpy(), -((ln(3. / (3 + 2 + 1)) + ln(2. / (2 + 1)) + ln(1. / 1)) + (ln(3. / (3 + 2 + 1)) + ln(1. / (1 + 2)) + ln(2. / 2))) / 2, places=5) self.assertAlmostEqual( loss(labels, scores, weights).numpy(), -(2 * (ln(3. / (3 + 2 + 1)) + ln(2. / (2 + 1)) + ln(1. / 1)) + 1 * (ln(3. / (3 + 2 + 1)) + ln(1. / (1 + 2)) + ln(2. / 2))) / 2, places=5)