def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test time, to finalize predictions. We only transform the action string sequences into logical forms here. """ best_action_strings = output_dict["best_action_strings"] # Instantiating an empty world for getting logical forms. world = NlvrLanguage(set()) logical_forms = [] for instance_action_sequences in best_action_strings: instance_logical_forms = [] for action_strings in instance_action_sequences: if action_strings: instance_logical_forms.append( world.action_sequence_to_logical_form(action_strings)) else: instance_logical_forms.append("") logical_forms.append(instance_logical_forms) action_mapping = output_dict["action_mapping"] best_actions = output_dict["best_action_strings"] debug_infos = output_dict["debug_info"] batch_action_info = [] for batch_index, (predicted_actions, debug_info) in enumerate( zip(best_actions, debug_infos)): instance_action_info = [] for predicted_action, action_debug_info in zip( predicted_actions[0], debug_info): action_info = {} action_info["predicted_action"] = predicted_action considered_actions = action_debug_info["considered_actions"] probabilities = action_debug_info["probabilities"] actions = [] for action, probability in zip(considered_actions, probabilities): if action != -1: actions.append((action_mapping[(batch_index, action)], probability)) actions.sort() considered_actions, probabilities = zip(*actions) action_info["considered_actions"] = considered_actions action_info["action_probabilities"] = probabilities action_info["question_attention"] = action_debug_info.get( "question_attention", []) instance_action_info.append(action_info) batch_action_info.append(instance_action_info) output_dict["predicted_actions"] = batch_action_info output_dict["logical_form"] = logical_forms return output_dict
def setUp(self): super().setUp() test_filename = self.FIXTURES_ROOT / "data" / "nlvr" / "sample_ungrouped_data.jsonl" data = [ json.loads(line)["structured_rep"] for line in open(test_filename).readlines() ] box_lists = [[ Box(object_reps, i) for i, object_reps in enumerate(box_rep) ] for box_rep in data] self.languages = [NlvrLanguage(boxes) for boxes in box_lists] # y_loc increases as we go down from top to bottom, and x_loc from left to right. That is, # the origin is at the top-left corner. custom_rep = [[{ "y_loc": 79, "size": 20, "type": "triangle", "x_loc": 27, "color": "Yellow" }, { "y_loc": 55, "size": 10, "type": "circle", "x_loc": 47, "color": "Black" }], [{ "y_loc": 44, "size": 30, "type": "square", "x_loc": 10, "color": "#0099ff" }, { "y_loc": 74, "size": 30, "type": "square", "x_loc": 40, "color": "Yellow" }], [{ "y_loc": 60, "size": 10, "type": "triangle", "x_loc": 12, "color": "#0099ff" }]] self.custom_language = NlvrLanguage( [Box(object_rep, i) for i, object_rep in enumerate(custom_rep)])
def _create_grammar_state( self, world: NlvrLanguage, possible_actions: List[ProductionRule]) -> GrammarStatelet: valid_actions = world.get_nonterminal_productions() action_mapping = {} for i, action in enumerate(possible_actions): action_mapping[action[0]] = i translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {} for key, action_strings in valid_actions.items(): translated_valid_actions[key] = {} # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid # productions of that non-terminal. action_indices = [ action_mapping[action_string] for action_string in action_strings ] # All actions in NLVR are global actions. global_actions = [(possible_actions[index][2], index) for index in action_indices] # Then we get the embedded representations of the global actions. global_action_tensors, global_action_ids = zip(*global_actions) global_action_tensor = torch.cat(global_action_tensors, dim=0) global_input_embeddings = self._action_embedder( global_action_tensor) translated_valid_actions[key]["global"] = ( global_input_embeddings, global_input_embeddings, list(global_action_ids), ) return GrammarStatelet([START_SYMBOL], translated_valid_actions, world.is_nonterminal)
def process_data( input_file: str, output_file: str, max_path_length: int, max_num_logical_forms: int, ignore_agenda: bool, write_sequences: bool, ) -> None: """ Reads an NLVR dataset and returns a JSON representation containing sentences, labels, correct and incorrect logical forms. The output will contain at most `max_num_logical_forms` logical forms each in both correct and incorrect lists. The output format is: ``[{"id": str, "label": str, "sentence": str, "correct": List[str], "incorrect": List[str]}]`` """ processed_data: JsonDict = [] # We can instantiate the ``ActionSpaceWalker`` with any world because the action space is the # same for all the ``NlvrLanguage`` objects. It is just the execution that differs. walker = ActionSpaceWalker(NlvrLanguage({}), max_path_length=max_path_length) for line in open(input_file): instance_id, sentence, structured_reps, label_strings = read_json_line( line) worlds = [] for structured_representation in structured_reps: boxes = { Box(object_list, box_id) for box_id, object_list in enumerate(structured_representation) } worlds.append(NlvrLanguage(boxes)) labels = [label_string == "true" for label_string in label_strings] correct_logical_forms = [] incorrect_logical_forms = [] if ignore_agenda: # Get 1000 shortest logical forms. logical_forms = walker.get_all_logical_forms( max_num_logical_forms=1000) else: # TODO (pradeep): Assuming all worlds give the same agenda. sentence_agenda = worlds[0].get_agenda_for_sentence(sentence) logical_forms = walker.get_logical_forms_with_agenda( sentence_agenda, max_num_logical_forms * 10) for logical_form in logical_forms: if all([ world.execute(logical_form) == label for world, label in zip(worlds, labels) ]): if len(correct_logical_forms) <= max_num_logical_forms: correct_logical_forms.append(logical_form) else: if len(incorrect_logical_forms) <= max_num_logical_forms: incorrect_logical_forms.append(logical_form) if (len(correct_logical_forms) >= max_num_logical_forms and len(incorrect_logical_forms) >= max_num_logical_forms): break if write_sequences: correct_sequences = [ worlds[0].logical_form_to_action_sequence(logical_form) for logical_form in correct_logical_forms ] incorrect_sequences = [ worlds[0].logical_form_to_action_sequence(logical_form) for logical_form in incorrect_logical_forms ] processed_data.append({ "id": instance_id, "sentence": sentence, "correct_sequences": correct_sequences, "incorrect_sequences": incorrect_sequences, "worlds": structured_reps, "labels": label_strings, }) else: processed_data.append({ "id": instance_id, "sentence": sentence, "correct_logical_forms": correct_logical_forms, "incorrect_logical_forms": incorrect_logical_forms, "worlds": structured_reps, "labels": label_strings, }) with open(output_file, "w") as outfile: for instance_processed_data in processed_data: json.dump(instance_processed_data, outfile) outfile.write("\n") outfile.close()
class TestNlvrLanguage(AllenNlpTestCase): def setUp(self): super().setUp() test_filename = self.FIXTURES_ROOT / "data" / "nlvr" / "sample_ungrouped_data.jsonl" data = [ json.loads(line)["structured_rep"] for line in open(test_filename).readlines() ] box_lists = [[ Box(object_reps, i) for i, object_reps in enumerate(box_rep) ] for box_rep in data] self.languages = [NlvrLanguage(boxes) for boxes in box_lists] # y_loc increases as we go down from top to bottom, and x_loc from left to right. That is, # the origin is at the top-left corner. custom_rep = [[{ "y_loc": 79, "size": 20, "type": "triangle", "x_loc": 27, "color": "Yellow" }, { "y_loc": 55, "size": 10, "type": "circle", "x_loc": 47, "color": "Black" }], [{ "y_loc": 44, "size": 30, "type": "square", "x_loc": 10, "color": "#0099ff" }, { "y_loc": 74, "size": 30, "type": "square", "x_loc": 40, "color": "Yellow" }], [{ "y_loc": 60, "size": 10, "type": "triangle", "x_loc": 12, "color": "#0099ff" }]] self.custom_language = NlvrLanguage( [Box(object_rep, i) for i, object_rep in enumerate(custom_rep)]) def test_logical_form_with_assert_executes_correctly(self): executor = self.languages[0] # Utterance is "There is a circle closely touching a corner of a box." and label is "True". logical_form_true = "(object_count_greater_equals (touch_corner (circle (all_objects))) 1)" assert executor.execute(logical_form_true) is True logical_form_false = "(object_count_equals (touch_corner (circle (all_objects))) 9)" assert executor.execute(logical_form_false) is False def test_logical_form_with_box_filter_executes_correctly(self): executor = self.languages[2] # Utterance is "There is a box without a blue item." and label is "False". logical_form = "(box_exists (member_color_none_equals all_boxes color_blue))" assert executor.execute(logical_form) is False def test_logical_form_with_box_filter_within_object_filter_executes_correctly( self): executor = self.languages[2] # Utterance is "There are at least three blue items in boxes with blue items" and label # is "True". logical_form = "(object_count_greater_equals \ (object_in_box (member_color_any_equals all_boxes color_blue)) 3)" assert executor.execute(logical_form) is True def test_logical_form_with_same_color_executes_correctly(self): executor = self.languages[1] # Utterance is "There are exactly two blocks of the same color." and label is "True". logical_form = "(object_count_equals (same_color all_objects) 2)" assert executor.execute(logical_form) is True def test_logical_form_with_same_shape_executes_correctly(self): executor = self.languages[0] # Utterance is "There are less than three black objects of the same shape" and label is "False". logical_form = "(object_count_lesser (same_shape (black (all_objects))) 3)" assert executor.execute(logical_form) is False def test_logical_form_with_touch_wall_executes_correctly(self): executor = self.languages[0] # Utterance is "There are two black circles touching a wall" and label is "False". logical_form = "(object_count_greater_equals (touch_wall (black (circle (all_objects)))) 2)" assert executor.execute(logical_form) is False def test_logical_form_with_not_executes_correctly(self): executor = self.languages[2] # Utterance is "There are at most two medium triangles not touching a wall." and label is "True". logical_form = ( "(object_count_lesser_equals ((negate_filter touch_wall) " "(medium (triangle (all_objects)))) 2)") assert executor.execute(logical_form) is True def test_logical_form_with_color_comparison_executes_correctly(self): executor = self.languages[0] # Utterance is "The color of the circle touching the wall is black." and label is "True". logical_form = "(object_color_all_equals (circle (touch_wall (all_objects))) color_black)" assert executor.execute(logical_form) is True def test_spatial_relations_return_objects_in_the_same_box(self): # "above", "below", "top", "bottom" are relations defined only for objects within the same # box. So they should not return objects from other boxes. # Asserting that the color of the objects above the yellow triangle is only black (it is not # yellow or blue, which are colors of objects from other boxes) assert self.custom_language.execute( "(object_color_all_equals (above (yellow (triangle all_objects)))" " color_black)") is True # Asserting that the only shape below the blue square is a square. assert self.custom_language.execute( "(object_shape_all_equals (below (blue (square all_objects)))" " shape_square)") is True # Asserting the shape of the object at the bottom in the box with a circle is triangle. logical_form = ( "(object_shape_all_equals (bottom (object_in_box" " (member_shape_any_equals all_boxes shape_circle))) shape_triangle)" ) assert self.custom_language.execute(logical_form) is True # Asserting the shape of the object at the top of the box with all squares is a square (!). logical_form = ( "(object_shape_all_equals (top (object_in_box" " (member_shape_all_equals all_boxes shape_square))) shape_square)" ) assert self.custom_language.execute(logical_form) is True def test_touch_object_executes_correctly(self): # Assert that there is a yellow square touching a blue square. assert self.custom_language.execute( "(object_exists (yellow (square (touch_object (blue " "(square all_objects))))))") is True # Assert that the triangle does not touch the circle (they are out of vertical range). assert self.custom_language.execute( "(object_shape_none_equals (touch_object (triangle all_objects))" " shape_circle)") is True def test_spatial_relations_with_objects_from_different_boxes(self): # When the objects are from different boxes, top and bottom should return objects from # respective boxes. # There are triangles in two boxes, so top should return the top objects from both boxes. assert self.custom_language.execute( "(object_count_equals (top (object_in_box (member_shape_any_equals " "all_boxes shape_triangle))) 2)") is True def test_same_and_different_execute_correctly(self): # All the objects in the box with two objects of the same shape are squares. assert self.custom_language.execute( "(object_shape_all_equals " "(object_in_box (member_shape_same (member_count_equals all_boxes 2)))" " shape_square)") is True # There is a circle in the box with objects of different shapes. assert self.custom_language.execute( "(object_shape_any_equals (object_in_box " "(member_shape_different all_boxes)) shape_circle)") is True def test_get_action_sequence_handles_multi_arg_functions(self): language = self.languages[0] # box_color_filter logical_form = "(box_exists (member_color_all_equals all_boxes color_blue))" action_sequence = language.logical_form_to_action_sequence( logical_form) assert 'Set[Box] -> [<Set[Box],Color:Set[Box]>, Set[Box], Color]' in action_sequence # box_shape_filter logical_form = "(box_exists (member_shape_all_equals all_boxes shape_square))" action_sequence = language.logical_form_to_action_sequence( logical_form) assert 'Set[Box] -> [<Set[Box],Shape:Set[Box]>, Set[Box], Shape]' in action_sequence # box_count_filter logical_form = "(box_exists (member_count_equals all_boxes 3))" action_sequence = language.logical_form_to_action_sequence( logical_form) assert 'Set[Box] -> [<Set[Box],int:Set[Box]>, Set[Box], int]' in action_sequence # assert_color logical_form = "(object_color_all_equals all_objects color_blue)" action_sequence = language.logical_form_to_action_sequence( logical_form) assert 'bool -> [<Set[Object],Color:bool>, Set[Object], Color]' in action_sequence # assert_shape logical_form = "(object_shape_all_equals all_objects shape_square)" action_sequence = language.logical_form_to_action_sequence( logical_form) assert 'bool -> [<Set[Object],Shape:bool>, Set[Object], Shape]' in action_sequence # assert_box_count logical_form = "(box_count_equals all_boxes 1)" action_sequence = language.logical_form_to_action_sequence( logical_form) assert 'bool -> [<Set[Box],int:bool>, Set[Box], int]' in action_sequence # assert_object_count logical_form = "(object_count_equals all_objects 1)" action_sequence = language.logical_form_to_action_sequence( logical_form) assert 'bool -> [<Set[Object],int:bool>, Set[Object], int]' in action_sequence def test_logical_form_with_object_filter_returns_correct_action_sequence( self): language = self.languages[0] logical_form = "(object_color_all_equals (circle (touch_wall all_objects)) color_black)" action_sequence = language.logical_form_to_action_sequence( logical_form) assert action_sequence == [ '@start@ -> bool', 'bool -> [<Set[Object],Color:bool>, Set[Object], Color]', '<Set[Object],Color:bool> -> object_color_all_equals', 'Set[Object] -> [<Set[Object]:Set[Object]>, Set[Object]]', '<Set[Object]:Set[Object]> -> circle', 'Set[Object] -> [<Set[Object]:Set[Object]>, Set[Object]]', '<Set[Object]:Set[Object]> -> touch_wall', 'Set[Object] -> all_objects', 'Color -> color_black' ] def test_logical_form_with_negate_filter_returns_correct_action_sequence( self): language = self.languages[0] logical_form = "(object_exists ((negate_filter touch_wall) all_objects))" action_sequence = language.logical_form_to_action_sequence( logical_form) negate_filter_production = ( '<Set[Object]:Set[Object]> -> ' '[<<Set[Object]:Set[Object]>:<Set[Object]:Set[Object]>>, ' '<Set[Object]:Set[Object]>]') assert action_sequence == [ '@start@ -> bool', 'bool -> [<Set[Object]:bool>, Set[Object]]', '<Set[Object]:bool> -> object_exists', 'Set[Object] -> [<Set[Object]:Set[Object]>, Set[Object]]', negate_filter_production, '<<Set[Object]:Set[Object]>:<Set[Object]:Set[Object]>> -> negate_filter', '<Set[Object]:Set[Object]> -> touch_wall', 'Set[Object] -> all_objects' ] def test_logical_form_with_box_filter_returns_correct_action_sequence( self): language = self.languages[0] logical_form = "(box_exists (member_color_none_equals all_boxes color_blue))" action_sequence = language.logical_form_to_action_sequence( logical_form) assert action_sequence == [ '@start@ -> bool', 'bool -> [<Set[Box]:bool>, Set[Box]]', '<Set[Box]:bool> -> box_exists', 'Set[Box] -> [<Set[Box],Color:Set[Box]>, Set[Box], Color]', '<Set[Box],Color:Set[Box]> -> member_color_none_equals', 'Set[Box] -> all_boxes', 'Color -> color_blue' ] def test_get_agenda_for_sentence(self): language = self.languages[0] agenda = language.get_agenda_for_sentence( "there is a tower with exactly two yellow blocks") assert set(agenda) == set([ 'Color -> color_yellow', '<Set[Box]:bool> -> box_exists', 'int -> 2' ]) agenda = language.get_agenda_for_sentence( "There is at most one yellow item closely touching " "the bottom of a box.") assert set(agenda) == set([ '<Set[Object]:Set[Object]> -> yellow', '<Set[Object]:Set[Object]> -> touch_bottom', 'int -> 1' ]) agenda = language.get_agenda_for_sentence( "There is at most one yellow item closely touching " "the right wall of a box.") assert set(agenda) == set([ '<Set[Object]:Set[Object]> -> yellow', '<Set[Object]:Set[Object]> -> touch_right', 'int -> 1' ]) agenda = language.get_agenda_for_sentence( "There is at most one yellow item closely touching " "the left wall of a box.") assert set(agenda) == set([ '<Set[Object]:Set[Object]> -> yellow', '<Set[Object]:Set[Object]> -> touch_left', 'int -> 1' ]) agenda = language.get_agenda_for_sentence( "There is at most one yellow item closely touching " "a wall of a box.") assert set(agenda) == set([ '<Set[Object]:Set[Object]> -> yellow', '<Set[Object]:Set[Object]> -> touch_wall', 'int -> 1' ]) agenda = language.get_agenda_for_sentence( "There is exactly one square touching any edge") assert set(agenda) == set([ '<Set[Object]:Set[Object]> -> square', '<Set[Object]:Set[Object]> -> touch_wall', 'int -> 1' ]) agenda = language.get_agenda_for_sentence( "There is exactly one square not touching any edge") assert set(agenda) == set([ '<Set[Object]:Set[Object]> -> square', '<Set[Object]:Set[Object]> -> touch_wall', 'int -> 1', '<<Set[Object]:Set[Object]>:<Set[Object]:Set[Object]>> -> negate_filter' ]) agenda = language.get_agenda_for_sentence( "There is only 1 tower with 1 blue block at the base") assert set(agenda) == set([ '<Set[Object]:Set[Object]> -> blue', 'int -> 1', '<Set[Object]:Set[Object]> -> bottom', 'int -> 1' ]) agenda = language.get_agenda_for_sentence( "There is only 1 tower that has 1 blue block at the top") assert set(agenda) == set([ '<Set[Object]:Set[Object]> -> blue', 'int -> 1', '<Set[Object]:Set[Object]> -> top', 'int -> 1', 'Set[Box] -> all_boxes' ]) agenda = language.get_agenda_for_sentence( "There is exactly one square touching the blue " "triangle") assert set(agenda) == set([ '<Set[Object]:Set[Object]> -> square', '<Set[Object]:Set[Object]> -> blue', '<Set[Object]:Set[Object]> -> triangle', '<Set[Object]:Set[Object]> -> touch_object', 'int -> 1' ]) def test_get_agenda_for_sentence_correctly_adds_object_filters(self): # In logical forms that contain "box_exists" at the top, there can never be object filtering # operations like "blue", "square" etc. In those cases, strings like "blue" and "square" in # sentences should map to "color_blue" and "shape_square" respectively. language = self.languages[0] agenda = language.get_agenda_for_sentence( "there is a box with exactly two yellow triangles " "touching the top edge") assert "<Set[Object]:Set[Object]> -> yellow" not in agenda assert "Color -> color_yellow" in agenda assert "<Set[Object]:Set[Object]> -> triangle" not in agenda assert "Shape -> shape_triangle" in agenda assert "<Set[Object]:Set[Object]> -> touch_top" not in agenda agenda = language.get_agenda_for_sentence( "there are exactly two yellow triangles touching the" " top edge") assert "<Set[Object]:Set[Object]> -> yellow" in agenda assert "Color -> color_yellow" not in agenda assert "<Set[Object]:Set[Object]> -> triangle" in agenda assert "Shape -> shape_triangle" not in agenda assert "<Set[Object]:Set[Object]> -> touch_top" in agenda
def text_to_instance( self, # type: ignore sentence: str, structured_representations: List[List[List[JsonDict]]], labels: List[str] = None, target_sequences: List[List[str]] = None, identifier: str = None) -> Instance: """ Parameters ---------- sentence : ``str`` The query sentence. structured_representations : ``List[List[List[JsonDict]]]`` A list of Json representations of all the worlds. See expected format in this class' docstring. labels : ``List[str]`` (optional) List of string representations of the labels (true or false) corresponding to the ``structured_representations``. Not required while testing. target_sequences : ``List[List[str]]`` (optional) List of target action sequences for each element which lead to the correct denotation in worlds corresponding to the structured representations. identifier : ``str`` (optional) The identifier from the dataset if available. """ # pylint: disable=arguments-differ worlds = [] for structured_representation in structured_representations: boxes = set([ Box(object_list, box_id) for box_id, object_list in enumerate(structured_representation) ]) worlds.append(NlvrLanguage(boxes)) tokenized_sentence = self._tokenizer.tokenize(sentence) sentence_field = TextField(tokenized_sentence, self._sentence_token_indexers) production_rule_fields: List[Field] = [] instance_action_ids: Dict[str, int] = {} # TODO(pradeep): Assuming that possible actions are the same in all worlds. This may change # later. for production_rule in worlds[0].all_possible_productions(): instance_action_ids[production_rule] = len(instance_action_ids) field = ProductionRuleField(production_rule, is_global_rule=True) production_rule_fields.append(field) action_field = ListField(production_rule_fields) worlds_field = ListField([MetadataField(world) for world in worlds]) metadata: Dict[str, Any] = { "sentence_tokens": [x.text for x in tokenized_sentence] } fields: Dict[str, Field] = { "sentence": sentence_field, "worlds": worlds_field, "actions": action_field, "metadata": MetadataField(metadata) } if identifier is not None: fields["identifier"] = MetadataField(identifier) # Depending on the type of supervision used for training the parser, we may want either # target action sequences or an agenda in our instance. We check if target sequences are # provided, and include them if they are. If not, we'll get an agenda for the sentence, and # include that in the instance. if target_sequences: action_sequence_fields: List[Field] = [] for target_sequence in target_sequences: index_fields = ListField([ IndexField(instance_action_ids[action], action_field) for action in target_sequence ]) action_sequence_fields.append(index_fields) # TODO(pradeep): Define a max length for this field. fields["target_action_sequences"] = ListField( action_sequence_fields) elif self._output_agendas: # TODO(pradeep): Assuming every world gives the same agenda for a sentence. This is true # now, but may change later too. agenda = worlds[0].get_agenda_for_sentence(sentence) assert agenda, "No agenda found for sentence: %s" % sentence # agenda_field contains indices into actions. agenda_field = ListField([ IndexField(instance_action_ids[action], action_field) for action in agenda ]) fields["agenda"] = agenda_field if labels: labels_field = ListField([ LabelField(label, label_namespace='denotations') for label in labels ]) fields["labels"] = labels_field return Instance(fields)
def __init__( self, vocab: Vocabulary, sentence_embedder: TextFieldEmbedder, action_embedding_dim: int, encoder: Seq2SeqEncoder, attention: Attention, beam_size: int, max_decoding_steps: int, max_num_finished_states: int = None, dropout: float = 0.0, normalize_beam_score_by_length: bool = False, checklist_cost_weight: float = 0.6, dynamic_cost_weight: Dict[str, Union[int, float]] = None, penalize_non_agenda_actions: bool = False, initial_mml_model_file: str = None, ) -> None: super().__init__( vocab=vocab, sentence_embedder=sentence_embedder, action_embedding_dim=action_embedding_dim, encoder=encoder, dropout=dropout, ) self._agenda_coverage = Average() self._decoder_trainer: DecoderTrainer[ Callable[[CoverageState], torch.Tensor] ] = ExpectedRiskMinimization( beam_size=beam_size, normalize_by_length=normalize_beam_score_by_length, max_decoding_steps=max_decoding_steps, max_num_finished_states=max_num_finished_states, ) # Instantiating an empty NlvrLanguage just to get the number of terminals. self._terminal_productions = set(NlvrLanguage(set()).terminal_productions.values()) self._decoder_step = CoverageTransitionFunction( encoder_output_dim=self._encoder.get_output_dim(), action_embedding_dim=action_embedding_dim, input_attention=attention, activation=Activation.by_name("tanh")(), add_action_bias=False, dropout=dropout, ) self._checklist_cost_weight = checklist_cost_weight self._dynamic_cost_wait_epochs = None self._dynamic_cost_rate = None if dynamic_cost_weight: self._dynamic_cost_wait_epochs = dynamic_cost_weight["wait_num_epochs"] self._dynamic_cost_rate = dynamic_cost_weight["rate"] self._penalize_non_agenda_actions = penalize_non_agenda_actions self._last_epoch_in_forward: int = None # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've # copied a trained ERM model from a different machine and the original MML model that was # used to initialize it does not exist on the current machine. This may not be the best # solution for the problem. if initial_mml_model_file is not None: if os.path.isfile(initial_mml_model_file): archive = load_archive(initial_mml_model_file) self._initialize_weights_from_archive(archive) else: # A model file is passed, but it does not exist. This is expected to happen when # you're using a trained ERM model to decode. But it may also happen if the path to # the file is really just incorrect. So throwing a warning. logger.warning( "MML model file for initializing weights is passed, but does not exist." " This is fine if you're just decoding." )