def cifar10_experiment_marginal_2_2_3000(): logits_file = 'cifar_logits.dat' logits, labels = utils.load_test_logits_labels(logits_file) bins_list = list(range(10, 101, 10)) num_trials = 20 num_calibration = 3000 l2_ces, l2_stddevs, mses = vary_bin_calibration( logits, labels, num_calibration, bins_list, Calibrators=[ calibrators.HistogramMarginalCalibrator, calibrators.PlattBinnerMarginalCalibrator ], eval_calibration=eval_marginal_calibration, eval_mse=eval_marginal_mse, num_trials=num_trials, resample=True) plot_mse_ce_curve(bins_list, l2_ces, mses, xlim=(0.0, 0.001), ylim=(0.0, 0.0075)) plot_ces(bins_list, l2_ces, l2_stddevs)
def cifar10_experiment_marginal_3_1_1000(): logits_file = 'cifar_logits.dat' logits, labels = utils.load_test_logits_labels(logits_file) bins_list = list(range(10, 101, 10)) num_trials = 100 num_calib = 1000 num_eval = 1000 l2_ces, l2_stddevs, mses = vary_bin_calibration( data_sampler=make_calibration_eval_data_sampler( logits, labels, num_calib, num_eval), num_bins_list=bins_list, Calibrators=[ calibrators.HistogramMarginalCalibrator, calibrators.PlattBinnerMarginalCalibrator ], calibration_evaluators=[ upper_bound_marginal_calibration_biased, upper_bound_marginal_calibration_unbiased ], eval_mse=utils.eval_marginal_mse, num_trials=num_trials) plot_mse_ce_curve(bins_list, l2_ces, mses, xlim=(0.0, 0.002), ylim=(0.0, 0.04)) plot_ces(bins_list, l2_ces, l2_stddevs)
def cifar_experiment(savefile, binning_func=utils.get_equal_bins, lp=2): np.random.seed(0) calibration_data_size = 1000 bin_data_size = 1000 logits, labels = utils.load_test_logits_labels('cifar_logits.dat') lower_bound_experiment(logits, labels, calibration_data_size, bin_data_size, bins_list=[2, 4, 8, 16, 32, 64, 128], save_name=savefile, binning_func=binning_func, lp=lp)
def imagenet_experiment_top_1_1_1000(): logits_file = 'imagenet_logits.dat' logits, labels = utils.load_test_logits_labels(logits_file) bins_list = list(range(10, 101, 10)) num_trials = 100 num_calibration = 1000 l2_ces, l2_stddevs, mses = vary_bin_calibration( data_sampler=make_calibration_data_sampler(logits, labels, num_calibration), num_bins_list=bins_list, Calibrators=[ calibrators.HistogramTopCalibrator, calibrators.PlattBinnerTopCalibrator ], calibration_evaluators=[eval_top_calibration, eval_top_calibration], eval_mse=utils.eval_top_mse, num_trials=num_trials) plot_mse_ce_curve(bins_list, l2_ces, mses) plot_ces(bins_list, l2_ces, l2_stddevs)
color='red', capsize=4, label='plugin') plt.errorbar(verification_sizes, accumulated_errors[1], yerr=[error_bars_90[1], error_bars_90[1]], barsabove=True, color='blue', capsize=4, label='debiased') plt.ylabel("Mean-Squared-Error") plt.xlabel("Number of Samples") plt.legend(loc='upper right') plt.show() plt.ylabel("Number of estimates") plt.xlabel("Absolute deviation from ground truth") bins = np.linspace(np.min(errors[:, 0, :]), np.max(errors[:, 0, :]), 40) plt.hist(errors[0][0], bins, alpha=0.5, label='plugin') plt.hist(errors[1][0], bins, alpha=0.5, label='debiased') plt.legend(loc='upper right') plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=num_resamples)) plt.show() if __name__ == "__main__": args = parser.parse_args() logits, labels = utils.load_test_logits_labels(args.logits_file) compare_estimators(logits, labels, args.platt_data_size, args.bin_data_size, args.num_bins)