コード例 #1
0
def run(method, architecture, output_dir, test_level):
  """Trains a model and records its predictions on configured datasets.

  Args:
    method: Name of modeling method (vanilla, dropout, svi, ll_svi).
    architecture: Name of DNN architecture (mlp or dropout).
    output_dir: Directory to record the trained model and output stats.
    test_level: Zero indicates no testing. One indicates testing with real data.
      Two is for testing with fake data.
  """
  fake_data = test_level > 1
  gfile.makedirs(output_dir)
  model_opts, data_opts_list = get_experiment_config(method, architecture,
                                                     test_level=test_level,
                                                     output_dir=output_dir)

  # Separately build dataset[0] with shuffle=True for training.
  dataset_train = data_lib.build_dataset(data_opts_list[0], fake_data=fake_data)
  dataset_eval = data_lib.build_dataset(data_opts_list[1], fake_data=fake_data)
  model = models_lib.build_and_train(model_opts,
                                     dataset_train, dataset_eval, output_dir)
  logging.info('Saving model to output_dir.')
  model.save_weights(output_dir + '/model.ckpt')

  for idx, data_opts in enumerate(data_opts_list):
    dataset = data_lib.build_dataset(data_opts, fake_data=fake_data)
    logging.info('Running predictions for dataset #%d', idx)
    stats = models_lib.make_predictions(model_opts, model, dataset)
    array_utils.write_npz(output_dir, 'stats_%d.npz' % idx, stats)
    del stats['logits_samples']
    array_utils.write_npz(output_dir, 'stats_small_%d.npz' % idx, stats)
コード例 #2
0
def run(dataset_name,
        model_dir,
        predictions_per_example,
        max_examples,
        output_dir,
        fake_data=False):
    """Runs predictions on the given dataset using the specified model."""
    tf.io.gfile.makedirs(output_dir)
    data_config = image_data_utils.get_data_config(dataset_name)
    dataset = data_lib.build_dataset(data_config, fake_data=fake_data)
    if max_examples:
        dataset = dataset.take(max_examples)

    model = models_lib.load_model(model_dir)
    logging.info('Starting predictions.')
    predictions = experiment_utils.make_predictions(model,
                                                    dataset.batch(_BATCH_SIZE),
                                                    predictions_per_example)

    logging.info('Done computing predictions; recording results to disk.')
    array_utils.write_npz(output_dir, 'predictions_%s.npz' % dataset_name,
                          predictions)
    del predictions['logits_samples']
    array_utils.write_npz(output_dir,
                          'predictions_small_%s.npz' % dataset_name,
                          predictions)
コード例 #3
0
def run(model_dir,
        dataset_name,
        predictions_per_example,
        max_examples,
        output_dir,
        fake_data=False):
    """Runs predictions on the given dataset using the specified model."""
    tf.io.gfile.makedirs(output_dir)
    data_config = data_lib.DataConfig.from_name(dataset_name,
                                                fake_data=fake_data)
    dataset = data_lib.build_dataset(data_config, batch_size=_BATCH_SIZE)
    if max_examples:
        dataset = dataset.take(max_examples // _BATCH_SIZE)

    model = models_lib.load_trained_model(model_dir)

    logging.info('Starting predictions.')
    predictions = models_lib.make_predictions(model, dataset,
                                              predictions_per_example)

    array_utils.write_npz(output_dir, 'predictions_%s.npz' % dataset_name,
                          predictions)
    del predictions['probs_samples']
    array_utils.write_npz(output_dir,
                          'predictions_small_%s.npz' % dataset_name,
                          predictions)
コード例 #4
0
def run(prediction_path, validation_path, model_dir_ensemble,
        use_temp_scaling=False):
  """Runs predictions on the given dataset using the specified model."""

  logging.info('Loading predictions...')
  out_file_prefix = 'metrics_'
  if model_dir_ensemble:
    probs, labels = get_ensemble_stats(model_dir_ensemble, prediction_path)
    out_file_prefix = 'metrics_ensemble_'
  else:
    stats = array_utils.load_stats_from_tfrecords(
        prediction_path, max_records=_MAX_PREDICTIONS)
    probs = stats['probs'].astype(np.float32)
    labels = stats['labels'].astype(np.int32)

  if len(labels.shape) > 1:
    labels = np.squeeze(labels, -1)

  probs = metrics_lib.soften_probabilities(probs=probs)
  logits = inverse_softmax(probs)

  if use_temp_scaling:
    predictions_base_dir = os.path.dirname(
        os.path.dirname(os.path.dirname(prediction_path)))
    json_dir = os.path.join(predictions_base_dir, '*/*/temperature_hparam.json')

    paths = gfile.glob(json_dir)
    filestats = gfile.stat(json_dir)
    idx = np.argmax([s.mtime_nsecs for s in filestats])
    temperature_hparam_path = paths[idx]
    with gfile.GFile(temperature_hparam_path) as fh:
      temp = json.loads(fh.read())['temperature']

    logits /= temp
    probs = tf.nn.softmax(logits).numpy()
    probs = metrics_lib.soften_probabilities(probs=probs)

    out_file_prefix = 'metrics_temp_scaled_'

  # confidence vs accuracy
  thresholds = np.linspace(0, 1, 10, endpoint=False)
  accuracies, counts = metrics_lib.compute_accuracies_at_confidences(
      labels, probs, thresholds)

  overall_accuracy = (probs.argmax(-1) == labels).mean()
  accuracy_top5 = metrics_lib.accuracy_top_k(probs, labels, 5)
  accuracy_top10 = metrics_lib.accuracy_top_k(probs, labels, 10)

  probs_correct_class = probs[np.arange(probs.shape[0]), labels]
  negative_log_likelihood = -np.log(probs_correct_class).mean()
  entropy_per_example = tfd.Categorical(probs=probs).entropy().numpy()

  uncertainty, resolution, reliability = metrics_lib.brier_decomposition(
      labels=labels, logits=logits)
  brier = tf.reduce_mean(metrics_lib.brier_scores(labels=labels, logits=logits))

  ece = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, _ECE_BINS)
  ece_top5 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, _ECE_BINS, top_k=5)
  ece_top10 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, _ECE_BINS, top_k=10)

  validation_stats = array_utils.load_stats_from_tfrecords(
      validation_path, max_records=20000)
  validation_probs = validation_stats['probs'].astype(np.float32)

  bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs)
  q_ece = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, bins)

  bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=5)
  q_ece_top5 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, bins, top_k=5)

  bins = metrics_lib.get_quantile_bins(_ECE_BINS, validation_probs, top_k=10)
  q_ece_top10 = metrics_lib.expected_calibration_error_multiclass(
      probs, labels, bins, top_k=10)

  metrics = {
      'accuracy': overall_accuracy,
      'accuracy_top5': accuracy_top5,
      'accuracy_top10': accuracy_top10,
      'accuracy_at_confidence': accuracies,
      'confidence_thresholds': thresholds,
      'confidence_counts': counts,
      'ece': ece,
      'ece_top5': ece_top5,
      'ece_top10': ece_top10,
      'q_ece': q_ece,
      'q_ece_top5': q_ece_top5,
      'q_ece_top10': q_ece_top10,
      'ece_nbins': _ECE_BINS,
      'entropy_per_example': entropy_per_example,
      'brier_uncertainty': uncertainty.numpy(),
      'brier_resolution': resolution.numpy(),
      'brier_reliability': reliability.numpy(),
      'brier_score': brier.numpy(),
      'true_labels': labels,
      'pred_labels': probs.argmax(-1),
      'prob_true_label': probs[np.arange(len(labels)), labels],
      'prob_pred_label': probs.max(-1),
      'negative_log_likelihood': negative_log_likelihood,
  }

  save_dir = os.path.dirname(prediction_path)
  split_path = split_prediction_path(prediction_path)
  prediction_file = split_path[-1]
  dataset_name = '-'.join(prediction_file.split('_')[2:])

  out_file = out_file_prefix + dataset_name + '.npz'
  array_utils.write_npz(save_dir, out_file, metrics)