def test_expected_calibration_error_all_right(self): num_bins = 90 ece = metrics_lib.expected_calibration_error( np.ones(10), np.ones(10), bins=num_bins) self.assertAlmostEqual(ece, 0.) ece = metrics_lib.expected_calibration_error( np.zeros(10), np.zeros(10), bins=num_bins) self.assertAlmostEqual(ece, 0.)
def test_expected_calibration_error(self): np.random.seed(1) nsamples = 100 probs = np.linspace(0, 1, nsamples) labels = np.random.rand(nsamples) < probs ece = metrics_lib.expected_calibration_error(probs, labels) bad_ece = metrics_lib.expected_calibration_error(probs / 2, labels) self.assertBetween(ece, 0, 1) self.assertBetween(bad_ece, 0, 1) self.assertLess(ece, bad_ece) bins = metrics_lib.get_quantile_bins(10, probs) quantile_ece = metrics_lib.expected_calibration_error(probs, labels, bins) bad_quantile_ece = metrics_lib.expected_calibration_error( probs / 2, labels, bins) self.assertBetween(quantile_ece, 0, 1) self.assertBetween(bad_quantile_ece, 0, 1) self.assertLess(quantile_ece, bad_quantile_ece)
def test_expected_calibration_error_bad_input(self): with self.assertRaises(ValueError): metrics_lib.expected_calibration_error(np.ones(1), np.ones(1)) with self.assertRaises(ValueError): metrics_lib.expected_calibration_error(np.ones(100), np.ones(1)) with self.assertRaises(ValueError): metrics_lib.expected_calibration_error(np.ones(100), np.ones(100) * 0.5)