class ActionSpaceWalkerTest(AllenNlpTestCase): def setUp(self): super(ActionSpaceWalkerTest, self).setUp() self.world = FakeWorldWithAssertions() self.walker = ActionSpaceWalker(self.world, max_path_length=10) def test_get_logical_forms_with_agenda(self): black_logical_forms = self.walker.get_logical_forms_with_agenda(['<o,o> -> black']) # These are all the possible logical forms with black assert len(black_logical_forms) == 25 shortest_logical_form = self.walker.get_logical_forms_with_agenda(['<o,o> -> black'], 1)[0] # This is the shortest complete logical form with black assert shortest_logical_form == '(object_exists (black all_objects))' black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(['<o,o> -> black', '<o,o> -> triangle', '<o,o> -> touch_wall']) # Permutations of the three functions. There will not be repetitions of any functions # because we limit the length of paths to 10 above. assert set(black_triangle_touch_forms) == set([ '(object_exists (black (triangle (touch_wall all_objects))))', '(object_exists (black (touch_wall (triangle all_objects))))', '(object_exists (triangle (black (touch_wall all_objects))))', '(object_exists (triangle (touch_wall (black all_objects))))', '(object_exists (touch_wall (black (triangle all_objects))))', '(object_exists (touch_wall (triangle (black all_objects))))']) def test_get_all_logical_forms(self): # get_all_logical_forms should sort logical forms by length. ten_shortest_logical_forms = self.walker.get_all_logical_forms(max_num_logical_forms=10) shortest_logical_form = ten_shortest_logical_forms[0] assert shortest_logical_form == '(object_exists all_objects)' length_three_logical_forms = ten_shortest_logical_forms[1:4] assert set(length_three_logical_forms) == {'(object_exists (black all_objects))', '(object_exists (touch_wall all_objects))', '(object_exists (triangle all_objects))'}
def search(tables_directory: str, input_examples_file: str, output_file: str, max_path_length: int, max_num_logical_forms: int, use_agenda: bool) -> None: data = [wikitables_util.parse_example_line(example_line) for example_line in open(input_examples_file)] tokenizer = WordTokenizer() with open(output_file, "w") as output_file_pointer: for instance_data in data: utterance = instance_data["question"] question_id = instance_data["id"] if utterance.startswith('"') and utterance.endswith('"'): utterance = utterance[1:-1] # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged table_file = instance_data["table_filename"].replace("csv", "tagged") # pylint: disable=protected-access target_list = [TableQuestionContext._normalize_string(value) for value in instance_data["target_values"]] try: target_value_list = evaluator.to_value_list(target_list) except: print(target_list) target_value_list = evaluator.to_value_list(target_list) tokenized_question = tokenizer.tokenize(utterance) table_file = f"{tables_directory}/{table_file}" context = TableQuestionContext.read_from_file(table_file, tokenized_question) world = WikiTablesVariableFreeWorld(context) walker = ActionSpaceWalker(world, max_path_length=max_path_length) correct_logical_forms = [] print(f"{question_id} {utterance}", file=output_file_pointer) if use_agenda: agenda = world.get_agenda() print(f"Agenda: {agenda}", file=output_file_pointer) all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda, max_num_logical_forms=10000) else: all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000) for logical_form in all_logical_forms: try: denotation = world.execute(logical_form) except ExecutionError: print(f"Failed to execute: {logical_form}", file=sys.stderr) continue if isinstance(denotation, list): denotation_list = [str(denotation_item) for denotation_item in denotation] else: # For numbers and dates denotation_list = [str(denotation)] denotation_value_list = evaluator.to_value_list(denotation_list) if evaluator.check_denotation(target_value_list, denotation_value_list): correct_logical_forms.append(logical_form) if not correct_logical_forms: print("NO LOGICAL FORMS FOUND!", file=output_file_pointer) for logical_form in correct_logical_forms[:max_num_logical_forms]: print(logical_form, file=output_file_pointer) print(file=output_file_pointer)
def search( tables_directory: str, data: JsonDict, output_path: str, max_path_length: int, max_num_logical_forms: int, use_agenda: bool, output_separate_files: bool, conservative_agenda: bool, ) -> None: print(f"Starting search with {len(data)} instances", file=sys.stderr) language_logger = logging.getLogger("allennlp.semparse.domain_languages.wikitables_language") language_logger.setLevel(logging.ERROR) tokenizer = SpacyTokenizer() if output_separate_files and not os.path.exists(output_path): os.makedirs(output_path) if not output_separate_files: output_file_pointer = open(output_path, "w") for instance_data in data: utterance = instance_data["question"] question_id = instance_data["id"] if utterance.startswith('"') and utterance.endswith('"'): utterance = utterance[1:-1] # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged table_file = instance_data["table_filename"].replace("csv", "tagged") target_list = instance_data["target_values"] tokenized_question = tokenizer.tokenize(utterance) table_file = f"{tables_directory}/{table_file}" context = TableQuestionContext.read_from_file(table_file, tokenized_question) world = WikiTablesLanguage(context) walker = ActionSpaceWalker(world, max_path_length=max_path_length) correct_logical_forms = [] if use_agenda: agenda = world.get_agenda(conservative=conservative_agenda) allow_partial_match = not conservative_agenda all_logical_forms = walker.get_logical_forms_with_agenda( agenda=agenda, max_num_logical_forms=10000, allow_partial_match=allow_partial_match ) else: all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000) for logical_form in all_logical_forms: if world.evaluate_logical_form(logical_form, target_list): correct_logical_forms.append(logical_form) if output_separate_files and correct_logical_forms: with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer: for logical_form in correct_logical_forms: print(logical_form, file=output_file_pointer) elif not output_separate_files: print(f"{question_id} {utterance}", file=output_file_pointer) if use_agenda: print(f"Agenda: {agenda}", file=output_file_pointer) if not correct_logical_forms: print("NO LOGICAL FORMS FOUND!", file=output_file_pointer) for logical_form in correct_logical_forms[:max_num_logical_forms]: print(logical_form, file=output_file_pointer) print(file=output_file_pointer) if not output_separate_files: output_file_pointer.close()
def search(tables_directory: str, input_examples_file: str, output_path: str, max_path_length: int, max_num_logical_forms: int, use_agenda: bool, output_separate_files: bool) -> None: data = [ wikitables_util.parse_example_line(example_line) for example_line in open(input_examples_file) ] tokenizer = WordTokenizer() if output_separate_files and not os.path.exists(output_path): os.makedirs(output_path) if not output_separate_files: output_file_pointer = open(output_path, "w") for instance_data in data: utterance = instance_data["question"] question_id = instance_data["id"] if utterance.startswith('"') and utterance.endswith('"'): utterance = utterance[1:-1] # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged table_file = instance_data["table_filename"].replace("csv", "tagged") target_list = instance_data["target_values"] tokenized_question = tokenizer.tokenize(utterance) table_file = f"{tables_directory}/{table_file}" context = TableQuestionContext.read_from_file(table_file, tokenized_question) world = WikiTablesVariableFreeWorld(context) walker = ActionSpaceWalker(world, max_path_length=max_path_length) correct_logical_forms = [] if use_agenda: agenda = world.get_agenda() all_logical_forms = walker.get_logical_forms_with_agenda( agenda=agenda, max_num_logical_forms=10000) else: all_logical_forms = walker.get_all_logical_forms( max_num_logical_forms=10000) for logical_form in all_logical_forms: if world.evaluate_logical_form(logical_form, target_list): correct_logical_forms.append(logical_form) if output_separate_files and correct_logical_forms: with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer: for logical_form in correct_logical_forms: print(logical_form, file=output_file_pointer) elif not output_separate_files: print(f"{question_id} {utterance}", file=output_file_pointer) if use_agenda: print(f"Agenda: {agenda}", file=output_file_pointer) if not correct_logical_forms: print("NO LOGICAL FORMS FOUND!", file=output_file_pointer) for logical_form in correct_logical_forms[:max_num_logical_forms]: print(logical_form, file=output_file_pointer) print(file=output_file_pointer) if not output_separate_files: output_file_pointer.close()
def search(tables_directory: str, input_examples_file: str, output_path: str, max_path_length: int, max_num_logical_forms: int, use_agenda: bool, output_separate_files: bool) -> None: data = [wikitables_util.parse_example_line(example_line) for example_line in open(input_examples_file)] tokenizer = WordTokenizer() if output_separate_files and not os.path.exists(output_path): os.makedirs(output_path) if not output_separate_files: output_file_pointer = open(output_path, "w") for instance_data in data: utterance = instance_data["question"] question_id = instance_data["id"] if utterance.startswith('"') and utterance.endswith('"'): utterance = utterance[1:-1] # For example: csv/200-csv/47.csv -> tagged/200-tagged/47.tagged table_file = instance_data["table_filename"].replace("csv", "tagged") target_list = instance_data["target_values"] tokenized_question = tokenizer.tokenize(utterance) table_file = f"{tables_directory}/{table_file}" context = TableQuestionContext.read_from_file(table_file, tokenized_question) world = WikiTablesVariableFreeWorld(context) walker = ActionSpaceWalker(world, max_path_length=max_path_length) correct_logical_forms = [] if use_agenda: agenda = world.get_agenda() all_logical_forms = walker.get_logical_forms_with_agenda(agenda=agenda, max_num_logical_forms=10000) else: all_logical_forms = walker.get_all_logical_forms(max_num_logical_forms=10000) for logical_form in all_logical_forms: if world.evaluate_logical_form(logical_form, target_list): correct_logical_forms.append(logical_form) if output_separate_files and correct_logical_forms: with gzip.open(f"{output_path}/{question_id}.gz", "wt") as output_file_pointer: for logical_form in correct_logical_forms: print(logical_form, file=output_file_pointer) elif not output_separate_files: print(f"{question_id} {utterance}", file=output_file_pointer) if use_agenda: print(f"Agenda: {agenda}", file=output_file_pointer) if not correct_logical_forms: print("NO LOGICAL FORMS FOUND!", file=output_file_pointer) for logical_form in correct_logical_forms[:max_num_logical_forms]: print(logical_form, file=output_file_pointer) print(file=output_file_pointer) if not output_separate_files: output_file_pointer.close()
class ActionSpaceWalkerTest(AllenNlpTestCase): def setUp(self): super(ActionSpaceWalkerTest, self).setUp() self.world = FakeWorldWithAssertions() self.walker = ActionSpaceWalker(self.world, max_path_length=10) def test_get_logical_forms_with_agenda(self): black_logical_forms = self.walker.get_logical_forms_with_agenda( ['<o,o> -> black']) # These are all the possible logical forms with black assert len(black_logical_forms) == 25 shortest_logical_form = self.walker.get_logical_forms_with_agenda( ['<o,o> -> black'], 1)[0] # This is the shortest complete logical form with black assert shortest_logical_form == '(object_exists (black all_objects))' black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda( ['<o,o> -> black', '<o,o> -> triangle', '<o,o> -> touch_wall']) # Permutations of the three functions. There will not be repetitions of any functions # because we limit the length of paths to 10 above. assert set(black_triangle_touch_forms) == set([ '(object_exists (black (triangle (touch_wall all_objects))))', '(object_exists (black (touch_wall (triangle all_objects))))', '(object_exists (triangle (black (touch_wall all_objects))))', '(object_exists (triangle (touch_wall (black all_objects))))', '(object_exists (touch_wall (black (triangle all_objects))))', '(object_exists (touch_wall (triangle (black all_objects))))' ]) def test_get_all_logical_forms(self): # get_all_logical_forms should sort logical forms by length. ten_shortest_logical_forms = self.walker.get_all_logical_forms( max_num_logical_forms=10) shortest_logical_form = ten_shortest_logical_forms[0] assert shortest_logical_form == '(object_exists all_objects)' length_three_logical_forms = ten_shortest_logical_forms[1:4] assert set(length_three_logical_forms) == { '(object_exists (black all_objects))', '(object_exists (touch_wall all_objects))', '(object_exists (triangle all_objects))' }
def setUp(self): super(ActionSpaceWalkerTest, self).setUp() self.world = FakeWorldWithAssertions() self.walker = ActionSpaceWalker(self.world, max_path_length=10)
class ActionSpaceWalkerTest(AllenNlpTestCase): def setUp(self): super(ActionSpaceWalkerTest, self).setUp() self.world = FakeWorldWithAssertions() self.walker = ActionSpaceWalker(self.world, max_path_length=10) def test_get_logical_forms_with_agenda(self): black_logical_forms = self.walker.get_logical_forms_with_agenda( ['<o,o> -> black']) # These are all the possible logical forms with black assert len(black_logical_forms) == 25 shortest_logical_form = self.walker.get_logical_forms_with_agenda( ['<o,o> -> black'], 1)[0] # This is the shortest complete logical form with black assert shortest_logical_form == '(object_exists (black all_objects))' black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda( ['<o,o> -> black', '<o,o> -> triangle', '<o,o> -> touch_wall']) # Permutations of the three functions. There will not be repetitions of any functions # because we limit the length of paths to 10 above. assert set(black_triangle_touch_forms) == set([ '(object_exists (black (triangle (touch_wall all_objects))))', '(object_exists (black (touch_wall (triangle all_objects))))', '(object_exists (triangle (black (touch_wall all_objects))))', '(object_exists (triangle (touch_wall (black all_objects))))', '(object_exists (touch_wall (black (triangle all_objects))))', '(object_exists (touch_wall (triangle (black all_objects))))' ]) def test_get_logical_forms_with_empty_agenda_returns_all_logical_forms( self): with self.assertLogs("allennlp.semparse.action_space_walker") as log: empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda( []) first_four_logical_forms = empty_agenda_logical_forms[:4] assert set(first_four_logical_forms) == { '(object_exists all_objects)', '(object_exists (black all_objects))', '(object_exists (touch_wall all_objects))', '(object_exists (triangle all_objects))' } self.assertEqual(log.output, [ "WARNING:allennlp.semparse.action_space_walker:" "Agenda is empty! Returning all paths instead." ]) def test_get_logical_forms_with_agenda_ignores_null_set_item(self): with self.assertLogs("allennlp.semparse.action_space_walker") as log: yellow_black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda( [ '<o,o> -> yellow', '<o,o> -> black', '<o,o> -> triangle', '<o,o> -> touch_wall' ]) # Permutations of the three functions, after ignoring yellow. There will not be repetitions # of any functions because we limit the length of paths to 10 above. assert set(yellow_black_triangle_touch_forms) == set([ '(object_exists (black (triangle (touch_wall all_objects))))', '(object_exists (black (touch_wall (triangle all_objects))))', '(object_exists (triangle (black (touch_wall all_objects))))', '(object_exists (triangle (touch_wall (black all_objects))))', '(object_exists (touch_wall (black (triangle all_objects))))', '(object_exists (touch_wall (triangle (black all_objects))))' ]) self.assertEqual(log.output, [ "WARNING:allennlp.semparse.action_space_walker:" "<o,o> -> yellow is not in any of the paths found! Ignoring it." ]) def test_get_all_logical_forms(self): # get_all_logical_forms should sort logical forms by length. ten_shortest_logical_forms = self.walker.get_all_logical_forms( max_num_logical_forms=10) shortest_logical_form = ten_shortest_logical_forms[0] assert shortest_logical_form == '(object_exists all_objects)' length_three_logical_forms = ten_shortest_logical_forms[1:4] assert set(length_three_logical_forms) == { '(object_exists (black all_objects))', '(object_exists (touch_wall all_objects))', '(object_exists (triangle all_objects))' }
def process_data( input_file: str, output_file: str, max_path_length: int, max_num_logical_forms: int, ignore_agenda: bool, write_sequences: bool, ) -> None: """ Reads an NLVR dataset and returns a JSON representation containing sentences, labels, correct and incorrect logical forms. The output will contain at most `max_num_logical_forms` logical forms each in both correct and incorrect lists. The output format is: ``[{"id": str, "label": str, "sentence": str, "correct": List[str], "incorrect": List[str]}]`` """ processed_data: JsonDict = [] # We can instantiate the ``ActionSpaceWalker`` with any world because the action space is the # same for all the ``NlvrLanguage`` objects. It is just the execution that differs. walker = ActionSpaceWalker(NlvrLanguage({}), max_path_length=max_path_length) for line in open(input_file): instance_id, sentence, structured_reps, label_strings = read_json_line( line) worlds = [] for structured_representation in structured_reps: boxes = { Box(object_list, box_id) for box_id, object_list in enumerate(structured_representation) } worlds.append(NlvrLanguage(boxes)) labels = [label_string == "true" for label_string in label_strings] correct_logical_forms = [] incorrect_logical_forms = [] if ignore_agenda: # Get 1000 shortest logical forms. logical_forms = walker.get_all_logical_forms( max_num_logical_forms=1000) else: # TODO (pradeep): Assuming all worlds give the same agenda. sentence_agenda = worlds[0].get_agenda_for_sentence(sentence) logical_forms = walker.get_logical_forms_with_agenda( sentence_agenda, max_num_logical_forms * 10) for logical_form in logical_forms: if all([ world.execute(logical_form) == label for world, label in zip(worlds, labels) ]): if len(correct_logical_forms) <= max_num_logical_forms: correct_logical_forms.append(logical_form) else: if len(incorrect_logical_forms) <= max_num_logical_forms: incorrect_logical_forms.append(logical_form) if (len(correct_logical_forms) >= max_num_logical_forms and len(incorrect_logical_forms) >= max_num_logical_forms): break if write_sequences: correct_sequences = [ worlds[0].logical_form_to_action_sequence(logical_form) for logical_form in correct_logical_forms ] incorrect_sequences = [ worlds[0].logical_form_to_action_sequence(logical_form) for logical_form in incorrect_logical_forms ] processed_data.append({ "id": instance_id, "sentence": sentence, "correct_sequences": correct_sequences, "incorrect_sequences": incorrect_sequences, "worlds": structured_reps, "labels": label_strings, }) else: processed_data.append({ "id": instance_id, "sentence": sentence, "correct_logical_forms": correct_logical_forms, "incorrect_logical_forms": incorrect_logical_forms, "worlds": structured_reps, "labels": label_strings, }) with open(output_file, "w") as outfile: for instance_processed_data in processed_data: json.dump(instance_processed_data, outfile) outfile.write("\n") outfile.close()
def setUp(self): super(ActionSpaceWalkerTest, self).setUp() self.world = FakeLanguageWithAssertions(start_types={bool}) self.walker = ActionSpaceWalker(self.world, max_path_length=10)
class ActionSpaceWalkerTest(AllenNlpTestCase): def setUp(self): super(ActionSpaceWalkerTest, self).setUp() self.world = FakeLanguageWithAssertions(start_types={bool}) self.walker = ActionSpaceWalker(self.world, max_path_length=10) def test_get_logical_forms_with_agenda(self): black_logical_forms = self.walker.get_logical_forms_with_agenda( ['<Set[Object]:Set[Object]> -> black']) # These are all the possible logical forms with black assert len(black_logical_forms) == 25 shortest_logical_form = self.walker.get_logical_forms_with_agenda( ['<Set[Object]:Set[Object]> -> black'], 1)[0] # This is the shortest complete logical form with black assert shortest_logical_form == '(object_exists (black all_objects))' agenda = [ '<Set[Object]:Set[Object]> -> black', '<Set[Object]:Set[Object]> -> triangle', '<Set[Object]:Set[Object]> -> touch_wall' ] black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda( agenda) # Permutations of the three functions. There will not be repetitions of any functions # because we limit the length of paths to 10 above. assert set(black_triangle_touch_forms) == set([ '(object_exists (black (triangle (touch_wall all_objects))))', '(object_exists (black (touch_wall (triangle all_objects))))', '(object_exists (triangle (black (touch_wall all_objects))))', '(object_exists (triangle (touch_wall (black all_objects))))', '(object_exists (touch_wall (black (triangle all_objects))))', '(object_exists (touch_wall (triangle (black all_objects))))' ]) def test_get_logical_forms_with_agenda_and_partial_match(self): black_logical_forms = self.walker.get_logical_forms_with_agenda( ['<Set[Object]:Set[Object]> -> black']) # These are all the possible logical forms with black assert len(black_logical_forms) == 25 shortest_logical_form = self.walker.get_logical_forms_with_agenda( ['<Set[Object]:Set[Object]> -> black'], 1)[0] # This is the shortest complete logical form with black assert shortest_logical_form == '(object_exists (black all_objects))' agenda = [ '<Set[Object]:Set[Object]> -> black', '<Set[Object]:Set[Object]> -> triangle', '<Set[Object]:Set[Object]> -> touch_wall' ] black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda( agenda, allow_partial_match=True) # The first six logical forms will contain permutations of all three functions. assert set(black_triangle_touch_forms[:6]) == set([ '(object_exists (black (triangle (touch_wall all_objects))))', '(object_exists (black (touch_wall (triangle all_objects))))', '(object_exists (triangle (black (touch_wall all_objects))))', '(object_exists (triangle (touch_wall (black all_objects))))', '(object_exists (touch_wall (black (triangle all_objects))))', '(object_exists (touch_wall (triangle (black all_objects))))' ]) # The next six will be the shortest six with two agenda items. assert set(black_triangle_touch_forms[6:12]) == set([ '(object_exists (black (triangle all_objects)))', '(object_exists (black (touch_wall all_objects)))', '(object_exists (triangle (black all_objects)))', '(object_exists (triangle (touch_wall all_objects)))', '(object_exists (touch_wall (black all_objects)))', '(object_exists (touch_wall (triangle all_objects)))' ]) # After a bunch of longer logical forms with two agenda items, we have the shortest three # with one agenda item. assert set(black_triangle_touch_forms[30:33]) == set([ '(object_exists (black all_objects))', '(object_exists (triangle all_objects))', '(object_exists (touch_wall all_objects))' ]) def test_get_logical_forms_with_empty_agenda_returns_all_logical_forms( self): with self.assertLogs("allennlp.semparse.action_space_walker") as log: empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda( [], allow_partial_match=True) first_four_logical_forms = empty_agenda_logical_forms[:4] assert set(first_four_logical_forms) == { '(object_exists all_objects)', '(object_exists (black all_objects))', '(object_exists (touch_wall all_objects))', '(object_exists (triangle all_objects))' } self.assertEqual(log.output, [ "WARNING:allennlp.semparse.action_space_walker:" "Agenda is empty! Returning all paths instead." ]) def test_get_logical_forms_with_unmatched_agenda_returns_all_logical_forms( self): agenda = ['<Set[Object]:Set[Object]> -> purple'] with self.assertLogs("allennlp.semparse.action_space_walker") as log: empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda( agenda, allow_partial_match=True) first_four_logical_forms = empty_agenda_logical_forms[:4] assert set(first_four_logical_forms) == { '(object_exists all_objects)', '(object_exists (black all_objects))', '(object_exists (touch_wall all_objects))', '(object_exists (triangle all_objects))' } self.assertEqual(log.output, [ "WARNING:allennlp.semparse.action_space_walker:" "Agenda items not in any of the paths found. Returning all paths." ]) empty_set = self.walker.get_logical_forms_with_agenda( agenda, allow_partial_match=False) assert empty_set == [] def test_get_logical_forms_with_agenda_ignores_null_set_item(self): with self.assertLogs("allennlp.semparse.action_space_walker") as log: agenda = [ '<Set[Object]:Set[Object]> -> yellow', '<Set[Object]:Set[Object]> -> black', '<Set[Object]:Set[Object]> -> triangle', '<Set[Object]:Set[Object]> -> touch_wall' ] yellow_black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda( agenda) # Permutations of the three functions, after ignoring yellow. There will not be repetitions # of any functions because we limit the length of paths to 10 above. assert set(yellow_black_triangle_touch_forms) == set([ '(object_exists (black (triangle (touch_wall all_objects))))', '(object_exists (black (touch_wall (triangle all_objects))))', '(object_exists (triangle (black (touch_wall all_objects))))', '(object_exists (triangle (touch_wall (black all_objects))))', '(object_exists (touch_wall (black (triangle all_objects))))', '(object_exists (touch_wall (triangle (black all_objects))))' ]) self.assertEqual(log.output, [ "WARNING:allennlp.semparse.action_space_walker:" "<Set[Object]:Set[Object]> -> yellow is not in any of the paths found! Ignoring it." ]) def test_get_all_logical_forms(self): # get_all_logical_forms should sort logical forms by length. ten_shortest_logical_forms = self.walker.get_all_logical_forms( max_num_logical_forms=10) shortest_logical_form = ten_shortest_logical_forms[0] assert shortest_logical_form == '(object_exists all_objects)' length_three_logical_forms = ten_shortest_logical_forms[1:4] assert set(length_three_logical_forms) == { '(object_exists (black all_objects))', '(object_exists (touch_wall all_objects))', '(object_exists (triangle all_objects))' }
def process_data(input_file: str, output_file: str, max_path_length: int, max_num_logical_forms: int, ignore_agenda: bool, write_sequences: bool) -> None: """ Reads an NLVR dataset and returns a JSON representation containing sentences, labels, correct and incorrect logical forms. The output will contain at most `max_num_logical_forms` logical forms each in both correct and incorrect lists. The output format is: ``[{"id": str, "label": str, "sentence": str, "correct": List[str], "incorrect": List[str]}]`` """ processed_data: JsonDict = [] # We can instantiate the ``ActionSpaceWalker`` with any world because the action space is the # same for all the ``NlvrWorlds``. It is just the execution that differs. serialized_walker_path = f"serialized_action_space_walker_pl={max_path_length}.pkl" if os.path.isfile(serialized_walker_path): print("Reading walker from serialized file", file=sys.stderr) walker = pickle.load(open(serialized_walker_path, "rb")) else: walker = ActionSpaceWalker(NlvrWorld({}), max_path_length=max_path_length) pickle.dump(walker, open(serialized_walker_path, "wb")) for line in open(input_file): instance_id, sentence, structured_reps, label_strings = read_json_line(line) worlds = [NlvrWorld(structured_rep) for structured_rep in structured_reps] labels = [label_string == "true" for label_string in label_strings] correct_logical_forms = [] incorrect_logical_forms = [] if ignore_agenda: # Get 1000 shortest logical forms. logical_forms = walker.get_all_logical_forms(max_num_logical_forms=1000) else: # TODO (pradeep): Assuming all worlds give the same agenda. sentence_agenda = worlds[0].get_agenda_for_sentence(sentence, add_paths_to_agenda=False) logical_forms = walker.get_logical_forms_with_agenda(sentence_agenda, max_num_logical_forms * 10) for logical_form in logical_forms: if all([world.execute(logical_form) == label for world, label in zip(worlds, labels)]): if len(correct_logical_forms) <= max_num_logical_forms: correct_logical_forms.append(logical_form) else: if len(incorrect_logical_forms) <= max_num_logical_forms: incorrect_logical_forms.append(logical_form) if len(correct_logical_forms) >= max_num_logical_forms \ and len(incorrect_logical_forms) >= max_num_logical_forms: break if write_sequences: parsed_correct_forms = [worlds[0].parse_logical_form(logical_form) for logical_form in correct_logical_forms] correct_sequences = [worlds[0].get_action_sequence(parsed_form) for parsed_form in parsed_correct_forms] parsed_incorrect_forms = [worlds[0].parse_logical_form(logical_form) for logical_form in incorrect_logical_forms] incorrect_sequences = [worlds[0].get_action_sequence(parsed_form) for parsed_form in parsed_incorrect_forms] processed_data.append({"id": instance_id, "sentence": sentence, "correct_sequences": correct_sequences, "incorrect_sequences": incorrect_sequences, "worlds": structured_reps, "labels": label_strings}) else: processed_data.append({"id": instance_id, "sentence": sentence, "correct_logical_forms": correct_logical_forms, "incorrect_logical_forms": incorrect_logical_forms, "worlds": structured_reps, "labels": label_strings}) with open(output_file, "w") as outfile: for instance_processed_data in processed_data: json.dump(instance_processed_data, outfile) outfile.write('\n') outfile.close()
class ActionSpaceWalkerTest(AllenNlpTestCase): def setUp(self): super(ActionSpaceWalkerTest, self).setUp() self.world = FakeWorldWithAssertions() self.walker = ActionSpaceWalker(self.world, max_path_length=10) def test_get_logical_forms_with_agenda(self): black_logical_forms = self.walker.get_logical_forms_with_agenda(['<o,o> -> black']) # These are all the possible logical forms with black assert len(black_logical_forms) == 25 shortest_logical_form = self.walker.get_logical_forms_with_agenda(['<o,o> -> black'], 1)[0] # This is the shortest complete logical form with black assert shortest_logical_form == '(object_exists (black all_objects))' black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(['<o,o> -> black', '<o,o> -> triangle', '<o,o> -> touch_wall']) # Permutations of the three functions. There will not be repetitions of any functions # because we limit the length of paths to 10 above. assert set(black_triangle_touch_forms) == set([ '(object_exists (black (triangle (touch_wall all_objects))))', '(object_exists (black (touch_wall (triangle all_objects))))', '(object_exists (triangle (black (touch_wall all_objects))))', '(object_exists (triangle (touch_wall (black all_objects))))', '(object_exists (touch_wall (black (triangle all_objects))))', '(object_exists (touch_wall (triangle (black all_objects))))']) def test_get_logical_forms_with_empty_agenda_returns_all_logical_forms(self): with self.assertLogs("allennlp.semparse.action_space_walker") as log: empty_agenda_logical_forms = self.walker.get_logical_forms_with_agenda([]) first_four_logical_forms = empty_agenda_logical_forms[:4] assert set(first_four_logical_forms) == {'(object_exists all_objects)', '(object_exists (black all_objects))', '(object_exists (touch_wall all_objects))', '(object_exists (triangle all_objects))'} self.assertEqual(log.output, ["WARNING:allennlp.semparse.action_space_walker:" "Agenda is empty! Returning all paths instead."]) def test_get_logical_forms_with_agenda_ignores_null_set_item(self): with self.assertLogs("allennlp.semparse.action_space_walker") as log: yellow_black_triangle_touch_forms = self.walker.get_logical_forms_with_agenda(['<o,o> -> yellow', '<o,o> -> black', '<o,o> -> triangle', '<o,o> -> touch_wall']) # Permutations of the three functions, after ignoring yellow. There will not be repetitions # of any functions because we limit the length of paths to 10 above. assert set(yellow_black_triangle_touch_forms) == set([ '(object_exists (black (triangle (touch_wall all_objects))))', '(object_exists (black (touch_wall (triangle all_objects))))', '(object_exists (triangle (black (touch_wall all_objects))))', '(object_exists (triangle (touch_wall (black all_objects))))', '(object_exists (touch_wall (black (triangle all_objects))))', '(object_exists (touch_wall (triangle (black all_objects))))']) self.assertEqual(log.output, ["WARNING:allennlp.semparse.action_space_walker:" "<o,o> -> yellow is not in any of the paths found! Ignoring it."]) def test_get_all_logical_forms(self): # get_all_logical_forms should sort logical forms by length. ten_shortest_logical_forms = self.walker.get_all_logical_forms(max_num_logical_forms=10) shortest_logical_form = ten_shortest_logical_forms[0] assert shortest_logical_form == '(object_exists all_objects)' length_three_logical_forms = ten_shortest_logical_forms[1:4] assert set(length_three_logical_forms) == {'(object_exists (black all_objects))', '(object_exists (touch_wall all_objects))', '(object_exists (triangle all_objects))'}