def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. We only transform the action string sequences into logical
        forms here.
        """
        best_action_strings = output_dict["best_action_strings"]
        # Instantiating an empty world for getting logical forms.
        world = NlvrLanguage(set())
        logical_forms = []
        for instance_action_sequences in best_action_strings:
            instance_logical_forms = []
            for action_strings in instance_action_sequences:
                if action_strings:
                    instance_logical_forms.append(
                        world.action_sequence_to_logical_form(action_strings))
                else:
                    instance_logical_forms.append("")
            logical_forms.append(instance_logical_forms)

        action_mapping = output_dict["action_mapping"]
        best_actions = output_dict["best_action_strings"]
        debug_infos = output_dict["debug_info"]
        batch_action_info = []
        for batch_index, (predicted_actions, debug_info) in enumerate(
                zip(best_actions, debug_infos)):
            instance_action_info = []
            for predicted_action, action_debug_info in zip(
                    predicted_actions[0], debug_info):
                action_info = {}
                action_info["predicted_action"] = predicted_action
                considered_actions = action_debug_info["considered_actions"]
                probabilities = action_debug_info["probabilities"]
                actions = []
                for action, probability in zip(considered_actions,
                                               probabilities):
                    if action != -1:
                        actions.append((action_mapping[(batch_index, action)],
                                        probability))
                actions.sort()
                considered_actions, probabilities = zip(*actions)
                action_info["considered_actions"] = considered_actions
                action_info["action_probabilities"] = probabilities
                action_info["question_attention"] = action_debug_info.get(
                    "question_attention", [])
                instance_action_info.append(action_info)
            batch_action_info.append(instance_action_info)
        output_dict["predicted_actions"] = batch_action_info
        output_dict["logical_form"] = logical_forms
        return output_dict
예제 #2
0
 def setUp(self):
     super().setUp()
     test_filename = self.FIXTURES_ROOT / "data" / "nlvr" / "sample_ungrouped_data.jsonl"
     data = [
         json.loads(line)["structured_rep"]
         for line in open(test_filename).readlines()
     ]
     box_lists = [[
         Box(object_reps, i) for i, object_reps in enumerate(box_rep)
     ] for box_rep in data]
     self.languages = [NlvrLanguage(boxes) for boxes in box_lists]
     # y_loc increases as we go down from top to bottom, and x_loc from left to right. That is,
     # the origin is at the top-left corner.
     custom_rep = [[{
         "y_loc": 79,
         "size": 20,
         "type": "triangle",
         "x_loc": 27,
         "color": "Yellow"
     }, {
         "y_loc": 55,
         "size": 10,
         "type": "circle",
         "x_loc": 47,
         "color": "Black"
     }],
                   [{
                       "y_loc": 44,
                       "size": 30,
                       "type": "square",
                       "x_loc": 10,
                       "color": "#0099ff"
                   }, {
                       "y_loc": 74,
                       "size": 30,
                       "type": "square",
                       "x_loc": 40,
                       "color": "Yellow"
                   }],
                   [{
                       "y_loc": 60,
                       "size": 10,
                       "type": "triangle",
                       "x_loc": 12,
                       "color": "#0099ff"
                   }]]
     self.custom_language = NlvrLanguage(
         [Box(object_rep, i) for i, object_rep in enumerate(custom_rep)])
    def _create_grammar_state(
            self, world: NlvrLanguage,
            possible_actions: List[ProductionRule]) -> GrammarStatelet:
        valid_actions = world.get_nonterminal_productions()
        action_mapping = {}
        for i, action in enumerate(possible_actions):
            action_mapping[action[0]] = i
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.
            action_indices = [
                action_mapping[action_string]
                for action_string in action_strings
            ]
            # All actions in NLVR are global actions.
            global_actions = [(possible_actions[index][2], index)
                              for index in action_indices]

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            translated_valid_actions[key]["global"] = (
                global_input_embeddings,
                global_input_embeddings,
                list(global_action_ids),
            )
        return GrammarStatelet([START_SYMBOL], translated_valid_actions,
                               world.is_nonterminal)
예제 #4
0
def process_data(
    input_file: str,
    output_file: str,
    max_path_length: int,
    max_num_logical_forms: int,
    ignore_agenda: bool,
    write_sequences: bool,
) -> None:
    """
    Reads an NLVR dataset and returns a JSON representation containing sentences, labels, correct and
    incorrect logical forms. The output will contain at most `max_num_logical_forms` logical forms
    each in both correct and incorrect lists. The output format is:
        ``[{"id": str, "label": str, "sentence": str, "correct": List[str], "incorrect": List[str]}]``
    """
    processed_data: JsonDict = []
    # We can instantiate the ``ActionSpaceWalker`` with any world because the action space is the
    # same for all the ``NlvrLanguage`` objects. It is just the execution that differs.
    walker = ActionSpaceWalker(NlvrLanguage({}),
                               max_path_length=max_path_length)
    for line in open(input_file):
        instance_id, sentence, structured_reps, label_strings = read_json_line(
            line)
        worlds = []
        for structured_representation in structured_reps:
            boxes = {
                Box(object_list, box_id)
                for box_id, object_list in enumerate(structured_representation)
            }
            worlds.append(NlvrLanguage(boxes))
        labels = [label_string == "true" for label_string in label_strings]
        correct_logical_forms = []
        incorrect_logical_forms = []
        if ignore_agenda:
            # Get 1000 shortest logical forms.
            logical_forms = walker.get_all_logical_forms(
                max_num_logical_forms=1000)
        else:
            # TODO (pradeep): Assuming all worlds give the same agenda.
            sentence_agenda = worlds[0].get_agenda_for_sentence(sentence)
            logical_forms = walker.get_logical_forms_with_agenda(
                sentence_agenda, max_num_logical_forms * 10)
        for logical_form in logical_forms:
            if all([
                    world.execute(logical_form) == label
                    for world, label in zip(worlds, labels)
            ]):
                if len(correct_logical_forms) <= max_num_logical_forms:
                    correct_logical_forms.append(logical_form)
            else:
                if len(incorrect_logical_forms) <= max_num_logical_forms:
                    incorrect_logical_forms.append(logical_form)
            if (len(correct_logical_forms) >= max_num_logical_forms
                    and len(incorrect_logical_forms) >= max_num_logical_forms):
                break
        if write_sequences:
            correct_sequences = [
                worlds[0].logical_form_to_action_sequence(logical_form)
                for logical_form in correct_logical_forms
            ]
            incorrect_sequences = [
                worlds[0].logical_form_to_action_sequence(logical_form)
                for logical_form in incorrect_logical_forms
            ]
            processed_data.append({
                "id": instance_id,
                "sentence": sentence,
                "correct_sequences": correct_sequences,
                "incorrect_sequences": incorrect_sequences,
                "worlds": structured_reps,
                "labels": label_strings,
            })
        else:
            processed_data.append({
                "id": instance_id,
                "sentence": sentence,
                "correct_logical_forms": correct_logical_forms,
                "incorrect_logical_forms": incorrect_logical_forms,
                "worlds": structured_reps,
                "labels": label_strings,
            })
    with open(output_file, "w") as outfile:
        for instance_processed_data in processed_data:
            json.dump(instance_processed_data, outfile)
            outfile.write("\n")
        outfile.close()
예제 #5
0
class TestNlvrLanguage(AllenNlpTestCase):
    def setUp(self):
        super().setUp()
        test_filename = self.FIXTURES_ROOT / "data" / "nlvr" / "sample_ungrouped_data.jsonl"
        data = [
            json.loads(line)["structured_rep"]
            for line in open(test_filename).readlines()
        ]
        box_lists = [[
            Box(object_reps, i) for i, object_reps in enumerate(box_rep)
        ] for box_rep in data]
        self.languages = [NlvrLanguage(boxes) for boxes in box_lists]
        # y_loc increases as we go down from top to bottom, and x_loc from left to right. That is,
        # the origin is at the top-left corner.
        custom_rep = [[{
            "y_loc": 79,
            "size": 20,
            "type": "triangle",
            "x_loc": 27,
            "color": "Yellow"
        }, {
            "y_loc": 55,
            "size": 10,
            "type": "circle",
            "x_loc": 47,
            "color": "Black"
        }],
                      [{
                          "y_loc": 44,
                          "size": 30,
                          "type": "square",
                          "x_loc": 10,
                          "color": "#0099ff"
                      }, {
                          "y_loc": 74,
                          "size": 30,
                          "type": "square",
                          "x_loc": 40,
                          "color": "Yellow"
                      }],
                      [{
                          "y_loc": 60,
                          "size": 10,
                          "type": "triangle",
                          "x_loc": 12,
                          "color": "#0099ff"
                      }]]
        self.custom_language = NlvrLanguage(
            [Box(object_rep, i) for i, object_rep in enumerate(custom_rep)])

    def test_logical_form_with_assert_executes_correctly(self):
        executor = self.languages[0]
        # Utterance is "There is a circle closely touching a corner of a box." and label is "True".
        logical_form_true = "(object_count_greater_equals (touch_corner (circle (all_objects))) 1)"
        assert executor.execute(logical_form_true) is True
        logical_form_false = "(object_count_equals (touch_corner (circle (all_objects))) 9)"
        assert executor.execute(logical_form_false) is False

    def test_logical_form_with_box_filter_executes_correctly(self):
        executor = self.languages[2]
        # Utterance is "There is a box without a blue item." and label is "False".
        logical_form = "(box_exists (member_color_none_equals all_boxes color_blue))"
        assert executor.execute(logical_form) is False

    def test_logical_form_with_box_filter_within_object_filter_executes_correctly(
            self):
        executor = self.languages[2]
        # Utterance is "There are at least three blue items in boxes with blue items" and label
        # is "True".
        logical_form = "(object_count_greater_equals \
                            (object_in_box (member_color_any_equals all_boxes color_blue)) 3)"

        assert executor.execute(logical_form) is True

    def test_logical_form_with_same_color_executes_correctly(self):
        executor = self.languages[1]
        # Utterance is "There are exactly two blocks of the same color." and label is "True".
        logical_form = "(object_count_equals (same_color all_objects) 2)"
        assert executor.execute(logical_form) is True

    def test_logical_form_with_same_shape_executes_correctly(self):
        executor = self.languages[0]
        # Utterance is "There are less than three black objects of the same shape" and label is "False".
        logical_form = "(object_count_lesser (same_shape (black (all_objects))) 3)"
        assert executor.execute(logical_form) is False

    def test_logical_form_with_touch_wall_executes_correctly(self):
        executor = self.languages[0]
        # Utterance is "There are two black circles touching a wall" and label is "False".
        logical_form = "(object_count_greater_equals (touch_wall (black (circle (all_objects)))) 2)"
        assert executor.execute(logical_form) is False

    def test_logical_form_with_not_executes_correctly(self):
        executor = self.languages[2]
        # Utterance is "There are at most two medium triangles not touching a wall." and label is "True".
        logical_form = (
            "(object_count_lesser_equals ((negate_filter touch_wall) "
            "(medium (triangle (all_objects)))) 2)")
        assert executor.execute(logical_form) is True

    def test_logical_form_with_color_comparison_executes_correctly(self):
        executor = self.languages[0]
        # Utterance is "The color of the circle touching the wall is black." and label is "True".
        logical_form = "(object_color_all_equals (circle (touch_wall (all_objects))) color_black)"
        assert executor.execute(logical_form) is True

    def test_spatial_relations_return_objects_in_the_same_box(self):
        # "above", "below", "top", "bottom" are relations defined only for objects within the same
        # box. So they should not return objects from other boxes.
        # Asserting that the color of the objects above the yellow triangle is only black (it is not
        # yellow or blue, which are colors of objects from other boxes)
        assert self.custom_language.execute(
            "(object_color_all_equals (above (yellow (triangle all_objects)))"
            " color_black)") is True
        # Asserting that the only shape below the blue square is a square.
        assert self.custom_language.execute(
            "(object_shape_all_equals (below (blue (square all_objects)))"
            " shape_square)") is True
        # Asserting the shape of the object at the bottom in the box with a circle is triangle.
        logical_form = (
            "(object_shape_all_equals (bottom (object_in_box"
            " (member_shape_any_equals all_boxes shape_circle))) shape_triangle)"
        )
        assert self.custom_language.execute(logical_form) is True

        # Asserting the shape of the object at the top of the box with all squares is a square (!).
        logical_form = (
            "(object_shape_all_equals (top (object_in_box"
            " (member_shape_all_equals all_boxes shape_square))) shape_square)"
        )
        assert self.custom_language.execute(logical_form) is True

    def test_touch_object_executes_correctly(self):
        # Assert that there is a yellow square touching a blue square.
        assert self.custom_language.execute(
            "(object_exists (yellow (square (touch_object (blue "
            "(square all_objects))))))") is True
        # Assert that the triangle does not touch the circle (they are out of vertical range).
        assert self.custom_language.execute(
            "(object_shape_none_equals (touch_object (triangle all_objects))"
            " shape_circle)") is True

    def test_spatial_relations_with_objects_from_different_boxes(self):
        # When the objects are from different boxes, top and bottom should return objects from
        # respective boxes.
        # There are triangles in two boxes, so top should return the top objects from both boxes.
        assert self.custom_language.execute(
            "(object_count_equals (top (object_in_box (member_shape_any_equals "
            "all_boxes shape_triangle))) 2)") is True

    def test_same_and_different_execute_correctly(self):
        # All the objects in the box with two objects of the same shape are squares.
        assert self.custom_language.execute(
            "(object_shape_all_equals "
            "(object_in_box (member_shape_same (member_count_equals all_boxes 2)))"
            " shape_square)") is True
        # There is a circle in the box with objects of different shapes.
        assert self.custom_language.execute(
            "(object_shape_any_equals (object_in_box "
            "(member_shape_different all_boxes)) shape_circle)") is True

    def test_get_action_sequence_handles_multi_arg_functions(self):
        language = self.languages[0]
        # box_color_filter
        logical_form = "(box_exists (member_color_all_equals all_boxes color_blue))"
        action_sequence = language.logical_form_to_action_sequence(
            logical_form)
        assert 'Set[Box] -> [<Set[Box],Color:Set[Box]>, Set[Box], Color]' in action_sequence

        # box_shape_filter
        logical_form = "(box_exists (member_shape_all_equals all_boxes shape_square))"
        action_sequence = language.logical_form_to_action_sequence(
            logical_form)
        assert 'Set[Box] -> [<Set[Box],Shape:Set[Box]>, Set[Box], Shape]' in action_sequence

        # box_count_filter
        logical_form = "(box_exists (member_count_equals all_boxes 3))"
        action_sequence = language.logical_form_to_action_sequence(
            logical_form)
        assert 'Set[Box] -> [<Set[Box],int:Set[Box]>, Set[Box], int]' in action_sequence

        # assert_color
        logical_form = "(object_color_all_equals all_objects color_blue)"
        action_sequence = language.logical_form_to_action_sequence(
            logical_form)
        assert 'bool -> [<Set[Object],Color:bool>, Set[Object], Color]' in action_sequence

        # assert_shape
        logical_form = "(object_shape_all_equals all_objects shape_square)"
        action_sequence = language.logical_form_to_action_sequence(
            logical_form)
        assert 'bool -> [<Set[Object],Shape:bool>, Set[Object], Shape]' in action_sequence

        # assert_box_count
        logical_form = "(box_count_equals all_boxes 1)"
        action_sequence = language.logical_form_to_action_sequence(
            logical_form)
        assert 'bool -> [<Set[Box],int:bool>, Set[Box], int]' in action_sequence

        # assert_object_count
        logical_form = "(object_count_equals all_objects 1)"
        action_sequence = language.logical_form_to_action_sequence(
            logical_form)
        assert 'bool -> [<Set[Object],int:bool>, Set[Object], int]' in action_sequence

    def test_logical_form_with_object_filter_returns_correct_action_sequence(
            self):
        language = self.languages[0]
        logical_form = "(object_color_all_equals (circle (touch_wall all_objects)) color_black)"
        action_sequence = language.logical_form_to_action_sequence(
            logical_form)
        assert action_sequence == [
            '@start@ -> bool',
            'bool -> [<Set[Object],Color:bool>, Set[Object], Color]',
            '<Set[Object],Color:bool> -> object_color_all_equals',
            'Set[Object] -> [<Set[Object]:Set[Object]>, Set[Object]]',
            '<Set[Object]:Set[Object]> -> circle',
            'Set[Object] -> [<Set[Object]:Set[Object]>, Set[Object]]',
            '<Set[Object]:Set[Object]> -> touch_wall',
            'Set[Object] -> all_objects', 'Color -> color_black'
        ]

    def test_logical_form_with_negate_filter_returns_correct_action_sequence(
            self):
        language = self.languages[0]
        logical_form = "(object_exists ((negate_filter touch_wall) all_objects))"
        action_sequence = language.logical_form_to_action_sequence(
            logical_form)
        negate_filter_production = (
            '<Set[Object]:Set[Object]> -> '
            '[<<Set[Object]:Set[Object]>:<Set[Object]:Set[Object]>>, '
            '<Set[Object]:Set[Object]>]')
        assert action_sequence == [
            '@start@ -> bool', 'bool -> [<Set[Object]:bool>, Set[Object]]',
            '<Set[Object]:bool> -> object_exists',
            'Set[Object] -> [<Set[Object]:Set[Object]>, Set[Object]]',
            negate_filter_production,
            '<<Set[Object]:Set[Object]>:<Set[Object]:Set[Object]>> -> negate_filter',
            '<Set[Object]:Set[Object]> -> touch_wall',
            'Set[Object] -> all_objects'
        ]

    def test_logical_form_with_box_filter_returns_correct_action_sequence(
            self):
        language = self.languages[0]
        logical_form = "(box_exists (member_color_none_equals all_boxes color_blue))"
        action_sequence = language.logical_form_to_action_sequence(
            logical_form)
        assert action_sequence == [
            '@start@ -> bool', 'bool -> [<Set[Box]:bool>, Set[Box]]',
            '<Set[Box]:bool> -> box_exists',
            'Set[Box] -> [<Set[Box],Color:Set[Box]>, Set[Box], Color]',
            '<Set[Box],Color:Set[Box]> -> member_color_none_equals',
            'Set[Box] -> all_boxes', 'Color -> color_blue'
        ]

    def test_get_agenda_for_sentence(self):
        language = self.languages[0]
        agenda = language.get_agenda_for_sentence(
            "there is a tower with exactly two yellow blocks")
        assert set(agenda) == set([
            'Color -> color_yellow', '<Set[Box]:bool> -> box_exists',
            'int -> 2'
        ])
        agenda = language.get_agenda_for_sentence(
            "There is at most one yellow item closely touching "
            "the bottom of a box.")
        assert set(agenda) == set([
            '<Set[Object]:Set[Object]> -> yellow',
            '<Set[Object]:Set[Object]> -> touch_bottom', 'int -> 1'
        ])
        agenda = language.get_agenda_for_sentence(
            "There is at most one yellow item closely touching "
            "the right wall of a box.")
        assert set(agenda) == set([
            '<Set[Object]:Set[Object]> -> yellow',
            '<Set[Object]:Set[Object]> -> touch_right', 'int -> 1'
        ])
        agenda = language.get_agenda_for_sentence(
            "There is at most one yellow item closely touching "
            "the left wall of a box.")
        assert set(agenda) == set([
            '<Set[Object]:Set[Object]> -> yellow',
            '<Set[Object]:Set[Object]> -> touch_left', 'int -> 1'
        ])
        agenda = language.get_agenda_for_sentence(
            "There is at most one yellow item closely touching "
            "a wall of a box.")
        assert set(agenda) == set([
            '<Set[Object]:Set[Object]> -> yellow',
            '<Set[Object]:Set[Object]> -> touch_wall', 'int -> 1'
        ])
        agenda = language.get_agenda_for_sentence(
            "There is exactly one square touching any edge")
        assert set(agenda) == set([
            '<Set[Object]:Set[Object]> -> square',
            '<Set[Object]:Set[Object]> -> touch_wall', 'int -> 1'
        ])
        agenda = language.get_agenda_for_sentence(
            "There is exactly one square not touching any edge")
        assert set(agenda) == set([
            '<Set[Object]:Set[Object]> -> square',
            '<Set[Object]:Set[Object]> -> touch_wall', 'int -> 1',
            '<<Set[Object]:Set[Object]>:<Set[Object]:Set[Object]>> -> negate_filter'
        ])
        agenda = language.get_agenda_for_sentence(
            "There is only 1 tower with 1 blue block at the base")
        assert set(agenda) == set([
            '<Set[Object]:Set[Object]> -> blue', 'int -> 1',
            '<Set[Object]:Set[Object]> -> bottom', 'int -> 1'
        ])
        agenda = language.get_agenda_for_sentence(
            "There is only 1 tower that has 1 blue block at the top")
        assert set(agenda) == set([
            '<Set[Object]:Set[Object]> -> blue', 'int -> 1',
            '<Set[Object]:Set[Object]> -> top', 'int -> 1',
            'Set[Box] -> all_boxes'
        ])
        agenda = language.get_agenda_for_sentence(
            "There is exactly one square touching the blue "
            "triangle")
        assert set(agenda) == set([
            '<Set[Object]:Set[Object]> -> square',
            '<Set[Object]:Set[Object]> -> blue',
            '<Set[Object]:Set[Object]> -> triangle',
            '<Set[Object]:Set[Object]> -> touch_object', 'int -> 1'
        ])

    def test_get_agenda_for_sentence_correctly_adds_object_filters(self):
        # In logical forms that contain "box_exists" at the top, there can never be object filtering
        # operations like "blue", "square" etc. In those cases, strings like "blue" and "square" in
        # sentences should map to "color_blue" and "shape_square" respectively.
        language = self.languages[0]
        agenda = language.get_agenda_for_sentence(
            "there is a box with exactly two yellow triangles "
            "touching the top edge")
        assert "<Set[Object]:Set[Object]> -> yellow" not in agenda
        assert "Color -> color_yellow" in agenda
        assert "<Set[Object]:Set[Object]> -> triangle" not in agenda
        assert "Shape -> shape_triangle" in agenda
        assert "<Set[Object]:Set[Object]> -> touch_top" not in agenda
        agenda = language.get_agenda_for_sentence(
            "there are exactly two yellow triangles touching the"
            " top edge")
        assert "<Set[Object]:Set[Object]> -> yellow" in agenda
        assert "Color -> color_yellow" not in agenda
        assert "<Set[Object]:Set[Object]> -> triangle" in agenda
        assert "Shape -> shape_triangle" not in agenda
        assert "<Set[Object]:Set[Object]> -> touch_top" in agenda
예제 #6
0
파일: nlvr.py 프로젝트: Jaynil1611/syn-qg
    def text_to_instance(
            self,  # type: ignore
            sentence: str,
            structured_representations: List[List[List[JsonDict]]],
            labels: List[str] = None,
            target_sequences: List[List[str]] = None,
            identifier: str = None) -> Instance:
        """
        Parameters
        ----------
        sentence : ``str``
            The query sentence.
        structured_representations : ``List[List[List[JsonDict]]]``
            A list of Json representations of all the worlds. See expected format in this class' docstring.
        labels : ``List[str]`` (optional)
            List of string representations of the labels (true or false) corresponding to the
            ``structured_representations``. Not required while testing.
        target_sequences : ``List[List[str]]`` (optional)
            List of target action sequences for each element which lead to the correct denotation in
            worlds corresponding to the structured representations.
        identifier : ``str`` (optional)
            The identifier from the dataset if available.
        """
        # pylint: disable=arguments-differ
        worlds = []
        for structured_representation in structured_representations:
            boxes = set([
                Box(object_list, box_id)
                for box_id, object_list in enumerate(structured_representation)
            ])
            worlds.append(NlvrLanguage(boxes))
        tokenized_sentence = self._tokenizer.tokenize(sentence)
        sentence_field = TextField(tokenized_sentence,
                                   self._sentence_token_indexers)
        production_rule_fields: List[Field] = []
        instance_action_ids: Dict[str, int] = {}
        # TODO(pradeep): Assuming that possible actions are the same in all worlds. This may change
        # later.
        for production_rule in worlds[0].all_possible_productions():
            instance_action_ids[production_rule] = len(instance_action_ids)
            field = ProductionRuleField(production_rule, is_global_rule=True)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)
        worlds_field = ListField([MetadataField(world) for world in worlds])
        metadata: Dict[str, Any] = {
            "sentence_tokens": [x.text for x in tokenized_sentence]
        }
        fields: Dict[str, Field] = {
            "sentence": sentence_field,
            "worlds": worlds_field,
            "actions": action_field,
            "metadata": MetadataField(metadata)
        }
        if identifier is not None:
            fields["identifier"] = MetadataField(identifier)
        # Depending on the type of supervision used for training the parser, we may want either
        # target action sequences or an agenda in our instance. We check if target sequences are
        # provided, and include them if they are. If not, we'll get an agenda for the sentence, and
        # include that in the instance.
        if target_sequences:
            action_sequence_fields: List[Field] = []
            for target_sequence in target_sequences:
                index_fields = ListField([
                    IndexField(instance_action_ids[action], action_field)
                    for action in target_sequence
                ])
                action_sequence_fields.append(index_fields)
                # TODO(pradeep): Define a max length for this field.
            fields["target_action_sequences"] = ListField(
                action_sequence_fields)
        elif self._output_agendas:
            # TODO(pradeep): Assuming every world gives the same agenda for a sentence. This is true
            # now, but may change later too.
            agenda = worlds[0].get_agenda_for_sentence(sentence)
            assert agenda, "No agenda found for sentence: %s" % sentence
            # agenda_field contains indices into actions.
            agenda_field = ListField([
                IndexField(instance_action_ids[action], action_field)
                for action in agenda
            ])
            fields["agenda"] = agenda_field
        if labels:
            labels_field = ListField([
                LabelField(label, label_namespace='denotations')
                for label in labels
            ])
            fields["labels"] = labels_field

        return Instance(fields)
    def __init__(
        self,
        vocab: Vocabulary,
        sentence_embedder: TextFieldEmbedder,
        action_embedding_dim: int,
        encoder: Seq2SeqEncoder,
        attention: Attention,
        beam_size: int,
        max_decoding_steps: int,
        max_num_finished_states: int = None,
        dropout: float = 0.0,
        normalize_beam_score_by_length: bool = False,
        checklist_cost_weight: float = 0.6,
        dynamic_cost_weight: Dict[str, Union[int, float]] = None,
        penalize_non_agenda_actions: bool = False,
        initial_mml_model_file: str = None,
    ) -> None:
        super().__init__(
            vocab=vocab,
            sentence_embedder=sentence_embedder,
            action_embedding_dim=action_embedding_dim,
            encoder=encoder,
            dropout=dropout,
        )
        self._agenda_coverage = Average()
        self._decoder_trainer: DecoderTrainer[
            Callable[[CoverageState], torch.Tensor]
        ] = ExpectedRiskMinimization(
            beam_size=beam_size,
            normalize_by_length=normalize_beam_score_by_length,
            max_decoding_steps=max_decoding_steps,
            max_num_finished_states=max_num_finished_states,
        )

        # Instantiating an empty NlvrLanguage just to get the number of terminals.
        self._terminal_productions = set(NlvrLanguage(set()).terminal_productions.values())
        self._decoder_step = CoverageTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=attention,
            activation=Activation.by_name("tanh")(),
            add_action_bias=False,
            dropout=dropout,
        )
        self._checklist_cost_weight = checklist_cost_weight
        self._dynamic_cost_wait_epochs = None
        self._dynamic_cost_rate = None
        if dynamic_cost_weight:
            self._dynamic_cost_wait_epochs = dynamic_cost_weight["wait_num_epochs"]
            self._dynamic_cost_rate = dynamic_cost_weight["rate"]
        self._penalize_non_agenda_actions = penalize_non_agenda_actions
        self._last_epoch_in_forward: int = None
        # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
        # copied a trained ERM model from a different machine and the original MML model that was
        # used to initialize it does not exist on the current machine. This may not be the best
        # solution for the problem.
        if initial_mml_model_file is not None:
            if os.path.isfile(initial_mml_model_file):
                archive = load_archive(initial_mml_model_file)
                self._initialize_weights_from_archive(archive)
            else:
                # A model file is passed, but it does not exist. This is expected to happen when
                # you're using a trained ERM model to decode. But it may also happen if the path to
                # the file is really just incorrect. So throwing a warning.
                logger.warning(
                    "MML model file for initializing weights is passed, but does not exist."
                    " This is fine if you're just decoding."
                )