Esempio n. 1
0
  def test_pairwise_logistic_loss_sum_with_invalid_labels(self):
    scores = [[1., 3., 2.]]
    labels = [[0., -1., 1.]]

    loss = losses.PairwiseLogisticLoss(reduction=tf.losses.Reduction.SUM)
    self.assertAlmostEqual(
        loss(labels, scores).numpy(), ln(1 + math.exp(-1.)), places=5)
Esempio n. 2
0
    def test_pairwise_losses_are_serializable(self):
        self.assertIsLossSerializable(
            losses.PairwiseHingeLoss(lambda_weight=self._lambda_weight))

        self.assertIsLossSerializable(
            losses.PairwiseLogisticLoss(lambda_weight=self._lambda_weight))

        self.assertIsLossSerializable(
            losses.PairwiseSoftZeroOneLoss(lambda_weight=self._lambda_weight))
Esempio n. 3
0
    def _check_pairwise_loss(self, loss_form):
        """Helper function to test `loss_fn`."""
        scores = [[1., 3., 2.], [1., 2., 3.]]
        labels = [[0., 0., 1.], [0., 0., 2.]]
        listwise_weights = [[2.], [1.]]
        listwise_weights_expanded = [[2.] * 3, [1.] * 3]
        itemwise_weights = [[2., 3., 4.], [1., 1., 1.]]
        default_weights = [1.] * 3
        list_size = 3.
        loss_form_dict = {
            'hinge': losses.PairwiseHingeLoss(name='hinge'),
            'logistic': losses.PairwiseLogisticLoss(name='logistic'),
            'soft_zero_one':
            losses.PairwiseSoftZeroOneLoss(name='soft_zero_one'),
        }
        loss_fn = loss_form_dict[loss_form]

        # Individual lists.
        self.assertAlmostEqual(loss_fn([labels[0]], [scores[0]]).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[0], scores[0],
                                                  default_weights, loss_form)
                               ]),
                               places=5)
        self.assertAlmostEqual(loss_fn([labels[1]], [scores[1]]).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[1], scores[1],
                                                  default_weights, loss_form)
                               ]),
                               places=5)

        # Itemwise weights.
        self.assertAlmostEqual(loss_fn([labels[0]], [scores[0]],
                                       sample_weight=[itemwise_weights[0]
                                                      ]).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[0], scores[0],
                                                  itemwise_weights[0],
                                                  loss_form)
                               ]),
                               places=5)

        self.assertAlmostEqual(loss_fn([labels[1]], [scores[1]],
                                       sample_weight=[itemwise_weights[1]
                                                      ]).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[1], scores[1],
                                                  itemwise_weights[1],
                                                  loss_form)
                               ]),
                               places=5)

        # Multiple lists.
        self.assertAlmostEqual(loss_fn(labels,
                                       scores,
                                       sample_weight=listwise_weights).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[0], scores[0],
                                                  listwise_weights_expanded[0],
                                                  loss_form),
                                   _pairwise_loss(labels[1], scores[1],
                                                  listwise_weights_expanded[1],
                                                  loss_form)
                               ]),
                               places=5)

        # Test LambdaWeight.
        rank_discount_fn = lambda r: 1. / tf.math.log1p(r)
        lambda_weight = losses.DCGLambdaWeight(
            rank_discount_fn=rank_discount_fn, smooth_fraction=1.)
        loss_form_dict = {
            'hinge':
            losses.PairwiseHingeLoss(name='hinge',
                                     lambda_weight=lambda_weight),
            'logistic':
            losses.PairwiseLogisticLoss(name='logistic',
                                        lambda_weight=lambda_weight),
            'soft_zero_one':
            losses.PairwiseSoftZeroOneLoss(name='soft_zero_one',
                                           lambda_weight=lambda_weight),
        }
        loss_fn = loss_form_dict[loss_form]

        self.assertAlmostEqual(loss_fn(labels,
                                       scores,
                                       sample_weight=listwise_weights).numpy(),
                               _batch_aggregation([
                                   _pairwise_loss(labels[0],
                                                  scores[0],
                                                  listwise_weights_expanded[0],
                                                  loss_form,
                                                  rank_discount_form='LOG'),
                                   _pairwise_loss(labels[1],
                                                  scores[1],
                                                  listwise_weights_expanded[1],
                                                  loss_form,
                                                  rank_discount_form='LOG')
                               ]) * list_size,
                               places=5)