def _compute_action_probabilities( self, # type: ignore state: CoverageState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[ group_index] action_embeddings, output_action_embeddings, action_ids = instance_actions[ 'global'] # This embedding addition the only difference between the logic here and the # corresponding logic in the super class. embedding_addition = self._get_predicted_embedding_addition( state.checklist_state[group_index], action_ids, action_embeddings) addition = embedding_addition * self._checklist_multiplier predicted_action_embedding = predicted_action_embedding + addition # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. action_logits = action_embeddings.mm( predicted_action_embedding.unsqueeze(-1)).squeeze(-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append( (group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def _compute_action_probabilities( self, # type: ignore state: CoverageState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[ group_index] action_ids: List[int] = [] if "global" in instance_actions: action_embeddings, output_action_embeddings, embedded_actions = instance_actions[ 'global'] # This embedding addition the only difference between the logic here and the # corresponding logic in the super class. embedding_addition = self._get_predicted_embedding_addition( state.checklist_state[group_index], embedded_actions, action_embeddings) addition = embedding_addition * self._checklist_multiplier predicted_action_embedding = predicted_action_embedding + addition # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. embedded_action_logits = action_embeddings.mm( predicted_action_embedding.unsqueeze(-1)).squeeze(-1) action_ids += embedded_actions else: embedded_action_logits = None output_action_embeddings = None if 'linked' in instance_actions: linking_scores, type_embeddings, linked_actions = instance_actions[ 'linked'] action_ids += linked_actions # (num_question_tokens, 1) linked_action_logits = linking_scores.mm( attention_weights[group_index].unsqueeze(-1)).squeeze(-1) linked_logits_addition = self._get_linked_logits_addition( state.checklist_state[group_index], linked_actions, linked_action_logits) addition = linked_logits_addition * self._linked_checklist_multiplier linked_action_logits = linked_action_logits + addition # The `output_action_embeddings` tensor gets used later as the input to the next # decoder step. For linked actions, we don't have any action embedding, so we use # the entity type instead. if output_action_embeddings is None: output_action_embeddings = type_embeddings else: output_action_embeddings = torch.cat( [output_action_embeddings, type_embeddings], dim=0) if self._mixture_feedforward is not None: # The linked and global logits are combined with a mixture weight to prevent the # linked_action_logits from dominating the embedded_action_logits if a softmax # was applied on both together. mixture_weight = self._mixture_feedforward( hidden_state[group_index]) mix1 = torch.log(mixture_weight) mix2 = torch.log(1 - mixture_weight) entity_action_probs = torch.nn.functional.log_softmax( linked_action_logits, dim=-1) + mix1 if embedded_action_logits is None: current_log_probs = entity_action_probs else: embedded_action_probs = torch.nn.functional.log_softmax( embedded_action_logits, dim=-1) + mix2 current_log_probs = torch.cat( [embedded_action_probs, entity_action_probs], dim=-1) else: if embedded_action_logits is None: current_log_probs = torch.nn.functional.log_softmax( linked_action_logits, dim=-1) else: action_logits = torch.cat( [embedded_action_logits, linked_action_logits], dim=-1) current_log_probs = torch.nn.functional.log_softmax( action_logits, dim=-1) else: current_log_probs = torch.nn.functional.log_softmax( embedded_action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append( (group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results