def compare_scaling_ce(probs, labels, platt_data_size, bin_data_size, num_bins, ver_base_size=2000, ver_size_increment=1000, max_ver_size=7000, num_resamples=1000, save_prefix='./saved_files/debiased_estimator/', lp=1, Calibrator=cal.PlattTopCalibrator): calibrator = Calibrator(num_calibration=platt_data_size, num_bins=num_bins) calibrator.train_calibration(probs[:platt_data_size], labels[:platt_data_size]) predictions = cal.get_top_predictions(probs) correct = (predictions == labels).astype(np.int32) verification_correct = correct[bin_data_size:] verification_probs = calibrator.calibrate(probs[bin_data_size:]) verification_sizes = list( range(ver_base_size, 1 + min(max_ver_size, len(verification_probs)), ver_size_increment)) binning_probs = calibrator.calibrate(probs[:bin_data_size]) bins = cal.get_equal_bins(binning_probs, num_bins=num_bins) def plugin_estimator(p, l): data = list(zip(p, l)) binned_data = cal.bin(data, bins) return cal.plugin_ce(binned_data, power=lp) def debiased_estimator(p, l): data = list(zip(p, l)) binned_data = cal.bin(data, bins) if lp == 2: return cal.unbiased_l2_ce(binned_data) else: return cal.normal_debiased_ce(binned_data, power=lp) estimators = [plugin_estimator, debiased_estimator] estimates = get_estimates(estimators, verification_probs, verification_correct, verification_sizes, num_resamples) true_calibration = plugin_estimator(verification_probs, verification_correct) print(true_calibration) print(np.sqrt(np.mean(estimates[1, -1, :]))) errors = np.abs(estimates - true_calibration) plot_mse_curve(errors, verification_sizes, num_resamples, save_prefix, num_bins) plot_histograms(errors, num_resamples, save_prefix, num_bins)
def compare_scaling_binning_squared_ce( probs, labels, platt_data_size, bin_data_size, num_bins, ver_base_size=2000, ver_size_increment=1000, max_ver_size=7000, num_resamples=1000, save_prefix='./saved_files/debiased_estimator/', lp=2, Calibrator=cal.PlattBinnerTopCalibrator): calibrator = Calibrator(num_calibration=platt_data_size, num_bins=num_bins) calibrator.train_calibration(probs[:platt_data_size], labels[:platt_data_size]) predictions = cal.get_top_predictions(probs) correct = (predictions == labels).astype(np.int32) verification_correct = correct[bin_data_size:] verification_probs = calibrator.calibrate(probs[bin_data_size:]) verification_sizes = list( range(ver_base_size, 1 + min(max_ver_size, len(verification_probs)), ver_size_increment)) estimators = [ lambda p, l: cal.get_calibration_error(p, l, p=lp, debias=False)**lp, lambda p, l: cal.get_calibration_error(p, l, p=lp, debias=True)**lp ] estimates = get_estimates(estimators, verification_probs, verification_correct, verification_sizes, num_resamples) true_calibration = cal.get_calibration_error(verification_probs, verification_correct, p=lp, debias=False)**lp print(true_calibration) print(np.sqrt(np.mean(estimates[1, -1, :]))) errors = np.abs(estimates - true_calibration) plot_mse_curve(errors, verification_sizes, num_resamples, save_prefix, num_bins) plot_histograms(errors, num_resamples, save_prefix, num_bins)
def lower_bound_experiment(logits, labels, calibration_data_size, bin_data_size, bins_list, save_name='cmp_est', binning_func=cal.get_equal_bins, lp=2, num_samples=1000): # Shuffle the logits and labels. np.random.seed(0) # Keep results consistent. indices = np.random.choice(list(range(len(logits))), size=len(logits), replace=False) logits = [logits[i] for i in indices] labels = [labels[i] for i in indices] predictions = cal.get_top_predictions(logits) probs = cal.get_top_probs(logits) correct = (predictions == labels) print('num_correct: ', sum(correct)) # Platt scale on first chunk of data platt = cal.get_platt_scaler(probs[:calibration_data_size], correct[:calibration_data_size]) platt_probs = platt(probs) lower, middle, upper = [], [], [] for num_bins in bins_list: bins = binning_func(platt_probs[:calibration_data_size + bin_data_size], num_bins=num_bins) verification_probs = platt_probs[calibration_data_size + bin_data_size:] verification_correct = correct[calibration_data_size + bin_data_size:] verification_data = list(zip(verification_probs, verification_correct)) def estimator(data): binned_data = cal.bin(data, bins) return cal.plugin_ce(binned_data, power=lp) print('estimate: ', estimator(verification_data)) estimate_interval = cal.bootstrap_uncertainty(verification_data, estimator, num_samples=1000) lower.append(estimate_interval[0]) middle.append(estimate_interval[1]) upper.append(estimate_interval[2]) print('interval: ', estimate_interval) # Plot the results. lower_errors = np.array(middle) - np.array(lower) upper_errors = np.array(upper) - np.array(middle) plt.clf() font = {'family': 'normal', 'size': 18} rc('font', **font) plt.errorbar(bins_list, middle, yerr=[lower_errors, upper_errors], barsabove=True, fmt='none', color='black', capsize=4) plt.scatter(bins_list, middle, color='black') plt.xlabel(r"No. of bins") if lp == 2: plt.ylabel("Calibration error") else: plt.ylabel("l%d Calibration error" % lp) plt.xscale('log', basex=2) ax = plt.gca() ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.yaxis.set_ticks_position('left') ax.xaxis.set_ticks_position('bottom') plt.tight_layout() plt.savefig(save_name)
def eval_top_calibration(probs, probs, labels): correct = (cal.get_top_predictions(probs) == labels) data = list(zip(probs, correct)) bins = cal.get_discrete_bins(probs) binned_data = cal.bin(data, bins) return cal.plugin_ce(binned_data)**2