Exemplo n.º 1
0
 def test_dcg_lambda_weight_is_serializable(self):
     self.assertIsSerializable(losses.DCGLambdaWeight())
     self.assertIsSerializable(
         losses.DCGLambdaWeight(gain_fn=utils.identity,
                                rank_discount_fn=utils.log2_inverse))
     self.assertIsSerializable(losses.NDCGLambdaWeight())
     self.assertIsSerializable(
         losses.NDCGLambdaWeight(gain_fn=utils.identity,
                                 rank_discount_fn=utils.log2_inverse))
Exemplo n.º 2
0
    def test_softmax_loss(self):
        scores = [[1., 3., 2.], [1., 2., 3.], [1., 2., 3.]]
        labels = [[0., 0., 1.], [0., 0., 2.], [0., 0., 0.]]
        weights = [[2.], [1.], [1.]]

        loss = losses.get(loss=losses.RankingLossKey.SOFTMAX_LOSS)
        self.assertAlmostEqual(
            loss(labels, scores).numpy(),
            -(ln(_softmax(scores[0])[2]) + ln(_softmax(scores[1])[2]) * 2.) /
            3.,
            places=5)
        self.assertAlmostEqual(loss(labels, scores, weights).numpy(),
                               -(ln(_softmax(scores[0])[2]) * 2. +
                                 ln(_softmax(scores[1])[2]) * 2. * 1.) / 3.,
                               places=5)

        # Test LambdaWeight.
        rank_discount_fn = lambda r: 1. / tf.math.log1p(r)
        lambda_weight = losses.DCGLambdaWeight(
            rank_discount_fn=rank_discount_fn)
        loss = losses.get(loss=losses.RankingLossKey.SOFTMAX_LOSS,
                          lambda_weight=lambda_weight)
        self.assertAlmostEqual(
            loss(labels, scores).numpy(),
            -(ln(_softmax(scores[0])[2]) / ln(1. + 2.) +
              ln(_softmax(scores[1])[2]) * 2. / ln(1. + 1.)) / 3.,
            places=5)
Exemplo n.º 3
0
    def _check_pairwise_loss(self, loss_form):
        """Helper function to test `loss_fn`."""
        scores = [[1., 3., 2.], [1., 2., 3.]]
        labels = [[0., 0., 1.], [0., 0., 2.]]
        listwise_weights = [[2.], [1.]]
        listwise_weights_expanded = [[2.] * 3, [1.] * 3]
        itemwise_weights = [[2., 3., 4.], [1., 1., 1.]]
        default_weights = [1.] * 3
        list_size = 3.
        loss_form_dict = {
            'hinge':
            losses.get(loss=losses.RankingLossKey.PAIRWISE_HINGE_LOSS,
                       name='hinge'),
            'logistic':
            losses.get(loss=losses.RankingLossKey.PAIRWISE_LOGISTIC_LOSS,
                       name='logistic'),
            'soft_zero_one':
            losses.get(loss=losses.RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS,
                       name='soft_zero_one'),
        }
        loss_fn = loss_form_dict[loss_form]

        # Individual lists.
        self.assertAlmostEqual(loss_fn([labels[0]], [scores[0]]).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[0], scores[0],
                                                  default_weights, loss_form)
                               ]),
                               places=5)
        self.assertAlmostEqual(loss_fn([labels[1]], [scores[1]]).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[1], scores[1],
                                                  default_weights, loss_form)
                               ]),
                               places=5)

        # Itemwise weights.
        self.assertAlmostEqual(loss_fn([labels[0]], [scores[0]],
                                       sample_weight=[itemwise_weights[0]
                                                      ]).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[0], scores[0],
                                                  itemwise_weights[0],
                                                  loss_form)
                               ]),
                               places=5)

        self.assertAlmostEqual(loss_fn([labels[1]], [scores[1]],
                                       sample_weight=[itemwise_weights[1]
                                                      ]).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[1], scores[1],
                                                  itemwise_weights[1],
                                                  loss_form)
                               ]),
                               places=5)

        # Multiple lists.
        self.assertAlmostEqual(loss_fn(labels,
                                       scores,
                                       sample_weight=listwise_weights).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[0], scores[0],
                                                  listwise_weights_expanded[0],
                                                  loss_form),
                                   _pairwise_loss(labels[1], scores[1],
                                                  listwise_weights_expanded[1],
                                                  loss_form)
                               ]),
                               places=5)

        # Test LambdaWeight.
        rank_discount_fn = lambda r: 1. / tf.math.log1p(r)
        lambda_weight = losses.DCGLambdaWeight(
            rank_discount_fn=rank_discount_fn, smooth_fraction=1.)
        loss_form_dict = {
            'hinge':
            losses.get(loss=losses.RankingLossKey.PAIRWISE_HINGE_LOSS,
                       name='hinge',
                       lambda_weight=lambda_weight),
            'logistic':
            losses.get(loss=losses.RankingLossKey.PAIRWISE_LOGISTIC_LOSS,
                       name='logistic',
                       lambda_weight=lambda_weight),
            'soft_zero_one':
            losses.get(loss=losses.RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS,
                       name='soft_zero_one',
                       lambda_weight=lambda_weight),
        }
        loss_fn = loss_form_dict[loss_form]

        self.assertAlmostEqual(loss_fn(labels,
                                       scores,
                                       sample_weight=listwise_weights).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[0],
                                                  scores[0],
                                                  listwise_weights_expanded[0],
                                                  loss_form,
                                                  rank_discount_form='LOG'),
                                   _pairwise_loss(labels[1],
                                                  scores[1],
                                                  listwise_weights_expanded[1],
                                                  loss_form,
                                                  rank_discount_form='LOG')
                               ]) * list_size,
                               places=5)