def testReadRetrievalSolutionWorks(self): # Define inputs. file_path = os.path.join(FLAGS.test_tmpdir, 'retrieval_solution.csv') with tf.io.gfile.GFile(file_path, 'w') as f: f.write('id,images,Usage\n') f.write('0123456789abcdef,None,Ignored\n') f.write( '0223456789abcdef,fedcba9876543210 fedcba9876543200,Public\n') f.write('0323456789abcdef,fedcba9876543200,Private\n') f.write('0423456789abcdef,fedcba9876543220,Private\n') f.write('0523456789abcdef,None,Ignored\n') # Run tested function. (public_solution, private_solution, ignored_ids) = dataset_file_io.ReadSolution( file_path, dataset_file_io.RETRIEVAL_TASK_ID) # Define expected results. expected_public_solution = { '0223456789abcdef': ['fedcba9876543210', 'fedcba9876543200'], } expected_private_solution = { '0323456789abcdef': ['fedcba9876543200'], '0423456789abcdef': ['fedcba9876543220'], } expected_ignored_ids = ['0123456789abcdef', '0523456789abcdef'] # Compare actual and expected results. self.assertEqual(public_solution, expected_public_solution) self.assertEqual(private_solution, expected_private_solution) self.assertEqual(ignored_ids, expected_ignored_ids)
def testReadRecognitionSolutionWorks(self): # Define inputs. file_path = os.path.join(FLAGS.test_tmpdir, 'recognition_solution.csv') with tf.io.gfile.GFile(file_path, 'w') as f: f.write('id,landmarks,Usage\n') f.write('0123456789abcdef,0 12,Public\n') f.write('0223456789abcdef,,Public\n') f.write('0323456789abcdef,100,Ignored\n') f.write('0423456789abcdef,1,Private\n') f.write('0523456789abcdef,,Ignored\n') # Run tested function. (public_solution, private_solution, ignored_ids) = dataset_file_io.ReadSolution( file_path, dataset_file_io.RECOGNITION_TASK_ID) # Define expected results. expected_public_solution = { '0123456789abcdef': [0, 12], '0223456789abcdef': [] } expected_private_solution = { '0423456789abcdef': [1], } expected_ignored_ids = ['0323456789abcdef', '0523456789abcdef'] # Compare actual and expected results. self.assertEqual(public_solution, expected_public_solution) self.assertEqual(private_solution, expected_private_solution) self.assertEqual(ignored_ids, expected_ignored_ids)
def main(argv): if len(argv) > 1: raise RuntimeError('Too many command-line arguments.') # Read solution. print('Reading solution...') public_solution, private_solution, ignored_ids = dataset_file_io.ReadSolution( cmd_args.solution_path, dataset_file_io.RETRIEVAL_TASK_ID) print('done!') # Read predictions. print('Reading predictions...') public_predictions, private_predictions = dataset_file_io.ReadPredictions( cmd_args.predictions_path, set(public_solution.keys()), set(private_solution.keys()), set(ignored_ids), dataset_file_io.RETRIEVAL_TASK_ID) print('done!') # Mean average precision. print('**********************************************') print('(Public) Mean Average Precision: %f' % metrics.MeanAveragePrecision(public_predictions, public_solution)) print('(Private) Mean Average Precision: %f' % metrics.MeanAveragePrecision(private_predictions, private_solution)) # Mean precision@k. print('**********************************************') public_precisions = 100.0 * metrics.MeanPrecisions(public_predictions, public_solution) private_precisions = 100.0 * metrics.MeanPrecisions( private_predictions, private_solution) print('(Public) Mean precisions: P@1: %.2f, P@5: %.2f, P@10: %.2f, ' 'P@50: %.2f, P@100: %.2f' % (public_precisions[0], public_precisions[4], public_precisions[9], public_precisions[49], public_precisions[99])) print('(Private) Mean precisions: P@1: %.2f, P@5: %.2f, P@10: %.2f, ' 'P@50: %.2f, P@100: %.2f' % (private_precisions[0], private_precisions[4], private_precisions[9], private_precisions[49], private_precisions[99])) # Mean/median position of first correct. print('**********************************************') public_mean_position, public_median_position = metrics.MeanMedianPosition( public_predictions, public_solution) private_mean_position, private_median_position = metrics.MeanMedianPosition( private_predictions, private_solution) print('(Public) Mean position: %.2f, median position: %.2f' % (public_mean_position, public_median_position)) print('(Private) Mean position: %.2f, median position: %.2f' % (private_mean_position, private_median_position))
def main(argv): if len(argv) > 1: raise RuntimeError('Too many command-line arguments.') # Read solution. print('Reading solution...') public_solution, private_solution, ignored_ids = dataset_file_io.ReadSolution( cmd_args.solution_path, dataset_file_io.RECOGNITION_TASK_ID) print('done!') # Read predictions. print('Reading predictions...') public_predictions, private_predictions = dataset_file_io.ReadPredictions( cmd_args.predictions_path, set(public_solution.keys()), set(private_solution.keys()), set(ignored_ids), dataset_file_io.RECOGNITION_TASK_ID) print('done!') # Global Average Precision. print('**********************************************') print('(Public) Global Average Precision: %f' % metrics.GlobalAveragePrecision(public_predictions, public_solution)) print( '(Private) Global Average Precision: %f' % metrics.GlobalAveragePrecision(private_predictions, private_solution)) # Global Average Precision ignoring non-landmark queries. print('**********************************************') print( '(Public) Global Average Precision ignoring non-landmark queries: %f' % metrics.GlobalAveragePrecision(public_predictions, public_solution, ignore_non_gt_test_images=True)) print( '(Private) Global Average Precision ignoring non-landmark queries: %f' % metrics.GlobalAveragePrecision(private_predictions, private_solution, ignore_non_gt_test_images=True)) # Top-1 accuracy. print('**********************************************') print('(Public) Top-1 accuracy: %.2f' % (100.0 * metrics.Top1Accuracy(public_predictions, public_solution))) print( '(Private) Top-1 accuracy: %.2f' % (100.0 * metrics.Top1Accuracy(private_predictions, private_solution)))