def test_atis_grammar_statelet(self):
        valid_actions = None
        world = AtisWorld([("give me all flights from boston to "
                            "philadelphia next week arriving after lunch")])
        action_sequence = \
                ['statement -> [query, ";"]',
                 'query -> ["(", "SELECT", distinct, select_results, "FROM", table_refs, '
                 'where_clause, ")"]',
                 'distinct -> ["DISTINCT"]',
                 'select_results -> [col_refs]',
                 'col_refs -> [col_ref, ",", col_refs]',
                 'col_ref -> ["city", ".", "city_code"]',
                 'col_refs -> [col_ref]',
                 'col_ref -> ["city", ".", "city_name"]',
                 'table_refs -> [table_name]',
                 'table_name -> ["city"]',
                 'where_clause -> ["WHERE", "(", conditions, ")"]',
                 'conditions -> [condition]',
                 'condition -> [biexpr]',
                 'biexpr -> ["city", ".", "city_name", binaryop, city_city_name_string]',
                 'binaryop -> ["="]',
                 'city_city_name_string -> ["\'BOSTON\'"]']

        grammar_state = GrammarStatelet(['statement'],
                                        world.valid_actions,
                                        AtisSemanticParser.is_nonterminal)
        for action in action_sequence:
            grammar_state = grammar_state.take_action(action)
        assert grammar_state._nonterminal_stack == []
Ejemplo n.º 2
0
    def test_atis_grammar_statelet(self):
        valid_actions = None
        world = AtisWorld([("give me all flights from boston to "
                            "philadelphia next week arriving after lunch")])

        action_seq = [
            'statement -> [query, ";"]',
            'query -> ["(", "SELECT", distinct, select_results, "FROM", table_refs, '
            'where_clause, ")"]',
            'where_clause -> ["WHERE", "(", conditions, ")"]',
            'conditions -> [condition]', 'condition -> [biexpr]',
            'biexpr -> ["city", ".", "city_name", binaryop, city_city_name_string]',
            'city_city_name_string -> ["\'BOSTON\'"]', 'binaryop -> ["="]',
            'table_refs -> [table_name]', 'table_name -> ["city"]',
            'select_results -> [col_refs]',
            'col_refs -> [col_ref, ",", col_refs]', 'col_refs -> [col_ref]',
            'col_ref -> ["city", ".", "city_name"]',
            'col_ref -> ["city", ".", "city_code"]', 'distinct -> ["DISTINCT"]'
        ]

        grammar_state = GrammarStatelet(['statement'], {},
                                        world.valid_actions, {},
                                        AtisSemanticParser.is_nonterminal,
                                        reverse_productions=False)
        for action in action_seq:
            grammar_state = grammar_state.take_action(action)
        assert grammar_state._nonterminal_stack == []
    def test_grammar_statelet(self):
        valid_actions = None
        world = Text2SqlWorld(self.schema)

        sql = ["SELECT", "COUNT", "(", "*", ")", "FROM", "LOCATION", ",", "RESTAURANT", ";"]
        action_sequence, valid_actions = world.get_action_sequence_and_all_actions(sql)

        grammar_state = GrammarStatelet(
            ["statement"], valid_actions, Text2SqlParser.is_nonterminal, reverse_productions=True
        )
        for action in action_sequence:
            grammar_state = grammar_state.take_action(action)
        assert grammar_state._nonterminal_stack == []
Ejemplo n.º 4
0
    def test_grammar_statelet(self):
        valid_actions = None
        world = Text2SqlWorld(self.schema)

        sql = ['SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',', 'RESTAURANT', ';']
        action_sequence, valid_actions = world.get_action_sequence_and_all_actions(sql)

        grammar_state = GrammarStatelet(['statement'],
                                        valid_actions,
                                        Text2SqlParser.is_nonterminal,
                                        reverse_productions=True)
        for action in action_sequence:
            grammar_state = grammar_state.take_action(action)
        assert grammar_state._nonterminal_stack == [] # pylint: disable=protected-access
Ejemplo n.º 5
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = GrammarStatelet(["s"], {"s": s_actions, "t": t_actions}, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = GrammarStatelet(["t"], {"s": s_actions, "t": t_actions}, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = GrammarStatelet(
         ["e"], {"s": s_actions, "t": t_actions, "e": e_actions}, is_nonterminal
     )
     assert state.get_valid_actions() == e_actions
    def _create_grammar_state(
            self, world: NlvrWorld,
            possible_actions: List[ProductionRuleArray]) -> GrammarStatelet:
        valid_actions = world.get_valid_actions()
        action_mapping = {}
        for i, action in enumerate(possible_actions):
            action_mapping[action[0]] = i
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.
            action_indices = [
                action_mapping[action_string]
                for action_string in action_strings
            ]
            # All actions in NLVR are global actions.
            global_actions = [(possible_actions[index][2], index)
                              for index in action_indices]

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                       global_input_embeddings,
                                                       list(global_action_ids))
        return GrammarStatelet([START_SYMBOL], {}, translated_valid_actions,
                               {}, type_declaration.is_nonterminal)
    def test_grammar_statelet(self):
        valid_actions = None
        world = Text2SqlWorld(self.schema)

        sql = [
            'SELECT', 'COUNT', '(', '*', ')', 'FROM', 'LOCATION', ',',
            'RESTAURANT', ';'
        ]
        action_sequence, valid_actions = world.get_action_sequence_and_all_actions(
            sql)

        grammar_state = GrammarStatelet(['statement'],
                                        valid_actions,
                                        Text2SqlParser.is_nonterminal,
                                        reverse_productions=True)
        for action in action_sequence:
            grammar_state = grammar_state.take_action(action)
        assert grammar_state._nonterminal_stack == []  # pylint: disable=protected-access
Ejemplo n.º 8
0
    def _create_grammar_state(
            self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input ``ProductionRules``,
        and we use those to embed the valid actions for every non-terminal type.

        Parameters
        ----------
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = self._world.get_nonterminal_productions()

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                global_actions.append((production_rule_array[2], action_index))

            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors,
                                             dim=0).long()
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            global_output_embeddings = self._output_action_embedder(
                global_action_tensor)
            translated_valid_actions[key]["global"] = (
                global_input_embeddings,
                global_output_embeddings,
                list(global_action_ids),
            )

        return GrammarStatelet([START_SYMBOL], translated_valid_actions,
                               self._world.is_nonterminal)
 def test_get_valid_actions_adds_lambda_productions(self):
     state = GrammarStatelet(
         ['s'], {('s', 'x'): ['s']},
         {
             's': {
                 'global':
                 (torch.Tensor([1, 1]), torch.Tensor([2, 2]), [1, 2])
             }
         }, {'s -> x':
             (torch.Tensor([5]), torch.Tensor([6]), 5)}, is_nonterminal)
     actions = state.get_valid_actions()
     assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5])
     assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6])
     assert actions['global'][2] == [1, 2, 5]
     # We're doing this assert twice to make sure we haven't accidentally modified the state.
     actions = state.get_valid_actions()
     assert_almost_equal(actions['global'][0].cpu().numpy(), [1, 1, 5])
     assert_almost_equal(actions['global'][1].cpu().numpy(), [2, 2, 6])
     assert actions['global'][2] == [1, 2, 5]
    def _create_grammar_state(self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        """
        device = util.get_device_of(self._action_embedder.weight)
        # TODO(Mark): This type is pure \(- . ^)/
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}

        actions_grouped_by_nonterminal: Dict[str, List[Tuple[ProductionRule, int]]] = defaultdict(list)
        for i, action in enumerate(possible_actions):
            if action.rule == "":
                continue
            if action.is_global_rule:
                actions_grouped_by_nonterminal[action.nonterminal].append((action, i))
            else:
                raise ValueError("The sql parser doesn't support non-global actions yet.")

        for key, production_rule_arrays in actions_grouped_by_nonterminal.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                global_actions.append((production_rule_array.rule_id, action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0).long()
                if device >= 0:
                    global_action_tensor = global_action_tensor.to(device)

                global_input_embeddings = self._action_embedder(global_action_tensor)
                global_output_embeddings = self._output_action_embedder(global_action_tensor)

                translated_valid_actions[key]['global'] = (global_input_embeddings,
                                                           global_output_embeddings,
                                                           list(global_action_ids))
        return GrammarStatelet(['statement'],
                               translated_valid_actions,
                               self.is_nonterminal,
                               reverse_productions=True)
Ejemplo n.º 11
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = GrammarStatelet(['s'], {
         's': s_actions,
         't': t_actions
     }, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = GrammarStatelet(['t'], {
         's': s_actions,
         't': t_actions
     }, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = GrammarStatelet(['e'], {
         's': s_actions,
         't': t_actions,
         'e': e_actions
     }, is_nonterminal)
     assert state.get_valid_actions() == e_actions
Ejemplo n.º 12
0
 def test_get_valid_actions_uses_top_of_stack(self):
     s_actions = object()
     t_actions = object()
     e_actions = object()
     state = GrammarStatelet(['s'], {'s': s_actions, 't': t_actions}, is_nonterminal)
     assert state.get_valid_actions() == s_actions
     state = GrammarStatelet(['t'], {'s': s_actions, 't': t_actions}, is_nonterminal)
     assert state.get_valid_actions() == t_actions
     state = GrammarStatelet(['e'], {'s': s_actions, 't': t_actions, 'e': e_actions}, is_nonterminal)
     assert state.get_valid_actions() == e_actions
    def test_take_action_gives_correct_next_states_with_non_lambda_productions(
            self):
        # state.take_action() doesn't read or change these objects, it just passes them through, so
        # we'll use some sentinels to be sure of that.
        valid_actions = object()
        context_actions = object()

        state = GrammarStatelet(['s'], {}, valid_actions, context_actions,
                                is_nonterminal)
        next_state = state.take_action('s -> [t, r]')
        expected_next_state = GrammarStatelet(['r', 't'], {}, valid_actions,
                                              context_actions, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = GrammarStatelet(['r', 't'], {}, valid_actions, context_actions,
                                is_nonterminal)
        next_state = state.take_action('t -> identity')
        expected_next_state = GrammarStatelet(['r'], {}, valid_actions,
                                              context_actions, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__
Ejemplo n.º 14
0
 def test_take_action_crashes_with_mismatched_types(self):
     with pytest.raises(AssertionError):
         state = GrammarStatelet(['s'], {}, is_nonterminal)
         state.take_action('t -> identity')
Ejemplo n.º 15
0
    def _create_grammar_state(
            self, world: SpiderWorld, possible_actions: List[ProductionRule],
            linking_scores: torch.Tensor,
            linked_actions_linking_scores: torch.Tensor,
            entity_types: torch.Tensor,
            entity_graph_encoding: torch.Tensor) -> GrammarStatelet:
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities = world.entities_names

        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(
                    global_action_tensors,
                    dim=0).to(global_action_tensors[0].device).long()
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_output_embeddings,
                    list(global_action_ids))
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [
                    rule.split(' -> ')[1].strip('[]\"')
                    for rule in linked_rules
                ]

                entity_ids = [entity_map[entity] for entity in entities]

                entity_linking_scores = linking_scores[entity_ids]

                if linked_actions_linking_scores is not None:
                    entity_action_linking_scores = linked_actions_linking_scores[
                        entity_ids]

                if not self._decoder_use_graph_entities:
                    entity_type_tensor = entity_types[entity_ids]
                    entity_type_embeddings = (
                        self._entity_type_decoder_embedding(
                            entity_type_tensor).to(
                                entity_types.device).float())
                else:
                    entity_type_embeddings = entity_graph_encoding.index_select(
                        dim=0,
                        index=torch.tensor(
                            entity_ids, device=entity_graph_encoding.device))

                if self._self_attend:
                    translated_valid_actions[key]['linked'] = (
                        entity_linking_scores, entity_type_embeddings,
                        list(linked_action_ids), entity_action_linking_scores)
                else:
                    translated_valid_actions[key]['linked'] = (
                        entity_linking_scores, entity_type_embeddings,
                        list(linked_action_ids))

        return GrammarStatelet(['statement'], translated_valid_actions,
                               self.is_nonterminal)
Ejemplo n.º 16
0
    def setUp(self):
        super().setUp()
        self.decoder_step = BasicTransitionFunction(
            encoder_output_dim=2,
            action_embedding_dim=2,
            input_attention=Attention.by_name('dot_product')(),
            num_start_types=3,
            add_action_bias=False)

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [torch.FloatTensor([x]) for x in [.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        # This maps non-terminals to valid actions, where the valid actions are grouped by _type_.
        # We have "global" actions, which are from the global grammar, and "linked" actions, which
        # are instance-specific and are generated based on question attention.  Each action type
        # has a tuple which is (input representation, output representation, action ids).
        valid_actions = {
            'e': {
                'global': (torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]),
                           torch.FloatTensor([[-1, -1], [-2, -2],
                                              [-3, -3]]), [0, 1, 2]),
                'linked': (torch.FloatTensor([[.1, .2, .3], [.4, .5, .6]]),
                           torch.FloatTensor([[3, 3], [4, 4]]), [3, 4])
            },
            'd': {
                'global':
                (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1,
                                                                  -1]]), [0]),
                'linked': (torch.FloatTensor([[-.1, -.2, -.3], [-.4, -.5, -.6],
                                              [-.7, -.8, -.9]]),
                           torch.FloatTensor([[5, 5], [6, 6], [7,
                                                               7]]), [1, 2, 3])
            }
        }
        grammar_state = [
            GrammarStatelet([nonterminal], {}, valid_actions, {},
                            is_nonterminal)
            for _, nonterminal in zip(batch_indices, ['e', 'd', 'e'])
        ]
        self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]],
                                                  [[10, 11], [12, 13],
                                                   [14, 15]]])
        self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
        self.possible_actions = [[
            ('e -> f', False, None), ('e -> g', True, None),
            ('e -> h', True, None), ('e -> i', True, None),
            ('e -> j', True, None)
        ],
                                 [
                                     ('d -> q', True, None),
                                     ('d -> g', True, None),
                                     ('d -> h', True, None),
                                     ('d -> i', True, None)
                                 ]]

        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnStatelet(hidden_state[i], memory_cell[i],
                            previous_action_embedding[i], attended_question[i],
                            self.encoder_outputs, self.encoder_output_mask))
        self.state = GrammarBasedState(batch_indices=batch_indices,
                                       action_history=action_history,
                                       score=score,
                                       rnn_state=rnn_state,
                                       grammar_state=grammar_state,
                                       possible_actions=self.possible_actions)
Ejemplo n.º 17
0
    def _create_grammar_statelet(self,
                                 language: DropLanguage,
                                 possible_actions: List[ProductionRule]) -> Tuple[GrammarStatelet, Dict[str, int],
                                                                                  List[str]]:
                                 # linked_rule2idx: Dict = None,
                                 # action2ques_linkingscore: torch.FloatTensor = None,
                                 # quesspan_action_emb: torch.FloatTensor = None) -> GrammarStatelet:
        """ Make grammar state for a particular instance in the batch using the global and instance-specific actions.
        For each instance-specific action we have a linking_score vector (size:ques_tokens), and an action embedding

        Parameters:
        ------------
        world: `SampleHotpotWorld` The world for this instance
        possible_actions: All possible actions, global and instance-specific

        linked_rule2idx: Dict from linked_action to idx used for the next two members
        action2ques_linkingscore: Linking score matrix of size (instance-specific_actions, num_ques_tokens)
            The indexing is based on the linked_rule2idx dict. The num_ques_tokens is to a padded length
            The num_ques_tokens is to a padded length, because of which not using a dictionary but a tensor.
        quesspan_action_emb: Similarly, a (instance-specific_actions, action_embedding_dim) matrix.
            The indexing is based on the linked_rule2idx dict.
        """
        # ProductionRule: (rule, is_global_rule, rule_id, nonterminal)
        action2actionidx = {}
        actionidx2actionstr: List[str] = []
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action2actionidx[action_string] = action_index
            actionidx2actionstr.append(action_string)

        valid_actions = language.get_nonterminal_productions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]] = {}

        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [action2actionidx[action_string] for action_string in action_strings]
            production_rule_arrays = [(possible_actions[index], index) for index in action_indices]

            # For global_actions: (rule_vocab_id_tensor, action_index)
            global_actions = []

            for production_rule_array, action_index in production_rule_arrays:
                # production_rule_array: ProductionRule
                if production_rule_array[1]:
                    global_actions.append((production_rule_array[2], action_index))
                else:
                    raise NotImplementedError

            # First: Get the embedded representations of the global actions
            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0)
                # TODO(nitish): Figure out if need action_bias and separate input/output action embeddings
                # if self._add_action_bias:
                #     global_action_biases = self._action_biases(global_action_tensor)
                #     global_input_embeddings = torch.cat([global_input_embeddings, global_action_biases], dim=-1)
                global_output_embeddings = self._action_embedder(global_action_tensor)
                translated_valid_actions[key]['global'] = (global_output_embeddings,
                                                           global_output_embeddings,
                                                           list(global_action_ids))

        return (GrammarStatelet([START_SYMBOL],
                                translated_valid_actions,
                                language.is_nonterminal), action2actionidx, actionidx2actionstr)
Ejemplo n.º 18
0
    def _create_grammar_state(self, world: WikiTablesLanguage,
                              possible_actions: List[ProductionRuleArray],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of
        creating that is creating the `valid_actions` dictionary, which contains embedded
        representations of all of the valid actions.  So, we create that here as well.

        The way we represent the valid expansions is a little complicated: we use a
        dictionary of `action types`, where the key is the action type (like "global", "linked", or
        whatever your model is expecting), and the value is a tuple representing all actions of
        that type.  The tuple is (input tensor, output tensor, action id).  The input tensor has
        the representation that is used when `selecting` actions, for all actions of this type.
        The output tensor has the representation that is used when feeding the action to the next
        step of the decoder (this could just be the same as the input tensor).  The action ids are
        a list of indices into the main action list for each batch instance.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRuleArrays``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``WikiTablesLanguage``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRuleArray]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        # TODO(mattg): Move the "valid_actions" construction to another method.
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_nonterminal_productions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions if any.
            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0)
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                if self._add_action_bias:
                    global_action_biases = self._action_biases(
                        global_action_tensor)
                    global_input_embeddings = torch.cat(
                        [global_input_embeddings, global_action_biases],
                        dim=-1)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]['global'] = (
                    global_input_embeddings, global_output_embeddings,
                    list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(
                    entity_type_tensor)
                translated_valid_actions[key]['linked'] = (
                    entity_linking_scores, entity_type_embeddings,
                    list(linked_action_ids))
        return GrammarStatelet([START_SYMBOL], translated_valid_actions,
                               world.is_nonterminal)
Ejemplo n.º 19
0
    def _create_grammar_state(self, world: WikiTablesWorld,
                              possible_actions: List[ProductionRuleArray],
                              linking_scores: torch.Tensor,
                              entity_types: torch.Tensor) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRuleArrays``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``WikiTablesWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRuleArray]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_question_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index
        entity_map = {}
        for entity_index, entity in enumerate(world.table_graph.entities):
            entity_map[entity] = entity_index

        valid_actions = world.get_valid_actions()
        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            # Then we get the embedded representations of the global actions.
            global_action_tensors, global_action_ids = zip(*global_actions)
            global_action_tensor = torch.cat(global_action_tensors, dim=0)
            global_input_embeddings = self._action_embedder(
                global_action_tensor)
            if self._add_action_bias:
                global_action_biases = self._action_biases(
                    global_action_tensor)
                global_input_embeddings = torch.cat(
                    [global_input_embeddings, global_action_biases], dim=-1)
            global_output_embeddings = self._output_action_embedder(
                global_action_tensor)
            translated_valid_actions[key]['global'] = (
                global_input_embeddings, global_output_embeddings,
                list(global_action_ids))

            # Then the representations of the linked actions.
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = [rule.split(' -> ')[1] for rule in linked_rules]
                entity_ids = [entity_map[entity] for entity in entities]
                # (num_linked_actions, num_question_tokens)
                entity_linking_scores = linking_scores[entity_ids]
                # (num_linked_actions,)
                entity_type_tensor = entity_types[entity_ids]
                # (num_linked_actions, entity_type_embedding_dim)
                entity_type_embeddings = self._entity_type_decoder_embedding(
                    entity_type_tensor)
                translated_valid_actions[key]['linked'] = (
                    entity_linking_scores, entity_type_embeddings,
                    list(linked_action_ids))

        # Lastly, we need to also create embedded representations of context-specific actions.  In
        # this case, those are only variable productions, like "r -> x".  Note that our language
        # only permits one lambda at a time, so we don't need to worry about how nested lambdas
        # might impact this.
        context_actions = {}
        for action_id, action in enumerate(possible_actions):
            if action[0].endswith(" -> x"):
                input_embedding = self._action_embedder(action[2])
                if self._add_action_bias:
                    input_bias = self._action_biases(action[2])
                    input_embedding = torch.cat([input_embedding, input_bias],
                                                dim=-1)
                output_embedding = self._output_action_embedder(action[2])
                context_actions[action[0]] = (input_embedding,
                                              output_embedding, action_id)

        return GrammarStatelet([START_SYMBOL], {}, translated_valid_actions,
                               context_actions,
                               type_declaration.is_nonterminal)
Ejemplo n.º 20
0
    def setUp(self):
        super().setUp()
        self.decoder_step = BasicTransitionFunction(
            encoder_output_dim=2,
            action_embedding_dim=2,
            input_attention=Attention.by_name("dot_product")(),
            add_action_bias=False,
        )

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [torch.FloatTensor([x]) for x in [0.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        # This maps non-terminals to valid actions, where the valid actions are grouped by _type_.
        # We have "global" actions, which are from the global grammar, and "linked" actions, which
        # are instance-specific and are generated based on question attention.  Each action type
        # has a tuple which is (input representation, output representation, action ids).
        valid_actions = {
            "e": {
                "global": (
                    torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]),
                    torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]),
                    [0, 1, 2],
                ),
                "linked": (
                    torch.FloatTensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
                    torch.FloatTensor([[3, 3], [4, 4]]),
                    [3, 4],
                ),
            },
            "d": {
                "global":
                (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1,
                                                                  -1]]), [0]),
                "linked": (
                    torch.FloatTensor([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6],
                                       [-0.7, -0.8, -0.9]]),
                    torch.FloatTensor([[5, 5], [6, 6], [7, 7]]),
                    [1, 2, 3],
                ),
            },
        }
        grammar_state = [
            GrammarStatelet([nonterminal], valid_actions, is_nonterminal)
            for _, nonterminal in zip(batch_indices, ["e", "d", "e"])
        ]
        self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]],
                                                  [[10, 11], [12, 13],
                                                   [14, 15]]])
        self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
        self.possible_actions = [
            [
                ("e -> f", False, None),
                ("e -> g", True, None),
                ("e -> h", True, None),
                ("e -> i", True, None),
                ("e -> j", True, None),
            ],
            [
                ("d -> q", True, None),
                ("d -> g", True, None),
                ("d -> h", True, None),
                ("d -> i", True, None),
            ],
        ]

        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnStatelet(
                    hidden_state[i],
                    memory_cell[i],
                    previous_action_embedding[i],
                    attended_question[i],
                    self.encoder_outputs,
                    self.encoder_output_mask,
                ))
        self.state = GrammarBasedState(
            batch_indices=batch_indices,
            action_history=action_history,
            score=score,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            possible_actions=self.possible_actions,
        )
    def test_take_action_gives_correct_next_states_with_lambda_productions(
            self):
        # state.take_action() doesn't read or change these objects, it just passes them through, so
        # we'll use some sentinels to be sure of that.
        valid_actions = object()
        context_actions = object()

        state = GrammarStatelet(['t', '<s,d>'], {}, valid_actions,
                                context_actions, is_nonterminal)
        next_state = state.take_action('<s,d> -> [lambda x, d]')
        expected_next_state = GrammarStatelet(['t', 'd'], {('s', 'x'): ['d']},
                                              valid_actions, context_actions,
                                              is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('d -> [<s,r>, d]')
        expected_next_state = GrammarStatelet(['t', 'd', '<s,r>'],
                                              {('s', 'x'): ['d', '<s,r>']},
                                              valid_actions, context_actions,
                                              is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('<s,r> -> [lambda y, r]')
        expected_next_state = GrammarStatelet(['t', 'd', 'r'], {
            ('s', 'x'): ['d', 'r'],
            ('s', 'y'): ['r']
        }, valid_actions, context_actions, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('r -> identity')
        expected_next_state = GrammarStatelet(['t', 'd'], {('s', 'x'): ['d']},
                                              valid_actions, context_actions,
                                              is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__

        state = expected_next_state
        next_state = state.take_action('d -> x')
        expected_next_state = GrammarStatelet(['t'], {}, valid_actions,
                                              context_actions, is_nonterminal)
        assert next_state.__dict__ == expected_next_state.__dict__
Ejemplo n.º 22
0
 def test_is_finished_just_uses_nonterminal_stack(self):
     state = GrammarStatelet(['s'], {}, is_nonterminal)
     assert not state.is_finished()
     state = GrammarStatelet([], {}, is_nonterminal)
     assert state.is_finished()
    def _create_grammar_state(
        self,
        world: AtisWorld,
        possible_actions: List[ProductionRule],
        linking_scores: torch.Tensor,
        entity_types: torch.Tensor,
    ) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        world : ``AtisWorld``
            From the input to ``forward`` for a single batch instance.
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        linking_scores : ``torch.Tensor``
            Assumed to have shape ``(num_entities, num_utterance_tokens)`` (i.e., there is no batch
            dimension).
        entity_types : ``torch.Tensor``
            Assumed to have shape ``(num_entities,)`` (i.e., there is no batch dimension).
        """
        action_map = {}
        for action_index, action in enumerate(possible_actions):
            action_string = action[0]
            action_map[action_string] = action_index

        valid_actions = world.valid_actions
        entity_map = {}
        entities = world.entities

        for entity_index, entity in enumerate(entities):
            entity_map[entity] = entity_index

        translated_valid_actions: Dict[str, Dict[str, Tuple[torch.Tensor,
                                                            torch.Tensor,
                                                            List[int]]]] = {}
        for key, action_strings in valid_actions.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.

            action_indices = [
                action_map[action_string] for action_string in action_strings
            ]
            production_rule_arrays = [(possible_actions[index], index)
                                      for index in action_indices]
            global_actions = []
            linked_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                if production_rule_array[1]:
                    global_actions.append(
                        (production_rule_array[2], action_index))
                else:
                    linked_actions.append(
                        (production_rule_array[0], action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = (torch.cat(
                    global_action_tensors,
                    dim=0).to(entity_types.device).long())
                global_input_embeddings = self._action_embedder(
                    global_action_tensor)
                global_output_embeddings = self._output_action_embedder(
                    global_action_tensor)
                translated_valid_actions[key]["global"] = (
                    global_input_embeddings,
                    global_output_embeddings,
                    list(global_action_ids),
                )
            if linked_actions:
                linked_rules, linked_action_ids = zip(*linked_actions)
                entities = list(linked_rules)
                entity_ids = [entity_map[entity] for entity in entities]
                entity_linking_scores = linking_scores[entity_ids]
                entity_type_tensor = entity_types[entity_ids]
                entity_type_embeddings = (
                    self._entity_type_decoder_embedding(entity_type_tensor).to(
                        entity_types.device).float())
                translated_valid_actions[key]["linked"] = (
                    entity_linking_scores,
                    entity_type_embeddings,
                    list(linked_action_ids),
                )

        return GrammarStatelet(["statement"], translated_valid_actions,
                               self.is_nonterminal)
Ejemplo n.º 24
0
 def test_take_action_crashes_with_mismatched_types(self):
     with pytest.raises(AssertionError):
         state = GrammarStatelet(['s'], {}, is_nonterminal)
         state.take_action('t -> identity')
Ejemplo n.º 25
0
 def test_is_finished_just_uses_nonterminal_stack(self):
     state = GrammarStatelet(['s'], {}, is_nonterminal)
     assert not state.is_finished()
     state = GrammarStatelet([], {}, is_nonterminal)
     assert state.is_finished()