def evaluate(function, calibrator, dist, n): zs = dist(size=n) ps = function(zs) phats = calibrator.calibrate(zs) bins = utils.get_discrete_bins(phats) data = list(zip(phats, ps)) binned_data = utils.bin(data, bins) return utils.plugin_ce(binned_data)**2
def eval_top_calibration(probs, logits, labels, plugin=True): correct = (utils.get_top_predictions(logits) == labels) data = list(zip(probs, correct)) bins = utils.get_discrete_bins(probs) binned_data = utils.bin(data, bins) if plugin: return utils.plugin_ce(binned_data)**2 else: return utils.improved_unbiased_square_ce(binned_data)
def eval_marginal_calibration(probs, logits, labels, plugin=True): ces = [] # Compute the calibration error per class, then take the average. k = logits.shape[1] labels_one_hot = utils.get_labels_one_hot(np.array(labels), k) for c in range(k): probs_c = probs[:, c] labels_c = labels_one_hot[:, c] data_c = list(zip(probs_c, labels_c)) bins_c = utils.get_discrete_bins(probs_c) binned_data_c = utils.bin(data_c, bins_c) ce_c = utils.plugin_ce(binned_data_c)**2 ces.append(ce_c) return np.mean(ces)
def estimator(data): binned_data = utils.bin(data, bins) return utils.plugin_ce(binned_data, power=lp)
def mse(data): logits, labels = zip(*data) return get_mse_ce(logits, labels, lambda x: utils.plugin_ce(x)**2)[0]
def plugin_ce_squared(data): logits, labels = zip(*data) return get_mse_ce(logits, labels, lambda x: utils.plugin_ce(x)**2)[1]
def calibrate_marginals_experiment(logits, labels, k): num_calib = 3000 num_bin = 3000 num_cert = 4000 assert (logits.shape[0] == num_calib + num_bin + num_cert) num_bins = 100 bootstrap_samples = 100 # First split by label? To ensure equal class numbers? Do this later. labels = utils.get_labels_one_hot(labels[:, 0], k) mse = np.mean(np.square(labels - logits)) print('original mse is ', mse) calib_logits = logits[:num_calib, :] calib_labels = labels[:num_calib, :] bin_logits = logits[num_calib:num_calib + num_bin, :] bin_labels = labels[num_calib:num_calib + num_bin, :] cert_logits = logits[num_calib + num_bin:, :] cert_labels = labels[num_calib + num_bin:, :] mses = [] unbiased_ces = [] biased_ces = [] std_unbiased_ces = [] std_biased_ces = [] for num_bins in range(10, 21, 10): # Train a platt scaler and binner. platts = [] platt_binners_equal_points = [] for l in range(k): platt_l = utils.get_platt_scaler(calib_logits[:, l], calib_labels[:, l]) platts.append(platt_l) cal_logits_l = platt_l(calib_logits[:, l]) # bins_l = utils.get_equal_bins(cal_logits_l, num_bins=num_bins) # Get # bins_l = utils.get_equal_prob_bins(num_bins=num_bins) bins_l = [0.0012, 0.05, 0.01, 0.95, 0.985, 1.0] cal_bin_logits_l = platt_l(bin_logits[:, l]) platt_binner_l = utils.get_discrete_calibrator( cal_bin_logits_l, bins_l) platt_binners_equal_points.append(platt_binner_l) # Write a function that takes data and outputs the mse, ce def get_mse_ce(logits, labels, ce_est): mses = [] ces = [] logits = np.array(logits) labels = np.array(labels) for l in range(k): cal_logits_l = platt_binners_equal_points[l](platts[l]( logits[:, l])) data = list(zip(cal_logits_l, labels[:, l])) bins_l = utils.get_discrete_bins(cal_logits_l) binned_data = utils.bin(data, bins_l) # probs = platts[l](logits[:, l]) # for p in [1, 5, 10, 20, 50, 85, 88.5, 92, 94, 96, 98, 100]: # print(np.percentile(probs, p)) # import time # time.sleep(100) # print('lengths') # print([len(d) for d in binned_data]) ces.append(ce_est(binned_data)) mses.append( np.mean([(prob - label)**2 for prob, label in data])) return np.mean(mses), np.mean(ces) def plugin_ce_squared(data): logits, labels = zip(*data) return get_mse_ce(logits, labels, lambda x: utils.plugin_ce(x)**2)[1] def mse(data): logits, labels = zip(*data) return get_mse_ce(logits, labels, lambda x: utils.plugin_ce(x)**2)[0] def unbiased_ce_squared(data): logits, labels = zip(*data) return get_mse_ce(logits, labels, utils.improved_unbiased_square_ce)[1] mse, improved_unbiased_ce = get_mse_ce( cert_logits, cert_labels, utils.improved_unbiased_square_ce) mse, biased_ce = get_mse_ce(cert_logits, cert_labels, lambda x: utils.plugin_ce(x)**2) mses.append(mse) unbiased_ces.append(improved_unbiased_ce) biased_ces.append(biased_ce) print('biased ce: ', np.sqrt(biased_ce)) print('mse: ', mse) print('improved ce: ', np.sqrt(improved_unbiased_ce)) data = list(zip(list(cert_logits), list(cert_labels))) std_biased_ces.append( utils.bootstrap_std(data, plugin_ce_squared, num_samples=bootstrap_samples)) std_unbiased_ces.append( utils.bootstrap_std(data, unbiased_ce_squared, num_samples=bootstrap_samples)) std_multiplier = 1.3 # For one sided 90% confidence interval. upper_unbiased_ces = list( map(lambda p: np.sqrt(p[0] + std_multiplier * p[1]), zip(unbiased_ces, std_unbiased_ces))) upper_biased_ces = list( map(lambda p: np.sqrt(p[0] + std_multiplier * p[1]), zip(biased_ces, std_biased_ces))) # Get points on the Pareto curve, and plot them. def get_pareto_points(data): pareto_points = [] def dominated(p1, p2): return p1[0] >= p2[0] and p1[1] >= p2[1] for datum in data: num_dominated = sum(map(lambda x: dominated(datum, x), data)) if num_dominated == 1: pareto_points.append(datum) return pareto_points print( get_pareto_points( list(zip(upper_unbiased_ces, mses, list(range(5, 101, 5)))))) print( get_pareto_points( list(zip(upper_biased_ces, mses, list(range(5, 101, 5)))))) plot_unbiased_ces, plot_unbiased_mses = zip( *get_pareto_points(list(zip(upper_unbiased_ces, mses)))) plot_biased_ces, plot_biased_mses = zip( *get_pareto_points(list(zip(upper_biased_ces, mses)))) plt.title("MSE vs Calibration Error") plt.scatter(plot_unbiased_ces, plot_unbiased_mses, c='red', marker='o', label='Ours') plt.scatter(plot_biased_ces, plot_biased_mses, c='blue', marker='s', label='Plugin') plt.legend(loc='upper left') plt.ylim(0.0, 0.013) plt.xlabel("L2 Calibration Error") plt.ylabel("Mean-Squared Error") plt.show()
def compare_estimators(logits, labels, platt_data_size, bin_data_size, num_bins, ver_base_size=2000, ver_size_increment=1000, num_resamples=100, save_name='cmp_est'): # Convert logits to prediction, probs. predictions = utils.get_top_predictions(logits) probs = utils.get_top_probs(logits) correct = (predictions == labels) # Platt scale on first chunk of data platt = utils.get_platt_scaler(probs[:platt_data_size], correct[:platt_data_size]) platt_probs = platt(probs) estimator_names = ['biased', 'unbiased'] estimators = [ lambda x: utils.plugin_ce(x)**2, utils.improved_unbiased_square_ce ] bins = utils.get_equal_bins(platt_probs[:platt_data_size + bin_data_size], num_bins=num_bins) binner = utils.get_discrete_calibrator( platt_probs[platt_data_size:platt_data_size + bin_data_size], bins) verification_probs = binner(platt_probs[platt_data_size + bin_data_size:]) verification_correct = correct[platt_data_size + bin_data_size:] verification_data = list(zip(verification_probs, verification_correct)) verification_sizes = list( range(ver_base_size, len(verification_probs) + 1, ver_size_increment)) # We want to compare the two estimators when varying the number of samples. # However, a single point of comparison does not tell us much about the estimators. # So we use resampling - we resample from the test set many times, and run the estimators # on the resamples. We stores these values. This gives us a sense of the range of values # the estimator might output. # So estimates[i][j][k] stores the estimate when using estimator i, with verification_sizes[j] # samples, in the k-th resampling. estimates = np.zeros( (len(estimators), len(verification_sizes), num_resamples)) # We also store the certified estimates. These represent the upper bounds we get using # each estimator. They are computing using the std-dev of the estimator estimated by # Bootstrap. cert_estimates = np.zeros( (len(estimators), len(verification_sizes), num_resamples)) for ver_idx, verification_size in zip(range(len(verification_sizes)), verification_sizes): for k in range(num_resamples): # Resample indices = np.random.choice(list(range(len(verification_data))), size=verification_size, replace=True) cur_verification_data = [verification_data[i] for i in indices] cur_verification_probs = [verification_probs[i] for i in indices] bins = utils.get_discrete_bins(cur_verification_probs) # Compute estimates for each estimator. for i in range(len(estimators)): def estimator(data): binned_data = utils.bin(data, bins) return estimators[i](binned_data) cur_estimate = estimator(cur_verification_data) estimates[i][ver_idx][k] = cur_estimate # cert_resampling_estimates[j].append( # cur_estimate + utils.bootstrap_std(cur_verification_data, estimator, num_samples=20)) estimates = np.sort(estimates, axis=-1) lower_bound = int(0.1 * num_resamples) median = int(0.5 * num_resamples) upper_bound = int(0.9 * num_resamples) lower_estimates = estimates[:, :, lower_bound] upper_estimates = estimates[:, :, upper_bound] median_estimates = estimates[:, :, median] # We can also compute the MSEs of the estimator. bins = utils.get_discrete_bins(verification_probs) true_calibration = utils.plugin_ce(utils.bin(verification_data, bins))**2 print(true_calibration) print(np.sqrt(np.mean(estimates[1, -1, :]))) errors = np.abs(estimates - true_calibration) accumulated_errors = np.mean(errors, axis=-1) error_bars_90 = 1.645 * np.std(errors, axis=-1) / np.sqrt(num_resamples) print(accumulated_errors) plt.errorbar(verification_sizes, accumulated_errors[0], yerr=[error_bars_90[0], error_bars_90[0]], barsabove=True, color='red', capsize=4, label='plugin') plt.errorbar(verification_sizes, accumulated_errors[1], yerr=[error_bars_90[1], error_bars_90[1]], barsabove=True, color='blue', capsize=4, label='debiased') plt.ylabel("Mean-Squared-Error") plt.xlabel("Number of Samples") plt.legend(loc='upper right') plt.show() plt.ylabel("Number of estimates") plt.xlabel("Absolute deviation from ground truth") bins = np.linspace(np.min(errors[:, 0, :]), np.max(errors[:, 0, :]), 40) plt.hist(errors[0][0], bins, alpha=0.5, label='plugin') plt.hist(errors[1][0], bins, alpha=0.5, label='debiased') plt.legend(loc='upper right') plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=num_resamples)) plt.show()
def eval_top_calibration(probs, logits, labels): correct = (utils.get_top_predictions(logits) == labels) data = list(zip(probs, correct)) bins = utils.get_discrete_bins(probs) binned_data = utils.bin(data, bins) return utils.plugin_ce(binned_data)**2