Exemplo n.º 1
0
    def _get_state_cost(
        self, batch_worlds: List[List[NlvrLanguage]], state: CoverageState
    ) -> torch.Tensor:
        """
        Return the cost of a finished state. Since it is a finished state, the group size will be
        1, and hence we'll return just one cost.

        The ``batch_worlds`` parameter here is because we need the world to check the denotation
        accuracy of the action sequence in the finished state.  Instead of adding a field to the
        ``State`` object just for this method, we take the ``World`` as a parameter here.
        """
        if not state.is_finished():
            raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
        instance_worlds = batch_worlds[state.batch_indices[0]]
        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask.
        checklist_balance = state.checklist_state[0].get_balance()
        checklist_cost = torch.sum((checklist_balance) ** 2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        # Note: If we are penalizing the model for producing non-agenda actions, this is not the
        # upper limit on the checklist cost. That would be the number of terminal actions.
        denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        # TODO (pradeep): The denotation based cost below is strict. May be define a cost based on
        # how many worlds the logical form is correct in?
        # extras being None happens when we are testing. We do not care about the cost
        # then.  TODO (pradeep): Make this cleaner.
        if state.extras is None or all(self._check_state_denotations(state, instance_worlds)):
            cost = checklist_cost
        else:
            cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
        return cost
Exemplo n.º 2
0
    def _get_state_cost(self, worlds: List[WikiTablesLanguage], state: CoverageState) -> torch.Tensor:
        if not state.is_finished():
            raise RuntimeError("_get_state_cost() is not defined for unfinished states!")
        world = worlds[state.batch_indices[0]]

        # Our checklist cost is a sum of squared error from where we want to be, making sure we
        # take into account the mask. We clamp the lower limit of the balance at 0 to avoid
        # penalizing agenda actions produced multiple times.
        checklist_balance = torch.clamp(state.checklist_state[0].get_balance(), min=0.0)
        checklist_cost = torch.sum((checklist_balance) ** 2)

        # This is the number of items on the agenda that we want to see in the decoded sequence.
        # We use this as the denotation cost if the path is incorrect.
        denotation_cost = torch.sum(state.checklist_state[0].checklist_target.float())
        checklist_cost = self._checklist_cost_weight * checklist_cost
        action_history = state.action_history[0]
        batch_index = state.batch_indices[0]
        action_strings = [state.possible_actions[batch_index][i][0] for i in action_history]
        target_values = state.extras[batch_index]
        evaluation = False
        executor_logger = \
                logging.getLogger('allennlp_semparse.domain_languages.wikitables_language')
        executor_logger.setLevel(logging.ERROR)
        evaluation = world.evaluate_action_sequence(action_strings, target_values)
        if evaluation:
            cost = checklist_cost
        else:
            cost = checklist_cost + (1 - self._checklist_cost_weight) * denotation_cost
        return cost
Exemplo n.º 3
0
    def _compute_action_probabilities(
        self,  # type: ignore
        state: CoverageState,
        hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor
    ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            action_embeddings, output_action_embeddings, action_ids = instance_actions[
                'global']

            # This embedding addition the only difference between the logic here and the
            # corresponding logic in the super class.
            embedding_addition = self._get_predicted_embedding_addition(
                state.checklist_state[group_index], action_ids,
                action_embeddings)
            addition = embedding_addition * self._checklist_multiplier
            predicted_action_embedding = predicted_action_embedding + addition

            # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
            # (embedding_dim, 1) matrix.
            action_logits = action_embeddings.mm(
                predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
            current_log_probs = torch.nn.functional.log_softmax(action_logits,
                                                                dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, current_log_probs,
                 output_action_embeddings, action_ids))
        return batch_results
 def _get_state_info(
         self, state: CoverageState,
         batch_worlds: List[List[NlvrLanguage]]) -> Dict[str, List]:
     """
     This method is here for debugging purposes, in case you want to look at the what the model
     is learning. It may be inefficient to call it while training the model on real data.
     """
     if len(state.batch_indices) == 1 and state.is_finished():
         costs = [
             float(
                 self._get_state_cost(batch_worlds,
                                      state).detach().cpu().numpy())
         ]
     else:
         costs = []
     model_scores = [
         float(score.detach().cpu().numpy()) for score in state.score
     ]
     all_actions = state.possible_actions[0]
     action_sequences = [[
         self._get_action_string(all_actions[action]) for action in history
     ] for history in state.action_history]
     agenda_sequences = []
     all_agenda_indices = []
     for checklist_state in state.checklist_state:
         agenda_indices = []
         for action, is_wanted in zip(checklist_state.terminal_actions,
                                      checklist_state.checklist_target):
             action_int = int(action.detach().cpu().numpy())
             is_wanted_int = int(is_wanted.detach().cpu().numpy())
             if is_wanted_int != 0:
                 agenda_indices.append(action_int)
         agenda_sequences.append([
             self._get_action_string(all_actions[action])
             for action in agenda_indices
         ])
         all_agenda_indices.append(agenda_indices)
     return {
         "agenda": agenda_sequences,
         "agenda_indices": all_agenda_indices,
         "history": action_sequences,
         "history_indices": state.action_history,
         "costs": costs,
         "scores": model_scores
     }
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            worlds: List[List[NlvrLanguage]],
            actions: List[List[ProductionRule]],
            agenda: torch.LongTensor,
            identifier: List[str] = None,
            labels: torch.LongTensor = None,
            epoch_num: List[int] = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        # We look at the epoch number and adjust the checklist cost weight if needed here.
        instance_epoch_num = epoch_num[0] if epoch_num is not None else None
        if self._dynamic_cost_rate is not None:
            if self.training and instance_epoch_num is None:
                raise RuntimeError(
                    "If you want a dynamic cost weight, use the "
                    "BucketIterator with track_epoch=True.")
            if instance_epoch_num != self._last_epoch_in_forward:
                if instance_epoch_num >= self._dynamic_cost_wait_epochs:
                    decrement = self._checklist_cost_weight * self._dynamic_cost_rate
                    self._checklist_cost_weight -= decrement
                    logger.info("Checklist cost weight is now %f",
                                self._checklist_cost_weight)
                self._last_epoch_in_forward = instance_epoch_num
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [
            next(iter(sentence.values())).new_zeros(1, dtype=torch.float)
            for i in range(batch_size)
        ]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i])
            for i in range(batch_size)
        ]

        label_strings = self._get_label_strings(
            labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        # TODO(mattg): It looks like the agenda is only ever used on the CPU.  In that case, it's a
        # waste to copy it to the GPU and then back, and this should probably be a MetadataField.
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            initial_checklist_states.append(
                ChecklistStatelet(terminal_actions=terminal_actions,
                                  checklist_target=checklist_target,
                                  checklist_mask=checklist_mask,
                                  checklist=initial_checklist))
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings,
            checklist_state=initial_checklist_states)
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(
            initial_state,  # type: ignore
            self._decoder_step,
            partial(self._get_state_cost, worlds))
        if identifier is not None:
            outputs['identifier'] = identifier
        best_final_states = outputs['best_final_states']
        best_action_sequences = {}
        for batch_index, states in best_final_states.items():
            best_action_sequences[batch_index] = [
                state.action_history[0] for state in states
            ]
        batch_action_strings = self._get_action_strings(
            actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(action_strings=batch_action_strings,
                                 worlds=worlds,
                                 label_strings=label_strings,
                                 possible_actions=actions,
                                 agenda_data=agenda_data)
        else:
            # We're testing.
            if metadata is not None:
                outputs["sentence_tokens"] = [
                    x["sentence_tokens"] for x in metadata
                ]
            outputs['debug_info'] = []
            for i in range(batch_size):
                outputs['debug_info'].append(
                    best_final_states[i][0].debug_info[0])  # type: ignore
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs['action_mapping'] = action_mapping
        return outputs
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        table: Dict[str, torch.LongTensor],
        world: List[WikiTablesLanguage],
        actions: List[List[ProductionRule]],
        agenda: torch.LongTensor,
        target_values: List[List[str]] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the question ``TextField``. This will
           be passed through a ``TextFieldEmbedder`` and then through an encoder.
        table : ``Dict[str, torch.LongTensor]``
            The output of ``KnowledgeGraphField.as_array()`` applied on the table
            ``KnowledgeGraphField``.  This output is similar to a ``TextField`` output, where each
            entity in the table is treated as a "token", and we will use a ``TextFieldEmbedder`` to
            get embeddings for each entity.
        world : ``List[WikiTablesLanguage]``
            We use a ``MetadataField`` to get the ``WikiTablesLanguage`` object for each input instance.
            Because of how ``MetadataField`` works, this gets passed to us as a ``List[WikiTablesLanguage]``,
        actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``world`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        agenda : ``torch.LongTensor``
            Agenda vectors that the checklist vectors will be compared against to compute the checklist
            cost.
        target_values : ``List[List[str]]``, optional (default = None)
            For each instance, a list of target values taken from the example lisp string. We pass
            this list to the evaluator along with logical forms to compute denotation accuracy.
        metadata : ``List[Dict[str, Any]]``, optional (default = None)
            Metadata containing the original tokenized question within a 'question_tokens' field.
        """
        # Each instance's agenda is of size (agenda_size, 1)
        agenda_list = [a for a in agenda]
        checklist_states = []
        all_terminal_productions = [
            set(instance_world.terminal_productions.values())
            for instance_world in world
        ]
        max_num_terminals = max(
            [len(terminals) for terminals in all_terminal_productions])
        for instance_actions, instance_agenda, terminal_productions in zip(
                actions, agenda_list, all_terminal_productions):
            checklist_info = self._get_checklist_info(instance_agenda,
                                                      instance_actions,
                                                      terminal_productions,
                                                      max_num_terminals)
            checklist_target, terminal_actions, checklist_mask = checklist_info
            initial_checklist = checklist_target.new_zeros(
                checklist_target.size())
            checklist_states.append(
                ChecklistStatelet(
                    terminal_actions=terminal_actions,
                    checklist_target=checklist_target,
                    checklist_mask=checklist_mask,
                    checklist=initial_checklist,
                ))
        outputs: Dict[str, Any] = {}
        rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(
            question, table, world, actions, outputs)

        batch_size = len(rnn_state)
        initial_score = rnn_state[0].hidden_state.new_zeros(batch_size)
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),  # type: ignore
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=rnn_state,
            grammar_state=grammar_state,
            checklist_state=checklist_states,
            possible_actions=actions,
            extras=target_values,
            debug_info=None,
        )

        if target_values is not None:
            logger.warning(f"TARGET VALUES: {target_values}")
            trainer_outputs = self._decoder_trainer.decode(  # type: ignore
                initial_state, self._decoder_step,
                partial(self._get_state_cost, world))
            outputs.update(trainer_outputs)
        else:
            initial_state.debug_info = [[] for _ in range(batch_size)]
            batch_size = len(actions)
            agenda_indices = [actions_[:, 0].cpu().data for actions_ in agenda]
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            best_final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._decoder_step,
                keep_final_unfinished_states=False,
            )
            for i in range(batch_size):
                in_agenda_ratio = 0.0
                # Decoding may not have terminated with any completed logical forms, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i in best_final_states:
                    action_sequence = best_final_states[i][0].action_history[0]
                    action_strings = [
                        action_mapping[(i, action_index)]
                        for action_index in action_sequence
                    ]
                    instance_possible_actions = actions[i]
                    agenda_actions = []
                    for rule_id in agenda_indices[i]:
                        rule_id = int(rule_id)
                        if rule_id == -1:
                            continue
                        action_string = instance_possible_actions[rule_id][0]
                        agenda_actions.append(action_string)
                    actions_in_agenda = [
                        action in action_strings for action in agenda_actions
                    ]
                    if actions_in_agenda:
                        # Note: This means that when there are no actions on agenda, agenda coverage
                        # will be 0, not 1.
                        in_agenda_ratio = sum(actions_in_agenda) / len(
                            actions_in_agenda)
                self._agenda_coverage(in_agenda_ratio)

            self._compute_validation_outputs(actions, best_final_states, world,
                                             target_values, metadata, outputs)
        return outputs
Exemplo n.º 7
0
    def forward(
        self,  # type: ignore
        sentence: Dict[str, torch.LongTensor],
        worlds: List[List[NlvrLanguage]],
        actions: List[List[ProductionRule]],
        agenda: torch.LongTensor,
        identifier: List[str] = None,
        labels: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Decoder logic for producing type constrained target sequences that maximize coverage of
        their respective agendas, and minimize a denotation based loss.
        """
        if self._dynamic_cost_rate is not None:
            # This could be added back pretty easily with an EpochCallback passed to the Trainer (it
            # just has to set the epoch number on the model, which could then be queried in here).
            logger.warning(
                "Dynamic cost rate functionality was removed in AllenNLP 1.0. If you want this, "
                "use version 0.9.  We will just use the static checklist cost weight."
            )
        batch_size = len(worlds)

        initial_rnn_state = self._get_initial_rnn_state(sentence)
        initial_score_list = [agenda.new_zeros(1, dtype=torch.float) for i in range(batch_size)]
        # TODO (pradeep): Assuming all worlds give the same set of valid actions.
        initial_grammar_state = [
            self._create_grammar_state(worlds[i][0], actions[i]) for i in range(batch_size)
        ]

        label_strings = self._get_label_strings(labels) if labels is not None else None
        # Each instance's agenda is of size (agenda_size, 1)
        # TODO(mattg): It looks like the agenda is only ever used on the CPU.  In that case, it's a
        # waste to copy it to the GPU and then back, and this should probably be a MetadataField.
        agenda_list = [agenda[i] for i in range(batch_size)]
        initial_checklist_states = []
        for instance_actions, instance_agenda in zip(actions, agenda_list):
            checklist_info = self._get_checklist_info(instance_agenda, instance_actions)
            checklist_target, terminal_actions, checklist_mask = checklist_info

            initial_checklist = checklist_target.new_zeros(checklist_target.size())
            initial_checklist_states.append(
                ChecklistStatelet(
                    terminal_actions=terminal_actions,
                    checklist_target=checklist_target,
                    checklist_mask=checklist_mask,
                    checklist=initial_checklist,
                )
            )
        initial_state = CoverageState(
            batch_indices=list(range(batch_size)),
            action_history=[[] for _ in range(batch_size)],
            score=initial_score_list,
            rnn_state=initial_rnn_state,
            grammar_state=initial_grammar_state,
            possible_actions=actions,
            extras=label_strings,
            checklist_state=initial_checklist_states,
        )
        if not self.training:
            initial_state.debug_info = [[] for _ in range(batch_size)]

        agenda_data = [agenda_[:, 0].cpu().data for agenda_ in agenda_list]
        outputs = self._decoder_trainer.decode(  # type: ignore
            initial_state, self._decoder_step, partial(self._get_state_cost, worlds)
        )
        if identifier is not None:
            outputs["identifier"] = identifier
        best_final_states = outputs["best_final_states"]
        best_action_sequences = {}
        for batch_index, states in best_final_states.items():
            best_action_sequences[batch_index] = [state.action_history[0] for state in states]
        batch_action_strings = self._get_action_strings(actions, best_action_sequences)
        batch_denotations = self._get_denotations(batch_action_strings, worlds)
        if labels is not None:
            # We're either training or validating.
            self._update_metrics(
                action_strings=batch_action_strings,
                worlds=worlds,
                label_strings=label_strings,
                possible_actions=actions,
                agenda_data=agenda_data,
            )
        else:
            # We're testing.
            if metadata is not None:
                outputs["sentence_tokens"] = [x["sentence_tokens"] for x in metadata]
            outputs["debug_info"] = []
            for i in range(batch_size):
                outputs["debug_info"].append(best_final_states[i][0].debug_info[0])  # type: ignore
            outputs["best_action_strings"] = batch_action_strings
            outputs["denotations"] = batch_denotations
            action_mapping = {}
            for batch_index, batch_actions in enumerate(actions):
                for action_index, action in enumerate(batch_actions):
                    action_mapping[(batch_index, action_index)] = action[0]
            outputs["action_mapping"] = action_mapping
        return outputs
    def _compute_action_probabilities(
        self,  # type: ignore
        state: CoverageState,
        hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor
    ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            action_ids: List[int] = []
            if "global" in instance_actions:
                action_embeddings, output_action_embeddings, embedded_actions = instance_actions[
                    'global']

                # This embedding addition the only difference between the logic here and the
                # corresponding logic in the super class.
                embedding_addition = self._get_predicted_embedding_addition(
                    state.checklist_state[group_index], embedded_actions,
                    action_embeddings)
                addition = embedding_addition * self._checklist_multiplier
                predicted_action_embedding = predicted_action_embedding + addition

                # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
                # (embedding_dim, 1) matrix.
                embedded_action_logits = action_embeddings.mm(
                    predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
                action_ids += embedded_actions
            else:
                embedded_action_logits = None
                output_action_embeddings = None

            if 'linked' in instance_actions:
                linking_scores, type_embeddings, linked_actions = instance_actions[
                    'linked']
                action_ids += linked_actions
                # (num_question_tokens, 1)
                linked_action_logits = linking_scores.mm(
                    attention_weights[group_index].unsqueeze(-1)).squeeze(-1)

                linked_logits_addition = self._get_linked_logits_addition(
                    state.checklist_state[group_index], linked_actions,
                    linked_action_logits)

                addition = linked_logits_addition * self._linked_checklist_multiplier
                linked_action_logits = linked_action_logits + addition

                # The `output_action_embeddings` tensor gets used later as the input to the next
                # decoder step.  For linked actions, we don't have any action embedding, so we use
                # the entity type instead.
                if output_action_embeddings is None:
                    output_action_embeddings = type_embeddings
                else:
                    output_action_embeddings = torch.cat(
                        [output_action_embeddings, type_embeddings], dim=0)

                if self._mixture_feedforward is not None:
                    # The linked and global logits are combined with a mixture weight to prevent the
                    # linked_action_logits from dominating the embedded_action_logits if a softmax
                    # was applied on both together.
                    mixture_weight = self._mixture_feedforward(
                        hidden_state[group_index])
                    mix1 = torch.log(mixture_weight)
                    mix2 = torch.log(1 - mixture_weight)

                    entity_action_probs = torch.nn.functional.log_softmax(
                        linked_action_logits, dim=-1) + mix1
                    if embedded_action_logits is None:
                        current_log_probs = entity_action_probs
                    else:
                        embedded_action_probs = torch.nn.functional.log_softmax(
                            embedded_action_logits, dim=-1) + mix2
                        current_log_probs = torch.cat(
                            [embedded_action_probs, entity_action_probs],
                            dim=-1)
                else:
                    if embedded_action_logits is None:
                        current_log_probs = torch.nn.functional.log_softmax(
                            linked_action_logits, dim=-1)
                    else:
                        action_logits = torch.cat(
                            [embedded_action_logits, linked_action_logits],
                            dim=-1)
                        current_log_probs = torch.nn.functional.log_softmax(
                            action_logits, dim=-1)
            else:
                current_log_probs = torch.nn.functional.log_softmax(
                    embedded_action_logits, dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, current_log_probs,
                 output_action_embeddings, action_ids))
        return batch_results