Esempio n. 1
0
 def test_get_valid_actions_adds_lambda_productions_only_for_correct_type(
         self):
     state = LambdaGrammarStatelet(
         ["t"],
         {("s", "x"): ["t"]},
         {
             "s": {
                 "global":
                 (torch.Tensor([1, 1]), torch.Tensor([2, 2]), [1, 2])
             },
             "t": {
                 "global":
                 (torch.Tensor([3, 3]), torch.Tensor([4, 4]), [3, 4])
             },
         },
         {"s -> x": (torch.Tensor([5]), torch.Tensor([6]), 5)},
         is_nonterminal,
     )
     actions = state.get_valid_actions()
     assert_almost_equal(actions["global"][0].cpu().numpy(), [3, 3])
     assert_almost_equal(actions["global"][1].cpu().numpy(), [4, 4])
     assert actions["global"][2] == [3, 4]
     # We're doing this assert twice to make sure we haven't accidentally modified the state.
     actions = state.get_valid_actions()
     assert_almost_equal(actions["global"][0].cpu().numpy(), [3, 3])
     assert_almost_equal(actions["global"][1].cpu().numpy(), [4, 4])
     assert actions["global"][2] == [3, 4]
Esempio n. 2
0
 def test_get_valid_actions_adds_lambda_productions(self):
     state = LambdaGrammarStatelet(['s'],
                                   {('s', 'x'): ['s']},
                                   {'s': {'global': (torch.Tensor([1, 1]), torch.Tensor([2, 2]), [1, 2])}},
                                   {'s -> x': (torch.Tensor([5]), torch.Tensor([6]), 5)},
                                   is_nonterminal)
     actions = state.get_valid_actions()
     assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5])
     assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6])
     assert actions['global'][2] == [1, 2, 5]
     # We're doing this assert twice to make sure we haven't accidentally modified the state.
     actions = state.get_valid_actions()
     assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5])
     assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6])
     assert actions['global'][2] == [1, 2, 5]
Esempio n. 3
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = LambdaGrammarStatelet(['s'], {}, {'s': s_actions, 't': t_actions}, {}, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = LambdaGrammarStatelet(['t'], {}, {'s': s_actions, 't': t_actions}, {}, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = LambdaGrammarStatelet(['e'],
                                   {},
                                   {'s': s_actions, 't': t_actions, 'e': e_actions},
                                   {},
                                   is_nonterminal)
     assert state.get_valid_actions() == e_actions
Esempio n. 4
0
    def test_take_action_gives_correct_next_states_with_non_lambda_productions(self):
        # state.take_action() doesn't read or change these objects, it just passes them through, so
        # we'll use some sentinels to be sure of that.
        valid_actions = object()
        context_actions = object()

        state = LambdaGrammarStatelet(['s'], {}, valid_actions, context_actions, is_nonterminal)
        next_state = state.take_action('s -> [t, r]')
        expected_next_state = LambdaGrammarStatelet(['r', 't'], {}, valid_actions, context_actions, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = LambdaGrammarStatelet(['r', 't'], {}, valid_actions, context_actions, is_nonterminal)
        next_state = state.take_action('t -> identity')
        expected_next_state = LambdaGrammarStatelet(['r'], {}, valid_actions, context_actions, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__
Esempio n. 5
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = LambdaGrammarStatelet(["s"], {}, {
         "s": s_actions,
         "t": t_actions
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = LambdaGrammarStatelet(["t"], {}, {
         "s": s_actions,
         "t": t_actions
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = LambdaGrammarStatelet(["e"], {}, {
         "s": s_actions,
         "t": t_actions,
         "e": e_actions
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == e_actions
Esempio n. 6
0
 def test_is_finished_just_uses_nonterminal_stack(self):
     state = LambdaGrammarStatelet(["s"], {}, {}, {}, is_nonterminal)
     assert not state.is_finished()
     state = LambdaGrammarStatelet([], {}, {}, {}, is_nonterminal)
     assert state.is_finished()
Esempio n. 7
0
    def test_take_action_gives_correct_next_states_with_lambda_productions(
            self):
        # state.take_action() doesn't read or change these objects, it just passes them through, so
        # we'll use some sentinels to be sure of that.
        valid_actions = object()
        context_actions = object()

        state = LambdaGrammarStatelet(["t", "<s,d>"], {}, valid_actions,
                                      context_actions, is_nonterminal)
        next_state = state.take_action("<s,d> -> [lambda x, d]")
        expected_next_state = LambdaGrammarStatelet(["t", "d"],
                                                    {("s", "x"): ["d"]},
                                                    valid_actions,
                                                    context_actions,
                                                    is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action("d -> [<s,r>, d]")
        expected_next_state = LambdaGrammarStatelet(
            ["t", "d", "<s,r>"],
            {("s", "x"): ["d", "<s,r>"]},
            valid_actions,
            context_actions,
            is_nonterminal,
        )
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action("<s,r> -> [lambda y, r]")
        expected_next_state = LambdaGrammarStatelet(
            ["t", "d", "r"],
            {
                ("s", "x"): ["d", "r"],
                ("s", "y"): ["r"]
            },
            valid_actions,
            context_actions,
            is_nonterminal,
        )
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action("r -> identity")
        expected_next_state = LambdaGrammarStatelet(["t", "d"],
                                                    {("s", "x"): ["d"]},
                                                    valid_actions,
                                                    context_actions,
                                                    is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action("d -> x")
        expected_next_state = LambdaGrammarStatelet(["t"], {}, valid_actions,
                                                    context_actions,
                                                    is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__
Esempio n. 8
0
 def test_take_action_crashes_with_mismatched_types(self):
     with pytest.raises(AssertionError):
         state = LambdaGrammarStatelet(["s"], {}, {}, {}, is_nonterminal)
         state.take_action("t -> identity")
Esempio n. 9
0
    def test_take_action_gives_correct_next_states_with_lambda_productions(self):
        # state.take_action() doesn't read or change these objects, it just passes them through, so
        # we'll use some sentinels to be sure of that.
        valid_actions = object()
        context_actions = object()

        state = LambdaGrammarStatelet(['t', '<s,d>'], {}, valid_actions, context_actions, is_nonterminal)
        next_state = state.take_action('<s,d> -> [lambda x, d]')
        expected_next_state = LambdaGrammarStatelet(['t', 'd'],
                                                    {('s', 'x'): ['d']},
                                                    valid_actions,
                                                    context_actions,
                                                    is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('d -> [<s,r>, d]')
        expected_next_state = LambdaGrammarStatelet(['t', 'd', '<s,r>'],
                                                    {('s', 'x'): ['d', '<s,r>']},
                                                    valid_actions,
                                                    context_actions,
                                                    is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('<s,r> -> [lambda y, r]')
        expected_next_state = LambdaGrammarStatelet(['t', 'd', 'r'],
                                                    {('s', 'x'): ['d', 'r'], ('s', 'y'): ['r']},
                                                    valid_actions,
                                                    context_actions,
                                                    is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('r -> identity')
        expected_next_state = LambdaGrammarStatelet(['t', 'd'],
                                                    {('s', 'x'): ['d']},
                                                    valid_actions,
                                                    context_actions,
                                                    is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('d -> x')
        expected_next_state = LambdaGrammarStatelet(['t'],
                                                    {},
                                                    valid_actions,
                                                    context_actions,
                                                    is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__
 def test_take_action_crashes_with_mismatched_types(self):
     with pytest.raises(AssertionError):
         state = LambdaGrammarStatelet(['s'], {}, {}, {}, is_nonterminal)
         state.take_action('t -> identity')
 def test_is_finished_just_uses_nonterminal_stack(self):
     state = LambdaGrammarStatelet(['s'], {}, {}, {}, is_nonterminal)
     assert not state.is_finished()
     state = LambdaGrammarStatelet([], {}, {}, {}, is_nonterminal)
     assert state.is_finished()
    def _create_grammar_state(
            self, world: WikiTablesWorld,
            possible_actions: List[ProductionRule],
            linking_scores: torch.Tensor,
            entity_types: torch.Tensor) -> LambdaGrammarStatelet:
        """
        This method creates the LambdaGrammarStatelet object that's used for decoding.  Part of
        creating that is creating the `valid_actions` dictionary, which contains embedded
        representations of all of the valid actions.  So, we create that here as well.

        The way we represent the valid expansions is a little complicated: we use a
        dictionary of `action types`, where the key is the action type (like "global", "linked", or
        whatever your model is expecting), and the value is a tuple representing all actions of
        that type.  The tuple is (input tensor, output tensor, action id).  The input tensor has
        the representation that is used when `selecting` actions, for all actions of this type.
        The output tensor has the representation that is used when feeding the action to the next
        step of the decoder (this could just be the same as the input tensor).  The action ids are
        a list of indices into the main action list for each batch instance.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``WikiTablesWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        # TODO(mattg): Move the "valid_actions" construction to another method.
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(
                    global_action_tensor)
                global_input_embeddings = torch.cat(
                    [global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(
                global_action_tensor)
            translated_valid_actions[key]['global'] = (
                global_input_embeddings, global_output_embeddings,
                list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(
                    entity_type_tensor)
                translated_valid_actions[key]['linked'] = (
                    entity_linking_scores, entity_type_embeddings,
                    list(linked_action_ids))

        # Lastly, we need to also create embedded representations of context-specific actions.  In
        # this case, those are only variable productions, like "r -> x".  Note that our language
        # only permits one lambda at a time, so we don't need to worry about how nested lambdas
        # might impact this.
        context_actions = {}
        for action_id, action in enumerate(possible_actions):
            if action[0].endswith(" -> x"):
                input_embedding = self._action_embedder(action[2])
                if self._add_action_bias:
                    input_bias = self._action_biases(action[2])
                    input_embedding = torch.cat([input_embedding, input_bias],
                                                dim=-1)
                output_embedding = self._output_action_embedder(action[2])
                context_actions[action[0]] = (input_embedding,
                                              output_embedding, action_id)

        return LambdaGrammarStatelet([START_SYMBOL], {},
                                     translated_valid_actions, context_actions,
                                     type_declaration.is_nonterminal)