예제 #1
0
def create_interactions(task, input_dir, output_dir):
    """Converts original task data to interactions.

  Interactions will be written to f'{output_dir}/interactions'. Other files
  might be written as well.

  Args:
    task: The current task.
    input_dir: Data with original task data.
    output_dir: Outputs are written to this directory.
  """
    if task == tasks.Task.SQA:
        tsv_dir = input_dir
    elif task == tasks.Task.WTQ:
        wtq_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    elif task == tasks.Task.WIKISQL:
        wikisql_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    elif task == tasks.Task.WIKISQL_SUPERVISED:
        wikisql_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    else:
        raise ValueError(f'Unknown task: {task.name}')
    sqa_utils.create_interactions(
        get_supervision_modes(task),
        tsv_dir,
        get_interaction_dir(output_dir),
    )
예제 #2
0
def create_interactions(
    task,
    input_dir,
    output_dir,
    token_selector,
):  # pylint: disable=g-doc-args
  """Converts original task data to interactions.

  Interactions will be written to f'{output_dir}/interactions'. Other files
  might be written as well.

  Args:
    task: The current task.
    input_dir: Data with original task data.
    output_dir: Outputs are written to this directory.
    token_selector: Optional helper class to keep more relevant tokens in input.
  """

  def to_tfrecord(
      interactions):
    """Helper function that binds output dir and token_selector arguments."""
    _to_tfrecord(interactions, output_dir, token_selector)

  if task == tasks.Task.SQA:
    tsv_dir = input_dir
  elif task == tasks.Task.WTQ:
    wtq_utils.convert(input_dir, output_dir)
    tsv_dir = output_dir
  elif task == tasks.Task.WIKISQL:
    wikisql_utils.convert(input_dir, output_dir)
    tsv_dir = output_dir
  elif task == tasks.Task.WIKISQL_SUPERVISED:
    wikisql_utils.convert(input_dir, output_dir)
    tsv_dir = output_dir
  elif task == tasks.Task.TABFACT:
    to_tfrecord(tabfact_utils.convert(input_dir))
    return
  elif task == tasks.Task.HYBRIDQA:
    to_tfrecord(hybridqa_utils.convert(input_dir))
    return
  elif task == tasks.Task.HYBRIDQA_RC:
    to_tfrecord(hybridqa_rc_utils.convert(input_dir, output_dir))
    return
  elif task == tasks.Task.HYBRIDQA_E2E:
    to_tfrecord(
        hybridqa_rc_utils.create_interactions_from_hybridqa_predictions(
            output_dir))
    return
  elif task == tasks.Task.SEM_TAB_FACT:
    to_tfrecord(sem_tab_fact_utils.convert(input_dir))
    return
  else:
    raise ValueError(f'Unknown task: {task.name}')
  sqa_utils.create_interactions(
      get_supervision_modes(task),
      tsv_dir,
      get_interaction_dir(output_dir),
      token_selector,
  )
예제 #3
0
파일: task_utils.py 프로젝트: dsbip/tapas
def create_interactions(
    task,
    input_dir,
    output_dir,
    token_selector,
):  # pylint: disable=g-doc-args
    """Converts original task data to interactions.

  Interactions will be written to f'{output_dir}/interactions'. Other files
  might be written as well.

  Args:
    task: The current task.
    input_dir: Data with original task data.
    output_dir: Outputs are written to this directory.
    token_selector: Optional helper class to keep more relevant tokens in input.
  """
    def to_tfrecord(interactions):
        """Helper function that binds output dir and token_selector arguments."""
        _to_tfrecord(interactions, output_dir, token_selector)

    def to_json(config, output_dir):
        config_filename = os.path.join(output_dir, 'hybridqa_rc_config.json')
        with tf.io.gfile.GFile(config_filename, 'w') as fp:
            json.dump(config.asdict, fp, indent=4, sort_keys=True)

    if task == tasks.Task.SQA:
        tsv_dir = input_dir
    elif task == tasks.Task.WTQ:
        wtq_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    elif task == tasks.Task.WIKISQL:
        wikisql_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    elif task == tasks.Task.WIKISQL_SUPERVISED:
        wikisql_utils.convert(input_dir, output_dir)
        tsv_dir = output_dir
    elif task == tasks.Task.TABFACT:
        to_tfrecord(tabfact_utils.convert(input_dir))
        return
    else:
        raise ValueError(f'Unknown task: {task.name}')
    sqa_utils.create_interactions(
        get_supervision_modes(task),
        tsv_dir,
        get_interaction_dir(output_dir),
        token_selector,
    )
예제 #4
0
    def test_simple_test(self):
        with tempfile.TemporaryDirectory() as input_dir:
            with tempfile.TemporaryDirectory() as output_dir:
                _create_inputs(
                    input_dir,
                    tables=[{
                        'id': '1-0000001-1',
                        'header': ['Text', 'Number'],
                        'types': ['text', 'real'],
                        'rows': [['A', 1], ['B', 2], ['C', 3]],
                    }],
                    examples=[
                        {
                            'question': 'What is text for 2?',
                            'table_id': '1-0000001-1',
                            'sql': {
                                'agg': 0,  # No aggregation
                                'sel': 0,  # Text column
                                'conds': [[1, 0, 2]]  # Column 1 = 2
                            },
                        },
                        {
                            'question':
                            'What is sum when number is greater than 1?',
                            'table_id': '1-0000001-1',
                            'sql': {
                                'agg': 4,  # SUM
                                'sel': 1,  # Number column
                                'conds': [[1, 1, 1]]  # Column 1 > 1
                            }
                        }
                    ])

                wikisql_utils.convert(input_dir=input_dir,
                                      output_dir=output_dir)

                table_path = os.path.join(
                    output_dir,
                    wikisql_utils._TABLE_DIR_NAME,
                    '1-0000001-1.csv',
                )
                with tf.io.gfile.GFile(table_path) as table_file:
                    actual = [dict(row) for row in csv.DictReader(table_file)]
                    self.assertEqual([{
                        'Text': 'A',
                        'Number': '1',
                    }, {
                        'Text': 'B',
                        'Number': '2',
                    }, {
                        'Text': 'C',
                        'Number': '3'
                    }], actual)

                filename = os.path.join(output_dir, 'dev.tsv')
                with tf.io.gfile.GFile(filename) as dev_file:
                    actual = list(csv.DictReader(dev_file, delimiter='\t'))
                    logging.info(actual)
                    self.assertEqual(
                        {
                            'id': 'dev-0',
                            'annotator': '0',
                            'position': '0',
                            'question': 'What is text for 2?',
                            'table_file': 'table_csv/1-0000001-1.csv',
                            'answer_coordinates': "['(1, 0)']",
                            'aggregation': '',
                            'answer_text': "['B']",
                            'float_answer': '',
                        }, dict(actual[0]))
                    self.assertEqual(
                        {
                            'id': 'dev-1',
                            'annotator': '0',
                            'position': '0',
                            'question':
                            'What is sum when number is greater than 1?',
                            'table_file': 'table_csv/1-0000001-1.csv',
                            'answer_coordinates': "['(1, 1)', '(2, 1)']",
                            'aggregation': 'SUM',
                            'answer_text': "['5.0']",
                            'float_answer': '5.0',
                        }, dict(actual[1]))