Пример #1
0
    def test_soft_trimmed_degenerated(self):
        """Tests several possible degenerated cases."""
        with self.assertRaises(ValueError):
            losses.SoftTrimmedRegressionLoss(start_quantile=0.6,
                                             end_quantile=0.1)

        loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=0.1,
                                                   end_quantile=0.3)
        weights, index = loss_fn._get_target_weights_and_indices()
        self.assertLen(weights, 3)
        self.assertEqual(index, 1)

        loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=0.0,
                                                   end_quantile=0.2)
        weights, index = loss_fn._get_target_weights_and_indices()
        self.assertLen(weights, 2)
        self.assertEqual(index, 0)

        loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=0.4,
                                                   end_quantile=1.0)
        weights, index = loss_fn._get_target_weights_and_indices()
        self.assertLen(weights, 2)
        self.assertEqual(index, 1)

        loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=0.0,
                                                   end_quantile=1.0)
        weights, index = loss_fn._get_target_weights_and_indices()
        self.assertLen(weights, 1)
        self.assertEqual(index, 0)
Пример #2
0
 def test_soft_trimmed_degenerated(self, start, end):
     """Tests several possible degenerated cases."""
     loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=start,
                                                end_quantile=end)
     try:
         loss_fn(self._y_true, self._y_pred)
     except ValueError:
         self.fail(
             'SoftTrimmedRegressionLoss raised ValueError unexpectedly!')
Пример #3
0
 def test_soft_trimmed(self, start, end, power):
     loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=start,
                                                end_quantile=end,
                                                power=power)
     loss = loss_fn(self._y_true, self._y_pred)
     start_index = int(start * self._num_points)
     end_index = int(end * self._num_points)
     selected = tf.pow(self._values[start_index:end_index], power)
     expected_loss = tf.math.reduce_mean(selected)
     self.assertAllClose(loss, expected_loss, 0.2, 0.2)