Ejemplo n.º 1
0
 def test_parse_example_line(self):
     # pylint: disable=no-self-use,protected-access
     with open(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples") as filename:
         lines = filename.readlines()
     example_info = WikiTablesDatasetReader._parse_example_line(lines[0])
     question = 'what was the last year where this team was a part of the usl a-league?'
     assert example_info == {'id': 'nt-0',
                             'question': question,
                             'table_filename': 'tables/590.csv'}
Ejemplo n.º 2
0
 def test_parse_example_line(self):
     # pylint: disable=no-self-use,protected-access
     with open(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples") as filename:
         lines = filename.readlines()
     example_info = WikiTablesDatasetReader._parse_example_line(lines[0])
     question = 'what was the last year where this team was a part of the usl a-league?'
     assert example_info == {'id': 'nt-0',
                             'question': question,
                             'table_filename': 'tables/590.csv'}
Ejemplo n.º 3
0
def make_data(input_examples_file: str,
              tables_directory: str,
              archived_model_file: str,
              output_dir: str,
              num_logical_forms: int) -> None:
    reader = WikiTablesDatasetReader(tables_directory=tables_directory,
                                     keep_if_no_dpd=True,
                                     output_agendas=True)
    dataset = reader.read(input_examples_file)
    input_lines = []
    with open(input_examples_file) as input_file:
        input_lines = input_file.readlines()
    # Note: Double { for escaping {.
    new_tables_config = f"{{model: {{tables_directory: {tables_directory}}}}}"
    archive = load_archive(archived_model_file,
                           overrides=new_tables_config)
    model = archive.model
    model.training = False
    model._decoder_trainer._max_num_decoded_sequences = 100
    for instance, example_line in zip(dataset, input_lines):
        outputs = model.forward_on_instance(instance)
        parsed_info = reader._parse_example_line(example_line)
        example_id = parsed_info["id"]
        logical_forms = outputs["logical_form"]
        correct_logical_forms = []
        for logical_form in logical_forms:
            if model._denotation_accuracy.evaluate_logical_form(logical_form, example_line):
                correct_logical_forms.append(logical_form)
                if len(correct_logical_forms) >= num_logical_forms:
                    break
        num_found = len(correct_logical_forms)
        print(f"{num_found} found for {example_id}")
        if num_found == 0:
            continue
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        output_file = gzip.open(os.path.join(output_dir, f"{example_id}.gz"), "wb")
        for logical_form in correct_logical_forms:
            logical_form_line = (logical_form + "\n").encode('utf-8')
            output_file.write(logical_form_line)
        output_file.close()