コード例 #1
0
def evaluate(function, calibrator, dist, n):
    zs = dist(size=n)
    ps = function(zs)
    phats = calibrator.calibrate(zs)
    bins = utils.get_discrete_bins(phats)
    data = list(zip(phats, ps))
    binned_data = utils.bin(data, bins)
    return utils.plugin_ce(binned_data)**2
コード例 #2
0
def eval_top_calibration(probs, logits, labels, plugin=True):
    correct = (utils.get_top_predictions(logits) == labels)
    data = list(zip(probs, correct))
    bins = utils.get_discrete_bins(probs)
    binned_data = utils.bin(data, bins)
    if plugin:
        return utils.plugin_ce(binned_data)**2
    else:
        return utils.improved_unbiased_square_ce(binned_data)
コード例 #3
0
def eval_marginal_calibration(probs, logits, labels, plugin=True):
    ces = []  # Compute the calibration error per class, then take the average.
    k = logits.shape[1]
    labels_one_hot = utils.get_labels_one_hot(np.array(labels), k)
    for c in range(k):
        probs_c = probs[:, c]
        labels_c = labels_one_hot[:, c]
        data_c = list(zip(probs_c, labels_c))
        bins_c = utils.get_discrete_bins(probs_c)
        binned_data_c = utils.bin(data_c, bins_c)
        ce_c = utils.plugin_ce(binned_data_c)**2
        ces.append(ce_c)
    return np.mean(ces)
コード例 #4
0
 def estimator(data):
     binned_data = utils.bin(data, bins)
     return utils.plugin_ce(binned_data, power=lp)
コード例 #5
0
 def mse(data):
     logits, labels = zip(*data)
     return get_mse_ce(logits, labels,
                       lambda x: utils.plugin_ce(x)**2)[0]
コード例 #6
0
 def plugin_ce_squared(data):
     logits, labels = zip(*data)
     return get_mse_ce(logits, labels,
                       lambda x: utils.plugin_ce(x)**2)[1]
コード例 #7
0
def calibrate_marginals_experiment(logits, labels, k):
    num_calib = 3000
    num_bin = 3000
    num_cert = 4000
    assert (logits.shape[0] == num_calib + num_bin + num_cert)
    num_bins = 100
    bootstrap_samples = 100
    # First split by label? To ensure equal class numbers? Do this later.
    labels = utils.get_labels_one_hot(labels[:, 0], k)
    mse = np.mean(np.square(labels - logits))
    print('original mse is ', mse)
    calib_logits = logits[:num_calib, :]
    calib_labels = labels[:num_calib, :]
    bin_logits = logits[num_calib:num_calib + num_bin, :]
    bin_labels = labels[num_calib:num_calib + num_bin, :]
    cert_logits = logits[num_calib + num_bin:, :]
    cert_labels = labels[num_calib + num_bin:, :]
    mses = []
    unbiased_ces = []
    biased_ces = []
    std_unbiased_ces = []
    std_biased_ces = []
    for num_bins in range(10, 21, 10):
        # Train a platt scaler and binner.
        platts = []
        platt_binners_equal_points = []
        for l in range(k):
            platt_l = utils.get_platt_scaler(calib_logits[:, l],
                                             calib_labels[:, l])
            platts.append(platt_l)
            cal_logits_l = platt_l(calib_logits[:, l])
            # bins_l = utils.get_equal_bins(cal_logits_l, num_bins=num_bins)
            # Get
            # bins_l = utils.get_equal_prob_bins(num_bins=num_bins)
            bins_l = [0.0012, 0.05, 0.01, 0.95, 0.985, 1.0]
            cal_bin_logits_l = platt_l(bin_logits[:, l])
            platt_binner_l = utils.get_discrete_calibrator(
                cal_bin_logits_l, bins_l)
            platt_binners_equal_points.append(platt_binner_l)

        # Write a function that takes data and outputs the mse, ce
        def get_mse_ce(logits, labels, ce_est):
            mses = []
            ces = []
            logits = np.array(logits)
            labels = np.array(labels)
            for l in range(k):
                cal_logits_l = platt_binners_equal_points[l](platts[l](
                    logits[:, l]))
                data = list(zip(cal_logits_l, labels[:, l]))
                bins_l = utils.get_discrete_bins(cal_logits_l)
                binned_data = utils.bin(data, bins_l)
                # probs = platts[l](logits[:, l])
                # for p in [1, 5, 10, 20, 50, 85, 88.5, 92, 94, 96, 98, 100]:
                #     print(np.percentile(probs, p))
                # import time
                # time.sleep(100)
                # print('lengths')
                # print([len(d) for d in binned_data])
                ces.append(ce_est(binned_data))
                mses.append(
                    np.mean([(prob - label)**2 for prob, label in data]))
            return np.mean(mses), np.mean(ces)

        def plugin_ce_squared(data):
            logits, labels = zip(*data)
            return get_mse_ce(logits, labels,
                              lambda x: utils.plugin_ce(x)**2)[1]

        def mse(data):
            logits, labels = zip(*data)
            return get_mse_ce(logits, labels,
                              lambda x: utils.plugin_ce(x)**2)[0]

        def unbiased_ce_squared(data):
            logits, labels = zip(*data)
            return get_mse_ce(logits, labels,
                              utils.improved_unbiased_square_ce)[1]

        mse, improved_unbiased_ce = get_mse_ce(
            cert_logits, cert_labels, utils.improved_unbiased_square_ce)
        mse, biased_ce = get_mse_ce(cert_logits, cert_labels,
                                    lambda x: utils.plugin_ce(x)**2)
        mses.append(mse)
        unbiased_ces.append(improved_unbiased_ce)
        biased_ces.append(biased_ce)
        print('biased ce: ', np.sqrt(biased_ce))
        print('mse: ', mse)
        print('improved ce: ', np.sqrt(improved_unbiased_ce))
        data = list(zip(list(cert_logits), list(cert_labels)))
        std_biased_ces.append(
            utils.bootstrap_std(data,
                                plugin_ce_squared,
                                num_samples=bootstrap_samples))
        std_unbiased_ces.append(
            utils.bootstrap_std(data,
                                unbiased_ce_squared,
                                num_samples=bootstrap_samples))

    std_multiplier = 1.3  # For one sided 90% confidence interval.
    upper_unbiased_ces = list(
        map(lambda p: np.sqrt(p[0] + std_multiplier * p[1]),
            zip(unbiased_ces, std_unbiased_ces)))
    upper_biased_ces = list(
        map(lambda p: np.sqrt(p[0] + std_multiplier * p[1]),
            zip(biased_ces, std_biased_ces)))

    # Get points on the Pareto curve, and plot them.
    def get_pareto_points(data):
        pareto_points = []

        def dominated(p1, p2):
            return p1[0] >= p2[0] and p1[1] >= p2[1]

        for datum in data:
            num_dominated = sum(map(lambda x: dominated(datum, x), data))
            if num_dominated == 1:
                pareto_points.append(datum)
        return pareto_points

    print(
        get_pareto_points(
            list(zip(upper_unbiased_ces, mses, list(range(5, 101, 5))))))
    print(
        get_pareto_points(
            list(zip(upper_biased_ces, mses, list(range(5, 101, 5))))))
    plot_unbiased_ces, plot_unbiased_mses = zip(
        *get_pareto_points(list(zip(upper_unbiased_ces, mses))))
    plot_biased_ces, plot_biased_mses = zip(
        *get_pareto_points(list(zip(upper_biased_ces, mses))))
    plt.title("MSE vs Calibration Error")
    plt.scatter(plot_unbiased_ces,
                plot_unbiased_mses,
                c='red',
                marker='o',
                label='Ours')
    plt.scatter(plot_biased_ces,
                plot_biased_mses,
                c='blue',
                marker='s',
                label='Plugin')
    plt.legend(loc='upper left')
    plt.ylim(0.0, 0.013)
    plt.xlabel("L2 Calibration Error")
    plt.ylabel("Mean-Squared Error")
    plt.show()
コード例 #8
0
def compare_estimators(logits,
                       labels,
                       platt_data_size,
                       bin_data_size,
                       num_bins,
                       ver_base_size=2000,
                       ver_size_increment=1000,
                       num_resamples=100,
                       save_name='cmp_est'):
    # Convert logits to prediction, probs.
    predictions = utils.get_top_predictions(logits)
    probs = utils.get_top_probs(logits)
    correct = (predictions == labels)
    # Platt scale on first chunk of data
    platt = utils.get_platt_scaler(probs[:platt_data_size],
                                   correct[:platt_data_size])
    platt_probs = platt(probs)
    estimator_names = ['biased', 'unbiased']
    estimators = [
        lambda x: utils.plugin_ce(x)**2, utils.improved_unbiased_square_ce
    ]

    bins = utils.get_equal_bins(platt_probs[:platt_data_size + bin_data_size],
                                num_bins=num_bins)
    binner = utils.get_discrete_calibrator(
        platt_probs[platt_data_size:platt_data_size + bin_data_size], bins)
    verification_probs = binner(platt_probs[platt_data_size + bin_data_size:])
    verification_correct = correct[platt_data_size + bin_data_size:]
    verification_data = list(zip(verification_probs, verification_correct))
    verification_sizes = list(
        range(ver_base_size,
              len(verification_probs) + 1, ver_size_increment))
    # We want to compare the two estimators when varying the number of samples.
    # However, a single point of comparison does not tell us much about the estimators.
    # So we use resampling - we resample from the test set many times, and run the estimators
    # on the resamples. We stores these values. This gives us a sense of the range of values
    # the estimator might output.
    # So estimates[i][j][k] stores the estimate when using estimator i, with verification_sizes[j]
    # samples, in the k-th resampling.
    estimates = np.zeros(
        (len(estimators), len(verification_sizes), num_resamples))
    # We also store the certified estimates. These represent the upper bounds we get using
    # each estimator. They are computing using the std-dev of the estimator estimated by
    # Bootstrap.
    cert_estimates = np.zeros(
        (len(estimators), len(verification_sizes), num_resamples))
    for ver_idx, verification_size in zip(range(len(verification_sizes)),
                                          verification_sizes):
        for k in range(num_resamples):
            # Resample
            indices = np.random.choice(list(range(len(verification_data))),
                                       size=verification_size,
                                       replace=True)
            cur_verification_data = [verification_data[i] for i in indices]
            cur_verification_probs = [verification_probs[i] for i in indices]
            bins = utils.get_discrete_bins(cur_verification_probs)
            # Compute estimates for each estimator.
            for i in range(len(estimators)):

                def estimator(data):
                    binned_data = utils.bin(data, bins)
                    return estimators[i](binned_data)

                cur_estimate = estimator(cur_verification_data)
                estimates[i][ver_idx][k] = cur_estimate
                # cert_resampling_estimates[j].append(
                # 	cur_estimate + utils.bootstrap_std(cur_verification_data, estimator, num_samples=20))

    estimates = np.sort(estimates, axis=-1)
    lower_bound = int(0.1 * num_resamples)
    median = int(0.5 * num_resamples)
    upper_bound = int(0.9 * num_resamples)
    lower_estimates = estimates[:, :, lower_bound]
    upper_estimates = estimates[:, :, upper_bound]
    median_estimates = estimates[:, :, median]

    # We can also compute the MSEs of the estimator.
    bins = utils.get_discrete_bins(verification_probs)
    true_calibration = utils.plugin_ce(utils.bin(verification_data, bins))**2
    print(true_calibration)
    print(np.sqrt(np.mean(estimates[1, -1, :])))
    errors = np.abs(estimates - true_calibration)
    accumulated_errors = np.mean(errors, axis=-1)
    error_bars_90 = 1.645 * np.std(errors, axis=-1) / np.sqrt(num_resamples)
    print(accumulated_errors)
    plt.errorbar(verification_sizes,
                 accumulated_errors[0],
                 yerr=[error_bars_90[0], error_bars_90[0]],
                 barsabove=True,
                 color='red',
                 capsize=4,
                 label='plugin')
    plt.errorbar(verification_sizes,
                 accumulated_errors[1],
                 yerr=[error_bars_90[1], error_bars_90[1]],
                 barsabove=True,
                 color='blue',
                 capsize=4,
                 label='debiased')
    plt.ylabel("Mean-Squared-Error")
    plt.xlabel("Number of Samples")
    plt.legend(loc='upper right')
    plt.show()

    plt.ylabel("Number of estimates")
    plt.xlabel("Absolute deviation from ground truth")
    bins = np.linspace(np.min(errors[:, 0, :]), np.max(errors[:, 0, :]), 40)
    plt.hist(errors[0][0], bins, alpha=0.5, label='plugin')
    plt.hist(errors[1][0], bins, alpha=0.5, label='debiased')
    plt.legend(loc='upper right')
    plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=num_resamples))
    plt.show()
コード例 #9
0
def eval_top_calibration(probs, logits, labels):
    correct = (utils.get_top_predictions(logits) == labels)
    data = list(zip(probs, correct))
    bins = utils.get_discrete_bins(probs)
    binned_data = utils.bin(data, bins)
    return utils.plugin_ce(binned_data)**2