示例#1
0
 def forward(self, input, grow=False):
     """
     Inputs Shapes: 
         input: batch_size x len_src (wanna tranpose)
     
     Outputs Shapes:
         out: batch_size x len_src x d_model
         mask_src 
         
     """
     
     if grow:
         return self.forward_grow(input)
     
     
     """ Embedding: batch_size x len_src x d_model """
     emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0)
     """ Scale the emb by sqrt(d_model) """
     
     if self.time == 'positional_encoding':
         emb = emb * math.sqrt(self.model_size)
     """ Adding positional encoding """
     emb = self.time_transformer(emb)
     if isinstance(emb, tuple):
         emb = emb[0]
     emb = self.preprocess_layer(emb)
     
     mask_src = input.data.eq(onmt.Constants.PAD).unsqueeze(1) # batch_size x len_src x 1 for broadcasting
     
     pad_mask = torch.autograd.Variable(input.data.ne(onmt.Constants.PAD)) # batch_size x len_src
     #pad_mask = None
     
     context = emb.contiguous()
     
     memory_bank = list()
     
     for i, layer in enumerate(self.layer_modules):
         
         
         if len(self.layer_modules) - i <= onmt.Constants.checkpointing and self.training:        
             context = checkpoint(custom_layer(layer), context, mask_src, pad_mask)
             
             #print(type(context))
         else:
             #print(i, layer.death_rate)
             context = layer(context, mask_src, pad_mask)      # batch_size x len_src x d_model
         
     
             
     
     # From Google T2T
     # if normalization is done in layer_preprocess, then it should also be done
     # on the output, since the output can grow very large, being the sum of
     # a whole stack of unnormalized layer outputs.    
     context = self.postprocess_layer(context)
     
    
     return context, mask_src
示例#2
0
    def forward(self, input, context, src, grow=False):
        """
        Inputs Shapes: 
            input: (Variable) batch_size x len_tgt (wanna tranpose)
            context: (Variable) batch_size x len_src x d_model
            mask_src (Tensor) batch_size x len_src
        Outputs Shapes:
            out: batch_size x len_tgt x d_model
            coverage: batch_size x len_tgt x len_src
            
        """
        
        """ Embedding: batch_size x len_tgt x d_model """
        
        if grow:
            return self.forward_grow(input, context, src)

        
        emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0)
        if self.time == 'positional_encoding':
            emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        emb = self.time_transformer(emb)
        if isinstance(emb, tuple):
            emb = emb[0]
        emb = self.preprocess_layer(emb)
        

        mask_src = src.data.eq(onmt.Constants.PAD).unsqueeze(1)
        
        pad_mask_src = torch.autograd.Variable(src.data.ne(onmt.Constants.PAD))
        
        len_tgt = input.size(1)
        mask_tgt = input.data.eq(onmt.Constants.PAD).unsqueeze(1) + self.mask[:len_tgt, :len_tgt]
        mask_tgt = torch.gt(mask_tgt, 0)
        
        output = emb.contiguous()
        
        pad_mask_tgt = torch.autograd.Variable(input.data.ne(onmt.Constants.PAD)) # batch_size x len_src
        pad_mask_src = torch.autograd.Variable(1 - mask_src.squeeze(1))
        
        #~ memory_bank = None
        
        
        for i, layer in enumerate(self.layer_modules):
            
            if len(self.layer_modules) - i <= onmt.Constants.checkpointing and self.training:           
                
                output, coverage = checkpoint(custom_layer(layer), output, context[i], mask_tgt, mask_src, 
                                            pad_mask_tgt, pad_mask_src) # batch_size x len_src x d_model
                
            else:
                output, coverage = layer(output, context[i], mask_tgt, mask_src, 
                                            pad_mask_tgt, pad_mask_src) # batch_size x len_src x d_model
            
            
        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.    
        output = self.postprocess_layer(output)
        
        return output, coverage