Esempio n. 1
0
    def train_AE(self, global_t, title, context, target, target_lens, sent_name):
        self.train()
        self.optimizer_AE[sent_name].zero_grad()
        # batch_size = title.size(0)
        # 每一句的情感用第二个分类器来预测,输入当前的m_hidden,输出分类结果
        title_last_hidden, _ = self.layers["seq_encoder"](title)
        context_last_hidden, _ = self.layers["seq_encoder"](context)
        target_hidden, _ = self.layers["seq_encoder"](target[:, 1:], target_lens - 1)

        mu, logsigma, z_post = self.layers["{}_vae".format(sent_name)](target_hidden)
        final_info = torch.cat([title_last_hidden, context_last_hidden, z_post], dim=1)

        output = self.layers["vae_decoder_{}".format(sent_name)](init_hidden=self.layers["init_decoder"](final_info),
                                                                 context=None, inputs=target[:, :-1])
        flattened_output = output.view(-1, self.vocab_size)
        dec_target = target[:, 1:].contiguous().view(-1)

        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(mask.size(0), self.vocab_size)  # [(batch_sz * seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(-1, self.vocab_size)
        self.rc_loss = self.criterion_ce(masked_output / self.temp, masked_target)
        kld = gaussian_kld(mu, logsigma)
        self.avg_kld = torch.mean(kld)
        self.kl_weights = min(global_t / self.full_kl_step, 1.0)  # 退火
        self.kl_loss = self.kl_weights * self.avg_kld

        self.bow_logits = self.layers["bow_project_{}".format(sent_name)](z_post)
        labels = target[:, 1:]
        label_mask = torch.sign(labels).detach().float()
        bow_loss = -F.log_softmax(self.bow_logits, dim=1).gather(1, labels) * label_mask
        sum_bow_loss = torch.sum(bow_loss, 1)
        self.avg_bow_loss = torch.mean(sum_bow_loss)

        self.aug_elbo_loss = self.avg_bow_loss + self.kl_loss + self.rc_loss

        self.aug_elbo_loss.backward()

        self.optimizer_AE[sent_name].step()

        avg_aug_elbo_loss = self.aug_elbo_loss.item()
        avg_rc_loss = self.rc_loss.data.item()
        avg_kl_loss = self.kl_loss.item()
        avg_bow_loss = self.avg_bow_loss.item()
        global_t += 1

        return [('avg_aug_elbo_loss', avg_aug_elbo_loss),
                ('avg_kl_loss', avg_kl_loss),
                ('avg_rc_loss', avg_rc_loss),
                ('avg_bow_loss', avg_bow_loss),
                ('kl_weight', self.kl_weights)], global_t
Esempio n. 2
0
    def valid_VAE(self, global_t, target, target_lens, sent_name):
        self.eval()

        target_hidden, _ = self.layers["seq_encoder"](target[:, 1:], target_lens - 1)
        mu, logsigma, z_post = self.layers["{}_vae".format(sent_name)](target_hidden)
        output = self.layers["vae_decoder_{}".format(sent_name)](init_hidden=self.layers["init_decoder_hidden"](z_post), context=None, inputs=target[:, :-1])

        flattened_output = output.view(-1, self.vocab_size)
        dec_target = target[:, 1:].contiguous().view(-1)

        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(mask.size(0), self.vocab_size)  # [(batch_sz * seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(-1, self.vocab_size)
        self.rc_loss = self.criterion_ce(masked_output / self.temp, masked_target)

        kld = gaussian_kld(mu, logsigma)
        self.avg_kld = torch.mean(kld)
        self.kl_weights = min(global_t / self.full_kl_step, 1.0)  # 退火
        self.kl_loss = self.kl_weights * self.avg_kld

        self.bow_logits = self.layers["bow_project_{}".format(sent_name)](z_post)
        labels = target[:, 1:]
        label_mask = torch.sign(labels).detach().float()
        bow_loss = -F.log_softmax(self.bow_logits, dim=1).gather(1, labels) * label_mask
        sum_bow_loss = torch.sum(bow_loss, 1)
        self.avg_bow_loss = torch.mean(sum_bow_loss)

        self.aug_elbo_loss = self.avg_bow_loss + self.kl_loss + self.rc_loss

        avg_aug_elbo_loss = self.aug_elbo_loss.item()
        avg_rc_loss = self.rc_loss.data.item()
        avg_kl_loss = self.kl_loss.item()
        avg_bow_loss = self.avg_bow_loss.item()

        return [('val_aug_elbo_loss', avg_aug_elbo_loss),
                ('val_kl_loss', avg_kl_loss),
                ('val_rc_loss', avg_rc_loss),
                ('val_bow_loss', avg_bow_loss),
                ('kl_weight', self.kl_weights)], global_t
Esempio n. 3
0
    def valid_AE(self,
                 global_t,
                 title,
                 context,
                 target,
                 target_lens,
                 sentiment_mask=None,
                 sentiment_lead=None):
        self.seq_encoder.eval()
        self.decoder.eval()

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)

        # import pdb
        # pdb.set_trace()
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        condition_prior = torch.cat((title_last_hidden, context_last_hidden),
                                    dim=1)

        z_prior, prior_mu, prior_logvar, pi, pi_final = self.sample_code_prior(
            condition_prior, sentiment_mask=sentiment_mask)
        z_post, post_mu, post_logvar = self.sample_code_post(
            x, condition_prior)
        if sentiment_lead is not None:
            self.sent_lead_loss = self.criterion_sent_lead(
                input=pi, target=sentiment_lead)
        else:
            self.sent_lead_loss = 0

        # if sentiment_lead is not None:
        #     return [('valid_lead_loss', self.sent_lead_loss.item())], global_t

        final_info = torch.cat((z_post, condition_prior), dim=1)

        output = self.decoder(init_hidden=self.init_decoder_hidden(final_info),
                              context=None,
                              inputs=target[:, :-1])
        flattened_output = output.view(-1, self.vocab_size)
        # flattened_output = self.softmax(flattened_output) + 1e-10
        # flattened_output = torch.log(flattened_output)
        dec_target = target[:, 1:].contiguous().view(-1)

        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz * seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        self.rc_loss = self.criterion_ce(masked_output / self.temp,
                                         masked_target)

        # kl散度
        kld = gaussian_kld(post_mu, post_logvar, prior_mu, prior_logvar)
        self.avg_kld = torch.mean(kld)
        self.kl_weights = min(global_t / self.full_kl_step, 1.0)  # 退火
        self.kl_loss = self.kl_weights * self.avg_kld
        # avg_bow_loss
        self.bow_logits = self.bow_project(final_info)
        # 说白了就是把target所有词的预测loss求个和
        labels = target[:, 1:]
        label_mask = torch.sign(labels).detach().float()

        bow_loss = -F.log_softmax(self.bow_logits, dim=1).gather(
            1, labels) * label_mask
        sum_bow_loss = torch.sum(bow_loss, 1)
        self.avg_bow_loss = torch.mean(sum_bow_loss)
        self.aug_elbo_loss = self.avg_bow_loss + self.kl_loss + self.rc_loss

        avg_aug_elbo_loss = self.aug_elbo_loss.item()
        avg_kl_loss = self.kl_loss.item()
        avg_rc_loss = self.rc_loss.data.item()
        avg_bow_loss = self.avg_bow_loss.item()
        avg_lead_loss = 0 if sentiment_lead is None else self.sent_lead_loss.item(
        )

        return [('valid_lead_loss', avg_lead_loss),
                ('valid_aug_elbo_loss', avg_aug_elbo_loss),
                ('valid_kl_loss', avg_kl_loss), ('valid_rc_loss', avg_rc_loss),
                ('valid_bow_loss', avg_bow_loss)], global_t
Esempio n. 4
0
    def train_AE(self,
                 global_t,
                 title,
                 context,
                 target,
                 target_lens,
                 sentiment_mask=None,
                 sentiment_lead=None):
        self.seq_encoder.train()
        self.decoder.train()
        # batch_size = title.size(0)
        # 每一句的情感用第二个分类器来预测,输入当前的m_hidden,输出分类结果

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)

        # import pdb
        # pdb.set_trace()
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        condition_prior = torch.cat((title_last_hidden, context_last_hidden),
                                    dim=1)

        z_prior, prior_mu, prior_logvar, pi, pi_final = self.sample_code_prior(
            condition_prior, sentiment_mask=sentiment_mask)
        z_post, post_mu, post_logvar = self.sample_code_post(
            x, condition_prior)
        # import pdb
        # pdb.set_trace()
        if sentiment_lead is not None:
            self.sent_lead_loss = self.criterion_sent_lead(
                input=pi, target=sentiment_lead)
        else:
            self.sent_lead_loss = 0

        # if sentiment_lead is not None:
        #     self.optimizer_lead.zero_grad()
        #     self.sent_lead_loss.backward()
        #     self.optimizer_lead.step()
        #     return [('lead_loss', self.sent_lead_loss.item())], global_t

        final_info = torch.cat((z_post, condition_prior), dim=1)
        # reconstruct_loss
        # import pdb
        # pdb.set_trace()
        output = self.decoder(init_hidden=self.init_decoder_hidden(final_info),
                              context=None,
                              inputs=target[:, :-1])
        flattened_output = output.view(-1, self.vocab_size)
        # flattened_output = self.softmax(flattened_output) + 1e-10
        # flattened_output = torch.log(flattened_output)
        dec_target = target[:, 1:].contiguous().view(-1)

        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz * seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        self.rc_loss = self.criterion_ce(masked_output / self.temp,
                                         masked_target)

        # kl散度
        kld = gaussian_kld(post_mu, post_logvar, prior_mu, prior_logvar)
        self.avg_kld = torch.mean(kld)
        self.kl_weights = min(global_t / self.full_kl_step, 1.0)  # 退火
        self.kl_loss = self.kl_weights * self.avg_kld

        # avg_bow_loss
        self.bow_logits = self.bow_project(final_info)
        # 说白了就是把target所有词的预测loss求个和
        labels = target[:, 1:]
        label_mask = torch.sign(labels).detach().float()
        # 取符号变成正数,从而通过最小化来optimize
        # soft_result = self.softmax(self.bow_logits) + 1e-10
        # bow_loss = -torch.log(soft_result).gather(1, labels) * label_mask
        bow_loss = -F.log_softmax(self.bow_logits, dim=1).gather(
            1, labels) * label_mask
        sum_bow_loss = torch.sum(bow_loss, 1)
        self.avg_bow_loss = torch.mean(sum_bow_loss)
        self.aug_elbo_loss = self.avg_bow_loss + self.kl_loss + self.rc_loss
        self.total_loss = self.aug_elbo_loss + self.sent_lead_loss

        # 变相增加标注集的学习率
        if sentiment_mask is not None:
            self.total_loss = self.total_loss * 13.33

        self.optimizer_AE.zero_grad()
        self.total_loss.backward()
        self.optimizer_AE.step()

        avg_total_loss = self.total_loss.item()
        avg_lead_loss = 0 if sentiment_lead is None else self.sent_lead_loss.item(
        )
        avg_aug_elbo_loss = self.aug_elbo_loss.item()
        avg_kl_loss = self.kl_loss.item()
        avg_rc_loss = self.rc_loss.data.item()
        avg_bow_loss = self.avg_bow_loss.item()
        global_t += 1

        return [('avg_total_loss', avg_total_loss),
                ('avg_lead_loss', avg_lead_loss),
                ('avg_aug_elbo_loss', avg_aug_elbo_loss),
                ('avg_kl_loss', avg_kl_loss), ('avg_rc_loss', avg_rc_loss),
                ('avg_bow_loss', avg_bow_loss),
                ('kl_weight', self.kl_weights)], global_t