Example #1
0
    def _compute_action_probabilities(
        self, state: GrammarBasedState, hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor
    ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # We take a couple of extra arguments here because subclasses might use them.
        # pylint: disable=unused-argument,no-self-use

        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            action_embeddings, output_action_embeddings, action_ids = instance_actions[
                'global']
            # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
            # (embedding_dim, 1) matrix.
            action_logits = action_embeddings.mm(
                predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
            current_log_probs = torch.nn.functional.log_softmax(action_logits,
                                                                dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, current_log_probs,
                 output_action_embeddings, action_ids))
        return batch_results
    def _compute_action_probabilities(self,
                                      state: GrammarBasedState,
                                      hidden_state: torch.Tensor,
                                      attention_weights: torch.Tensor,
                                      predicted_action_embeddings: torch.Tensor
                                     ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # We take a couple of extra arguments here because subclasses might use them.
        # pylint: disable=unused-argument,no-self-use

        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[group_index]
            action_embeddings, output_action_embeddings, action_ids = instance_actions['global']
            # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
            # (embedding_dim, 1) matrix.
            action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
            current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append((group_index,
                                                                    log_probs,
                                                                    current_log_probs,
                                                                    output_action_embeddings,
                                                                    action_ids))
        return batch_results
Example #3
0
    def _compute_action_probabilities(
        self, state: GrammarBasedState, hidden_state: torch.Tensor,
        attention_weights: torch.Tensor,
        predicted_action_embeddings: torch.Tensor
    ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]:
        # In this section we take our predicted action embedding and compare it to the available
        # actions in our current state (which might be different for each group element).  For
        # computing action scores, we'll forget about doing batched / grouped computation, as it
        # adds too much complexity and doesn't speed things up, anyway, with the operations we're
        # doing here.  This means we don't need any action masks, as we'll only get the right
        # lengths for what we're computing.

        group_size = len(state.batch_indices)
        actions = state.get_valid_actions()

        batch_results: Dict[int, List[Tuple[int, Any, Any, Any,
                                            List[int]]]] = defaultdict(list)
        for group_index in range(group_size):
            instance_actions = actions[group_index]
            predicted_action_embedding = predicted_action_embeddings[
                group_index]
            embedded_actions: List[int] = []

            output_action_embeddings = None
            embedded_action_logits = None
            current_log_probs = None

            if 'global' in instance_actions:
                action_embeddings, output_action_embeddings, embedded_actions = instance_actions[
                    'global']
                # This is just a matrix product between a (num_actions, embedding_dim) matrix and an
                # (embedding_dim, 1) matrix.
                embedded_action_logits = action_embeddings.mm(
                    predicted_action_embedding.unsqueeze(-1)).squeeze(-1)
                action_ids = embedded_actions

            if 'linked' in instance_actions:
                linking_scores, type_embeddings, linked_actions = instance_actions[
                    'linked']
                action_ids = embedded_actions + linked_actions
                # (num_question_tokens, 1)
                linked_action_logits = linking_scores.mm(
                    attention_weights[group_index].unsqueeze(-1)).squeeze(-1)

                # The `output_action_embeddings` tensor gets used later as the input to the next
                # decoder step.  For linked actions, we don't have any action embedding, so we use
                # the entity type instead.
                if output_action_embeddings is not None:
                    output_action_embeddings = torch.cat(
                        [output_action_embeddings, type_embeddings], dim=0)
                else:
                    output_action_embeddings = type_embeddings

                if self._mixture_feedforward is not None:
                    # The linked and global logits are combined with a mixture weight to prevent the
                    # linked_action_logits from dominating the embedded_action_logits if a softmax
                    # was applied on both together.
                    mixture_weight = self._mixture_feedforward(
                        hidden_state[group_index])
                    mix1 = torch.log(mixture_weight)
                    mix2 = torch.log(1 - mixture_weight)

                    entity_action_probs = torch.nn.functional.log_softmax(
                        linked_action_logits, dim=-1) + mix1
                    if embedded_action_logits is not None:
                        embedded_action_probs = torch.nn.functional.log_softmax(
                            embedded_action_logits, dim=-1) + mix2
                        current_log_probs = torch.cat(
                            [embedded_action_probs, entity_action_probs],
                            dim=-1)
                    else:
                        current_log_probs = entity_action_probs
                else:
                    if embedded_action_logits is not None:
                        action_logits = torch.cat(
                            [embedded_action_logits, linked_action_logits],
                            dim=-1)
                    else:
                        action_logits = linked_action_logits
                    current_log_probs = torch.nn.functional.log_softmax(
                        action_logits, dim=-1)
            else:
                action_logits = embedded_action_logits
                current_log_probs = torch.nn.functional.log_softmax(
                    action_logits, dim=-1)

            # This is now the total score for each state after taking each action.  We're going to
            # sort by this later, so it's important that this is the total score, not just the
            # score for the current action.
            log_probs = state.score[group_index] + current_log_probs
            batch_results[state.batch_indices[group_index]].append(
                (group_index, log_probs, current_log_probs,
                 output_action_embeddings, action_ids))
        return batch_results
Example #4
0
    def _take_first_step(
            self,
            state: GrammarBasedState,
            allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]:
        # We'll just do a projection from the current hidden state (which was initialized with the
        # final encoder output) to the number of start actions that we have, normalize those
        # logits, and use that as our score.  We end up duplicating some of the logic from
        # `_compute_new_states` here, but we do things slightly differently, and it's easier to
        # just copy the parts we need than to try to re-use that code.

        # (group_size, hidden_dim)
        hidden_state = torch.stack(
            [rnn_state.hidden_state for rnn_state in state.rnn_state])
        # (group_size, num_start_type)
        start_action_logits = self._start_type_predictor(hidden_state)
        log_probs = torch.nn.functional.log_softmax(start_action_logits,
                                                    dim=-1)
        sorted_log_probs, sorted_actions = log_probs.sort(dim=-1,
                                                          descending=True)

        sorted_actions = sorted_actions.detach().cpu().numpy().tolist()
        if state.debug_info is not None:
            probs_cpu = log_probs.exp().detach().cpu().numpy().tolist()
        else:
            probs_cpu = [None] * len(state.batch_indices)

        # state.get_valid_actions() will return a list that is consistently sorted, so as along as
        # the set of valid start actions never changes, we can just match up the log prob indices
        # above with the position of each considered action, and we're good.
        valid_actions = state.get_valid_actions()
        considered_actions = [
            actions['global'][2] for actions in valid_actions
        ]
        if len(considered_actions[0]) != self._num_start_types:
            raise RuntimeError(
                "Calculated wrong number of initial actions.  Expected "
                f"{self._num_start_types}, found {len(considered_actions[0])}."
            )

        best_next_states: Dict[int, List[Tuple[int, int,
                                               int]]] = defaultdict(list)
        for group_index, (batch_index, group_actions) in enumerate(
                zip(state.batch_indices, sorted_actions)):
            for action_index, action in enumerate(group_actions):
                # `action` is currently the index in `log_probs`, not the actual action ID.  To get
                # the action ID, we need to go through `considered_actions`.
                action = considered_actions[group_index][action]
                if allowed_actions is not None and action not in allowed_actions[
                        group_index]:
                    # This happens when our _decoder trainer_ wants us to only evaluate certain
                    # actions, likely because they are the gold actions in this state.  We just skip
                    # emitting any state that isn't allowed by the trainer, because constructing the
                    # new state can be expensive.
                    continue
                best_next_states[batch_index].append(
                    (group_index, action_index, action))

        new_states = []
        for batch_index, best_states in sorted(best_next_states.items()):
            for group_index, action_index, action in best_states:
                # We'll yield a bunch of states here that all have a `group_size` of 1, so that the
                # learning algorithm can decide how many of these it wants to keep, and it can just
                # regroup them later, as that's a really easy operation.
                new_score = state.score[group_index] + sorted_log_probs[
                    group_index, action_index]

                # This part is different from `_compute_new_states` - we're just passing through
                # the previous RNN state, as predicting the start type wasn't included in the
                # decoder RNN in the original model.
                new_rnn_state = state.rnn_state[group_index]
                new_state = state.new_state_from_group_index(
                    group_index, action, new_score, new_rnn_state,
                    considered_actions[group_index], probs_cpu[group_index],
                    None)
                new_states.append(new_state)
        return new_states
    def _take_first_step(self,
                         state: GrammarBasedState,
                         allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]:
        # We'll just do a projection from the current hidden state (which was initialized with the
        # final encoder output) to the number of start actions that we have, normalize those
        # logits, and use that as our score.  We end up duplicating some of the logic from
        # `_compute_new_states` here, but we do things slightly differently, and it's easier to
        # just copy the parts we need than to try to re-use that code.

        # (group_size, hidden_dim)
        hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state])
        # (group_size, num_start_type)
        start_action_logits = self._start_type_predictor(hidden_state)
        log_probs = torch.nn.functional.log_softmax(start_action_logits, dim=-1)
        sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True)

        sorted_actions = sorted_actions.detach().cpu().numpy().tolist()
        if state.debug_info is not None:
            probs_cpu = log_probs.exp().detach().cpu().numpy().tolist()
        else:
            probs_cpu = [None] * len(state.batch_indices)

        # state.get_valid_actions() will return a list that is consistently sorted, so as along as
        # the set of valid start actions never changes, we can just match up the log prob indices
        # above with the position of each considered action, and we're good.
        valid_actions = state.get_valid_actions()
        considered_actions = [actions['global'][2] for actions in valid_actions]
        if len(considered_actions[0]) != self._num_start_types:
            raise RuntimeError("Calculated wrong number of initial actions.  Expected "
                               f"{self._num_start_types}, found {len(considered_actions[0])}.")

        best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list)
        for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices, sorted_actions)):
            for action_index, action in enumerate(group_actions):
                # `action` is currently the index in `log_probs`, not the actual action ID.  To get
                # the action ID, we need to go through `considered_actions`.
                action = considered_actions[group_index][action]
                if allowed_actions is not None and action not in allowed_actions[group_index]:
                    # This happens when our _decoder trainer_ wants us to only evaluate certain
                    # actions, likely because they are the gold actions in this state.  We just skip
                    # emitting any state that isn't allowed by the trainer, because constructing the
                    # new state can be expensive.
                    continue
                best_next_states[batch_index].append((group_index, action_index, action))

        new_states = []
        for batch_index, best_states in sorted(best_next_states.items()):
            for group_index, action_index, action in best_states:
                # We'll yield a bunch of states here that all have a `group_size` of 1, so that the
                # learning algorithm can decide how many of these it wants to keep, and it can just
                # regroup them later, as that's a really easy operation.
                new_score = state.score[group_index] + sorted_log_probs[group_index, action_index]

                # This part is different from `_compute_new_states` - we're just passing through
                # the previous RNN state, as predicting the start type wasn't included in the
                # decoder RNN in the original model.
                new_rnn_state = state.rnn_state[group_index]
                new_state = state.new_state_from_group_index(group_index,
                                                             action,
                                                             new_score,
                                                             new_rnn_state,
                                                             considered_actions[group_index],
                                                             probs_cpu[group_index],
                                                             None)
                new_states.append(new_state)
        return new_states