class ParallelTransformerEncoder(nn.Module):
    """Encoder in 'Attention is all you need'
    
    Args:
        opt: list of options ( see train.py )
        dicts : dictionary (for source language)
        
    """
    
    def __init__(self, opt, dicts, positional_encoder):
    
        super(ParallelTransformerEncoder, self).__init__()
        
        self.model_size = opt.model_size
        self.n_heads = opt.n_heads
        self.inner_size = opt.inner_size
        self.layers = opt.layers
        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.time = opt.time
        
        if hasattr(opt, 'grow_dropout'):
            self.grow_dropout = opt.grow_dropout
        
        self.word_lut = nn.Embedding(dicts.size(),
                                     self.model_size,
                                     padding_idx=onmt.Constants.PAD)
        
        if opt.time == 'positional_encoding':
            self.time_transformer = positional_encoder
        elif opt.time == 'gru':
            self.time_transformer = nn.GRU(self.model_size, self.model_size, 1, batch_first=True)
        elif opt.time == 'lstm':
            self.time_transformer = nn.LSTM(self.model_size, self.model_size, 1, batch_first=True)
        
        #~ self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d', static=False)
        self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d', static=onmt.Constants.static)
        
        self.postprocess_layer = PrePostProcessing(self.model_size, 0, sequence='n')
        
        self.positional_encoder = positional_encoder
        
        self.layer_modules = nn.ModuleList([ParallelEncoderLayer(self.n_heads, self.model_size, self.dropout, self.inner_size, self.attn_dropout) for _ in range(self.layers)])
    
    def add_layers(self, n_new_layer):
        
        self.new_modules = list()
        self.layers += n_new_layer
        
        for i in range(n_new_layer):
            layer = ParallelEncoderLayer(self.n_heads, self.model_size, self.dropout, self.inner_size, self.attn_dropout) 
            
            # the first layer will use the preprocessing which is the last postprocessing
            if i == 0:
                layer.preprocess_attn.load_state_dict(self.postprocess_layer.state_dict())
                #~ layer.preprocess_attn.layer_norm.function.weight.requires_grad = False
                #~ layer.preprocess_attn.layer_norm.function.bias.requires_grad = False
                #~ if hasattr(layer.postprocess_attn, 'k'):
                    #~ layer.postprocess_attn.k.data.fill_(0.01)
                
                # replace the last postprocessing layer with a new one
                self.postprocess_layer = PrePostProcessing(self.model_size, 0, sequence='n')
            
            self.layer_modules.append(layer)
    
    def mark_pretrained(self):
        
        self.pretrained_point = self.layers
    
    def forward(self, input, grow=False):
        """
        Inputs Shapes: 
            input: batch_size x len_src (wanna tranpose)
        
        Outputs Shapes:
            out: batch_size x len_src x d_model
            mask_src 
            
        """
        
        if grow:
            return self.forward_grow(input)
        
        
        """ Embedding: batch_size x len_src x d_model """
        emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0)
        """ Scale the emb by sqrt(d_model) """
        
        if self.time == 'positional_encoding':
            emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        emb = self.time_transformer(emb)
        if isinstance(emb, tuple):
            emb = emb[0]
        emb = self.preprocess_layer(emb)
        
        mask_src = input.data.eq(onmt.Constants.PAD).unsqueeze(1) # batch_size x len_src x 1 for broadcasting
        
        pad_mask = torch.autograd.Variable(input.data.ne(onmt.Constants.PAD)) # batch_size x len_src
        #~ pad_mask = None
        
        context = emb.contiguous()
        
        memory_bank = list()
        
        for i, layer in enumerate(self.layer_modules):
            
            
            if len(self.layer_modules) - i <= onmt.Constants.checkpointing and self.training:        
                context, norm_input = checkpoint(custom_layer(layer), context, mask_src, pad_mask)
                
                #~ print(type(context))
            else:
                context, norm_input = layer(context, mask_src, pad_mask)      # batch_size x len_src x d_model
            
            if i > 0: # don't keep the norm input of the first layer (a.k.a embedding)
                memory_bank.append(norm_input)
                
        
        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.    
        context = self.postprocess_layer(context)
        
        # make a huge memory bank on the encoder side
        memory_bank.append(context)
        
        memory_bank = torch.stack(memory_bank)
            
        
        return memory_bank, mask_src
        
    def forward_grow(self, input):
        """
        Inputs Shapes: 
            input: batch_size x len_src (wanna tranpose)
        
        Outputs Shapes:
            out: batch_size x len_src x d_model
            mask_src 
            
        """
        
        with torch.no_grad():
            """ Embedding: batch_size x len_src x d_model """
            emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0)
            """ Scale the emb by sqrt(d_model) """
            
            if self.time == 'positional_encoding':
                emb = emb * math.sqrt(self.model_size)
            """ Adding positional encoding """
            emb = self.time_transformer(emb)
            if isinstance(emb, tuple):
                emb = emb[0]
            emb = self.preprocess_layer(emb)
            
            mask_src = input.data.eq(onmt.Constants.PAD).unsqueeze(1) # batch_size x len_src x 1 for broadcasting
            
            pad_mask = torch.autograd.Variable(input.data.ne(onmt.Constants.PAD)) # batch_size x len_src
            #~ pad_mask = None
            
            context = emb.contiguous()
            
            memory_bank = list()
            
            for i in range(self.pretrained_point):
                
                layer = self.layer_modules[i]
                
                context, norm_input = layer(context, mask_src, pad_mask)      # batch_size x len_src x d_model
                
                if i > 0: # don't keep the norm input of the first layer (a.k.a embedding)
                    memory_bank.append(norm_input)
                    
        
        for i in range(self.layers - self.pretrained_point):
            
            res_drop_rate = 0.0
            if i == 0:
                res_drop_rate = self.grow_dropout
            
            layer = self.layer_modules[self.pretrained_point + i]
            
            context, norm_input = layer(context, mask_src, pad_mask, residual_dropout=res_drop_rate)      # batch_size x len_src x d_model
            
            memory_bank.append(norm_input)
        
        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.    
        context = self.postprocess_layer(context)
        
        # make a huge memory bank on the encoder side
        memory_bank.append(context)
        
        memory_bank = torch.stack(memory_bank)
            
        
        return memory_bank, mask_src
class ParallelTransformerDecoder(nn.Module):
    """Encoder in 'Attention is all you need'
    
    Args:
        opt
        dicts 
        
        
    """
    
    def __init__(self, opt, dicts, positional_encoder):
    
        super(ParallelTransformerDecoder, self).__init__()
        
        self.model_size = opt.model_size
        self.n_heads = opt.n_heads
        self.inner_size = opt.inner_size
        self.layers = opt.layers
        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout 
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.time = opt.time
        
        if hasattr(opt, 'grow_dropout'):
            self.grow_dropout = opt.grow_dropout
        
        if opt.time == 'positional_encoding':
            self.time_transformer = positional_encoder
        elif opt.time == 'gru':
            self.time_transformer = nn.GRU(self.model_size, self.model_size, 1, batch_first=True)
        elif opt.time == 'lstm':
            self.time_transformer = nn.LSTM(self.model_size, self.model_size, 1, batch_first=True)
        
        #~ self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d', static=False)
        self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d', static=onmt.Constants.static)
        self.postprocess_layer = PrePostProcessing(self.model_size, 0, sequence='n')
        
        self.word_lut = nn.Embedding(dicts.size(),
                                     self.model_size,
                                     padding_idx=onmt.Constants.PAD)
        
        self.positional_encoder = positional_encoder
        
        self.layer_modules = nn.ModuleList([DecoderLayer(self.n_heads, self.model_size, self.dropout, self.inner_size, self.attn_dropout) for _ in range(self.layers)])
        
        len_max = self.positional_encoder.len_max
        mask = torch.ByteTensor(np.triu(np.ones((len_max,len_max)), k=1).astype('uint8'))
        self.register_buffer('mask', mask)
    
    def renew_buffer(self, new_len):
        
        self.positional_encoder.renew(new_len)
        mask = torch.ByteTensor(np.triu(np.ones((new_len,new_len)), k=1).astype('uint8'))
        self.register_buffer('mask', mask)
    
    def mark_pretrained(self):
        
        self.pretrained_point = self.layers
        
    
    def add_layers(self, n_new_layer):
        
        self.new_modules = list()
        self.layers += n_new_layer
        
        for i in range(n_new_layer):
            layer = DecoderLayer(self.n_heads, self.model_size, self.dropout, self.inner_size, self.attn_dropout) 
            # the first layer will use the preprocessing which is the last postprocessing
            if i == 0:
                # layer.preprocess_attn = self.postprocess_layer
                layer.preprocess_attn.load_state_dict(self.postprocess_layer.state_dict())
                #~ layer.preprocess_attn.layer_norm.function.weight.requires_grad = False
                #~ layer.preprocess_attn.layer_norm.function.bias.requires_grad = False
                # replace the last postprocessing layer with a new one
                #~ if hasattr(layer.postprocess_attn, 'k'):
                    #~ layer.postprocess_attn.k.data.fill_(0.01)
                
                self.postprocess_layer = PrePostProcessing(self.model_size, 0, sequence='n')
            
            self.layer_modules.append(layer)
        
    def forward(self, input, context, src, grow=False):
        """
        Inputs Shapes: 
            input: (Variable) batch_size x len_tgt (wanna tranpose)
            context: (Variable) batch_size x len_src x d_model
            mask_src (Tensor) batch_size x len_src
        Outputs Shapes:
            out: batch_size x len_tgt x d_model
            coverage: batch_size x len_tgt x len_src
            
        """
        
        """ Embedding: batch_size x len_tgt x d_model """
        
        if grow:
            return self.forward_grow(input, context, src)

        
        emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0)
        if self.time == 'positional_encoding':
            emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        emb = self.time_transformer(emb)
        if isinstance(emb, tuple):
            emb = emb[0]
        emb = self.preprocess_layer(emb)
        

        mask_src = src.data.eq(onmt.Constants.PAD).unsqueeze(1)
        
        pad_mask_src = torch.autograd.Variable(src.data.ne(onmt.Constants.PAD))
        
        len_tgt = input.size(1)
        mask_tgt = input.data.eq(onmt.Constants.PAD).unsqueeze(1) + self.mask[:len_tgt, :len_tgt]
        mask_tgt = torch.gt(mask_tgt, 0)
        
        output = emb.contiguous()
        
        pad_mask_tgt = torch.autograd.Variable(input.data.ne(onmt.Constants.PAD)) # batch_size x len_src
        pad_mask_src = torch.autograd.Variable(1 - mask_src.squeeze(1))
        
        #~ memory_bank = None
        
        
        for i, layer in enumerate(self.layer_modules):
            
            if len(self.layer_modules) - i <= onmt.Constants.checkpointing and self.training:           
                
                output, coverage = checkpoint(custom_layer(layer), output, context[i], mask_tgt, mask_src, 
                                            pad_mask_tgt, pad_mask_src) # batch_size x len_src x d_model
                
            else:
                output, coverage = layer(output, context[i], mask_tgt, mask_src, 
                                            pad_mask_tgt, pad_mask_src) # batch_size x len_src x d_model
            
            
        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.    
        output = self.postprocess_layer(output)
        
        return output, coverage
        
    def forward_grow(self, input, context, src):
        """
        Inputs Shapes: 
            input: (Variable) batch_size x len_tgt (wanna tranpose)
            context: (Variable) batch_size x len_src x d_model
            mask_src (Tensor) batch_size x len_src
        Outputs Shapes:
            out: batch_size x len_tgt x d_model
            coverage: batch_size x len_tgt x len_src
            
        """
        
        """ Embedding: batch_size x len_tgt x d_model """
        
        with torch.no_grad():
        
            emb = embedded_dropout(self.word_lut, input, dropout=self.word_dropout if self.training else 0)
            if self.time == 'positional_encoding':
                emb = emb * math.sqrt(self.model_size)
            """ Adding positional encoding """
            emb = self.time_transformer(emb)
            if isinstance(emb, tuple):
                emb = emb[0]
            emb = self.preprocess_layer(emb)
            

            mask_src = src.data.eq(onmt.Constants.PAD).unsqueeze(1)
            
            pad_mask_src = torch.autograd.Variable(src.data.ne(onmt.Constants.PAD))
            
            len_tgt = input.size(1)
            mask_tgt = input.data.eq(onmt.Constants.PAD).unsqueeze(1) + self.mask[:len_tgt, :len_tgt]
            mask_tgt = torch.gt(mask_tgt, 0)
            
            output = emb.contiguous()
            
            pad_mask_tgt = torch.autograd.Variable(input.data.ne(onmt.Constants.PAD)) # batch_size x len_src
            pad_mask_src = torch.autograd.Variable(1 - mask_src.squeeze(1))
            
            
            for i in range(self.pretrained_point):
                
                layer = self.layer_modules[i]
                
                output, coverage = layer(output, context[i], mask_tgt, mask_src, 
                                                pad_mask_tgt, pad_mask_src) # batch_size x len_src x d_model
            
        
        for i in range(self.layers - self.pretrained_point):
            
            res_drop_rate = 0.0
            if i == 0:
                res_drop_rate = self.grow_dropout
            
            layer = self.layer_modules[self.pretrained_point + i]    
            output, coverage = layer(output, context[self.pretrained_point + i], mask_tgt, mask_src, 
                                                pad_mask_tgt, pad_mask_src, residual_dropout=res_drop_rate) # batch_size x len_src x d_model
        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.    
        output = self.postprocess_layer(output)
        
        return output, coverage

    #~ def step(self, input, context, src, buffer=None):
    def step(self, input, decoder_state):
        """
        Inputs Shapes: 
            input: (Variable) batch_size x len_tgt (wanna tranpose)
            context: (Variable) batch_size x len_src x d_model
            mask_src (Tensor) batch_size x len_src
            buffer (List of tensors) List of batch_size * len_tgt-1 * d_model for self-attention recomputing
        Outputs Shapes:
            out: batch_size x len_tgt x d_model
            coverage: batch_size x len_tgt x len_src
            
        """
        # note: transpose 1-2 because the first dimension (0) is the number of layer
        context = decoder_state.context.transpose(1, 2)
        buffer = decoder_state.buffer
        src = decoder_state.src.transpose(0, 1)
        
        if decoder_state.input_seq is None:
            decoder_state.input_seq = input
        else:
            # concatenate the last input to the previous input sequence
            decoder_state.input_seq = torch.cat([decoder_state.input_seq, input], 0)
        input = decoder_state.input_seq.transpose(0, 1)
        input_ = input[:,-1].unsqueeze(1)
            
        output_buffer = list()
            
        batch_size = input.size(0)
        
        
        input_ = input[:,-1].unsqueeze(1)
        # print(input_.size())
        """ Embedding: batch_size x 1 x d_model """
        emb = self.word_lut(input_)
        
        
        if self.time == 'positional_encoding':
            emb = emb * math.sqrt(self.model_size)
        """ Adding positional encoding """
        if self.time == 'positional_encoding':
            emb = self.time_transformer(emb, t=input.size(1))
        else:
            prev_h = buffer[0] if buffer is None else None
            emb = self.time_transformer(emb, prev_h)
            buffer[0] = emb[1]
            
        if isinstance(emb, tuple):
            emb = emb[0] # emb should be batch_size x 1 x dim
        
            
        # Preprocess layer: adding dropout
        emb = self.preprocess_layer(emb)
        
        # batch_size x 1 x len_src
        mask_src = src.data.eq(onmt.Constants.PAD).unsqueeze(1)
        
        pad_mask_src = torch.autograd.Variable(src.data.ne(onmt.Constants.PAD))
        
        len_tgt = input.size(1)
        mask_tgt = input.data.eq(onmt.Constants.PAD).unsqueeze(1) + self.mask[:len_tgt, :len_tgt]
        # mask_tgt = self.mask[:len_tgt, :len_tgt].unsqueeze(0).repeat(batch_size, 1, 1)
        mask_tgt = torch.gt(mask_tgt, 0)
        mask_tgt = mask_tgt[:, -1, :].unsqueeze(1)
                
        output = emb.contiguous()
        
        pad_mask_tgt = torch.autograd.Variable(input.data.ne(onmt.Constants.PAD)) # batch_size x len_src
        pad_mask_src = torch.autograd.Variable(1 - mask_src.squeeze(1))
        
        memory_bank = None
        
        for i, layer in enumerate(self.layer_modules):
            
            buffer_ = buffer[i] if buffer is not None else None
            assert(output.size(1) == 1)
            output, coverage, buffer_ = layer.step(output, context[i], mask_tgt, mask_src, 
                                        pad_mask_tgt=None, pad_mask_src=None, buffer=buffer_) # batch_size x len_src x d_model
            
            output_buffer.append(buffer_)
            
        
        
        buffer = torch.stack(output_buffer)
        # From Google T2T
        # if normalization is done in layer_preprocess, then it should also be done
        # on the output, since the output can grow very large, being the sum of
        # a whole stack of unnormalized layer outputs.    
        output = self.postprocess_layer(output)
        
        decoder_state._update_state(buffer)    
        
        return output, coverage