def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]):
        embedded_input = self._sentence_embedder(sentence)
        # (batch_size, sentence_length)
        sentence_mask = util.get_text_field_mask(sentence).float()

        batch_size = embedded_input.size(0)

        # (batch_size, sentence_length, encoder_output_dim)
        encoder_outputs = self._encoder(embedded_input, sentence_mask)

        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, sentence_mask, self._encoder.is_bidirectional())
        memory_cell = util.new_variable_with_size(
            encoder_outputs, (batch_size, self._encoder.get_output_dim()), 0)
        attended_sentence = self._decoder_step.attend_on_sentence(
            final_encoder_output, encoder_outputs, sentence_mask)
        encoder_outputs_list = [encoder_outputs[i] for i in range(batch_size)]
        sentence_mask_list = [sentence_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnState(final_encoder_output[i], memory_cell[i],
                         self._first_action_embedding, attended_sentence[i],
                         encoder_outputs_list, sentence_mask_list))
        return initial_rnn_state
 def make_state(group_index: int,
                action: int,
                new_score: torch.Tensor,
                action_embedding: torch.Tensor) -> GrammarBasedDecoderState:
     new_rnn_state = RnnState(hidden_state[group_index],
                              memory_cell[group_index],
                              action_embedding,
                              attended_question[group_index],
                              state.rnn_state[group_index].encoder_outputs,
                              state.rnn_state[group_index].encoder_output_mask)
     batch_index = state.batch_indices[group_index]
     for i, log_probs, _, actions in batch_action_probs[batch_index]:
         if i == group_index:
             considered_actions = actions
             probabilities = log_probs.exp().cpu()
             break
     return state.new_state_from_group_index(group_index,
                                             action,
                                             new_score,
                                             new_rnn_state,
                                             considered_actions,
                                             probabilities,
                                             updated_rnn_state['attention_weights'])
Beispiel #3
0
    def _compute_new_states(state: WikiTablesDecoderState,
                            log_probs: torch.Tensor,
                            hidden_state: torch.Tensor,
                            memory_cell: torch.Tensor,
                            action_embeddings: torch.Tensor,
                            attended_question: torch.Tensor,
                            attention_weights: torch.Tensor,
                            considered_actions: List[List[int]],
                            allowed_actions: List[Set[int]],
                            max_actions: int = None) -> List[WikiTablesDecoderState]:
        # Each group index here might get accessed multiple times, and doing the slicing operation
        # each time is more expensive than doing it once upfront.  These three lines give about a
        # 10% speedup in training time.  I also tried this with sorted_log_probs and
        # action_embeddings, but those get accessed for _each action_, so doing the splits there
        # didn't help.
        hidden_state = [x.squeeze(0) for x in hidden_state.split(1, 0)]
        memory_cell = [x.squeeze(0) for x in memory_cell.split(1, 0)]
        attended_question = [x.squeeze(0) for x in attended_question.split(1, 0)]

        sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True)
        if max_actions is not None:
            # We might need a version of `sorted_log_probs` on the CPU later, but only if we need
            # to truncate the best states to `max_actions`.
            sorted_log_probs_cpu = sorted_log_probs.data.cpu().numpy()
        if state.debug_info is not None:
            probs_cpu = log_probs.exp().data.cpu().numpy().tolist()
        sorted_actions = sorted_actions.data.cpu().numpy().tolist()
        best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list)
        for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices, sorted_actions)):
            for action_index, action in enumerate(group_actions):
                # `action` is currently the index in `log_probs`, not the actual action ID.  To get
                # the action ID, we need to go through `considered_actions`.
                action = considered_actions[group_index][action]
                if action == -1:
                    # This was padding.
                    continue
                if allowed_actions is not None and action not in allowed_actions[group_index]:
                    # This happens when our _decoder trainer_ wants us to only evaluate certain
                    # actions, likely because they are the gold actions in this state.  We just skip
                    # emitting any state that isn't allowed by the trainer, because constructing the
                    # new state can be expensive.
                    continue
                best_next_states[batch_index].append((group_index, action_index, action))
        new_states = []
        for batch_index, best_states in sorted(best_next_states.items()):
            if max_actions is not None:
                # We sorted previously by _group_index_, but we then combined by _batch_index_.  We
                # need to get the top next states for each _batch_ instance, so we sort all of the
                # instance's states again (across group index) by score.  We don't need to do this
                # if `max_actions` is None, because we'll be keeping all of the next states,
                # anyway.
                best_states.sort(key=lambda x: sorted_log_probs_cpu[x[:2]], reverse=True)
                best_states = best_states[:max_actions]
            for group_index, action_index, action in best_states:
                # We'll yield a bunch of states here that all have a `group_size` of 1, so that the
                # learning algorithm can decide how many of these it wants to keep, and it can just
                # regroup them later, as that's a really easy operation.
                batch_index = state.batch_indices[group_index]
                new_action_history = state.action_history[group_index] + [action]
                new_score = sorted_log_probs[group_index, action_index]

                # `action_index` is the index in the _sorted_ tensors, but the action embedding
                # matrix is _not_ sorted, so we need to get back the original, non-sorted action
                # index before we get the action embedding.
                action_embedding_index = sorted_actions[group_index][action_index]
                action_embedding = action_embeddings[group_index, action_embedding_index, :]
                production_rule = state.possible_actions[batch_index][action][0]
                new_grammar_state = state.grammar_state[group_index].take_action(production_rule)
                if state.debug_info is not None:
                    debug_info = {
                            'considered_actions': considered_actions[group_index],
                            'question_attention': attention_weights[group_index],
                            'probabilities': probs_cpu[group_index],
                            }
                    new_debug_info = [state.debug_info[group_index] + [debug_info]]
                else:
                    new_debug_info = None

                new_rnn_state = RnnState(hidden_state[group_index],
                                         memory_cell[group_index],
                                         action_embedding,
                                         attended_question[group_index],
                                         state.rnn_state[group_index].encoder_outputs,
                                         state.rnn_state[group_index].encoder_output_mask)

                new_state = WikiTablesDecoderState(batch_indices=[batch_index],
                                                   action_history=[new_action_history],
                                                   score=[new_score],
                                                   rnn_state=[new_rnn_state],
                                                   grammar_state=[new_grammar_state],
                                                   action_embeddings=state.action_embeddings,
                                                   output_action_embeddings=state.output_action_embeddings,
                                                   action_biases=state.action_biases,
                                                   action_indices=state.action_indices,
                                                   possible_actions=state.possible_actions,
                                                   flattened_linking_scores=state.flattened_linking_scores,
                                                   actions_to_entities=state.actions_to_entities,
                                                   entity_types=state.entity_types,
                                                   debug_info=new_debug_info)
                new_states.append(new_state)
        return new_states
    def _get_initial_state_and_scores(self,
                                      question,
                                      table,
                                      world,
                                      actions,
                                      example_lisp_string=None,
                                      add_world_to_initial_state=False,
                                      checklist_states=None):
        u"""
        Does initial preparation and creates an intiial state for both the semantic parsers. Note
        that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to
        pass it.
        """
        table_text = table[u'text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text,
                                              num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(
            encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask(
            {
                u'ignored': neighbor_indices + 1
            }, num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(
            BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                              neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(
            embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.tanh(entity_type_embeddings +
                                       projected_neighbor_embeddings)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table[u'linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())

        initial_score = embedded_question.data.new_zeros(batch_size)

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(
            actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(
            linking_scores, world, actions)
        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnState(final_encoder_output[i], memory_cell[i],
                         self._first_action_embedding,
                         self._first_attended_question, encoder_output_list,
                         question_mask_list))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i])
            for i in range(batch_size)
        ]
        initial_state_world = world if add_world_to_initial_state else None
        initial_state = WikiTablesDecoderState(
            batch_indices=range(batch_size),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            action_embeddings=action_embeddings,
            output_action_embeddings=output_action_embeddings,
            action_biases=action_biases,
            action_indices=action_indices,
            possible_actions=actions,
            flattened_linking_scores=flattened_linking_scores,
            actions_to_entities=actions_to_entities,
            entity_types=entity_type_dict,
            world=initial_state_world,
            example_lisp_string=example_lisp_string,
            checklist_state=checklist_states,
            debug_info=None)
        return {
            u"initial_state": initial_state,
            u"linking_scores": linking_scores,
            u"feature_scores": feature_scores,
            u"similarity_scores": question_entity_similarity_max_score
        }
Beispiel #5
0
    def _compute_new_states(cls,
                            state: NlvrDecoderState,
                            action_logprobs: List[List[Tuple[int,
                                                             torch.Tensor]]],
                            hidden_state: torch.Tensor,
                            memory_cell: torch.Tensor,
                            action_embeddings: torch.Tensor,
                            attended_sentence: torch.Tensor,
                            considered_actions: List[List[int]],
                            allowed_actions: List[Set[int]] = None,
                            new_checklist_states: List[
                                List[ChecklistState]] = None,
                            max_actions: int = None) -> List[NlvrDecoderState]:
        """
        This method is very similar to ``WikiTabledDecoderStep._compute_new_states``.
        The difference here is that we also keep track of checklists if they are passed to this
        method.
        """
        # batch_index -> group_index, action_index, checklist, score
        states_info: Dict[int, List[Tuple[int, int, torch.Tensor,
                                          torch.Tensor]]] = defaultdict(list)
        if new_checklist_states is None:
            # We do not have checklist states. Making a list of lists of Nones of the appropriate size for
            # the zips below.
            new_checklist_states = [[None for logprob in instance_logprobs]
                                    for instance_logprobs in action_logprobs]
        for group_index, instance_info in enumerate(
                zip(state.batch_indices, action_logprobs,
                    new_checklist_states)):
            batch_index, instance_logprobs, instance_new_checklist_states = instance_info
            for (action_index,
                 score), checklist_state in zip(instance_logprobs,
                                                instance_new_checklist_states):
                states_info[batch_index].append(
                    (group_index, action_index, checklist_state, score))

        new_states = []
        for batch_index, instance_states_info in states_info.items():
            batch_scores = torch.cat(
                [info[-1] for info in instance_states_info])
            _, sorted_indices = batch_scores.sort(-1, descending=True)
            sorted_states_info = [
                instance_states_info[i]
                for i in sorted_indices.data.cpu().numpy()
            ]
            allowed_states_info = []
            for i, (group_index, action_index, _,
                    _) in enumerate(sorted_states_info):
                action = considered_actions[group_index][action_index]
                if allowed_actions is not None and action not in allowed_actions[
                        group_index]:
                    continue
                allowed_states_info.append(sorted_states_info[i])
            sorted_states_info = allowed_states_info
            if max_actions is not None:
                sorted_states_info = sorted_states_info[:max_actions]
            for group_index, action_index, new_checklist_state, new_score in sorted_states_info:
                # This is the actual index of the action from the original list of actions.
                # We do not have to check whether it is the padding index because ``take_step``
                # already took care of that.
                action = considered_actions[group_index][action_index]
                action_embedding = action_embeddings[group_index,
                                                     action_index, :]
                new_action_history = state.action_history[group_index] + [
                    action
                ]
                production_rule = state.possible_actions[batch_index][action][
                    0]
                new_grammar_state = state.grammar_state[
                    group_index].take_action(production_rule)
                new_rnn_state = RnnState(
                    hidden_state[group_index], memory_cell[group_index],
                    action_embedding, attended_sentence[group_index],
                    state.rnn_state[group_index].encoder_outputs,
                    state.rnn_state[group_index].encoder_output_mask)
                new_state = NlvrDecoderState(
                    batch_indices=[batch_index],
                    action_history=[new_action_history],
                    score=[new_score],
                    rnn_state=[new_rnn_state],
                    grammar_state=[new_grammar_state],
                    action_embeddings=state.action_embeddings,
                    action_indices=state.action_indices,
                    possible_actions=state.possible_actions,
                    worlds=state.worlds,
                    label_strings=state.label_strings,
                    checklist_state=[new_checklist_state])
                new_states.append(new_state)
        return new_states
    def setUp(self):
        super().setUp()
        self.decoder_step = BasicTransitionFunction(
            encoder_output_dim=2,
            action_embedding_dim=2,
            input_attention=Attention.by_name('dot_product')(),
            num_start_types=3,
            add_action_bias=False)

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [torch.FloatTensor([x]) for x in [.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        # This maps non-terminals to valid actions, where the valid actions are grouped by _type_.
        # We have "global" actions, which are from the global grammar, and "linked" actions, which
        # are instance-specific and are generated based on question attention.  Each action type
        # has a tuple which is (input representation, output representation, action ids).
        valid_actions = {
            'e': {
                'global': (torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]),
                           torch.FloatTensor([[-1, -1], [-2, -2],
                                              [-3, -3]]), [0, 1, 2]),
                'linked': (torch.FloatTensor([[.1, .2, .3], [.4, .5, .6]]),
                           torch.FloatTensor([[3, 3], [4, 4]]), [3, 4])
            },
            'd': {
                'global':
                (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1,
                                                                  -1]]), [0]),
                'linked': (torch.FloatTensor([[-.1, -.2, -.3], [-.4, -.5, -.6],
                                              [-.7, -.8, -.9]]),
                           torch.FloatTensor([[5, 5], [6, 6], [7,
                                                               7]]), [1, 2, 3])
            }
        }
        grammar_state = [
            GrammarState([nonterminal], {}, valid_actions, {}, is_nonterminal)
            for _, nonterminal in zip(batch_indices, ['e', 'd', 'e'])
        ]
        self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]],
                                                  [[10, 11], [12, 13],
                                                   [14, 15]]])
        self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
        self.possible_actions = [[
            ('e -> f', False, None), ('e -> g', True, None),
            ('e -> h', True, None), ('e -> i', True, None),
            ('e -> j', True, None)
        ],
                                 [
                                     ('d -> q', True, None),
                                     ('d -> g', True, None),
                                     ('d -> h', True, None),
                                     ('d -> i', True, None)
                                 ]]

        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnState(hidden_state[i], memory_cell[i],
                         previous_action_embedding[i], attended_question[i],
                         self.encoder_outputs, self.encoder_output_mask))
        self.state = GrammarBasedDecoderState(
            batch_indices=batch_indices,
            action_history=action_history,
            score=score,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            possible_actions=self.possible_actions)
Beispiel #7
0
    def setUp(self):
        super().setUp()

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [Variable(torch.FloatTensor([x])) for x in [.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        grammar_state = [
            GrammarState(['e'], {}, {}, {}, is_nonterminal)
            for _ in batch_indices
        ]
        self.encoder_outputs = torch.FloatTensor([[1, 2], [3, 4], [5, 6]])
        self.encoder_output_mask = Variable(
            torch.FloatTensor([[1, 1], [1, 0], [1, 1]]))
        self.action_embeddings = torch.FloatTensor([[0, 0], [1, 1], [2, 2],
                                                    [3, 3], [4, 4], [5, 5]])
        self.action_indices = {
            (0, 0): 1,
            (0, 1): 0,
            (0, 2): 2,
            (0, 3): 3,
            (0, 4): 5,
            (1, 0): 4,
            (1, 1): 0,
            (1, 2): 2,
            (1, 3): 3,
        }
        self.possible_actions = [[
            ('e -> f', False, None), ('e -> g', True, None),
            ('e -> h', True, None), ('e -> i', True, None),
            ('e -> j', True, None)
        ],
                                 [
                                     ('e -> q', True, None),
                                     ('e -> g', True, None),
                                     ('e -> h', True, None),
                                     ('e -> i', True, None)
                                 ]]

        # (batch_size, num_entities, num_question_tokens) = (2, 5, 3)
        linking_scores = Variable(
            torch.Tensor([[[.1, .2, .3], [.4, .5, .6], [.7, .8, .9],
                           [1.0, 1.1, 1.2], [1.3, 1.4, 1.5]],
                          [[-.1, -.2, -.3], [-.4, -.5, -.6], [-.7, -.8, -.9],
                           [-1.0, -1.1, -1.2], [-1.3, -1.4, -1.5]]]))
        flattened_linking_scores = linking_scores.view(2 * 5, 3)

        # Maps (batch_index, action_index) to indices into the flattened linking score tensor,
        # which has shae (batch_size * num_entities, num_question_tokens).
        actions_to_entities = {
            (0, 0): 0,
            (0, 1): 1,
            (0, 2): 2,
            (0, 6): 3,
            (1, 3): 6,
            (1, 4): 7,
            (1, 5): 8,
        }
        entity_types = {
            0: 0,
            1: 2,
            2: 1,
            3: 0,
            4: 0,
            5: 1,
            6: 0,
            7: 1,
            8: 2,
        }
        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnState(hidden_state[i], memory_cell[i],
                         previous_action_embedding[i], attended_question[i],
                         self.encoder_outputs, self.encoder_output_mask))
        self.state = WikiTablesDecoderState(
            batch_indices=batch_indices,
            action_history=action_history,
            score=score,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            action_embeddings=self.action_embeddings,
            action_indices=self.action_indices,
            possible_actions=self.possible_actions,
            flattened_linking_scores=flattened_linking_scores,
            actions_to_entities=actions_to_entities,
            entity_types=entity_types)
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesWorld],
                actions: List[List[ProductionRuleArray]],
                example_lisp_string: List[str] = None,
                target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default=None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word cosine similarity. Need to add a small value to
        # to the table norm since there are padding values which cause a divide by 0.
        embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
        embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = question_entity_similarity_max_score + feature_scores
        else:
            # The linking score is computed as a linear projection of two terms. The first is the maximum
            # similarity score over the entity's words and the question token. The second is the maximum
            # similarity over the words in the entity's neighbors and the question token.
            #   The second term, projected_question_neighbor_similarity, is useful when
            # a column needs to be selected. For example, the question token might have no similarity
            # with the column name, but is similar with the cells in the column.
            #   Note that projected_question_neighbor_similarity is intended to capture the same information
            # as the related_column feature.
            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = Variable(encoder_outputs.data.new(batch_size, self._encoder.get_output_dim()).fill_(0))

        initial_score = Variable(embedded_question.data.new(batch_size).fill_(0))

        action_embeddings, action_indices = self._embed_actions(actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores,
                                                                                     world,
                                                                                     actions)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnState(final_encoder_output[i],
                                              memory_cell[i],
                                              self._first_action_embedding,
                                              self._first_attended_question,
                                              encoder_output_list,
                                              question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i], actions[i])
                                 for i in range(batch_size)]
        initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)),
                                               action_history=[[] for _ in range(batch_size)],
                                               score=initial_score_list,
                                               rnn_state=initial_rnn_state,
                                               grammar_state=initial_grammar_state,
                                               action_embeddings=action_embeddings,
                                               action_indices=action_indices,
                                               possible_actions=actions,
                                               flattened_linking_scores=flattened_linking_scores,
                                               actions_to_entities=actions_to_entities,
                                               entity_types=entity_type_dict,
                                               debug_info=None)
        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs: Dict[str, Any] = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            outputs['linking_scores'] = linking_scores
            if self._linking_params is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    if example_lisp_string:
                        self._denotation_accuracy(logical_form, example_lisp_string[i])
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                    outputs['entities'].append(world[i].table_graph.entities)
                else:
                    outputs['logical_form'].append('')
                    self._has_logical_form(0.0)
                    if example_lisp_string:
                        self._denotation_accuracy(None, example_lisp_string[i])
            return outputs
Beispiel #9
0
    def _get_initial_rnn_and_grammar_state(
            self, question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld],
            actions: List[List[ProductionRuleArray]],
            outputs: Dict[str,
                          Any]) -> Tuple[List[RnnState], List[GrammarState]]:
        """
        Encodes the question and table, computes a linking between the two, and constructs an
        initial RnnState and GrammarState for each batch instance to pass to the decoder.

        We take ``outputs`` as a parameter here and `modify` it, adding things that we want to
        visualize in a demo.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text,
                                              num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(
            encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask(
            {
                'ignored': neighbor_indices + 1
            }, num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(
            BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                              neighbor_mask)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)
        projected_neighbor_embeddings = self._neighbor_params(
            embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.tanh(entity_type_embeddings +
                                       projected_neighbor_embeddings)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnState(final_encoder_output[i], memory_cell[i],
                         self._first_action_embedding,
                         self._first_attended_question, encoder_output_list,
                         question_mask_list))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i], linking_scores[i],
                                       entity_types[i])
            for i in range(batch_size)
        ]
        if not self.training:
            # We add a few things to the outputs that will be returned from `forward` at evaluation
            # time, for visualization in a demo.
            outputs['linking_scores'] = linking_scores
            if feature_scores is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
        return initial_rnn_state, initial_grammar_state