コード例 #1
0
ファイル: poemwae.py プロジェクト: iAlexKai/headHider
class PoemWAE(nn.Module):
    def __init__(self, config, api, PAD_token=0, pretrain_weight=None):
        super(PoemWAE, self).__init__()
        self.vocab = api.vocab
        self.vocab_size = len(self.vocab)
        self.rev_vocab = api.rev_vocab
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.maxlen = config.maxlen
        self.clip = config.clip
        self.lambda_gp = config.lambda_gp
        self.lr_gan_g = config.lr_gan_g
        self.lr_gan_d = config.lr_gan_d
        self.n_d_loss = config.n_d_loss
        self.temp = config.temp
        self.init_w = config.init_weight

        self.embedder = nn.Embedding(self.vocab_size,
                                     config.emb_size,
                                     padding_idx=PAD_token)
        if pretrain_weight is not None:
            self.embedder.weight.data.copy_(torch.from_numpy(pretrain_weight))
        # 用同一个seq_encoder来编码标题和前后两句话
        self.seq_encoder = Encoder(self.embedder, config.emb_size,
                                   config.n_hidden, True, config.n_layers,
                                   config.noise_radius)
        # 由于Poem这里context是title和last sentence双向GRU编码后的直接cat,4*hidden
        # 注意如果使用Poemwar_gmp则使用子类中的prior_net,即混合高斯分布的一个先验分布
        self.prior_net = Variation(config.n_hidden * 4,
                                   config.z_size,
                                   dropout_rate=config.dropout,
                                   init_weight=self.init_w)  # p(e|c)

        # 注意这儿原来是给Dialog那个任务用的,3*hidden
        # Poem数据集上,将title和上一句,另外加上x都分别用双向GRU编码并cat,因此是6*hidden
        self.post_net = Variation(config.n_hidden * 6,
                                  config.z_size,
                                  dropout_rate=config.dropout,
                                  init_weight=self.init_w)

        self.post_generator = nn.Sequential(
            nn.Linear(config.z_size, config.z_size),
            nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1), nn.ReLU(),
            nn.Linear(config.z_size, config.z_size),
            nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1), nn.ReLU(),
            nn.Linear(config.z_size, config.z_size))
        self.post_generator.apply(self.init_weights)

        self.prior_generator = nn.Sequential(
            nn.Linear(config.z_size, config.z_size),
            nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1), nn.ReLU(),
            nn.Linear(config.z_size, config.z_size),
            nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1), nn.ReLU(),
            nn.Linear(config.z_size, config.z_size))
        self.prior_generator.apply(self.init_weights)

        self.init_decoder_hidden = nn.Sequential(
            nn.Linear(config.n_hidden * 4 + config.z_size,
                      config.n_hidden * 4),
            nn.BatchNorm1d(config.n_hidden * 4, eps=1e-05, momentum=0.1),
            nn.ReLU())

        # 由于Poem这里context是title和last sentence双向GRU编码后的直接cat,因此hidden_size变为z_size + 4*hidden
        # 修改:decoder的hidden_size还设为n_hidden, init_hidden使用一个MLP将cat变换为n_hidden
        self.decoder = Decoder(self.embedder,
                               config.emb_size,
                               config.n_hidden * 4,
                               self.vocab_size,
                               n_layers=1)

        self.discriminator = nn.Sequential(
            # 因为Poem的cat两个双向编码,这里改为4*n_hidden + z_size
            nn.Linear(config.n_hidden * 4 + config.z_size,
                      config.n_hidden * 2),
            nn.BatchNorm1d(config.n_hidden * 2, eps=1e-05, momentum=0.1),
            nn.LeakyReLU(0.2),
            nn.Linear(config.n_hidden * 2, config.n_hidden * 2),
            nn.BatchNorm1d(config.n_hidden * 2, eps=1e-05, momentum=0.1),
            nn.LeakyReLU(0.2),
            nn.Linear(config.n_hidden * 2, 1),
        )
        self.discriminator.apply(self.init_weights)

        # optimizer 定义,分别对应三个模块的训练,注意!三个模块的optimizer不相同
        # self.optimizer_AE = optim.SGD(list(self.seq_encoder.parameters())
        self.optimizer_AE = optim.SGD(
            list(self.seq_encoder.parameters()) +
            list(self.post_net.parameters()) +
            list(self.post_generator.parameters()) +
            list(self.init_decoder_hidden.parameters()) +
            list(self.decoder.parameters()),
            lr=config.lr_ae)
        self.optimizer_G = optim.RMSprop(
            list(self.post_net.parameters()) +
            list(self.post_generator.parameters()) +
            list(self.prior_net.parameters()) +
            list(self.prior_generator.parameters()),
            lr=self.lr_gan_g)
        self.optimizer_D = optim.RMSprop(self.discriminator.parameters(),
                                         lr=self.lr_gan_d)

        self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE,
                                                         step_size=10,
                                                         gamma=0.8)

        self.criterion_ce = nn.CrossEntropyLoss()

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            m.weight.data.uniform_(-self.init_w, self.init_w)
            # nn.init.kaiming_normal_(m.weight.data)
            # nn.init.kaiming_uniform_(m.weight.data)
            m.bias.data.fill_(0)

    # x: (batch, 2*n_hidden)
    # c: (batch, 2*2*n_hidden)
    def sample_code_post(self, x, c):
        z, _, _ = self.post_net(torch.cat((x, c),
                                          1))  # 输入:(batch, 3*2*n_hidden)
        z = self.post_generator(z)
        return z

    def sample_code_prior_sentiment(self, c, align):
        choice_statistic = self.prior_net(c, align)  # e: (batch, z_size)
        return choice_statistic

    def sample_code_prior(self, c):
        z, _, _ = self.prior_net(c)  # e: (batch, z_size)
        z = self.prior_generator(z)  # z: (batch, z_size)
        return z

    # 输入 title, context, target, target_lens.
    # c由title和context encode之后的hidden相concat而成
    def train_AE(self, title, context, target, target_lens):
        self.seq_encoder.train()
        self.decoder.train()
        # import pdb
        # pdb.set_trace()
        # (batch, 2 * hidden_size)
        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)

        # (batch, 2 * hidden_size)
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        # context_embedding
        c = torch.cat((title_last_hidden, context_last_hidden),
                      1)  # (batch, 2 * hidden_size * 2)
        z = self.sample_code_post(x, c)  # (batch, z_size)

        # 标准的autoencoder的decode,decoder初态为x, c的cat,将target错位输入
        # output: (batch, len, vocab_size) len是9,即7+标点+</s>

        output = self.decoder(self.init_decoder_hidden(torch.cat((z, c), 1)),
                              None, target[:, :-1], target_lens - 1)
        flattened_output = output.view(-1, self.vocab_size)

        dec_target = target[:, 1:].contiguous().view(-1)
        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz * seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)

        self.optimizer_AE.zero_grad()
        loss = self.criterion_ce(masked_output / self.temp, masked_target)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(
            list(self.seq_encoder.parameters()) +
            list(self.decoder.parameters()), self.clip)
        self.optimizer_AE.step()

        return [('train_loss_AE', loss.item())]

    # G是来缩短W距离的,可以类比VAE里面的缩小KL散度项
    def train_G(self,
                title,
                context,
                target,
                target_lens,
                sentiment_mask=None,
                mask_type=None):
        self.seq_encoder.eval()
        self.optimizer_G.zero_grad()

        for p in self.discriminator.parameters():
            p.requires_grad = False
        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)
        c = torch.cat((title_last_hidden, context_last_hidden),
                      1)  # (batch, 2 * hidden_size * 2)

        # -----------------posterior samples ---------------------------
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        z_post = self.sample_code_post(
            x.detach(), c.detach())  # 去掉梯度,防止梯度向encoder的传播 (batch, z_size)

        errG_post = torch.mean(
            self.discriminator(torch.cat(
                (z_post, c.detach()),
                1))) * self.n_d_loss  # (batch, z_size + 4 * hidden)
        errG_post.backward(minus_one)

        # ----------------- prior samples ---------------------------
        prior_z = self.sample_code_prior(c.detach())
        errG_prior = torch.mean(
            self.discriminator(torch.cat(
                (prior_z, c.detach()), 1))) * self.n_d_loss
        # import pdb
        # pdb.set_trace()
        errG_prior.backward(one)
        self.optimizer_G.step()

        for p in self.discriminator.parameters():
            p.requires_grad = True

        costG = errG_prior - errG_post

        return [('train_loss_G', costG.item())]

    # D是用来拟合W距离,loss下降说明拟合度变好,增大gradient_penalty一定程度上可以提高拟合度
    # n_iters_n越大,D训练的次数越多,对应的拟合度也越好
    def train_D(self, title, context, target, target_lens):
        self.seq_encoder.eval()
        self.discriminator.train()
        self.optimizer_D.zero_grad()

        batch_size = context.size(0)

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)
        c = torch.cat((title_last_hidden, context_last_hidden),
                      1)  # (batch, 2, hidden_size * 2)
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        post_z = self.sample_code_post(x, c)
        errD_post = torch.mean(
            self.discriminator(torch.cat(
                (post_z.detach(), c.detach()), 1))) * self.n_d_loss
        errD_post.backward(one)

        prior_z = self.sample_code_prior(c)
        errD_prior = torch.mean(
            self.discriminator(torch.cat(
                (prior_z.detach(), c.detach()), 1))) * self.n_d_loss
        errD_prior.backward(minus_one)
        # import pdb
        # pdb.set_trace()

        alpha = to_tensor(torch.rand(batch_size, 1))
        alpha = alpha.expand(prior_z.size())
        interpolates = alpha * prior_z.data + ((1 - alpha) * post_z.data)
        interpolates = Variable(interpolates, requires_grad=True)

        d_input = torch.cat((interpolates, c.detach()), 1)
        disc_interpolates = torch.mean(self.discriminator(d_input))
        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=to_tensor(torch.ones(disc_interpolates.size())),
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]
        gradient_penalty = (
            (gradients.contiguous().view(gradients.size(0), -1).norm(2, dim=1)
             - 1)**2).mean() * self.lambda_gp
        gradient_penalty.backward()

        self.optimizer_D.step()
        costD = -(errD_prior - errD_post) + gradient_penalty
        return [('train_loss_D', costD.item())]

    def valid(self, title, context, target, target_lens, sentiment_mask=None):
        self.seq_encoder.eval()
        self.discriminator.eval()
        self.decoder.eval()

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)
        c = torch.cat((title_last_hidden, context_last_hidden),
                      1)  # (batch, 2 * hidden_size * 2)
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)

        post_z = self.sample_code_post(x, c)
        prior_z = self.sample_code_prior(c)
        errD_post = torch.mean(self.discriminator(torch.cat((post_z, c), 1)))
        errD_prior = torch.mean(self.discriminator(torch.cat((prior_z, c), 1)))
        costD = -(errD_prior - errD_post)
        costG = -costD

        dec_target = target[:, 1:].contiguous().view(-1)  # (batch_size * len)
        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(mask.size(0), self.vocab_size)

        output = self.decoder(
            self.init_decoder_hidden(torch.cat((post_z, c), 1)), None,
            target[:, :-1], (target_lens - 1))
        flattened_output = output.view(-1, self.vocab_size)
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        lossAE = self.criterion_ce(masked_output / self.temp, masked_target)
        return [('valid_loss_AE', lossAE.item()),
                ('valid_loss_G', costG.item()), ('valid_loss_D', costD.item())]

    # 正如论文中说的,测试生成的时候,从先验网络中拿到噪声,用G生成prior_z(即代码中的sample_code_prior(c))
    # 然后decoder将prior_z和c的cat当做输入,decode出这句诗(这和论文里面不太一样,论文里面只把prior_z当做输入)
    # batch_size是1,一次测一句

    # title 即标题
    # context 上一句
    def test(self, title_tensor, title_words, headers):
        self.seq_encoder.eval()
        self.discriminator.eval()
        self.decoder.eval()
        # tem初始化为[2,3,0,0,0,0,0,0,0]

        tem = [[2, 3] + [0] * (self.maxlen - 2)]
        pred_poems = []

        title_tokens = [
            self.vocab[e] for e in title_words[0].tolist()
            if e not in [0, self.eos_id, self.go_id]
        ]
        pred_poems.append(title_tokens)
        for sent_id in range(4):
            tem = to_tensor(np.array(tem))
            context = tem

            # vec_context = np.zeros((batch_size, self.maxlen), dtype=np.int64)
            # for b_id in range(batch_size):
            #     vec_context[b_id, :] = np.array(context[b_id])
            # context = to_tensor(vec_context)

            title_last_hidden, _ = self.seq_encoder(
                title_tensor)  # (batch=1, 2*hidden)
            if sent_id == 0:
                context_last_hidden, _ = self.seq_encoder(
                    title_tensor)  # (batch=1, 2*hidden)
            else:
                context_last_hidden, _ = self.seq_encoder(
                    context)  # (batch=1, 2*hidden)
            c = torch.cat((title_last_hidden, context_last_hidden),
                          1)  # (batch, 4*hidden_size)
            # 由于一次只有一首诗,batch_size = 1,因此不必repeat
            prior_z = self.sample_code_prior(c)

            # decode_words 是完整的一句诗
            decode_words = self.decoder.testing(
                init_hidden=self.init_decoder_hidden(torch.cat((prior_z, c),
                                                               1)),
                maxlen=self.maxlen,
                go_id=self.go_id,
                mode="greedy",
                header=headers[sent_id])

            decode_words = decode_words[0].tolist()
            # import pdb
            # pdb.set_trace()
            if len(decode_words) > self.maxlen:
                tem = [decode_words[0:self.maxlen]]
            else:
                tem = [[0] * (self.maxlen - len(decode_words)) + decode_words]

            pred_tokens = [
                self.vocab[e] for e in decode_words[:-1]
                if e != self.eos_id and e != 0
            ]
            pred_poems.append(pred_tokens)

        gen = ''
        for line in pred_poems:
            true_str = " ".join(line)
            gen = gen + true_str + '\n'

        return gen

    def sample(self, title, context, repeat, go_id, end_id):
        self.seq_encoder.eval()
        self.decoder.eval()

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)
        c = torch.cat((title_last_hidden, context_last_hidden),
                      1)  # (batch, 2 * hidden_size * 2)

        c_repeated = c.expand(
            repeat, -1)  # 注意,我们输入的batch_size是1,这里复制repeat遍,为了后面的BLEU计算

        prior_z = self.sample_code_prior(
            c_repeated)  # c_repeated: (batch_size=repeat, 4*hidden_size)

        # (batch, max_len, 1)  (batch_size, 1)
        sample_words, sample_lens = self.decoder.sampling(
            self.init_decoder_hidden(torch.cat((prior_z, c_repeated), 1)),
            self.maxlen, go_id, end_id, "greedy")
        return sample_words, sample_lens
コード例 #2
0
ファイル: dialogwae.py プロジェクト: lindsey98/DialogWAE
class DialogWAE(nn.Module):
    def __init__(self, config, vocab_size, PAD_token=0):
        super(DialogWAE, self).__init__()
        self.vocab_size = vocab_size
        self.maxlen = config['maxlen']
        self.clip = config['clip']
        self.lambda_gp = config['lambda_gp']
        self.temp = config['temp']

        self.embedder = nn.Embedding(vocab_size,
                                     config['emb_size'],
                                     padding_idx=PAD_token)
        self.utt_encoder = Encoder(self.embedder, config['emb_size'],
                                   config['n_hidden'], True,
                                   config['n_layers'], config['noise_radius'])
        self.context_encoder = ContextEncoder(self.utt_encoder,
                                              config['n_hidden'] * 2 + 2,
                                              config['n_hidden'], 1,
                                              config['noise_radius'])
        self.prior_net = Variation(config['n_hidden'],
                                   config['z_size'])  # p(e|c)
        self.post_net = Variation(config['n_hidden'] * 3,
                                  config['z_size'])  # q(e|c,x)

        self.post_generator = nn.Sequential(
            nn.Linear(config['z_size'], config['z_size']),
            nn.BatchNorm1d(config['z_size'], eps=1e-05,
                           momentum=0.1), nn.ReLU(),
            nn.Linear(config['z_size'], config['z_size']),
            nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1),
            nn.ReLU(), nn.Linear(config['z_size'], config['z_size']))
        self.post_generator.apply(self.init_weights)

        self.prior_generator = nn.Sequential(
            nn.Linear(config['z_size'], config['z_size']),
            nn.BatchNorm1d(config['z_size'], eps=1e-05,
                           momentum=0.1), nn.ReLU(),
            nn.Linear(config['z_size'], config['z_size']),
            nn.BatchNorm1d(config['z_size'], eps=1e-05, momentum=0.1),
            nn.ReLU(), nn.Linear(config['z_size'], config['z_size']))
        self.prior_generator.apply(self.init_weights)

        self.decoder = Decoder(self.embedder,
                               config['emb_size'],
                               config['n_hidden'] + config['z_size'],
                               vocab_size,
                               n_layers=1)

        self.discriminator = nn.Sequential(
            nn.Linear(config['n_hidden'] + config['z_size'],
                      config['n_hidden'] * 2),
            nn.BatchNorm1d(config['n_hidden'] * 2, eps=1e-05, momentum=0.1),
            nn.LeakyReLU(0.2),
            nn.Linear(config['n_hidden'] * 2, config['n_hidden'] * 2),
            nn.BatchNorm1d(config['n_hidden'] * 2, eps=1e-05, momentum=0.1),
            nn.LeakyReLU(0.2),
            nn.Linear(config['n_hidden'] * 2, 1),
        )
        self.discriminator.apply(self.init_weights)

        self.optimizer_AE = optim.SGD(list(self.context_encoder.parameters()) +
                                      list(self.post_net.parameters()) +
                                      list(self.post_generator.parameters()) +
                                      list(self.decoder.parameters()),
                                      lr=config['lr_ae'])
        self.optimizer_G = optim.RMSprop(
            list(self.post_net.parameters()) +
            list(self.post_generator.parameters()) +
            list(self.prior_net.parameters()) +
            list(self.prior_generator.parameters()),
            lr=config['lr_gan_g'])
        self.optimizer_D = optim.RMSprop(self.discriminator.parameters(),
                                         lr=config['lr_gan_d'])

        self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE,
                                                         step_size=10,
                                                         gamma=0.6)

        self.criterion_ce = nn.CrossEntropyLoss()

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            m.weight.data.uniform_(-0.02, 0.02)
            m.bias.data.fill_(0)

    def sample_code_post(self, x, c):
        e, _, _ = self.post_net(torch.cat((x, c), 1))
        z = self.post_generator(e)
        return z

    def sample_code_prior(self, c):
        e, _, _ = self.prior_net(c)
        z = self.prior_generator(e)
        return z

    def train_AE(self, context, context_lens, utt_lens, floors, response,
                 res_lens):
        self.context_encoder.train()
        self.decoder.train()
        c = self.context_encoder(context, context_lens, utt_lens, floors)
        x, _ = self.utt_encoder(response[:, 1:], res_lens - 1)
        z = self.sample_code_post(x, c)
        output = self.decoder(torch.cat((z, c), 1), None, response[:, :-1],
                              (res_lens - 1))
        flattened_output = output.view(-1, self.vocab_size)

        dec_target = response[:, 1:].contiguous().view(-1)
        mask = dec_target.gt(0)  # [(batch_sz*seq_len)]
        masked_target = dec_target.masked_select(mask)  #
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz*seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)

        self.optimizer_AE.zero_grad()
        loss = self.criterion_ce(masked_output / self.temp, masked_target)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(
            list(self.context_encoder.parameters()) +
            list(self.decoder.parameters()), self.clip)
        self.optimizer_AE.step()

        return [('train_loss_AE', loss.item())]

    def train_G(self, context, context_lens, utt_lens, floors, response,
                res_lens):
        self.context_encoder.eval()
        self.optimizer_G.zero_grad()

        for p in self.discriminator.parameters():
            p.requires_grad = False

        c = self.context_encoder(context, context_lens, utt_lens, floors)
        # -----------------posterior samples ---------------------------
        x, _ = self.utt_encoder(response[:, 1:], res_lens - 1)
        z_post = self.sample_code_post(x.detach(), c.detach())
        errG_post = torch.mean(
            self.discriminator(torch.cat((z_post, c.detach()), 1)))
        errG_post.backward(minus_one)

        # ----------------- prior samples ---------------------------
        prior_z = self.sample_code_prior(c.detach())
        errG_prior = torch.mean(
            self.discriminator(torch.cat((prior_z, c.detach()), 1)))
        errG_prior.backward(one)

        self.optimizer_G.step()

        for p in self.discriminator.parameters():
            p.requires_grad = True

        costG = errG_prior - errG_post
        return [('train_loss_G', costG.item())]

    def train_D(self, context, context_lens, utt_lens, floors, response,
                res_lens):
        self.context_encoder.eval()
        self.discriminator.train()

        self.optimizer_D.zero_grad()

        batch_size = context.size(0)

        c = self.context_encoder(context, context_lens, utt_lens, floors)
        x, _ = self.utt_encoder(response[:, 1:], res_lens - 1)
        post_z = self.sample_code_post(x, c)
        errD_post = torch.mean(
            self.discriminator(torch.cat((post_z.detach(), c.detach()), 1)))
        errD_post.backward(one)

        prior_z = self.sample_code_prior(c)
        errD_prior = torch.mean(
            self.discriminator(torch.cat((prior_z.detach(), c.detach()), 1)))
        errD_prior.backward(minus_one)

        alpha = gData(torch.rand(batch_size, 1))
        alpha = alpha.expand(prior_z.size())
        interpolates = alpha * prior_z.data + ((1 - alpha) * post_z.data)
        interpolates = Variable(interpolates, requires_grad=True)
        d_input = torch.cat((interpolates, c.detach()), 1)
        disc_interpolates = torch.mean(self.discriminator(d_input))
        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=gData(torch.ones(disc_interpolates.size())),
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]
        gradient_penalty = (
            (gradients.contiguous().view(gradients.size(0), -1).norm(2, dim=1)
             - 1)**2).mean() * self.lambda_gp
        gradient_penalty.backward()

        self.optimizer_D.step()
        costD = -(errD_prior - errD_post) + gradient_penalty
        return [('train_loss_D', costD.item())]

    def valid(self, context, context_lens, utt_lens, floors, response,
              res_lens):
        self.context_encoder.eval()
        self.discriminator.eval()
        self.decoder.eval()

        c = self.context_encoder(context, context_lens, utt_lens, floors)
        x, _ = self.utt_encoder(response[:, 1:], res_lens - 1)
        post_z = self.sample_code_post(x, c)
        prior_z = self.sample_code_prior(c)
        errD_post = torch.mean(self.discriminator(torch.cat((post_z, c), 1)))
        errD_prior = torch.mean(self.discriminator(torch.cat((prior_z, c), 1)))
        costD = -(errD_prior - errD_post)
        costG = -costD

        dec_target = response[:, 1:].contiguous().view(-1)
        mask = dec_target.gt(0)  # [(batch_sz*seq_len)]
        masked_target = dec_target.masked_select(mask)
        output_mask = mask.unsqueeze(1).expand(mask.size(0), self.vocab_size)
        output = self.decoder(torch.cat((post_z, c), 1), None,
                              response[:, :-1], (res_lens - 1))
        flattened_output = output.view(-1, self.vocab_size)
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        lossAE = self.criterion_ce(masked_output / self.temp, masked_target)
        return [('valid_loss_AE', lossAE.item()),
                ('valid_loss_G', costG.item()), ('valid_loss_D', costD.item())]

    def sample(self, context, context_lens, utt_lens, floors, repeat, SOS_tok,
               EOS_tok):
        self.context_encoder.eval()
        self.decoder.eval()

        c = self.context_encoder(context, context_lens, utt_lens,
                                 floors)  # encode context into embedding
        c_repeated = c.expand(repeat, -1)
        prior_z = self.sample_code_prior(c_repeated)
        #         print(prior_z.shape)
        #         print(prior_z)
        sample_words, sample_lens = self.decoder.sampling(
            torch.cat((prior_z, c_repeated), 1), None, self.maxlen, SOS_tok,
            EOS_tok, "greedy")
        return sample_words, sample_lens

    def adjust_lr(self):
        self.lr_scheduler_AE.step()
class DFVAE(nn.Module):
    def __init__(self, config, vocab_size, PAD_token=0):
        super(DFVAE, self).__init__()
        self.vocab_size = vocab_size
        self.maxlen = config['maxlen']
        self.clip = config['clip']
        self.lambda_gp = config['lambda_gp']
        self.temp = config['temp']

        self.embedder = nn.Embedding(vocab_size,
                                     config['emb_size'],
                                     padding_idx=PAD_token)
        self.utt_encoder = Encoder(self.embedder, config['emb_size'],
                                   config['n_hidden'], True,
                                   config['n_layers'], config['noise_radius'])
        self.context_encoder = ContextEncoder(self.utt_encoder,
                                              config['n_hidden'] * 2 + 2,
                                              config['n_hidden'], 1,
                                              config['noise_radius'])
        self.prior_net = Variation(config['n_hidden'],
                                   config['z_size'])  # p(e|c)
        self.post_net = Variation(config['n_hidden'] * 3,
                                  config['z_size'])  # q(e|c,x)

        #self.prior_highway = nn.Linear(config['n_hidden'], config['n_hidden'])
        #self.post_highway = nn.Linear(config['n_hidden'] * 3, config['n_hidden'])
        self.postflow1 = flow.myIAF(config['z_size'], config['z_size'] * 2,
                                    config['n_hidden'], 3)
        self.postflow2 = flow.myIAF(config['z_size'], config['z_size'] * 2,
                                    config['n_hidden'], 3)
        self.postflow3 = flow.myIAF(config['z_size'], config['z_size'] * 2,
                                    config['n_hidden'], 3)
        self.priorflow1 = flow.IAF(config['z_size'], config['z_size'] * 2,
                                   config['n_hidden'], 3)
        self.priorflow2 = flow.IAF(config['z_size'], config['z_size'] * 2,
                                   config['n_hidden'], 3)
        self.priorflow3 = flow.IAF(config['z_size'], config['z_size'] * 2,
                                   config['n_hidden'], 3)

        self.post_generator = nn_.SequentialFlow(self.postflow1,
                                                 self.postflow2,
                                                 self.postflow3)
        self.prior_generator = nn_.SequentialFlow(self.priorflow1,
                                                  self.priorflow2,
                                                  self.priorflow3)

        self.decoder = Decoder(self.embedder,
                               config['emb_size'],
                               config['n_hidden'] + config['z_size'],
                               vocab_size,
                               n_layers=1)

        self.optimizer_AE = optim.SGD(
            list(self.context_encoder.parameters()) +
            list(self.post_net.parameters()) +
            list(self.post_generator.parameters()) +
            list(self.decoder.parameters()) +
            list(self.prior_net.parameters()) +
            list(self.prior_generator.parameters())
            #+list(self.prior_highway.parameters())
            #+list(self.post_highway.parameters())
            ,
            lr=config['lr_ae'])
        self.optimizer_G = optim.RMSprop(
            list(self.post_net.parameters()) +
            list(self.post_generator.parameters()) +
            list(self.prior_net.parameters()) +
            list(self.prior_generator.parameters())
            #+list(self.prior_highway.parameters())
            #+list(self.post_highway.parameters())
            ,
            lr=config['lr_gan_g'])

        #self.optimizer_D = optim.RMSprop(self.discriminator.parameters(), lr=config['lr_gan_d'])

        self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE,
                                                         step_size=10,
                                                         gamma=0.6)

        self.criterion_ce = nn.CrossEntropyLoss()

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            m.weight.data.uniform_(-0.02, 0.02)
            m.bias.data.fill_(0)

    def sample_post(self, x, c):
        xc = torch.cat((x, c), 1)
        e, mu, log_s = self.post_net(xc)
        #h_post = self.post_highway(xc)
        z, det_f, _, _ = self.post_generator((e, torch.eye(e.shape[1]), c, mu))
        #h_prior = self.prior_highway(c)
        tilde_z, det_g, _ = self.prior_generator((z, det_f, c))
        return tilde_z, z, mu, log_s, det_f, det_g

    def sample_code_post(self, x, c):
        xc = torch.cat((x, c), 1)
        e, mu, log_s = self.post_net(xc)
        #h_post = self.post_highway(xc)
        z, det_f, _, _ = self.post_generator((e, torch.eye(e.shape[1]), c, mu))
        #h_prior = self.prior_highway(c)
        tilde_z, det_g, _ = self.prior_generator((z, det_f, c))
        return tilde_z, mu, log_s, det_f, det_g

    def sample_post2(self, x, c):
        xc = torch.cat((x, c), 1)
        e, mu, log_s = self.post_net(xc)
        #h_post = self.post_highway(xc)
        z, det_f, _, _ = self.post_generator((e, torch.eye(e.shape[1]), c, mu))
        return e, mu, log_s, z, det_f

    def sample_code_prior(self, c):
        e, mu, log_s = self.prior_net(c)
        #z = self.prior_generator(e)
        #h_prior = self.prior_highway(c)
        #tilde_z, det_g, _ = self.prior_generator((e, 0, h_prior))
        return e, mu, log_s  #, det_g

    def sample_prior(self, c):
        e, mu, log_s = self.prior_net(c)
        #h_prior = self.prior_highway(c)
        z, det_prior, _ = self.prior_generator((e, 0, c))
        return z, det_prior

    def train_AE(self, context, context_lens, utt_lens, floors, response,
                 res_lens):
        self.context_encoder.train()
        self.decoder.train()
        c = self.context_encoder(context, context_lens, utt_lens, floors)
        x, _ = self.utt_encoder(response[:, 1:], res_lens - 1)
        z, _, _, _, _ = self.sample_code_post(x, c)
        z_post, mu_post, log_s_post, det_f, det_g = self.sample_code_post(x, c)
        #prior_z, mu_prior, log_s_prior = self.sample_code_prior(c)
        #KL_loss = torch.sum(log_s_prior - log_s_post + (torch.exp(log_s_post) + (mu_post - mu_prior)**2)/torch.exp(log_s_prior),1) / 2 - 100
        #kloss = KL_loss - det_f #+ det_g
        #KL_loss = log_Normal_diag(z_post, mu_post, log_s_post) - log_Normal_diag(prior_z, mu_prior, log_s_prior)
        output = self.decoder(torch.cat((z_post, c), 1), None,
                              response[:, :-1], (res_lens - 1))
        flattened_output = output.view(-1, self.vocab_size)

        dec_target = response[:, 1:].contiguous().view(-1)
        mask = dec_target.gt(0)  # [(batch_sz*seq_len)]
        masked_target = dec_target.masked_select(mask)  #
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz*seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        #print(KL_loss.mean())
        #print(det_f.mean())
        self.optimizer_AE.zero_grad()
        AE_term = self.criterion_ce(masked_output / self.temp, masked_target)
        loss = AE_term  #+ KL_loss.mean()
        loss.backward()

        #torch.nn.utils.clip_grad_norm_(list(self.context_encoder.parameters())+list(self.decoder.parameters()), self.clip)
        torch.nn.utils.clip_grad_norm_(
            list(self.context_encoder.parameters()) +
            list(self.decoder.parameters()) +
            list(self.post_generator.parameters()) +
            list(self.prior_generator.parameters()) +
            list(self.post_net.parameters()), self.clip)
        self.optimizer_AE.step()

        return [
            ('train_loss_AE', AE_term.item())
        ]  #,('KL_loss', KL_loss.mean().item())]#,('det_f', det_f.mean().item()),('det_g', det_g.mean().item())]

    def train_G(self, context, context_lens, utt_lens, floors, response,
                res_lens):
        self.context_encoder.eval()
        self.optimizer_G.zero_grad()
        c = self.context_encoder(context, context_lens, utt_lens, floors)
        # -----------------posterior samples ---------------------------
        x, _ = self.utt_encoder(response[:, 1:], res_lens - 1)
        z_0, mu_post, log_s_post, z_post, weight = self.sample_post2(
            x.detach(), c.detach())
        # ----------------- prior samples ---------------------------
        prior_z, mu_prior, log_s_prior = self.sample_code_prior(c.detach())
        KL_loss = torch.sum(
            log_s_prior - log_s_post + torch.exp(log_s_post) /
            torch.exp(log_s_prior) * torch.sum(weight**2, dim=2) +
            (mu_post)**2 / torch.exp(log_s_prior), 1) / 2 - 100
        #KL_loss = abs(log_Normal_diag(z_0, mu_post, log_s_post) - log_Normal_diag(z_post, mu_prior, log_s_prior))
        #KL_loss2 = torch.sum((prior_z - mu_post.detach())**2 / (2 * torch.exp(log_s_post.detach())),1)
        #print(mu_post.shape, prior_z.shape)
        loss = KL_loss
        #print(-det_f , KL_loss )
        #loss = abs(loss)
        loss.mean().backward()
        torch.nn.utils.clip_grad_norm_(
            list(self.post_generator.parameters()) +
            list(self.prior_generator.parameters()) +
            list(self.post_net.parameters()) +
            list(self.prior_generator.parameters()), self.clip)
        self.optimizer_G.step()
        #costG = errG_prior - errG_post
        return [
            ('KL_loss', KL_loss.mean().item())
        ]  #,('det_f', det_f.mean().item()),('det_g', det_g.sum().item())]

    def valid(self, context, context_lens, utt_lens, floors, response,
              res_lens):
        self.context_encoder.eval()
        #self.discriminator.eval()
        self.decoder.eval()

        c = self.context_encoder(context, context_lens, utt_lens, floors)
        x, _ = self.utt_encoder(response[:, 1:], res_lens - 1)
        post_z, mu_post, log_s_post, det_f, det_g = self.sample_code_post(x, c)
        prior_z, mu_prior, log_s_prior = self.sample_code_prior(c)
        #errD_post = torch.mean(self.discriminator(torch.cat((post_z, c),1)))
        #errD_prior = torch.mean(self.discriminator(torch.cat((prior_z, c),1)))
        KL_loss = torch.sum(
            log_s_prior - log_s_post +
            (torch.exp(log_s_post) +
             (mu_post)**2) / torch.exp(log_s_prior), 1) / 2
        #KL_loss = log_Normal_diag(post_z, mu_post, log_s_post) - log_Normal_diag(prior_z, mu_prior, log_s_prior)
        #KL_loss2 = torch.sum((prior_z - mu_post)**2 / (2 * torch.exp(log_s_post)),1)
        loss = KL_loss  # -det_f
        costG = loss.sum()
        dec_target = response[:, 1:].contiguous().view(-1)
        mask = dec_target.gt(0)  # [(batch_sz*seq_len)]
        masked_target = dec_target.masked_select(mask)
        output_mask = mask.unsqueeze(1).expand(mask.size(0), self.vocab_size)
        output = self.decoder(torch.cat((post_z, c), 1), None,
                              response[:, :-1], (res_lens - 1))
        flattened_output = output.view(-1, self.vocab_size)
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        lossAE = self.criterion_ce(masked_output / self.temp, masked_target)
        return [('valid_loss_AE', lossAE.item()),
                ('valid_loss_G', costG.item())]

    def sample(self, context, context_lens, utt_lens, floors, repeat, SOS_tok,
               EOS_tok):
        self.context_encoder.eval()
        self.decoder.eval()

        c = self.context_encoder(context, context_lens, utt_lens, floors)
        c_repeated = c.expand(repeat, -1)
        prior_z, _ = self.sample_prior(c_repeated)
        sample_words, sample_lens = self.decoder.sampling(
            torch.cat((prior_z, c_repeated), 1), None, self.maxlen, SOS_tok,
            EOS_tok, "greedy")
        return sample_words, sample_lens

    def gen(self, context, prior_z, context_lens, utt_lens, floors, repeat,
            SOS_tok, EOS_tok):
        self.context_encoder.eval()
        self.decoder.eval()
        c = self.context_encoder(context, context_lens, utt_lens, floors)
        c_repeated = c.expand(repeat, -1)
        sample_words, sample_lens = self.decoder.sampling(
            torch.cat((prior_z, c_repeated), 1), None, self.maxlen, SOS_tok,
            EOS_tok, "greedy")
        return sample_words, sample_lens

    def sample_latent(self, context, context_lens, utt_lens, floors, repeat,
                      SOS_tok, EOS_tok):
        self.context_encoder.eval()
        #self.decoder.eval()
        c = self.context_encoder(context, context_lens, utt_lens, floors)
        c_repeated = c.expand(repeat, -1)
        e, _, _ = self.sample_code_prior(c_repeated)
        prior_z, _, _ = self.prior_generator((e, 0, c_repeated))
        return prior_z, e

    def sample_latent_post(self, context, context_lens, utt_lens, floors,
                           response, res_lens, repeat):
        self.context_encoder.eval()
        c = self.context_encoder(context, context_lens, utt_lens, floors)
        x, _ = self.utt_encoder(response[:, 1:], res_lens - 1)
        c_repeated = c.expand(repeat, -1)
        x_repeated = x.expand(repeat, -1)
        z_post, z, mu_post, log_s_post, det_f, det_g = self.sample_post(
            x_repeated, c_repeated)
        return z_post, z

    def adjust_lr(self):
        self.lr_scheduler_AE.step()
コード例 #4
0
class CVAE(nn.Module):
    def __init__(self, config, api, PAD_token=0):
        super(CVAE, self).__init__()
        self.vocab = api.vocab
        self.vocab_size = len(self.vocab)
        self.rev_vocab = api.rev_vocab
        self.go_id = self.rev_vocab["<s>"]
        self.eos_id = self.rev_vocab["</s>"]
        self.maxlen = config.maxlen
        self.clip = config.clip
        self.temp = config.temp
        self.full_kl_step = config.full_kl_step
        self.z_size = config.z_size
        self.init_w = config.init_weight
        self.softmax = nn.Softmax(dim=1)

        self.embedder = nn.Embedding(self.vocab_size,
                                     config.emb_size,
                                     padding_idx=PAD_token)
        # 对title, 每一句诗做编码
        self.seq_encoder = Encoder(embedder=self.embedder,
                                   input_size=config.emb_size,
                                   hidden_size=config.n_hidden,
                                   bidirectional=True,
                                   n_layers=config.n_layers,
                                   noise_radius=config.noise_radius)

        # 先验网络的输入是 标题encode结果 + 上一句诗过encoder的结果 + 上一句情感过encoder的结果
        self.prior_net = Variation(config.n_hidden * 4,
                                   config.z_size,
                                   dropout_rate=config.dropout,
                                   init_weight=self.init_w)
        # 后验网络,再加上x的2*hidden
        # self.post_net = Variation(config.n_hidden * 6, config.z_size*2)
        self.post_net = Variation(config.n_hidden * 6,
                                  config.z_size,
                                  dropout_rate=config.dropout,
                                  init_weight=self.init_w)
        # 词包loss的MLP
        self.bow_project = nn.Sequential(
            nn.Linear(config.n_hidden * 4 + config.z_size, 400),
            nn.LeakyReLU(), nn.Dropout(config.dropout),
            nn.Linear(400, self.vocab_size))
        self.init_decoder_hidden = nn.Sequential(
            nn.Linear(config.n_hidden * 4 + config.z_size, config.n_hidden),
            nn.BatchNorm1d(config.n_hidden, eps=1e-05, momentum=0.1),
            nn.LeakyReLU())
        # self.post_generator = nn.Sequential(
        #     nn.Linear(config.z_size, config.z_size),
        #     nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1),
        #     nn.LeakyReLU(),
        #     nn.Linear(config.z_size, config.z_size),
        #     nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1),
        #     nn.LeakyReLU(),
        #     nn.Linear(config.z_size, config.z_size)
        # )
        # self.post_generator.apply(self.init_weights)

        # self.prior_generator = nn.Sequential(
        #     nn.Linear(config.z_size, config.z_size),
        #     nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1),
        #     nn.ReLU(),
        #     nn.Dropout(config.dropout),
        #     nn.Linear(config.z_size, config.z_size),
        #     nn.BatchNorm1d(config.z_size, eps=1e-05, momentum=0.1),
        #     nn.ReLU(),
        #     nn.Dropout(config.dropout),
        #     nn.Linear(config.z_size, config.z_size)
        # )
        # self.prior_generator.apply(self.init_weights)

        self.init_decoder_hidden.apply(self.init_weights)
        self.bow_project.apply(self.init_weights)
        self.post_net.apply(self.init_weights)

        self.decoder = Decoder(embedder=self.embedder,
                               input_size=config.emb_size,
                               hidden_size=config.n_hidden,
                               vocab_size=self.vocab_size,
                               n_layers=1)

        # self.optimizer_lead = optim.Adam(list(self.seq_encoder.parameters())\
        #                                + list(self.prior_net.parameters()), lr=config.lr_lead)
        self.optimizer_AE = optim.Adam(list(self.seq_encoder.parameters())\
                                       + list(self.prior_net.parameters())\
                                       # + list(self.prior_generator.parameters())

                                       + list(self.post_net.parameters())\
                                       # + list(self.post_generator.parameters())

                                       + list(self.bow_project.parameters())\
                                       + list(self.init_decoder_hidden.parameters())\
                                       + list(self.decoder.parameters()), lr=config.lr_ae)

        # self.lr_scheduler_AE = optim.lr_scheduler.StepLR(self.optimizer_AE, step_size=10, gamma=0.6)

        self.criterion_ce = nn.CrossEntropyLoss()
        self.softmax = nn.Softmax(dim=1)
        self.criterion_sent_lead = nn.CrossEntropyLoss()

    def set_full_kl_step(self, kl_full_step):
        self.full_kl_step = kl_full_step

    def force_change_lr(self, new_init_lr_ae):
        self.optimizer_AE = optim.Adam(list(self.seq_encoder.parameters()) \
                                       + list(self.prior_net.parameters()) \
                                       # + list(self.prior_generator.parameters())

                                       + list(self.post_net.parameters()) \
                                       # + list(self.post_generator.parameters())

                                       + list(self.bow_project.parameters()) \
                                       + list(self.init_decoder_hidden.parameters()) \
                                       + list(self.decoder.parameters()), lr=new_init_lr_ae)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            m.weight.data.uniform_(-self.init_w, self.init_w)
            m.bias.data.fill_(0)

    def sample_code_post(self, x, c):
        # import pdb
        # pdb.set_trace()
        # mulogsigma = self.post_net(torch.cat((x, c), dim=1))
        # mu, logsigma = torch.chunk(mulogsigma, chunks=2, dim=1)
        # batch_size = c.size(0)
        # std = torch.exp(0.5 * logsigma)
        # epsilon = to_tensor(torch.randn([batch_size, self.z_size]))
        # z = epsilon * std + mu
        z, mu, logsigma = self.post_net(torch.cat(
            (x, c), 1))  # 输入:(batch, 3*2*n_hidden)
        # z = self.post_generator(z)
        return z, mu, logsigma

    def sample_code_prior(self, c, sentiment_mask=None, mask_type=None):
        return self.prior_net(c,
                              sentiment_mask=sentiment_mask,
                              mask_type=mask_type)

    # # input: (batch, 3)
    # # target: (batch, 3)
    # def criterion_sent_lead(self, input, target):
    #     softmax_res = self.softmax(input)
    #     negative_log_softmax_res = -torch.log(softmax_res + 1e-10)  # (batch, 3)
    #     cross_entropy_loss = torch.sum(negative_log_softmax_res * target, dim=1)
    #     avg_cross_entropy = torch.mean(cross_entropy_loss)
    #     return avg_cross_entropy

    # sentiment_lead: (batch, 3)
    def train_AE(self,
                 global_t,
                 title,
                 context,
                 target,
                 target_lens,
                 sentiment_mask=None,
                 sentiment_lead=None):
        self.seq_encoder.train()
        self.decoder.train()
        # batch_size = title.size(0)
        # 每一句的情感用第二个分类器来预测,输入当前的m_hidden,输出分类结果

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)

        # import pdb
        # pdb.set_trace()
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        condition_prior = torch.cat((title_last_hidden, context_last_hidden),
                                    dim=1)

        z_prior, prior_mu, prior_logvar, pi, pi_final = self.sample_code_prior(
            condition_prior, sentiment_mask=sentiment_mask)
        z_post, post_mu, post_logvar = self.sample_code_post(
            x, condition_prior)
        # import pdb
        # pdb.set_trace()
        if sentiment_lead is not None:
            self.sent_lead_loss = self.criterion_sent_lead(
                input=pi, target=sentiment_lead)
        else:
            self.sent_lead_loss = 0

        # if sentiment_lead is not None:
        #     self.optimizer_lead.zero_grad()
        #     self.sent_lead_loss.backward()
        #     self.optimizer_lead.step()
        #     return [('lead_loss', self.sent_lead_loss.item())], global_t

        final_info = torch.cat((z_post, condition_prior), dim=1)
        # reconstruct_loss
        # import pdb
        # pdb.set_trace()
        output = self.decoder(init_hidden=self.init_decoder_hidden(final_info),
                              context=None,
                              inputs=target[:, :-1])
        flattened_output = output.view(-1, self.vocab_size)
        # flattened_output = self.softmax(flattened_output) + 1e-10
        # flattened_output = torch.log(flattened_output)
        dec_target = target[:, 1:].contiguous().view(-1)

        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz * seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        self.rc_loss = self.criterion_ce(masked_output / self.temp,
                                         masked_target)

        # kl散度
        kld = gaussian_kld(post_mu, post_logvar, prior_mu, prior_logvar)
        self.avg_kld = torch.mean(kld)
        self.kl_weights = min(global_t / self.full_kl_step, 1.0)  # 退火
        self.kl_loss = self.kl_weights * self.avg_kld

        # avg_bow_loss
        self.bow_logits = self.bow_project(final_info)
        # 说白了就是把target所有词的预测loss求个和
        labels = target[:, 1:]
        label_mask = torch.sign(labels).detach().float()
        # 取符号变成正数,从而通过最小化来optimize
        # soft_result = self.softmax(self.bow_logits) + 1e-10
        # bow_loss = -torch.log(soft_result).gather(1, labels) * label_mask
        bow_loss = -F.log_softmax(self.bow_logits, dim=1).gather(
            1, labels) * label_mask
        sum_bow_loss = torch.sum(bow_loss, 1)
        self.avg_bow_loss = torch.mean(sum_bow_loss)
        self.aug_elbo_loss = self.avg_bow_loss + self.kl_loss + self.rc_loss
        self.total_loss = self.aug_elbo_loss + self.sent_lead_loss

        # 变相增加标注集的学习率
        if sentiment_mask is not None:
            self.total_loss = self.total_loss * 13.33

        self.optimizer_AE.zero_grad()
        self.total_loss.backward()
        self.optimizer_AE.step()

        avg_total_loss = self.total_loss.item()
        avg_lead_loss = 0 if sentiment_lead is None else self.sent_lead_loss.item(
        )
        avg_aug_elbo_loss = self.aug_elbo_loss.item()
        avg_kl_loss = self.kl_loss.item()
        avg_rc_loss = self.rc_loss.data.item()
        avg_bow_loss = self.avg_bow_loss.item()
        global_t += 1

        return [('avg_total_loss', avg_total_loss),
                ('avg_lead_loss', avg_lead_loss),
                ('avg_aug_elbo_loss', avg_aug_elbo_loss),
                ('avg_kl_loss', avg_kl_loss), ('avg_rc_loss', avg_rc_loss),
                ('avg_bow_loss', avg_bow_loss),
                ('kl_weight', self.kl_weights)], global_t

    def valid_AE(self,
                 global_t,
                 title,
                 context,
                 target,
                 target_lens,
                 sentiment_mask=None,
                 sentiment_lead=None):
        self.seq_encoder.eval()
        self.decoder.eval()

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)

        # import pdb
        # pdb.set_trace()
        x, _ = self.seq_encoder(target[:, 1:], target_lens - 1)
        condition_prior = torch.cat((title_last_hidden, context_last_hidden),
                                    dim=1)

        z_prior, prior_mu, prior_logvar, pi, pi_final = self.sample_code_prior(
            condition_prior, sentiment_mask=sentiment_mask)
        z_post, post_mu, post_logvar = self.sample_code_post(
            x, condition_prior)
        if sentiment_lead is not None:
            self.sent_lead_loss = self.criterion_sent_lead(
                input=pi, target=sentiment_lead)
        else:
            self.sent_lead_loss = 0

        # if sentiment_lead is not None:
        #     return [('valid_lead_loss', self.sent_lead_loss.item())], global_t

        final_info = torch.cat((z_post, condition_prior), dim=1)

        output = self.decoder(init_hidden=self.init_decoder_hidden(final_info),
                              context=None,
                              inputs=target[:, :-1])
        flattened_output = output.view(-1, self.vocab_size)
        # flattened_output = self.softmax(flattened_output) + 1e-10
        # flattened_output = torch.log(flattened_output)
        dec_target = target[:, 1:].contiguous().view(-1)

        mask = dec_target.gt(0)  # 即判断target的token中是否有0(pad项)
        masked_target = dec_target.masked_select(mask)  # 选出非pad项
        output_mask = mask.unsqueeze(1).expand(
            mask.size(0), self.vocab_size)  # [(batch_sz * seq_len) x n_tokens]
        masked_output = flattened_output.masked_select(output_mask).view(
            -1, self.vocab_size)
        self.rc_loss = self.criterion_ce(masked_output / self.temp,
                                         masked_target)

        # kl散度
        kld = gaussian_kld(post_mu, post_logvar, prior_mu, prior_logvar)
        self.avg_kld = torch.mean(kld)
        self.kl_weights = min(global_t / self.full_kl_step, 1.0)  # 退火
        self.kl_loss = self.kl_weights * self.avg_kld
        # avg_bow_loss
        self.bow_logits = self.bow_project(final_info)
        # 说白了就是把target所有词的预测loss求个和
        labels = target[:, 1:]
        label_mask = torch.sign(labels).detach().float()

        bow_loss = -F.log_softmax(self.bow_logits, dim=1).gather(
            1, labels) * label_mask
        sum_bow_loss = torch.sum(bow_loss, 1)
        self.avg_bow_loss = torch.mean(sum_bow_loss)
        self.aug_elbo_loss = self.avg_bow_loss + self.kl_loss + self.rc_loss

        avg_aug_elbo_loss = self.aug_elbo_loss.item()
        avg_kl_loss = self.kl_loss.item()
        avg_rc_loss = self.rc_loss.data.item()
        avg_bow_loss = self.avg_bow_loss.item()
        avg_lead_loss = 0 if sentiment_lead is None else self.sent_lead_loss.item(
        )

        return [('valid_lead_loss', avg_lead_loss),
                ('valid_aug_elbo_loss', avg_aug_elbo_loss),
                ('valid_kl_loss', avg_kl_loss), ('valid_rc_loss', avg_rc_loss),
                ('valid_bow_loss', avg_bow_loss)], global_t

    # batch_size = 1 只输入了一个标题
    # test的时候,只有先验,没有后验,更没有所谓的kl散度
    def test(self, title_tensor, title_words, mask_type=None):
        self.seq_encoder.eval()
        self.decoder.eval()
        assert title_tensor.size(0) == 1
        tem = [[2, 3] + [0] * (self.maxlen - 2)]
        pred_poems = []
        # 过滤掉标题中的<s> </s> 0,只为了打印
        title_tokens = [
            self.vocab[e] for e in title_words[0].tolist()
            if e not in [0, self.eos_id, self.go_id]
        ]
        pred_poems.append(title_tokens)

        for i in range(4):
            tem = to_tensor(np.array(tem))
            context = tem
            if i == 0:
                context_last_hidden, _ = self.seq_encoder(title_tensor)
            else:
                context_last_hidden, _ = self.seq_encoder(context)
            title_last_hidden, _ = self.seq_encoder(title_tensor)

            condition_prior = torch.cat(
                (title_last_hidden, context_last_hidden), dim=1)
            z_prior, prior_mu, prior_logvar, _, _ = self.sample_code_prior(
                condition_prior, mask_type=mask_type)
            final_info = torch.cat((z_prior, condition_prior), 1)

            decode_words = self.decoder.testing(
                init_hidden=self.init_decoder_hidden(final_info),
                maxlen=self.maxlen,
                go_id=self.go_id,
                mode="greedy")
            decode_words = decode_words[0].tolist()

            if len(decode_words) >= self.maxlen:
                tem = [decode_words[0:self.maxlen]]
            else:
                tem = [[0] * (self.maxlen - len(decode_words)) + decode_words]
            pred_tokens = [
                self.vocab[e] for e in decode_words[:-1]
                if e != self.eos_id and e != 0 and e != self.go_id
            ]
            pred_poems.append(pred_tokens)

        gen = ""
        for line in pred_poems:
            cur_line = " ".join(line)
            gen = gen + cur_line + '\n'

        return gen

    def sample(self, title, context, repeat, go_id, end_id):
        self.seq_encoder.eval()
        self.decoder.eval()

        assert title.size(0) == 1
        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(context)
        condition_prior = torch.cat((title_last_hidden, context_last_hidden),
                                    1)
        condition_prior_repeat = condition_prior.expand(repeat, -1)

        z_prior_repeat, _, _, _, _ = self.sample_code_prior(
            condition_prior_repeat)

        final_info = torch.cat((z_prior_repeat, condition_prior_repeat), dim=1)
        sample_words, sample_lens = self.decoder.sampling(
            init_hidden=self.init_decoder_hidden(final_info),
            maxlen=self.maxlen,
            go_id=self.go_id,
            eos_id=self.eos_id,
            mode="greedy")

        return sample_words, sample_lens