class DeepAPI(nn.Module): ''' model. ''' def __init__(self, config, vocab_size): super(DeepAPI, self).__init__() self.vocab_size = vocab_size self.maxlen = config['maxlen'] self.clip = config['clip'] self.temp = config['temp'] self.desc_embedder = nn.Embedding(vocab_size, config['emb_size'], padding_idx=PAD_ID) self.api_embedder = nn.Embedding(vocab_size, config['emb_size'], padding_idx=PAD_ID) # utter encoder: encode response to vector self.encoder = Encoder(self.desc_embedder, config['emb_size'], config['n_hidden'], True, config['n_layers'], config['noise_radius']) self.decoder = Decoder(self.api_embedder, config['emb_size'], config['n_hidden'] * 2, vocab_size, config['use_attention'], 1, config['dropout']) # utter decoder: P(x|c,z) self.optimizer = optim.Adadelta(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=config['lr_ae'], rho=0.95) self.criterion_ce = nn.CrossEntropyLoss() def forward(self, descs, desc_lens, apiseqs, api_lens): c, hids = self.encoder(descs, desc_lens) output, _ = self.decoder(c, hids, None, apiseqs[:, :-1], (api_lens - 1)) # decode from z, c # output: [batch x seq_len x n_tokens] output = output.view(-1, self.vocab_size) # [batch*seq_len x n_tokens] dec_target = apiseqs[:, 1:].contiguous().view(-1) mask = dec_target.gt(0) # [(batch_sz*seq_len)] masked_target = dec_target.masked_select(mask) # output_mask = mask.unsqueeze(1).expand( mask.size(0), self.vocab_size) # [(batch_sz*seq_len) x n_tokens] masked_output = output.masked_select(output_mask).view( -1, self.vocab_size) loss = self.criterion_ce(masked_output / self.temp, masked_target) return loss def train_AE(self, descs, desc_lens, apiseqs, api_lens): self.encoder.train() self.decoder.train() loss = self.forward(descs, desc_lens, apiseqs, api_lens) self.optimizer.zero_grad() loss.backward() # `clip_grad_norm` to prevent exploding gradient in RNNs / LSTMs torch.nn.utils.clip_grad_norm_( list(self.encoder.parameters()) + list(self.decoder.parameters()), self.clip) self.optimizer.step() return {'train_loss': loss.item()} def valid(self, descs, desc_lens, apiseqs, api_lens): self.encoder.eval() self.decoder.eval() loss = self.forward(descs, desc_lens, apiseqs, api_lens) return {'valid_loss': loss.item()} def sample(self, descs, desc_lens, n_samples, mode='beamsearch'): self.encoder.eval() self.decoder.eval() c, hids = self.encoder(descs, desc_lens) if mode == 'beamsearch': sample_words, sample_lens, _ = self.decoder.beam_decode( c, hids, None, 12, self.maxlen, n_samples) #[batch_size x n_samples x seq_len] sample_words, sample_lens = sample_words[0], sample_lens[0] else: sample_words, sample_lens = self.decoder.sampling( c, hids, None, n_samples, self.maxlen, mode) return sample_words, sample_lens def adjust_lr(self): #self.lr_scheduler_AE.step() return None
class Transformer_EncoderDecoder(nn.Module): """ 标准的Encoder-Decoder架构 """ def __init__(self, config): super(Transformer_EncoderDecoder, self).__init__() c = copy.deepcopy self.attn = MultiHeadedAttention(config['head'], config['emb_dim']) self.ff = PositionwiseFeedForward(config['emb_dim'], config['d_ff'], config['drop_out']) self.position = PositionalEncoding(config['emb_dim'], config['drop_out']) self.encoder = Encoder( EncoderLayer(config['emb_dim'], c(self.attn), c(self.ff), config['drop_out']), config['N_layers']) self.decoder = Decoder( DecoderLayer(config['emb_dim'], c(self.attn), c(self.attn), c(self.ff), config['drop_out']), config['N_layers']) self.src_embed = nn.Sequential( Embeddings(config['emb_dim'], config['vocab_size']), c(self.position)) self.tgt_embed = nn.Sequential( Embeddings(config['emb_dim'], config['vocab_size']), c(self.position)) self.generator = Generator(config['emb_dim'], config['vocab_size']) self.fc_out = nn.Linear(config['emb_dim'], config['vocab_size']) self.model = EncoderDecoder(self.encoder, self.decoder, self.src_embed, self.tgt_embed, self.generator) # forward函数调用自身encode方法实现encoder,然后调用decode方式实现decoder def forward(self, src, src_lens, tgt, tar_lens, pad=0): #src的shape=tgt的shape:[batch_size,max_length] "Take in and process masked src and target sequences." # 随机初始化参数,这非常重要 for p in self.model.parameters(): if p.dim() > 1: nn.init.xavier_uniform(p) #生成mask #src:[batch_size,max_legth] self.src = src self.src_mask = (src != pad).unsqueeze(-2) if tgt is not None: #self.trg表示去掉每行的最后一个单词=====》相当于t-1时刻 self.tgt = tgt[:, :-1] #self.trg_y表示去掉每行的第一个单词=====》相当于t时刻 #decode 就是使用encoder和t-1时刻去预测t时刻 self.tgt_y = tgt[:, 1:] self.tgt_mask = \ self.make_std_mask(self.tgt, pad) self.ntokens = (self.tgt_y != pad).data.sum() #模型workflow output = self.decode(self.encode(src, self.src_mask), self.src_mask, tgt[:, :-1], self.tgt_mask) output = self.fc_out(output) #[3,49,10000] dec_tgt = tgt[:, 1:].clone() # [3,49] dec_tgt[tgt[:, 1:] == PAD_ID] = -100 # 对矩阵等于pad_id的进行填充为-100,这样在计算后可以降低影响 loss = nn.CrossEntropyLoss()(output.view(-1, 10000) / 1.0, dec_tgt.view(-1)) return loss def init_weights(self, m): # Initialize Linear Weight for GAN if isinstance(m, nn.Linear): m.weight.data.uniform_(-0.08, 0.08) #nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0.) def valid(self, src_seqs, src_lens, target, tar_lens): self.eval() loss = self.forward(src_seqs, src_lens, target, tar_lens) return {'valid_loss': loss.item()} def sample(self, src_seqs, src_lens, n_samples, decode_mode='beamsearch'): self.eval() src_pad_mask = src_seqs.eq(PAD_ID) c, hids = self.encoder(src_seqs, src_lens) init_h, hids = self.ctx2dec(c), self.ctx2dec(hids) if decode_mode == 'beamsearch': sample_words, sample_lens, _ = self.decoder.beam_decode( init_h, hids, src_pad_mask, None, 12, self.maxlen, n_samples) #[batch_size x n_samples x seq_len] sample_words, sample_lens = sample_words[0], sample_lens[0] else: sample_words, sample_lens = self.decoder.sampling( init_h, hids, src_pad_mask, None, self.maxlen, decode_mode) return sample_words, sample_lens def adjust_lr(self): #self.lr_scheduler_AE.step() return None def encode(self, src, src_mask): return self.encoder(self.src_embed(src), src_mask) def decode(self, memory, src_mask, tgt, tgt_mask): return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) @staticmethod def make_std_mask(tgt, pad): "创建Mask,使得我们不能attend to未来的词" tgt_mask = (tgt != pad).unsqueeze(-2) tgt_mask = tgt_mask & Variable( subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) return tgt_mask