def test_mean_squared_loss_with_invalid_labels(self): scores = [[1., 3., 2.]] labels = [[0., -1., 1.]] loss = losses.MeanSquaredLoss() self.assertAlmostEqual(loss(labels, scores).numpy(), (1. + 1.) / 3., places=5)
def test_mean_squared_loss(self): scores = [[0.2, 0.5, 0.3], [0.2, 0.3, 0.5], [0.2, 0.3, 0.5]] labels = [[0., 0., 1.], [0., 0., 2.], [0., 0., 0.]] weights = [[2.], [1.], [1.]] loss = losses.MeanSquaredLoss() self.assertAlmostEqual(loss(labels, scores).numpy(), (_mean_squared_error(labels[0], scores[0]) + _mean_squared_error(labels[1], scores[1]) + _mean_squared_error(labels[2], scores[2])) / 9., places=5) self.assertAlmostEqual( loss(labels, scores, weights).numpy(), (_mean_squared_error(labels[0], scores[0]) * 2.0 + _mean_squared_error(labels[1], scores[1]) + _mean_squared_error(labels[2], scores[2])) / 9., places=5)
def test_pointwise_losses_are_serializable(self): self.assertIsLossSerializable( losses.ClickEMLoss(exam_loss_weight=2.0, rel_loss_weight=5.0)) self.assertIsLossSerializable(losses.SigmoidCrossEntropyLoss()) self.assertIsLossSerializable(losses.MeanSquaredLoss())