コード例 #1
0
    def test_expected_calibration_error_quantile_multiclass(self, top_k):
        bad_quantile_eces = {1: .5, 2: .25, None: .2}
        num_samples = int(1e4)
        num_classes = 5
        probabilities, labels = _make_perfectly_calibrated_multiclass(
            num_samples, num_classes)

        bins = metrics_lib.get_quantile_bins(10, probabilities, top_k=top_k)
        good_quantile_ece = metrics_lib.expected_calibration_error_multiclass(
            probabilities, labels, bins, top_k)
        bad_quantile_ece = metrics_lib.expected_calibration_error_multiclass(
            np.fliplr(probabilities), labels, bins, top_k)
        self.assertAllClose(good_quantile_ece, 0, atol=0.05)
        self.assertAllClose(bad_quantile_ece,
                            bad_quantile_eces[top_k],
                            atol=0.05)
コード例 #2
0
  def test_expected_calibration_error(self):
    np.random.seed(1)
    nsamples = 100
    probs = np.linspace(0, 1, nsamples)
    labels = np.random.rand(nsamples) < probs
    ece = metrics_lib.expected_calibration_error(probs, labels)
    bad_ece = metrics_lib.expected_calibration_error(probs / 2, labels)

    self.assertBetween(ece, 0, 1)
    self.assertBetween(bad_ece, 0, 1)
    self.assertLess(ece, bad_ece)

    bins = metrics_lib.get_quantile_bins(10, probs)
    quantile_ece = metrics_lib.expected_calibration_error(probs, labels, bins)
    bad_quantile_ece = metrics_lib.expected_calibration_error(
        probs / 2, labels, bins)

    self.assertBetween(quantile_ece, 0, 1)
    self.assertBetween(bad_quantile_ece, 0, 1)
    self.assertLess(quantile_ece, bad_quantile_ece)
コード例 #3
0
def run(prediction_path, validation_path, model_dir_ensemble,
        use_temp_scaling=False):
  """Runs predictions on the given dataset using the specified model."""

  logging.info('Loading predictions...')
  out_file_prefix = 'metrics_'
  if model_dir_ensemble:
    probs, labels = get_ensemble_stats(model_dir_ensemble, prediction_path)
    out_file_prefix = 'metrics_ensemble_'
  else:
    stats = array_utils.load_stats_from_tfrecords(
        prediction_path, max_records=_MAX_PREDICTIONS)
    probs = stats['probs'].astype(np.float32)
    labels = stats['labels'].astype(np.int32)

  if len(labels.shape) > 1:
    labels = np.squeeze(labels, -1)

  probs = metrics_lib.soften_probabilities(probs=probs)
  logits = inverse_softmax(probs)

  if use_temp_scaling:
    predictions_base_dir = os.path.dirname(
        os.path.dirname(os.path.dirname(prediction_path)))
    json_dir = os.path.join(predictions_base_dir, '*/*/temperature_hparam.json')

    paths = gfile.glob(json_dir)
    filestats = gfile.stat(json_dir)
    idx = np.argmax([s.mtime_nsecs for s in filestats])
    temperature_hparam_path = paths[idx]
    with gfile.GFile(temperature_hparam_path) as fh:
      temp = json.loads(fh.read())['temperature']

    logits /= temp
    probs = tf.nn.softmax(logits).numpy()
    probs = metrics_lib.soften_probabilities(probs=probs)

    out_file_prefix = 'metrics_temp_scaled_'

  # confidence vs accuracy
  thresholds = np.linspace(0, 1, 10, endpoint=False)
  accuracies, counts = metrics_lib.compute_accuracies_at_confidences(
      labels, probs, thresholds)

  overall_accuracy = (probs.argmax(-1) == labels).mean()
  accuracy_top5 = metrics_lib.accuracy_top_k(probs, labels, 5)
  accuracy_top10 = metrics_lib.accuracy_top_k(probs, labels, 10)

  probs_correct_class = probs[np.arange(probs.shape[0]), labels]
  negative_log_likelihood = -np.log(probs_correct_class).mean()
  entropy_per_example = tfd.Categorical(probs=probs).entropy().numpy()

  uncertainty, resolution, reliability = metrics_lib.brier_decomposition(
      labels=labels, logits=logits)
  brier = tf.reduce_mean(metrics_lib.brier_scores(labels=labels, logits=logits))

  ece = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, _ECE_BINS)
  ece_top5 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, _ECE_BINS, top_k=5)
  ece_top10 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, _ECE_BINS, top_k=10)

  validation_stats = array_utils.load_stats_from_tfrecords(
      validation_path, max_records=20000)
  validation_probs = validation_stats['probs'].astype(np.float32)

  bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs)
  q_ece = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, bins)

  bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=5)
  q_ece_top5 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, bins, top_k=5)

  bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=10)
  q_ece_top10 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, bins, top_k=10)

  metrics = {
      'accuracy': overall_accuracy,
      'accuracy_top5': accuracy_top5,
      'accuracy_top10': accuracy_top10,
      'accuracy_at_confidence': accuracies,
      'confidence_thresholds': thresholds,
      'confidence_counts': counts,
      'ece': ece,
      'ece_top5': ece_top5,
      'ece_top10': ece_top10,
      'q_ece': q_ece,
      'q_ece_top5': q_ece_top5,
      'q_ece_top10': q_ece_top10,
      'ece_nbins': _ECE_BINS,
      'entropy_per_example': entropy_per_example,
      'brier_uncertainty': uncertainty.numpy(),
      'brier_resolution': resolution.numpy(),
      'brier_reliability': reliability.numpy(),
      'brier_score': brier.numpy(),
      'true_labels': labels,
      'pred_labels': probs.argmax(-1),
      'prob_true_label': probs[np.arange(len(labels)), labels],
      'prob_pred_label': probs.max(-1),
      'negative_log_likelihood': negative_log_likelihood,
  }

  save_dir = os.path.dirname(prediction_path)
  split_path = split_prediction_path(prediction_path)
  prediction_file = split_path[-1]
  dataset_name = '-'.join(prediction_file.split('_')[2:])

  out_file = out_file_prefix + dataset_name + '.npz'
  array_utils.write_npz(save_dir, out_file, metrics)