def test_approx_ndcg_loss(self): scores = [[1.4, -2.8, -0.4], [0., 1.8, 10.2], [1., 1.2, -3.2]] # ranks= [[1, 3, 2], [3, 2, 1], [2, 1, 3]] labels = [[0., 2., 1.], [1., 0., 3.], [0., 0., 0.]] weights = [[2.], [1.], [1.]] example_weights = [[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]] norm_weights = [ normalize_weights(w, l) for w, l in zip(example_weights, labels) ] loss = losses.ApproxNDCGLoss() self.assertAlmostEqual( loss(labels, scores).numpy(), -((1 / (3 / ln(2) + 1 / ln(3))) * (3 / ln(4) + 1 / ln(3)) + (1 / (7 / ln(2) + 1 / ln(3))) * (7 / ln(2) + 1 / ln(4))) / 3., places=5) self.assertAlmostEqual( loss(labels, scores, weights).numpy(), -(2 * (1 / (3 / ln(2) + 1 / ln(3))) * (3 / ln(4) + 1 / ln(3)) + 1 * (1 / (7 / ln(2) + 1 / ln(3))) * (7 / ln(2) + 1 / ln(4))) / 3., places=5) self.assertAlmostEqual( loss(labels, scores, example_weights).numpy(), -(norm_weights[0] * (1 / (3 / ln(2) + 1 / ln(3))) * (3 / ln(4) + 1 / ln(3)) + norm_weights[1] * (1 / (7 / ln(2) + 1 / ln(3))) * (7 / ln(2) + 1 / ln(4))) / 3., places=5)
def test_listwise_losses_are_serializable(self): self.assertIsLossSerializable( losses.SoftmaxLoss(lambda_weight=self._lambda_weight)) self.assertIsLossSerializable( losses.ListMLELoss(lambda_weight=self._lambda_weight)) self.assertIsLossSerializable( losses.ApproxMRRLoss(lambda_weight=self._lambda_weight)) self.assertIsLossSerializable( losses.ApproxNDCGLoss(lambda_weight=self._lambda_weight)) # TODO: Debug assertIsLossSerializable for Gumbel loss. Right now, # the loss values got from obj and the deserialized don't match exactly. self.assertIsSerializable(losses.GumbelApproxNDCGLoss(seed=1))
def test_approx_ndcg_loss_sum_batch(self): scores = [[1.4, -2.8, -0.4], [0., 1.8, 10.2], [1., 1.2, -3.2]] # ranks= [[1, 3, 2], [3, 2, 1], [2, 1, 3]] labels = [[0., 2., 1.], [1., 0., 3.], [0., 0., 0.]] example_weights = [[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]] norm_wts = [ normalize_weights(wts, lbls) for wts, lbls in zip(example_weights, labels) ] loss = losses.ApproxNDCGLoss( reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE) self.assertAlmostEqual( loss(labels, scores).numpy(), -((1 / (3 / ln(2) + 1 / ln(3))) * (3 / ln(4) + 1 / ln(3)) + (1 / (7 / ln(2) + 1 / ln(3))) * (7 / ln(2) + 1 / ln(4))) / 3., places=5) self.assertAlmostEqual( loss(labels, scores, example_weights).numpy(), -(norm_wts[0] * (1 / (3 / ln(2) + 1 / ln(3))) * (3 / ln(4) + 1 / ln(3)) + norm_wts[1] * (1 / (7 / ln(2) + 1 / ln(3))) * (7 / ln(2) + 1 / ln(4))) / 3., places=5)