コード例 #1
0
    def test_mean_squared_loss_with_invalid_labels(self):
        scores = [[1., 3., 2.]]
        labels = [[0., -1., 1.]]

        loss = losses.MeanSquaredLoss()
        self.assertAlmostEqual(loss(labels, scores).numpy(), (1. + 1.) / 3.,
                               places=5)
コード例 #2
0
    def test_mean_squared_loss(self):
        scores = [[0.2, 0.5, 0.3], [0.2, 0.3, 0.5], [0.2, 0.3, 0.5]]
        labels = [[0., 0., 1.], [0., 0., 2.], [0., 0., 0.]]
        weights = [[2.], [1.], [1.]]

        loss = losses.MeanSquaredLoss()
        self.assertAlmostEqual(loss(labels, scores).numpy(),
                               (_mean_squared_error(labels[0], scores[0]) +
                                _mean_squared_error(labels[1], scores[1]) +
                                _mean_squared_error(labels[2], scores[2])) /
                               9.,
                               places=5)
        self.assertAlmostEqual(
            loss(labels, scores, weights).numpy(),
            (_mean_squared_error(labels[0], scores[0]) * 2.0 +
             _mean_squared_error(labels[1], scores[1]) +
             _mean_squared_error(labels[2], scores[2])) / 9.,
            places=5)
コード例 #3
0
 def test_pointwise_losses_are_serializable(self):
     self.assertIsLossSerializable(
         losses.ClickEMLoss(exam_loss_weight=2.0, rel_loss_weight=5.0))
     self.assertIsLossSerializable(losses.SigmoidCrossEntropyLoss())
     self.assertIsLossSerializable(losses.MeanSquaredLoss())