def _compute_action_probabilities( self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # We take a couple of extra arguments here because subclasses might use them. # pylint: disable=unused-argument,no-self-use # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[ group_index] action_embeddings, output_action_embeddings, action_ids = instance_actions[ 'global'] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. action_logits = action_embeddings.mm( predicted_action_embedding.unsqueeze(-1)).squeeze(-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append( (group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def _compute_action_probabilities(self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # We take a couple of extra arguments here because subclasses might use them. # pylint: disable=unused-argument,no-self-use # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[group_index] action_embeddings, output_action_embeddings, action_ids = instance_actions['global'] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. action_logits = action_embeddings.mm(predicted_action_embedding.unsqueeze(-1)).squeeze(-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append((group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def _compute_action_probabilities( self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[ group_index] embedded_actions: List[int] = [] output_action_embeddings = None embedded_action_logits = None current_log_probs = None if 'global' in instance_actions: action_embeddings, output_action_embeddings, embedded_actions = instance_actions[ 'global'] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. embedded_action_logits = action_embeddings.mm( predicted_action_embedding.unsqueeze(-1)).squeeze(-1) action_ids = embedded_actions if 'linked' in instance_actions: linking_scores, type_embeddings, linked_actions = instance_actions[ 'linked'] action_ids = embedded_actions + linked_actions # (num_question_tokens, 1) linked_action_logits = linking_scores.mm( attention_weights[group_index].unsqueeze(-1)).squeeze(-1) # The `output_action_embeddings` tensor gets used later as the input to the next # decoder step. For linked actions, we don't have any action embedding, so we use # the entity type instead. if output_action_embeddings is not None: output_action_embeddings = torch.cat( [output_action_embeddings, type_embeddings], dim=0) else: output_action_embeddings = type_embeddings if self._mixture_feedforward is not None: # The linked and global logits are combined with a mixture weight to prevent the # linked_action_logits from dominating the embedded_action_logits if a softmax # was applied on both together. mixture_weight = self._mixture_feedforward( hidden_state[group_index]) mix1 = torch.log(mixture_weight) mix2 = torch.log(1 - mixture_weight) entity_action_probs = torch.nn.functional.log_softmax( linked_action_logits, dim=-1) + mix1 if embedded_action_logits is not None: embedded_action_probs = torch.nn.functional.log_softmax( embedded_action_logits, dim=-1) + mix2 current_log_probs = torch.cat( [embedded_action_probs, entity_action_probs], dim=-1) else: current_log_probs = entity_action_probs else: if embedded_action_logits is not None: action_logits = torch.cat( [embedded_action_logits, linked_action_logits], dim=-1) else: action_logits = linked_action_logits current_log_probs = torch.nn.functional.log_softmax( action_logits, dim=-1) else: action_logits = embedded_action_logits current_log_probs = torch.nn.functional.log_softmax( action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append( (group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def _take_first_step( self, state: GrammarBasedState, allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]: # We'll just do a projection from the current hidden state (which was initialized with the # final encoder output) to the number of start actions that we have, normalize those # logits, and use that as our score. We end up duplicating some of the logic from # `_compute_new_states` here, but we do things slightly differently, and it's easier to # just copy the parts we need than to try to re-use that code. # (group_size, hidden_dim) hidden_state = torch.stack( [rnn_state.hidden_state for rnn_state in state.rnn_state]) # (group_size, num_start_type) start_action_logits = self._start_type_predictor(hidden_state) log_probs = torch.nn.functional.log_softmax(start_action_logits, dim=-1) sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True) sorted_actions = sorted_actions.detach().cpu().numpy().tolist() if state.debug_info is not None: probs_cpu = log_probs.exp().detach().cpu().numpy().tolist() else: probs_cpu = [None] * len(state.batch_indices) # state.get_valid_actions() will return a list that is consistently sorted, so as along as # the set of valid start actions never changes, we can just match up the log prob indices # above with the position of each considered action, and we're good. valid_actions = state.get_valid_actions() considered_actions = [ actions['global'][2] for actions in valid_actions ] if len(considered_actions[0]) != self._num_start_types: raise RuntimeError( "Calculated wrong number of initial actions. Expected " f"{self._num_start_types}, found {len(considered_actions[0])}." ) best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list) for group_index, (batch_index, group_actions) in enumerate( zip(state.batch_indices, sorted_actions)): for action_index, action in enumerate(group_actions): # `action` is currently the index in `log_probs`, not the actual action ID. To get # the action ID, we need to go through `considered_actions`. action = considered_actions[group_index][action] if allowed_actions is not None and action not in allowed_actions[ group_index]: # This happens when our _decoder trainer_ wants us to only evaluate certain # actions, likely because they are the gold actions in this state. We just skip # emitting any state that isn't allowed by the trainer, because constructing the # new state can be expensive. continue best_next_states[batch_index].append( (group_index, action_index, action)) new_states = [] for batch_index, best_states in sorted(best_next_states.items()): for group_index, action_index, action in best_states: # We'll yield a bunch of states here that all have a `group_size` of 1, so that the # learning algorithm can decide how many of these it wants to keep, and it can just # regroup them later, as that's a really easy operation. new_score = state.score[group_index] + sorted_log_probs[ group_index, action_index] # This part is different from `_compute_new_states` - we're just passing through # the previous RNN state, as predicting the start type wasn't included in the # decoder RNN in the original model. new_rnn_state = state.rnn_state[group_index] new_state = state.new_state_from_group_index( group_index, action, new_score, new_rnn_state, considered_actions[group_index], probs_cpu[group_index], None) new_states.append(new_state) return new_states
def _take_first_step(self, state: GrammarBasedState, allowed_actions: List[Set[int]] = None) -> List[GrammarBasedState]: # We'll just do a projection from the current hidden state (which was initialized with the # final encoder output) to the number of start actions that we have, normalize those # logits, and use that as our score. We end up duplicating some of the logic from # `_compute_new_states` here, but we do things slightly differently, and it's easier to # just copy the parts we need than to try to re-use that code. # (group_size, hidden_dim) hidden_state = torch.stack([rnn_state.hidden_state for rnn_state in state.rnn_state]) # (group_size, num_start_type) start_action_logits = self._start_type_predictor(hidden_state) log_probs = torch.nn.functional.log_softmax(start_action_logits, dim=-1) sorted_log_probs, sorted_actions = log_probs.sort(dim=-1, descending=True) sorted_actions = sorted_actions.detach().cpu().numpy().tolist() if state.debug_info is not None: probs_cpu = log_probs.exp().detach().cpu().numpy().tolist() else: probs_cpu = [None] * len(state.batch_indices) # state.get_valid_actions() will return a list that is consistently sorted, so as along as # the set of valid start actions never changes, we can just match up the log prob indices # above with the position of each considered action, and we're good. valid_actions = state.get_valid_actions() considered_actions = [actions['global'][2] for actions in valid_actions] if len(considered_actions[0]) != self._num_start_types: raise RuntimeError("Calculated wrong number of initial actions. Expected " f"{self._num_start_types}, found {len(considered_actions[0])}.") best_next_states: Dict[int, List[Tuple[int, int, int]]] = defaultdict(list) for group_index, (batch_index, group_actions) in enumerate(zip(state.batch_indices, sorted_actions)): for action_index, action in enumerate(group_actions): # `action` is currently the index in `log_probs`, not the actual action ID. To get # the action ID, we need to go through `considered_actions`. action = considered_actions[group_index][action] if allowed_actions is not None and action not in allowed_actions[group_index]: # This happens when our _decoder trainer_ wants us to only evaluate certain # actions, likely because they are the gold actions in this state. We just skip # emitting any state that isn't allowed by the trainer, because constructing the # new state can be expensive. continue best_next_states[batch_index].append((group_index, action_index, action)) new_states = [] for batch_index, best_states in sorted(best_next_states.items()): for group_index, action_index, action in best_states: # We'll yield a bunch of states here that all have a `group_size` of 1, so that the # learning algorithm can decide how many of these it wants to keep, and it can just # regroup them later, as that's a really easy operation. new_score = state.score[group_index] + sorted_log_probs[group_index, action_index] # This part is different from `_compute_new_states` - we're just passing through # the previous RNN state, as predicting the start type wasn't included in the # decoder RNN in the original model. new_rnn_state = state.rnn_state[group_index] new_state = state.new_state_from_group_index(group_index, action, new_score, new_rnn_state, considered_actions[group_index], probs_cpu[group_index], None) new_states.append(new_state) return new_states