def test_accuracy_top_k(self): num_samples = 20 num_classes = 10 probs = np.random.rand(num_samples, num_classes) probs /= np.expand_dims(probs.sum(axis=1), axis=-1) probs = np.apply_along_axis(sorted, 1, probs) labels = np.tile(np.arange(num_classes), 2) top_2_accuracy = metrics_lib.accuracy_top_k(probs, labels, 2) top_5_accuracy = metrics_lib.accuracy_top_k(probs, labels, 5) self.assertEqual(top_2_accuracy, .2) self.assertEqual(top_5_accuracy, .5)
def run(prediction_path, validation_path, model_dir_ensemble, use_temp_scaling=False): """Runs predictions on the given dataset using the specified model.""" logging.info('Loading predictions...') out_file_prefix = 'metrics_' if model_dir_ensemble: probs, labels = get_ensemble_stats(model_dir_ensemble, prediction_path) out_file_prefix = 'metrics_ensemble_' else: stats = array_utils.load_stats_from_tfrecords( prediction_path, max_records=_MAX_PREDICTIONS) probs = stats['probs'].astype(np.float32) labels = stats['labels'].astype(np.int32) if len(labels.shape) > 1: labels = np.squeeze(labels, -1) probs = metrics_lib.soften_probabilities(probs=probs) logits = inverse_softmax(probs) if use_temp_scaling: predictions_base_dir = os.path.dirname( os.path.dirname(os.path.dirname(prediction_path))) json_dir = os.path.join(predictions_base_dir, '*/*/temperature_hparam.json') paths = gfile.glob(json_dir) filestats = gfile.stat(json_dir) idx = np.argmax([s.mtime_nsecs for s in filestats]) temperature_hparam_path = paths[idx] with gfile.GFile(temperature_hparam_path) as fh: temp = json.loads(fh.read())['temperature'] logits /= temp probs = tf.nn.softmax(logits).numpy() probs = metrics_lib.soften_probabilities(probs=probs) out_file_prefix = 'metrics_temp_scaled_' # confidence vs accuracy thresholds = np.linspace(0, 1, 10, endpoint=False) accuracies, counts = metrics_lib.compute_accuracies_at_confidences( labels, probs, thresholds) overall_accuracy = (probs.argmax(-1) == labels).mean() accuracy_top5 = metrics_lib.accuracy_top_k(probs, labels, 5) accuracy_top10 = metrics_lib.accuracy_top_k(probs, labels, 10) probs_correct_class = probs[np.arange(probs.shape[0]), labels] negative_log_likelihood = -np.log(probs_correct_class).mean() entropy_per_example = tfd.Categorical(probs=probs).entropy().numpy() uncertainty, resolution, reliability = metrics_lib.brier_decomposition( labels=labels, logits=logits) brier = tf.reduce_mean(metrics_lib.brier_scores(labels=labels, logits=logits)) ece = metrics_lib.expected_calibration_error_multiclass( probs, labels, _ECE_BINS) ece_top5 = metrics_lib.expected_calibration_error_multiclass( probs, labels, _ECE_BINS, top_k=5) ece_top10 = metrics_lib.expected_calibration_error_multiclass( probs, labels, _ECE_BINS, top_k=10) validation_stats = array_utils.load_stats_from_tfrecords( validation_path, max_records=20000) validation_probs = validation_stats['probs'].astype(np.float32) bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs) q_ece = metrics_lib.expected_calibration_error_multiclass( probs, labels, bins) bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=5) q_ece_top5 = metrics_lib.expected_calibration_error_multiclass( probs, labels, bins, top_k=5) bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=10) q_ece_top10 = metrics_lib.expected_calibration_error_multiclass( probs, labels, bins, top_k=10) metrics = { 'accuracy': overall_accuracy, 'accuracy_top5': accuracy_top5, 'accuracy_top10': accuracy_top10, 'accuracy_at_confidence': accuracies, 'confidence_thresholds': thresholds, 'confidence_counts': counts, 'ece': ece, 'ece_top5': ece_top5, 'ece_top10': ece_top10, 'q_ece': q_ece, 'q_ece_top5': q_ece_top5, 'q_ece_top10': q_ece_top10, 'ece_nbins': _ECE_BINS, 'entropy_per_example': entropy_per_example, 'brier_uncertainty': uncertainty.numpy(), 'brier_resolution': resolution.numpy(), 'brier_reliability': reliability.numpy(), 'brier_score': brier.numpy(), 'true_labels': labels, 'pred_labels': probs.argmax(-1), 'prob_true_label': probs[np.arange(len(labels)), labels], 'prob_pred_label': probs.max(-1), 'negative_log_likelihood': negative_log_likelihood, } save_dir = os.path.dirname(prediction_path) split_path = split_prediction_path(prediction_path) prediction_file = split_path[-1] dataset_name = '-'.join(prediction_file.split('_')[2:]) out_file = out_file_prefix + dataset_name + '.npz' array_utils.write_npz(save_dir, out_file, metrics)