Example #1
0
    def decoder_topk(self, batch, max_dec_step=30):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg))
            else:
                out, attn_dist = self.decoder(self.embedding(ys),
                                              encoder_outputs,
                                              (mask_src, mask_trg))

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            filtered_logit = top_k_top_p_filtering(logit[:, -1],
                                                   top_k=3,
                                                   top_p=0,
                                                   filter_value=-float('Inf'))
            # Sample from the filtered distribution
            next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1),
                                          1).squeeze()
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]

            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Example #2
0
    def train_one_batch(self, batch, train=True):
        ## pad and other stuff
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        encoder_outputs = self.encoder(self.embedding(enc_batch),mask_src)

        # Decode 
        sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),encoder_outputs, (mask_src,mask_trg))
        ## compute output dist
        logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab, extra_zeros)
        
        ## loss: NNL if ptr else Cross entropy
        loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))

        if(train):
            loss.backward()
            self.optimizer.step()
        if(config.label_smoothing): 
            loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        
        return loss.item(), math.exp(min(loss.item(), 100)), loss
Example #3
0
    def train_one_batch(self, batch, train=True):
        if train:
            self.model.train()
        enc_batch, enc_mask, _, enc_batch_extend_vocab, extra_zeros, _, _ = \
            get_input_from_batch(batch)
        dec_batch, dec_mask, _, _, _ = get_output_from_batch(batch)
        dec_batch_input, dec_batch_output = dec_batch[:, :-1], dec_batch[:, 1:]
        dec_mask = dec_mask[:, :-1]

        self.optimizer.zero_grad()
        logit, *_ = self.model(
            enc_batch,
            attention_mask=enc_mask,
            decoder_input_ids=dec_batch_input,
            decoder_attention_mask=dec_mask,
        )

        # loss: NNL if ptr else Cross entropy
        loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                              dec_batch_output.contiguous().view(-1))

        if (config.act):
            loss += self.compute_act_loss(self.encoder)
            loss += self.compute_act_loss(self.decoder)

        if (train):
            loss.backward()
            self.optimizer.step()
        if (config.label_smoothing):
            loss = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1))

        return loss.item(), math.exp(loss.item()), loss
Example #4
0
    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        mask_trg = torch.ones((enc_batch.size(0), 50))
        meta_size = meta.size()
        meta = meta.repeat(1, 50).view(meta_size[0], 50, meta_size[1])
        out, attn_dist, _, _ = self.decoder(meta, encoder_outputs, None,
                                            (mask_src, None, mask_trg))
        prob = self.generator(out,
                              attn_dist,
                              enc_batch_extend_vocab,
                              extra_zeros,
                              attn_dist_db=None)
        _, batch_out = torch.max(prob, dim=1)
        batch_out = batch_out.data.cpu().numpy()
        sentences = []
        for sent in batch_out:
            st = ''
            for w in sent:
                if w == config.EOS_idx: break
                else: st += self.vocab.index2word[w] + ' '
            sentences.append(st)
        return sentences
Example #5
0
    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(
            config.PAD_idx).unsqueeze(1)
        posterior_mask = self.embedding(batch["posterior_mask"])
        r_encoder_outputs = self.r_encoder(
            self.embedding(batch["posterior_batch"]), mask_res)
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)
        meta = self.embedding(batch["program_label"])

        # Decode
        mask_trg = dec_batch.data.eq(config.PAD_idx).unsqueeze(1)
        latent_dim = meta.size()[-1]
        meta = meta.repeat(1, dec_batch.size(1)).view(dec_batch.size(0),
                                                      dec_batch.size(1),
                                                      latent_dim)
        pre_logit, attn_dist, mean, log_var = self.decoder(
            meta, encoder_outputs, r_encoder_outputs,
            (mask_src, mask_res, mask_trg))
        if not train:
            pre_logit, attn_dist, _, _ = self.decoder(
                meta, encoder_outputs, None, (mask_src, None, mask_trg))
        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
        kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],
                                mean["prior"], log_var["prior"])
        kld_loss = torch.mean(kld_loss)
        kl_weight = min(iter / config.full_kl_step,
                        1) if config.full_kl_step > 0 else 1.0
        loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss
        if (train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(loss_rec.item(),
                                             100)), kld_loss.item()
Example #6
0
    def train_one_batch(self, batch, train=True):
        ## pad and other stuff
        enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        self.h0_encoder, self.c0_encoder = self.get_state(enc_batch)
        src_h, (src_h_t, src_c_t) = self.encoder(
            self.embedding(enc_batch), (self.h0_encoder, self.c0_encoder))
        h_t = src_h_t[-1]
        c_t = src_c_t[-1]
        self.h0_encoder_r, self.c0_encoder_r = self.get_state(dec_batch)
        src_h_r, (src_h_t_r, src_c_t_r) = self.encoder_r(
            self.embedding(dec_batch), (self.h0_encoder_r, self.c0_encoder_r))
        h_t_r = src_h_t_r[-1]
        c_t_r = src_c_t_r[-1]
        
        #sample and reparameter
        z_sample, mu, var = self.represent(torch.cat((h_t_r, h_t), 1))
        p_z_sample, p_mu, p_var = self.prior(h_t)
        
        # Decode
        decoder_init_state = nn.Tanh()(self.encoder2decoder(torch.cat((z_sample, h_t), 1)))
        
        sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1)
        target_embedding = self.embedding(dec_batch_shift)
        ctx = src_h.transpose(0, 1)
        trg_h, (_, _) = self.decoder(
            target_embedding,
            (decoder_init_state, c_t),
            ctx    
        )
        pre_logit = trg_h
        logit = self.generator(pre_logit)
        
        ## loss: NNL if ptr else Cross entropy
        re_loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        kl_losses = 0.5 * torch.sum(torch.exp(var - p_var) + (mu - p_mu) ** 2 / torch.exp(p_var) - 1. - var + p_var, 1)
        kl_loss = torch.mean(kl_losses)
        latent_logit = self.mlp_b(torch.cat((z_sample, h_t), 1)).unsqueeze(1)
        latent_logit = F.log_softmax(latent_logit,dim=-1)
        latent_logits = latent_logit.repeat(1, logit.size(1), 1)
        bow_loss = self.criterion(latent_logits.contiguous().view(-1, latent_logits.size(-1)), dec_batch.contiguous().view(-1))
        loss = re_loss + 0.48 * kl_loss + bow_loss
        if(train):
            loss.backward()
            self.optimizer.step()
        if(config.label_smoothing): 
            s_loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        
        return s_loss.item(), math.exp(min(s_loss.item(), 100)), loss.item(), re_loss.item(), kl_loss.item(), bow_loss.item()
Example #7
0
    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        if config.dataset == "empathetic":
            meta = meta - meta
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)
        if config.model == "cvaetrs":
            kld_loss, z = self.latent_layer(encoder_outputs[:, 0],
                                            None,
                                            train=False)

        ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        decoded_words = []
        for i in range(max_dec_step + 1):
            input_vector = self.embedding(ys)
            if config.model == "cvaetrs":
                input_vector[:, 0] = input_vector[:, 0] + z + meta
            else:
                input_vector[:, 0] = input_vector[:, 0] + meta
            out, attn_dist = self.decoder(input_vector, encoder_outputs,
                                          (mask_src, mask_trg))

            prob = self.generator(out,
                                  attn_dist,
                                  enc_batch_extend_vocab,
                                  extra_zeros,
                                  attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim=1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])

            if config.USE_CUDA:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)

            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
    def decoder_greedy_po(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        # print('=====================ys========================')
        # print(ys)

        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):

            out, attn_dist = self.decoder(
                self.embedding(ys) + meta.unsqueeze(1), encoder_outputs,
                (mask_src, mask_trg))

            prob = self.generator(out,
                                  attn_dist,
                                  enc_batch_extend_vocab,
                                  extra_zeros,
                                  attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim=1)
            # print('=====================next_word1========================')
            # print(next_word)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            #next_word = next_word.data[0]
            # print('=====================next_word2========================')
            # print(next_word)
            if config.USE_CUDA:
                # print('=====================shape========================')
                # print(ys.shape, next_word.shape)
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)

            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
            # print('=====================new_ys========================')
            # print(ys)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Example #9
0
    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)

        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(self.embedding(enc_batch)+emb_mask,mask_src)

        # Decode 
        sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),encoder_outputs, (mask_src,mask_trg))
        
        ## compute output dist
        logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None)
        #logit = F.log_softmax(logit,dim=-1) #fix the name later
        ## loss: NNL if ptr else Cross entropy
        loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        
        #multi-task
        if config.multitask:
            #q_h = torch.mean(encoder_outputs,dim=1)
            q_h = encoder_outputs[:,0]
            logit_prob = self.decoder_key(q_h)
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + nn.CrossEntropyLoss()(logit_prob,torch.LongTensor(batch['program_label']).cuda())
            loss_bce_program = nn.CrossEntropyLoss()(logit_prob,torch.LongTensor(batch['program_label']).cuda()).item()
            pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1)
            program_acc = accuracy_score(batch["program_label"], pred_program)

        if(config.label_smoothing): 
            loss_ppl = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)).item()

        if(train):
            loss.backward()
            self.optimizer.step()
        if config.multitask:
            if(config.label_smoothing):
                return loss_ppl, math.exp(min(loss_ppl, 100)), loss_bce_program, program_acc
            else:
                return loss.item(), math.exp(min(loss.item(), 100)), loss_bce_program, program_acc
        else:
            if(config.label_smoothing):
                return loss_ppl, math.exp(min(loss_ppl, 100)), 0, 0
            else:
                return loss.item(), math.exp(min(loss.item(), 100)), 0, 0
    def train_n_batch(self, batchs, iter, train=True):
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()
        for batch in batchs:
            enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
            dec_batch, _, _, _, _ = get_output_from_batch(batch)
            ## Encode
            mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
            encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

            meta = self.embedding(batch["program_label"])
            if config.dataset=="empathetic":
                meta = meta-meta
            # Decode
            sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
            if config.USE_CUDA: sos_token = sos_token.cuda()
            dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1)

            mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

            pre_logit, attn_dist, mean, log_var, probs= self.decoder(self.embedding(dec_batch_shift)+meta.unsqueeze(1),encoder_outputs, True, (mask_src,mask_trg))
            ## compute output dist
            logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None)
            ## loss: NNL if ptr else Cross entropy
            sbow = dec_batch #[batch, seq_len]
            seq_len = sbow.size(1)
            loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
            if config.model=="cvaetrs":
                loss_aux = 0
                for prob in probs:
                    sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2)
                    sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len]

                    loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1))
                kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"])
                kld_loss = torch.mean(kld_loss)
                kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1)
                #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0
                loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux
                elbo = loss_rec+kld_loss
            else:
                loss = loss_rec
                elbo = loss_rec
                kld_loss = torch.Tensor([0])
                loss_aux = torch.Tensor([0])
            loss.backward()
            # clip gradient
        nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
        self.optimizer.step()

        return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item()
Example #11
0
    def forward(self, batch):
        enc_batch, enc_mask, _, enc_batch_extend_vocab, extra_zeros, _, _ = \
            get_input_from_batch(batch)
        dec_batch, dec_mask, _, _, _ = get_output_from_batch(batch)
        dec_batch_input, dec_batch_output = dec_batch[:, :-1], dec_batch[:, 1:]
        dec_mask = dec_mask[:, :-1]

        self.optimizer.zero_grad()
        loss = self.model(
            input_ids=enc_batch,
            decoder_input_ids=dec_batch_input,
            lm_labels=dec_batch_output,
            attention_mask=enc_mask,
            decoder_attention_mask=dec_mask,
        )[0]

        return loss.item(), math.exp(min(loss.item(), 600)), loss
    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        meta = self.embedding(batch["program_label"])
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(
            self.embedding(dec_batch_shift) + meta.unsqueeze(1),
            encoder_outputs, (mask_src, mask_trg))
        #+meta.unsqueeze(1)
        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        #logit = F.log_softmax(logit,dim=-1) #fix the name later
        ## loss: NNL if ptr else Cross entropy
        loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                              dec_batch.contiguous().view(-1))

        if (train):
            loss.backward()
            self.optimizer.step()

        return loss.item(), math.exp(min(loss.item(), 100)), 0
Example #13
0
    def score_sentence(self, batch):
        # pad and other stuff
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        cand_batch = batch["cand_index"]
        hit_1 = 0
        for i, b in enumerate(enc_batch):
            # Encode
            mask_src = b.unsqueeze(0).data.eq(config.PAD_idx).unsqueeze(1)
            encoder_outputs = self.encoder(self.embedding(b.unsqueeze(0)),
                                           mask_src)
            rank = {}
            for j, c in enumerate(cand_batch[i]):
                if config.USE_CUDA:
                    c = c.cuda()
                # Decode
                sos_token = torch.LongTensor(
                    [config.SOS_idx] * b.unsqueeze(0).size(0)).unsqueeze(1)
                if config.USE_CUDA:
                    sos_token = sos_token.cuda()
                dec_batch_shift = torch.cat(
                    (sos_token, c.unsqueeze(0)[:, :-1]), 1)

                mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
                pre_logit, attn_dist = self.decoder(
                    self.embedding(dec_batch_shift), encoder_outputs,
                    (mask_src, mask_trg))

                # compute output dist
                logit = self.generator(pre_logit, attn_dist,
                                       enc_batch_extend_vocab[i].unsqueeze(0),
                                       extra_zeros)
                loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    c.unsqueeze(0).contiguous().view(-1))
                # print("CANDIDATE {}".format(j), loss.item(), math.exp(min(loss.item(), 100)))
                rank[j] = math.exp(min(loss.item(), 100))
            s = sorted(rank.items(), key=lambda x: x[1], reverse=False)
            if (
                    s[1][0] == 19
            ):  # because the candidate are sorted in revers order ====> last (19) is the correct one
                hit_1 += 1
        return hit_1 / float(len(enc_batch))
Example #14
0
    def decoder_greedy(self, batch):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(config.max_dec_step):
            out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs,
                                          (mask_src, mask_trg))
            prob = self.generator(out, attn_dist, enc_batch_extend_vocab,
                                  extra_zeros)
            _, next_word = torch.max(prob[:, -1], dim=1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]

            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent
Example #15
0
    def decoder_greedy(self, batch, max_dec_step=31):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = get_input_from_batch(
            batch)

        encoder_outputs, encoder_hidden = self.encoder(enc_batch, enc_lens)

        s_t_1 = encoder_hidden.transpose(0, 1).contiguous().view(
            -1, config.hidden_dim * 2)  #b x hidden_dim*2

        kld_loss, z = self.latent(s_t_1, None, False)
        if config.model == "seq2seq":
            z = z - z

        s_t_1 = torch.cat((z, s_t_1), dim=-1)

        batch_size = enc_batch.size(0)
        y_t_1 = torch.LongTensor([config.SOS_idx] * batch_size)
        if config.USE_CUDA:
            y_t_1 = y_t_1.cuda()

        decoded_words = []
        for di in range(max_dec_step):
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, coverage = self.decoder(
                y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1,
                extra_zeros, enc_batch_extend_vocab, coverage, di)

            _, topk_ids = torch.topk(final_dist, 1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in topk_ids.view(-1)
            ])

            if config.USE_CUDA:
                y_t_1 = topk_ids.squeeze(-1).cuda()  # Teacher forcing
            else:
                y_t_1 = topk_ids.squeeze(-1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Example #16
0
    def train_one_batch(self, batch, train=True):
        ## pad and other stuff
        input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch(
            batch)
        dec_batch, dec_mask_batch, dec_index_batch, copy_gate, copy_ptr = get_output_from_batch(
            batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        with torch.no_grad():
            # encoder_outputs are hidden states from transformer
            encoder_outputs, _ = self.encoder(
                input_ids_batch,
                token_type_ids=example_index_batch,
                attention_mask=input_mask_batch,
                output_all_encoded_layers=False)

        # # Draft Decoder
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     input_ids_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda(device=0)

        dec_batch_shift = torch.cat(
            (sos_token, dec_batch[:, :-1]),
            1)  # shift the decoder input (summary) by one step
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit1, attn_dist1 = self.decoder(self.embedding(dec_batch_shift),
                                              encoder_outputs,
                                              (None, mask_trg))

        # print(pre_logit1.size())
        ## compute output dist
        logit1 = self.generator(pre_logit1,
                                attn_dist1,
                                enc_batch_extend_vocab,
                                extra_zeros,
                                copy_gate=copy_gate,
                                copy_ptr=copy_ptr,
                                mask_trg=mask_trg)
        ## loss: NNL if ptr else Cross entropy
        loss1 = self.criterion(logit1.contiguous().view(-1, logit1.size(-1)),
                               dec_batch.contiguous().view(-1))

        # Refine Decoder - train using gold label TARGET
        'TODO: turn gold-target-text into BERT insertable representation'
        pre_logit2, attn_dist2 = self.generate_refinement_output(
            encoder_outputs, dec_batch, dec_index_batch, extra_zeros,
            dec_mask_batch)
        # pre_logit2, attn_dist2 = self.decoder(self.embedding(encoded_gold_target),encoder_outputs, (None,mask_trg))

        logit2 = self.generator(pre_logit2,
                                attn_dist2,
                                enc_batch_extend_vocab,
                                extra_zeros,
                                copy_gate=copy_gate,
                                copy_ptr=copy_ptr,
                                mask_trg=None)
        loss2 = self.criterion(logit2.contiguous().view(-1, logit2.size(-1)),
                               dec_batch.contiguous().view(-1))

        loss = loss1 + loss2

        if train:
            loss.backward()
            self.optimizer.step()
        return loss
Example #17
0
    def train_one_batch(self, batch, train=True, mode="pretrain", task=0):
        enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        self.h0_encoder, self.c0_encoder = self.get_state(enc_batch)
        src_h, (src_h_t,
                src_c_t) = self.encoder(self.embedding(enc_batch),
                                        (self.h0_encoder, self.c0_encoder))
        h_t = src_h_t[-1]
        c_t = src_c_t[-1]

        # Decode
        decoder_init_state = nn.Tanh()(self.encoder2decoder(h_t))

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)
        target_embedding = self.embedding(dec_batch_shift)
        ctx = src_h.transpose(0, 1)
        trg_h, (_, _) = self.decoder(target_embedding,
                                     (decoder_init_state, c_t), ctx)

        #Memory
        mem_h_input = torch.cat(
            (decoder_init_state.unsqueeze(1), trg_h[:, 0:-1, :]), 1)
        mem_input = torch.cat((target_embedding, mem_h_input), 2)
        mem_output = self.memory(mem_input)

        #Combine
        gates = self.dec_gate(trg_h) + self.mem_gate(mem_output)
        decoder_gate, memory_gate = gates.chunk(2, 2)
        decoder_gate = F.sigmoid(decoder_gate)
        memory_gate = F.sigmoid(memory_gate)
        pre_logit = F.tanh(decoder_gate * trg_h + memory_gate * mem_output)
        logit = self.generator(pre_logit)

        if mode == "pretrain":
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
            if train:
                loss.backward()
                self.optimizer.step()
            if (config.label_smoothing):
                loss = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1))
            return loss.item(), math.exp(min(loss.item(), 100)), loss

        elif mode == "select":
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
            if (train):
                l1_loss = 0.0
                for p in self.memory.parameters():
                    l1_loss += torch.sum(torch.abs(p))
                loss += 0.0005 * l1_loss
                loss.backward()
                self.optimizer.step()
                self.compute_hooks(task)
            if (config.label_smoothing):
                loss = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1))
            return loss.item(), math.exp(min(loss.item(), 100)), loss

        else:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
            if (train):
                self.register_hooks(task)
                loss.backward()
                self.optimizer.step()
                self.unhook(task)
            if (config.label_smoothing):
                loss = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1))
            return loss.item(), math.exp(min(loss.item(), 100)), loss
Example #18
0
    def train_one_batch(self, batch, n_iter, train=True):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
            get_input_from_batch(batch)
        dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
            get_output_from_batch(batch)

        self.optimizer.zero_grad()

        encoder_outputs, encoder_hidden = self.encoder(enc_batch, enc_lens)
        # sort response for lstm
        r_len = np.array(batch["posterior_lengths"])
        r_sort = r_len.argsort()[::-1]
        r_len = r_len[r_sort].tolist()
        unsort = r_sort.argsort()

        _, encoder_hidden_r = self.encoder_r(
            batch["posterior_batch"][r_sort.tolist()], r_len)
        #encoder_hidden_r = encoder_hidden_r[unsort.tolist()] #unsort

        s_t_1 = encoder_hidden.transpose(0, 1).contiguous().view(
            -1, config.hidden_dim * 2)  #b x hidden_dim*2
        s_t_1_r = encoder_hidden_r.transpose(0, 1).contiguous().view(
            -1,
            config.hidden_dim * 2)[unsort.tolist()]  #unsort #b x hidden_dim*2
        batch_size = enc_batch.size(0)

        #meta = self.embedding(batch["program_label"])
        kld_loss, z = self.latent(s_t_1, s_t_1_r, train=True)
        if config.model == "seq2seq":
            z = z - z
            kld_loss = torch.Tensor([0])

        s_t_1 = torch.cat((z, s_t_1), dim=-1)

        if config.model == "cvae":
            z_logit = self.bow(s_t_1)  # [batch_size, vocab_size]
            z_logit = z_logit.unsqueeze(1).repeat(1, dec_batch.size(1), 1)
            loss_aux = self.criterion(
                z_logit.contiguous().view(-1, z_logit.size(-1)),
                dec_batch.contiguous().view(-1))
        y_t_1 = torch.LongTensor([config.SOS_idx] * batch_size)

        if config.USE_CUDA:
            y_t_1 = y_t_1.cuda()
        step_losses = []
        for di in range(max_dec_len):
            final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.decoder(
                y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1,
                extra_zeros, enc_batch_extend_vocab, coverage, di)
            target = target_batch[:, di]
            gold_probs = torch.gather(final_dist, 1,
                                      target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs + config.eps)
            if config.is_coverage:
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                coverage = next_coverage

            step_mask = dec_padding_mask[:, di]
            step_loss = step_loss * step_mask.cuda()
            step_losses.append(step_loss)
            y_t_1 = dec_batch[:, di]  # Teacher forcing

        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / dec_lens_var.float().cuda()
        loss_rec = torch.mean(batch_avg_loss)
        if config.model == "cvae":
            kl_weight = min(
                math.tanh(6 * n_iter / config.full_kl_step - 3) + 1, 1)
            #kl_weight = min(n_iter/config.full_kl_step, 0.5) if config.full_kl_step >0 else 1.0
            loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + loss_aux * config.aux_ceiling
            elbo = loss_rec + kld_loss
        else:
            loss = loss_rec
            loss_aux = torch.Tensor([0])
            elbo = loss_rec
        if (train):
            loss.backward()

            self.optimizer.step()
        return loss_rec.item(), math.exp(
            loss_rec.item()), kld_loss.item(), loss_aux.item(), elbo.item()
Example #19
0
    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(
            config.PAD_idx).unsqueeze(1)
        posterior_mask = self.embedding(batch["posterior_mask"])
        r_encoder_outputs = self.r_encoder(
            self.embedding(batch["posterior_batch"]) + posterior_mask,
            mask_res)
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)
        #latent variable
        if config.model == "cvaetrs":
            kld_loss, z = self.latent_layer(encoder_outputs[:, 0],
                                            r_encoder_outputs[:, 0],
                                            train=True)

        meta = self.embedding(batch["program_label"])
        if config.dataset == "empathetic":
            meta = meta - meta
        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()

        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  #(batch, len, embedding)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        input_vector = self.embedding(dec_batch_shift)
        if config.model == "cvaetrs":
            input_vector[:, 0] = input_vector[:, 0] + z + meta
        else:
            input_vector[:, 0] = input_vector[:, 0] + meta
        pre_logit, attn_dist = self.decoder(input_vector, encoder_outputs,
                                            (mask_src, mask_trg))

        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
        if config.model == "cvaetrs":
            z_logit = self.bow(z + meta)  # [batch_size, vocab_size]
            z_logit = z_logit.unsqueeze(1).repeat(1, logit.size(1), 1)
            loss_aux = self.criterion(
                z_logit.contiguous().view(-1, z_logit.size(-1)),
                dec_batch.contiguous().view(-1))
            if config.multitask:
                emo_logit = self.emo(encoder_outputs[:, 0])
                emo_loss = self.emo_criterion(emo_logit,
                                              batch["program_label"] - 9)
            #kl_weight = min(iter/config.full_kl_step, 0.28) if config.full_kl_step >0 else 1.0
            kl_weight = min(
                math.tanh(6 * iter / config.full_kl_step - 3) + 1, 1)
            loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux
            if config.multitask:
                loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux + emo_loss
            aux = loss_aux.item()
            elbo = loss_rec + kld_loss
        else:
            loss = loss_rec
            elbo = loss_rec
            kld_loss = torch.Tensor([0])
            aux = 0
            if config.multitask:
                emo_logit = self.emo(encoder_outputs[:, 0])
                emo_loss = self.emo_criterion(emo_logit,
                                              batch["program_label"] - 9)
                loss = loss_rec + emo_loss
        if (train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(
            loss_rec.item(), 100)), kld_loss.item(), aux, elbo.item()
    def train_one_batch(self, batch, iter, train=True):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(
            config.PAD_idx).unsqueeze(1)
        post_emb = self.embedding(batch["posterior_batch"])
        r_encoder_outputs = self.r_encoder(post_emb, mask_res)

        ## Encode
        num_sentences, enc_seq_len = enc_batch.size()
        batch_size = enc_lens.size(0)
        max_len = enc_lens.data.max().item()
        input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1)

        # word level encoder
        enc_emb = self.embedding(enc_batch)
        word_encoder_outpus, word_encoder_hidden = self.word_encoder(
            enc_emb, input_lengths)
        word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(
            num_sentences, -1)

        # pad and pack word_encoder_hidden
        start = torch.cumsum(
            torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0)
        word_encoder_hidden = torch.stack([
            pad(word_encoder_hidden.narrow(0, s, l), max_len)
            for s, l in zip(start.data.tolist(), enc_lens.data.tolist())
        ], 0)

        # mask_src = ~(enc_padding_mask.bool()).unsqueeze(1)
        mask_src = (1 - enc_padding_mask.byte()).unsqueeze(1)

        # context level encoder
        if word_encoder_hidden.size(-1) != config.hidden_dim:
            word_encoder_hidden = self.linear(word_encoder_hidden)
        encoder_outputs = self.encoder(word_encoder_hidden, mask_src)

        #latent variable
        if config.model == "cvaetrs":
            kld_loss, z = self.latent_layer(encoder_outputs[:, 0],
                                            r_encoder_outputs[:, 0],
                                            train=True)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     batch_size).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()

        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  #(batch, len, embedding)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        input_vector = self.embedding(dec_batch_shift)
        if config.model == "cvaetrs":
            input_vector[:, 0] = input_vector[:, 0] + z
        else:
            input_vector[:, 0] = input_vector[:, 0]
        pre_logit, attn_dist = self.decoder(input_vector, encoder_outputs,
                                            (mask_src, mask_trg))

        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
        if config.model == "cvaetrs":
            z_logit = self.bow(z)  # [batch_size, vocab_size]
            z_logit = z_logit.unsqueeze(1).repeat(1, logit.size(1), 1)
            loss_aux = self.criterion(
                z_logit.contiguous().view(-1, z_logit.size(-1)),
                dec_batch.contiguous().view(-1))

            #kl_weight = min(iter/config.full_kl_step, 0.28) if config.full_kl_step >0 else 1.0
            kl_weight = min(
                math.tanh(6 * iter / config.full_kl_step - 3) + 1, 1)
            loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux

            aux = loss_aux.item()
            elbo = loss_rec + kld_loss
        else:
            loss = loss_rec
            elbo = loss_rec
            kld_loss = torch.Tensor([0])
            aux = 0
            if config.multitask:
                emo_logit = self.emo(encoder_outputs[:, 0])
                emo_loss = self.emo_criterion(emo_logit,
                                              batch["program_label"] - 9)
                loss = loss_rec + emo_loss
        if (train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(
            loss_rec.item(), 100)), kld_loss.item(), aux, elbo.item()
Example #21
0
    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        if config.dataset == "empathetic":
            emb_mask = self.embedding(batch["mask_input"])
            encoder_outputs = self.encoder(
                self.embedding(enc_batch) + emb_mask, mask_src)
        else:
            encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)
        ## Attention over decoder
        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        #q_h = encoder_outputs[:,0]
        logit_prob = self.decoder_key(q_h)  #(bsz, num_experts)

        if (config.topk > 0):
            k_max_value, k_max_index = torch.topk(logit_prob, config.topk)
            a = np.empty([logit_prob.shape[0], self.decoder_number])
            a.fill(float('-inf'))
            mask = torch.Tensor(a).cuda()
            logit_prob_ = mask.scatter_(1,
                                        k_max_index.cuda().long(), k_max_value)
            attention_parameters = self.attention_activation(logit_prob_)
        else:
            attention_parameters = self.attention_activation(logit_prob)
        # print("===============================================================================")
        # print("listener attention weight:",attention_parameters.data.cpu().numpy())
        # print("===============================================================================")
        if (config.oracle):
            attention_parameters = self.attention_activation(
                torch.FloatTensor(batch['target_program']) * 1000).cuda()
        attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze(
            -1)  # (batch_size, expert_num, 1, 1)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),
                                            encoder_outputs,
                                            (mask_src, mask_trg),
                                            attention_parameters)
        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        #logit = F.log_softmax(logit,dim=-1) #fix the name later
        ## loss: NNL if ptr else Cross entropy
        if (train and config.schedule > 10):
            if (random.uniform(0, 1) <=
                (0.0001 +
                 (1 - 0.0001) * math.exp(-1. * iter / config.schedule))):
                config.oracle = True
            else:
                config.oracle = False

        if config.softmax:
            loss = self.criterion(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)) + nn.CrossEntropyLoss()(
                    logit_prob, torch.LongTensor(
                        batch['program_label']).cuda())
            loss_bce_program = nn.CrossEntropyLoss()(
                logit_prob,
                torch.LongTensor(batch['program_label']).cuda()).item()
        else:
            loss = self.criterion(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)) + nn.BCEWithLogitsLoss()(
                    logit_prob, torch.FloatTensor(
                        batch['target_program']).cuda())
            loss_bce_program = nn.BCEWithLogitsLoss()(
                logit_prob,
                torch.FloatTensor(batch['target_program']).cuda()).item()
        pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1)
        program_acc = accuracy_score(batch["program_label"], pred_program)

        if (config.label_smoothing):
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)).item()

        if (train):
            loss.backward()
            self.optimizer.step()

        if (config.label_smoothing):
            return loss_ppl, math.exp(min(loss_ppl,
                                          100)), loss_bce_program, program_acc
        else:
            return loss.item(), math.exp(min(
                loss.item(), 100)), loss_bce_program, program_acc
Example #22
0
    def decoder_topk(self, batch, max_dec_step=30):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        ## Attention over decoder
        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        #q_h = encoder_outputs[:,0]
        logit_prob = self.decoder_key(q_h)

        if (config.topk > 0):
            k_max_value, k_max_index = torch.topk(logit_prob, config.topk)
            a = np.empty([logit_prob.shape[0], self.decoder_number])
            a.fill(float('-inf'))
            mask = torch.Tensor(a).cuda()
            logit_prob = mask.scatter_(1,
                                       k_max_index.cuda().long(), k_max_value)

        attention_parameters = self.attention_activation(logit_prob)

        if (config.oracle):
            attention_parameters = self.attention_activation(
                torch.FloatTensor(batch['target_program']) * 1000).cuda()
        attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze(
            -1)  # (batch_size, expert_num, 1, 1)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg), attention_parameters)
            else:

                out, attn_dist = self.decoder(self.embedding(ys),
                                              encoder_outputs,
                                              (mask_src, mask_trg),
                                              attention_parameters)

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            filtered_logit = top_k_top_p_filtering(logit[:, -1],
                                                   top_k=3,
                                                   top_p=0,
                                                   filter_value=-float('Inf'))
            # Sample from the filtered distribution
            next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1),
                                          1).squeeze()
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        
        ## Encode
        num_sentences, enc_seq_len = enc_batch.size()
        batch_size = enc_lens.size(0)
        max_len = enc_lens.data.max().item()
        input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1)
        
        # word level encoder
        enc_emb = self.embedding(enc_batch)
        word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths)
        word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1)

        # pad and pack word_encoder_hidden
        start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0)
        word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len)
                                            for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0)
        mask_src = ~(enc_padding_mask.bool()).unsqueeze(1)

        # context level encoder
        if word_encoder_hidden.size(-1) != config.hidden_dim:
            word_encoder_hidden = self.linear(word_encoder_hidden)
        encoder_outputs = self.encoder(word_encoder_hidden, mask_src)
        
        ys = torch.ones(batch_size, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        
        decoded_words = []
        for i in range(max_dec_step+1):

            out, attn_dist, _, _,_ = self.decoder(self.embedding(ys), encoder_outputs, None, (mask_src, None, mask_trg))
            
            prob = self.generator(out,attn_dist,enc_batch_extend_vocab, extra_zeros, attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim = 1)
            decoded_words.append(['<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1)])

            if config.USE_CUDA:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
            
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st+= e + ' '
            sent.append(st)
        return sent
Example #24
0
    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)

        if config.dataset == "empathetic":
            emb_mask = self.embedding(batch["mask_input"])
            encoder_outputs = self.encoder(
                self.embedding(enc_batch) + emb_mask, mask_src)
        else:
            encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        # q_h = torch.max(encoder_outputs, dim=1)
        emotions_mimic, emotions_non_mimic, mu_positive_prior, logvar_positive_prior, mu_negative_prior, logvar_negative_prior = \
            self.vae_sampler(q_h, batch['program_label'], self.emoji_embedding)
        # KLLoss = -0.5 * (torch.sum(1 + logvar_n - mu_n.pow(2) - logvar_n.exp()) + torch.sum(1 + logvar_p - mu_p.pow(2) - logvar_p.exp()))
        m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1),
                                             encoder_outputs, mask_src)
        m_tilde_out = self.emotion_input_encoder_2(
            emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src)
        if train:
            emotions_mimic, emotions_non_mimic, mu_positive_posterior, logvar_positive_posterior, mu_negative_posterior, logvar_negative_posterior = \
                self.vae_sampler.forward_train(q_h, batch['program_label'], self.emoji_embedding, M_out=m_out.mean(dim=1), M_tilde_out=m_tilde_out.mean(dim=1))
            KLLoss_positive = self.vae_sampler.kl_div(
                mu_positive_posterior, logvar_positive_posterior,
                mu_positive_prior, logvar_positive_prior)
            KLLoss_negative = self.vae_sampler.kl_div(
                mu_negative_posterior, logvar_negative_posterior,
                mu_negative_prior, logvar_negative_prior)
            KLLoss = KLLoss_positive + KLLoss_negative
        else:
            KLLoss_positive = self.vae_sampler.kl_div(mu_positive_prior,
                                                      logvar_positive_prior)
            KLLoss_negative = self.vae_sampler.kl_div(mu_negative_prior,
                                                      logvar_negative_prior)
            KLLoss = KLLoss_positive + KLLoss_negative

        if config.emo_combine == "att":
            v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src)
        elif config.emo_combine == "gate":
            v = self.cdecoder(m_out, m_tilde_out)

        x = self.s_weight(q_h)

        # method2: E (W@c)
        logit_prob = torch.matmul(x, self.emoji_embedding.weight.transpose(
            0, 1))  # shape (b_size, 32)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift), v,
                                            v, (mask_src, mask_trg))

        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)

        if (train and config.schedule > 10):
            if (random.uniform(0, 1) <=
                (0.0001 +
                 (1 - 0.0001) * math.exp(-1. * iter / config.schedule))):
                config.oracle = True
            else:
                config.oracle = False

        if config.softmax:
            program_label = torch.LongTensor(batch['program_label'])
            if config.USE_CUDA: program_label = program_label.cuda()

            if config.emo_combine == 'gate':
                L1_loss = nn.CrossEntropyLoss()(logit_prob, program_label)
                loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1)) + KLLoss + L1_loss
            else:
                L1_loss = nn.CrossEntropyLoss()(
                    logit_prob, torch.LongTensor(batch['program_label'])
                    if not config.USE_CUDA else torch.LongTensor(
                        batch['program_label']).cuda())
                loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1)) + KLLoss + L1_loss

            loss_bce_program = nn.CrossEntropyLoss()(logit_prob,
                                                     program_label).item()
        else:
            loss = self.criterion(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)) + nn.BCEWithLogitsLoss()(
                    logit_prob, torch.FloatTensor(
                        batch['target_program']).cuda())
            loss_bce_program = nn.BCEWithLogitsLoss()(
                logit_prob,
                torch.FloatTensor(batch['target_program']).cuda()).item()
        pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1)
        program_acc = accuracy_score(batch["program_label"], pred_program)

        if (config.label_smoothing):
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)).item()

        if (train):
            loss.backward()
            self.optimizer.step()

        if (config.label_smoothing):
            return loss_ppl, math.exp(min(loss_ppl,
                                          100)), loss_bce_program, program_acc
        else:
            return loss.item(), math.exp(min(
                loss.item(), 100)), loss_bce_program, program_acc
    def train_one_batch(self, batch, iter, train=True):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(config.PAD_idx).unsqueeze(1)
        post_emb = self.embedding(batch["posterior_batch"])
        r_encoder_outputs = self.r_encoder(post_emb, mask_res)

        ## Encode
        num_sentences, enc_seq_len = enc_batch.size()
        batch_size = enc_lens.size(0)
        max_len = enc_lens.data.max().item()
        input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1)
        
        # word level encoder
        enc_emb = self.embedding(enc_batch)
        word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths)
        word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1)

        # pad and pack word_encoder_hidden
        start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0)
        word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len)
                                            for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0)
        
        # mask_src = ~(enc_padding_mask.bool()).unsqueeze(1)
        mask_src = (1 - enc_padding_mask.byte()).unsqueeze(1)
        
        # context level encoder
        if word_encoder_hidden.size(-1) != config.hidden_dim:
            word_encoder_hidden = self.linear(word_encoder_hidden)
        encoder_outputs = self.encoder(word_encoder_hidden, mask_src)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] * batch_size).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()

        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) #(batch, len, embedding)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        dec_emb = self.embedding(dec_batch_shift)

        pre_logit, attn_dist, mean, log_var, probs = self.decoder(dec_emb, encoder_outputs, r_encoder_outputs, 
                                                                    (mask_src, mask_res, mask_trg))
        
        ## compute output dist
        logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        sbow = dec_batch #[batch, seq_len]
        seq_len = sbow.size(1)
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        if config.model=="cvaetrs":
            loss_aux = 0
            for prob in probs:
                sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2)
                sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len]

                loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1))
            kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"])
            kld_loss = torch.mean(kld_loss)
            kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1)
            #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0
            loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux
            elbo = loss_rec + kld_loss
        else:
            loss = loss_rec
            elbo = loss_rec
            kld_loss = torch.Tensor([0])
            loss_aux = torch.Tensor([0])
        if(train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item()
Example #26
0
    def decoder_greedy(self, batch):
        input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch(
            batch)
        # mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        with torch.no_grad():
            encoder_outputs, _ = self.encoder(
                input_ids_batch,
                token_type_ids=enc_batch_extend_vocab,
                attention_mask=input_mask_batch,
                output_all_encoded_layers=False)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA: ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(config.max_dec_step):
            out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs,
                                          (None, mask_trg))
            prob = self.generator(out, attn_dist, enc_batch_extend_vocab,
                                  extra_zeros)
            _, next_word = torch.max(prob[:, -1], dim=1)

            decoded_words.append(
                self.tokenizer.convert_ids_to_tokens(next_word.tolist()))

            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>' or e == '<PAD>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Example #27
0
    def decoder_topk(self,
                     batch,
                     max_dec_step=30,
                     emotion_classifier='built_in'):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        emotions = batch['program_label']

        context_emo = [
            self.positive_emotions[0]
            if d['compound'] > 0 else self.negative_emotions[0]
            for d in batch['context_emotion_scores']
        ]
        context_emo = torch.Tensor(context_emo)
        if config.USE_CUDA:
            context_emo = context_emo.cuda()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)

        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]

        x = self.s_weight(q_h)
        # method 2
        logit_prob = torch.matmul(x,
                                  self.emoji_embedding.weight.transpose(0, 1))

        if emotion_classifier == "vader":
            context_emo = [
                self.positive_emotions[0]
                if d['compound'] > 0 else self.negative_emotions[0]
                for d in batch['context_emotion_scores']
            ]
            context_emo = torch.Tensor(context_emo)
            if config.USE_CUDA:
                context_emo = context_emo.cuda()
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, context_emo, self.emoji_embedding)
        elif emotion_classifier == None:
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, batch['program_label'], self.emoji_embedding)
        elif emotion_classifier == "built_in":
            emo_pred = torch.argmax(logit_prob, dim=-1)
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, emo_pred, self.emoji_embedding)

        m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1),
                                             encoder_outputs, mask_src)
        m_tilde_out = self.emotion_input_encoder_2(
            emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src)

        if config.emo_combine == "att":
            v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src)
        elif config.emo_combine == "gate":
            v = self.cdecoder(m_out, m_tilde_out)
        elif config.emo_combine == 'vader':
            m_weight = context_emo_scores.unsqueeze(-1).unsqueeze(-1)
            m_tilde_weight = 1 - m_weight
            v = m_weight * m_weight + m_tilde_weight * m_tilde_out

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg), attention_parameters)
            else:
                out, attn_dist = self.decoder(self.embedding(ys), v, v,
                                              (mask_src, mask_trg))

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            filtered_logit = top_k_top_p_filtering(logit[:, -1],
                                                   top_k=3,
                                                   top_p=0,
                                                   filter_value=-float('Inf'))
            # Sample from the filtered distribution
            next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1),
                                          1).squeeze()
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data.item()

            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent