Ejemplo n.º 1
0
    def testReadRetrievalSolutionWorks(self):
        # Define inputs.
        file_path = os.path.join(tf.test.get_temp_dir(),
                                 'retrieval_solution.csv')
        with tf.gfile.GFile(file_path, 'w') as f:
            f.write('id,images,Usage\n')
            f.write('0123456789abcdef,None,Ignored\n')
            f.write(
                '0223456789abcdef,fedcba9876543210 fedcba9876543200,Public\n')
            f.write('0323456789abcdef,fedcba9876543200,Private\n')
            f.write('0423456789abcdef,fedcba9876543220,Private\n')
            f.write('0523456789abcdef,None,Ignored\n')

        # Run tested function.
        (public_solution, private_solution,
         ignored_ids) = dataset_file_io.ReadSolution(
             file_path, dataset_file_io.RETRIEVAL_TASK_ID)

        # Define expected results.
        expected_public_solution = {
            '0223456789abcdef': ['fedcba9876543210', 'fedcba9876543200'],
        }
        expected_private_solution = {
            '0323456789abcdef': ['fedcba9876543200'],
            '0423456789abcdef': ['fedcba9876543220'],
        }
        expected_ignored_ids = ['0123456789abcdef', '0523456789abcdef']

        # Compare actual and expected results.
        self.assertEqual(public_solution, expected_public_solution)
        self.assertEqual(private_solution, expected_private_solution)
        self.assertEqual(ignored_ids, expected_ignored_ids)
Ejemplo n.º 2
0
    def testReadRecognitionSolutionWorks(self):
        # Define inputs.
        file_path = os.path.join(tf.test.get_temp_dir(),
                                 'recognition_solution.csv')
        with tf.gfile.GFile(file_path, 'w') as f:
            f.write('id,landmarks,Usage\n')
            f.write('0123456789abcdef,0 12,Public\n')
            f.write('0223456789abcdef,,Public\n')
            f.write('0323456789abcdef,100,Ignored\n')
            f.write('0423456789abcdef,1,Private\n')
            f.write('0523456789abcdef,,Ignored\n')

        # Run tested function.
        (public_solution, private_solution,
         ignored_ids) = dataset_file_io.ReadSolution(
             file_path, dataset_file_io.RECOGNITION_TASK_ID)

        # Define expected results.
        expected_public_solution = {
            '0123456789abcdef': [0, 12],
            '0223456789abcdef': []
        }
        expected_private_solution = {
            '0423456789abcdef': [1],
        }
        expected_ignored_ids = ['0323456789abcdef', '0523456789abcdef']

        # Compare actual and expected results.
        self.assertEqual(public_solution, expected_public_solution)
        self.assertEqual(private_solution, expected_private_solution)
        self.assertEqual(ignored_ids, expected_ignored_ids)
Ejemplo n.º 3
0
def main(argv):
  if len(argv) > 1:
    raise RuntimeError('Too many command-line arguments.')

  # Read solution.
  print('Reading solution...')
  public_solution, private_solution, ignored_ids = dataset_file_io.ReadSolution(
      cmd_args.solution_path, dataset_file_io.RETRIEVAL_TASK_ID)
  print('done!')

  # Read predictions.
  print('Reading predictions...')
  public_predictions, private_predictions = dataset_file_io.ReadPredictions(
      cmd_args.predictions_path, set(public_solution.keys()),
      set(private_solution.keys()), set(ignored_ids),
      dataset_file_io.RETRIEVAL_TASK_ID)
  print('done!')

  # Mean average precision.
  print('**********************************************')
  print('(Public)  Mean Average Precision: %f' %
        metrics.MeanAveragePrecision(public_predictions, public_solution))
  print('(Private) Mean Average Precision: %f' %
        metrics.MeanAveragePrecision(private_predictions, private_solution))

  # Mean precision@k.
  print('**********************************************')
  public_precisions = 100.0 * metrics.MeanPrecisions(public_predictions,
                                                     public_solution)
  private_precisions = 100.0 * metrics.MeanPrecisions(private_predictions,
                                                      private_solution)
  print('(Public)  Mean precisions: P@1: %.2f, P@5: %.2f, P@10: %.2f, '
        'P@50: %.2f, P@100: %.2f' %
        (public_precisions[0], public_precisions[4], public_precisions[9],
         public_precisions[49], public_precisions[99]))
  print('(Private) Mean precisions: P@1: %.2f, P@5: %.2f, P@10: %.2f, '
        'P@50: %.2f, P@100: %.2f' %
        (private_precisions[0], private_precisions[4], private_precisions[9],
         private_precisions[49], private_precisions[99]))

  # Mean/median position of first correct.
  print('**********************************************')
  public_mean_position, public_median_position = metrics.MeanMedianPosition(
      public_predictions, public_solution)
  private_mean_position, private_median_position = metrics.MeanMedianPosition(
      private_predictions, private_solution)
  print('(Public)  Mean position: %.2f, median position: %.2f' %
        (public_mean_position, public_median_position))
  print('(Private) Mean position: %.2f, median position: %.2f' %
        (private_mean_position, private_median_position))
def main(argv):
    if len(argv) > 1:
        raise RuntimeError('Too many command-line arguments.')

    # Read solution.
    print('Reading solution...')
    public_solution, private_solution, ignored_ids = dataset_file_io.ReadSolution(
        cmd_args.solution_path, dataset_file_io.RECOGNITION_TASK_ID)
    print('done!')

    # Read predictions.
    print('Reading predictions...')
    public_predictions, private_predictions = dataset_file_io.ReadPredictions(
        cmd_args.predictions_path, set(public_solution.keys()),
        set(private_solution.keys()), set(ignored_ids),
        dataset_file_io.RECOGNITION_TASK_ID)
    print('done!')

    # Global Average Precision.
    print('**********************************************')
    print('(Public)  Global Average Precision: %f' %
          metrics.GlobalAveragePrecision(public_predictions, public_solution))
    print(
        '(Private) Global Average Precision: %f' %
        metrics.GlobalAveragePrecision(private_predictions, private_solution))

    # Global Average Precision ignoring non-landmark queries.
    print('**********************************************')
    print(
        '(Public)  Global Average Precision ignoring non-landmark queries: %f'
        % metrics.GlobalAveragePrecision(public_predictions,
                                         public_solution,
                                         ignore_non_gt_test_images=True))
    print(
        '(Private) Global Average Precision ignoring non-landmark queries: %f'
        % metrics.GlobalAveragePrecision(private_predictions,
                                         private_solution,
                                         ignore_non_gt_test_images=True))

    # Top-1 accuracy.
    print('**********************************************')
    print('(Public)  Top-1 accuracy: %.2f' %
          (100.0 * metrics.Top1Accuracy(public_predictions, public_solution)))
    print(
        '(Private) Top-1 accuracy: %.2f' %
        (100.0 * metrics.Top1Accuracy(private_predictions, private_solution)))