Esempio n. 1
0
def _eval_for_set(
    name,
    task,
    interaction_file,
    prediction_file,
    global_step,
):
  """Computes eval metric from predictions."""
  if task in [
      tasks.Task.SQA, tasks.Task.WTQ, tasks.Task.WIKISQL,
      tasks.Task.WIKISQL_SUPERVISED
  ]:
    if not tf.io.gfile.exists(prediction_file):
      _warn(
          f"Can't evaluate for {name} because {prediction_file} doesn't exist.")
      return
    test_examples = calc_metrics_utils.read_data_examples_from_interactions(
        interaction_file)
    calc_metrics_utils.read_predictions(
        predictions_path=prediction_file,
        examples=test_examples,
    )
    denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy(
        examples=test_examples,
        denotation_errors_path=None,
        predictions_file_name=None,
    )
    _print(f'{name} denotation accuracy: {denotation_accuracy:0.4f}')
  else:
    raise ValueError(f'Unknown task: {task.name}')
Esempio n. 2
0
def _eval_for_set(
    model_dir,
    name,
    task,
    interaction_file,
    prediction_file,
    global_step,
):
    """Computes eval metric from predictions."""
    if not tf.io.gfile.exists(prediction_file):
        _warn(
            f"Can't evaluate for {name} because {prediction_file} doesn't exist."
        )
        return
    test_examples = calc_metrics_utils.read_data_examples_from_interactions(
        interaction_file)
    calc_metrics_utils.read_predictions(
        predictions_path=prediction_file,
        examples=test_examples,
    )
    if task in [
            tasks.Task.SQA, tasks.Task.WTQ, tasks.Task.WIKISQL,
            tasks.Task.WIKISQL_SUPERVISED
    ]:
        denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy(
            examples=test_examples,
            denotation_errors_path=None,
            predictions_file_name=None,
        )
        if global_step is not None:
            _create_measurements_for_metrics(
                {'denotation_accuracy': denotation_accuracy},
                global_step=global_step,
                model_dir=model_dir,
                name=name,
            )
    elif task == tasks.Task.TABFACT:
        accuracy = calc_metrics_utils.calc_classification_accuracy(
            test_examples)
        if global_step is not None:
            _create_measurements_for_metrics(
                {'accuracy': accuracy},
                global_step=global_step,
                model_dir=model_dir,
                name=name,
            )
    else:
        raise ValueError(f'Unknown task: {task.name}')
Esempio n. 3
0
def main(_):
    examples = calc_metrics_utils.read_data_examples_from_interactions(
        FLAGS.interactions_file)

    prediction_file_name = os.path.basename(FLAGS.prediction_files)
    calc_metrics_utils.read_predictions(FLAGS.prediction_files, examples)
    if FLAGS.is_strong_supervision_available:
        results = calc_metrics_utils.calc_structure_metrics(
            examples, FLAGS.denotation_errors_path)
        print('%s: joint_accuracy=%s' %
              (FLAGS.prediction_files, results.joint_acc))

    denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy(
        examples, FLAGS.denotation_errors_path, prediction_file_name)
    print('%s: denotation_accuracy=%s' %
          (FLAGS.prediction_files, denotation_accuracy))