示例#1
0
def main(argv):
  if len(argv) > 1:
    raise RuntimeError('Too many command-line arguments.')

  # Read solution.
  print('Reading solution...')
  public_solution, private_solution, ignored_ids = dataset_file_io.ReadSolution(
      cmd_args.solution_path, dataset_file_io.RETRIEVAL_TASK_ID)
  print('done!')

  # Read predictions.
  print('Reading predictions...')
  public_predictions, private_predictions = dataset_file_io.ReadPredictions(
      cmd_args.predictions_path, set(public_solution.keys()),
      set(private_solution.keys()), set(ignored_ids),
      dataset_file_io.RETRIEVAL_TASK_ID)
  print('done!')

  # Mean average precision.
  print('**********************************************')
  print('(Public)  Mean Average Precision: %f' %
        metrics.MeanAveragePrecision(public_predictions, public_solution))
  print('(Private) Mean Average Precision: %f' %
        metrics.MeanAveragePrecision(private_predictions, private_solution))

  # Mean precision@k.
  print('**********************************************')
  public_precisions = 100.0 * metrics.MeanPrecisions(public_predictions,
                                                     public_solution)
  private_precisions = 100.0 * metrics.MeanPrecisions(private_predictions,
                                                      private_solution)
  print('(Public)  Mean precisions: P@1: %.2f, P@5: %.2f, P@10: %.2f, '
        'P@50: %.2f, P@100: %.2f' %
        (public_precisions[0], public_precisions[4], public_precisions[9],
         public_precisions[49], public_precisions[99]))
  print('(Private) Mean precisions: P@1: %.2f, P@5: %.2f, P@10: %.2f, '
        'P@50: %.2f, P@100: %.2f' %
        (private_precisions[0], private_precisions[4], private_precisions[9],
         private_precisions[49], private_precisions[99]))

  # Mean/median position of first correct.
  print('**********************************************')
  public_mean_position, public_median_position = metrics.MeanMedianPosition(
      public_predictions, public_solution)
  private_mean_position, private_median_position = metrics.MeanMedianPosition(
      private_predictions, private_solution)
  print('(Public)  Mean position: %.2f, median position: %.2f' %
        (public_mean_position, public_median_position))
  print('(Private) Mean position: %.2f, median position: %.2f' %
        (private_mean_position, private_median_position))
    def testMeanAveragePrecisionWorks(self):
        # Define input.
        predictions = _CreateRetrievalPredictions()
        solution = _CreateRetrievalSolution()

        # Run tested function.
        mean_ap = metrics.MeanAveragePrecision(predictions, solution)

        # Define expected results.
        expected_mean_ap = 0.458333

        # Compare actual and expected results.
        self.assertAllClose(mean_ap, expected_mean_ap)