Example #1
0
    def test_gumbel_approx_ndcg_loss(self):
        scores = [[1.4, -2.8, -0.4], [0., 1.8, 10.2], [1., 1.2, -3.2]]
        labels = [[0., 2., 1.], [1., 0., 3.], [0., 0., 0.]]

        # sampled_scores = [[-.291, -1.643, -2.826],
        #                   [-.0866, -2.924, -3.530],
        #                   [-12.42, -9.492, -7.939e-5],
        #                   [-8.859, -6.830, -1.223e-3],
        #                   [-.8930, -.5266, -45.80183],
        #                   [-.6650, -.7220, -45.94149]]
        # ranks    =     [[1,      2,      3],
        #                 [1,      2,      3],
        #                 [3,      2,      1],
        #                 [3,      2,      1],
        #                 [2,      1,      3],
        #                 [1,      2,      3]]
        # expanded_labels = [[0., 2., 1.],
        #                    [0., 2., 1.],
        #                    [1., 0., 3.],
        #                    [1., 0., 3.],
        #                    [0., 0., 0.],
        #                    [0., 0., 0.]]
        # expanded_weights = [[2.], [2.],
        #                     [1.], [1.],
        #                     [1.], [1.]]

        loss = losses.GumbelApproxNDCGLoss(sample_size=2, seed=1)
        self.assertAlmostEqual(
            loss(labels, scores).numpy(),
            -(2 * (1 / (3 / ln(2) + 1 / ln(3))) * (3 / ln(3) + 1 / ln(4)) + 2 *
              (1 / (7 / ln(2) + 1 / ln(3))) * (7 / ln(2) + 1 / ln(4))) / 6,
            places=3)
Example #2
0
    def test_gumbel_approx_ndcg_weighted_loss(self):
        scores = [[1.4, -2.8, -0.4], [0., 1.8, 10.2], [1., 1.2, -3.2]]
        labels = [[0., 2., 1.], [1., 0., 3.], [0., 0., 0.]]
        weights = [[2.], [1.], [1.]]

        loss = losses.GumbelApproxNDCGLoss(sample_size=2, seed=1)
        self.assertAlmostEqual(
            loss(labels, scores, weights).numpy(),
            -(2 * 2 * (1 / (3 / ln(2) + 1 / ln(3))) *
              (3 / ln(3) + 1 / ln(4)) + 1 * 2 * (1 / (7 / ln(2) + 1 / ln(3))) *
              (7 / ln(2) + 1 / ln(4))) / 6,
            places=3)
Example #3
0
 def test_listwise_losses_are_serializable(self):
     self.assertIsLossSerializable(
         losses.SoftmaxLoss(lambda_weight=self._lambda_weight))
     self.assertIsLossSerializable(
         losses.ListMLELoss(lambda_weight=self._lambda_weight))
     self.assertIsLossSerializable(
         losses.ApproxMRRLoss(lambda_weight=self._lambda_weight))
     self.assertIsLossSerializable(
         losses.ApproxNDCGLoss(lambda_weight=self._lambda_weight))
     # TODO: Debug assertIsLossSerializable for Gumbel loss. Right now,
     # the loss values got from obj and the deserialized don't match exactly.
     self.assertIsSerializable(losses.GumbelApproxNDCGLoss(seed=1))