def run(method, architecture, output_dir, test_level): """Trains a model and records its predictions on configured datasets. Args: method: Name of modeling method (vanilla, dropout, svi, ll_svi). architecture: Name of DNN architecture (mlp or dropout). output_dir: Directory to record the trained model and output stats. test_level: Zero indicates no testing. One indicates testing with real data. Two is for testing with fake data. """ fake_data = test_level > 1 gfile.makedirs(output_dir) model_opts, data_opts_list = get_experiment_config(method, architecture, test_level=test_level, output_dir=output_dir) # Separately build dataset[0] with shuffle=True for training. dataset_train = data_lib.build_dataset(data_opts_list[0], fake_data=fake_data) dataset_eval = data_lib.build_dataset(data_opts_list[1], fake_data=fake_data) model = models_lib.build_and_train(model_opts, dataset_train, dataset_eval, output_dir) logging.info('Saving model to output_dir.') model.save_weights(output_dir + '/model.ckpt') for idx, data_opts in enumerate(data_opts_list): dataset = data_lib.build_dataset(data_opts, fake_data=fake_data) logging.info('Running predictions for dataset #%d', idx) stats = models_lib.make_predictions(model_opts, model, dataset) array_utils.write_npz(output_dir, 'stats_%d.npz' % idx, stats) del stats['logits_samples'] array_utils.write_npz(output_dir, 'stats_small_%d.npz' % idx, stats)
def run(dataset_name, model_dir, predictions_per_example, max_examples, output_dir, fake_data=False): """Runs predictions on the given dataset using the specified model.""" tf.io.gfile.makedirs(output_dir) data_config = image_data_utils.get_data_config(dataset_name) dataset = data_lib.build_dataset(data_config, fake_data=fake_data) if max_examples: dataset = dataset.take(max_examples) model = models_lib.load_model(model_dir) logging.info('Starting predictions.') predictions = experiment_utils.make_predictions(model, dataset.batch(_BATCH_SIZE), predictions_per_example) logging.info('Done computing predictions; recording results to disk.') array_utils.write_npz(output_dir, 'predictions_%s.npz' % dataset_name, predictions) del predictions['logits_samples'] array_utils.write_npz(output_dir, 'predictions_small_%s.npz' % dataset_name, predictions)
def run(model_dir, dataset_name, predictions_per_example, max_examples, output_dir, fake_data=False): """Runs predictions on the given dataset using the specified model.""" tf.io.gfile.makedirs(output_dir) data_config = data_lib.DataConfig.from_name(dataset_name, fake_data=fake_data) dataset = data_lib.build_dataset(data_config, batch_size=_BATCH_SIZE) if max_examples: dataset = dataset.take(max_examples // _BATCH_SIZE) model = models_lib.load_trained_model(model_dir) logging.info('Starting predictions.') predictions = models_lib.make_predictions(model, dataset, predictions_per_example) array_utils.write_npz(output_dir, 'predictions_%s.npz' % dataset_name, predictions) del predictions['probs_samples'] array_utils.write_npz(output_dir, 'predictions_small_%s.npz' % dataset_name, predictions)
def run(prediction_path, validation_path, model_dir_ensemble, use_temp_scaling=False): """Runs predictions on the given dataset using the specified model.""" logging.info('Loading predictions...') out_file_prefix = 'metrics_' if model_dir_ensemble: probs, labels = get_ensemble_stats(model_dir_ensemble, prediction_path) out_file_prefix = 'metrics_ensemble_' else: stats = array_utils.load_stats_from_tfrecords( prediction_path, max_records=_MAX_PREDICTIONS) probs = stats['probs'].astype(np.float32) labels = stats['labels'].astype(np.int32) if len(labels.shape) > 1: labels = np.squeeze(labels, -1) probs = metrics_lib.soften_probabilities(probs=probs) logits = inverse_softmax(probs) if use_temp_scaling: predictions_base_dir = os.path.dirname( os.path.dirname(os.path.dirname(prediction_path))) json_dir = os.path.join(predictions_base_dir, '*/*/temperature_hparam.json') paths = gfile.glob(json_dir) filestats = gfile.stat(json_dir) idx = np.argmax([s.mtime_nsecs for s in filestats]) temperature_hparam_path = paths[idx] with gfile.GFile(temperature_hparam_path) as fh: temp = json.loads(fh.read())['temperature'] logits /= temp probs = tf.nn.softmax(logits).numpy() probs = metrics_lib.soften_probabilities(probs=probs) out_file_prefix = 'metrics_temp_scaled_' # confidence vs accuracy thresholds = np.linspace(0, 1, 10, endpoint=False) accuracies, counts = metrics_lib.compute_accuracies_at_confidences( labels, probs, thresholds) overall_accuracy = (probs.argmax(-1) == labels).mean() accuracy_top5 = metrics_lib.accuracy_top_k(probs, labels, 5) accuracy_top10 = metrics_lib.accuracy_top_k(probs, labels, 10) probs_correct_class = probs[np.arange(probs.shape[0]), labels] negative_log_likelihood = -np.log(probs_correct_class).mean() entropy_per_example = tfd.Categorical(probs=probs).entropy().numpy() uncertainty, resolution, reliability = metrics_lib.brier_decomposition( labels=labels, logits=logits) brier = tf.reduce_mean(metrics_lib.brier_scores(labels=labels, logits=logits)) ece = metrics_lib.expected_calibration_error_multiclass( probs, labels, _ECE_BINS) ece_top5 = metrics_lib.expected_calibration_error_multiclass( probs, labels, _ECE_BINS, top_k=5) ece_top10 = metrics_lib.expected_calibration_error_multiclass( probs, labels, _ECE_BINS, top_k=10) validation_stats = array_utils.load_stats_from_tfrecords( validation_path, max_records=20000) validation_probs = validation_stats['probs'].astype(np.float32) bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs) q_ece = metrics_lib.expected_calibration_error_multiclass( probs, labels, bins) bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=5) q_ece_top5 = metrics_lib.expected_calibration_error_multiclass( probs, labels, bins, top_k=5) bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=10) q_ece_top10 = metrics_lib.expected_calibration_error_multiclass( probs, labels, bins, top_k=10) metrics = { 'accuracy': overall_accuracy, 'accuracy_top5': accuracy_top5, 'accuracy_top10': accuracy_top10, 'accuracy_at_confidence': accuracies, 'confidence_thresholds': thresholds, 'confidence_counts': counts, 'ece': ece, 'ece_top5': ece_top5, 'ece_top10': ece_top10, 'q_ece': q_ece, 'q_ece_top5': q_ece_top5, 'q_ece_top10': q_ece_top10, 'ece_nbins': _ECE_BINS, 'entropy_per_example': entropy_per_example, 'brier_uncertainty': uncertainty.numpy(), 'brier_resolution': resolution.numpy(), 'brier_reliability': reliability.numpy(), 'brier_score': brier.numpy(), 'true_labels': labels, 'pred_labels': probs.argmax(-1), 'prob_true_label': probs[np.arange(len(labels)), labels], 'prob_pred_label': probs.max(-1), 'negative_log_likelihood': negative_log_likelihood, } save_dir = os.path.dirname(prediction_path) split_path = split_prediction_path(prediction_path) prediction_file = split_path[-1] dataset_name = '-'.join(prediction_file.split('_')[2:]) out_file = out_file_prefix + dataset_name + '.npz' array_utils.write_npz(save_dir, out_file, metrics)