def create_interactions(task, input_dir, output_dir): """Converts original task data to interactions. Interactions will be written to f'{output_dir}/interactions'. Other files might be written as well. Args: task: The current task. input_dir: Data with original task data. output_dir: Outputs are written to this directory. """ if task == tasks.Task.SQA: tsv_dir = input_dir elif task == tasks.Task.WTQ: wtq_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL_SUPERVISED: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir else: raise ValueError(f'Unknown task: {task.name}') sqa_utils.create_interactions( get_supervision_modes(task), tsv_dir, get_interaction_dir(output_dir), )
def create_interactions( task, input_dir, output_dir, token_selector, ): # pylint: disable=g-doc-args """Converts original task data to interactions. Interactions will be written to f'{output_dir}/interactions'. Other files might be written as well. Args: task: The current task. input_dir: Data with original task data. output_dir: Outputs are written to this directory. token_selector: Optional helper class to keep more relevant tokens in input. """ def to_tfrecord( interactions): """Helper function that binds output dir and token_selector arguments.""" _to_tfrecord(interactions, output_dir, token_selector) if task == tasks.Task.SQA: tsv_dir = input_dir elif task == tasks.Task.WTQ: wtq_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL_SUPERVISED: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.TABFACT: to_tfrecord(tabfact_utils.convert(input_dir)) return elif task == tasks.Task.HYBRIDQA: to_tfrecord(hybridqa_utils.convert(input_dir)) return elif task == tasks.Task.HYBRIDQA_RC: to_tfrecord(hybridqa_rc_utils.convert(input_dir, output_dir)) return elif task == tasks.Task.HYBRIDQA_E2E: to_tfrecord( hybridqa_rc_utils.create_interactions_from_hybridqa_predictions( output_dir)) return elif task == tasks.Task.SEM_TAB_FACT: to_tfrecord(sem_tab_fact_utils.convert(input_dir)) return else: raise ValueError(f'Unknown task: {task.name}') sqa_utils.create_interactions( get_supervision_modes(task), tsv_dir, get_interaction_dir(output_dir), token_selector, )
def create_interactions( task, input_dir, output_dir, token_selector, ): # pylint: disable=g-doc-args """Converts original task data to interactions. Interactions will be written to f'{output_dir}/interactions'. Other files might be written as well. Args: task: The current task. input_dir: Data with original task data. output_dir: Outputs are written to this directory. token_selector: Optional helper class to keep more relevant tokens in input. """ def to_tfrecord(interactions): """Helper function that binds output dir and token_selector arguments.""" _to_tfrecord(interactions, output_dir, token_selector) def to_json(config, output_dir): config_filename = os.path.join(output_dir, 'hybridqa_rc_config.json') with tf.io.gfile.GFile(config_filename, 'w') as fp: json.dump(config.asdict, fp, indent=4, sort_keys=True) if task == tasks.Task.SQA: tsv_dir = input_dir elif task == tasks.Task.WTQ: wtq_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.WIKISQL_SUPERVISED: wikisql_utils.convert(input_dir, output_dir) tsv_dir = output_dir elif task == tasks.Task.TABFACT: to_tfrecord(tabfact_utils.convert(input_dir)) return else: raise ValueError(f'Unknown task: {task.name}') sqa_utils.create_interactions( get_supervision_modes(task), tsv_dir, get_interaction_dir(output_dir), token_selector, )
def test_simple_test(self): with tempfile.TemporaryDirectory() as input_dir: with tempfile.TemporaryDirectory() as output_dir: _create_inputs(input_dir, tables=[{ 'table_id': 'csv/203-csv/515.csv', 'columns': ['Text', 'Number'], 'rows': [['A', 1], ['B', 2], ['тапас', 3]], }], examples=[ { 'id': 'nt-2', 'utterance': 'What is text for 2?', 'context': 'csv/203-csv/515.csv', 'targetValue': 'B', }, ]) wtq_utils.convert(input_dir=input_dir, output_dir=output_dir) table_dir = os.path.join(output_dir, wtq_utils._TABLE_DIR_NAME) self.assertCountEqual(tf.io.gfile.listdir(output_dir), [ 'random-split-1-dev.tsv', 'random-split-1-train.tsv', 'random-split-2-dev.tsv', 'random-split-2-train.tsv', 'random-split-3-dev.tsv', 'random-split-3-train.tsv', 'random-split-4-dev.tsv', 'random-split-4-train.tsv', 'random-split-5-dev.tsv', 'random-split-5-train.tsv', 'table_csv', 'test.tsv', 'train.tsv', ]) self.assertEqual(tf.io.gfile.listdir(table_dir), ['203-515.csv']) table_path = os.path.join(table_dir, '203-515.csv') with tf.io.gfile.GFile(table_path) as table_file: actual = [dict(row) for row in csv.DictReader(table_file)] self.assertEqual([{ 'Text': 'a', 'Number': '1', }, { 'Text': 'b', 'Number': '2', }, { 'Text': 'тапас', 'Number': '3' }], actual) filename = os.path.join(output_dir, 'test.tsv') with tf.io.gfile.GFile(filename) as dev_file: actual = list(csv.DictReader(dev_file, delimiter='\t')) logging.info(actual) self.assertEqual( { 'id': 'nt-2', 'annotator': '0', 'position': '0', 'question': 'What is text for 2?', 'table_file': 'table_csv/203-515.csv', 'answer_coordinates': "['(-1, -1)']", 'aggregation': 'NONE', 'answer_text': "['B']", 'float_answer': '', }, dict(actual[0]))