def _compute_action_probabilities(
        self,
        state: GrammarBasedState,
        hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor,
    ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # We take a couple of extra arguments here because subclasses might use them.

        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            action_embeddings, output_action_embeddings, action_ids = instance_actions[
                "global"]
            # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
            # (embedding_dim, 1) matrix.
            action_logits = action_embeddings.mm(
                predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
            current_log_probs = torch.nn.functional.log_softmax(action_logits,
                                                                dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, current_log_probs,
                 output_action_embeddings, action_ids))
        return batch_results
    def _compute_action_probabilities(
        self,
        state: GrammarBasedState,
        hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor,
    ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[group_index]
            embedded_actions: List[int] = []

            output_action_embeddings = None
            embedded_action_logits = None
            current_log_probs = None

            if "global" in instance_actions:
                action_embeddings, output_action_embeddings, embedded_actions = instance_actions[
                    "global"
                ]
                # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
                # (embedding_dim, 1) matrix.
                embedded_action_logits = action_embeddings.mm(
                    predicted_action_embedding.unsqueeze(-1)
                ).squeeze(-1)
                action_ids = embedded_actions

            if "linked" in instance_actions:
                linking_scores, type_embeddings, linked_actions = instance_actions["linked"]
                action_ids = embedded_actions + linked_actions
                # linking_scores: (num_entities, num_question_tokens)
                # linked_action_logits: (num_entities, 1)
                linked_action_logits = linking_scores.mm(
                    attention_weights[group_index].unsqueeze(-1)
                ).squeeze(-1)

                # The `output_action_embeddings` tensor gets used later as the input to the next
                # decoder step.  For linked actions, we don't have any action embedding, so we use
                # the entity type instead.
                if output_action_embeddings is not None:
                    output_action_embeddings = torch.cat(
                        [output_action_embeddings, type_embeddings], dim=0
                    )
                else:
                    output_action_embeddings = type_embeddings

                if self._mixture_feedforward is not None:
                    # The linked and global logits are combined with a mixture weight to prevent the
                    # linked_action_logits from dominating the embedded_action_logits if a softmax
                    # was applied on both together.
                    mixture_weight = self._mixture_feedforward(hidden_state[group_index])
                    mix1 = torch.log(mixture_weight)
                    mix2 = torch.log(1 - mixture_weight)

                    entity_action_probs = (
                        torch.nn.functional.log_softmax(linked_action_logits, dim=-1) + mix1
                    )
                    if embedded_action_logits is not None:
                        embedded_action_probs = (
                            torch.nn.functional.log_softmax(embedded_action_logits, dim=-1) + mix2
                        )
                        current_log_probs = torch.cat(
                            [embedded_action_probs, entity_action_probs], dim=-1
                        )
                    else:
                        current_log_probs = entity_action_probs
                else:
                    if embedded_action_logits is not None:
                        action_logits = torch.cat(
                            [embedded_action_logits, linked_action_logits], dim=-1
                        )
                    else:
                        action_logits = linked_action_logits
                    current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1)
            else:
                action_logits = embedded_action_logits
                current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)
            )
        return batch_results