예제 #1
0
    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        
        ## Encode
        num_sentences, enc_seq_len = enc_batch.size()
        batch_size = enc_lens.size(0)
        max_len = enc_lens.data.max().item()
        input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1)
        
        # word level encoder
        enc_emb = self.embedding(enc_batch)
        word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths)
        word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1)

        # pad and pack word_encoder_hidden
        start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0)
        word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len)
                                            for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0)
        mask_src = ~(enc_padding_mask.bool()).unsqueeze(1)

        # context level encoder
        if word_encoder_hidden.size(-1) != config.hidden_dim:
            word_encoder_hidden = self.linear(word_encoder_hidden)
        encoder_outputs = self.encoder(word_encoder_hidden, mask_src)
        
        ys = torch.ones(batch_size, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        
        decoded_words = []
        for i in range(max_dec_step+1):

            out, attn_dist, _, _,_ = self.decoder(self.embedding(ys), encoder_outputs, None, (mask_src, None, mask_trg))
            
            prob = self.generator(out,attn_dist,enc_batch_extend_vocab, extra_zeros, attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim = 1)
            decoded_words.append(['<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1)])

            if config.USE_CUDA:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
            
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st+= e + ' '
            sent.append(st)
        return sent
예제 #2
0
    def train_one_batch(self, batch, iter, train=True):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(config.PAD_idx).unsqueeze(1)
        post_emb = self.embedding(batch["posterior_batch"])
        r_encoder_outputs = self.r_encoder(post_emb, mask_res)

        ## Encode
        num_sentences, enc_seq_len = enc_batch.size()
        batch_size = enc_lens.size(0)
        max_len = enc_lens.data.max().item()
        input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1)
        
        # word level encoder
        enc_emb = self.embedding(enc_batch)
        word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths)
        word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1)

        # pad and pack word_encoder_hidden
        start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0)
        word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len)
                                            for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0)
        
        # mask_src = ~(enc_padding_mask.bool()).unsqueeze(1)
        mask_src = (1 - enc_padding_mask.byte()).unsqueeze(1)
        
        # context level encoder
        if word_encoder_hidden.size(-1) != config.hidden_dim:
            word_encoder_hidden = self.linear(word_encoder_hidden)
        encoder_outputs = self.encoder(word_encoder_hidden, mask_src)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] * batch_size).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()

        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) #(batch, len, embedding)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        dec_emb = self.embedding(dec_batch_shift)

        pre_logit, attn_dist, mean, log_var, probs = self.decoder(dec_emb, encoder_outputs, r_encoder_outputs, 
                                                                    (mask_src, mask_res, mask_trg))
        
        ## compute output dist
        logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        sbow = dec_batch #[batch, seq_len]
        seq_len = sbow.size(1)
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        if config.model=="cvaetrs":
            loss_aux = 0
            for prob in probs:
                sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2)
                sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len]

                loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1))
            kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"])
            kld_loss = torch.mean(kld_loss)
            kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1)
            #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0
            loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux
            elbo = loss_rec + kld_loss
        else:
            loss = loss_rec
            elbo = loss_rec
            kld_loss = torch.Tensor([0])
            loss_aux = torch.Tensor([0])
        if(train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item()
예제 #3
0
    def train_one_batch(self, batch, iter, train=True):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(
            config.PAD_idx).unsqueeze(1)
        post_emb = self.embedding(batch["posterior_batch"])
        r_encoder_outputs = self.r_encoder(post_emb, mask_res)

        ## Encode
        num_sentences, enc_seq_len = enc_batch.size()
        batch_size = enc_lens.size(0)
        max_len = enc_lens.data.max().item()
        input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1)

        # word level encoder
        enc_emb = self.embedding(enc_batch)
        word_encoder_outpus, word_encoder_hidden = self.word_encoder(
            enc_emb, input_lengths)
        word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(
            num_sentences, -1)

        # pad and pack word_encoder_hidden
        start = torch.cumsum(
            torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0)
        word_encoder_hidden = torch.stack([
            pad(word_encoder_hidden.narrow(0, s, l), max_len)
            for s, l in zip(start.data.tolist(), enc_lens.data.tolist())
        ], 0)

        # mask_src = ~(enc_padding_mask.bool()).unsqueeze(1)
        mask_src = (1 - enc_padding_mask.byte()).unsqueeze(1)

        # context level encoder
        if word_encoder_hidden.size(-1) != config.hidden_dim:
            word_encoder_hidden = self.linear(word_encoder_hidden)
        encoder_outputs = self.encoder(word_encoder_hidden, mask_src)

        #latent variable
        if config.model == "cvaetrs":
            kld_loss, z = self.latent_layer(encoder_outputs[:, 0],
                                            r_encoder_outputs[:, 0],
                                            train=True)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     batch_size).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()

        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  #(batch, len, embedding)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        input_vector = self.embedding(dec_batch_shift)
        if config.model == "cvaetrs":
            input_vector[:, 0] = input_vector[:, 0] + z
        else:
            input_vector[:, 0] = input_vector[:, 0]
        pre_logit, attn_dist = self.decoder(input_vector, encoder_outputs,
                                            (mask_src, mask_trg))

        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
        if config.model == "cvaetrs":
            z_logit = self.bow(z)  # [batch_size, vocab_size]
            z_logit = z_logit.unsqueeze(1).repeat(1, logit.size(1), 1)
            loss_aux = self.criterion(
                z_logit.contiguous().view(-1, z_logit.size(-1)),
                dec_batch.contiguous().view(-1))

            #kl_weight = min(iter/config.full_kl_step, 0.28) if config.full_kl_step >0 else 1.0
            kl_weight = min(
                math.tanh(6 * iter / config.full_kl_step - 3) + 1, 1)
            loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux

            aux = loss_aux.item()
            elbo = loss_rec + kld_loss
        else:
            loss = loss_rec
            elbo = loss_rec
            kld_loss = torch.Tensor([0])
            aux = 0
            if config.multitask:
                emo_logit = self.emo(encoder_outputs[:, 0])
                emo_loss = self.emo_criterion(emo_logit,
                                              batch["program_label"] - 9)
                loss = loss_rec + emo_loss
        if (train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(
            loss_rec.item(), 100)), kld_loss.item(), aux, elbo.item()