class ActionSpaceWalkerTest(AllenNlpTestCase):
    def setUp(self):
        super(ActionSpaceWalkerTest, self).setUp()
        self.world = FakeWorldWithAssertions()
        self.walker = ActionSpaceWalker(self.world, max_path_length=10)

    def test_get_logical_forms_with_agenda(self):
        black_logical_forms = self.walker.get_logical_forms_with_agenda(['<o,o> -> black'])
        # These are all the possible logical forms with black
        assert len(black_logical_forms) == 25
        shortest_logical_form = self.walker.get_logical_forms_with_agenda(['<o,o> -> black'], 1)[0]
        # This is the shortest complete logical form with black
        assert shortest_logical_form == '(object_exists (black all_objects))'
        black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(['<o,o> -> black',
                                                                                '<o,o> -> triangle',
                                                                                '<o,o> -> touch_wall'])
        # Permutations of the three functions. There will not be repetitions of any functions
        # because we limit the length of paths to 10 above.
        assert set(black_triangle_touch_forms) == set([
                '(object_exists (black (triangle (touch_wall all_objects))))',
                '(object_exists (black (touch_wall (triangle all_objects))))',
                '(object_exists (triangle (black (touch_wall all_objects))))',
                '(object_exists (triangle (touch_wall (black all_objects))))',
                '(object_exists (touch_wall (black (triangle all_objects))))',
                '(object_exists (touch_wall (triangle (black all_objects))))'])

    def test_get_all_logical_forms(self):
        # get_all_logical_forms should sort logical forms by length.
        ten_shortest_logical_forms = self.walker.get_all_logical_forms(max_num_logical_forms=10)
        shortest_logical_form = ten_shortest_logical_forms[0]
        assert shortest_logical_form == '(object_exists all_objects)'
        length_three_logical_forms = ten_shortest_logical_forms[1:4]
        assert set(length_three_logical_forms) == {'(object_exists (black all_objects))',
                                                   '(object_exists (touch_wall all_objects))',
                                                   '(object_exists (triangle all_objects))'}
def search(tables_directory: str,
           input_examples_file: str,
           output_file: str,
           max_path_length: int,
           max_num_logical_forms: int,
           use_agenda: bool) -> None:
    data = [wikitables_util.parse_example_line(example_line) for example_line in
            open(input_examples_file)]
    tokenizer = WordTokenizer()
    with open(output_file, "w") as output_file_pointer:
        for instance_data in data:
            utterance = instance_data["question"]
            question_id = instance_data["id"]
            if utterance.startswith('"') and utterance.endswith('"'):
                utterance = utterance[1:-1]
            # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
            table_file = instance_data["table_filename"].replace("csv", "tagged")
            # pylint: disable=protected-access
            target_list = [TableQuestionContext._normalize_string(value) for value in
                           instance_data["target_values"]]
            try:
                target_value_list = evaluator.to_value_list(target_list)
            except:
                print(target_list)
                target_value_list = evaluator.to_value_list(target_list)
            tokenized_question = tokenizer.tokenize(utterance)
            table_file = f"{tables_directory}/{table_file}"
            context = TableQuestionContext.read_from_file(table_file, tokenized_question)
            world = WikiTablesVariableFreeWorld(context)
            walker = ActionSpaceWalker(world, max_path_length=max_path_length)
            correct_logical_forms = []
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                agenda = world.get_agenda()
                print(f"Agenda: {agenda}", file=output_file_pointer)
                all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda,
                                                                         max_num_logical_forms=10000)
            else:
                all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
            for logical_form in all_logical_forms:
                try:
                    denotation = world.execute(logical_form)
                except ExecutionError:
                    print(f"Failed to execute: {logical_form}", file=sys.stderr)
                    continue
                if isinstance(denotation, list):
                    denotation_list = [str(denotation_item) for denotation_item in denotation]
                else:
                    # For numbers and dates
                    denotation_list = [str(denotation)]
                denotation_value_list = evaluator.to_value_list(denotation_list)
                if evaluator.check_denotation(target_value_list, denotation_value_list):
                    correct_logical_forms.append(logical_form)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
def search(
    tables_directory: str,
    data: JsonDict,
    output_path: str,
    max_path_length: int,
    max_num_logical_forms: int,
    use_agenda: bool,
    output_separate_files: bool,
    conservative_agenda: bool,
) -> None:
    print(f"Starting search with {len(data)} instances", file=sys.stderr)
    language_logger = logging.getLogger("allennlp.semparse.domain_languages.wikitables_language")
    language_logger.setLevel(logging.ERROR)
    tokenizer = SpacyTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file, tokenized_question)
        world = WikiTablesLanguage(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda(conservative=conservative_agenda)
            allow_partial_match = not conservative_agenda
            all_logical_forms = walker.get_logical_forms_with_agenda(
                agenda=agenda, max_num_logical_forms=10000, allow_partial_match=allow_partial_match
            )
        else:
            all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
Exemple #4
0
def search(tables_directory: str, input_examples_file: str, output_path: str,
           max_path_length: int, max_num_logical_forms: int, use_agenda: bool,
           output_separate_files: bool) -> None:
    data = [
        wikitables_util.parse_example_line(example_line)
        for example_line in open(input_examples_file)
    ]
    tokenizer = WordTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file,
                                                      tokenized_question)
        world = WikiTablesVariableFreeWorld(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda()
            all_logical_forms = walker.get_logical_forms_with_agenda(
                agenda=agenda, max_num_logical_forms=10000)
        else:
            all_logical_forms = walker.get_all_logical_forms(
                max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz",
                           "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
def search(tables_directory: str,
           input_examples_file: str,
           output_path: str,
           max_path_length: int,
           max_num_logical_forms: int,
           use_agenda: bool,
           output_separate_files: bool) -> None:
    data = [wikitables_util.parse_example_line(example_line) for example_line in
            open(input_examples_file)]
    tokenizer = WordTokenizer()
    if output_separate_files and not os.path.exists(output_path):
        os.makedirs(output_path)
    if not output_separate_files:
        output_file_pointer = open(output_path, "w")
    for instance_data in data:
        utterance = instance_data["question"]
        question_id = instance_data["id"]
        if utterance.startswith('"') and utterance.endswith('"'):
            utterance = utterance[1:-1]
        # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged
        table_file = instance_data["table_filename"].replace("csv", "tagged")
        target_list = instance_data["target_values"]
        tokenized_question = tokenizer.tokenize(utterance)
        table_file = f"{tables_directory}/{table_file}"
        context = TableQuestionContext.read_from_file(table_file, tokenized_question)
        world = WikiTablesVariableFreeWorld(context)
        walker = ActionSpaceWalker(world, max_path_length=max_path_length)
        correct_logical_forms = []
        if use_agenda:
            agenda = world.get_agenda()
            all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda,
                                                                     max_num_logical_forms=10000)
        else:
            all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000)
        for logical_form in all_logical_forms:
            if world.evaluate_logical_form(logical_form, target_list):
                correct_logical_forms.append(logical_form)
        if output_separate_files and correct_logical_forms:
            with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer:
                for logical_form in correct_logical_forms:
                    print(logical_form, file=output_file_pointer)
        elif not output_separate_files:
            print(f"{question_id} {utterance}", file=output_file_pointer)
            if use_agenda:
                print(f"Agenda: {agenda}", file=output_file_pointer)
            if not correct_logical_forms:
                print("NO LOGICAL FORMS FOUND!", file=output_file_pointer)
            for logical_form in correct_logical_forms[:max_num_logical_forms]:
                print(logical_form, file=output_file_pointer)
            print(file=output_file_pointer)
    if not output_separate_files:
        output_file_pointer.close()
Exemple #6
0
class ActionSpaceWalkerTest(AllenNlpTestCase):
    def setUp(self):
        super(ActionSpaceWalkerTest, self).setUp()
        self.world = FakeWorldWithAssertions()
        self.walker = ActionSpaceWalker(self.world, max_path_length=10)

    def test_get_logical_forms_with_agenda(self):
        black_logical_forms = self.walker.get_logical_forms_with_agenda(
            ['<o,o> -> black'])
        # These are all the possible logical forms with black
        assert len(black_logical_forms) == 25
        shortest_logical_form = self.walker.get_logical_forms_with_agenda(
            ['<o,o> -> black'], 1)[0]
        # This is the shortest complete logical form with black
        assert shortest_logical_form == '(object_exists (black all_objects))'
        black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(
            ['<o,o> -> black', '<o,o> -> triangle', '<o,o> -> touch_wall'])
        # Permutations of the three functions. There will not be repetitions of any functions
        # because we limit the length of paths to 10 above.
        assert set(black_triangle_touch_forms) == set([
            '(object_exists (black (triangle (touch_wall all_objects))))',
            '(object_exists (black (touch_wall (triangle all_objects))))',
            '(object_exists (triangle (black (touch_wall all_objects))))',
            '(object_exists (triangle (touch_wall (black all_objects))))',
            '(object_exists (touch_wall (black (triangle all_objects))))',
            '(object_exists (touch_wall (triangle (black all_objects))))'
        ])

    def test_get_all_logical_forms(self):
        # get_all_logical_forms should sort logical forms by length.
        ten_shortest_logical_forms = self.walker.get_all_logical_forms(
            max_num_logical_forms=10)
        shortest_logical_form = ten_shortest_logical_forms[0]
        assert shortest_logical_form == '(object_exists all_objects)'
        length_three_logical_forms = ten_shortest_logical_forms[1:4]
        assert set(length_three_logical_forms) == {
            '(object_exists (black all_objects))',
            '(object_exists (touch_wall all_objects))',
            '(object_exists (triangle all_objects))'
        }
 def setUp(self):
     super(ActionSpaceWalkerTest, self).setUp()
     self.world = FakeWorldWithAssertions()
     self.walker = ActionSpaceWalker(self.world, max_path_length=10)
class ActionSpaceWalkerTest(AllenNlpTestCase):
    def setUp(self):
        super(ActionSpaceWalkerTest, self).setUp()
        self.world = FakeWorldWithAssertions()
        self.walker = ActionSpaceWalker(self.world, max_path_length=10)

    def test_get_logical_forms_with_agenda(self):
        black_logical_forms = self.walker.get_logical_forms_with_agenda(
            ['<o,o> -> black'])
        # These are all the possible logical forms with black
        assert len(black_logical_forms) == 25
        shortest_logical_form = self.walker.get_logical_forms_with_agenda(
            ['<o,o> -> black'], 1)[0]
        # This is the shortest complete logical form with black
        assert shortest_logical_form == '(object_exists (black all_objects))'
        black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(
            ['<o,o> -> black', '<o,o> -> triangle', '<o,o> -> touch_wall'])
        # Permutations of the three functions. There will not be repetitions of any functions
        # because we limit the length of paths to 10 above.
        assert set(black_triangle_touch_forms) == set([
            '(object_exists (black (triangle (touch_wall all_objects))))',
            '(object_exists (black (touch_wall (triangle all_objects))))',
            '(object_exists (triangle (black (touch_wall all_objects))))',
            '(object_exists (triangle (touch_wall (black all_objects))))',
            '(object_exists (touch_wall (black (triangle all_objects))))',
            '(object_exists (touch_wall (triangle (black all_objects))))'
        ])

    def test_get_logical_forms_with_empty_agenda_returns_all_logical_forms(
            self):
        with self.assertLogs("allennlp.semparse.action_space_walker") as log:
            empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda(
                [])
            first_four_logical_forms = empty_agenda_logical_forms[:4]
            assert set(first_four_logical_forms) == {
                '(object_exists all_objects)',
                '(object_exists (black all_objects))',
                '(object_exists (touch_wall all_objects))',
                '(object_exists (triangle all_objects))'
            }
        self.assertEqual(log.output, [
            "WARNING:allennlp.semparse.action_space_walker:"
            "Agenda is empty! Returning all paths instead."
        ])

    def test_get_logical_forms_with_agenda_ignores_null_set_item(self):
        with self.assertLogs("allennlp.semparse.action_space_walker") as log:
            yellow_black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(
                [
                    '<o,o> -> yellow', '<o,o> -> black', '<o,o> -> triangle',
                    '<o,o> -> touch_wall'
                ])
            # Permutations of the three functions, after ignoring yellow. There will not be repetitions
            # of any functions because we limit the length of paths to 10 above.
            assert set(yellow_black_triangle_touch_forms) == set([
                '(object_exists (black (triangle (touch_wall all_objects))))',
                '(object_exists (black (touch_wall (triangle all_objects))))',
                '(object_exists (triangle (black (touch_wall all_objects))))',
                '(object_exists (triangle (touch_wall (black all_objects))))',
                '(object_exists (touch_wall (black (triangle all_objects))))',
                '(object_exists (touch_wall (triangle (black all_objects))))'
            ])
        self.assertEqual(log.output, [
            "WARNING:allennlp.semparse.action_space_walker:"
            "<o,o> -> yellow is not in any of the paths found! Ignoring it."
        ])

    def test_get_all_logical_forms(self):
        # get_all_logical_forms should sort logical forms by length.
        ten_shortest_logical_forms = self.walker.get_all_logical_forms(
            max_num_logical_forms=10)
        shortest_logical_form = ten_shortest_logical_forms[0]
        assert shortest_logical_form == '(object_exists all_objects)'
        length_three_logical_forms = ten_shortest_logical_forms[1:4]
        assert set(length_three_logical_forms) == {
            '(object_exists (black all_objects))',
            '(object_exists (touch_wall all_objects))',
            '(object_exists (triangle all_objects))'
        }
Exemple #9
0
def process_data(
    input_file: str,
    output_file: str,
    max_path_length: int,
    max_num_logical_forms: int,
    ignore_agenda: bool,
    write_sequences: bool,
) -> None:
    """
    Reads an NLVR dataset and returns a JSON representation containing sentences, labels, correct and
    incorrect logical forms. The output will contain at most `max_num_logical_forms` logical forms
    each in both correct and incorrect lists. The output format is:
        ``[{"id": str, "label": str, "sentence": str, "correct": List[str], "incorrect": List[str]}]``
    """
    processed_data: JsonDict = []
    # We can instantiate the ``ActionSpaceWalker`` with any world because the action space is the
    # same for all the ``NlvrLanguage`` objects. It is just the execution that differs.
    walker = ActionSpaceWalker(NlvrLanguage({}),
                               max_path_length=max_path_length)
    for line in open(input_file):
        instance_id, sentence, structured_reps, label_strings = read_json_line(
            line)
        worlds = []
        for structured_representation in structured_reps:
            boxes = {
                Box(object_list, box_id)
                for box_id, object_list in enumerate(structured_representation)
            }
            worlds.append(NlvrLanguage(boxes))
        labels = [label_string == "true" for label_string in label_strings]
        correct_logical_forms = []
        incorrect_logical_forms = []
        if ignore_agenda:
            # Get 1000 shortest logical forms.
            logical_forms = walker.get_all_logical_forms(
                max_num_logical_forms=1000)
        else:
            # TODO (pradeep): Assuming all worlds give the same agenda.
            sentence_agenda = worlds[0].get_agenda_for_sentence(sentence)
            logical_forms = walker.get_logical_forms_with_agenda(
                sentence_agenda, max_num_logical_forms * 10)
        for logical_form in logical_forms:
            if all([
                    world.execute(logical_form) == label
                    for world, label in zip(worlds, labels)
            ]):
                if len(correct_logical_forms) <= max_num_logical_forms:
                    correct_logical_forms.append(logical_form)
            else:
                if len(incorrect_logical_forms) <= max_num_logical_forms:
                    incorrect_logical_forms.append(logical_form)
            if (len(correct_logical_forms) >= max_num_logical_forms
                    and len(incorrect_logical_forms) >= max_num_logical_forms):
                break
        if write_sequences:
            correct_sequences = [
                worlds[0].logical_form_to_action_sequence(logical_form)
                for logical_form in correct_logical_forms
            ]
            incorrect_sequences = [
                worlds[0].logical_form_to_action_sequence(logical_form)
                for logical_form in incorrect_logical_forms
            ]
            processed_data.append({
                "id": instance_id,
                "sentence": sentence,
                "correct_sequences": correct_sequences,
                "incorrect_sequences": incorrect_sequences,
                "worlds": structured_reps,
                "labels": label_strings,
            })
        else:
            processed_data.append({
                "id": instance_id,
                "sentence": sentence,
                "correct_logical_forms": correct_logical_forms,
                "incorrect_logical_forms": incorrect_logical_forms,
                "worlds": structured_reps,
                "labels": label_strings,
            })
    with open(output_file, "w") as outfile:
        for instance_processed_data in processed_data:
            json.dump(instance_processed_data, outfile)
            outfile.write("\n")
        outfile.close()
Exemple #10
0
 def setUp(self):
     super(ActionSpaceWalkerTest, self).setUp()
     self.world = FakeLanguageWithAssertions(start_types={bool})
     self.walker = ActionSpaceWalker(self.world, max_path_length=10)
Exemple #11
0
class ActionSpaceWalkerTest(AllenNlpTestCase):
    def setUp(self):
        super(ActionSpaceWalkerTest, self).setUp()
        self.world = FakeLanguageWithAssertions(start_types={bool})
        self.walker = ActionSpaceWalker(self.world, max_path_length=10)

    def test_get_logical_forms_with_agenda(self):
        black_logical_forms = self.walker.get_logical_forms_with_agenda(
            ['<Set[Object]:Set[Object]> -> black'])
        # These are all the possible logical forms with black
        assert len(black_logical_forms) == 25
        shortest_logical_form = self.walker.get_logical_forms_with_agenda(
            ['<Set[Object]:Set[Object]> -> black'], 1)[0]
        # This is the shortest complete logical form with black
        assert shortest_logical_form == '(object_exists (black all_objects))'
        agenda = [
            '<Set[Object]:Set[Object]> -> black',
            '<Set[Object]:Set[Object]> -> triangle',
            '<Set[Object]:Set[Object]> -> touch_wall'
        ]
        black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(
            agenda)
        # Permutations of the three functions. There will not be repetitions of any functions
        # because we limit the length of paths to 10 above.
        assert set(black_triangle_touch_forms) == set([
            '(object_exists (black (triangle (touch_wall all_objects))))',
            '(object_exists (black (touch_wall (triangle all_objects))))',
            '(object_exists (triangle (black (touch_wall all_objects))))',
            '(object_exists (triangle (touch_wall (black all_objects))))',
            '(object_exists (touch_wall (black (triangle all_objects))))',
            '(object_exists (touch_wall (triangle (black all_objects))))'
        ])

    def test_get_logical_forms_with_agenda_and_partial_match(self):
        black_logical_forms = self.walker.get_logical_forms_with_agenda(
            ['<Set[Object]:Set[Object]> -> black'])
        # These are all the possible logical forms with black
        assert len(black_logical_forms) == 25
        shortest_logical_form = self.walker.get_logical_forms_with_agenda(
            ['<Set[Object]:Set[Object]> -> black'], 1)[0]
        # This is the shortest complete logical form with black
        assert shortest_logical_form == '(object_exists (black all_objects))'
        agenda = [
            '<Set[Object]:Set[Object]> -> black',
            '<Set[Object]:Set[Object]> -> triangle',
            '<Set[Object]:Set[Object]> -> touch_wall'
        ]
        black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(
            agenda, allow_partial_match=True)
        # The first six logical forms will contain permutations of all three functions.
        assert set(black_triangle_touch_forms[:6]) == set([
            '(object_exists (black (triangle (touch_wall all_objects))))',
            '(object_exists (black (touch_wall (triangle all_objects))))',
            '(object_exists (triangle (black (touch_wall all_objects))))',
            '(object_exists (triangle (touch_wall (black all_objects))))',
            '(object_exists (touch_wall (black (triangle all_objects))))',
            '(object_exists (touch_wall (triangle (black all_objects))))'
        ])

        # The next six will be the shortest six with two agenda items.
        assert set(black_triangle_touch_forms[6:12]) == set([
            '(object_exists (black (triangle all_objects)))',
            '(object_exists (black (touch_wall all_objects)))',
            '(object_exists (triangle (black all_objects)))',
            '(object_exists (triangle (touch_wall all_objects)))',
            '(object_exists (touch_wall (black all_objects)))',
            '(object_exists (touch_wall (triangle all_objects)))'
        ])

        # After a bunch of longer logical forms with two agenda items, we have the shortest three
        # with one agenda item.
        assert set(black_triangle_touch_forms[30:33]) == set([
            '(object_exists (black all_objects))',
            '(object_exists (triangle all_objects))',
            '(object_exists (touch_wall all_objects))'
        ])

    def test_get_logical_forms_with_empty_agenda_returns_all_logical_forms(
            self):
        with self.assertLogs("allennlp.semparse.action_space_walker") as log:
            empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda(
                [], allow_partial_match=True)
            first_four_logical_forms = empty_agenda_logical_forms[:4]
            assert set(first_four_logical_forms) == {
                '(object_exists all_objects)',
                '(object_exists (black all_objects))',
                '(object_exists (touch_wall all_objects))',
                '(object_exists (triangle all_objects))'
            }
        self.assertEqual(log.output, [
            "WARNING:allennlp.semparse.action_space_walker:"
            "Agenda is empty! Returning all paths instead."
        ])

    def test_get_logical_forms_with_unmatched_agenda_returns_all_logical_forms(
            self):
        agenda = ['<Set[Object]:Set[Object]> -> purple']
        with self.assertLogs("allennlp.semparse.action_space_walker") as log:
            empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda(
                agenda, allow_partial_match=True)
            first_four_logical_forms = empty_agenda_logical_forms[:4]
            assert set(first_four_logical_forms) == {
                '(object_exists all_objects)',
                '(object_exists (black all_objects))',
                '(object_exists (touch_wall all_objects))',
                '(object_exists (triangle all_objects))'
            }
        self.assertEqual(log.output, [
            "WARNING:allennlp.semparse.action_space_walker:"
            "Agenda items not in any of the paths found. Returning all paths."
        ])
        empty_set = self.walker.get_logical_forms_with_agenda(
            agenda, allow_partial_match=False)
        assert empty_set == []

    def test_get_logical_forms_with_agenda_ignores_null_set_item(self):
        with self.assertLogs("allennlp.semparse.action_space_walker") as log:
            agenda = [
                '<Set[Object]:Set[Object]> -> yellow',
                '<Set[Object]:Set[Object]> -> black',
                '<Set[Object]:Set[Object]> -> triangle',
                '<Set[Object]:Set[Object]> -> touch_wall'
            ]
            yellow_black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(
                agenda)
            # Permutations of the three functions, after ignoring yellow. There will not be repetitions
            # of any functions because we limit the length of paths to 10 above.
            assert set(yellow_black_triangle_touch_forms) == set([
                '(object_exists (black (triangle (touch_wall all_objects))))',
                '(object_exists (black (touch_wall (triangle all_objects))))',
                '(object_exists (triangle (black (touch_wall all_objects))))',
                '(object_exists (triangle (touch_wall (black all_objects))))',
                '(object_exists (touch_wall (black (triangle all_objects))))',
                '(object_exists (touch_wall (triangle (black all_objects))))'
            ])
        self.assertEqual(log.output, [
            "WARNING:allennlp.semparse.action_space_walker:"
            "<Set[Object]:Set[Object]> -> yellow is not in any of the paths found! Ignoring it."
        ])

    def test_get_all_logical_forms(self):
        # get_all_logical_forms should sort logical forms by length.
        ten_shortest_logical_forms = self.walker.get_all_logical_forms(
            max_num_logical_forms=10)
        shortest_logical_form = ten_shortest_logical_forms[0]
        assert shortest_logical_form == '(object_exists all_objects)'
        length_three_logical_forms = ten_shortest_logical_forms[1:4]
        assert set(length_three_logical_forms) == {
            '(object_exists (black all_objects))',
            '(object_exists (touch_wall all_objects))',
            '(object_exists (triangle all_objects))'
        }
Exemple #12
0
def process_data(input_file: str,
                 output_file: str,
                 max_path_length: int,
                 max_num_logical_forms: int,
                 ignore_agenda: bool,
                 write_sequences: bool) -> None:
    """
    Reads an NLVR dataset and returns a JSON representation containing sentences, labels, correct and
    incorrect logical forms. The output will contain at most `max_num_logical_forms` logical forms
    each in both correct and incorrect lists. The output format is:
        ``[{"id": str, "label": str, "sentence": str, "correct": List[str], "incorrect": List[str]}]``
    """
    processed_data: JsonDict = []
    # We can instantiate the ``ActionSpaceWalker`` with any world because the action space is the
    # same for all the ``NlvrWorlds``. It is just the execution that differs.
    serialized_walker_path = f"serialized_action_space_walker_pl={max_path_length}.pkl"
    if os.path.isfile(serialized_walker_path):
        print("Reading walker from serialized file", file=sys.stderr)
        walker = pickle.load(open(serialized_walker_path, "rb"))
    else:
        walker = ActionSpaceWalker(NlvrWorld({}), max_path_length=max_path_length)
        pickle.dump(walker, open(serialized_walker_path, "wb"))
    for line in open(input_file):
        instance_id, sentence, structured_reps, label_strings = read_json_line(line)
        worlds = [NlvrWorld(structured_rep) for structured_rep in structured_reps]
        labels = [label_string == "true" for label_string in label_strings]
        correct_logical_forms = []
        incorrect_logical_forms = []
        if ignore_agenda:
            # Get 1000 shortest logical forms.
            logical_forms = walker.get_all_logical_forms(max_num_logical_forms=1000)
        else:
            # TODO (pradeep): Assuming all worlds give the same agenda.
            sentence_agenda = worlds[0].get_agenda_for_sentence(sentence, add_paths_to_agenda=False)
            logical_forms = walker.get_logical_forms_with_agenda(sentence_agenda,
                                                                 max_num_logical_forms * 10)
        for logical_form in logical_forms:
            if all([world.execute(logical_form) == label for world, label in zip(worlds, labels)]):
                if len(correct_logical_forms) <= max_num_logical_forms:
                    correct_logical_forms.append(logical_form)
            else:
                if len(incorrect_logical_forms) <= max_num_logical_forms:
                    incorrect_logical_forms.append(logical_form)
            if len(correct_logical_forms) >= max_num_logical_forms \
               and len(incorrect_logical_forms) >= max_num_logical_forms:
                break
        if write_sequences:
            parsed_correct_forms = [worlds[0].parse_logical_form(logical_form) for logical_form in
                                    correct_logical_forms]
            correct_sequences = [worlds[0].get_action_sequence(parsed_form) for parsed_form in
                                 parsed_correct_forms]
            parsed_incorrect_forms = [worlds[0].parse_logical_form(logical_form) for logical_form in
                                      incorrect_logical_forms]
            incorrect_sequences = [worlds[0].get_action_sequence(parsed_form) for parsed_form in
                                   parsed_incorrect_forms]
            processed_data.append({"id": instance_id,
                                   "sentence": sentence,
                                   "correct_sequences": correct_sequences,
                                   "incorrect_sequences": incorrect_sequences,
                                   "worlds": structured_reps,
                                   "labels": label_strings})
        else:
            processed_data.append({"id": instance_id,
                                   "sentence": sentence,
                                   "correct_logical_forms": correct_logical_forms,
                                   "incorrect_logical_forms": incorrect_logical_forms,
                                   "worlds": structured_reps,
                                   "labels": label_strings})
    with open(output_file, "w") as outfile:
        for instance_processed_data in processed_data:
            json.dump(instance_processed_data, outfile)
            outfile.write('\n')
        outfile.close()
 def setUp(self):
     super(ActionSpaceWalkerTest, self).setUp()
     self.world = FakeWorldWithAssertions()
     self.walker = ActionSpaceWalker(self.world, max_path_length=10)
class ActionSpaceWalkerTest(AllenNlpTestCase):
    def setUp(self):
        super(ActionSpaceWalkerTest, self).setUp()
        self.world = FakeWorldWithAssertions()
        self.walker = ActionSpaceWalker(self.world, max_path_length=10)

    def test_get_logical_forms_with_agenda(self):
        black_logical_forms = self.walker.get_logical_forms_with_agenda(['<o,o> -> black'])
        # These are all the possible logical forms with black
        assert len(black_logical_forms) == 25
        shortest_logical_form = self.walker.get_logical_forms_with_agenda(['<o,o> -> black'], 1)[0]
        # This is the shortest complete logical form with black
        assert shortest_logical_form == '(object_exists (black all_objects))'
        black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(['<o,o> -> black',
                                                                                '<o,o> -> triangle',
                                                                                '<o,o> -> touch_wall'])
        # Permutations of the three functions. There will not be repetitions of any functions
        # because we limit the length of paths to 10 above.
        assert set(black_triangle_touch_forms) == set([
                '(object_exists (black (triangle (touch_wall all_objects))))',
                '(object_exists (black (touch_wall (triangle all_objects))))',
                '(object_exists (triangle (black (touch_wall all_objects))))',
                '(object_exists (triangle (touch_wall (black all_objects))))',
                '(object_exists (touch_wall (black (triangle all_objects))))',
                '(object_exists (touch_wall (triangle (black all_objects))))'])

    def test_get_logical_forms_with_empty_agenda_returns_all_logical_forms(self):
        with self.assertLogs("allennlp.semparse.action_space_walker") as log:
            empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda([])
            first_four_logical_forms = empty_agenda_logical_forms[:4]
            assert set(first_four_logical_forms) == {'(object_exists all_objects)',
                                                     '(object_exists (black all_objects))',
                                                     '(object_exists (touch_wall all_objects))',
                                                     '(object_exists (triangle all_objects))'}
        self.assertEqual(log.output,
                         ["WARNING:allennlp.semparse.action_space_walker:"
                          "Agenda is empty! Returning all paths instead."])

    def test_get_logical_forms_with_agenda_ignores_null_set_item(self):
        with self.assertLogs("allennlp.semparse.action_space_walker") as log:
            yellow_black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(['<o,o> -> yellow',
                                                                                           '<o,o> -> black',
                                                                                           '<o,o> -> triangle',
                                                                                           '<o,o> -> touch_wall'])
            # Permutations of the three functions, after ignoring yellow. There will not be repetitions
            # of any functions because we limit the length of paths to 10 above.
            assert set(yellow_black_triangle_touch_forms) == set([
                    '(object_exists (black (triangle (touch_wall all_objects))))',
                    '(object_exists (black (touch_wall (triangle all_objects))))',
                    '(object_exists (triangle (black (touch_wall all_objects))))',
                    '(object_exists (triangle (touch_wall (black all_objects))))',
                    '(object_exists (touch_wall (black (triangle all_objects))))',
                    '(object_exists (touch_wall (triangle (black all_objects))))'])
        self.assertEqual(log.output,
                         ["WARNING:allennlp.semparse.action_space_walker:"
                          "<o,o> -> yellow is not in any of the paths found! Ignoring it."])

    def test_get_all_logical_forms(self):
        # get_all_logical_forms should sort logical forms by length.
        ten_shortest_logical_forms = self.walker.get_all_logical_forms(max_num_logical_forms=10)
        shortest_logical_form = ten_shortest_logical_forms[0]
        assert shortest_logical_form == '(object_exists all_objects)'
        length_three_logical_forms = ten_shortest_logical_forms[1:4]
        assert set(length_three_logical_forms) == {'(object_exists (black all_objects))',
                                                   '(object_exists (touch_wall all_objects))',
                                                   '(object_exists (triangle all_objects))'}