def create_interactions(task, input_dir, output_dir): """Converts original task data to interactions. Interactions will be written to f'{output_dir}/interactions'. Other files might be written as well. Args: task: The current task. input_dir: Data with original task data. output_dir: Outputs are written to this directory. """ if task == tasks.Task.SQA: tsv_dir = input_dir elif task == tasks.Task.WTQ: wtq_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL_SUPERVISED: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir else: raise ValueError(f'Unknown task: {task.name}') sqa_utils.create_interactions( get_supervision_modes(task), tsv_dir, get_interaction_dir(output_dir), )
def create_interactions( task, input_dir, output_dir, token_selector, ): # pylint: disable=g-doc-args """Converts original task data to interactions. Interactions will be written to f'{output_dir}/interactions'. Other files might be written as well. Args: task: The current task. input_dir: Data with original task data. output_dir: Outputs are written to this directory. token_selector: Optional helper class to keep more relevant tokens in input. """ def to_tfrecord( interactions): """Helper function that binds output dir and token_selector arguments.""" _to_tfrecord(interactions, output_dir, token_selector) if task == tasks.Task.SQA: tsv_dir = input_dir elif task == tasks.Task.WTQ: wtq_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL_SUPERVISED: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.TABFACT: to_tfrecord(tabfact_utils.convert(input_dir)) return elif task == tasks.Task.HYBRIDQA: to_tfrecord(hybridqa_utils.convert(input_dir)) return elif task == tasks.Task.HYBRIDQA_RC: to_tfrecord(hybridqa_rc_utils.convert(input_dir, output_dir)) return elif task == tasks.Task.HYBRIDQA_E2E: to_tfrecord( hybridqa_rc_utils.create_interactions_from_hybridqa_predictions( output_dir)) return elif task == tasks.Task.SEM_TAB_FACT: to_tfrecord(sem_tab_fact_utils.convert(input_dir)) return else: raise ValueError(f'Unknown task: {task.name}') sqa_utils.create_interactions( get_supervision_modes(task), tsv_dir, get_interaction_dir(output_dir), token_selector, )
def create_interactions( task, input_dir, output_dir, token_selector, ): # pylint: disable=g-doc-args """Converts original task data to interactions. Interactions will be written to f'{output_dir}/interactions'. Other files might be written as well. Args: task: The current task. input_dir: Data with original task data. output_dir: Outputs are written to this directory. token_selector: Optional helper class to keep more relevant tokens in input. """ def to_tfrecord(interactions): """Helper function that binds output dir and token_selector arguments.""" _to_tfrecord(interactions, output_dir, token_selector) def to_json(config, output_dir): config_filename = os.path.join(output_dir, 'hybridqa_rc_config.json') with tf.io.gfile.GFile(config_filename, 'w') as fp: json.dump(config.asdict, fp, indent=4, sort_keys=True) if task == tasks.Task.SQA: tsv_dir = input_dir elif task == tasks.Task.WTQ: wtq_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL_SUPERVISED: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.TABFACT: to_tfrecord(tabfact_utils.convert(input_dir)) return else: raise ValueError(f'Unknown task: {task.name}') sqa_utils.create_interactions( get_supervision_modes(task), tsv_dir, get_interaction_dir(output_dir), token_selector, )
def test_simple_test(self): with tempfile.TemporaryDirectory() as input_dir: with tempfile.TemporaryDirectory() as output_dir: _create_inputs( input_dir, tables=[{ 'id': '1-0000001-1', 'header': ['Text', 'Number'], 'types': ['text', 'real'], 'rows': [['A', 1], ['B', 2], ['C', 3]], }], examples=[ { 'question': 'What is text for 2?', 'table_id': '1-0000001-1', 'sql': { 'agg': 0, # No aggregation 'sel': 0, # Text column 'conds': [[1, 0, 2]] # Column 1 = 2 }, }, { 'question': 'What is sum when number is greater than 1?', 'table_id': '1-0000001-1', 'sql': { 'agg': 4, # SUM 'sel': 1, # Number column 'conds': [[1, 1, 1]] # Column 1 > 1 } } ]) wikisql_utils.convert(input_dir=input_dir, output_dir=output_dir) table_path = os.path.join( output_dir, wikisql_utils._TABLE_DIR_NAME, '1-0000001-1.csv', ) with tf.io.gfile.GFile(table_path) as table_file: actual = [dict(row) for row in csv.DictReader(table_file)] self.assertEqual([{ 'Text': 'A', 'Number': '1', }, { 'Text': 'B', 'Number': '2', }, { 'Text': 'C', 'Number': '3' }], actual) filename = os.path.join(output_dir, 'dev.tsv') with tf.io.gfile.GFile(filename) as dev_file: actual = list(csv.DictReader(dev_file, delimiter='\t')) logging.info(actual) self.assertEqual( { 'id': 'dev-0', 'annotator': '0', 'position': '0', 'question': 'What is text for 2?', 'table_file': 'table_csv/1-0000001-1.csv', 'answer_coordinates': "['(1, 0)']", 'aggregation': '', 'answer_text': "['B']", 'float_answer': '', }, dict(actual[0])) self.assertEqual( { 'id': 'dev-1', 'annotator': '0', 'position': '0', 'question': 'What is sum when number is greater than 1?', 'table_file': 'table_csv/1-0000001-1.csv', 'answer_coordinates': "['(1, 1)', '(2, 1)']", 'aggregation': 'SUM', 'answer_text': "['5.0']", 'float_answer': '5.0', }, dict(actual[1]))