def test_expected_calibration_error_quantile_multiclass(self, top_k): bad_quantile_eces = {1: .5, 2: .25, None: .2} num_samples = int(1e4) num_classes = 5 probabilities, labels = _make_perfectly_calibrated_multiclass( num_samples, num_classes) bins = metrics_lib.get_quantile_bins(10, probabilities, top_k=top_k) good_quantile_ece = metrics_lib.expected_calibration_error_multiclass( probabilities, labels, bins, top_k) bad_quantile_ece = metrics_lib.expected_calibration_error_multiclass( np.fliplr(probabilities), labels, bins, top_k) self.assertAllClose(good_quantile_ece, 0, atol=0.05) self.assertAllClose(bad_quantile_ece, bad_quantile_eces[top_k], atol=0.05)
def test_expected_calibration_error(self): np.random.seed(1) nsamples = 100 probs = np.linspace(0, 1, nsamples) labels = np.random.rand(nsamples) < probs ece = metrics_lib.expected_calibration_error(probs, labels) bad_ece = metrics_lib.expected_calibration_error(probs / 2, labels) self.assertBetween(ece, 0, 1) self.assertBetween(bad_ece, 0, 1) self.assertLess(ece, bad_ece) bins = metrics_lib.get_quantile_bins(10, probs) quantile_ece = metrics_lib.expected_calibration_error(probs, labels, bins) bad_quantile_ece = metrics_lib.expected_calibration_error( probs / 2, labels, bins) self.assertBetween(quantile_ece, 0, 1) self.assertBetween(bad_quantile_ece, 0, 1) self.assertLess(quantile_ece, bad_quantile_ece)
def run(prediction_path, validation_path, model_dir_ensemble, use_temp_scaling=False): """Runs predictions on the given dataset using the specified model.""" logging.info('Loading predictions...') out_file_prefix = 'metrics_' if model_dir_ensemble: probs, labels = get_ensemble_stats(model_dir_ensemble, prediction_path) out_file_prefix = 'metrics_ensemble_' else: stats = array_utils.load_stats_from_tfrecords( prediction_path, max_records=_MAX_PREDICTIONS) probs = stats['probs'].astype(np.float32) labels = stats['labels'].astype(np.int32) if len(labels.shape) > 1: labels = np.squeeze(labels, -1) probs = metrics_lib.soften_probabilities(probs=probs) logits = inverse_softmax(probs) if use_temp_scaling: predictions_base_dir = os.path.dirname( os.path.dirname(os.path.dirname(prediction_path))) json_dir = os.path.join(predictions_base_dir, '*/*/temperature_hparam.json') paths = gfile.glob(json_dir) filestats = gfile.stat(json_dir) idx = np.argmax([s.mtime_nsecs for s in filestats]) temperature_hparam_path = paths[idx] with gfile.GFile(temperature_hparam_path) as fh: temp = json.loads(fh.read())['temperature'] logits /= temp probs = tf.nn.softmax(logits).numpy() probs = metrics_lib.soften_probabilities(probs=probs) out_file_prefix = 'metrics_temp_scaled_' # confidence vs accuracy thresholds = np.linspace(0, 1, 10, endpoint=False) accuracies, counts = metrics_lib.compute_accuracies_at_confidences( labels, probs, thresholds) overall_accuracy = (probs.argmax(-1) == labels).mean() accuracy_top5 = metrics_lib.accuracy_top_k(probs, labels, 5) accuracy_top10 = metrics_lib.accuracy_top_k(probs, labels, 10) probs_correct_class = probs[np.arange(probs.shape[0]), labels] negative_log_likelihood = -np.log(probs_correct_class).mean() entropy_per_example = tfd.Categorical(probs=probs).entropy().numpy() uncertainty, resolution, reliability = metrics_lib.brier_decomposition( labels=labels, logits=logits) brier = tf.reduce_mean(metrics_lib.brier_scores(labels=labels, logits=logits)) ece = metrics_lib.expected_calibration_error_multiclass( probs, labels, _ECE_BINS) ece_top5 = metrics_lib.expected_calibration_error_multiclass( probs, labels, _ECE_BINS, top_k=5) ece_top10 = metrics_lib.expected_calibration_error_multiclass( probs, labels, _ECE_BINS, top_k=10) validation_stats = array_utils.load_stats_from_tfrecords( validation_path, max_records=20000) validation_probs = validation_stats['probs'].astype(np.float32) bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs) q_ece = metrics_lib.expected_calibration_error_multiclass( probs, labels, bins) bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=5) q_ece_top5 = metrics_lib.expected_calibration_error_multiclass( probs, labels, bins, top_k=5) bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=10) q_ece_top10 = metrics_lib.expected_calibration_error_multiclass( probs, labels, bins, top_k=10) metrics = { 'accuracy': overall_accuracy, 'accuracy_top5': accuracy_top5, 'accuracy_top10': accuracy_top10, 'accuracy_at_confidence': accuracies, 'confidence_thresholds': thresholds, 'confidence_counts': counts, 'ece': ece, 'ece_top5': ece_top5, 'ece_top10': ece_top10, 'q_ece': q_ece, 'q_ece_top5': q_ece_top5, 'q_ece_top10': q_ece_top10, 'ece_nbins': _ECE_BINS, 'entropy_per_example': entropy_per_example, 'brier_uncertainty': uncertainty.numpy(), 'brier_resolution': resolution.numpy(), 'brier_reliability': reliability.numpy(), 'brier_score': brier.numpy(), 'true_labels': labels, 'pred_labels': probs.argmax(-1), 'prob_true_label': probs[np.arange(len(labels)), labels], 'prob_pred_label': probs.max(-1), 'negative_log_likelihood': negative_log_likelihood, } save_dir = os.path.dirname(prediction_path) split_path = split_prediction_path(prediction_path) prediction_file = split_path[-1] dataset_name = '-'.join(prediction_file.split('_')[2:]) out_file = out_file_prefix + dataset_name + '.npz' array_utils.write_npz(save_dir, out_file, metrics)