def test_dcg_lambda_weight_is_serializable(self): self.assertIsSerializable(losses.DCGLambdaWeight()) self.assertIsSerializable( losses.DCGLambdaWeight(gain_fn=utils.identity, rank_discount_fn=utils.log2_inverse)) self.assertIsSerializable(losses.NDCGLambdaWeight()) self.assertIsSerializable( losses.NDCGLambdaWeight(gain_fn=utils.identity, rank_discount_fn=utils.log2_inverse))
def test_softmax_loss(self): scores = [[1., 3., 2.], [1., 2., 3.], [1., 2., 3.]] labels = [[0., 0., 1.], [0., 0., 2.], [0., 0., 0.]] weights = [[2.], [1.], [1.]] loss = losses.get(loss=losses.RankingLossKey.SOFTMAX_LOSS) self.assertAlmostEqual( loss(labels, scores).numpy(), -(ln(_softmax(scores[0])[2]) + ln(_softmax(scores[1])[2]) * 2.) / 3., places=5) self.assertAlmostEqual(loss(labels, scores, weights).numpy(), -(ln(_softmax(scores[0])[2]) * 2. + ln(_softmax(scores[1])[2]) * 2. * 1.) / 3., places=5) # Test LambdaWeight. rank_discount_fn = lambda r: 1. / tf.math.log1p(r) lambda_weight = losses.DCGLambdaWeight( rank_discount_fn=rank_discount_fn) loss = losses.get(loss=losses.RankingLossKey.SOFTMAX_LOSS, lambda_weight=lambda_weight) self.assertAlmostEqual( loss(labels, scores).numpy(), -(ln(_softmax(scores[0])[2]) / ln(1. + 2.) + ln(_softmax(scores[1])[2]) * 2. / ln(1. + 1.)) / 3., places=5)
def _check_pairwise_loss(self, loss_form): """Helper function to test `loss_fn`.""" scores = [[1., 3., 2.], [1., 2., 3.]] labels = [[0., 0., 1.], [0., 0., 2.]] listwise_weights = [[2.], [1.]] listwise_weights_expanded = [[2.] * 3, [1.] * 3] itemwise_weights = [[2., 3., 4.], [1., 1., 1.]] default_weights = [1.] * 3 list_size = 3. loss_form_dict = { 'hinge': losses.get(loss=losses.RankingLossKey.PAIRWISE_HINGE_LOSS, name='hinge'), 'logistic': losses.get(loss=losses.RankingLossKey.PAIRWISE_LOGISTIC_LOSS, name='logistic'), 'soft_zero_one': losses.get(loss=losses.RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS, name='soft_zero_one'), } loss_fn = loss_form_dict[loss_form] # Individual lists. self.assertAlmostEqual(loss_fn([labels[0]], [scores[0]]).numpy(), _batch_aggregation([ _pairwise_loss(labels[0], scores[0], default_weights, loss_form) ]), places=5) self.assertAlmostEqual(loss_fn([labels[1]], [scores[1]]).numpy(), _batch_aggregation([ _pairwise_loss(labels[1], scores[1], default_weights, loss_form) ]), places=5) # Itemwise weights. self.assertAlmostEqual(loss_fn([labels[0]], [scores[0]], sample_weight=[itemwise_weights[0] ]).numpy(), _batch_aggregation([ _pairwise_loss(labels[0], scores[0], itemwise_weights[0], loss_form) ]), places=5) self.assertAlmostEqual(loss_fn([labels[1]], [scores[1]], sample_weight=[itemwise_weights[1] ]).numpy(), _batch_aggregation([ _pairwise_loss(labels[1], scores[1], itemwise_weights[1], loss_form) ]), places=5) # Multiple lists. self.assertAlmostEqual(loss_fn(labels, scores, sample_weight=listwise_weights).numpy(), _batch_aggregation([ _pairwise_loss(labels[0], scores[0], listwise_weights_expanded[0], loss_form), _pairwise_loss(labels[1], scores[1], listwise_weights_expanded[1], loss_form) ]), places=5) # Test LambdaWeight. rank_discount_fn = lambda r: 1. / tf.math.log1p(r) lambda_weight = losses.DCGLambdaWeight( rank_discount_fn=rank_discount_fn, smooth_fraction=1.) loss_form_dict = { 'hinge': losses.get(loss=losses.RankingLossKey.PAIRWISE_HINGE_LOSS, name='hinge', lambda_weight=lambda_weight), 'logistic': losses.get(loss=losses.RankingLossKey.PAIRWISE_LOGISTIC_LOSS, name='logistic', lambda_weight=lambda_weight), 'soft_zero_one': losses.get(loss=losses.RankingLossKey.PAIRWISE_SOFT_ZERO_ONE_LOSS, name='soft_zero_one', lambda_weight=lambda_weight), } loss_fn = loss_form_dict[loss_form] self.assertAlmostEqual(loss_fn(labels, scores, sample_weight=listwise_weights).numpy(), _batch_aggregation([ _pairwise_loss(labels[0], scores[0], listwise_weights_expanded[0], loss_form, rank_discount_form='LOG'), _pairwise_loss(labels[1], scores[1], listwise_weights_expanded[1], loss_form, rank_discount_form='LOG') ]) * list_size, places=5)