Example #1
0
    def test_brier_decompositions(self):
        batch_shape = (2, 3)
        num_samples, num_classes = 99, 9
        logits = tf.random.uniform(batch_shape + (num_samples, num_classes))
        dist = tfp.distributions.Categorical(logits=logits)
        labels = dist.sample().numpy()
        probs = dist.probs.numpy()

        all_decomps = metrics_lib.brier_decompositions(labels, probs)
        self.assertEqual(all_decomps.shape, batch_shape + (3, ))
        for indices in np.ndindex(*batch_shape):
            decomp_i = metrics_lib.brier_decomposition(labels[indices],
                                                       logits[indices])
            decomp_i = tf.stack(decomp_i, axis=-1).numpy()
            self.assertAllClose(decomp_i, all_decomps[indices])
Example #2
0
def run(prediction_path, validation_path, model_dir_ensemble,
        use_temp_scaling=False):
  """Runs predictions on the given dataset using the specified model."""

  logging.info('Loading predictions...')
  out_file_prefix = 'metrics_'
  if model_dir_ensemble:
    probs, labels = get_ensemble_stats(model_dir_ensemble, prediction_path)
    out_file_prefix = 'metrics_ensemble_'
  else:
    stats = array_utils.load_stats_from_tfrecords(
        prediction_path, max_records=_MAX_PREDICTIONS)
    probs = stats['probs'].astype(np.float32)
    labels = stats['labels'].astype(np.int32)

  if len(labels.shape) > 1:
    labels = np.squeeze(labels, -1)

  probs = metrics_lib.soften_probabilities(probs=probs)
  logits = inverse_softmax(probs)

  if use_temp_scaling:
    predictions_base_dir = os.path.dirname(
        os.path.dirname(os.path.dirname(prediction_path)))
    json_dir = os.path.join(predictions_base_dir, '*/*/temperature_hparam.json')

    paths = gfile.glob(json_dir)
    filestats = gfile.stat(json_dir)
    idx = np.argmax([s.mtime_nsecs for s in filestats])
    temperature_hparam_path = paths[idx]
    with gfile.GFile(temperature_hparam_path) as fh:
      temp = json.loads(fh.read())['temperature']

    logits /= temp
    probs = tf.nn.softmax(logits).numpy()
    probs = metrics_lib.soften_probabilities(probs=probs)

    out_file_prefix = 'metrics_temp_scaled_'

  # confidence vs accuracy
  thresholds = np.linspace(0, 1, 10, endpoint=False)
  accuracies, counts = metrics_lib.compute_accuracies_at_confidences(
      labels, probs, thresholds)

  overall_accuracy = (probs.argmax(-1) == labels).mean()
  accuracy_top5 = metrics_lib.accuracy_top_k(probs, labels, 5)
  accuracy_top10 = metrics_lib.accuracy_top_k(probs, labels, 10)

  probs_correct_class = probs[np.arange(probs.shape[0]), labels]
  negative_log_likelihood = -np.log(probs_correct_class).mean()
  entropy_per_example = tfd.Categorical(probs=probs).entropy().numpy()

  uncertainty, resolution, reliability = metrics_lib.brier_decomposition(
      labels=labels, logits=logits)
  brier = tf.reduce_mean(metrics_lib.brier_scores(labels=labels, logits=logits))

  ece = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, _ECE_BINS)
  ece_top5 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, _ECE_BINS, top_k=5)
  ece_top10 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, _ECE_BINS, top_k=10)

  validation_stats = array_utils.load_stats_from_tfrecords(
      validation_path, max_records=20000)
  validation_probs = validation_stats['probs'].astype(np.float32)

  bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs)
  q_ece = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, bins)

  bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=5)
  q_ece_top5 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, bins, top_k=5)

  bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=10)
  q_ece_top10 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, bins, top_k=10)

  metrics = {
      'accuracy': overall_accuracy,
      'accuracy_top5': accuracy_top5,
      'accuracy_top10': accuracy_top10,
      'accuracy_at_confidence': accuracies,
      'confidence_thresholds': thresholds,
      'confidence_counts': counts,
      'ece': ece,
      'ece_top5': ece_top5,
      'ece_top10': ece_top10,
      'q_ece': q_ece,
      'q_ece_top5': q_ece_top5,
      'q_ece_top10': q_ece_top10,
      'ece_nbins': _ECE_BINS,
      'entropy_per_example': entropy_per_example,
      'brier_uncertainty': uncertainty.numpy(),
      'brier_resolution': resolution.numpy(),
      'brier_reliability': reliability.numpy(),
      'brier_score': brier.numpy(),
      'true_labels': labels,
      'pred_labels': probs.argmax(-1),
      'prob_true_label': probs[np.arange(len(labels)), labels],
      'prob_pred_label': probs.max(-1),
      'negative_log_likelihood': negative_log_likelihood,
  }

  save_dir = os.path.dirname(prediction_path)
  split_path = split_prediction_path(prediction_path)
  prediction_file = split_path[-1]
  dataset_name = '-'.join(prediction_file.split('_')[2:])

  out_file = out_file_prefix + dataset_name + '.npz'
  array_utils.write_npz(save_dir, out_file, metrics)