def cifar10_experiment_marginal_2_2_3000():
    logits_file = 'cifar_logits.dat'
    logits, labels = utils.load_test_logits_labels(logits_file)
    bins_list = list(range(10, 101, 10))
    num_trials = 20
    num_calibration = 3000
    l2_ces, l2_stddevs, mses = vary_bin_calibration(
        logits,
        labels,
        num_calibration,
        bins_list,
        Calibrators=[
            calibrators.HistogramMarginalCalibrator,
            calibrators.PlattBinnerMarginalCalibrator
        ],
        eval_calibration=eval_marginal_calibration,
        eval_mse=eval_marginal_mse,
        num_trials=num_trials,
        resample=True)
    plot_mse_ce_curve(bins_list,
                      l2_ces,
                      mses,
                      xlim=(0.0, 0.001),
                      ylim=(0.0, 0.0075))
    plot_ces(bins_list, l2_ces, l2_stddevs)
예제 #2
0
def cifar10_experiment_marginal_3_1_1000():
    logits_file = 'cifar_logits.dat'
    logits, labels = utils.load_test_logits_labels(logits_file)
    bins_list = list(range(10, 101, 10))
    num_trials = 100
    num_calib = 1000
    num_eval = 1000
    l2_ces, l2_stddevs, mses = vary_bin_calibration(
        data_sampler=make_calibration_eval_data_sampler(
            logits, labels, num_calib, num_eval),
        num_bins_list=bins_list,
        Calibrators=[
            calibrators.HistogramMarginalCalibrator,
            calibrators.PlattBinnerMarginalCalibrator
        ],
        calibration_evaluators=[
            upper_bound_marginal_calibration_biased,
            upper_bound_marginal_calibration_unbiased
        ],
        eval_mse=utils.eval_marginal_mse,
        num_trials=num_trials)
    plot_mse_ce_curve(bins_list,
                      l2_ces,
                      mses,
                      xlim=(0.0, 0.002),
                      ylim=(0.0, 0.04))
    plot_ces(bins_list, l2_ces, l2_stddevs)
def cifar_experiment(savefile, binning_func=utils.get_equal_bins, lp=2):
    np.random.seed(0)
    calibration_data_size = 1000
    bin_data_size = 1000
    logits, labels = utils.load_test_logits_labels('cifar_logits.dat')
    lower_bound_experiment(logits,
                           labels,
                           calibration_data_size,
                           bin_data_size,
                           bins_list=[2, 4, 8, 16, 32, 64, 128],
                           save_name=savefile,
                           binning_func=binning_func,
                           lp=lp)
예제 #4
0
def imagenet_experiment_top_1_1_1000():
    logits_file = 'imagenet_logits.dat'
    logits, labels = utils.load_test_logits_labels(logits_file)
    bins_list = list(range(10, 101, 10))
    num_trials = 100
    num_calibration = 1000
    l2_ces, l2_stddevs, mses = vary_bin_calibration(
        data_sampler=make_calibration_data_sampler(logits, labels,
                                                   num_calibration),
        num_bins_list=bins_list,
        Calibrators=[
            calibrators.HistogramTopCalibrator,
            calibrators.PlattBinnerTopCalibrator
        ],
        calibration_evaluators=[eval_top_calibration, eval_top_calibration],
        eval_mse=utils.eval_top_mse,
        num_trials=num_trials)
    plot_mse_ce_curve(bins_list, l2_ces, mses)
    plot_ces(bins_list, l2_ces, l2_stddevs)
                 color='red',
                 capsize=4,
                 label='plugin')
    plt.errorbar(verification_sizes,
                 accumulated_errors[1],
                 yerr=[error_bars_90[1], error_bars_90[1]],
                 barsabove=True,
                 color='blue',
                 capsize=4,
                 label='debiased')
    plt.ylabel("Mean-Squared-Error")
    plt.xlabel("Number of Samples")
    plt.legend(loc='upper right')
    plt.show()

    plt.ylabel("Number of estimates")
    plt.xlabel("Absolute deviation from ground truth")
    bins = np.linspace(np.min(errors[:, 0, :]), np.max(errors[:, 0, :]), 40)
    plt.hist(errors[0][0], bins, alpha=0.5, label='plugin')
    plt.hist(errors[1][0], bins, alpha=0.5, label='debiased')
    plt.legend(loc='upper right')
    plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=num_resamples))
    plt.show()


if __name__ == "__main__":
    args = parser.parse_args()
    logits, labels = utils.load_test_logits_labels(args.logits_file)
    compare_estimators(logits, labels, args.platt_data_size,
                       args.bin_data_size, args.num_bins)