コード例 #1
0
 def train_calibration(self, logits, labels):
     assert(len(logits) >= self._num_calibration)
     probs = utils.get_top_probs(logits)
     predictions = utils.get_top_predictions(logits)
     correct = (predictions == labels)
     bins = utils.get_equal_bins(probs, num_bins=self._num_bins)
     self._calibrator = utils.get_histogram_calibrator(
         probs, correct, bins)
コード例 #2
0
def eval_top_calibration(probs, logits, labels, plugin=True):
    correct = (utils.get_top_predictions(logits) == labels)
    data = list(zip(probs, correct))
    bins = utils.get_discrete_bins(probs)
    binned_data = utils.bin(data, bins)
    if plugin:
        return utils.plugin_ce(binned_data)**2
    else:
        return utils.improved_unbiased_square_ce(binned_data)
コード例 #3
0
 def train_calibration(self, logits, labels):
     assert(len(logits) >= self._num_calibration)
     predictions = utils.get_top_predictions(logits)
     probs = utils.get_top_probs(logits)
     correct = (predictions == labels)
     self._platt = utils.get_platt_scaler(
         probs, correct)
     platt_probs = self._platt(probs)
     bins = utils.get_equal_bins(platt_probs, num_bins=self._num_bins)
     self._discrete_calibrator = utils.get_discrete_calibrator(
         platt_probs, bins)
コード例 #4
0
def lower_bound_experiment(logits,
                           labels,
                           calibration_data_size,
                           bin_data_size,
                           bins_list,
                           save_name='cmp_est',
                           binning_func=utils.get_equal_bins,
                           lp=2):
    # Shuffle the logits and labels.
    indices = np.random.choice(list(range(len(logits))),
                               size=len(logits),
                               replace=False)
    logits = [logits[i] for i in indices]
    labels = [labels[i] for i in indices]
    predictions = utils.get_top_predictions(logits)
    probs = utils.get_top_probs(logits)
    correct = (predictions == labels)
    print('num_correct: ', sum(correct))
    # Platt scale on first chunk of data
    platt = utils.get_platt_scaler(probs[:calibration_data_size],
                                   correct[:calibration_data_size])
    platt_probs = platt(probs)
    lower, middle, upper = [], [], []
    for num_bins in bins_list:
        bins = binning_func(platt_probs[:calibration_data_size +
                                        bin_data_size],
                            num_bins=num_bins)
        verification_probs = platt_probs[calibration_data_size +
                                         bin_data_size:]
        verification_correct = correct[calibration_data_size + bin_data_size:]
        verification_data = list(zip(verification_probs, verification_correct))

        def estimator(data):
            binned_data = utils.bin(data, bins)
            return utils.plugin_ce(binned_data, power=lp)

        print('estimate: ', estimator(verification_data))
        estimate_interval = utils.bootstrap_uncertainty(verification_data,
                                                        estimator,
                                                        num_samples=1000)
        lower.append(estimate_interval[0])
        middle.append(estimate_interval[1])
        upper.append(estimate_interval[2])
        print('interval: ', estimate_interval)
    lower_errors = np.array(middle) - np.array(lower)
    upper_errors = np.array(upper) - np.array(middle)
    plt.clf()
    font = {'family': 'normal', 'size': 18}
    rc('font', **font)
    plt.errorbar(bins_list,
                 middle,
                 yerr=[lower_errors, upper_errors],
                 barsabove=True,
                 fmt='none',
                 color='black',
                 capsize=4)
    plt.scatter(bins_list, middle, color='black')
    plt.xlabel(r"No. of bins")
    if lp == 2:
        plt.ylabel("Calibration error")
    else:
        plt.ylabel("L%d Calibration error" % lp)
    plt.xscale('log', basex=2)
    ax = plt.gca()
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    plt.tight_layout()
    plt.savefig(save_name)
コード例 #5
0
def eval_top_mse(probs, logits, labels):
    correct = (utils.get_top_predictions(logits) == labels)
    return np.mean(np.square(probs - correct))
コード例 #6
0
def compare_estimators(logits,
                       labels,
                       platt_data_size,
                       bin_data_size,
                       num_bins,
                       ver_base_size=2000,
                       ver_size_increment=1000,
                       num_resamples=100,
                       save_name='cmp_est'):
    # Convert logits to prediction, probs.
    predictions = utils.get_top_predictions(logits)
    probs = utils.get_top_probs(logits)
    correct = (predictions == labels)
    # Platt scale on first chunk of data
    platt = utils.get_platt_scaler(probs[:platt_data_size],
                                   correct[:platt_data_size])
    platt_probs = platt(probs)
    estimator_names = ['biased', 'unbiased']
    estimators = [
        lambda x: utils.plugin_ce(x)**2, utils.improved_unbiased_square_ce
    ]

    bins = utils.get_equal_bins(platt_probs[:platt_data_size + bin_data_size],
                                num_bins=num_bins)
    binner = utils.get_discrete_calibrator(
        platt_probs[platt_data_size:platt_data_size + bin_data_size], bins)
    verification_probs = binner(platt_probs[platt_data_size + bin_data_size:])
    verification_correct = correct[platt_data_size + bin_data_size:]
    verification_data = list(zip(verification_probs, verification_correct))
    verification_sizes = list(
        range(ver_base_size,
              len(verification_probs) + 1, ver_size_increment))
    # We want to compare the two estimators when varying the number of samples.
    # However, a single point of comparison does not tell us much about the estimators.
    # So we use resampling - we resample from the test set many times, and run the estimators
    # on the resamples. We stores these values. This gives us a sense of the range of values
    # the estimator might output.
    # So estimates[i][j][k] stores the estimate when using estimator i, with verification_sizes[j]
    # samples, in the k-th resampling.
    estimates = np.zeros(
        (len(estimators), len(verification_sizes), num_resamples))
    # We also store the certified estimates. These represent the upper bounds we get using
    # each estimator. They are computing using the std-dev of the estimator estimated by
    # Bootstrap.
    cert_estimates = np.zeros(
        (len(estimators), len(verification_sizes), num_resamples))
    for ver_idx, verification_size in zip(range(len(verification_sizes)),
                                          verification_sizes):
        for k in range(num_resamples):
            # Resample
            indices = np.random.choice(list(range(len(verification_data))),
                                       size=verification_size,
                                       replace=True)
            cur_verification_data = [verification_data[i] for i in indices]
            cur_verification_probs = [verification_probs[i] for i in indices]
            bins = utils.get_discrete_bins(cur_verification_probs)
            # Compute estimates for each estimator.
            for i in range(len(estimators)):

                def estimator(data):
                    binned_data = utils.bin(data, bins)
                    return estimators[i](binned_data)

                cur_estimate = estimator(cur_verification_data)
                estimates[i][ver_idx][k] = cur_estimate
                # cert_resampling_estimates[j].append(
                # 	cur_estimate + utils.bootstrap_std(cur_verification_data, estimator, num_samples=20))

    estimates = np.sort(estimates, axis=-1)
    lower_bound = int(0.1 * num_resamples)
    median = int(0.5 * num_resamples)
    upper_bound = int(0.9 * num_resamples)
    lower_estimates = estimates[:, :, lower_bound]
    upper_estimates = estimates[:, :, upper_bound]
    median_estimates = estimates[:, :, median]

    # We can also compute the MSEs of the estimator.
    bins = utils.get_discrete_bins(verification_probs)
    true_calibration = utils.plugin_ce(utils.bin(verification_data, bins))**2
    print(true_calibration)
    print(np.sqrt(np.mean(estimates[1, -1, :])))
    errors = np.abs(estimates - true_calibration)
    accumulated_errors = np.mean(errors, axis=-1)
    error_bars_90 = 1.645 * np.std(errors, axis=-1) / np.sqrt(num_resamples)
    print(accumulated_errors)
    plt.errorbar(verification_sizes,
                 accumulated_errors[0],
                 yerr=[error_bars_90[0], error_bars_90[0]],
                 barsabove=True,
                 color='red',
                 capsize=4,
                 label='plugin')
    plt.errorbar(verification_sizes,
                 accumulated_errors[1],
                 yerr=[error_bars_90[1], error_bars_90[1]],
                 barsabove=True,
                 color='blue',
                 capsize=4,
                 label='debiased')
    plt.ylabel("Mean-Squared-Error")
    plt.xlabel("Number of Samples")
    plt.legend(loc='upper right')
    plt.show()

    plt.ylabel("Number of estimates")
    plt.xlabel("Absolute deviation from ground truth")
    bins = np.linspace(np.min(errors[:, 0, :]), np.max(errors[:, 0, :]), 40)
    plt.hist(errors[0][0], bins, alpha=0.5, label='plugin')
    plt.hist(errors[1][0], bins, alpha=0.5, label='debiased')
    plt.legend(loc='upper right')
    plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=num_resamples))
    plt.show()
コード例 #7
0
def eval_top_calibration(probs, logits, labels):
    correct = (utils.get_top_predictions(logits) == labels)
    data = list(zip(probs, correct))
    bins = utils.get_discrete_bins(probs)
    binned_data = utils.bin(data, bins)
    return utils.plugin_ce(binned_data)**2
コード例 #8
0
import argparse
import numpy as np

import utils

parser = argparse.ArgumentParser()
parser.add_argument('--logits_file',
                    default='cifar_logits.dat',
                    type=str,
                    help='Name of file to load logits, labels pair.')

if __name__ == "__main__":
    args = parser.parse_args()
    logits, labels = utils.load_test_logits_labels(args.logits_file)
    # Get prediction accuracy.
    predictions = utils.get_top_predictions(logits)
    probs = utils.get_top_probs(logits)
    correct = (predictions == labels)
    print('accuracy: ', float(sum(correct)) / len(labels))
    # Get top-label MSE.
    top_mse = utils.eval_top_mse(probs, logits, labels)
    print('top mse: ', top_mse)
    # Get marginal MSE.
    marginal_mse = utils.eval_marginal_mse(logits, logits, labels)
    print('marginal mse: ', marginal_mse)