示例#1
0
def _eval_for_set(
    name,
    task,
    interaction_file,
    prediction_file,
    global_step,
):
  """Computes eval metric from predictions."""
  if task in [
      tasks.Task.SQA, tasks.Task.WTQ, tasks.Task.WIKISQL,
      tasks.Task.WIKISQL_SUPERVISED
  ]:
    if not tf.io.gfile.exists(prediction_file):
      _warn(
          f"Can't evaluate for {name} because {prediction_file} doesn't exist.")
      return
    test_examples = calc_metrics_utils.read_data_examples_from_interactions(
        interaction_file)
    calc_metrics_utils.read_predictions(
        predictions_path=prediction_file,
        examples=test_examples,
    )
    denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy(
        examples=test_examples,
        denotation_errors_path=None,
        predictions_file_name=None,
    )
    _print(f'{name} denotation accuracy: {denotation_accuracy:0.4f}')
  else:
    raise ValueError(f'Unknown task: {task.name}')
示例#2
0
 def test_calc_denotation_accuracy_handles_low_precision_floats(self):
     test_tmpdir, output_tables_file, table_name = _write_tables_dict(
         headers=['FLOAT'], data=[['992.39']])
     data_path = _write_dataset([[
         'dev-0',
         '0',
         '0',
         '-',
         table_name,
         '[]',
         '[]',
         calc_metrics_utils._AggregationFunction.NONE,
         '992.3900146484375',
     ]])
     examples = _read_data_examples(data_path)
     predictions_path = _write_predictions(
         data=[['dev-0', '0', '0', '["(0, 0)"]', '0', '0']])
     calc_metrics_utils.read_predictions(predictions_path, examples)
     denotation_accuracy = _calc_denotation_accuracy(
         output_tables_file,
         examples,
         denotation_errors_path=test_tmpdir,
         predictions_file_name='predictions',
     )
     self.assertEqual(1.0, denotation_accuracy)
示例#3
0
 def test_read_predictions(self):
     data_path = _write_synthetic_dataset('table_1')
     examples = _read_data_examples(data_path)
     predictions_path = _write_synthetic_predictions()
     calc_metrics_utils.read_predictions(predictions_path, examples)
     self.assertEqual(examples['dev-2-0_0'].gold_agg_function,
                      calc_metrics_utils._AggregationFunction.COUNT)
     self.assertEqual(examples['dev-2-0_0'].pred_agg_function,
                      calc_metrics_utils._AggregationFunction.NONE)
示例#4
0
def _eval_for_set(
    model_dir,
    name,
    task,
    interaction_file,
    prediction_file,
    global_step,
):
    """Computes eval metric from predictions."""
    if not tf.io.gfile.exists(prediction_file):
        _warn(
            f"Can't evaluate for {name} because {prediction_file} doesn't exist."
        )
        return
    test_examples = calc_metrics_utils.read_data_examples_from_interactions(
        interaction_file)
    calc_metrics_utils.read_predictions(
        predictions_path=prediction_file,
        examples=test_examples,
    )
    if task in [
            tasks.Task.SQA, tasks.Task.WTQ, tasks.Task.WIKISQL,
            tasks.Task.WIKISQL_SUPERVISED
    ]:
        denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy(
            examples=test_examples,
            denotation_errors_path=None,
            predictions_file_name=None,
        )
        if global_step is not None:
            _create_measurements_for_metrics(
                {'denotation_accuracy': denotation_accuracy},
                global_step=global_step,
                model_dir=model_dir,
                name=name,
            )
    elif task == tasks.Task.TABFACT:
        accuracy = calc_metrics_utils.calc_classification_accuracy(
            test_examples)
        if global_step is not None:
            _create_measurements_for_metrics(
                {'accuracy': accuracy},
                global_step=global_step,
                model_dir=model_dir,
                name=name,
            )
    else:
        raise ValueError(f'Unknown task: {task.name}')
示例#5
0
def main(_):
    examples = calc_metrics_utils.read_data_examples_from_interactions(
        FLAGS.interactions_file)

    prediction_file_name = os.path.basename(FLAGS.prediction_files)
    calc_metrics_utils.read_predictions(FLAGS.prediction_files, examples)
    if FLAGS.is_strong_supervision_available:
        results = calc_metrics_utils.calc_structure_metrics(
            examples, FLAGS.denotation_errors_path)
        print('%s: joint_accuracy=%s' %
              (FLAGS.prediction_files, results.joint_acc))

    denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy(
        examples, FLAGS.denotation_errors_path, prediction_file_name)
    print('%s: denotation_accuracy=%s' %
          (FLAGS.prediction_files, denotation_accuracy))
 def test_weighted_denotation_accuracy(self):
   test_tmpdir, output_tables_file, table_name = _write_tables_dict()
   data_path = _write_synthetic_dataset(table_name)
   examples = _read_data_examples(data_path)
   predictions_path = _write_synthetic_predictions()
   calc_metrics_utils.read_predictions(predictions_path, examples)
   predictions_file_name = 'predictions'
   stats = _calc_weighted_denotation_accuracy(
       output_tables_file,
       examples,
       denotation_errors_path=test_tmpdir,
       predictions_file_name=predictions_file_name,
       add_weights=True,
   )
   self.assertEqual(stats['denotation_accuracy'], 0.8)
   self.assertEqual(stats['weighted_denotation_accuracy'], 0.5)
示例#7
0
 def test_calc_denotation_accuracy_handles_nans(self):
     test_tmpdir, output_tables_file, table_name = _write_tables_dict()
     data_path = _write_dataset([[
         'dev-0', '0', '0', '-', table_name, '[]', '[]',
         calc_metrics_utils._AggregationFunction.SUM, 'NAN'
     ]])
     examples = _read_data_examples(data_path)
     predictions_path = _write_predictions(
         data=[['dev-0', '0', '0', '[]', '0', '1']])
     calc_metrics_utils.read_predictions(predictions_path, examples)
     denotation_accuracy = _calc_denotation_accuracy(
         output_tables_file,
         examples,
         denotation_errors_path=test_tmpdir,
         predictions_file_name='predictions',
     )
     self.assertEqual(1.0, denotation_accuracy)
  def test_calc_structure_metrics(self):
    data_path = _write_synthetic_dataset('table_1')
    examples = _read_data_examples(data_path)
    predictions_path = _write_synthetic_predictions()
    calc_metrics_utils.read_predictions(predictions_path, examples)
    test_tmpdir = tempfile.mkdtemp()
    results = calc_metrics_utils.calc_structure_metrics(
        examples, denotation_errors_path=test_tmpdir)
    self.assertEqual(results.aggregation_acc, 0.6)
    self.assertEqual(results.cell_acc, 0.6)
    self.assertEqual(results.joint_acc, 0.6)

    denotation_errors = pd.read_csv(
        os.path.join(test_tmpdir, 'structured_examples.tsv'), sep='\t')
    self.assertEqual(denotation_errors.iloc[0, 1], 'dev-0-0_0')
    self.assertEqual(denotation_errors.iloc[0, 2],
                     calc_metrics_utils._Answer.NONE)
    self.assertEqual(denotation_errors.iloc[0, 3],
                     calc_metrics_utils._Answer.NONE)
示例#9
0
 def test_read_predictions_without_pred_aggr(self):
     predictions_path = _write_predictions(
         data=[['dev-0', '0', '0', '["(0,0)"]']],
         headers=('id', 'annotator', 'position', 'answer_coordinates'))
     examples = {
         'dev-0-0_0':
         calc_metrics_utils.Example(
             example_id='dev-0-0_0',
             question='q',
             table_id='tab_0',
             table=pd.DataFrame(),
             gold_cell_coo={},
             gold_agg_function=calc_metrics_utils._AggregationFunction.NONE,
             float_answer=None,
             has_gold_answer=True,
         )
     }
     calc_metrics_utils.read_predictions(predictions_path, examples)
     self.assertLen(examples, 1)
     self.assertEqual(
         next(iter(examples.values())).pred_agg_function,
         calc_metrics_utils._AggregationFunction.NONE)
示例#10
0
    def test_denotation_accuracy(self):
        test_tmpdir, output_tables_file, table_name = _write_tables_dict()
        data_path = _write_synthetic_dataset(table_name)
        examples = _read_data_examples(data_path)
        predictions_path = _write_synthetic_predictions()
        calc_metrics_utils.read_predictions(predictions_path, examples)
        predictions_file_name = 'predictions'
        denotation_accuracy = _calc_denotation_accuracy(
            output_tables_file,
            examples,
            denotation_errors_path=test_tmpdir,
            predictions_file_name=predictions_file_name,
        )
        self.assertEqual(denotation_accuracy, 0.8)

        denotation_errors = pd.read_csv(os.path.join(
            test_tmpdir,
            'denotation_examples_{}'.format(predictions_file_name)),
                                        sep='\t')
        self.assertEqual(denotation_errors.iloc[0, 1], 'dev-0-0_0')
        self.assertEqual(denotation_errors.iloc[0, 2], '-')
        self.assertEqual(denotation_errors.iloc[0, 5], "['6.13', 'Richmond']")
        self.assertEqual(denotation_errors.iloc[0, 7], '[(2, 1), (2, 2)]')