Exemple #1
0
    def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]):
        embedded_input = self._sentence_embedder(sentence)
        # (batch_size, sentence_length)
        sentence_mask = util.get_text_field_mask(sentence).float()

        batch_size = embedded_input.size(0)

        # (batch_size, sentence_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(embedded_input, sentence_mask))

        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, sentence_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        attended_sentence, _ = self._decoder_step.attend_on_question(
            final_encoder_output, encoder_outputs, sentence_mask)
        encoder_outputs_list = [encoder_outputs[i] for i in range(batch_size)]
        sentence_mask_list = [sentence_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding, attended_sentence[i],
                            encoder_outputs_list, sentence_mask_list))
        return initial_rnn_state
 def make_state(group_index: int, action: int, new_score: torch.Tensor,
                action_embedding: torch.Tensor) -> GrammarBasedState:
     new_rnn_state = RnnStatelet(
         hidden_state[group_index],
         memory_cell[group_index],
         action_embedding,
         attended_question[group_index],
         state.rnn_state[group_index].encoder_outputs,
         state.rnn_state[group_index].encoder_output_mask,
     )
     batch_index = state.batch_indices[group_index]
     for i, _, current_log_probs, _, actions in batch_action_probs[
             batch_index]:
         if i == group_index:
             considered_actions = actions
             probabilities = current_log_probs.exp().cpu()
             break
     return state.new_state_from_group_index(
         group_index,
         action,
         new_score,
         new_rnn_state,
         considered_actions,
         probabilities,
         updated_rnn_state["attention_weights"],
     )
    def _get_initial_state(
            self, encoder_outputs: torch.Tensor, mask: torch.Tensor,
            actions: List[List[ProductionRule]]) -> GrammarBasedState:

        batch_size = encoder_outputs.size(0)
        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        initial_score = encoder_outputs.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(
                    final_encoder_output[i],
                    memory_cell[i],
                    self._first_action_embedding,
                    self._first_attended_utterance,
                    encoder_output_list,
                    utterance_mask_list,
                ))

        initial_grammar_state = [
            self._create_grammar_state(actions[i]) for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            debug_info=None,
        )
        return initial_state
Exemple #4
0
    def _get_initial_rnn_state(
        self,
        question_encoded: torch.FloatTensor,
        question_mask: torch.Tensor,
        question_encoded_finalstate: torch.FloatTensor,
        question_encoded_aslist: List[torch.Tensor],
        question_mask_aslist: List[torch.Tensor],
    ):
        """ Get the initial RnnStatelet for the decoder based on the question encoding

        Parameters:
        -----------
        ques_repr: (B, question_length, D)
        ques_mask: (B, question_length)
        question_final_repr: (B, D)
        ques_encoded_list: [(question_length, D)] - List of length B
        ques_mask_list: [(question_length)] - List of length B
        """

        batch_size = question_encoded_finalstate.size(0)
        ques_encoded_dim = question_encoded_finalstate.size()[-1]

        # Shape: (B, D)
        memory_cell = question_encoded_finalstate.new_zeros(
            batch_size, ques_encoded_dim)
        # TODO(nitish): Why does WikiTablesParser use '_first_attended_question' embedding and not this
        attended_sentence, _ = self._decoder_step.attend_on_question(
            question_encoded_finalstate, question_encoded, question_mask)

        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(
                    question_encoded_finalstate[i],
                    memory_cell[i],
                    self._first_action_embedding,
                    attended_sentence[i],
                    question_encoded_aslist,
                    question_mask_aslist,
                ))
        return initial_rnn_state
Exemple #5
0
    def _get_initial_rnn_and_grammar_state(
        self,
        question: Dict[str, torch.LongTensor],
        table: Dict[str, torch.LongTensor],
        world: List[WikiTablesLanguage],
        actions: List[List[ProductionRuleArray]],
        outputs: Dict[str, Any],
    ) -> Tuple[List[RnnStatelet], List[GrammarStatelet]]:
        """
        Encodes the question and table, computes a linking between the two, and constructs an
        initial RnnStatelet and GrammarStatelet for each batch instance to pass to the
        decoder.

        We take ``outputs`` as a parameter here and `modify` it, adding things that we want to
        visualize in a demo.
        """
        table_text = table["text"]
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question)
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, encoded_table)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # (batch_size, num_entities, num_neighbors) or None
        neighbor_indices = self._get_neighbor_indices(world, num_entities,
                                                      encoded_table)

        if neighbor_indices is not None:
            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
            # be added for the mask since that method expects 0 for padding.
            # (batch_size, num_entities, num_neighbors, embedding_dim)
            embedded_neighbors = util.batched_index_select(
                encoded_table, torch.abs(neighbor_indices))

            neighbor_mask = util.get_text_field_mask(
                {
                    "ignored": {
                        "ignored": neighbor_indices + 1
                    }
                },
                num_wrapping_dims=1).float()

            # Encoder initialized to easily obtain a masked average.
            neighbor_encoder = TimeDistributed(
                BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
            # (batch_size, num_entities, embedding_dim)
            embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                                  neighbor_mask)
            projected_neighbor_embeddings = self._neighbor_params(
                embedded_neighbors.float())

            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings +
                                           projected_neighbor_embeddings)
        else:
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_table.view(batch_size, num_entities * num_entity_tokens,
                                self._embedding_dim),
            torch.transpose(embedded_question, 1, 2),
        )

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table["linking"]

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(
                question_entity_similarity_max_score,
                torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(
                question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(
                    -1)
            linking_scores = (projected_question_entity_similarity +
                              projected_question_neighbor_similarity)

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            world, linking_scores.transpose(1, 2), question_mask,
            entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(
                    final_encoder_output[i],
                    memory_cell[i],
                    self._first_action_embedding,
                    self._first_attended_question,
                    encoder_output_list,
                    question_mask_list,
                ))
        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i], linking_scores[i],
                                       entity_types[i])
            for i in range(batch_size)
        ]
        if not self.training:
            # We add a few things to the outputs that will be returned from `forward` at evaluation
            # time, for visualization in a demo.
            outputs["linking_scores"] = linking_scores
            if feature_scores is not None:
                outputs["feature_scores"] = feature_scores
            outputs["similarity_scores"] = question_entity_similarity_max_score
        return initial_rnn_state, initial_grammar_state
Exemple #6
0
    def _get_initial_state(
        self,
        utterance: Dict[str, torch.LongTensor],
        worlds: List[AtisWorld],
        actions: List[List[ProductionRule]],
        linking_scores: torch.Tensor,
    ) -> GrammarBasedState:
        embedded_utterance = self._utterance_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance)

        batch_size = embedded_utterance.size(0)
        num_entities = max([len(world.entities) for world in worlds])

        # entity_types: tensor with shape (batch_size, num_entities)
        entity_types, _ = self._get_type_vector(worlds, num_entities,
                                                embedded_utterance)

        # (batch_size, num_utterance_tokens, embedding_dim)
        encoder_input = embedded_utterance

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            if self._decoder_num_layers > 1:
                initial_rnn_state.append(
                    RnnStatelet(
                        final_encoder_output[i].repeat(
                            self._decoder_num_layers, 1),
                        memory_cell[i].repeat(self._decoder_num_layers, 1),
                        self._first_action_embedding,
                        self._first_attended_utterance,
                        encoder_output_list,
                        utterance_mask_list,
                    ))
            else:
                initial_rnn_state.append(
                    RnnStatelet(
                        final_encoder_output[i],
                        memory_cell[i],
                        self._first_action_embedding,
                        self._first_attended_utterance,
                        encoder_output_list,
                        utterance_mask_list,
                    ))

        initial_grammar_state = [
            self._create_grammar_state(worlds[i], actions[i],
                                       linking_scores[i], entity_types[i])
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            debug_info=None,
        )
        return initial_state
    def setUp(self):
        super().setUp()
        self.decoder_step = BasicTransitionFunction(
            encoder_output_dim=2,
            action_embedding_dim=2,
            input_attention=Attention.by_name("dot_product")(),
            add_action_bias=False,
        )

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [torch.FloatTensor([x]) for x in [0.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        # This maps non-terminals to valid actions, where the valid actions are grouped by _type_.
        # We have "global" actions, which are from the global grammar, and "linked" actions, which
        # are instance-specific and are generated based on question attention.  Each action type
        # has a tuple which is (input representation, output representation, action ids).
        valid_actions = {
            "e": {
                "global": (
                    torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]),
                    torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]),
                    [0, 1, 2],
                ),
                "linked": (
                    torch.FloatTensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
                    torch.FloatTensor([[3, 3], [4, 4]]),
                    [3, 4],
                ),
            },
            "d": {
                "global":
                (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1,
                                                                  -1]]), [0]),
                "linked": (
                    torch.FloatTensor([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6],
                                       [-0.7, -0.8, -0.9]]),
                    torch.FloatTensor([[5, 5], [6, 6], [7, 7]]),
                    [1, 2, 3],
                ),
            },
        }
        grammar_state = [
            GrammarStatelet([nonterminal], valid_actions, is_nonterminal)
            for _, nonterminal in zip(batch_indices, ["e", "d", "e"])
        ]
        self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]],
                                                  [[10, 11], [12, 13],
                                                   [14, 15]]])
        self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
        self.possible_actions = [
            [
                ("e -> f", False, None),
                ("e -> g", True, None),
                ("e -> h", True, None),
                ("e -> i", True, None),
                ("e -> j", True, None),
            ],
            [
                ("d -> q", True, None),
                ("d -> g", True, None),
                ("d -> h", True, None),
                ("d -> i", True, None),
            ],
        ]

        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnStatelet(
                    hidden_state[i],
                    memory_cell[i],
                    previous_action_embedding[i],
                    attended_question[i],
                    self.encoder_outputs,
                    self.encoder_output_mask,
                ))
        self.state = GrammarBasedState(
            batch_indices=batch_indices,
            action_history=action_history,
            score=score,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            possible_actions=self.possible_actions,
        )
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[QuarelWorld],
                actions: List[List[ProductionRule]],
                entity_bits: torch.Tensor = None,
                denotation_target: torch.Tensor = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        entity_bits : ``torch.Tensor``, optional (default=None)
            Tensor encoding bits for the world entities.
        denotation_target : ``torch.Tensor``, optional (default=None)
            If model's field ``denotation_only`` is True, this is the tensor target denotation.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        metadata : List[Dict[str, Any]], optional (default=None).
            A dictionary of metadata for each batch element which has keys:
                question_tokens : ``List[str]``, optional.
                    The original string tokens in the question.
                world_extractions : ``nltk.Tree``, optional.
                    Extracted worlds from the question.
                answer_index : ``List[str]``, optional.
                    Index of the correct answer.
        """

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                           num_entities * num_entity_tokens,
                                                                           self._embedding_dim),
                                                       torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(batch_size,
                                       num_question_tokens*num_entities*num_entity_tokens,
                                       self._embedding_dim)
                question_entity_similarity = self._entity_similarity_layer(product)
                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table['linking']

            if self._linking_params is not None:
                feature_scores = self._linking_params(linking_features).squeeze(3)
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                    question_mask, entity_type_dict)
            encoder_input = embedded_question
        else:
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input = torch.cat([embedded_question, entity_bits], 2)
            else:
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(final_encoder_output)
            loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim)

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))

        initial_grammar_state = [self._create_grammar_state(world[i], actions[i],
                                                            linking_scores[i], entity_types[i])
                                 for i in range(batch_size)]

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          extras=None,
                                          debug_info=None)

        if self.training:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
            return outputs

        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            if self._linking_params is not None:
                outputs['linking_scores'] = linking_scores
                outputs['feature_scores'] = feature_scores
                outputs['linking_features'] = linking_features
            if self._use_entities:
                outputs['linking_probabilities'] = linking_probabilities
            if entity_bits is not None:
                outputs['entity_bits'] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            outputs['denotation_acc'] = []
            outputs['score'] = []
            outputs['parse_acc'] = []
            outputs['answer_index'] = []
            if metadata is not None:
                outputs['question_tokens'] = []
                outputs['world_extractions'] = []
            for i in range(batch_size):
                if metadata is not None:
                    outputs['question_tokens'].append(metadata[i].get('question_tokens', []))
                if metadata is not None:
                    outputs['world_extractions'].append(metadata[i].get('world_extractions', {}))
                outputs['entities'].append(world[i].table_graph.entities)
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and 'answer_index' in metadata[i]:
                        answer_index = metadata[i]['answer_index']
                        denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index)
                        self._denotation_accuracy(denotation_accuracy)
                    score = math.exp(best_final_states[i][0].score[0].data.cpu().item())
                    outputs['answer_index'].append(predicted_answer_index)
                    outputs['score'].append(score)
                    outputs['parse_acc'].append(sequence_in_targets)
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['denotation_acc'].append(denotation_accuracy)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                else:
                    outputs['parse_acc'].append(0)
                    outputs['logical_form'].append('')
                    outputs['denotation_acc'].append(0)
                    outputs['score'].append(0)
                    outputs['answer_index'].append(-1)
                    outputs['best_action_sequence'].append([])
                    outputs['debug_info'].append([])
                    self._has_logical_form(0.0)
            return outputs