def test_smooth_fraction(self):
     """For the weights using absolute rank."""
     sorted_labels = [[2.0, 1.0, 0.0]]
     lambda_weight = ranking_losses.DCGLambdaWeight(smooth_fraction=1.0)
     with self.cached_session():
         self.assertAllClose(
             lambda_weight.pair_weights(sorted_labels).eval(),
             [[[0., 1. / 2., 2. * 2. / 3.], [1. / 2., 0., 1. / 6.],
               [2. * 2. / 3., 1. / 6., 0.]]])
     lambda_weight = ranking_losses.DCGLambdaWeight(topn=1,
                                                    smooth_fraction=1.0)
     with self.cached_session():
         self.assertAllClose(
             lambda_weight.pair_weights(sorted_labels).eval(),
             [[[0., 1., 2.], [1., 0., 0.], [2., 0., 0.]]])
 def test_invalid_labels(self):
     sorted_labels = [[2.0, 1.0, -1.0]]
     lambda_weight = ranking_losses.DCGLambdaWeight()
     with self.cached_session():
         self.assertAllClose(
             lambda_weight.pair_weights(sorted_labels).eval(),
             [[[0., 1. / 2., 0.], [1. / 2., 0., 0.], [0., 0., 0.]]])
 def test_softmax_loss(self):
     scores = [[1., 3., 2.], [1., 2., 3.], [1., 2., 3.]]
     labels = [[0., 0., 1.], [0., 0., 2.], [0., 0., 0.]]
     weights = [[2.], [1.], [1.]]
     with self.cached_session():
         self.assertAlmostEqual(
             ranking_losses._softmax_loss(labels, scores).eval(),
             -(math.log(_softmax(scores[0])[2]) +
               math.log(_softmax(scores[1])[2]) * 2.) / 2.,
             places=5)
         self.assertAlmostEqual(
             ranking_losses._softmax_loss(labels, scores, weights).eval(),
             -(math.log(_softmax(scores[0])[2]) * 2. +
               math.log(_softmax(scores[1])[2]) * 2. * 1.) / 2.,
             places=5)
         # Test LambdaWeight.
         lambda_weight = ranking_losses.DCGLambdaWeight(
             rank_discount_fn=lambda r: 1. / math_ops.log1p(r))
         self.assertAlmostEqual(
             ranking_losses._softmax_loss(
                 labels, scores, lambda_weight=lambda_weight).eval(),
             -(math.log(_softmax(scores[0])[2]) / math.log(1. + 2.) +
               math.log(_softmax(scores[1])[2]) * 2. / math.log(1. + 1.)) /
             2.,
             places=5)
 def test_normalized(self):
     sorted_labels = [[1.0, 2.0]]
     lambda_weight = ranking_losses.DCGLambdaWeight(normalized=True)
     max_dcg = 2.5
     with self.cached_session():
         self.assertAllClose(
             lambda_weight.pair_weights(sorted_labels).eval(),
             [[[0., 1. / 2. / max_dcg], [1. / 2. / max_dcg, 0.]]])
 def test_individual_weights(self):
     sorted_labels = [[1.0, 2.0]]
     lambda_weight = ranking_losses.DCGLambdaWeight(normalized=True)
     max_dcg = 2.5
     with self.cached_session():
         self.assertAllClose(
             lambda_weight.individual_weights(sorted_labels).eval(),
             [[1. / max_dcg / 1., 2. / max_dcg / 2.]])
 def test_default(self):
     """For the weight using rank diff."""
     sorted_labels = [[2.0, 1.0, 0.0]]
     lambda_weight = ranking_losses.DCGLambdaWeight()
     with self.cached_session():
         self.assertAllClose(
             lambda_weight.pair_weights(sorted_labels).eval(),
             [[[0., 1. / 2., 2. * 1. / 6.], [1. / 2., 0., 1. / 2.],
               [2. * 1. / 6., 1. / 2., 0.]]])
 def test_gain_and_discount(self):
     sorted_labels = [[2.0, 1.0]]
     lambda_weight = ranking_losses.DCGLambdaWeight(
         gain_fn=lambda x: math_ops.pow(2., x) - 1.,
         rank_discount_fn=lambda r: 1. / math_ops.log1p(r))
     with self.cached_session():
         self.assertAllClose(
             lambda_weight.pair_weights(sorted_labels).eval(),
             [[[0., 2. * (1. / math.log(2.) - 1. / math.log(3.))],
               [2. * (1. / math.log(2.) - 1. / math.log(3.)), 0.]]])
    def _check_make_pairwise_loss(self, loss_key):
        """Helper function to test `make_loss_fn`."""
        scores = [[1., 3., 2.], [1., 2., 3.]]
        labels = [[0., 0., 1.], [0., 0., 2.]]
        listwise_weights = [[2.], [1.]]
        listwise_weights_expanded = [[2.] * 3, [1.] * 3]
        itemwise_weights = [[2., 3., 4.], [1., 1., 1.]]
        default_weights = [1.] * 3
        weights_feature_name = 'weights'
        list_size = 3.
        features = {}

        loss_fn = ranking_losses.make_loss_fn(loss_key)
        with self.cached_session():
            # Individual lists.
            self.assertAlmostEqual(loss_fn([labels[0]], [scores[0]],
                                           features).eval(),
                                   _batch_aggregation([
                                       _pairwise_loss(labels[0], scores[0],
                                                      default_weights,
                                                      loss_key)
                                   ]),
                                   places=5)
            self.assertAlmostEqual(loss_fn([labels[1]], [scores[1]],
                                           features).eval(),
                                   _batch_aggregation([
                                       _pairwise_loss(labels[1], scores[1],
                                                      default_weights,
                                                      loss_key)
                                   ]),
                                   places=5)

            # Itemwise weights.
            loss_fn = ranking_losses.make_loss_fn(
                loss_key, weights_feature_name=weights_feature_name)
            features[weights_feature_name] = [itemwise_weights[0]]
            self.assertAlmostEqual(loss_fn([labels[0]], [scores[0]],
                                           features).eval(),
                                   _batch_aggregation([
                                       _pairwise_loss(labels[0], scores[0],
                                                      itemwise_weights[0],
                                                      loss_key)
                                   ]),
                                   places=5)

            features[weights_feature_name] = [itemwise_weights[1]]
            self.assertAlmostEqual(loss_fn([labels[1]], [scores[1]],
                                           features).eval(),
                                   _batch_aggregation([
                                       _pairwise_loss(labels[1], scores[1],
                                                      itemwise_weights[1],
                                                      loss_key)
                                   ]),
                                   places=5)

            # Multiple lists.
            features[weights_feature_name] = listwise_weights
            self.assertAlmostEqual(
                loss_fn(labels, scores, features).eval(),
                _batch_aggregation([
                    _pairwise_loss(labels[0], scores[0],
                                   listwise_weights_expanded[0], loss_key),
                    _pairwise_loss(labels[1], scores[1],
                                   listwise_weights_expanded[1], loss_key)
                ]),
                places=5)

            # Test LambdaWeight.
            lambda_weight = ranking_losses.DCGLambdaWeight(
                rank_discount_fn=lambda r: 1. / math_ops.log1p(r),
                smooth_fraction=1.)
            loss_fn = ranking_losses.make_loss_fn(
                loss_key,
                weights_feature_name=weights_feature_name,
                lambda_weight=lambda_weight)
            self.assertAlmostEqual(
                loss_fn(labels, scores, features).eval(),
                _batch_aggregation([
                    _pairwise_loss(labels[0],
                                   scores[0],
                                   listwise_weights_expanded[0],
                                   loss_key,
                                   rank_discount_form='LOG'),
                    _pairwise_loss(labels[1],
                                   scores[1],
                                   listwise_weights_expanded[1],
                                   loss_key,
                                   rank_discount_form='LOG')
                ]) * list_size,
                places=5)

            # Test loss reduction method.
            # Two reduction methods should return different loss values.
            loss_fn_1 = ranking_losses.make_loss_fn(
                loss_key, reduction=core_losses.Reduction.SUM)
            loss_fn_2 = ranking_losses.make_loss_fn(
                loss_key, reduction=core_losses.Reduction.MEAN)
            self.assertNotAlmostEqual(
                loss_fn_1(labels, scores, features).eval(),
                loss_fn_2(labels, scores, features).eval())
    def _check_pairwise_loss(self, loss_fn):
        """Helper function to test `loss_fn`."""
        scores = [[1., 3., 2.], [1., 2., 3.]]
        labels = [[0., 0., 1.], [0., 0., 2.]]
        listwise_weights = [[2.], [1.]]
        listwise_weights_expanded = [[2.] * 3, [1.] * 3]
        itemwise_weights = [[2., 3., 4.], [1., 1., 1.]]
        default_weights = [1.] * 3
        list_size = 3.
        loss_form_dict = {
            ranking_losses._pairwise_hinge_loss:
            ranking_losses.RankingLossKey.PAIRWISE_HINGE_LOSS,
            ranking_losses._pairwise_logistic_loss:
            ranking_losses.RankingLossKey.PAIRWISE_LOGISTIC_LOSS,
            ranking_losses._pairwise_soft_zero_one_loss:
            ranking_losses.RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS,
        }
        loss_form = loss_form_dict[loss_fn]
        with self.cached_session():
            # Individual lists.
            self.assertAlmostEqual(loss_fn([labels[0]], [scores[0]]).eval(),
                                   _batch_aggregation([
                                       _pairwise_loss(labels[0], scores[0],
                                                      default_weights,
                                                      loss_form)
                                   ]),
                                   places=5)
            self.assertAlmostEqual(loss_fn([labels[1]], [scores[1]]).eval(),
                                   _batch_aggregation([
                                       _pairwise_loss(labels[1], scores[1],
                                                      default_weights,
                                                      loss_form)
                                   ]),
                                   places=5)

            # Itemwise weights.
            self.assertAlmostEqual(loss_fn([labels[0]], [scores[0]],
                                           weights=[itemwise_weights[0]
                                                    ]).eval(),
                                   _batch_aggregation([
                                       _pairwise_loss(labels[0], scores[0],
                                                      itemwise_weights[0],
                                                      loss_form)
                                   ]),
                                   places=5)

            self.assertAlmostEqual(loss_fn([labels[1]], [scores[1]],
                                           weights=[itemwise_weights[1]
                                                    ]).eval(),
                                   _batch_aggregation([
                                       _pairwise_loss(labels[1], scores[1],
                                                      itemwise_weights[1],
                                                      loss_form)
                                   ]),
                                   places=5)

            # Multiple lists.
            self.assertAlmostEqual(
                loss_fn(labels, scores, weights=listwise_weights).eval(),
                _batch_aggregation([
                    _pairwise_loss(labels[0], scores[0],
                                   listwise_weights_expanded[0], loss_form),
                    _pairwise_loss(labels[1], scores[1],
                                   listwise_weights_expanded[1], loss_form)
                ]),
                places=5)

            # Test LambdaWeight.
            lambda_weight = ranking_losses.DCGLambdaWeight(
                rank_discount_fn=lambda r: 1. / math_ops.log1p(r),
                smooth_fraction=1.)
            self.assertAlmostEqual(
                loss_fn(labels,
                        scores,
                        weights=listwise_weights,
                        lambda_weight=lambda_weight).eval(),
                _batch_aggregation([
                    _pairwise_loss(labels[0],
                                   scores[0],
                                   listwise_weights_expanded[0],
                                   loss_form,
                                   rank_discount_form='LOG'),
                    _pairwise_loss(labels[1],
                                   scores[1],
                                   listwise_weights_expanded[1],
                                   loss_form,
                                   rank_discount_form='LOG')
                ]) * list_size,
                places=5)

            # Test loss reduction method.
            # Two reduction methods should return different loss values.
            reduced_1 = loss_fn(labels,
                                scores,
                                reduction=core_losses.Reduction.SUM).eval()
            reduced_2 = loss_fn(labels,
                                scores,
                                reduction=core_losses.Reduction.MEAN).eval()
            self.assertNotAlmostEqual(reduced_1, reduced_2)