コード例 #1
0
ファイル: run_task_main.py プロジェクト: vishwajeet93/tapas
def _eval_for_set(
    name,
    task,
    interaction_file,
    prediction_file,
    global_step,
):
  """Computes eval metric from predictions."""
  if task in [
      tasks.Task.SQA, tasks.Task.WTQ, tasks.Task.WIKISQL,
      tasks.Task.WIKISQL_SUPERVISED
  ]:
    if not tf.io.gfile.exists(prediction_file):
      _warn(
          f"Can't evaluate for {name} because {prediction_file} doesn't exist.")
      return
    test_examples = calc_metrics_utils.read_data_examples_from_interactions(
        interaction_file)
    calc_metrics_utils.read_predictions(
        predictions_path=prediction_file,
        examples=test_examples,
    )
    denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy(
        examples=test_examples,
        denotation_errors_path=None,
        predictions_file_name=None,
    )
    _print(f'{name} denotation accuracy: {denotation_accuracy:0.4f}')
  else:
    raise ValueError(f'Unknown task: {task.name}')
コード例 #2
0
 def test_calc_denotation_accuracy_without_gold_answer(self):
     table = pd.DataFrame([['a', 'b'], ['0', '1']], columns=['A', 'B'])
     denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy(
         examples={
             '0':
             calc_metrics_utils.Example(
                 example_id='0',
                 question='q',
                 table_id='tab_0',
                 table=table,
                 gold_cell_coo=set(),
                 gold_agg_function=calc_metrics_utils._AggregationFunction.
                 NONE,
                 float_answer=None,
                 has_gold_answer=False,
             ),
             '1':
             calc_metrics_utils.Example(
                 example_id='1',
                 question='q',
                 table_id='tab_0',
                 table=table,
                 gold_cell_coo={(0, 0)},
                 gold_agg_function=calc_metrics_utils._AggregationFunction.
                 NONE,
                 float_answer=None,
                 has_gold_answer=True,
                 pred_cell_coo={(0, 0)},
             )
         },
         denotation_errors_path=None,
         predictions_file_name=None)
     self.assertEqual(0.5, denotation_accuracy)
コード例 #3
0
def _calc_denotation_accuracy(tables_file, examples, denotation_errors_path,
                              predictions_file_name):
    with tf.io.gfile.GFile(tables_file, 'rb') as f:
        tables = pickle.load(f)
    for example in examples.values():
        example.table = tables[example.table_id]
    return calc_metrics_utils.calc_denotation_accuracy(examples,
                                                       denotation_errors_path,
                                                       predictions_file_name)
コード例 #4
0
def _eval_for_set(
    model_dir,
    name,
    task,
    interaction_file,
    prediction_file,
    global_step,
):
    """Computes eval metric from predictions."""
    if not tf.io.gfile.exists(prediction_file):
        _warn(
            f"Can't evaluate for {name} because {prediction_file} doesn't exist."
        )
        return
    test_examples = calc_metrics_utils.read_data_examples_from_interactions(
        interaction_file)
    calc_metrics_utils.read_predictions(
        predictions_path=prediction_file,
        examples=test_examples,
    )
    if task in [
            tasks.Task.SQA, tasks.Task.WTQ, tasks.Task.WIKISQL,
            tasks.Task.WIKISQL_SUPERVISED
    ]:
        denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy(
            examples=test_examples,
            denotation_errors_path=None,
            predictions_file_name=None,
        )
        if global_step is not None:
            _create_measurements_for_metrics(
                {'denotation_accuracy': denotation_accuracy},
                global_step=global_step,
                model_dir=model_dir,
                name=name,
            )
    elif task == tasks.Task.TABFACT:
        accuracy = calc_metrics_utils.calc_classification_accuracy(
            test_examples)
        if global_step is not None:
            _create_measurements_for_metrics(
                {'accuracy': accuracy},
                global_step=global_step,
                model_dir=model_dir,
                name=name,
            )
    else:
        raise ValueError(f'Unknown task: {task.name}')
コード例 #5
0
ファイル: calc_metrics.py プロジェクト: google-research/tapas
def main(_):
    examples = calc_metrics_utils.read_data_examples_from_interactions(
        FLAGS.interactions_file)

    prediction_file_name = os.path.basename(FLAGS.prediction_files)
    calc_metrics_utils.read_predictions(FLAGS.prediction_files, examples)
    if FLAGS.is_strong_supervision_available:
        results = calc_metrics_utils.calc_structure_metrics(
            examples, FLAGS.denotation_errors_path)
        print('%s: joint_accuracy=%s' %
              (FLAGS.prediction_files, results.joint_acc))

    denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy(
        examples, FLAGS.denotation_errors_path, prediction_file_name)
    print('%s: denotation_accuracy=%s' %
          (FLAGS.prediction_files, denotation_accuracy))