def test_soft_trimmed_degenerated(self): """Tests several possible degenerated cases.""" with self.assertRaises(ValueError): losses.SoftTrimmedRegressionLoss(start_quantile=0.6, end_quantile=0.1) loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=0.1, end_quantile=0.3) weights, index = loss_fn._get_target_weights_and_indices() self.assertLen(weights, 3) self.assertEqual(index, 1) loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=0.0, end_quantile=0.2) weights, index = loss_fn._get_target_weights_and_indices() self.assertLen(weights, 2) self.assertEqual(index, 0) loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=0.4, end_quantile=1.0) weights, index = loss_fn._get_target_weights_and_indices() self.assertLen(weights, 2) self.assertEqual(index, 1) loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=0.0, end_quantile=1.0) weights, index = loss_fn._get_target_weights_and_indices() self.assertLen(weights, 1) self.assertEqual(index, 0)
def test_soft_trimmed_degenerated(self, start, end): """Tests several possible degenerated cases.""" loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=start, end_quantile=end) try: loss_fn(self._y_true, self._y_pred) except ValueError: self.fail( 'SoftTrimmedRegressionLoss raised ValueError unexpectedly!')
def test_soft_trimmed(self, start, end, power): loss_fn = losses.SoftTrimmedRegressionLoss(start_quantile=start, end_quantile=end, power=power) loss = loss_fn(self._y_true, self._y_pred) start_index = int(start * self._num_points) end_index = int(end * self._num_points) selected = tf.pow(self._values[start_index:end_index], power) expected_loss = tf.math.reduce_mean(selected) self.assertAllClose(loss, expected_loss, 0.2, 0.2)