def _compute_action_probabilities( self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor, ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # We take a couple of extra arguments here because subclasses might use them. # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[ group_index] action_embeddings, output_action_embeddings, action_ids = instance_actions[ "global"] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. action_logits = action_embeddings.mm( predicted_action_embedding.unsqueeze(-1)).squeeze(-1) current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append( (group_index, log_probs, current_log_probs, output_action_embeddings, action_ids)) return batch_results
def _compute_action_probabilities( self, state: GrammarBasedState, hidden_state: torch.Tensor, attention_weights: torch.Tensor, predicted_action_embeddings: torch.Tensor, ) -> Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]]: # In this section we take our predicted action embedding and compare it to the available # actions in our current state (which might be different for each group element). For # computing action scores, we'll forget about doing batched / grouped computation, as it # adds too much complexity and doesn't speed things up, anyway, with the operations we're # doing here. This means we don't need any action masks, as we'll only get the right # lengths for what we're computing. group_size = len(state.batch_indices) actions = state.get_valid_actions() batch_results: Dict[int, List[Tuple[int, Any, Any, Any, List[int]]]] = defaultdict(list) for group_index in range(group_size): instance_actions = actions[group_index] predicted_action_embedding = predicted_action_embeddings[group_index] embedded_actions: List[int] = [] output_action_embeddings = None embedded_action_logits = None current_log_probs = None if "global" in instance_actions: action_embeddings, output_action_embeddings, embedded_actions = instance_actions[ "global" ] # This is just a matrix product between a (num_actions, embedding_dim) matrix and an # (embedding_dim, 1) matrix. embedded_action_logits = action_embeddings.mm( predicted_action_embedding.unsqueeze(-1) ).squeeze(-1) action_ids = embedded_actions if "linked" in instance_actions: linking_scores, type_embeddings, linked_actions = instance_actions["linked"] action_ids = embedded_actions + linked_actions # linking_scores: (num_entities, num_question_tokens) # linked_action_logits: (num_entities, 1) linked_action_logits = linking_scores.mm( attention_weights[group_index].unsqueeze(-1) ).squeeze(-1) # The `output_action_embeddings` tensor gets used later as the input to the next # decoder step. For linked actions, we don't have any action embedding, so we use # the entity type instead. if output_action_embeddings is not None: output_action_embeddings = torch.cat( [output_action_embeddings, type_embeddings], dim=0 ) else: output_action_embeddings = type_embeddings if self._mixture_feedforward is not None: # The linked and global logits are combined with a mixture weight to prevent the # linked_action_logits from dominating the embedded_action_logits if a softmax # was applied on both together. mixture_weight = self._mixture_feedforward(hidden_state[group_index]) mix1 = torch.log(mixture_weight) mix2 = torch.log(1 - mixture_weight) entity_action_probs = ( torch.nn.functional.log_softmax(linked_action_logits, dim=-1) + mix1 ) if embedded_action_logits is not None: embedded_action_probs = ( torch.nn.functional.log_softmax(embedded_action_logits, dim=-1) + mix2 ) current_log_probs = torch.cat( [embedded_action_probs, entity_action_probs], dim=-1 ) else: current_log_probs = entity_action_probs else: if embedded_action_logits is not None: action_logits = torch.cat( [embedded_action_logits, linked_action_logits], dim=-1 ) else: action_logits = linked_action_logits current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) else: action_logits = embedded_action_logits current_log_probs = torch.nn.functional.log_softmax(action_logits, dim=-1) # This is now the total score for each state after taking each action. We're going to # sort by this later, so it's important that this is the total score, not just the # score for the current action. log_probs = state.score[group_index] + current_log_probs batch_results[state.batch_indices[group_index]].append( (group_index, log_probs, current_log_probs, output_action_embeddings, action_ids) ) return batch_results