Ejemplo n.º 1
0
class DeepAPI(nn.Module):
    ''' model. '''
    def __init__(self, config, vocab_size):
        super(DeepAPI, self).__init__()
        self.vocab_size = vocab_size
        self.maxlen = config['maxlen']
        self.clip = config['clip']
        self.temp = config['temp']

        self.desc_embedder = nn.Embedding(vocab_size,
                                          config['emb_size'],
                                          padding_idx=PAD_ID)
        self.api_embedder = nn.Embedding(vocab_size,
                                         config['emb_size'],
                                         padding_idx=PAD_ID)
        # utter encoder: encode response to vector
        self.encoder = Encoder(self.desc_embedder, config['emb_size'],
                               config['n_hidden'], True, config['n_layers'],
                               config['noise_radius'])
        self.decoder = Decoder(self.api_embedder, config['emb_size'],
                               config['n_hidden'] * 2, vocab_size,
                               config['use_attention'], 1,
                               config['dropout'])  # utter decoder: P(x|c,z)
        self.optimizer = optim.Adadelta(list(self.encoder.parameters()) +
                                        list(self.decoder.parameters()),
                                        lr=config['lr_ae'],
                                        rho=0.95)
        self.criterion_ce = nn.CrossEntropyLoss()

    def forward(self, descs, desc_lens, apiseqs, api_lens):
        c, hids = self.encoder(descs, desc_lens)
        output, _ = self.decoder(c, hids, None, apiseqs[:, :-1],
                                 (api_lens - 1))
        # decode from z, c  # output: [batch x seq_len x n_tokens]
        output = output.view(-1, self.vocab_size)  # [batch*seq_len x n_tokens]

        dec_target = apiseqs[:, 1:].contiguous().view(-1)
        mask = dec_target.gt(0)  # [(batch_sz*seq_len)]
        masked_target = dec_target.masked_select(mask)  #
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz*seq_len) x n_tokens]

        masked_output = output.masked_select(output_mask).view(
            -1, self.vocab_size)
        loss = self.criterion_ce(masked_output / self.temp, masked_target)
        return loss

    def train_AE(self, descs, desc_lens, apiseqs, api_lens):
        self.encoder.train()
        self.decoder.train()

        loss = self.forward(descs, desc_lens, apiseqs, api_lens)

        self.optimizer.zero_grad()
        loss.backward()
        # `clip_grad_norm` to prevent exploding gradient in RNNs / LSTMs
        torch.nn.utils.clip_grad_norm_(
            list(self.encoder.parameters()) + list(self.decoder.parameters()),
            self.clip)
        self.optimizer.step()
        return {'train_loss': loss.item()}

    def valid(self, descs, desc_lens, apiseqs, api_lens):
        self.encoder.eval()
        self.decoder.eval()
        loss = self.forward(descs, desc_lens, apiseqs, api_lens)
        return {'valid_loss': loss.item()}

    def sample(self, descs, desc_lens, n_samples, mode='beamsearch'):
        self.encoder.eval()
        self.decoder.eval()
        c, hids = self.encoder(descs, desc_lens)
        if mode == 'beamsearch':
            sample_words, sample_lens, _ = self.decoder.beam_decode(
                c, hids, None, 12, self.maxlen, n_samples)
            #[batch_size x n_samples x seq_len]
            sample_words, sample_lens = sample_words[0], sample_lens[0]
        else:
            sample_words, sample_lens = self.decoder.sampling(
                c, hids, None, n_samples, self.maxlen, mode)
        return sample_words, sample_lens

    def adjust_lr(self):
        #self.lr_scheduler_AE.step()
        return None
class Transformer_EncoderDecoder(nn.Module):
    """
    标准的Encoder-Decoder架构
    """
    def __init__(self, config):
        super(Transformer_EncoderDecoder, self).__init__()
        c = copy.deepcopy
        self.attn = MultiHeadedAttention(config['head'], config['emb_dim'])
        self.ff = PositionwiseFeedForward(config['emb_dim'], config['d_ff'],
                                          config['drop_out'])
        self.position = PositionalEncoding(config['emb_dim'],
                                           config['drop_out'])
        self.encoder = Encoder(
            EncoderLayer(config['emb_dim'], c(self.attn), c(self.ff),
                         config['drop_out']), config['N_layers'])
        self.decoder = Decoder(
            DecoderLayer(config['emb_dim'], c(self.attn), c(self.attn),
                         c(self.ff), config['drop_out']), config['N_layers'])
        self.src_embed = nn.Sequential(
            Embeddings(config['emb_dim'], config['vocab_size']),
            c(self.position))
        self.tgt_embed = nn.Sequential(
            Embeddings(config['emb_dim'], config['vocab_size']),
            c(self.position))
        self.generator = Generator(config['emb_dim'], config['vocab_size'])
        self.fc_out = nn.Linear(config['emb_dim'], config['vocab_size'])

        self.model = EncoderDecoder(self.encoder, self.decoder, self.src_embed,
                                    self.tgt_embed, self.generator)

        # forward函数调用自身encode方法实现encoder,然后调用decode方式实现decoder
    def forward(self, src, src_lens, tgt, tar_lens, pad=0):
        #src的shape=tgt的shape:[batch_size,max_length]
        "Take in and process masked src and target sequences."
        # 随机初始化参数,这非常重要
        for p in self.model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform(p)

        #生成mask
        #src:[batch_size,max_legth]
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if tgt is not None:
            #self.trg表示去掉每行的最后一个单词=====》相当于t-1时刻
            self.tgt = tgt[:, :-1]

            #self.trg_y表示去掉每行的第一个单词=====》相当于t时刻
            #decode 就是使用encoder和t-1时刻去预测t时刻
            self.tgt_y = tgt[:, 1:]
            self.tgt_mask = \
                self.make_std_mask(self.tgt, pad)
            self.ntokens = (self.tgt_y != pad).data.sum()

        #模型workflow
        output = self.decode(self.encode(src, self.src_mask), self.src_mask,
                             tgt[:, :-1], self.tgt_mask)

        output = self.fc_out(output)  #[3,49,10000]

        dec_tgt = tgt[:, 1:].clone()  # [3,49]
        dec_tgt[tgt[:,
                    1:] == PAD_ID] = -100  # 对矩阵等于pad_id的进行填充为-100,这样在计算后可以降低影响

        loss = nn.CrossEntropyLoss()(output.view(-1, 10000) / 1.0,
                                     dec_tgt.view(-1))
        return loss

    def init_weights(self, m):  # Initialize Linear Weight for GAN
        if isinstance(m, nn.Linear):
            m.weight.data.uniform_(-0.08,
                                   0.08)  #nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0.)

    def valid(self, src_seqs, src_lens, target, tar_lens):
        self.eval()
        loss = self.forward(src_seqs, src_lens, target, tar_lens)
        return {'valid_loss': loss.item()}

    def sample(self, src_seqs, src_lens, n_samples, decode_mode='beamsearch'):
        self.eval()
        src_pad_mask = src_seqs.eq(PAD_ID)
        c, hids = self.encoder(src_seqs, src_lens)
        init_h, hids = self.ctx2dec(c), self.ctx2dec(hids)
        if decode_mode == 'beamsearch':
            sample_words, sample_lens, _ = self.decoder.beam_decode(
                init_h, hids, src_pad_mask, None, 12, self.maxlen, n_samples)
            #[batch_size x n_samples x seq_len]
            sample_words, sample_lens = sample_words[0], sample_lens[0]
        else:
            sample_words, sample_lens = self.decoder.sampling(
                init_h, hids, src_pad_mask, None, self.maxlen, decode_mode)
        return sample_words, sample_lens

    def adjust_lr(self):
        #self.lr_scheduler_AE.step()
        return None

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

    @staticmethod
    def make_std_mask(tgt, pad):
        "创建Mask,使得我们不能attend to未来的词"
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & Variable(
            subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
        return tgt_mask