예제 #1
0
    def test_get_actions_to_consider(self):
        # pylint: disable=protected-access
        valid_actions_1 = {'e': [0, 1, 2, 4]}
        valid_actions_2 = {'e': [0, 1, 3]}
        valid_actions_3 = {'e': [2, 3, 4]}
        self.state.grammar_state[0] = GrammarState(['e'], {}, valid_actions_1, {}, is_nonterminal)
        self.state.grammar_state[1] = GrammarState(['e'], {}, valid_actions_2, {}, is_nonterminal)
        self.state.grammar_state[2] = GrammarState(['e'], {}, valid_actions_3, {}, is_nonterminal)

        # We're making a bunch of the actions linked actions here, pretending that there are only
        # two global actions.
        self.state.action_indices = {
                (0, 0): 1,
                (0, 1): 0,
                (0, 2): -1,
                (0, 3): -1,
                (0, 4): -1,
                (1, 0): -1,
                (1, 1): 0,
                (1, 2): -1,
                (1, 3): -1,
                }

        considered, to_embed, to_link = WikiTablesDecoderStep._get_actions_to_consider(self.state)
        # These are _global_ action indices.  They come from actions [[(0, 0), (0, 1)], [(1, 1)], []].
        expected_to_embed = [[1, 0], [0], []]
        assert to_embed == expected_to_embed
        # These are _batch_ action indices with a _global_ action index of -1.
        # They come from actions [[(0, 2), (0, 4)], [(1, 0), (1, 3)], [(0, 2), (0, 3), (0, 4)]].
        expected_to_link = [[2, 4], [0, 3], [2, 3, 4]]
        assert to_link == expected_to_link
        # These are _batch_ action indices, with padding in between the embedded actions and the
        # linked actions (and after the linked actions, if necessary).
        expected_considered = [[0, 1, 2, 4, -1], [1, -1, 0, 3, -1], [-1, -1, 2, 3, 4]]
        assert considered == expected_considered
예제 #2
0
 def test_get_valid_actions_adds_lambda_productions_only_for_correct_type(
         self):
     state = GrammarState(['t'], {('s', 'x'): ['t']}, {
                                      's': [1, 2],
                                      't': [3, 4]
                                  }, {'s -> x': 5}, is_nonterminal)
     assert state.get_valid_actions() == [3, 4]
     # We're doing this assert twice to make sure we haven't accidentally modified the state.
     assert state.get_valid_actions() == [3, 4]
예제 #3
0
 def test_get_valid_actions_adds_lambda_productions_only_for_correct_type(self):
     state = GrammarState(['t'],
                          {('s', 'x'): ['t']},
                          {'s': [1, 2], 't': [3, 4]},
                          {'s -> x': 5},
                          is_nonterminal)
     assert state.get_valid_actions() == [3, 4]
     # We're doing this assert twice to make sure we haven't accidentally modified the state.
     assert state.get_valid_actions() == [3, 4]
예제 #4
0
    def test_take_action_gives_correct_next_states_with_non_lambda_productions(self):
        # state.take_action() doesn't read or change these objects, it just passes them through, so
        # we'll use some sentinels to be sure of that.
        valid_actions = object()
        action_indices = object()

        state = GrammarState(['s'], {}, valid_actions, action_indices, is_nonterminal)
        next_state = state.take_action('s -> [t, r]')
        expected_next_state = GrammarState(['r', 't'], {}, valid_actions, action_indices, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = GrammarState(['r', 't'], {}, valid_actions, action_indices, is_nonterminal)
        next_state = state.take_action('t -> identity')
        expected_next_state = GrammarState(['r'], {}, valid_actions, action_indices, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__
예제 #5
0
    def _create_grammar_state(
            self, world: NlvrWorld,
            possible_actions: List[ProductionRuleArray]) -> GrammarState:
        valid_actions = world.get_valid_actions()
        action_mapping = {}
        for i, action in enumerate(possible_actions):
            action_mapping[action[0]] = i
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.
            action_indices = [
                action_mapping[action_string]
                for action_string in action_strings
            ]
            # All actions in NLVR are global actions.
            global_actions = [(possible_actions[index][2], index)
                              for index in action_indices]

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                       global_input_embeddings,
                                                       list(global_action_ids))
        return GrammarState([START_SYMBOL], {}, translated_valid_actions, {},
                            type_declaration.is_nonterminal)
예제 #6
0
 def test_get_valid_actions_adds_lambda_productions(self):
     state = GrammarState(
         ['s'], {('s', 'x'): ['s']},
         {
             's': {
                 'global':
                 (torch.Tensor([1, 1]), torch.Tensor([2, 2]), [1, 2])
             }
         }, {'s -> x':
             (torch.Tensor([5]), torch.Tensor([6]), 5)}, is_nonterminal)
     actions = state.get_valid_actions()
     assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5])
     assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6])
     assert actions['global'][2] == [1, 2, 5]
     # We're doing this assert twice to make sure we haven't accidentally modified the state.
     actions = state.get_valid_actions()
     assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5])
     assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6])
     assert actions['global'][2] == [1, 2, 5]
예제 #7
0
 def test_get_actions_to_consider_returns_none_if_no_linked_actions(self):
     # pylint: disable=protected-access
     valid_actions_1 = {'e': [0, 1, 2, 4]}
     valid_actions_2 = {'e': [0, 1, 3]}
     valid_actions_3 = {'e': [2, 3, 4]}
     self.state.grammar_state[0] = GrammarState(['e'], {}, valid_actions_1, {}, is_nonterminal)
     self.state.grammar_state[1] = GrammarState(['e'], {}, valid_actions_2, {}, is_nonterminal)
     self.state.grammar_state[2] = GrammarState(['e'], {}, valid_actions_3, {}, is_nonterminal)
     considered, to_embed, to_link = WikiTablesDecoderStep._get_actions_to_consider(self.state)
     # These are _global_ action indices.  All of the actions in this case are embedded, so this
     # is just a mapping from the valid actions above to their global ids.
     expected_to_embed = [[1, 0, 2, 5], [4, 0, 3], [2, 3, 5]]
     assert to_embed == expected_to_embed
     # There are no linked actions (all of them are embedded), so this should be None.
     assert to_link is None
     # These are _batch_ action indices, with padding in between the embedded actions and the
     # linked actions.  Because there are no linked actions, this is basically just the
     # valid_actions for each group element padded with -1s.
     expected_considered = [[0, 1, 2, 4], [0, 1, 3, -1], [2, 3, 4, -1]]
     assert considered == expected_considered
예제 #8
0
 def test_get_valid_actions_uses_top_of_stack(self):
     state = GrammarState(['s'], {}, {
         's': [1, 2],
         't': [3, 4]
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == [1, 2]
     state = GrammarState(['t'], {}, {
         's': [1, 2],
         't': [3, 4]
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == [3, 4]
     state = GrammarState(['e'], {}, {
         's': [1, 2],
         't': [3, 4],
         'e': []
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == []
예제 #9
0
 def test_get_valid_actions_uses_top_of_stack(self):
     state = GrammarState(['s'], {}, {'s': [1, 2], 't': [3, 4]}, {}, is_nonterminal)
     assert state.get_valid_actions() == [1, 2]
     state = GrammarState(['t'], {}, {'s': [1, 2], 't': [3, 4]}, {}, is_nonterminal)
     assert state.get_valid_actions() == [3, 4]
     state = GrammarState(['e'], {}, {'s': [1, 2], 't': [3, 4], 'e': []}, {}, is_nonterminal)
     assert state.get_valid_actions() == []
 def _create_grammar_state(world, possible_actions):
     valid_actions = world.get_valid_actions()
     action_mapping = {}
     for i, action in enumerate(possible_actions):
         action_string = action[0]
         action_mapping[action_string] = i
     translated_valid_actions = {}
     for key, action_strings in list(valid_actions.items()):
         translated_valid_actions[key] = [
             action_mapping[action_string]
             for action_string in action_strings
         ]
     return GrammarState([START_SYMBOL], {}, translated_valid_actions,
                         action_mapping, type_declaration.is_nonterminal)
예제 #11
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = GrammarState(['s'], {}, {
         's': s_actions,
         't': t_actions
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = GrammarState(['t'], {}, {
         's': s_actions,
         't': t_actions
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = GrammarState(['e'], {}, {
         's': s_actions,
         't': t_actions,
         'e': e_actions
     }, {}, is_nonterminal)
     assert state.get_valid_actions() == e_actions
예제 #12
0
 def _create_grammar_state(
         world: NlvrWorld,
         possible_actions: List[ProductionRuleArray]) -> GrammarState:
     valid_actions = world.get_valid_actions()
     action_mapping = {}
     for i, action in enumerate(possible_actions):
         action_mapping[action[0]] = i
     translated_valid_actions = {}
     for key, action_strings in valid_actions.items():
         translated_valid_actions[key] = [
             action_mapping[action_string]
             for action_string in action_strings
         ]
     return GrammarState([START_SYMBOL], {}, translated_valid_actions,
                         action_mapping, type_declaration.is_nonterminal)
예제 #13
0
    def test_take_action_gives_correct_next_states_with_non_lambda_productions(
            self):
        # state.take_action() doesn't read or change these objects, it just passes them through, so
        # we'll use some sentinels to be sure of that.
        valid_actions = object()
        context_actions = object()

        state = GrammarState(['s'], {}, valid_actions, context_actions,
                             is_nonterminal)
        next_state = state.take_action('s -> [t, r]')
        expected_next_state = GrammarState(['r', 't'], {}, valid_actions,
                                           context_actions, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = GrammarState(['r', 't'], {}, valid_actions, context_actions,
                             is_nonterminal)
        next_state = state.take_action('t -> identity')
        expected_next_state = GrammarState(['r'], {}, valid_actions,
                                           context_actions, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__
예제 #14
0
    def test_take_action_gives_correct_next_states_with_lambda_productions(self):
        # state.take_action() doesn't read or change these objects, it just passes them through, so
        # we'll use some sentinels to be sure of that.
        valid_actions = object()
        action_indices = object()

        state = GrammarState(['t', '<s,d>'], {}, valid_actions, action_indices, is_nonterminal)
        next_state = state.take_action('<s,d> -> [lambda x, d]')
        expected_next_state = GrammarState(['t', 'd'],
                                           {('s', 'x'): ['d']},
                                           valid_actions,
                                           action_indices,
                                           is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('d -> [<s,r>, d]')
        expected_next_state = GrammarState(['t', 'd', '<s,r>'],
                                           {('s', 'x'): ['d', '<s,r>']},
                                           valid_actions,
                                           action_indices,
                                           is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('<s,r> -> [lambda y, r]')
        expected_next_state = GrammarState(['t', 'd', 'r'],
                                           {('s', 'x'): ['d', 'r'], ('s', 'y'): ['r']},
                                           valid_actions,
                                           action_indices,
                                           is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('r -> identity')
        expected_next_state = GrammarState(['t', 'd'],
                                           {('s', 'x'): ['d']},
                                           valid_actions,
                                           action_indices,
                                           is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('d -> x')
        expected_next_state = GrammarState(['t'],
                                           {},
                                           valid_actions,
                                           action_indices,
                                           is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__
예제 #15
0
    def setUp(self):
        super().setUp()

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [Variable(torch.FloatTensor([x])) for x in [.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        grammar_state = [
            GrammarState(['e'], {}, {}, {}, is_nonterminal)
            for _ in batch_indices
        ]
        self.encoder_outputs = torch.FloatTensor([[1, 2], [3, 4], [5, 6]])
        self.encoder_output_mask = Variable(
            torch.FloatTensor([[1, 1], [1, 0], [1, 1]]))
        self.action_embeddings = torch.FloatTensor([[0, 0], [1, 1], [2, 2],
                                                    [3, 3], [4, 4], [5, 5]])
        self.action_indices = {
            (0, 0): 1,
            (0, 1): 0,
            (0, 2): 2,
            (0, 3): 3,
            (0, 4): 5,
            (1, 0): 4,
            (1, 1): 0,
            (1, 2): 2,
            (1, 3): 3,
        }
        self.possible_actions = [[
            ('e -> f', False, None), ('e -> g', True, None),
            ('e -> h', True, None), ('e -> i', True, None),
            ('e -> j', True, None)
        ],
                                 [
                                     ('e -> q', True, None),
                                     ('e -> g', True, None),
                                     ('e -> h', True, None),
                                     ('e -> i', True, None)
                                 ]]

        # (batch_size, num_entities, num_question_tokens) = (2, 5, 3)
        linking_scores = Variable(
            torch.Tensor([[[.1, .2, .3], [.4, .5, .6], [.7, .8, .9],
                           [1.0, 1.1, 1.2], [1.3, 1.4, 1.5]],
                          [[-.1, -.2, -.3], [-.4, -.5, -.6], [-.7, -.8, -.9],
                           [-1.0, -1.1, -1.2], [-1.3, -1.4, -1.5]]]))
        flattened_linking_scores = linking_scores.view(2 * 5, 3)

        # Maps (batch_index, action_index) to indices into the flattened linking score tensor,
        # which has shae (batch_size * num_entities, num_question_tokens).
        actions_to_entities = {
            (0, 0): 0,
            (0, 1): 1,
            (0, 2): 2,
            (0, 6): 3,
            (1, 3): 6,
            (1, 4): 7,
            (1, 5): 8,
        }
        entity_types = {
            0: 0,
            1: 2,
            2: 1,
            3: 0,
            4: 0,
            5: 1,
            6: 0,
            7: 1,
            8: 2,
        }
        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnState(hidden_state[i], memory_cell[i],
                         previous_action_embedding[i], attended_question[i],
                         self.encoder_outputs, self.encoder_output_mask))
        self.state = WikiTablesDecoderState(
            batch_indices=batch_indices,
            action_history=action_history,
            score=score,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            action_embeddings=self.action_embeddings,
            action_indices=self.action_indices,
            possible_actions=self.possible_actions,
            flattened_linking_scores=flattened_linking_scores,
            actions_to_entities=actions_to_entities,
            entity_types=entity_types)
예제 #16
0
 def test_is_finished_just_uses_nonterminal_stack(self):
     state = GrammarState(['s'], {}, {}, {}, is_nonterminal)
     assert not state.is_finished()
     state = GrammarState([], {}, {}, {}, is_nonterminal)
     assert state.is_finished()
예제 #17
0
 def test_is_finished_just_uses_nonterminal_stack(self):
     state = GrammarState(['s'], {}, {}, {}, is_nonterminal)
     assert not state.is_finished()
     state = GrammarState([], {}, {}, {}, is_nonterminal)
     assert state.is_finished()
예제 #18
0
    def test_take_action_gives_correct_next_states_with_lambda_productions(
            self):
        # state.take_action() doesn't read or change these objects, it just passes them through, so
        # we'll use some sentinels to be sure of that.
        valid_actions = object()
        context_actions = object()

        state = GrammarState(['t', '<s,d>'], {}, valid_actions,
                             context_actions, is_nonterminal)
        next_state = state.take_action('<s,d> -> [lambda x, d]')
        expected_next_state = GrammarState(['t', 'd'], {('s', 'x'): ['d']},
                                           valid_actions, context_actions,
                                           is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('d -> [<s,r>, d]')
        expected_next_state = GrammarState(['t', 'd', '<s,r>'],
                                           {('s', 'x'): ['d', '<s,r>']},
                                           valid_actions, context_actions,
                                           is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('<s,r> -> [lambda y, r]')
        expected_next_state = GrammarState(['t', 'd', 'r'], {
            ('s', 'x'): ['d', 'r'],
            ('s', 'y'): ['r']
        }, valid_actions, context_actions, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('r -> identity')
        expected_next_state = GrammarState(['t', 'd'], {('s', 'x'): ['d']},
                                           valid_actions, context_actions,
                                           is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('d -> x')
        expected_next_state = GrammarState(['t'], {}, valid_actions,
                                           context_actions, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__
예제 #19
0
 def test_take_action_crashes_with_mismatched_types(self):
     with pytest.raises(AssertionError):
         state = GrammarState(['s'], {}, {}, {}, is_nonterminal)
         state.take_action('t -> identity')
    def setUp(self):
        super().setUp()
        self.decoder_step = BasicTransitionFunction(
            encoder_output_dim=2,
            action_embedding_dim=2,
            input_attention=Attention.by_name('dot_product')(),
            num_start_types=3,
            add_action_bias=False)

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [torch.FloatTensor([x]) for x in [.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        # This maps non-terminals to valid actions, where the valid actions are grouped by _type_.
        # We have "global" actions, which are from the global grammar, and "linked" actions, which
        # are instance-specific and are generated based on question attention.  Each action type
        # has a tuple which is (input representation, output representation, action ids).
        valid_actions = {
            'e': {
                'global': (torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]),
                           torch.FloatTensor([[-1, -1], [-2, -2],
                                              [-3, -3]]), [0, 1, 2]),
                'linked': (torch.FloatTensor([[.1, .2, .3], [.4, .5, .6]]),
                           torch.FloatTensor([[3, 3], [4, 4]]), [3, 4])
            },
            'd': {
                'global':
                (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1,
                                                                  -1]]), [0]),
                'linked': (torch.FloatTensor([[-.1, -.2, -.3], [-.4, -.5, -.6],
                                              [-.7, -.8, -.9]]),
                           torch.FloatTensor([[5, 5], [6, 6], [7,
                                                               7]]), [1, 2, 3])
            }
        }
        grammar_state = [
            GrammarState([nonterminal], {}, valid_actions, {}, is_nonterminal)
            for _, nonterminal in zip(batch_indices, ['e', 'd', 'e'])
        ]
        self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]],
                                                  [[10, 11], [12, 13],
                                                   [14, 15]]])
        self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
        self.possible_actions = [[
            ('e -> f', False, None), ('e -> g', True, None),
            ('e -> h', True, None), ('e -> i', True, None),
            ('e -> j', True, None)
        ],
                                 [
                                     ('d -> q', True, None),
                                     ('d -> g', True, None),
                                     ('d -> h', True, None),
                                     ('d -> i', True, None)
                                 ]]

        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnState(hidden_state[i], memory_cell[i],
                         previous_action_embedding[i], attended_question[i],
                         self.encoder_outputs, self.encoder_output_mask))
        self.state = GrammarBasedDecoderState(
            batch_indices=batch_indices,
            action_history=action_history,
            score=score,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            possible_actions=self.possible_actions)
예제 #21
0
 def test_take_action_crashes_with_mismatched_types(self):
     with pytest.raises(AssertionError):
         state = GrammarState(['s'], {}, {}, {}, is_nonterminal)
         state.take_action('t -> identity')
예제 #22
0
 def test_get_valid_actions_adds_lambda_productions(self):
     state = GrammarState([u's'], {(u's', u'x'): [u's']}, {u's': [1, 2]},
                          {u's -> x': 5}, is_nonterminal)
     assert state.get_valid_actions() == [1, 2, 5]
     # We're doing this assert twice to make sure we haven't accidentally modified the state.
     assert state.get_valid_actions() == [1, 2, 5]
예제 #23
0
    def _create_grammar_state(self, world: WikiTablesWorld,
                              possible_actions: List[ProductionRuleArray],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarState:
        """
        This method creates the GrammarState object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRuleArrays``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``WikiTablesWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRuleArray]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(
                    global_action_tensor)
                global_input_embeddings = torch.cat(
                    [global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(
                global_action_tensor)
            translated_valid_actions[key]['global'] = (
                global_input_embeddings, global_output_embeddings,
                list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(
                    entity_type_tensor)
                translated_valid_actions[key]['linked'] = (
                    entity_linking_scores, entity_type_embeddings,
                    list(linked_action_ids))

        # Lastly, we need to also create embedded representations of context-specific actions.  In
        # this case, those are only variable productions, like "r -> x".  Note that our language
        # only permits one lambda at a time, so we don't need to worry about how nested lambdas
        # might impact this.
        context_actions = {}
        for action_id, action in enumerate(possible_actions):
            if action[0].endswith(" -> x"):
                input_embedding = self._action_embedder(action[2])
                if self._add_action_bias:
                    input_bias = self._action_biases(action[2])
                    input_embedding = torch.cat([input_embedding, input_bias],
                                                dim=-1)
                output_embedding = self._output_action_embedder(action[2])
                context_actions[action[0]] = (input_embedding,
                                              output_embedding, action_id)

        return GrammarState([START_SYMBOL], {}, translated_valid_actions,
                            context_actions, type_declaration.is_nonterminal)