def test_parse_example_line(self): # pylint: disable=no-self-use,protected-access with open(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples") as filename: lines = filename.readlines() example_info = WikiTablesDatasetReader._parse_example_line(lines[0]) question = 'what was the last year where this team was a part of the usl a-league?' assert example_info == {'id': 'nt-0', 'question': question, 'table_filename': 'tables/590.csv'}
def make_data(input_examples_file: str, tables_directory: str, archived_model_file: str, output_dir: str, num_logical_forms: int) -> None: reader = WikiTablesDatasetReader(tables_directory=tables_directory, keep_if_no_dpd=True, output_agendas=True) dataset = reader.read(input_examples_file) input_lines = [] with open(input_examples_file) as input_file: input_lines = input_file.readlines() # Note: Double { for escaping {. new_tables_config = f"{{model: {{tables_directory: {tables_directory}}}}}" archive = load_archive(archived_model_file, overrides=new_tables_config) model = archive.model model.training = False model._decoder_trainer._max_num_decoded_sequences = 100 for instance, example_line in zip(dataset, input_lines): outputs = model.forward_on_instance(instance) parsed_info = reader._parse_example_line(example_line) example_id = parsed_info["id"] logical_forms = outputs["logical_form"] correct_logical_forms = [] for logical_form in logical_forms: if model._denotation_accuracy.evaluate_logical_form(logical_form, example_line): correct_logical_forms.append(logical_form) if len(correct_logical_forms) >= num_logical_forms: break num_found = len(correct_logical_forms) print(f"{num_found} found for {example_id}") if num_found == 0: continue if not os.path.exists(output_dir): os.makedirs(output_dir) output_file = gzip.open(os.path.join(output_dir, f"{example_id}.gz"), "wb") for logical_form in correct_logical_forms: logical_form_line = (logical_form + "\n").encode('utf-8') output_file.write(logical_form_line) output_file.close()