Esempio n. 1
0
 def test_expected_calibration_error_all_right(self):
   num_bins = 90
   ece = metrics_lib.expected_calibration_error(
       np.ones(10), np.ones(10), bins=num_bins)
   self.assertAlmostEqual(ece, 0.)
   ece = metrics_lib.expected_calibration_error(
       np.zeros(10), np.zeros(10), bins=num_bins)
   self.assertAlmostEqual(ece, 0.)
Esempio n. 2
0
  def test_expected_calibration_error(self):
    np.random.seed(1)
    nsamples = 100
    probs = np.linspace(0, 1, nsamples)
    labels = np.random.rand(nsamples) < probs
    ece = metrics_lib.expected_calibration_error(probs, labels)
    bad_ece = metrics_lib.expected_calibration_error(probs / 2, labels)

    self.assertBetween(ece, 0, 1)
    self.assertBetween(bad_ece, 0, 1)
    self.assertLess(ece, bad_ece)

    bins = metrics_lib.get_quantile_bins(10, probs)
    quantile_ece = metrics_lib.expected_calibration_error(probs, labels, bins)
    bad_quantile_ece = metrics_lib.expected_calibration_error(
        probs / 2, labels, bins)

    self.assertBetween(quantile_ece, 0, 1)
    self.assertBetween(bad_quantile_ece, 0, 1)
    self.assertLess(quantile_ece, bad_quantile_ece)
Esempio n. 3
0
 def test_expected_calibration_error_bad_input(self):
   with self.assertRaises(ValueError):
     metrics_lib.expected_calibration_error(np.ones(1), np.ones(1))
   with self.assertRaises(ValueError):
     metrics_lib.expected_calibration_error(np.ones(100), np.ones(1))
   with self.assertRaises(ValueError):
     metrics_lib.expected_calibration_error(np.ones(100), np.ones(100) * 0.5)