def test_reader_reads_preprocessed_file(self):
     # We're should get the exact same results when reading a pre-processed file as we get when
     # we read the original data.
     reader = WikiTablesDatasetReader()
     dataset = reader.read(
         "tests/fixtures/data/wikitables/sample_data_preprocessed.jsonl")
     assert_dataset_correct(dataset)
예제 #2
0
    def test_production_rule_field_with_multiple_gpus(self):
        wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/'
        offline_lf_directory = wikitables_dir + 'action_space_walker_output/'
        wikitables_reader = WikiTablesDatasetReader(
            tables_directory=wikitables_dir,
            offline_logical_forms_directory=offline_lf_directory)
        instances = wikitables_reader.read(wikitables_dir +
                                           'sample_data.examples')
        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
        model = load_archive(archive_path).model
        model.cuda()

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(model.vocab)

        trainer = CallbackTrainer(model,
                                  self.optimizer,
                                  num_epochs=2,
                                  cuda_device=[0, 1],
                                  callbacks=[
                                      GenerateTrainingBatches(
                                          instances, multigpu_iterator),
                                      TrainSupervised()
                                  ])
        trainer.train()
    def test_production_rule_field_with_multiple_gpus(self):
        wikitables_dir = "allennlp/tests/fixtures/data/wikitables/"
        offline_lf_directory = wikitables_dir + "action_space_walker_output/"
        wikitables_reader = WikiTablesDatasetReader(
            tables_directory=wikitables_dir,
            offline_logical_forms_directory=offline_lf_directory)
        instances = wikitables_reader.read(wikitables_dir +
                                           "sample_data.examples")
        archive_path = (self.FIXTURES_ROOT / "semantic_parsing" /
                        "wikitables" / "serialization" / "model.tar.gz")
        model = load_archive(archive_path).model
        model.cuda()

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(model.vocab)

        trainer = CallbackTrainer(
            model,
            instances,
            multigpu_iterator,
            self.optimizer,
            num_epochs=2,
            cuda_device=[0, 1],
            callbacks=[GradientNormAndClip()],
        )
        trainer.train()
예제 #4
0
 def test_reader_reads_preprocessed_file(self):
     # We're should get the exact same results when reading a pre-processed file as we get when
     # we read the original data.
     reader = WikiTablesDatasetReader()
     dataset = reader.read(
         str(self.FIXTURES_ROOT / "data" / "wikitables" /
             "sample_data_preprocessed.jsonl"))
     assert_dataset_correct(dataset)
예제 #5
0
    def test_production_rule_field_with_multiple_gpus(self):
        wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/'
        wikitables_reader = WikiTablesDatasetReader(tables_directory=wikitables_dir,
                                                    dpd_output_directory=wikitables_dir + 'dpd_output/')
        instances = wikitables_reader.read(wikitables_dir + 'sample_data.examples')
        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
        model = load_archive(archive_path).model
        model.cuda()

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(model.vocab)
        trainer = Trainer(model, self.optimizer, multigpu_iterator, instances, num_epochs=2, cuda_device=[0, 1])
        trainer.train()
예제 #6
0
 def test_reader_reads(self):
     params = {
             'lazy': False,
             'tables_directory': self.FIXTURES_ROOT / "data" / "wikitables",
             'dpd_output_directory': self.FIXTURES_ROOT / "data" / "wikitables" / "dpd_output",
             }
     reader = WikiTablesDatasetReader.from_params(Params(params))
     dataset = reader.read(str(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples"))
     assert_dataset_correct(dataset)
예제 #7
0
 def test_parse_example_line(self):
     # pylint: disable=no-self-use,protected-access
     with open(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples") as filename:
         lines = filename.readlines()
     example_info = WikiTablesDatasetReader._parse_example_line(lines[0])
     question = 'what was the last year where this team was a part of the usl a-league?'
     assert example_info == {'id': 'nt-0',
                             'question': question,
                             'table_filename': 'tables/590.csv'}
예제 #8
0
 def test_parse_example_line(self):
     # pylint: disable=no-self-use,protected-access
     with open(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples") as filename:
         lines = filename.readlines()
     example_info = WikiTablesDatasetReader._parse_example_line(lines[0])
     question = 'what was the last year where this team was a part of the usl a-league?'
     assert example_info == {'id': 'nt-0',
                             'question': question,
                             'table_filename': 'tables/590.csv'}
예제 #9
0
 def test_reader_reads(self):
     params = {
             'lazy': False,
             'tables_directory': self.FIXTURES_ROOT / "data" / "wikitables",
             'dpd_output_directory': self.FIXTURES_ROOT / "data" / "wikitables" / "dpd_output",
             }
     reader = WikiTablesDatasetReader.from_params(Params(params))
     dataset = reader.read(str(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples"))
     assert_dataset_correct(dataset)
예제 #10
0
def make_data(input_examples_file: str, tables_directory: str,
              archived_model_file: str, output_dir: str,
              num_logical_forms: int, variable_free: bool) -> None:
    if variable_free:
        reader = WikiTablesVariableFreeDatasetReader(
            tables_directory=tables_directory,
            keep_if_no_logical_forms=True,
            output_agendas=True)
    else:
        reader = WikiTablesDatasetReader(tables_directory=tables_directory,
                                         keep_if_no_dpd=True,
                                         output_agendas=True)
    dataset = reader.read(input_examples_file)
    input_lines = []
    with open(input_examples_file) as input_file:
        input_lines = input_file.readlines()
    if variable_free:
        new_tables_config = {}
    else:
        # Note: Double { for escaping {.
        new_tables_config = f"{{model: {{tables_directory: {tables_directory}}}}}"
    archive = load_archive(archived_model_file, overrides=new_tables_config)
    model = archive.model
    model.training = False
    model._decoder_trainer._max_num_decoded_sequences = 100
    for instance, example_line in zip(dataset, input_lines):
        outputs = model.forward_on_instance(instance)
        parsed_info = parse_example_line(example_line)
        example_id = parsed_info["id"]
        logical_forms = outputs["logical_form"]
        correct_logical_forms = []
        for logical_form in logical_forms:
            if variable_free:
                world = instance.fields["world"].metadata
                target_values = instance.fields["target_values"].metadata
                logical_form_is_correct = world.evaluate_logical_form(
                    logical_form, target_values)
            else:
                logical_form_is_correct = model._executor.evaluate_logical_form(
                    logical_form, example_line)
            if logical_form_is_correct:
                correct_logical_forms.append(logical_form)
                if len(correct_logical_forms) >= num_logical_forms:
                    break
        num_found = len(correct_logical_forms)
        print(f"{num_found} found for {example_id}")
        if num_found == 0:
            continue
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        output_file = gzip.open(os.path.join(output_dir, f"{example_id}.gz"),
                                "wb")
        for logical_form in correct_logical_forms:
            logical_form_line = (logical_form + "\n").encode('utf-8')
            output_file.write(logical_form_line)
        output_file.close()
예제 #11
0
def make_data(
    input_examples_file: str,
    tables_directory: str,
    archived_model_file: str,
    output_dir: str,
    num_logical_forms: int,
) -> None:
    reader = WikiTablesDatasetReader(
        tables_directory=tables_directory, keep_if_no_logical_forms=True, output_agendas=True
    )
    dataset = reader.read(input_examples_file)
    input_lines = []
    with open(input_examples_file) as input_file:
        input_lines = input_file.readlines()
    archive = load_archive(archived_model_file)
    model = archive.model
    model.training = False
    model._decoder_trainer._max_num_decoded_sequences = 100
    for instance, example_line in zip(dataset, input_lines):
        outputs = model.forward_on_instance(instance)
        world = instance.fields["world"].metadata
        parsed_info = util.parse_example_line(example_line)
        example_id = parsed_info["id"]
        target_list = parsed_info["target_values"]
        logical_forms = outputs["logical_form"]
        correct_logical_forms = []
        for logical_form in logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
                if len(correct_logical_forms) >= num_logical_forms:
                    break
        num_found = len(correct_logical_forms)
        print(f"{num_found} found for {example_id}")
        if num_found == 0:
            continue
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        output_file = gzip.open(os.path.join(output_dir, f"{example_id}.gz"), "wb")
        for logical_form in correct_logical_forms:
            logical_form_line = (logical_form + "\n").encode("utf-8")
            output_file.write(logical_form_line)
        output_file.close()
예제 #12
0
 def test_reader_reads(self):
     offline_search_directory = self.FIXTURES_ROOT / "data" / "wikitables" / "action_space_walker_output"
     params = {
         'lazy': False,
         'tables_directory': self.FIXTURES_ROOT / "data" / "wikitables",
         'offline_logical_forms_directory': offline_search_directory,
     }
     reader = WikiTablesDatasetReader.from_params(Params(params))
     dataset = reader.read(self.FIXTURES_ROOT / "data" / "wikitables" /
                           "sample_data.examples")
     assert_dataset_correct(dataset)
 def test_reader_reads(self):
     params = {
         'lazy': False,
         'tables_directory': "tests/fixtures/data/wikitables",
         'dpd_output_directory':
         "tests/fixtures/data/wikitables/dpd_output",
     }
     reader = WikiTablesDatasetReader.from_params(Params(params))
     dataset = reader.read(
         "tests/fixtures/data/wikitables/sample_data.examples")
     assert_dataset_correct(dataset)
예제 #14
0
    def test_read_respects_max_dpd_tries_when_not_sorting(self):
        tables_directory = self.FIXTURES_ROOT / "data" / "wikitables"
        dpd_output_directory = self.FIXTURES_ROOT / "data" / "wikitables" / "dpd_output"
        reader = WikiTablesDatasetReader(lazy=False,
                                         sort_dpd_logical_forms=False,
                                         max_dpd_logical_forms=1,
                                         max_dpd_tries=1,
                                         tables_directory=tables_directory,
                                         dpd_output_directory=dpd_output_directory)
        dataset = reader.read(str(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples"))
        instances = list(dataset)
        instance = instances[0]
        actions = [action_field.rule for action_field in instance.fields['actions'].field_list]

        # We should have just taken the first logical form from the file, which has the following
        # action sequence.
        action_sequence = instance.fields["target_action_sequences"].field_list[0]
        action_indices = [l.sequence_index for l in action_sequence.field_list]
        action_strings = [actions[i] for i in action_indices]
        assert action_strings == [
                '@start@ -> d',
                'd -> [<c,d>, c]',
                '<c,d> -> [<<#1,#2>,<#2,#1>>, <d,c>]',
                '<<#1,#2>,<#2,#1>> -> reverse',
                '<d,c> -> fb:cell.cell.date',
                'c -> [<r,c>, r]',
                '<r,c> -> [<<#1,#2>,<#2,#1>>, <c,r>]',
                '<<#1,#2>,<#2,#1>> -> reverse',
                '<c,r> -> fb:row.row.year',
                'r -> [<n,r>, n]',
                '<n,r> -> fb:row.row.index',
                'n -> [<nd,nd>, n]',
                '<nd,nd> -> max',
                'n -> [<r,n>, r]',
                '<r,n> -> [<<#1,#2>,<#2,#1>>, <n,r>]',
                '<<#1,#2>,<#2,#1>> -> reverse',
                '<n,r> -> fb:row.row.index',
                'r -> [<c,r>, c]',
                '<c,r> -> fb:row.row.league',
                'c -> fb:cell.usl_a_league'
                ]
def make_data(input_examples_file: str,
              tables_directory: str,
              archived_model_file: str,
              output_dir: str,
              num_logical_forms: int) -> None:
    reader = WikiTablesDatasetReader(tables_directory=tables_directory,
                                     keep_if_no_dpd=True,
                                     output_agendas=True)
    dataset = reader.read(input_examples_file)
    input_lines = []
    with open(input_examples_file) as input_file:
        input_lines = input_file.readlines()
    # Note: Double { for escaping {.
    new_tables_config = f"{{model: {{tables_directory: {tables_directory}}}}}"
    archive = load_archive(archived_model_file,
                           overrides=new_tables_config)
    model = archive.model
    model.training = False
    model._decoder_trainer._max_num_decoded_sequences = 100
    for instance, example_line in zip(dataset, input_lines):
        outputs = model.forward_on_instance(instance)
        parsed_info = reader._parse_example_line(example_line)
        example_id = parsed_info["id"]
        logical_forms = outputs["logical_form"]
        correct_logical_forms = []
        for logical_form in logical_forms:
            if model._denotation_accuracy.evaluate_logical_form(logical_form, example_line):
                correct_logical_forms.append(logical_form)
                if len(correct_logical_forms) >= num_logical_forms:
                    break
        num_found = len(correct_logical_forms)
        print(f"{num_found} found for {example_id}")
        if num_found == 0:
            continue
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        output_file = gzip.open(os.path.join(output_dir, f"{example_id}.gz"), "wb")
        for logical_form in correct_logical_forms:
            logical_form_line = (logical_form + "\n").encode('utf-8')
            output_file.write(logical_form_line)
        output_file.close()
예제 #16
0
    def test_read_respects_max_dpd_tries_when_not_sorting(self):
        tables_directory = self.FIXTURES_ROOT / "data" / "wikitables"
        dpd_output_directory = self.FIXTURES_ROOT / "data" / "wikitables" / "dpd_output"
        reader = WikiTablesDatasetReader(lazy=False,
                                         sort_dpd_logical_forms=False,
                                         max_dpd_logical_forms=1,
                                         max_dpd_tries=1,
                                         tables_directory=tables_directory,
                                         dpd_output_directory=dpd_output_directory)
        dataset = reader.read(str(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data.examples"))
        instances = list(dataset)
        instance = instances[0]
        actions = [action_field.rule for action_field in instance.fields['actions'].field_list]

        # We should have just taken the first logical form from the file, which has the following
        # action sequence.
        action_sequence = instance.fields["target_action_sequences"].field_list[0]
        action_indices = [l.sequence_index for l in action_sequence.field_list]
        action_strings = [actions[i] for i in action_indices]
        assert action_strings == [
                '@start@ -> d',
                'd -> [<c,d>, c]',
                '<c,d> -> [<<#1,#2>,<#2,#1>>, <d,c>]',
                '<<#1,#2>,<#2,#1>> -> reverse',
                '<d,c> -> fb:cell.cell.date',
                'c -> [<r,c>, r]',
                '<r,c> -> [<<#1,#2>,<#2,#1>>, <c,r>]',
                '<<#1,#2>,<#2,#1>> -> reverse',
                '<c,r> -> fb:row.row.year',
                'r -> [<n,r>, n]',
                '<n,r> -> fb:row.row.index',
                'n -> [<nd,nd>, n]',
                '<nd,nd> -> max',
                'n -> [<r,n>, r]',
                '<r,n> -> [<<#1,#2>,<#2,#1>>, <n,r>]',
                '<<#1,#2>,<#2,#1>> -> reverse',
                '<n,r> -> fb:row.row.index',
                'r -> [<c,r>, c]',
                '<c,r> -> fb:row.row.league',
                'c -> fb:cell.usl_a_league'
                ]
 def test_reader_reads_with_lfs_in_tarball(self):
     offline_search_directory = (
         self.FIXTURES_ROOT / "data" / "wikitables" /
         "action_space_walker_output_with_single_tarball")
     params = {
         "lazy": False,
         "tables_directory": self.FIXTURES_ROOT / "data" / "wikitables",
         "offline_logical_forms_directory": offline_search_directory,
     }
     reader = WikiTablesDatasetReader.from_params(Params(params))
     dataset = reader.read(self.FIXTURES_ROOT / "data" / "wikitables" /
                           "sample_data.examples")
     assert_dataset_correct(dataset)
예제 #18
0
 def test_reader_reads_preprocessed_file(self):
     # We're should get the exact same results when reading a pre-processed file as we get when
     # we read the original data.
     reader = WikiTablesDatasetReader()
     dataset = reader.read(str(self.FIXTURES_ROOT / "data" / "wikitables" / "sample_data_preprocessed.jsonl"))
     assert_dataset_correct(dataset)