Exemplo n.º 1
0
def create_interactions(task, input_dir, output_dir):
    """Converts original task data to interactions.

  Interactions will be written to f'{output_dir}/interactions'. Other files
  might be written as well.

  Args:
    task: The current task.
    input_dir: Data with original task data.
    output_dir: Outputs are written to this directory.
  """
    if task == tasks.Task.SQA:
        tsv_dir = input_dir
    elif task == tasks.Task.WTQ:
        wtq_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    elif task == tasks.Task.WIKISQL:
        wikisql_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    elif task == tasks.Task.WIKISQL_SUPERVISED:
        wikisql_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    else:
        raise ValueError(f'Unknown task: {task.name}')
    sqa_utils.create_interactions(
        get_supervision_modes(task),
        tsv_dir,
        get_interaction_dir(output_dir),
    )
Exemplo n.º 2
0
def create_interactions(
    task,
    input_dir,
    output_dir,
    token_selector,
):  # pylint: disable=g-doc-args
  """Converts original task data to interactions.

  Interactions will be written to f'{output_dir}/interactions'. Other files
  might be written as well.

  Args:
    task: The current task.
    input_dir: Data with original task data.
    output_dir: Outputs are written to this directory.
    token_selector: Optional helper class to keep more relevant tokens in input.
  """

  def to_tfrecord(
      interactions):
    """Helper function that binds output dir and token_selector arguments."""
    _to_tfrecord(interactions, output_dir, token_selector)

  if task == tasks.Task.SQA:
    tsv_dir = input_dir
  elif task == tasks.Task.WTQ:
    wtq_utils.convert(input_dir, output_dir)
    tsv_dir = output_dir
  elif task == tasks.Task.WIKISQL:
    wikisql_utils.convert(input_dir, output_dir)
    tsv_dir = output_dir
  elif task == tasks.Task.WIKISQL_SUPERVISED:
    wikisql_utils.convert(input_dir, output_dir)
    tsv_dir = output_dir
  elif task == tasks.Task.TABFACT:
    to_tfrecord(tabfact_utils.convert(input_dir))
    return
  elif task == tasks.Task.HYBRIDQA:
    to_tfrecord(hybridqa_utils.convert(input_dir))
    return
  elif task == tasks.Task.HYBRIDQA_RC:
    to_tfrecord(hybridqa_rc_utils.convert(input_dir, output_dir))
    return
  elif task == tasks.Task.HYBRIDQA_E2E:
    to_tfrecord(
        hybridqa_rc_utils.create_interactions_from_hybridqa_predictions(
            output_dir))
    return
  elif task == tasks.Task.SEM_TAB_FACT:
    to_tfrecord(sem_tab_fact_utils.convert(input_dir))
    return
  else:
    raise ValueError(f'Unknown task: {task.name}')
  sqa_utils.create_interactions(
      get_supervision_modes(task),
      tsv_dir,
      get_interaction_dir(output_dir),
      token_selector,
  )
Exemplo n.º 3
0
def create_interactions(
    task,
    input_dir,
    output_dir,
    token_selector,
):  # pylint: disable=g-doc-args
    """Converts original task data to interactions.

  Interactions will be written to f'{output_dir}/interactions'. Other files
  might be written as well.

  Args:
    task: The current task.
    input_dir: Data with original task data.
    output_dir: Outputs are written to this directory.
    token_selector: Optional helper class to keep more relevant tokens in input.
  """
    def to_tfrecord(interactions):
        """Helper function that binds output dir and token_selector arguments."""
        _to_tfrecord(interactions, output_dir, token_selector)

    def to_json(config, output_dir):
        config_filename = os.path.join(output_dir, 'hybridqa_rc_config.json')
        with tf.io.gfile.GFile(config_filename, 'w') as fp:
            json.dump(config.asdict, fp, indent=4, sort_keys=True)

    if task == tasks.Task.SQA:
        tsv_dir = input_dir
    elif task == tasks.Task.WTQ:
        wtq_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    elif task == tasks.Task.WIKISQL:
        wikisql_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    elif task == tasks.Task.WIKISQL_SUPERVISED:
        wikisql_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    elif task == tasks.Task.TABFACT:
        to_tfrecord(tabfact_utils.convert(input_dir))
        return
    else:
        raise ValueError(f'Unknown task: {task.name}')
    sqa_utils.create_interactions(
        get_supervision_modes(task),
        tsv_dir,
        get_interaction_dir(output_dir),
        token_selector,
    )
Exemplo n.º 4
0
    def test_simple_test(self):
        with tempfile.TemporaryDirectory() as input_dir:
            with tempfile.TemporaryDirectory() as output_dir:
                _create_inputs(input_dir,
                               tables=[{
                                   'table_id':
                                   'csv/203-csv/515.csv',
                                   'columns': ['Text', 'Number'],
                                   'rows': [['A', 1], ['B', 2], ['тапас', 3]],
                               }],
                               examples=[
                                   {
                                       'id': 'nt-2',
                                       'utterance': 'What is text for 2?',
                                       'context': 'csv/203-csv/515.csv',
                                       'targetValue': 'B',
                                   },
                               ])

                wtq_utils.convert(input_dir=input_dir, output_dir=output_dir)

                table_dir = os.path.join(output_dir, wtq_utils._TABLE_DIR_NAME)
                self.assertCountEqual(tf.io.gfile.listdir(output_dir), [
                    'random-split-1-dev.tsv',
                    'random-split-1-train.tsv',
                    'random-split-2-dev.tsv',
                    'random-split-2-train.tsv',
                    'random-split-3-dev.tsv',
                    'random-split-3-train.tsv',
                    'random-split-4-dev.tsv',
                    'random-split-4-train.tsv',
                    'random-split-5-dev.tsv',
                    'random-split-5-train.tsv',
                    'table_csv',
                    'test.tsv',
                    'train.tsv',
                ])
                self.assertEqual(tf.io.gfile.listdir(table_dir),
                                 ['203-515.csv'])

                table_path = os.path.join(table_dir, '203-515.csv')
                with tf.io.gfile.GFile(table_path) as table_file:
                    actual = [dict(row) for row in csv.DictReader(table_file)]
                    self.assertEqual([{
                        'Text': 'a',
                        'Number': '1',
                    }, {
                        'Text': 'b',
                        'Number': '2',
                    }, {
                        'Text': 'тапас',
                        'Number': '3'
                    }], actual)

                filename = os.path.join(output_dir, 'test.tsv')
                with tf.io.gfile.GFile(filename) as dev_file:
                    actual = list(csv.DictReader(dev_file, delimiter='\t'))
                    logging.info(actual)
                    self.assertEqual(
                        {
                            'id': 'nt-2',
                            'annotator': '0',
                            'position': '0',
                            'question': 'What is text for 2?',
                            'table_file': 'table_csv/203-515.csv',
                            'answer_coordinates': "['(-1, -1)']",
                            'aggregation': 'NONE',
                            'answer_text': "['B']",
                            'float_answer': '',
                        }, dict(actual[0]))