def test_get_binary_ce(self, p, true_ce): probs = [0.5, 0.5, 0.5, 0.6, 0.5, 0.6, 0.6, 0.6, 0.6] labels = [0, 1, 0, 1, 0, 1, 1, 1, 0] pred_ce = _get_ce(probs, labels, p, debias=False, num_bins=None, binning_scheme=get_discrete_bins) self.assertAlmostEqual(pred_ce, true_ce) wrapper_ce = get_calibration_error(probs, labels, p=p, debias=False) self.assertAlmostEqual(pred_ce, wrapper_ce) pred_ce = _get_ce(probs, labels, p, debias=True, num_bins=None, binning_scheme=get_discrete_bins) self.assertLess(pred_ce, true_ce)
def test_get_two_label_ce(self, p, true_ce): # Same as the previous test, except probs is now multi-dimensional. pt6 = [0.4, 0.6] pt5 = [0.5, 0.5] probs = [pt5, pt5, pt5, pt6, pt5, pt6, pt6, pt6, pt6] labels = [0, 1, 0, 1, 0, 1, 1, 1, 0] pred_ce = _get_ce(probs, labels, p, debias=False, num_bins=None, binning_scheme=get_discrete_bins) self.assertAlmostEqual(pred_ce, true_ce) # Check that the wrapper calls _get_ce with the right options. wrapper_ce = get_calibration_error(probs, labels, p=p, debias=False) self.assertAlmostEqual(pred_ce, wrapper_ce) # For the 2 label case, marginal calibration and top-label calibration should be the same. top_label_ce = get_calibration_error(probs, labels, p=p, debias=False, mode='top-label') self.assertAlmostEqual(top_label_ce, pred_ce) debiased_top_label_ce = get_calibration_error(probs, labels, p=p, debias=True, mode='top-label') self.assertLess(debiased_top_label_ce, pred_ce) debiased_pred_ce = _get_ce(probs, labels, p, debias=True, num_bins=None, binning_scheme=get_discrete_bins) self.assertLess(debiased_pred_ce, true_ce)
def test_get_three_label_ce(self, p, true_marginal_ce, true_top_ce): # Same as the previous test, except probs is now multi-dimensional. l0 = [0.6, 0.3, 0.1] l1 = [0.1, 0.8, 0.1] l2 = [0.1, 0.0, 0.9] probs = np.array([l0, l0, l0, l0, l0, l0, l1, l1, l1, l1, l2, l2, l2, l2]) labels = np.array([ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]) perm = np.random.permutation(len(labels)) probs, labels = probs[perm], labels[perm] pred_ce = _get_ce(probs, labels, p, debias=False, num_bins=None, binning_scheme=get_discrete_bins) self.assertAlmostEqual(pred_ce, true_marginal_ce) # Check that the wrapper calls _get_ce with the right options. wrapper_ce = get_calibration_error(probs, labels, p=p, debias=False) self.assertAlmostEqual(pred_ce, wrapper_ce) top_label_ce = get_calibration_error(probs, labels, p=p, debias=False, mode='top-label') self.assertAlmostEqual(top_label_ce, true_top_ce) debiased_top_label_ce = get_calibration_error(probs, labels, p=p, debias=True, mode='top-label') self.assertLess(debiased_top_label_ce, true_top_ce) debiased_pred_ce = _get_ce(probs, labels, p, debias=True, num_bins=None, binning_scheme=get_discrete_bins) self.assertLess(debiased_pred_ce, true_marginal_ce)