def _eval_for_set( name, task, interaction_file, prediction_file, global_step, ): """Computes eval metric from predictions.""" if task in [ tasks.Task.SQA, tasks.Task.WTQ, tasks.Task.WIKISQL, tasks.Task.WIKISQL_SUPERVISED ]: if not tf.io.gfile.exists(prediction_file): _warn( f"Can't evaluate for {name} because {prediction_file} doesn't exist.") return test_examples = calc_metrics_utils.read_data_examples_from_interactions( interaction_file) calc_metrics_utils.read_predictions( predictions_path=prediction_file, examples=test_examples, ) denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy( examples=test_examples, denotation_errors_path=None, predictions_file_name=None, ) _print(f'{name} denotation accuracy: {denotation_accuracy:0.4f}') else: raise ValueError(f'Unknown task: {task.name}')
def test_calc_denotation_accuracy_without_gold_answer(self): table = pd.DataFrame([['a', 'b'], ['0', '1']], columns=['A', 'B']) denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy( examples={ '0': calc_metrics_utils.Example( example_id='0', question='q', table_id='tab_0', table=table, gold_cell_coo=set(), gold_agg_function=calc_metrics_utils._AggregationFunction. NONE, float_answer=None, has_gold_answer=False, ), '1': calc_metrics_utils.Example( example_id='1', question='q', table_id='tab_0', table=table, gold_cell_coo={(0, 0)}, gold_agg_function=calc_metrics_utils._AggregationFunction. NONE, float_answer=None, has_gold_answer=True, pred_cell_coo={(0, 0)}, ) }, denotation_errors_path=None, predictions_file_name=None) self.assertEqual(0.5, denotation_accuracy)
def _calc_denotation_accuracy(tables_file, examples, denotation_errors_path, predictions_file_name): with tf.io.gfile.GFile(tables_file, 'rb') as f: tables = pickle.load(f) for example in examples.values(): example.table = tables[example.table_id] return calc_metrics_utils.calc_denotation_accuracy(examples, denotation_errors_path, predictions_file_name)
def _eval_for_set( model_dir, name, task, interaction_file, prediction_file, global_step, ): """Computes eval metric from predictions.""" if not tf.io.gfile.exists(prediction_file): _warn( f"Can't evaluate for {name} because {prediction_file} doesn't exist." ) return test_examples = calc_metrics_utils.read_data_examples_from_interactions( interaction_file) calc_metrics_utils.read_predictions( predictions_path=prediction_file, examples=test_examples, ) if task in [ tasks.Task.SQA, tasks.Task.WTQ, tasks.Task.WIKISQL, tasks.Task.WIKISQL_SUPERVISED ]: denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy( examples=test_examples, denotation_errors_path=None, predictions_file_name=None, ) if global_step is not None: _create_measurements_for_metrics( {'denotation_accuracy': denotation_accuracy}, global_step=global_step, model_dir=model_dir, name=name, ) elif task == tasks.Task.TABFACT: accuracy = calc_metrics_utils.calc_classification_accuracy( test_examples) if global_step is not None: _create_measurements_for_metrics( {'accuracy': accuracy}, global_step=global_step, model_dir=model_dir, name=name, ) else: raise ValueError(f'Unknown task: {task.name}')
def main(_): examples = calc_metrics_utils.read_data_examples_from_interactions( FLAGS.interactions_file) prediction_file_name = os.path.basename(FLAGS.prediction_files) calc_metrics_utils.read_predictions(FLAGS.prediction_files, examples) if FLAGS.is_strong_supervision_available: results = calc_metrics_utils.calc_structure_metrics( examples, FLAGS.denotation_errors_path) print('%s: joint_accuracy=%s' % (FLAGS.prediction_files, results.joint_acc)) denotation_accuracy = calc_metrics_utils.calc_denotation_accuracy( examples, FLAGS.denotation_errors_path, prediction_file_name) print('%s: denotation_accuracy=%s' % (FLAGS.prediction_files, denotation_accuracy))