コード例 #1
0
    def _get_initial_state(
        self,
        encoder_outputs: torch.Tensor,
        utterance_mask: torch.Tensor,
        actions: List[ProductionRule],
    ) -> GrammarBasedState:
        batch_size = encoder_outputs.size(0)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        # Use CLS states as final encoder outputs
        # final_encoder_output = encoder_outputs[:,0,:]
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                encoder_outputs.shape[-1])
        initial_score = encoder_outputs.data.new_zeros(batch_size)
        attended_sentence, _ = self._transition_function.attend_on_question(
            final_encoder_output, encoder_outputs, utterance_mask)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            if self._decoder_num_layers > 1:
                encoder_output = final_encoder_output[i].repeat(
                    self._decoder_num_layers, 1)
                cell = memory_cell[i].repeat(self._decoder_num_layers, 1)
            else:
                encoder_output = final_encoder_output[i]
                cell = memory_cell[i]
            initial_rnn_state.append(
                RnnStatelet(
                    encoder_output,
                    cell,
                    self._first_action_embedding,
                    attended_sentence[i],
                    encoder_output_list,
                    utterance_mask_list,
                ))

        initial_grammar_state = [
            self._create_grammar_state(actions[i]) for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            debug_info=[[] for _ in range(batch_size)],
        )
        return initial_state
コード例 #2
0
    def _get_initial_state(self, utterance: Dict[str, torch.LongTensor],
                           worlds: List[AtisWorld],
                           actions: List[List[ProductionRuleArray]],
                           linking_scores: torch.Tensor) -> GrammarBasedState:
        embedded_utterance = self._utterance_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size = embedded_utterance.size(0)
        num_entities = max([len(world.entities) for world in worlds])

        # entity_types: tensor with shape (batch_size, num_entities)
        entity_types, _ = self._get_type_vector(worlds, num_entities,
                                                embedded_utterance)

        # (batch_size, num_utterance_tokens, embedding_dim)
        encoder_input = embedded_utterance

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, utterance_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, utterance_mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_utterance,
                            encoder_output_list, utterance_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(worlds[i], actions[i],
                                       linking_scores[i], entity_types[i])
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            debug_info=None)
        return initial_state
コード例 #3
0
 def _check_state_denotations(self, state: GrammarBasedState, worlds: List[NlvrWorld]) -> List[bool]:
     """
     Returns whether action history in the state evaluates to the correct denotations over all
     worlds. Only defined when the state is finished.
     """
     assert state.is_finished(), "Cannot compute denotations for unfinished states!"
     # Since this is a finished state, its group size must be 1.
     batch_index = state.batch_indices[0]
     instance_label_strings = state.extras[batch_index]
     history = state.action_history[0]
     all_actions = state.possible_actions[0]
     action_sequence = [all_actions[action][0] for action in history]
     return self._check_denotation(action_sequence, instance_label_strings, worlds)
コード例 #4
0
 def _check_state_denotations(self, state: GrammarBasedState,
                              worlds: List[NlvrWorld]) -> List[bool]:
     """
     Returns whether action history in the state evaluates to the correct denotations over all
     worlds. Only defined when the state is finished.
     """
     assert state.is_finished(
     ), "Cannot compute denotations for unfinished states!"
     # Since this is a finished state, its group size must be 1.
     batch_index = state.batch_indices[0]
     instance_label_strings = state.extras[batch_index]
     history = state.action_history[0]
     all_actions = state.possible_actions[0]
     action_sequence = [all_actions[action][0] for action in history]
     return self._check_denotation(action_sequence, instance_label_strings,
                                   worlds)
コード例 #5
0
    def _get_initial_state(
            self, encoder_outputs: torch.Tensor, mask: torch.Tensor,
            actions: List[List[ProductionRule]]) -> GrammarBasedState:

        batch_size = encoder_outputs.size(0)
        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, mask, self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder.get_output_dim())
        initial_score = encoder_outputs.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(
                    final_encoder_output[i],
                    memory_cell[i],
                    self._first_action_embedding,
                    self._first_attended_utterance,
                    encoder_output_list,
                    utterance_mask_list,
                ))

        initial_grammar_state = [
            self._create_grammar_state(actions[i]) for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            debug_info=None,
        )
        return initial_state
コード例 #6
0
    def _compute_action_probabilities(
        self, state: GrammarBasedState, hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor
    ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # We take a couple of extra arguments here because subclasses might use them.
        # pylint: disable=unused-argument,no-self-use

        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            action_embeddings, output_action_embeddings, action_ids = instance_actions[
                'global']
            # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
            # (embedding_dim, 1) matrix.
            action_logits = action_embeddings.mm(
                predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
            current_log_probs = torch.nn.functional.log_softmax(action_logits,
                                                                dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, current_log_probs,
                 output_action_embeddings, action_ids))
        return batch_results
コード例 #7
0
    def _compute_action_probabilities(self,
                                      state: GrammarBasedState,
                                      hidden_state: torch.Tensor,
                                      attention_weights: torch.Tensor,
                                      predicted_action_embeddings: torch.Tensor
                                     ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # We take a couple of extra arguments here because subclasses might use them.
        # pylint: disable=unused-argument,no-self-use

        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[group_index]
            action_embeddings, output_action_embeddings, action_ids = instance_actions['global']
            # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
            # (embedding_dim, 1) matrix.
            action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
            current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append((group_index,
                                                                    log_probs,
                                                                    current_log_probs,
                                                                    output_action_embeddings,
                                                                    action_ids))
        return batch_results
コード例 #8
0
    def _compute_action_probabilities(
        self, state: GrammarBasedState, hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor
    ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            embedded_actions: List[int] = []

            output_action_embeddings = None
            embedded_action_logits = None
            current_log_probs = None

            if 'global' in instance_actions:
                action_embeddings, output_action_embeddings, embedded_actions = instance_actions[
                    'global']
                # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
                # (embedding_dim, 1) matrix.
                embedded_action_logits = action_embeddings.mm(
                    predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
                action_ids = embedded_actions

            if 'linked' in instance_actions:
                linking_scores, type_embeddings, linked_actions = instance_actions[
                    'linked']
                action_ids = embedded_actions + linked_actions
                # (num_question_tokens, 1)
                linked_action_logits = linking_scores.mm(
                    attention_weights[group_index].unsqueeze(-1)).squeeze(-1)

                # The `output_action_embeddings` tensor gets used later as the input to the next
                # decoder step.  For linked actions, we don't have any action embedding, so we use
                # the entity type instead.
                if output_action_embeddings is not None:
                    output_action_embeddings = torch.cat(
                        [output_action_embeddings, type_embeddings], dim=0)
                else:
                    output_action_embeddings = type_embeddings

                if self._mixture_feedforward is not None:
                    # The linked and global logits are combined with a mixture weight to prevent the
                    # linked_action_logits from dominating the embedded_action_logits if a softmax
                    # was applied on both together.
                    mixture_weight = self._mixture_feedforward(
                        hidden_state[group_index])
                    mix1 = torch.log(mixture_weight)
                    mix2 = torch.log(1 - mixture_weight)

                    entity_action_probs = torch.nn.functional.log_softmax(
                        linked_action_logits, dim=-1) + mix1
                    if embedded_action_logits is not None:
                        embedded_action_probs = torch.nn.functional.log_softmax(
                            embedded_action_logits, dim=-1) + mix2
                        current_log_probs = torch.cat(
                            [embedded_action_probs, entity_action_probs],
                            dim=-1)
                    else:
                        current_log_probs = entity_action_probs
                else:
                    if embedded_action_logits is not None:
                        action_logits = torch.cat(
                            [embedded_action_logits, linked_action_logits],
                            dim=-1)
                    else:
                        action_logits = linked_action_logits
                    current_log_probs = torch.nn.functional.log_softmax(
                        action_logits, dim=-1)
            else:
                action_logits = embedded_action_logits
                current_log_probs = torch.nn.functional.log_softmax(
                    action_logits, dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, current_log_probs,
                 output_action_embeddings, action_ids))
        return batch_results
コード例 #9
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            table: Dict[str, torch.LongTensor],
            world: List[QuarelWorld],
            actions: List[List[ProductionRule]],
            entity_bits: torch.Tensor = None,
            denotation_target: torch.Tensor = None,
            target_action_sequences: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        entity_bits : ``torch.Tensor``, optional (default=None)
            Tensor encoding bits for the world entities.
        denotation_target : ``torch.Tensor``, optional (default=None)
            If model's field ``denotation_only`` is True, this is the tensor target denotation.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        metadata : List[Dict[str, Any]], optional (default=None).
            A dictionary of metadata for each batch element which has keys:
                question_tokens : ``List[str]``, optional.
                    The original string tokens in the question.
                world_extractions : ``nltk.Tree``, optional.
                    Extracted worlds from the question.
                answer_index : ``List[str]``, optional.
                    Index of the correct answer.
        """

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text,
                                                 num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(
            world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (
                    embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (
                    embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(
                    embedded_table.view(batch_size,
                                        num_entities * num_entity_tokens,
                                        self._embedding_dim),
                    torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(
                    batch_size, num_entities, num_entity_tokens,
                    num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(
                    question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (
                    embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (
                    embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(
                    -1, num_entities * num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size,
                                          num_entities * num_entity_tokens,
                                          self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(
                    batch_size,
                    num_question_tokens * num_entities * num_entity_tokens,
                    self._embedding_dim)
                question_entity_similarity = self._entity_similarity_layer(
                    product)
                question_entity_similarity = question_entity_similarity.view(
                    batch_size, num_entities, num_entity_tokens,
                    num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(
                    question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table['linking']

            if self._linking_params is not None:
                feature_scores = self._linking_params(
                    linking_features).squeeze(3)
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(
                world, linking_scores.transpose(1, 2), question_mask,
                entity_type_dict)
            encoder_input = embedded_question
        else:
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input = torch.cat([embedded_question, entity_bits], 2)
            else:
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(
            self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, question_mask, self._encoder.is_bidirectional())
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(
                final_encoder_output)
            loss = torch.nn.functional.cross_entropy(
                denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size,
                                                self._encoder_output_dim)

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):
            initial_rnn_state.append(
                RnnStatelet(final_encoder_output[i], memory_cell[i],
                            self._first_action_embedding,
                            self._first_attended_question, encoder_output_list,
                            question_mask_list))

        initial_grammar_state = [
            self._create_grammar_state(world[i], actions[i], linking_scores[i],
                                       entity_types[i])
            for i in range(batch_size)
        ]

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=None,
            debug_info=None)

        if self.training:
            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
            return outputs

        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(
                    initial_state, self._decoder_step,
                    (target_action_sequences, target_mask))['loss']

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(
                num_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            if self._linking_params is not None:
                outputs['linking_scores'] = linking_scores
                outputs['feature_scores'] = feature_scores
                outputs['linking_features'] = linking_features
            if self._use_entities:
                outputs['linking_probabilities'] = linking_probabilities
            if entity_bits is not None:
                outputs['entity_bits'] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            outputs['denotation_acc'] = []
            outputs['score'] = []
            outputs['parse_acc'] = []
            outputs['answer_index'] = []
            if metadata is not None:
                outputs['question_tokens'] = []
                outputs['world_extractions'] = []
            for i in range(batch_size):
                if metadata is not None:
                    outputs['question_tokens'].append(metadata[i].get(
                        'question_tokens', []))
                if metadata is not None:
                    outputs['world_extractions'].append(metadata[i].get(
                        'world_extractions', {}))
                outputs['entities'].append(world[i].table_graph.entities)
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][
                        0].action_history[0]
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(
                            best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in best_action_indices
                    ]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(
                            action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and 'answer_index' in metadata[i]:
                        answer_index = metadata[i]['answer_index']
                        denotation_accuracy = self._denotation_match(
                            predicted_answer_index, answer_index)
                        self._denotation_accuracy(denotation_accuracy)
                    score = math.exp(
                        best_final_states[i][0].score[0].data.cpu().item())
                    outputs['answer_index'].append(predicted_answer_index)
                    outputs['score'].append(score)
                    outputs['parse_acc'].append(sequence_in_targets)
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['denotation_acc'].append(denotation_accuracy)
                    outputs['debug_info'].append(
                        best_final_states[i][0].debug_info[0])  # type: ignore
                else:
                    outputs['parse_acc'].append(0)
                    outputs['logical_form'].append('')
                    outputs['denotation_acc'].append(0)
                    outputs['score'].append(0)
                    outputs['answer_index'].append(-1)
                    outputs['best_action_sequence'].append([])
                    outputs['debug_info'].append([])
                    self._has_logical_form(0.0)
            return outputs
コード例 #10
0
    def forward(self,  # type: ignore
                sentence: Dict[str, torch.LongTensor],
                worlds: List[List[NlvrWorld]],
                actions: List[List[ProductionRule]],
                identifier: List[str] = None,
                target_action_sequences: torch.LongTensor = None,
                labels: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
                              for i in range(batch_size)]
        label_strings = self._get_label_strings(labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [self._create_grammar_state(worlds[i][0], actions[i]) for i in
                                 range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          extras=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if identifier is not None:
            outputs["identifier"] = identifier
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._decoder_beam_search.search(self._max_decoding_steps,
                                                                 initial_state,
                                                                 self._decoder_step,
                                                                 keep_final_unfinished_states=False)
            best_action_sequences: Dict[int, List[List[int]]] = {}
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = [best_final_states[i][0].action_history[0]]
                    best_action_sequences[i] = best_action_indices
            batch_action_strings = self._get_action_strings(actions, best_action_sequences)
            batch_denotations = self._get_denotations(batch_action_strings, worlds)
            if target_action_sequences is not None:
                self._update_metrics(action_strings=batch_action_strings,
                                     worlds=worlds,
                                     label_strings=label_strings)
            else:
                if metadata is not None:
                    outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata]
                outputs['debug_info'] = []
                for i in range(batch_size):
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                outputs["best_action_strings"] = batch_action_strings
                outputs["denotations"] = batch_denotations
                action_mapping = {}
                for batch_index, batch_actions in enumerate(actions):
                    for action_index, action in enumerate(batch_actions):
                        action_mapping[(batch_index, action_index)] = action[0]
                outputs['action_mapping'] = action_mapping
        return outputs
コード例 #11
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesWorld],
                actions: List[List[ProductionRule]],
                example_lisp_string: List[str] = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        example_lisp_string : ``List[str]``, optional (default = None)
            The example (lisp-formatted) string corresponding to the given input.  This comes
            directly from the ``.examples`` file provided with the dataset.  We pass this to SEMPRE
            when evaluating denotation accuracy; it is otherwise unused.
        target_action_sequences : torch.Tensor, optional (default = None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' key.
        """
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question,
                                                                           table,
                                                                           world,
                                                                           actions,
                                                                           outputs)
        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),  # type: ignore
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=rnn_state,
                                          grammar_state=grammar_state,
                                          possible_actions=actions,
                                          extras=example_lisp_string,
                                          debug_info=None)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)

            self._compute_validation_outputs(actions,
                                             best_final_states,
                                             world,
                                             example_lisp_string,
                                             metadata,
                                             outputs)

            return outputs
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[WikiTablesVariableFreeWorld],
                actions: List[List[ProductionRuleArray]],
                target_values: List[List[str]] = None,
                target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesWorld]``,
        actions : ``List[List[ProductionRuleArray]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRuleArray`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_values : ``List[List[str]]``, optional (default = None)
            For each instance, a list of target values taken from the example lisp string. We pass
            this list to the evaluator along with logical forms to compute denotation accuracy.
        target_action_sequences : torch.Tensor, optional (default = None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question,
                                                                           table,
                                                                           world,
                                                                           actions,
                                                                           outputs)
        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),  # type: ignore
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=rnn_state,
                                          grammar_state=grammar_state,
                                          possible_actions=actions,
                                          extras=target_values,
                                          debug_info=None)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        if self.training:
            return self._decoder_trainer.decode(initial_state,
                                                self._decoder_step,
                                                (target_action_sequences, target_mask))
        else:
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']
            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    if target_action_sequences is not None:
                        # Use a Tensor, not a Variable, to avoid a memory leak.
                        targets = target_action_sequences[i].data
                        sequence_in_targets = 0
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)

            metadata = None
            self._compute_validation_outputs(actions,
                                             best_final_states,
                                             world,
                                             target_values,
                                             metadata,
                                             outputs)

            return outputs
コード例 #13
0
    def _take_first_step(
            self,
            state: GrammarBasedState,
            allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]:
        # We'll just do a projection from the current hidden state (which was initialized with the
        # final encoder output) to the number of start actions that we have, normalize those
        # logits, and use that as our score.  We end up duplicating some of the logic from
        # `_compute_new_states` here, but we do things slightly differently, and it's easier to
        # just copy the parts we need than to try to re-use that code.

        # (group_size, hidden_dim)
        hidden_state = torch.stack(
            [rnn_state.hidden_state for rnn_state in state.rnn_state])
        # (group_size, num_start_type)
        start_action_logits = self._start_type_predictor(hidden_state)
        log_probs = torch.nn.functional.log_softmax(start_action_logits,
                                                    dim=-1)
        sorted_log_probs, sorted_actions = log_probs.sort(dim=-1,
                                                          descending=True)

        sorted_actions = sorted_actions.detach().cpu().numpy().tolist()
        if state.debug_info is not None:
            probs_cpu = log_probs.exp().detach().cpu().numpy().tolist()
        else:
            probs_cpu = [None] * len(state.batch_indices)

        # state.get_valid_actions() will return a list that is consistently sorted, so as along as
        # the set of valid start actions never changes, we can just match up the log prob indices
        # above with the position of each considered action, and we're good.
        valid_actions = state.get_valid_actions()
        considered_actions = [
            actions['global'][2] for actions in valid_actions
        ]
        if len(considered_actions[0]) != self._num_start_types:
            raise RuntimeError(
                "Calculated wrong number of initial actions.  Expected "
                f"{self._num_start_types}, found {len(considered_actions[0])}."
            )

        best_next_states: Dict[int, List[Tuple[int, int,
                                               int]]] = defaultdict(list)
        for group_index, (batch_index, group_actions) in enumerate(
                zip(state.batch_indices, sorted_actions)):
            for action_index, action in enumerate(group_actions):
                # `action` is currently the index in `log_probs`, not the actual action ID.  To get
                # the action ID, we need to go through `considered_actions`.
                action = considered_actions[group_index][action]
                if allowed_actions is not None and action not in allowed_actions[
                        group_index]:
                    # This happens when our _decoder trainer_ wants us to only evaluate certain
                    # actions, likely because they are the gold actions in this state.  We just skip
                    # emitting any state that isn't allowed by the trainer, because constructing the
                    # new state can be expensive.
                    continue
                best_next_states[batch_index].append(
                    (group_index, action_index, action))

        new_states = []
        for batch_index, best_states in sorted(best_next_states.items()):
            for group_index, action_index, action in best_states:
                # We'll yield a bunch of states here that all have a `group_size` of 1, so that the
                # learning algorithm can decide how many of these it wants to keep, and it can just
                # regroup them later, as that's a really easy operation.
                new_score = state.score[group_index] + sorted_log_probs[
                    group_index, action_index]

                # This part is different from `_compute_new_states` - we're just passing through
                # the previous RNN state, as predicting the start type wasn't included in the
                # decoder RNN in the original model.
                new_rnn_state = state.rnn_state[group_index]
                new_state = state.new_state_from_group_index(
                    group_index, action, new_score, new_rnn_state,
                    considered_actions[group_index], probs_cpu[group_index],
                    None)
                new_states.append(new_state)
        return new_states
コード例 #14
0
    def setUp(self):
        super().setUp()
        self.decoder_step = BasicTransitionFunction(
            encoder_output_dim=2,
            action_embedding_dim=2,
            input_attention=Attention.by_name('dot_product')(),
            num_start_types=3,
            add_action_bias=False)

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [torch.FloatTensor([x]) for x in [.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        # This maps non-terminals to valid actions, where the valid actions are grouped by _type_.
        # We have "global" actions, which are from the global grammar, and "linked" actions, which
        # are instance-specific and are generated based on question attention.  Each action type
        # has a tuple which is (input representation, output representation, action ids).
        valid_actions = {
            'e': {
                'global': (torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]),
                           torch.FloatTensor([[-1, -1], [-2, -2],
                                              [-3, -3]]), [0, 1, 2]),
                'linked': (torch.FloatTensor([[.1, .2, .3], [.4, .5, .6]]),
                           torch.FloatTensor([[3, 3], [4, 4]]), [3, 4])
            },
            'd': {
                'global':
                (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1,
                                                                  -1]]), [0]),
                'linked': (torch.FloatTensor([[-.1, -.2, -.3], [-.4, -.5, -.6],
                                              [-.7, -.8, -.9]]),
                           torch.FloatTensor([[5, 5], [6, 6], [7,
                                                               7]]), [1, 2, 3])
            }
        }
        grammar_state = [
            GrammarStatelet([nonterminal], {}, valid_actions, {},
                            is_nonterminal)
            for _, nonterminal in zip(batch_indices, ['e', 'd', 'e'])
        ]
        self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]],
                                                  [[10, 11], [12, 13],
                                                   [14, 15]]])
        self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
        self.possible_actions = [[
            ('e -> f', False, None), ('e -> g', True, None),
            ('e -> h', True, None), ('e -> i', True, None),
            ('e -> j', True, None)
        ],
                                 [
                                     ('d -> q', True, None),
                                     ('d -> g', True, None),
                                     ('d -> h', True, None),
                                     ('d -> i', True, None)
                                 ]]

        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnStatelet(hidden_state[i], memory_cell[i],
                            previous_action_embedding[i], attended_question[i],
                            self.encoder_outputs, self.encoder_output_mask))
        self.state = GrammarBasedState(batch_indices=batch_indices,
                                       action_history=action_history,
                                       score=score,
                                       rnn_state=rnn_state,
                                       grammar_state=grammar_state,
                                       possible_actions=self.possible_actions)
コード例 #15
0
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrLanguage]],
            actions: List[List[ProductionRule]],
            identifier: List[str] = None,
            target_action_sequences: torch.LongTensor = None,
            labels: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihod over a set of approximate logical forms.
        """
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
            for i in range(batch_size)
        ]
        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]

        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings)

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        if identifier is not None:
            outputs["identifier"] = identifier
        if target_action_sequences is not None:
            outputs = self._decoder_trainer.decode(
                initial_state, self._decoder_step,
                (target_action_sequences, target_mask))
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._decoder_beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False)
            best_action_sequences: Dict[int, List[List[int]]] = {}
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = [
                        best_final_states[i][0].action_history[0]
                    ]
                    best_action_sequences[i] = best_action_indices
            batch_action_strings = self._get_action_strings(
                actions, best_action_sequences)
            batch_denotations = self._get_denotations(batch_action_strings,
                                                      worlds)
            if target_action_sequences is not None:
                self._update_metrics(action_strings=batch_action_strings,
                                     worlds=worlds,
                                     label_strings=label_strings)
            else:
                if metadata is not None:
                    outputs["sentence_tokens"] = [
                        x["sentence_tokens"] for x in metadata
                    ]
                outputs['debug_info'] = []
                for i in range(batch_size):
                    outputs['debug_info'].append(
                        best_final_states[i][0].debug_info[0])  # type: ignore
                outputs["best_action_strings"] = batch_action_strings
                outputs["denotations"] = batch_denotations
                action_mapping = {}
                for batch_index, batch_actions in enumerate(actions):
                    for action_index, action in enumerate(batch_actions):
                        action_mapping[(batch_index, action_index)] = action[0]
                outputs['action_mapping'] = action_mapping
        return outputs
コード例 #16
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                table: Dict[str, torch.LongTensor],
                world: List[QuarelWorld],
                actions: List[List[ProductionRule]],
                entity_bits: torch.Tensor = None,
                denotation_target: torch.Tensor = None,
                target_action_sequences: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # pylint: disable=unused-argument
        """
        In this method we encode the table entities, link them to words in the question, then
        encode the question. Then we set up the initial state for the decoder, and pass that
        state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference,
        if we're not.

        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[QuarelWorld]``
            We use a ``MetadataField`` to get the ``World`` for each input instance.  Because of
            how ``MetadataField`` works, this gets passed to us as a ``List[QuarelWorld]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequences : torch.Tensor, optional (default=None)
           A list of possibly valid action sequences, where each action is an index into the list
           of possible actions.  This tensor has shape ``(batch_size, num_action_sequences,
           sequence_length)``.
        """

        table_text = table['text']

        self._debug_count -= 1

        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, embedded_table)

        if self._use_entities:

            if self._entity_similarity_mode == "dot_product":
                # Compute entity and question word cosine similarity. Need to add a small value to
                # to the table norm since there are padding values which cause a divide by 0.
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                           num_entities * num_entity_tokens,
                                                                           self._embedding_dim),
                                                       torch.transpose(embedded_question, 1, 2))

                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

                linking_scores = question_entity_similarity_max_score
            elif self._entity_similarity_mode == "weighted_dot_product":
                embedded_table = embedded_table / (embedded_table.norm(dim=-1, keepdim=True) + 1e-13)
                embedded_question = embedded_question / (embedded_question.norm(dim=-1, keepdim=True) + 1e-13)
                eqe = embedded_question.unsqueeze(1).expand(-1, num_entities*num_entity_tokens, -1, -1)
                ete = embedded_table.view(batch_size, num_entities*num_entity_tokens, self._embedding_dim)
                ete = ete.unsqueeze(2).expand(-1, -1, num_question_tokens, -1)
                product = torch.mul(eqe, ete)
                product = product.view(batch_size,
                                       num_question_tokens*num_entities*num_entity_tokens,
                                       self._embedding_dim)
                question_entity_similarity = self._entity_similarity_layer(product)
                question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                             num_entities,
                                                                             num_entity_tokens,
                                                                             num_question_tokens)

                # (batch_size, num_entities, num_question_tokens)
                question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)
                linking_scores = question_entity_similarity_max_score

            # (batch_size, num_entities, num_question_tokens, num_features)
            linking_features = table['linking']

            if self._linking_params is not None:
                feature_scores = self._linking_params(linking_features).squeeze(3)
                linking_scores = linking_scores + feature_scores

            # (batch_size, num_question_tokens, num_entities)
            linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                    question_mask, entity_type_dict)
            encoder_input = embedded_question
        else:
            if entity_bits is not None and not self._entity_bits_output:
                encoder_input = torch.cat([embedded_question, entity_bits], 2)
            else:
                encoder_input = embedded_question

            # Fake linking_scores added for downstream code to not object
            linking_scores = question_mask.clone().fill_(0).unsqueeze(1)
            linking_probabilities = None

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        if self._entity_bits_output and entity_bits is not None:
            encoder_outputs = torch.cat([encoder_outputs, entity_bits], 2)

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        # For predicting a categorical denotation directly
        if self._denotation_only:
            denotation_logits = self._denotation_classifier(final_encoder_output)
            loss = torch.nn.functional.cross_entropy(denotation_logits, denotation_target.view(-1))
            self._denotation_accuracy_cat(denotation_logits, denotation_target)
            return {"loss": loss}

        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder_output_dim)

        _, num_entities, num_question_tokens = linking_scores.size()

        if target_action_sequences is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequences = target_action_sequences.squeeze(-1)
            target_mask = target_action_sequences != self._action_padding_index
        else:
            target_mask = None

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []

        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))

        initial_grammar_state = [self._create_grammar_state(world[i], actions[i],
                                                            linking_scores[i], entity_types[i])
                                 for i in range(batch_size)]

        initial_score = initial_rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          extras=None,
                                          debug_info=None)

        if self.training:
            outputs = self._decoder_trainer.decode(initial_state,
                                                   self._decoder_step,
                                                   (target_action_sequences, target_mask))
            return outputs

        else:
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs = {'action_mapping': action_mapping}
            if target_action_sequences is not None:
                outputs['loss'] = self._decoder_trainer.decode(initial_state,
                                                               self._decoder_step,
                                                               (target_action_sequences, target_mask))['loss']

            num_steps = self._max_decoding_steps
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(num_steps,
                                                         initial_state,
                                                         self._decoder_step,
                                                         keep_final_unfinished_states=False)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['entities'] = []
            if self._linking_params is not None:
                outputs['linking_scores'] = linking_scores
                outputs['feature_scores'] = feature_scores
                outputs['linking_features'] = linking_features
            if self._use_entities:
                outputs['linking_probabilities'] = linking_probabilities
            if entity_bits is not None:
                outputs['entity_bits'] = entity_bits
            # outputs['similarity_scores'] = question_entity_similarity_max_score
            outputs['logical_form'] = []
            outputs['denotation_acc'] = []
            outputs['score'] = []
            outputs['parse_acc'] = []
            outputs['answer_index'] = []
            if metadata is not None:
                outputs['question_tokens'] = []
                outputs['world_extractions'] = []
            for i in range(batch_size):
                if metadata is not None:
                    outputs['question_tokens'].append(metadata[i].get('question_tokens', []))
                if metadata is not None:
                    outputs['world_extractions'].append(metadata[i].get('world_extractions', {}))
                outputs['entities'].append(world[i].table_graph.entities)
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = best_final_states[i][0].action_history[0]
                    sequence_in_targets = 0
                    if target_action_sequences is not None:
                        targets = target_action_sequences[i].data
                        sequence_in_targets = self._action_history_match(best_action_indices, targets)
                        self._action_sequence_accuracy(sequence_in_targets)
                    action_strings = [action_mapping[(i, action_index)] for action_index in best_action_indices]
                    try:
                        self._has_logical_form(1.0)
                        logical_form = world[i].get_logical_form(action_strings, add_var_function=False)
                    except ParsingError:
                        self._has_logical_form(0.0)
                        logical_form = 'Error producing logical form'
                    denotation_accuracy = 0.0
                    predicted_answer_index = world[i].execute(logical_form)
                    if metadata is not None and 'answer_index' in metadata[i]:
                        answer_index = metadata[i]['answer_index']
                        denotation_accuracy = self._denotation_match(predicted_answer_index, answer_index)
                        self._denotation_accuracy(denotation_accuracy)
                    score = math.exp(best_final_states[i][0].score[0].data.cpu().item())
                    outputs['answer_index'].append(predicted_answer_index)
                    outputs['score'].append(score)
                    outputs['parse_acc'].append(sequence_in_targets)
                    outputs['best_action_sequence'].append(action_strings)
                    outputs['logical_form'].append(logical_form)
                    outputs['denotation_acc'].append(denotation_accuracy)
                    outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
                else:
                    outputs['parse_acc'].append(0)
                    outputs['logical_form'].append('')
                    outputs['denotation_acc'].append(0)
                    outputs['score'].append(0)
                    outputs['answer_index'].append(-1)
                    outputs['best_action_sequence'].append([])
                    outputs['debug_info'].append([])
                    self._has_logical_form(0.0)
            return outputs
コード例 #17
0
    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            question_predicates,
            # labelled_results,
            world: List[LCQuADLanguage],
            actions: List[List[ProductionRule]],
            question_entities=None,
            target_action_sequences: torch.LongTensor = None,
            labels: torch.LongTensor = None,
            logical_forms=None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences, trained to maximize marginal
        likelihood over a set of approximate logical forms.
        """
        assert target_action_sequences is not None

        batch_size = question['tokens'].size()[0]
        # Remove the trailing dimension (from ListField[ListField[IndexField]]).
        # assert target_action_sequences.dim() == 3
        target_action_sequences = target_action_sequences.squeeze(-1)
        target_mask = target_action_sequences != self._action_padding_index

        # if self._kg_embedder:
        #     embedded_entities = self._kg_embedder(question_entities, input_type="entity")
        #     embedded_type_entities = self._kg_embedder(question_type_entities, input_type="entity")
        #     embedded_predicates = self._kg_embedder(question_predicates, input_type="predicate")

        initial_rnn_state = self._get_initial_rnn_state(question)
        initial_score_list = [
            next(iter(question.values())).new_zeros(1, dtype=torch.float)
            for _ in range(batch_size)
        ]

        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_statelet = [
            self._create_grammar_state(world[i], actions[i])
            for i in range(batch_size)
        ]
        initial_state = GrammarBasedState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_statelet,
            possible_actions=actions)

        outputs = self._decoder_trainer.decode(
            initial_state, self._decoder_step,
            (target_action_sequences, target_mask))

        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._decoder_beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False)
            best_action_sequences: Dict[int, List[List[int]]] = {}
            for i in range(batch_size):
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    best_action_indices = [
                        best_final_states[i][0].action_history[0]
                    ]
                    best_action_sequences[i] = best_action_indices
            batch_action_strings = self._get_action_strings(
                actions, best_action_sequences)

            # self._update_metrics(action_strings=batch_action_strings,
            #                     worlds=world,
            #                     labelled_results=labelled_results)

            debug_infos = []
            for i in range(batch_size):
                debug_infos.append(best_final_states[i][0].debug_info[0])

            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]

            self._update_seq_metrics(action_strings=batch_action_strings,
                                     worlds=world,
                                     gold_logical_forms=logical_forms,
                                     train=self.training)

            outputs["predicted queries"] = batch_action_strings

            best_actions = batch_action_strings
            batch_action_info = []
            for batch_index, (predicted_actions, debug_info) in enumerate(
                    zip(best_actions, debug_infos)):
                instance_action_info = []
                for predicted_action, action_debug_info in zip(
                        predicted_actions[0], debug_info):
                    action_info = {}
                    action_info['predicted_action'] = predicted_action
                    considered_actions = action_debug_info[
                        'considered_actions']
                    probabilities = action_debug_info['probabilities']
                    actions = []
                    for action, probability in zip(considered_actions,
                                                   probabilities):
                        if action != -1:
                            actions.append(
                                (action_mapping[(batch_index, action)],
                                 probability))
                    actions.sort()
                    considered_actions, probabilities = zip(*actions)
                    action_info['considered_actions'] = considered_actions
                    action_info['action_probabilities'] = probabilities
                    action_info['question_attention'] = action_debug_info.get(
                        'question_attention', [])
                    instance_action_info.append(action_info)
                batch_action_info.append(instance_action_info)
            outputs["predicted_actions"] = batch_action_info

        return outputs
コード例 #18
0
    def setUp(self):
        super().setUp()
        self.decoder_step = BasicTransitionFunction(
            encoder_output_dim=2,
            action_embedding_dim=2,
            input_attention=Attention.by_name("dot_product")(),
            add_action_bias=False,
        )

        batch_indices = [0, 1, 0]
        action_history = [[1], [3, 4], []]
        score = [torch.FloatTensor([x]) for x in [0.1, 1.1, 2.2]]
        hidden_state = torch.FloatTensor([[i, i]
                                          for i in range(len(batch_indices))])
        memory_cell = torch.FloatTensor([[i, i]
                                         for i in range(len(batch_indices))])
        previous_action_embedding = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        attended_question = torch.FloatTensor(
            [[i, i] for i in range(len(batch_indices))])
        # This maps non-terminals to valid actions, where the valid actions are grouped by _type_.
        # We have "global" actions, which are from the global grammar, and "linked" actions, which
        # are instance-specific and are generated based on question attention.  Each action type
        # has a tuple which is (input representation, output representation, action ids).
        valid_actions = {
            "e": {
                "global": (
                    torch.FloatTensor([[0, 0], [-1, -1], [-2, -2]]),
                    torch.FloatTensor([[-1, -1], [-2, -2], [-3, -3]]),
                    [0, 1, 2],
                ),
                "linked": (
                    torch.FloatTensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
                    torch.FloatTensor([[3, 3], [4, 4]]),
                    [3, 4],
                ),
            },
            "d": {
                "global":
                (torch.FloatTensor([[0, 0]]), torch.FloatTensor([[-1,
                                                                  -1]]), [0]),
                "linked": (
                    torch.FloatTensor([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6],
                                       [-0.7, -0.8, -0.9]]),
                    torch.FloatTensor([[5, 5], [6, 6], [7, 7]]),
                    [1, 2, 3],
                ),
            },
        }
        grammar_state = [
            GrammarStatelet([nonterminal], valid_actions, is_nonterminal)
            for _, nonterminal in zip(batch_indices, ["e", "d", "e"])
        ]
        self.encoder_outputs = torch.FloatTensor([[[1, 2], [3, 4], [5, 6]],
                                                  [[10, 11], [12, 13],
                                                   [14, 15]]])
        self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
        self.possible_actions = [
            [
                ("e -> f", False, None),
                ("e -> g", True, None),
                ("e -> h", True, None),
                ("e -> i", True, None),
                ("e -> j", True, None),
            ],
            [
                ("d -> q", True, None),
                ("d -> g", True, None),
                ("d -> h", True, None),
                ("d -> i", True, None),
            ],
        ]

        rnn_state = []
        for i in range(len(batch_indices)):
            rnn_state.append(
                RnnStatelet(
                    hidden_state[i],
                    memory_cell[i],
                    previous_action_embedding[i],
                    attended_question[i],
                    self.encoder_outputs,
                    self.encoder_output_mask,
                ))
        self.state = GrammarBasedState(
            batch_indices=batch_indices,
            action_history=action_history,
            score=score,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            possible_actions=self.possible_actions,
        )
コード例 #19
0
    def _take_first_step(self,
                         state: GrammarBasedState,
                         allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]:
        # We'll just do a projection from the current hidden state (which was initialized with the
        # final encoder output) to the number of start actions that we have, normalize those
        # logits, and use that as our score.  We end up duplicating some of the logic from
        # `_compute_new_states` here, but we do things slightly differently, and it's easier to
        # just copy the parts we need than to try to re-use that code.

        # (group_size, hidden_dim)
        hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])
        # (group_size, num_start_type)
        start_action_logits = self._start_type_predictor(hidden_state)
        log_probs = torch.nn.functional.log_softmax(start_action_logits, dim=-1)
        sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True)

        sorted_actions = sorted_actions.detach().cpu().numpy().tolist()
        if state.debug_info is not None:
            probs_cpu = log_probs.exp().detach().cpu().numpy().tolist()
        else:
            probs_cpu = [None] * len(state.batch_indices)

        # state.get_valid_actions() will return a list that is consistently sorted, so as along as
        # the set of valid start actions never changes, we can just match up the log prob indices
        # above with the position of each considered action, and we're good.
        valid_actions = state.get_valid_actions()
        considered_actions = [actions['global'][2] for actions in valid_actions]
        if len(considered_actions[0]) != self._num_start_types:
            raise RuntimeError("Calculated wrong number of initial actions.  Expected "
                               f"{self._num_start_types}, found {len(considered_actions[0])}.")

        best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list)
        for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices, sorted_actions)):
            for action_index, action in enumerate(group_actions):
                # `action` is currently the index in `log_probs`, not the actual action ID.  To get
                # the action ID, we need to go through `considered_actions`.
                action = considered_actions[group_index][action]
                if allowed_actions is not None and action not in allowed_actions[group_index]:
                    # This happens when our _decoder trainer_ wants us to only evaluate certain
                    # actions, likely because they are the gold actions in this state.  We just skip
                    # emitting any state that isn't allowed by the trainer, because constructing the
                    # new state can be expensive.
                    continue
                best_next_states[batch_index].append((group_index, action_index, action))

        new_states = []
        for batch_index, best_states in sorted(best_next_states.items()):
            for group_index, action_index, action in best_states:
                # We'll yield a bunch of states here that all have a `group_size` of 1, so that the
                # learning algorithm can decide how many of these it wants to keep, and it can just
                # regroup them later, as that's a really easy operation.
                new_score = state.score[group_index] + sorted_log_probs[group_index, action_index]

                # This part is different from `_compute_new_states` - we're just passing through
                # the previous RNN state, as predicting the start type wasn't included in the
                # decoder RNN in the original model.
                new_rnn_state = state.rnn_state[group_index]
                new_state = state.new_state_from_group_index(group_index,
                                                             action,
                                                             new_score,
                                                             new_rnn_state,
                                                             considered_actions[group_index],
                                                             probs_cpu[group_index],
                                                             None)
                new_states.append(new_state)
        return new_states