def main(argv):
    if len(argv) > 1:
        raise RuntimeError('Too many command-line arguments.')

    # Read solution.
    print('Reading solution...')
    public_solution, private_solution, ignored_ids = dataset_file_io.ReadSolution(
        cmd_args.solution_path, dataset_file_io.RECOGNITION_TASK_ID)
    print('done!')

    # Read predictions.
    print('Reading predictions...')
    public_predictions, private_predictions = dataset_file_io.ReadPredictions(
        cmd_args.predictions_path, set(public_solution.keys()),
        set(private_solution.keys()), set(ignored_ids),
        dataset_file_io.RECOGNITION_TASK_ID)
    print('done!')

    # Global Average Precision.
    print('**********************************************')
    print('(Public)  Global Average Precision: %f' %
          metrics.GlobalAveragePrecision(public_predictions, public_solution))
    print(
        '(Private) Global Average Precision: %f' %
        metrics.GlobalAveragePrecision(private_predictions, private_solution))

    # Global Average Precision ignoring non-landmark queries.
    print('**********************************************')
    print(
        '(Public)  Global Average Precision ignoring non-landmark queries: %f'
        % metrics.GlobalAveragePrecision(public_predictions,
                                         public_solution,
                                         ignore_non_gt_test_images=True))
    print(
        '(Private) Global Average Precision ignoring non-landmark queries: %f'
        % metrics.GlobalAveragePrecision(private_predictions,
                                         private_solution,
                                         ignore_non_gt_test_images=True))

    # Top-1 accuracy.
    print('**********************************************')
    print('(Public)  Top-1 accuracy: %.2f' %
          (100.0 * metrics.Top1Accuracy(public_predictions, public_solution)))
    print(
        '(Private) Top-1 accuracy: %.2f' %
        (100.0 * metrics.Top1Accuracy(private_predictions, private_solution)))
Ejemplo n.º 2
0
    def testGlobalAveragePrecisionWorks(self):
        # Define input.
        predictions = _CreateRecognitionPredictions()
        solution = _CreateRecognitionSolution()

        # Run tested function.
        gap = metrics.GlobalAveragePrecision(predictions, solution)

        # Define expected results.
        expected_gap = 0.166667

        # Compare actual and expected results.
        self.assertAllClose(gap, expected_gap)
Ejemplo n.º 3
0
    def testGlobalAveragePrecisionIgnoreNonGroundTruthWorks(self):
        # Define input.
        predictions = _CreateRecognitionPredictions()
        solution = _CreateRecognitionSolution()

        # Run tested function.
        gap = metrics.GlobalAveragePrecision(predictions,
                                             solution,
                                             ignore_non_gt_test_images=True)

        # Define expected results.
        expected_gap = 0.333333

        # Compare actual and expected results.
        self.assertAllClose(gap, expected_gap)