def make_state(group_index: int, action: int, new_score: torch.Tensor,
                       action_embedding: torch.Tensor) -> GrammarBasedState:
            batch_index = state.batch_indices[group_index]

            decoder_outputs = state.rnn_state[group_index].decoder_outputs
            is_linked_action = not state.possible_actions[batch_index][action][
                1]
            if is_linked_action:
                if decoder_outputs is None:
                    decoder_outputs = hidden_state[group_index].unsqueeze(0), [
                        action
                    ], predicted_action_embedding[group_index].unsqueeze(0)
                else:
                    decoder_outputs_states, decoder_outputs_ids, predicted_action_embeddings = decoder_outputs
                    decoder_outputs = torch.cat(
                        (decoder_outputs_states,
                         hidden_state[group_index].unsqueeze(0)),
                        dim=0), decoder_outputs_ids + [action], torch.cat(
                            (predicted_action_embeddings,
                             predicted_action_embedding[group_index].unsqueeze(
                                 0)),
                            dim=0)

            for i, _, current_log_probs, _, actions, lsq, lsp, gate, attended_dec in batch_action_probs[
                    batch_index]:
                if i == group_index:
                    considered_actions = actions
                    probabilities = current_log_probs.exp().cpu()
                    considered_lsq = lsq
                    considered_lsp = lsp
                    gate_value = gate
                    attended_decoder = attended_dec
                    break

            new_rnn_state = RnnStatelet(
                hidden_state[group_index], memory_cell[group_index],
                action_embedding, attended_question[group_index],
                state.rnn_state[group_index].encoder_outputs,
                state.rnn_state[group_index].encoder_output_mask,
                decoder_outputs)

            return state.new_state_from_group_index(
                group_index, action, new_score, new_rnn_state,
                considered_actions, probabilities,
                updated_rnn_state['attention_weights'], considered_lsq,
                considered_lsp, gate_value)
        def make_state(group_index: int,
                       action: int,
                       new_score: torch.Tensor,
                       action_embedding: torch.Tensor) -> GrammarBasedState:
            batch_index = state.batch_indices[group_index]
            action_entity_id = state.action_entity_mapping[batch_index][action] + 1  # add 1 so that -1 becomes 0 (pad)
            if not state.action_history[0]:
                decoded_item_embeddings = state.rnn_state[group_index].item_embeddings[action_entity_id].unsqueeze(0)
            else:
                decoded_item_embeddings = torch.cat((
                    state.rnn_state[group_index].decoded_item_embeddings,
                    state.rnn_state[group_index].item_embeddings[action_entity_id].unsqueeze(0)
                ), dim=0)

            new_rnn_state = RnnStatelet(hidden_state[group_index],
                                        memory_cell[group_index],
                                        action_embedding,
                                        attended_question[group_index],
                                        state.rnn_state[group_index].encoder_outputs,
                                        state.rnn_state[group_index].encoder_output_mask,
                                        state.rnn_state[group_index].item_embeddings,
                                        decoder_outputs[group_index],
                                        decoded_item_embeddings)
            for i, _, current_log_probs, _, actions in batch_action_probs[batch_index]:
                if i == group_index:
                    considered_actions = actions
                    probabilities = current_log_probs.exp().cpu()
                    break
            return state.new_state_from_group_index(group_index,
                                                    action,
                                                    new_score,
                                                    new_rnn_state,
                                                    considered_actions,
                                                    probabilities,
                                                    updated_rnn_state['attention_weights'],
                                                    updated_rnn_state['output_attention_weights'])
Exemplo n.º 3
0
    def _get_initial_state(
            self, utterance: Dict[str, torch.LongTensor],
            worlds: List[SpiderWorld], schema: Dict[str, torch.LongTensor],
            actions: List[List[ProductionRule]]) -> GrammarBasedState:
        schema_text = schema['text']
        """KAIMARY"""
        # TextFieldEmbedder needs a "token" key in the Dict
        """
        embedded_schema:torch.Size([batch_size, num_entities, max_num_entity_tokens, embedding_dim])
        schema_mask:torch.Size([batch_size, num_entities, max_num_entity_tokens])
        embedded_utterance:torch.Size([batch_size, max_utterance_size, embedding_dim])
        entity_type_embeddings:torch.Size([batch_size, num_entities, embedding_dim])
        """
        embedded_schema = self._question_embedder(schema_text,
                                                  num_wrapping_dims=1)
        schema_mask = util.get_text_field_mask(schema_text,
                                               num_wrapping_dims=1).float()

        embedded_utterance = self._question_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_schema.size()
        num_entities = max([
            len(world.db_context.knowledge_graph.entities) for world in worlds
        ])
        num_question_tokens = utterance['tokens'].size(1)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            worlds, num_entities, embedded_schema.device)

        entity_type_embeddings = self._entity_type_encoder_embedding(
            entity_types)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(
            embedded_schema.view(batch_size, num_entities * num_entity_tokens,
                                 self._embedding_dim),
            torch.transpose(embedded_utterance, 1, 2))

        question_entity_similarity = question_entity_similarity.view(
            batch_size, num_entities, num_entity_tokens, num_question_tokens)
        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(
            question_entity_similarity, 2)
        """KAIMARY"""
        # Variable: linking_scores
        # The entitiy linking score s(e, i) in the Krishnamurthy 2017
        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = schema['linking']

        linking_scores = question_entity_similarity_max_score

        feature_scores = self._linking_params(linking_features).squeeze(3)

        linking_scores = linking_scores + feature_scores
        """KAIMARY"""
        # linking_probabilities
        # The scores s(e,i) are then fed into a softmax layer over all entities e of the same type
        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(
            worlds, linking_scores.transpose(1, 2), utterance_mask,
            entity_type_dict)

        # (batch_size, num_entities, num_neighbors) or None
        neighbor_indices = self._get_neighbor_indices(worlds, num_entities,
                                                      linking_scores.device)

        if self._use_neighbor_similarity_for_linking and neighbor_indices is not None:
            """KAIMARY"""
            # Seq2VecEncoder get the hidden state of the last step as the unique output
            # (batch_size, num_entities, embedding_dim)
            encoded_table = self._entity_encoder(embedded_schema, schema_mask)

            # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
            # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
            # be added for the mask since that method expects 0 for padding.
            # (batch_size, num_entities, num_neighbors, embedding_dim)
            embedded_neighbors = util.batched_index_select(
                encoded_table, torch.abs(neighbor_indices))

            neighbor_mask = util.get_text_field_mask(
                {
                    'ignored': neighbor_indices + 1
                }, num_wrapping_dims=1).float()

            # Encoder initialized to easily obtain a masked average.
            neighbor_encoder = TimeDistributed(
                BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
            # (batch_size, num_entities, embedding_dim)
            embedded_neighbors = neighbor_encoder(embedded_neighbors,
                                                  neighbor_mask)
            projected_neighbor_embeddings = self._neighbor_params(
                embedded_neighbors.float())
            """KAIMARY"""
            # Variable: entity_embedding
            # Rv in B Bogin 2019
            # Is a learned embedding for the schema item v, which base the embedding on the type of v and its schema neighbors only
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings +
                                           projected_neighbor_embeddings)
        else:
            # (batch_size, num_entities, embedding_dim)
            entity_embeddings = torch.tanh(entity_type_embeddings)
        """KAIMARY"""
        # Variable: link_embedding
        # Li in B Bogin 2019
        # Is an average of entity vectors weighted by the resulting distribution
        link_embedding = util.weighted_sum(entity_embeddings,
                                           linking_probabilities)
        """KAIMARY"""
        # Variable: encoder_input
        # [Wi, Li] in B Bogin 2019
        encoder_input = torch.cat([link_embedding, embedded_utterance], 2)

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))
        """KAIMARY"""
        # Variable: max_entities_relevance
        # ρv = maxi plink(v | xi) in B Bogin 2019
        # Is the maximum probability of v for any word xi
        max_entities_relevance = linking_probabilities.max(dim=1)[0]
        entities_relevance = max_entities_relevance.unsqueeze(-1).detach()
        """KAIMARY"""
        # entity_type_embeddings ???
        # Variable: graph_initial_embedding
        # hv(0) in B Bogin 2019
        # Is an initial embedding conditioned on the relevance score, and then used to be fed into GNN
        graph_initial_embedding = entity_type_embeddings * entities_relevance

        encoder_output_dim = self._encoder.get_output_dim()
        if self._gnn:
            """KAIMARY"""
            # Variable: entities_graph_encoding
            # φv in  B Bogin 2019
            # Is the final representation of each schema item after L steps
            entities_graph_encoding = self._get_schema_graph_encoding(
                worlds, graph_initial_embedding)
            """KAIMARY"""
            # Variable: graph_link_embedding
            # Lφ,i in B Bogin 2019
            graph_link_embedding = util.weighted_sum(entities_graph_encoding,
                                                     linking_probabilities)
            encoder_outputs = torch.cat(
                (encoder_outputs, graph_link_embedding), dim=-1)
            encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim(
            )
        else:
            entities_graph_encoding = None

        if self._self_attend:
            # linked_actions_linking_scores = self._get_linked_actions_linking_scores(actions, entities_graph_encoding)
            entities_ff = self._ent2ent_ff(entities_graph_encoding)
            linked_actions_linking_scores = torch.bmm(
                entities_ff, entities_ff.transpose(1, 2))
        else:
            linked_actions_linking_scores = [None] * batch_size

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(
                worlds[i], actions[i], linking_scores[i],
                linked_actions_linking_scores[i], entity_types[i],
                entities_graph_encoding[i]
                if entities_graph_encoding is not None else None)
            for i in range(batch_size)
        ]

        initial_sql_state = [
            SqlState(actions[i], self.parse_sql_on_decoding)
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            sql_state=initial_sql_state,
            possible_actions=actions,
            action_entity_mapping=[
                w.get_action_entity_mapping() for w in worlds
            ])

        return initial_state
Exemplo n.º 4
0
        def make_state(group_index: int, action: int, new_score: torch.Tensor,
                       action_embedding: torch.Tensor) -> GrammarBasedState:
            """
            group_index:
                batch index
            action:
                next action index (created based on the allowed actions[supervision])
            new_score:
                is log probability (log_probs) of the action to now. So it must be a negative score.
                Calc log_probs in _compute_action_probabilities function, as follow:
                (new_score = ) log_probs = state.score[group_index] + current_log_probs
                log_probs is a accumulated value that accumulate the scores (new_score will become the scores in next round).
                Actually, the score is also used as the loss value.
            action_embedding:
                come from 'output_action_embeddings' in '_compute_action_probabilities'
            """

            # state is the Initial state. encoder state.
            # This function can access the parameter in _construct_next_states.
            batch_index = state.batch_indices[group_index]

            # Initial state.rnn_state do not contain the decoder_outputs.
            # Because the state is encoder state and there is no decoder in spider_parser.py that create the initial state.
            # But we can give a decoder_outputs to the new state in this function. So next time, the decoder_outputs will be not None.
            # But for the first time, decoder_outputs is None. If the next action is still the global action, it will keep None.
            # What is decoder_outputs, please see state_machines/states/rnn_statelet.py.
            decoder_outputs = state.rnn_state[group_index].decoder_outputs
            is_linked_action = not state.possible_actions[batch_index][action][
                1]
            if is_linked_action:  # non-global action, so store the hidden state.
                if decoder_outputs is None:
                    decoder_outputs = hidden_state[group_index].unsqueeze(0), [
                        action
                    ]
                else:  # store all hidden_state, so use torch.cat
                    decoder_outputs_states, decoder_outputs_ids = decoder_outputs
                    decoder_outputs = torch.cat(
                        (decoder_outputs_states,
                         hidden_state[group_index].unsqueeze(0)),
                        dim=0), decoder_outputs_ids + [action]

            # Create a new RnnStatelet to instead of the old one.
            new_rnn_state = RnnStatelet(
                hidden_state[group_index],  # updated_state['hidden_state'],
                memory_cell[group_index],  # updated_state['memory_cell']

                # It come from 'output_action_embeddings' in '_compute_action_probabilities'
                # It is the embeding from self._output_action_embedder if last action is 'global'.
                action_embedding,
                attended_question[
                    group_index],  # updated_state['attended_question']
                state.rnn_state[group_index].encoder_outputs,
                state.rnn_state[group_index].encoder_output_mask,
                decoder_outputs)
            for i, _, current_log_probs, _, actions, lsq, lsp in batch_action_probs[
                    batch_index]:
                if i == group_index:
                    considered_actions = actions
                    probabilities = current_log_probs.exp().cpu()
                    considered_lsq = lsq
                    considered_lsp = lsp
                    break
                else:
                    pass  #assert False # I think it will not be false. But it will be here after training, why ???
            return state.new_state_from_group_index(
                group_index, action, new_score, new_rnn_state,
                considered_actions, probabilities,
                updated_rnn_state['attention_weights'], considered_lsq,
                considered_lsp)
        def make_state(
                group_index: int, action: int, new_score: torch.Tensor,
                action_input_embedding: torch.Tensor,
                action_output_embedding: torch.Tensor) -> GrammarBasedState:
            batch_index = state.batch_indices[group_index]

            decoder_outputs = state.rnn_state[group_index].decoder_outputs
            is_linked_action = not state.possible_actions[batch_index][action][
                1]
            if is_linked_action:
                if decoder_outputs is None:
                    decoder_outputs = hidden_state[group_index].unsqueeze(0), [
                        action
                    ]
                else:
                    decoder_outputs_states, decoder_outputs_ids = decoder_outputs
                    decoder_outputs = torch.cat(
                        (decoder_outputs_states,
                         hidden_state[group_index].unsqueeze(0)),
                        dim=0), decoder_outputs_ids + [action]

            # Temporal encoding for both input and output action embeddings
            decoding_step = state.rnn_state[group_index].decoding_step + 1
            decoder_input_action_embeddings = state.rnn_state[
                group_index].decoder_input_action_embeddings
            if decoder_input_action_embeddings is None:
                decoder_input_action_embeddings = (
                    action_input_embedding + self.A_t(
                        torch.tensor(decoding_step,
                                     dtype=torch.long,
                                     device=action_input_embedding.device))
                ).unsqueeze(0)
            else:
                decoder_input_action_embeddings = torch.cat(
                    (decoder_input_action_embeddings,
                     (action_input_embedding + self.A_t(
                         torch.tensor(decoding_step,
                                      dtype=torch.long,
                                      device=action_input_embedding.device))
                      ).unsqueeze(0)),
                    dim=0)

            decoder_output_action_embeddings = state.rnn_state[
                group_index].decoder_output_action_embeddings
            if decoder_output_action_embeddings is None:
                decoder_output_action_embeddings = (
                    action_output_embedding + self.B_t(
                        torch.tensor(decoding_step,
                                     dtype=torch.long,
                                     device=action_output_embedding.device))
                ).unsqueeze(0)
            else:
                decoder_output_action_embeddings = torch.cat(
                    (decoder_output_action_embeddings,
                     (action_output_embedding + self.B_t(
                         torch.tensor(decoding_step,
                                      dtype=torch.long,
                                      device=action_output_embedding.device))
                      ).unsqueeze(0)),
                    dim=0)

            new_rnn_state = RnnStatelet(
                hidden_state[group_index], memory_cell[group_index],
                action_output_embedding, attended_question[group_index],
                state.rnn_state[group_index].encoder_outputs,
                state.rnn_state[group_index].encoder_output_mask,
                decoding_step, decoder_outputs,
                decoder_input_action_embeddings,
                decoder_output_action_embeddings)
            for i, _, current_log_probs, _, _, actions, lsq, lsp in batch_action_probs[
                    batch_index]:
                if i == group_index:
                    considered_actions = actions
                    probabilities = current_log_probs.exp().cpu()
                    considered_lsq = lsq
                    considered_lsp = lsp
                    break
            return state.new_state_from_group_index(
                group_index, action, new_score, new_rnn_state,
                considered_actions, probabilities,
                updated_rnn_state['attention_weights'], considered_lsq,
                considered_lsp)
Exemplo n.º 6
0
    def _get_initial_state(self,
                           utterance: Dict[str, torch.LongTensor],
                           worlds: List[SpiderWorld],
                           schema: Dict[str, torch.LongTensor],
                           actions: List[List[ProductionRule]]) -> GrammarBasedState:
        device = utterance['tokens'].device
        batch_size, num_entities, _, _ = schema['linking'].size()
        num_entity_tokens = max([max(len(t) for t in world.db_context.entity_tokens) for world in worlds])

        # Hack token ids:
        # We have multiple chunks in an utterance but all the entity tokens chunks should have same type.
        utterance['tokens-type-ids'][utterance['tokens-type-ids']>1]=1
        utterance['tokens-offsets'] = torch.cat((utterance['tokens-offsets'].new_zeros((batch_size, 1)), utterance['tokens-offsets']), dim=-1)

        u_lens = [len(world.db_context.tokenized_utterance) for world in worlds]
        num_question_tokens = max(u_lens)

        for i in range(0, batch_size):
            for j, t in enumerate(utterance['tokens'][i]):
                if t == 102:
                    utterance['tokens-type-ids'][i,j] = 0

        combined_embedding = self._question_embedder(utterance)
        u_embeddings = []
        u_e_embeddings = []
        embedding_padding_fn = lambda:torch.zeros(self._bert_embedding_dim, device=device)
        for i, world in enumerate(worlds):
            u_len = u_lens[i]
            u_embedding = combined_embedding[i][1:u_len+1]
            u_embeddings.append(torch.cat([u_embedding, torch.zeros((num_question_tokens-u_len, self._bert_embedding_dim), device=device)]))

            e_embeddings = []
            pos = u_len + 2
            for t in world.db_context.entity_tokens:
                e_embeddings.append(torch.cat([combined_embedding[i][pos:pos+len(t)], 
                    torch.zeros((num_entity_tokens-len(t), self._bert_embedding_dim), device=device)]))
                pos += len(t) + 1
            u_e_embeddings.append(torch.stack(pad_sequence_to_length(e_embeddings, desired_length=num_entities, 
                default_value=lambda:torch.zeros((num_entity_tokens, self._bert_embedding_dim), device=device))))
        
        embedded_schema = torch.stack(u_e_embeddings)
        schema_mask = (embedded_schema.sum(dim=-1) != 0).float()
        embedded_utterance = torch.stack(u_embeddings)
        utterance_mask = util.get_mask_from_sequence_lengths(torch.as_tensor(u_lens, device=device).long(), max_length=max(u_lens))

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(worlds, num_entities, embedded_schema.device)

        entity_type_embeddings = self._entity_type_encoder_embedding(entity_types)

        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        embedded_schema = self._entity_encoder(embedded_schema, schema_mask)

        entity_question_similarity = self._embedding_sim_attn(embedded_schema, embedded_utterance)

        feature_scores = self._linking_params(schema['linking']).squeeze(3)

        # (batch_size, num_entities, num_question_tokens)
        linking_scores = entity_question_similarity * 10 + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_prob_by_type = self._get_linking_probabilities(worlds, linking_scores.transpose(1, 2),
                                                                utterance_mask, entity_type_dict)

        entity_embeddings = entity_type_embeddings

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance, utterance_mask))

        # compute the relevance of each entity with the relevance GNN
        ent_relevance, ent_relevance_logits, ent_to_qst_lnk_probs = self._graph_pruning(worlds,
                                                                                        encoder_outputs,
                                                                                        entity_embeddings,
                                                                                        linking_scores,
                                                                                        utterance_mask,
                                                                                        self._get_graph_adj_lists)
        # save this for loss calculation
        self.predicted_relevance_logits = ent_relevance_logits

        # multiply the embedding with the computed relevance
        graph_initial_embedding = entity_embeddings * ent_relevance

        encoder_output_dim = self._encoder.get_output_dim()
        if self._gnn:
            entities_graph_encoding = self._get_schema_graph_encoding(worlds,
                                                                      graph_initial_embedding)
            graph_link_embedding = util.weighted_sum(entities_graph_encoding, linking_prob_by_type)
            encoder_outputs = torch.cat((
                encoder_outputs,
                graph_link_embedding
            ), dim=-1)
            encoder_output_dim = self._action_embedding_dim + self._encoder.get_output_dim()
        else:
            entities_graph_encoding = None

        if self._self_attend:
            linked_actions_linking_scores = torch.stack([self._graph_attention(entities_graph_encoding[:,i],
                entities_graph_encoding) for i in range(0, num_entities)]).transpose(0, 1)
        else:
            linked_actions_linking_scores = [None] * batch_size

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        memory_cell = encoder_outputs.new_zeros(batch_size, encoder_output_dim)
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_utterance,
                                                 encoder_output_list,
                                                 utterance_mask_list))

        initial_grammar_state = [self._create_grammar_state(worlds[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            linked_actions_linking_scores[i],
                                                            entity_types[i],
                                                            entities_graph_encoding[
                                                                i] if entities_graph_encoding is not None else None)
                                 for i in range(batch_size)]

        initial_sql_state = [SqlState(actions[i], self._parse_sql_on_decoding) for i in range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          sql_state=initial_sql_state,
                                          possible_actions=actions,
                                          action_entity_mapping=[w.get_action_entity_mapping() for w in worlds])

        return initial_state