def get_loss(self, logits: torch.Tensor, target_actions: torch.Tensor, context: torch.Tensor): """ Shapes: logits[1]: action scores: (1, sequence_length, number_of_actions) target_actions: (1, sequence_length) """ # Supports beam search to check if there are top K predictions # (there will be an extra dimension) try: top_k_exists = logits[0][0][0][0] if top_k_exists: action_scores = logits[0][1].squeeze(0) target_actions = target_actions[0].squeeze(0) except (TypeError, IndexError): # Get rid of the batch dimension action_scores = logits[1].squeeze(0) target_actions = target_actions.squeeze(0) action_scores_list = torch.chunk(action_scores, action_scores.size()[0]) target_vars = [ cuda_utils.Variable(torch.LongTensor([t])) for t in target_actions ] losses = [ self.loss_func(action, target).view(1) for action, target in zip(action_scores_list, target_vars) ] total_loss = torch.sum(torch.cat(losses)) if len(losses) > 0 else None return total_loss
def _word_forward(self, inputs: torch.Tensor, word_idx: int) -> torch.Tensor: # inputs -> (batch, words, embed_dim) start_idx = word_idx - self.fwd_bwd_ctxt_len word_with_bwd_context = inputs.narrow(1, start_idx, self.fwd_bwd_ctxt_len + 1) word_with_fwd_context = inputs.narrow(1, word_idx, self.fwd_bwd_ctxt_len + 1) start_idx = word_idx - self.surr_ctxt_len word_with_surr_context = inputs.narrow(1, start_idx, 2 * self.surr_ctxt_len + 1) padding = cuda_utils.Variable( torch.cat([self.padding_tensor] * inputs.size()[0])) conv_in_bwd_context = torch.cat((word_with_bwd_context, padding), dim=1) conv_in_fwd_context = torch.cat((padding, word_with_fwd_context), dim=1) bwd_ctxt_rep = self._conv_maxpool(conv_in_bwd_context, self.convs_bwd) fwd_ctxt_rep = self._conv_maxpool(conv_in_fwd_context, self.convs_fwd) surr_ctxt_rep = self._conv_maxpool(word_with_surr_context, self.convs_surr) # Full representation by combining all contextual represenations. return torch.cat( ( self.bwd_fc(bwd_ctxt_rep), self.fwd_fc(fwd_ctxt_rep), self.surr_fc(surr_ctxt_rep), ), dim=1, )
def get_loss( self, logits: torch.Tensor, target_actions: torch.Tensor, context: torch.Tensor ): # action scores is a 2D Tensor of dims sequence_length x number_of_actions # targets is a 1D list of integers of length sequence_length # Get rid of the batch dimension action_scores = logits[1].squeeze(0) target_actions = target_actions.squeeze(0) action_scores_list = torch.chunk(action_scores, action_scores.size()[0]) target_vars = [ cuda_utils.Variable(torch.LongTensor([t])) for t in target_actions ] losses = [ self.loss_func(action, target).view(1) for action, target in zip(action_scores_list, target_vars) ] total_loss = torch.sum(torch.cat(losses)) if len(losses) > 0 else None return total_loss
def __init__(self, config: Config, embed_dim: int) -> None: super().__init__(config) self.fwd_bwd_ctxt_len = config.fwd_bwd_context_len self.surr_ctxt_len = config.surrounding_context_len self.ctxt_pad_len = max(self.fwd_bwd_ctxt_len, self.surr_ctxt_len) self.padding_tensor = cuda_utils.Variable(torch.Tensor( 1, self.fwd_bwd_ctxt_len, embed_dim), requires_grad=False) self.padding_tensor.fill_(0) bwd_convs, fwd_convs, surr_convs = [], [], [] in_channels = 1 out_channels = config.cnn.kernel_num kernel_sizes = config.cnn.kernel_sizes for k in kernel_sizes: bwd_convs.append( nn.Conv2d(in_channels, out_channels, (k, embed_dim))) fwd_convs.append( nn.Conv2d(in_channels, out_channels, (k, embed_dim))) surr_convs.append( nn.Conv2d(in_channels, out_channels, (k, embed_dim))) self.convs_bwd = nn.ModuleList(bwd_convs) self.convs_fwd = nn.ModuleList(fwd_convs) self.convs_surr = nn.ModuleList(surr_convs) # Token representation size with each context. token_rep_len = len(kernel_sizes) * out_channels self.bwd_fc = nn.Linear(token_rep_len, token_rep_len) self.fwd_fc = nn.Linear(token_rep_len, token_rep_len) self.surr_fc = nn.Linear(token_rep_len, token_rep_len) self.ctxt_pad = nn.ConstantPad1d( (self.ctxt_pad_len, self.ctxt_pad_len), 0) self.representation_dim = 3 * len(kernel_sizes) * out_channels
def push_action(self, state: ParserState, target_action_idx: int) -> None: """Used for updating the state with a target next action Args: state (ParserState): The state of the stack, buffer and action target_action_idx (int): Index of the action to process """ # Update action_stackrnn action_embedding = self.actions_lookup( cuda_utils.Variable(torch.LongTensor([target_action_idx]))) state.action_stackrnn.push(action_embedding, Element(target_action_idx)) # Update stack_stackrnn if target_action_idx == self.shift_idx: # To SHIFT, # 1. Pop T from buffer # 2. Push T into stack state.is_open_NT.append(False) token_embedding, token = state.buffer_stackrnn.pop() state.stack_stackrnn.push(token_embedding, Element(token)) elif target_action_idx == self.reduce_idx: # To REDUCE # 1. Pop Ts from stack until hit NT # 2. Pop the open NT from stack and close it # 3. Compute compositionalRep and push into stack state.num_open_NT -= 1 popped_rep = [] nt_tree = [] while not state.is_open_NT[-1]: assert len( state.stack_stackrnn) > 0, "How come stack is empty!" state.is_open_NT.pop() top_of_stack = state.stack_stackrnn.pop() popped_rep.append(top_of_stack[0]) nt_tree.append(top_of_stack[1]) # pop the open NT and close it top_of_stack = state.stack_stackrnn.pop() popped_rep.append(top_of_stack[0]) nt_tree.append(top_of_stack[1]) state.is_open_NT.pop() state.is_open_NT.append(False) compostional_rep = self.p_compositional(popped_rep) combinedElement = Element(nt_tree) state.stack_stackrnn.push(compostional_rep, combinedElement) elif target_action_idx in self.valid_NT_idxs: # if this is root prediction and if that root is one # of the unsupported intents if (len(state.predicted_actions_idx) == 1 and target_action_idx in self.ignore_subNTs_roots): state.found_unsupported = True state.is_open_NT.append(True) state.num_open_NT += 1 state.stack_stackrnn.push(action_embedding, Element(target_action_idx)) else: assert "not a valid action: {}".format( self.actions_vocab.itos[target_action_idx])
def forward( self, tokens: torch.Tensor, seq_lens: torch.Tensor, dict_feat: Optional[Tuple[torch.Tensor, ...]] = None, actions: Optional[List[List[int]]] = None, beam_size: int = 1, topk: int = 1, ): """RNNG forward function. Args: tokens (torch.Tensor): list of tokens seq_lens (torch.Tensor): list of sequence lengths dict_feat (Optional[Tuple[torch.Tensor, ...]]): dictionary or gazetteer features for each token actions (Optional[List[List[int]]]): Used only during training. Oracle actions for the instances. beam_size (int): Beam size; used only during inference topk (int) : Number of top results from the method. If beam_size is 1 this is 1. Returns: if topk == 1 tuple of list of predicted actions and list of corresponding scores else list of tuple of list of predicted actions and list of \ corresponding scores """ if self.training: assert beam_size == 1, "beam_size must be 1 during training" assert actions is not None, "actions must be provided for training" actions_idx_rev = list(reversed(actions[0])) else: torch.manual_seed(0) beam_size = max(beam_size, 1) # Reverse the order of indices along last axis before embedding lookup. tokens_list_rev = torch.flip(tokens, [len(tokens.size()) - 1]) dict_feat_rev = None if dict_feat: dict_ids, dict_weights, dict_lengths = dict_feat dict_ids_rev = torch.flip(dict_ids, [len(dict_ids.size()) - 1]) dict_weights_rev = torch.flip(dict_weights, [len(dict_weights.size()) - 1]) dict_lengths_rev = torch.flip(dict_lengths, [len(dict_lengths.size()) - 1]) dict_feat_rev = (dict_ids_rev, dict_weights_rev, dict_lengths_rev) embedding_input = ( [tokens_list_rev, dict_feat_rev] if dict_feat_rev is not None else [tokens_list_rev] ) tok_embeddings = self.embedding(*embedding_input) # Batch size is always = 1. So we squeeze the batch_size dimension. tok_embeddings = tok_embeddings.squeeze(0) tokens_list_rev = tokens_list_rev.squeeze(0) initial_state = ParserState(self) for i in range(tok_embeddings.size()[0]): tok_embedding = tok_embeddings[i].unsqueeze(0) tok = tokens_list_rev[i] initial_state.buffer_stackrnn.push(tok_embedding, Element(tok)) beam = [initial_state] while beam and any(not state.finished() for state in beam): # Stores plans for expansion as (score, state, action) plans: List[Tuple[float, ParserState, int]] = [] # Expand current beam states for state in beam: # Keep terminal states if state.finished(): plans.append((state.neg_prob, state, -1)) continue # translating Expression p_t = affine_transform({pbias, S, # stack_summary, B, buffer_summary, A, action_summary}); stack = state.stack_stackrnn stack_summary = stack.embedding() action_summary = state.action_stackrnn.embedding() buffer_summary = state.buffer_stackrnn.embedding() if self.dropout_layer.p > 0: stack_summary = self.dropout_layer(stack_summary) action_summary = self.dropout_layer(action_summary) buffer_summary = self.dropout_layer(buffer_summary) # feature for index of last open non-terminal last_open_NT_feature = torch.zeros(len(self.actions_vocab)) open_NT_exists = state.num_open_NT > 0 if ( len(stack) > 0 and open_NT_exists and self.ablation_use_last_open_NT_feature ): last_open_NT = None try: open_NT = state.is_open_NT[::-1].index(True) last_open_NT = stack.ele_from_top(open_NT) except ValueError: pass if last_open_NT: last_open_NT_feature[last_open_NT.node] = 1.0 last_open_NT_feature = last_open_NT_feature.unsqueeze(0) summaries = [] if self.ablation_use_buffer: summaries.append(buffer_summary) if self.ablation_use_stack: summaries.append(stack_summary) if self.ablation_use_action: summaries.append(action_summary) if self.ablation_use_last_open_NT_feature: summaries.append(last_open_NT_feature) action_p = self.action_linear(torch.cat(summaries, dim=1)) log_probs = F.log_softmax(action_p, dim=1)[0] for action in self.valid_actions(state): plans.append((state.neg_prob - log_probs[action], state, action)) beam = [] # Take actions to regenerate the beam for neg_prob, state, predicted_action_idx in sorted(plans)[:beam_size]: # Skip terminal states if state.finished(): beam.append(state) continue # Only branch out states when needed if beam_size > 1: state = state.copy() state.predicted_actions_idx.append(predicted_action_idx) target_action_idx = predicted_action_idx if self.training: assert ( len(actions_idx_rev) > 0 ), "Actions and tokens may not be in sync." target_action_idx = actions_idx_rev[-1] actions_idx_rev = actions_idx_rev[:-1] if ( self.constraints_ignore_loss_for_unsupported and state.found_unsupported ): pass else: state.action_scores.append(action_p) action_embedding = self.actions_lookup( cuda_utils.Variable(torch.LongTensor([target_action_idx])) ) state.action_stackrnn.push(action_embedding, Element(target_action_idx)) if target_action_idx == self.shift_idx: state.is_open_NT.append(False) tok_embedding, token = state.buffer_stackrnn.pop() state.stack_stackrnn.push(tok_embedding, Element(token)) elif target_action_idx == self.reduce_idx: state.num_open_NT -= 1 popped_rep = [] nt_tree = [] while not state.is_open_NT[-1]: assert len(state.stack_stackrnn) > 0, "How come stack is empty!" state.is_open_NT.pop() top_of_stack = state.stack_stackrnn.pop() popped_rep.append(top_of_stack[0]) nt_tree.append(top_of_stack[1]) # pop the open NT and close it top_of_stack = state.stack_stackrnn.pop() popped_rep.append(top_of_stack[0]) nt_tree.append(top_of_stack[1]) state.is_open_NT.pop() state.is_open_NT.append(False) compostional_rep = self.p_compositional(popped_rep) combinedElement = Element(nt_tree) state.stack_stackrnn.push(compostional_rep, combinedElement) elif target_action_idx in self.valid_NT_idxs: # if this is root prediction and if that root is one # of the unsupported intents if ( len(state.predicted_actions_idx) == 1 and target_action_idx in self.ignore_subNTs_roots ): state.found_unsupported = True state.is_open_NT.append(True) state.num_open_NT += 1 state.stack_stackrnn.push( action_embedding, Element(target_action_idx) ) else: assert "not a valid action: {}".format( self.actions_vocab.itos[target_action_idx] ) state.neg_prob = neg_prob beam.append(state) # End for # End while assert len(beam) > 0, "How come beam is empty?" assert len(state.stack_stackrnn) == 1, "How come stack len is " + str( len(state.stack_stackrnn) ) assert len(state.buffer_stackrnn) == 0, "How come buffer len is " + str( len(state.buffer_stackrnn) ) # Add batch dimension before returning. if topk <= 1: state = min(beam) return ( torch.LongTensor(state.predicted_actions_idx).unsqueeze(0), torch.cat(state.action_scores).unsqueeze(0), ) else: return [ ( torch.LongTensor(state.predicted_actions_idx).unsqueeze(0), torch.cat(state.action_scores).unsqueeze(0), ) for state in sorted(beam)[:topk] ]