def main(argv): if len(argv) > 1: raise RuntimeError('Too many command-line arguments.') # Read solution. print('Reading solution...') public_solution, private_solution, ignored_ids = dataset_file_io.ReadSolution( cmd_args.solution_path, dataset_file_io.RETRIEVAL_TASK_ID) print('done!') # Read predictions. print('Reading predictions...') public_predictions, private_predictions = dataset_file_io.ReadPredictions( cmd_args.predictions_path, set(public_solution.keys()), set(private_solution.keys()), set(ignored_ids), dataset_file_io.RETRIEVAL_TASK_ID) print('done!') # Mean average precision. print('**********************************************') print('(Public) Mean Average Precision: %f' % metrics.MeanAveragePrecision(public_predictions, public_solution)) print('(Private) Mean Average Precision: %f' % metrics.MeanAveragePrecision(private_predictions, private_solution)) # Mean precision@k. print('**********************************************') public_precisions = 100.0 * metrics.MeanPrecisions(public_predictions, public_solution) private_precisions = 100.0 * metrics.MeanPrecisions(private_predictions, private_solution) print('(Public) Mean precisions: P@1: %.2f, P@5: %.2f, P@10: %.2f, ' 'P@50: %.2f, P@100: %.2f' % (public_precisions[0], public_precisions[4], public_precisions[9], public_precisions[49], public_precisions[99])) print('(Private) Mean precisions: P@1: %.2f, P@5: %.2f, P@10: %.2f, ' 'P@50: %.2f, P@100: %.2f' % (private_precisions[0], private_precisions[4], private_precisions[9], private_precisions[49], private_precisions[99])) # Mean/median position of first correct. print('**********************************************') public_mean_position, public_median_position = metrics.MeanMedianPosition( public_predictions, public_solution) private_mean_position, private_median_position = metrics.MeanMedianPosition( private_predictions, private_solution) print('(Public) Mean position: %.2f, median position: %.2f' % (public_mean_position, public_median_position)) print('(Private) Mean position: %.2f, median position: %.2f' % (private_mean_position, private_median_position))
def testMeanAveragePrecisionWorks(self): # Define input. predictions = _CreateRetrievalPredictions() solution = _CreateRetrievalSolution() # Run tested function. mean_ap = metrics.MeanAveragePrecision(predictions, solution) # Define expected results. expected_mean_ap = 0.458333 # Compare actual and expected results. self.assertAllClose(mean_ap, expected_mean_ap)