def pad_token_sequence(
            self, tokens: Dict[str,
                               List[List[int]]], desired_num_tokens: Dict[str,
                                                                          int],
            padding_lengths: Dict[str, int]) -> Dict[str, List[List[int]]]:
        # Pad the tokens.
        # tokens has only one key...
        key = list(tokens.keys())[0]

        padded_tokens = pad_sequence_to_length(
            tokens[key],
            desired_num_tokens[key],
            default_value=self.get_padding_token)

        # Pad the characters within the tokens.
        desired_token_length = padding_lengths['num_token_characters']
        longest_token: List[int] = max(tokens[key], key=len, default=[])
        padding_value = 0
        if desired_token_length > len(longest_token):
            # Since we want to pad to greater than the longest token, we add a
            # "dummy token" so we can take advantage of the fast implementation of itertools.zip_longest.
            padded_tokens.append([padding_value] * desired_token_length)
        # pad the list of lists to the longest sublist, appending 0's
        padded_tokens = list(
            zip(*itertools.zip_longest(*padded_tokens,
                                       fillvalue=padding_value)))
        if desired_token_length > len(longest_token):
            # Removes the "dummy token".
            padded_tokens.pop()
        # Truncates all the tokens to the desired length, and return the result.
        return {
            key:
            [list(token[:desired_token_length]) for token in padded_tokens]
        }
 def as_tensor(self,
               padding_lengths: Dict[str, int],
               cuda_device: int = -1) -> torch.Tensor:
     desired_num_tokens = padding_lengths['num_tokens']
     padded_tags = pad_sequence_to_length(self._indexed_labels,
                                          desired_num_tokens)
     tensor = torch.LongTensor(padded_tags)
     return tensor if cuda_device == -1 else tensor.cuda(cuda_device)
示例#3
0
 def pad_token_sequence(
     self, tokens: Dict[str, List[int]], desired_num_tokens: Dict[str, int],
     padding_lengths: Dict[str, int]
 ) -> Dict[str, List[int]]:  # pylint: disable=unused-argument
     return {
         key: pad_sequence_to_length(val, desired_num_tokens[key])
         for key, val in tokens.items()
     }
示例#4
0
    def _get_neighbor_indices(worlds: List[WikiTablesWorld], num_entities: int,
                              tensor: torch.Tensor) -> torch.LongTensor:
        """
        This method returns the indices of each entity's neighbors. A tensor
        is accepted as a parameter for copying purposes.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_neighbors)``. It is padded
        with -1 instead of 0, since 0 is a valid neighbor index.
        """

        num_neighbors = 0
        for world in worlds:
            for entity in world.table_graph.entities:
                if len(world.table_graph.neighbors[entity]) > num_neighbors:
                    num_neighbors = len(world.table_graph.neighbors[entity])

        batch_neighbors = []
        for world in worlds:
            # Each batch instance has its own world, which has a corresponding table.
            entities = world.table_graph.entities
            entity2index = {entity: i for i, entity in enumerate(entities)}
            entity2neighbors = world.table_graph.neighbors
            neighbor_indexes = []
            for entity in entities:
                entity_neighbors = [
                    entity2index[n] for n in entity2neighbors[entity]
                ]
                # Pad with -1 instead of 0, since 0 represents a neighbor index.
                padded = pad_sequence_to_length(entity_neighbors,
                                                num_neighbors, lambda: -1)
                neighbor_indexes.append(padded)
            neighbor_indexes = pad_sequence_to_length(
                neighbor_indexes, num_entities, lambda: [-1] * num_neighbors)
            batch_neighbors.append(neighbor_indexes)
        return tensor.new_tensor(batch_neighbors, dtype=torch.long)
示例#5
0
 def as_tensor(self,
               padding_lengths: Dict[str, int],
               cuda_device: int = -1) -> Dict[str, torch.Tensor]:
     tensors = {}
     desired_num_entities = padding_lengths['num_entities']
     desired_num_entity_tokens = padding_lengths['num_entity_tokens']
     desired_num_utterance_tokens = padding_lengths['num_utterance_tokens']
     for indexer_name, indexer in self._token_indexers.items():
         padded_entities = util.pad_sequence_to_length(
             self._indexed_entity_texts[indexer_name],
             desired_num_entities,
             default_value=lambda: [])
         padded_arrays = []
         for padded_entity in padded_entities:
             padded_array = indexer.pad_token_sequence(
                 {'key': padded_entity}, {'key': desired_num_entity_tokens},
                 padding_lengths)['key']
             padded_arrays.append(padded_array)
         tensor = torch.LongTensor(padded_arrays)
         tensors[
             indexer_name] = tensor if cuda_device == -1 else tensor.cuda(
                 cuda_device)
     padded_linking_features = util.pad_sequence_to_length(
         self.linking_features,
         desired_num_entities,
         default_value=lambda: [])
     padded_linking_arrays = []
     default_feature_value = lambda: [0.0] * len(self._feature_extractors)
     for linking_features in padded_linking_features:
         padded_features = util.pad_sequence_to_length(
             linking_features,
             desired_num_utterance_tokens,
             default_value=default_feature_value)
         padded_linking_arrays.append(padded_features)
     linking_features_tensor = torch.FloatTensor(padded_linking_arrays)
     if cuda_device != -1:
         linking_features_tensor = linking_features_tensor.cuda(cuda_device)
     return {'text': tensors, 'linking': linking_features_tensor}
示例#6
0
    def _get_type_vector(
            worlds: List[WikiTablesWorld], num_entities: int,
            tensor: torch.Tensor) -> Tuple[torch.LongTensor, Dict[int, int]]:
        """
        Produces the one hot encoding for each entity's type. In addition,
        a map from a flattened entity index to type is returned to combine
        entity type operations into one method.

        Parameters
        ----------
        worlds : ``List[WikiTablesWorld]``
        num_entities : ``int``
        tensor : ``torch.Tensor``
            Used for copying the constructed list onto the right device.

        Returns
        -------
        A ``torch.LongTensor`` with shape ``(batch_size, num_entities, num_types)``.
        entity_types : ``Dict[int, int]``
            This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
        """
        entity_types = {}
        batch_types = []
        for batch_index, world in enumerate(worlds):
            types = []
            for entity_index, entity in enumerate(world.table_graph.entities):
                one_hot_vectors = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0],
                                   [0, 0, 0, 1]]
                # We need numbers to be first, then cells, then parts, then row, because our
                # entities are going to be sorted.  We do a split by type and then a merge later,
                # and it relies on this sorting.
                if entity.startswith('fb:cell'):
                    entity_type = 1
                elif entity.startswith('fb:part'):
                    entity_type = 2
                elif entity.startswith('fb:row'):
                    entity_type = 3
                else:
                    entity_type = 0
                types.append(one_hot_vectors[entity_type])

                # For easier lookups later, we're actually using a _flattened_ version
                # of (batch_index, entity_index) for the key, because this is how the
                # linking scores are stored.
                flattened_entity_index = batch_index * num_entities + entity_index
                entity_types[flattened_entity_index] = entity_type
            padded = pad_sequence_to_length(types, num_entities,
                                            lambda: [0, 0, 0, 0])
            batch_types.append(padded)
        return tensor.new_tensor(batch_types), entity_types
示例#7
0
 def as_tensor(self,
               padding_lengths: Dict[str, int],
               cuda_device: int = -1) -> DataArray:
     padded_field_list = pad_sequence_to_length(
         self.field_list, padding_lengths['num_fields'],
         self.field_list[0].empty_field)
     # Here we're removing the scoping on the padding length keys that we added in
     # `get_padding_lengths`; see the note there for more detail.
     child_padding_lengths = {
         key.replace('list_', '', 1): value
         for key, value in padding_lengths.items()
         if key.startswith('list_')
     }
     padded_fields = [
         field.as_tensor(child_padding_lengths, cuda_device)
         for field in padded_field_list
     ]
     return self.field_list[0].batch_tensors(padded_fields)
示例#8
0
    def _get_action_embeddings(state: NlvrDecoderState,
                               actions_to_embed: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        This method is identical to ``WikiTablesDecoderStep._get_action_embeddings``
        Returns an embedded representation for all actions in ``actions_to_embed``, using the state
        in ``NlvrDecoderState``.

        Parameters
        ----------
        state : ``NlvrDecoderState``
            The current state.  We'll use this to get the global action embeddings.
        actions_to_embed : ``List[List[int]]``
            A list of _global_ action indices for each group element.  Should have shape
            (group_size, num_actions), unpadded.

        Returns
        -------
        action_embeddings : ``torch.FloatTensor``
            An embedded representation of all of the given actions.  Shape is ``(group_size,
            num_actions, action_embedding_dim)``, where ``num_actions`` is the maximum number of
            considered actions for any group element.
        action_mask : ``torch.LongTensor``
            A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index,
            action_index)`` pairs were merely added as padding.
        """
        num_actions = [len(action_list) for action_list in actions_to_embed]
        max_num_actions = max(num_actions)
        padded_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions)
                          for action_list in actions_to_embed]
        # Shape: (group_size, num_actions)
        action_tensor = state.score[0].new_tensor(padded_actions, dtype=torch.long)
        # `state.action_embeddings` is shape (total_num_actions, action_embedding_dim).
        # We want to select from state.action_embeddings using `action_tensor` to get a tensor of
        # shape (group_size, num_actions, action_embedding_dim).  Unfortunately, the index_select
        # functions in nn.util don't do this operation.  So we'll do some reshapes and do the
        # index_select ourselves.
        group_size = len(state.batch_indices)
        action_embedding_dim = state.action_embeddings.size(-1)
        flattened_actions = action_tensor.view(-1)
        flattened_action_embeddings = state.action_embeddings.index_select(0, flattened_actions)
        action_embeddings = flattened_action_embeddings.view(group_size, max_num_actions, action_embedding_dim)
        sequence_lengths = action_embeddings.new_tensor(num_actions)
        action_mask = nn_util.get_mask_from_sequence_lengths(sequence_lengths, max_num_actions)
        return action_embeddings, action_mask
示例#9
0
 def test_pad_sequence_to_length(self):
     assert util.pad_sequence_to_length([1, 2, 3], 5) == [1, 2, 3, 0, 0]
     assert util.pad_sequence_to_length(
         [1, 2, 3], 5, default_value=lambda: 2) == [1, 2, 3, 2, 2]
     assert util.pad_sequence_to_length(
         [1, 2, 3], 5, padding_on_right=False) == [0, 0, 1, 2, 3]
示例#10
0
    def _get_entity_action_logits(
        self,
        state: WikiTablesDecoderState,
        actions_to_link: List[List[int]],
        attention_weights: torch.Tensor,
        linked_checklist_balance: torch.Tensor = None
    ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.FloatTensor]:
        """
        Returns scores for each action in ``actions_to_link`` that are derived from the linking
        scores between the question and the table entities, and the current attention on the
        question.  The intuition is that if we're paying attention to a particular word in the
        question, we should tend to select entity productions that we think that word refers to.
        We additionally return a mask representing which elements in the returned ``action_logits``
        tensor are just padding, and an embedded representation of each action that can be used as
        input to the next step of the encoder.  That embedded representation is derived from the
        type of the entity produced by the action.

        The ``actions_to_link`` are in terms of the `batch` action list passed to
        ``model.forward()``.  We need to convert these integers into indices into the linking score
        tensor, which has shape (batch_size, num_entities, num_question_tokens), look up the
        linking score for each entity, then aggregate the scores using the current question
        attention.

        Parameters
        ----------
        state : ``WikiTablesDecoderState``
            The current state.  We'll use this to get the linking scores.
        actions_to_link : ``List[List[int]]``
            A list of _batch_ action indices for each group element.  Should have shape
            (group_size, num_actions), unpadded.  This is expected to be output from
            :func:`_get_actions_to_consider`.
        attention_weights : ``torch.Tensor``
            The current attention weights over the question tokens.  Should have shape
            ``(group_size, num_question_tokens)``.
        linked_checklist_balance : ``torch.Tensor``, optional (default=None)
            If the parser is being trained to maximize coverage over an agenda, this is the balance
            vector corresponding to entity actions, containing 1s and 0s, with 1s showing the
            actions that are yet to be produced. Required only if the parser is being trained to
            maximize coverage.

        Returns
        -------
        action_logits : ``torch.FloatTensor``
            A score for each of the given actions.  Shape is ``(group_size, num_actions)``, where
            ``num_actions`` is the maximum number of considered actions for any group element.
        action_mask : ``torch.LongTensor``
            A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index,
            action_index)`` pairs were merely added as padding.
        type_embeddings : ``torch.LongTensor``
            A tensor of shape ``(group_size, num_actions, action_embedding_dim)``, with an embedded
            representation of the `type` of the entity corresponding to each action.
        """
        # First we map the actions to entity indices, using state.actions_to_entities, and find the
        # type of each entity using state.entity_types.
        action_entities: List[List[int]] = []
        entity_types: List[List[int]] = []
        for batch_index, action_list in zip(state.batch_indices,
                                            actions_to_link):
            action_entities.append([])
            entity_types.append([])
            for action_index in action_list:
                entity_index = state.actions_to_entities[(batch_index,
                                                          action_index)]
                action_entities[-1].append(entity_index)
                entity_types[-1].append(state.entity_types[entity_index])

        # Then we create a padded tensor suitable for use with
        # `state.flattened_linking_scores.index_select()`.
        num_actions = [len(action_list) for action_list in action_entities]
        max_num_actions = max(num_actions)
        padded_actions = [
            common_util.pad_sequence_to_length(action_list, max_num_actions)
            for action_list in action_entities
        ]
        padded_types = [
            common_util.pad_sequence_to_length(type_list, max_num_actions)
            for type_list in entity_types
        ]
        # Shape: (group_size, num_actions)
        action_tensor = state.score[0].new_tensor(padded_actions,
                                                  dtype=torch.long)
        type_tensor = state.score[0].new_tensor(padded_types, dtype=torch.long)

        # To get the type embedding tensor, we just use an embedding matrix on the list of entity
        # types.
        type_embeddings = self._entity_type_embedding(type_tensor)
        # `state.flattened_linking_scores` is shape (batch_size * num_entities, num_question_tokens).
        # We want to select from this using `action_tensor` to get a tensor of shape (group_size,
        # num_actions, num_question_tokens).  Unfortunately, the index_select functions in nn.util
        # don't do this operation.  So we'll do some reshapes and do the index_select ourselves.
        group_size = len(state.batch_indices)
        num_question_tokens = state.flattened_linking_scores.size(-1)
        flattened_actions = action_tensor.view(-1)
        # (group_size * num_actions, num_question_tokens)
        flattened_action_linking = state.flattened_linking_scores.index_select(
            0, flattened_actions)
        # (group_size, num_actions, num_question_tokens)
        action_linking = flattened_action_linking.view(group_size,
                                                       max_num_actions,
                                                       num_question_tokens)

        # Now we get action logits by weighting these entity x token scores by the attention over
        # the question tokens.  We can do this efficiently with torch.bmm.
        action_logits = action_linking.bmm(
            attention_weights.unsqueeze(-1)).squeeze(-1)
        if linked_checklist_balance is not None:
            # ``linked_checklist_balance`` is a binary tensor of size (group_size, num_actions) with
            # 1s indicating the linked actions that the agenda wants the decoder to produce, but
            # haven't been produced yet. We're simply doubling the logits of those actions here.
            action_logits_addition = action_logits * linked_checklist_balance
            action_logits = action_logits + self._linked_checklist_multiplier * action_logits_addition
        # Finally, we make a mask for our action logit tensor.
        sequence_lengths = action_linking.new_tensor(num_actions)
        action_mask = util.get_mask_from_sequence_lengths(
            sequence_lengths, max_num_actions)
        return action_logits, action_mask, type_embeddings
示例#11
0
    def _get_action_embeddings(
        state: WikiTablesDecoderState, actions_to_embed: List[List[int]]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns an embedded representation for all actions in ``actions_to_embed``, using the state
        in ``WikiTablesDecoderState``.

        Parameters
        ----------
        state : ``WikiTablesDecoderState``
            The current state.  We'll use this to get the global action embeddings.
        actions_to_embed : ``List[List[int]]``
            A list of _global_ action indices for each group element.  Should have shape
            (group_size, num_actions), unpadded.  This is expected to be output from
            :func:`_get_actions_to_consider`.

        Returns
        -------
        action_embeddings : ``torch.FloatTensor``
            An embedded representation of all of the given actions.  Shape is ``(group_size,
            num_actions, action_embedding_dim)``, where ``num_actions`` is the maximum number of
            considered actions for any group element.
        output_action_embeddings : ``torch.FloatTensor``
            A second embedded representation of all of the given actions.  The first is used when
            selecting actions, the second is used as the decoder output (which is the input at the
            next timestep).  This is similar to having separate word embeddings and softmax layer
            weights in a language model or MT model.
        action_biases : ``torch.FloatTensor``
            A bias weight for predicting each action.  Shape is ``(group_size, num_actions, 1)``.
        action_mask : ``torch.LongTensor``
            A mask of shape ``(group_size, num_actions)`` indicating which ``(group_index,
            action_index)`` pairs were merely added as padding.
        """
        num_actions = [len(action_list) for action_list in actions_to_embed]
        max_num_actions = max(num_actions)
        padded_actions = [
            common_util.pad_sequence_to_length(action_list, max_num_actions)
            for action_list in actions_to_embed
        ]
        # Shape: (group_size, num_actions)
        action_tensor = state.score[0].new_tensor(padded_actions,
                                                  dtype=torch.long)
        # `state.action_embeddings` is shape (total_num_actions, action_embedding_dim).
        # We want to select from state.action_embeddings using `action_tensor` to get a tensor of
        # shape (group_size, num_actions, action_embedding_dim).  Unfortunately, the index_select
        # functions in nn.util don't do this operation.  So we'll do some reshapes and do the
        # index_select ourselves.
        group_size = len(state.batch_indices)
        action_embedding_dim = state.action_embeddings.size(-1)

        flattened_actions = action_tensor.view(-1)
        flattened_action_embeddings = state.action_embeddings.index_select(
            0, flattened_actions)
        action_embeddings = flattened_action_embeddings.view(
            group_size, max_num_actions, action_embedding_dim)

        flattened_output_embeddings = state.output_action_embeddings.index_select(
            0, flattened_actions)
        output_embeddings = flattened_output_embeddings.view(
            group_size, max_num_actions, action_embedding_dim)

        flattened_biases = state.action_biases.index_select(
            0, flattened_actions)
        biases = flattened_biases.view(group_size, max_num_actions, 1)

        sequence_lengths = action_embeddings.new_tensor(num_actions)
        action_mask = util.get_mask_from_sequence_lengths(
            sequence_lengths, max_num_actions)
        return action_embeddings, output_embeddings, biases, action_mask
示例#12
0
    def _get_checklist_balance(
        state: WikiTablesDecoderState, unlinked_terminal_indices: List[int],
        actions_to_link: List[List[int]]
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        # This holds a list of checklist balances for this state. Each balance is a float vector
        # containing just 1s and 0s showing which of the items are filled. We clamp the min at 0
        # to ignore the number of times an action is taken. The value at an index will be 1 iff
        # the target wants an unmasked action to be taken, and it is not yet taken. All elements
        # in each balance corresponding to masked actions will be 0.
        checklist_balances = []
        for instance_checklist_state in state.checklist_state:
            checklist_balance = torch.clamp(
                instance_checklist_state.get_balance(), min=0.0)
            checklist_balances.append(checklist_balance)

        checklist_balance = torch.stack([x for x in checklist_balances])
        checklist_balance = checklist_balance.squeeze(
            2)  # (group_size, num_terminals)
        # We now need to split the ``checklist_balance`` into two tensors, one corresponding to
        # linked actions and the other to unlinked actions because they affect the output action
        # logits differently. We use ``unlinked_terminal_indices`` and ``actions_to_link`` to do that, but
        # the indices in those lists are indices of all actions, and the checklist balance
        # corresponds only to the terminal actions.
        # To make things more confusing, ``actions_to_link`` has batch action indices, and
        # ``unlinked_terminal_indices`` has global action indices.
        mapped_actions_to_link = []
        mapped_actions_to_embed = []
        # Mapping from batch action indices to checklist indices for each instance in group.
        batch_actions_to_checklist = [
            checklist_state.terminal_indices_dict
            for checklist_state in state.checklist_state
        ]
        for group_index, batch_index in enumerate(state.batch_indices):
            instance_mapped_embedded_actions = []
            for action in unlinked_terminal_indices:
                batch_action_index = state.global_to_batch_action_indices[(
                    batch_index, action)]
                if batch_action_index in batch_actions_to_checklist[
                        group_index]:
                    checklist_index = batch_actions_to_checklist[group_index][
                        batch_action_index]
                else:
                    # This means that the embedded action is not a terminal, because the checklist
                    # indices only correspond to terminal actions.
                    checklist_index = -1
                instance_mapped_embedded_actions.append(checklist_index)
            mapped_actions_to_embed.append(instance_mapped_embedded_actions)
        # We don't need to pad the unlinked actions because they're all currently the
        # same size as ``unlinked_terminal_indices``.
        unlinked_action_indices = checklist_balance.new_tensor(
            mapped_actions_to_embed, dtype=torch.long)
        unlinked_actions_mask = (unlinked_action_indices != -1).long()
        # torch.gather would complain if the indices are -1. So making them all 0 now. We'll use the
        # mask again on the balances.
        unlinked_action_indices = unlinked_action_indices * unlinked_actions_mask

        unlinked_checklist_balance = torch.gather(checklist_balance, 1,
                                                  unlinked_action_indices)
        unlinked_checklist_balance = unlinked_checklist_balance * unlinked_actions_mask.float(
        )
        # If actions_to_link is None, it means that all the valid actions in the current state need
        # to be embedded. We simply return None for checklist balance corresponding to linked
        # actions then.
        linked_checklist_balance = None
        if actions_to_link:
            for group_index, instance_actions_to_link in enumerate(
                    actions_to_link):
                mapped_actions_to_link.append([
                    batch_actions_to_checklist[group_index][action]
                    for action in instance_actions_to_link
                ])
            # We need to pad the linked action indices before we use them to gather appropriate balances.
            # Some of the indices may be 0s. So we need to make the padding index -1.
            max_num_linked_actions = max(
                [len(indices) for indices in mapped_actions_to_link])
            padded_actions_to_link = [
                common_util.pad_sequence_to_length(indices,
                                                   max_num_linked_actions,
                                                   default_value=lambda: -1)
                for indices in mapped_actions_to_link
            ]
            linked_action_indices = checklist_balance.new_tensor(
                padded_actions_to_link, dtype=torch.long)
            linked_actions_mask = (linked_action_indices != -1).long()
            linked_action_indices = linked_action_indices * linked_actions_mask
            linked_checklist_balance = torch.gather(checklist_balance, 1,
                                                    linked_action_indices)
            linked_checklist_balance = linked_checklist_balance * linked_actions_mask.float(
            )
        return linked_checklist_balance, unlinked_checklist_balance
示例#13
0
    def take_step(self,  # type: ignore
                  state: NlvrDecoderState,
                  max_actions: int = None,
                  allowed_actions: List[Set[int]] = None) -> List[NlvrDecoderState]:
        """
        Given a ``NlvrDecoderState``, returns a list of next states that are sorted by their scores.
        This method is very similar to ``WikiTablesDecoderStep._take_step``. The differences are
        that depending on the type of supervision being used, we may not have a notion of
        "allowed actions" here, and we do not perform entity linking here.
        """
        # Outline here: first we'll construct the input to the decoder, which is a concatenation of
        # an embedding of the decoder input (the last action taken) and an attention over the
        # sentence.  Then we'll update our decoder's hidden state given this input, and recompute
        # an attention over the sentence given our new hidden state.  We'll use a concatenation of
        # the new hidden state, the new attention, and optionally the checklist balance to predict an
        # output, then yield new states. We will compute and use a checklist balance when
        # ``allowed_actions`` is None, with the assumption that the ``DecoderTrainer`` that is
        # calling this method is trying to train a parser without logical form supervision.
        # TODO (pradeep): Make the distinction between the two kinds of trainers in the way they
        # call this method more explicit.

        # Each new state corresponds to one valid action that can be taken from the current state,
        # and they are ordered by model scores.
        attended_sentence = torch.stack([rnn_state.attended_input for rnn_state in state.rnn_state])
        hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])
        memory_cell = torch.stack([rnn_state.memory_cell for rnn_state in state.rnn_state])
        previous_action_embedding = torch.stack([rnn_state.previous_action_embedding
                                                 for rnn_state in state.rnn_state])

        # (group_size, decoder_input_dim)
        decoder_input = self._input_projection_layer(torch.cat([attended_sentence,
                                                                previous_action_embedding], -1))
        decoder_input = torch.tanh(decoder_input)
        hidden_state, memory_cell = self._decoder_cell(decoder_input, (hidden_state, memory_cell))

        hidden_state = self._dropout(hidden_state)
        # (group_size, encoder_output_dim)
        encoder_outputs = torch.stack([state.rnn_state[0].encoder_outputs[i] for i in state.batch_indices])
        encoder_output_mask = torch.stack([state.rnn_state[0].encoder_output_mask[i] for i in state.batch_indices])
        attended_sentence = self.attend_on_sentence(hidden_state, encoder_outputs, encoder_output_mask)

        # We get global indices of actions to embed here. The following logic is similar to
        # ``WikiTablesDecoderStep._get_actions_to_consider``, except that we do not have any actions
        # to link.
        valid_actions = state.get_valid_actions()
        global_valid_actions: List[List[Tuple[int, int]]] = []
        for batch_index, valid_action_list in zip(state.batch_indices, valid_actions):
            global_valid_actions.append([])
            for action_index in valid_action_list:
                # state.action_indices is a dictionary that maps (batch_index, batch_action_index)
                # to global_action_index
                global_action_index = state.action_indices[(batch_index, action_index)]
                global_valid_actions[-1].append((global_action_index, action_index))
        global_actions_to_embed: List[List[int]] = []
        local_actions: List[List[int]] = []
        for global_action_list in global_valid_actions:
            global_action_list.sort()
            global_actions_to_embed.append([])
            local_actions.append([])
            for global_action_index, action_index in global_action_list:
                global_actions_to_embed[-1].append(global_action_index)
                local_actions[-1].append(action_index)
        max_num_actions = max([len(action_list) for action_list in global_actions_to_embed])
        # We pad local actions with -1 as padding to get considered actions.
        considered_actions = [common_util.pad_sequence_to_length(action_list, max_num_actions,
                                                                 default_value=lambda: -1)
                              for action_list in local_actions]

        # action_embeddings: (group_size, num_embedded_actions, action_embedding_dim)
        # action_mask: (group_size, num_embedded_actions)
        action_embeddings, embedded_action_mask = self._get_action_embeddings(state,
                                                                              global_actions_to_embed)
        action_query = torch.cat([hidden_state, attended_sentence], dim=-1)
        # (group_size, action_embedding_dim)
        predicted_action_embedding = self._output_projection_layer(action_query)
        predicted_action_embedding = self._dropout(torch.tanh(predicted_action_embedding))
        if state.checklist_state[0] is not None:
            embedding_addition = self._get_predicted_embedding_addition(state)
            addition = embedding_addition * self._checklist_embedding_multiplier
            predicted_action_embedding = predicted_action_embedding + addition
        # We'll do a batch dot product here with `bmm`.  We want `dot(predicted_action_embedding,
        # action_embedding)` for each `action_embedding`, and we can get that efficiently with
        # `bmm` and some squeezing.
        # Shape: (group_size, num_embedded_actions)
        action_logits = action_embeddings.bmm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1)

        action_mask = embedded_action_mask.float()
        if state.checklist_state[0] is not None:
            # We will compute the logprobs and the checklists of potential next states together for
            # efficiency.
            logprobs, new_checklist_states = self._get_next_state_info_with_agenda(state,
                                                                                   considered_actions,
                                                                                   action_logits,
                                                                                   action_mask)
        else:
            logprobs = self._get_next_state_info_without_agenda(state,
                                                                considered_actions,
                                                                action_logits,
                                                                action_mask)
            new_checklist_states = None
        return self._compute_new_states(state,
                                        logprobs,
                                        hidden_state,
                                        memory_cell,
                                        action_embeddings,
                                        attended_sentence,
                                        considered_actions,
                                        allowed_actions,
                                        new_checklist_states,
                                        max_actions)