コード例 #1
0
    def _compute_action_probabilities(
        self,  # type: ignore
        state: CoverageState,
        hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor
    ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            action_embeddings, output_action_embeddings, action_ids = instance_actions[
                'global']

            # This embedding addition the only difference between the logic here and the
            # corresponding logic in the super class.
            embedding_addition = self._get_predicted_embedding_addition(
                state.checklist_state[group_index], action_ids,
                action_embeddings)
            addition = embedding_addition * self._checklist_multiplier
            predicted_action_embedding = predicted_action_embedding + addition

            # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
            # (embedding_dim, 1) matrix.
            action_logits = action_embeddings.mm(
                predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
            current_log_probs = torch.nn.functional.log_softmax(action_logits,
                                                                dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, current_log_probs,
                 output_action_embeddings, action_ids))
        return batch_results
    def _compute_action_probabilities(
        self,  # type: ignore
        state: CoverageState,
        hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor
    ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            action_ids: List[int] = []
            if "global" in instance_actions:
                action_embeddings, output_action_embeddings, embedded_actions = instance_actions[
                    'global']

                # This embedding addition the only difference between the logic here and the
                # corresponding logic in the super class.
                embedding_addition = self._get_predicted_embedding_addition(
                    state.checklist_state[group_index], embedded_actions,
                    action_embeddings)
                addition = embedding_addition * self._checklist_multiplier
                predicted_action_embedding = predicted_action_embedding + addition

                # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
                # (embedding_dim, 1) matrix.
                embedded_action_logits = action_embeddings.mm(
                    predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
                action_ids += embedded_actions
            else:
                embedded_action_logits = None
                output_action_embeddings = None

            if 'linked' in instance_actions:
                linking_scores, type_embeddings, linked_actions = instance_actions[
                    'linked']
                action_ids += linked_actions
                # (num_question_tokens, 1)
                linked_action_logits = linking_scores.mm(
                    attention_weights[group_index].unsqueeze(-1)).squeeze(-1)

                linked_logits_addition = self._get_linked_logits_addition(
                    state.checklist_state[group_index], linked_actions,
                    linked_action_logits)

                addition = linked_logits_addition * self._linked_checklist_multiplier
                linked_action_logits = linked_action_logits + addition

                # The `output_action_embeddings` tensor gets used later as the input to the next
                # decoder step.  For linked actions, we don't have any action embedding, so we use
                # the entity type instead.
                if output_action_embeddings is None:
                    output_action_embeddings = type_embeddings
                else:
                    output_action_embeddings = torch.cat(
                        [output_action_embeddings, type_embeddings], dim=0)

                if self._mixture_feedforward is not None:
                    # The linked and global logits are combined with a mixture weight to prevent the
                    # linked_action_logits from dominating the embedded_action_logits if a softmax
                    # was applied on both together.
                    mixture_weight = self._mixture_feedforward(
                        hidden_state[group_index])
                    mix1 = torch.log(mixture_weight)
                    mix2 = torch.log(1 - mixture_weight)

                    entity_action_probs = torch.nn.functional.log_softmax(
                        linked_action_logits, dim=-1) + mix1
                    if embedded_action_logits is None:
                        current_log_probs = entity_action_probs
                    else:
                        embedded_action_probs = torch.nn.functional.log_softmax(
                            embedded_action_logits, dim=-1) + mix2
                        current_log_probs = torch.cat(
                            [embedded_action_probs, entity_action_probs],
                            dim=-1)
                else:
                    if embedded_action_logits is None:
                        current_log_probs = torch.nn.functional.log_softmax(
                            linked_action_logits, dim=-1)
                    else:
                        action_logits = torch.cat(
                            [embedded_action_logits, linked_action_logits],
                            dim=-1)
                        current_log_probs = torch.nn.functional.log_softmax(
                            action_logits, dim=-1)
            else:
                current_log_probs = torch.nn.functional.log_softmax(
                    embedded_action_logits, dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, current_log_probs,
                 output_action_embeddings, action_ids))
        return batch_results