def push_action(self, state: ParserState, target_action_idx: int) -> None: """Used for updating the state with a target next action Args: state (ParserState): The state of the stack, buffer and action target_action_idx (int): Index of the action to process """ # Update action_stackrnn action_embedding = self.actions_lookup( cuda_utils.Variable(torch.LongTensor([target_action_idx])) ) state.action_stackrnn.push(action_embedding, Element(target_action_idx)) # Update stack_stackrnn if target_action_idx == self.shift_idx: # To SHIFT, # 1. Pop T from buffer # 2. Push T into stack state.is_open_NT.append(False) token_embedding, token = state.buffer_stackrnn.pop() state.stack_stackrnn.push(token_embedding, Element(token)) elif target_action_idx == self.reduce_idx: # To REDUCE # 1. Pop Ts from stack until hit NT # 2. Pop the open NT from stack and close it # 3. Compute compositionalRep and push into stack state.num_open_NT -= 1 popped_rep = [] nt_tree = [] while not state.is_open_NT[-1]: assert len(state.stack_stackrnn) > 0, "How come stack is empty!" state.is_open_NT.pop() top_of_stack = state.stack_stackrnn.pop() popped_rep.append(top_of_stack[0]) nt_tree.append(top_of_stack[1]) # pop the open NT and close it top_of_stack = state.stack_stackrnn.pop() popped_rep.append(top_of_stack[0]) nt_tree.append(top_of_stack[1]) state.is_open_NT.pop() state.is_open_NT.append(False) compostional_rep = self.p_compositional(popped_rep) combinedElement = Element(nt_tree) state.stack_stackrnn.push(compostional_rep, combinedElement) elif target_action_idx in self.valid_NT_idxs: # if this is root prediction and if that root is one # of the unsupported intents if ( len(state.predicted_actions_idx) == 1 and target_action_idx in self.ignore_subNTs_roots ): state.found_unsupported = True state.is_open_NT.append(True) state.num_open_NT += 1 state.stack_stackrnn.push(action_embedding, Element(target_action_idx)) else: assert "not a valid action: {}".format( self.actions_vocab.itos[target_action_idx] )
def populate_buffer(self): state = ParserState(self.parser) for _ in range(2): state.buffer_stackrnn.push(torch.zeros(1, 30), Element("Token")) return state
def forward( self, tokens: torch.Tensor, seq_lens: torch.Tensor, dict_feat: Optional[Tuple[torch.Tensor, ...]] = None, actions: Optional[List[List[int]]] = None, contextual_token_embeddings: Optional[torch.Tensor] = None, beam_size=1, top_k=1, ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """RNNG forward function. Args: tokens (torch.Tensor): list of tokens seq_lens (torch.Tensor): list of sequence lengths dict_feat (Optional[Tuple[torch.Tensor, ...]]): dictionary or gazetteer features for each token actions (Optional[List[List[int]]]): Used only during training. Oracle actions for the instances. Returns: list of top k tuple of predicted actions tensor and corresponding scores tensor. Tensor shape: (batch_size, action_length) (batch_size, action_length, number_of_actions) """ if self.stage != Stage.TEST: beam_size = 1 top_k = 1 if self.training: assert actions is not None, "actions must be provided for training" actions_idx_rev = list(reversed(actions[0])) else: torch.manual_seed(0) beam_size = max(beam_size, 1) # Reverse the order of input tokens. tokens_list_rev = torch.flip(tokens, [len(tokens.size()) - 1]) # Aggregate inputs for embedding module. embedding_input = [tokens] if dict_feat is not None: embedding_input.append(dict_feat) if contextual_token_embeddings is not None: embedding_input.append(contextual_token_embeddings) # Embed and reverse the order of tokens. token_embeddings = self.embedding(*embedding_input) token_embeddings = torch.flip(token_embeddings, [len(tokens.size()) - 1]) # Batch size is always = 1. So we squeeze the batch_size dimension. token_embeddings = token_embeddings.squeeze(0) tokens_list_rev = tokens_list_rev.squeeze(0) initial_state = ParserState(self) for i in range(token_embeddings.size()[0]): token_embedding = token_embeddings[i].unsqueeze(0) tok = tokens_list_rev[i] initial_state.buffer_stackrnn.push(token_embedding, Element(tok)) beam = [initial_state] while beam and any(not state.finished() for state in beam): # Stores plans for expansion as (score, state, action) plans: List[Tuple[float, ParserState, int]] = [] # Expand current beam states for state in beam: # Keep terminal states if state.finished(): plans.append((state.neg_prob, state, -1)) continue # translating Expression p_t = affine_transform({pbias, S, # stack_summary, B, buffer_summary, A, action_summary}); stack = state.stack_stackrnn stack_summary = stack.embedding() action_summary = state.action_stackrnn.embedding() buffer_summary = state.buffer_stackrnn.embedding() if self.dropout_layer.p > 0: stack_summary = self.dropout_layer(stack_summary) action_summary = self.dropout_layer(action_summary) buffer_summary = self.dropout_layer(buffer_summary) # feature for index of last open non-terminal last_open_NT_feature = torch.zeros(len(self.actions_vocab)) open_NT_exists = state.num_open_NT > 0 if ( len(stack) > 0 and open_NT_exists and self.ablation_use_last_open_NT_feature ): last_open_NT = None try: open_NT = state.is_open_NT[::-1].index(True) last_open_NT = stack.element_from_top(open_NT) except ValueError: pass if last_open_NT: last_open_NT_feature[last_open_NT.node] = 1.0 last_open_NT_feature = last_open_NT_feature.unsqueeze(0) summaries = [] if self.ablation_use_buffer: summaries.append(buffer_summary) if self.ablation_use_stack: summaries.append(stack_summary) if self.ablation_use_action: summaries.append(action_summary) if self.ablation_use_last_open_NT_feature: summaries.append(last_open_NT_feature) state.action_p = self.action_linear(torch.cat(summaries, dim=1)) log_probs = F.log_softmax(state.action_p, dim=1)[0] for action in self.valid_actions(state): plans.append( (state.neg_prob - log_probs[action].item(), state, action) ) beam = [] # Take actions to regenerate the beam for neg_prob, state, predicted_action_idx in sorted(plans)[:beam_size]: # Skip terminal states if state.finished(): beam.append(state) continue # Only branch out states when needed if beam_size > 1: state = state.copy() state.predicted_actions_idx.append(predicted_action_idx) target_action_idx = predicted_action_idx if self.training: assert ( len(actions_idx_rev) > 0 ), "Actions and tokens may not be in sync." target_action_idx = actions_idx_rev[-1] actions_idx_rev = actions_idx_rev[:-1] if ( self.constraints_ignore_loss_for_unsupported and state.found_unsupported ): pass else: state.action_scores.append(state.action_p) self.push_action(state, target_action_idx) state.neg_prob = neg_prob beam.append(state) # End for # End while assert len(beam) > 0, "How come beam is empty?" assert len(state.stack_stackrnn) == 1, "How come stack len is " + str( len(state.stack_stackrnn) ) assert len(state.buffer_stackrnn) == 0, "How come buffer len is " + str( len(state.buffer_stackrnn) ) # Unsqueeze to add batch dimension before returning. return [ ( cuda_utils.LongTensor(state.predicted_actions_idx).unsqueeze(0), torch.cat(state.action_scores).unsqueeze(0), ) for state in sorted(beam)[:top_k] ]
def forward( self, tokens: torch.Tensor, seq_lens: torch.Tensor, dict_feat: Optional[Tuple[torch.Tensor, ...]] = None, actions: Optional[List[List[int]]] = None, beam_size: int = 1, topk: int = 1, ): """RNNG forward function. Args: tokens (torch.Tensor): list of tokens seq_lens (torch.Tensor): list of sequence lengths dict_feat (Optional[Tuple[torch.Tensor, ...]]): dictionary or gazetteer features for each token actions (Optional[List[List[int]]]): Used only during training. Oracle actions for the instances. beam_size (int): Beam size; used only during inference topk (int) : Number of top results from the method. If beam_size is 1 this is 1. Returns: if topk == 1 tuple of list of predicted actions and list of corresponding scores else list of tuple of list of predicted actions and list of \ corresponding scores """ if self.training: assert beam_size == 1, "beam_size must be 1 during training" assert actions is not None, "actions must be provided for training" actions_idx_rev = list(reversed(actions[0])) else: torch.manual_seed(0) beam_size = max(beam_size, 1) # Reverse the order of indices along last axis before embedding lookup. tokens_list_rev = torch.flip(tokens, [len(tokens.size()) - 1]) dict_feat_rev = None if dict_feat: dict_ids, dict_weights, dict_lengths = dict_feat dict_ids_rev = torch.flip(dict_ids, [len(dict_ids.size()) - 1]) dict_weights_rev = torch.flip(dict_weights, [len(dict_weights.size()) - 1]) dict_lengths_rev = torch.flip(dict_lengths, [len(dict_lengths.size()) - 1]) dict_feat_rev = (dict_ids_rev, dict_weights_rev, dict_lengths_rev) embedding_input = ( [tokens_list_rev, dict_feat_rev] if dict_feat_rev is not None else [tokens_list_rev] ) tok_embeddings = self.embedding(*embedding_input) # Batch size is always = 1. So we squeeze the batch_size dimension. tok_embeddings = tok_embeddings.squeeze(0) tokens_list_rev = tokens_list_rev.squeeze(0) initial_state = ParserState(self) for i in range(tok_embeddings.size()[0]): tok_embedding = tok_embeddings[i].unsqueeze(0) tok = tokens_list_rev[i] initial_state.buffer_stackrnn.push(tok_embedding, Element(tok)) beam = [initial_state] while beam and any(not state.finished() for state in beam): # Stores plans for expansion as (score, state, action) plans: List[Tuple[float, ParserState, int]] = [] # Expand current beam states for state in beam: # Keep terminal states if state.finished(): plans.append((state.neg_prob, state, -1)) continue # translating Expression p_t = affine_transform({pbias, S, # stack_summary, B, buffer_summary, A, action_summary}); stack = state.stack_stackrnn stack_summary = stack.embedding() action_summary = state.action_stackrnn.embedding() buffer_summary = state.buffer_stackrnn.embedding() if self.dropout_layer.p > 0: stack_summary = self.dropout_layer(stack_summary) action_summary = self.dropout_layer(action_summary) buffer_summary = self.dropout_layer(buffer_summary) # feature for index of last open non-terminal last_open_NT_feature = torch.zeros(len(self.actions_vocab)) open_NT_exists = state.num_open_NT > 0 if ( len(stack) > 0 and open_NT_exists and self.ablation_use_last_open_NT_feature ): last_open_NT = None try: open_NT = state.is_open_NT[::-1].index(True) last_open_NT = stack.ele_from_top(open_NT) except ValueError: pass if last_open_NT: last_open_NT_feature[last_open_NT.node] = 1.0 last_open_NT_feature = last_open_NT_feature.unsqueeze(0) summaries = [] if self.ablation_use_buffer: summaries.append(buffer_summary) if self.ablation_use_stack: summaries.append(stack_summary) if self.ablation_use_action: summaries.append(action_summary) if self.ablation_use_last_open_NT_feature: summaries.append(last_open_NT_feature) action_p = self.action_linear(torch.cat(summaries, dim=1)) log_probs = F.log_softmax(action_p, dim=1)[0] for action in self.valid_actions(state): plans.append((state.neg_prob - log_probs[action], state, action)) beam = [] # Take actions to regenerate the beam for neg_prob, state, predicted_action_idx in sorted(plans)[:beam_size]: # Skip terminal states if state.finished(): beam.append(state) continue # Only branch out states when needed if beam_size > 1: state = state.copy() state.predicted_actions_idx.append(predicted_action_idx) target_action_idx = predicted_action_idx if self.training: assert ( len(actions_idx_rev) > 0 ), "Actions and tokens may not be in sync." target_action_idx = actions_idx_rev[-1] actions_idx_rev = actions_idx_rev[:-1] if ( self.constraints_ignore_loss_for_unsupported and state.found_unsupported ): pass else: state.action_scores.append(action_p) action_embedding = self.actions_lookup( cuda_utils.Variable(torch.LongTensor([target_action_idx])) ) state.action_stackrnn.push(action_embedding, Element(target_action_idx)) if target_action_idx == self.shift_idx: state.is_open_NT.append(False) tok_embedding, token = state.buffer_stackrnn.pop() state.stack_stackrnn.push(tok_embedding, Element(token)) elif target_action_idx == self.reduce_idx: state.num_open_NT -= 1 popped_rep = [] nt_tree = [] while not state.is_open_NT[-1]: assert len(state.stack_stackrnn) > 0, "How come stack is empty!" state.is_open_NT.pop() top_of_stack = state.stack_stackrnn.pop() popped_rep.append(top_of_stack[0]) nt_tree.append(top_of_stack[1]) # pop the open NT and close it top_of_stack = state.stack_stackrnn.pop() popped_rep.append(top_of_stack[0]) nt_tree.append(top_of_stack[1]) state.is_open_NT.pop() state.is_open_NT.append(False) compostional_rep = self.p_compositional(popped_rep) combinedElement = Element(nt_tree) state.stack_stackrnn.push(compostional_rep, combinedElement) elif target_action_idx in self.valid_NT_idxs: # if this is root prediction and if that root is one # of the unsupported intents if ( len(state.predicted_actions_idx) == 1 and target_action_idx in self.ignore_subNTs_roots ): state.found_unsupported = True state.is_open_NT.append(True) state.num_open_NT += 1 state.stack_stackrnn.push( action_embedding, Element(target_action_idx) ) else: assert "not a valid action: {}".format( self.actions_vocab.itos[target_action_idx] ) state.neg_prob = neg_prob beam.append(state) # End for # End while assert len(beam) > 0, "How come beam is empty?" assert len(state.stack_stackrnn) == 1, "How come stack len is " + str( len(state.stack_stackrnn) ) assert len(state.buffer_stackrnn) == 0, "How come buffer len is " + str( len(state.buffer_stackrnn) ) # Add batch dimension before returning. if topk <= 1: state = min(beam) return ( torch.LongTensor(state.predicted_actions_idx).unsqueeze(0), torch.cat(state.action_scores).unsqueeze(0), ) else: return [ ( torch.LongTensor(state.predicted_actions_idx).unsqueeze(0), torch.cat(state.action_scores).unsqueeze(0), ) for state in sorted(beam)[:topk] ]