def __init__(self,
                 vocab,
                 emo_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words
        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""
    def __init__(self, vocab, emo_number,  model_file_path=None, is_eval=False, load_optim=False):
        super(CvaeTrans, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, pretrain=False)
        
        self.word_encoder = WordEncoder(config.emb_dim, config.hidden_dim, config.bidirectional)
        self.encoder = Encoder(config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth, total_value_depth=config.depth,
                                filter_size=config.filter, universal=config.universal)
        self.r_encoder = Encoder(config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth, total_value_depth=config.depth,
                                filter_size=config.filter, universal=config.universal)
        
        self.decoder = VarDecoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, 
                                num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, 
                                filter_size=config.filter, vocab_size=self.vocab_size)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        self.linear = nn.Linear(2 * config.hidden_dim, config.hidden_dim)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        
        if model_file_path:
            print("loading weights")
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            #self.r_encoder.load_state_dict(state['r_encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            
        if (config.USE_CUDA):
            self.cuda()
        if is_eval:
            self.eval()
        else:
            self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
            if(config.noam):
                self.optimizer = NoamOpt(config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
                if config.USE_CUDA:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""
Example #3
0
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(PGNet, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.reduce_state = ReduceState()

        self.generator = Generator(config.rnn_hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if config.label_smoothing:
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if config.noam:
            self.optimizer = NoamOpt(
                config.rnn_hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if load_optim:
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""
class CvaeTrans(nn.Module):

    def __init__(self, vocab, emo_number,  model_file_path=None, is_eval=False, load_optim=False):
        super(CvaeTrans, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, pretrain=False)
        
        self.word_encoder = WordEncoder(config.emb_dim, config.hidden_dim, config.bidirectional)
        self.encoder = Encoder(config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth, total_value_depth=config.depth,
                                filter_size=config.filter, universal=config.universal)
        self.r_encoder = Encoder(config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth, total_value_depth=config.depth,
                                filter_size=config.filter, universal=config.universal)
        
        self.decoder = VarDecoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, 
                                num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, 
                                filter_size=config.filter, vocab_size=self.vocab_size)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        self.linear = nn.Linear(2 * config.hidden_dim, config.hidden_dim)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        
        if model_file_path:
            print("loading weights")
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            #self.r_encoder.load_state_dict(state['r_encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            
        if (config.USE_CUDA):
            self.cuda()
        if is_eval:
            self.eval()
        else:
            self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
            if(config.noam):
                self.optimizer = NoamOpt(config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
                if config.USE_CUDA:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            #'r_encoder_state_dict': self.r_encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) )
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(config.PAD_idx).unsqueeze(1)
        post_emb = self.embedding(batch["posterior_batch"])
        r_encoder_outputs = self.r_encoder(post_emb, mask_res)

        ## Encode
        num_sentences, enc_seq_len = enc_batch.size()
        batch_size = enc_lens.size(0)
        max_len = enc_lens.data.max().item()
        input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1)
        
        # word level encoder
        enc_emb = self.embedding(enc_batch)
        word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths)
        word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1)

        # pad and pack word_encoder_hidden
        start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0)
        word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len)
                                            for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0)
        
        # mask_src = ~(enc_padding_mask.bool()).unsqueeze(1)
        mask_src = (1 - enc_padding_mask.byte()).unsqueeze(1)
        
        # context level encoder
        if word_encoder_hidden.size(-1) != config.hidden_dim:
            word_encoder_hidden = self.linear(word_encoder_hidden)
        encoder_outputs = self.encoder(word_encoder_hidden, mask_src)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] * batch_size).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()

        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) #(batch, len, embedding)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        dec_emb = self.embedding(dec_batch_shift)

        pre_logit, attn_dist, mean, log_var, probs = self.decoder(dec_emb, encoder_outputs, r_encoder_outputs, 
                                                                    (mask_src, mask_res, mask_trg))
        
        ## compute output dist
        logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        sbow = dec_batch #[batch, seq_len]
        seq_len = sbow.size(1)
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        if config.model=="cvaetrs":
            loss_aux = 0
            for prob in probs:
                sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2)
                sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len]

                loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1))
            kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"])
            kld_loss = torch.mean(kld_loss)
            kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1)
            #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0
            loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux
            elbo = loss_rec + kld_loss
        else:
            loss = loss_rec
            elbo = loss_rec
            kld_loss = torch.Tensor([0])
            loss_aux = torch.Tensor([0])
        if(train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item()

    def train_n_batch(self, batchs, iter, train=True):
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()
        for batch in batchs:
            enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
            dec_batch, _, _, _, _ = get_output_from_batch(batch)
            ## Encode
            mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
            encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

            meta = self.embedding(batch["program_label"])
            if config.dataset=="empathetic":
                meta = meta-meta
            # Decode
            sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
            if config.USE_CUDA: sos_token = sos_token.cuda()
            dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1)

            mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

            pre_logit, attn_dist, mean, log_var, probs= self.decoder(self.embedding(dec_batch_shift)+meta.unsqueeze(1),encoder_outputs, True, (mask_src,mask_trg))
            ## compute output dist
            logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None)
            ## loss: NNL if ptr else Cross entropy
            sbow = dec_batch #[batch, seq_len]
            seq_len = sbow.size(1)
            loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
            if config.model=="cvaetrs":
                loss_aux = 0
                for prob in probs:
                    sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2)
                    sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len]

                    loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1))
                kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"])
                kld_loss = torch.mean(kld_loss)
                kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1)
                #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0
                loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux
                elbo = loss_rec+kld_loss
            else:
                loss = loss_rec
                elbo = loss_rec
                kld_loss = torch.Tensor([0])
                loss_aux = torch.Tensor([0])
            loss.backward()
            # clip gradient
        nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
        self.optimizer.step()

        return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item()

    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        
        ## Encode
        num_sentences, enc_seq_len = enc_batch.size()
        batch_size = enc_lens.size(0)
        max_len = enc_lens.data.max().item()
        input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1)
        
        # word level encoder
        enc_emb = self.embedding(enc_batch)
        word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths)
        word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1)

        # pad and pack word_encoder_hidden
        start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0)
        word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len)
                                            for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0)
        mask_src = ~(enc_padding_mask.bool()).unsqueeze(1)

        # context level encoder
        if word_encoder_hidden.size(-1) != config.hidden_dim:
            word_encoder_hidden = self.linear(word_encoder_hidden)
        encoder_outputs = self.encoder(word_encoder_hidden, mask_src)
        
        ys = torch.ones(batch_size, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        
        decoded_words = []
        for i in range(max_dec_step+1):

            out, attn_dist, _, _,_ = self.decoder(self.embedding(ys), encoder_outputs, None, (mask_src, None, mask_trg))
            
            prob = self.generator(out,attn_dist,enc_batch_extend_vocab, extra_zeros, attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim = 1)
            decoded_words.append(['<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1)])

            if config.USE_CUDA:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
            
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st+= e + ' '
            sent.append(st)
        return sent
Example #5
0
File: AOTNet.py Project: qtli/AOT
class AOT(nn.Module):
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(AOT, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)
        self.sse = SSE(vocab, config.emb_dim, config.dropout,
                       config.rnn_hidden_dim)
        self.rcr = RCR()

        ## multiple decoders
        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if config.label_smoothing:
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if config.noam:
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if load_optim:
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch_slow(self, batch, iter, train=True):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        src_batch = batch[
            'reviews_batch']  # reviews sequence (bsz, r_num, r_len)
        src_mask = batch[
            'reviews_mask']  # indicate which review is fake(for padding).  (bsz, r_num)
        src_length = batch['reviews_length']  # (bsz, r_num)
        enc_length_batch = batch[
            'reviews_length_list']  # 2-dim list, 0: len=bsz, 1: lens of reviews and pads
        src_labels = batch['reviews_label']  # (bsz, r_num)
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]
        dec_rank_batch = batch[
            'tags_idx_batch']  # tag indexes sequence (bsz, tgt_len)

        if config.noam:
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        # 1. Sentence-level Salience Estimation (SSE)
        cla_loss, sa_scores, sa_acc = self.sse.salience_estimate(
            src_batch, src_mask, src_length,
            src_labels)  # sa_scores: (bsz, r_num)

        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        src_enc_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, src_len, emb_dim)
        src_ext_rank = torch.LongTensor([]).to(config.device)  # (bsz, src_len)
        aln_rank = torch.LongTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)
        aln_mask_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)

        bsz, max_src_len = enc_batch.size()
        for idx in range(bsz):  # Custering (by k-means) and Ranking
            item_length = enc_length_batch[idx]
            reviews = torch.split(encoder_outputs[idx], item_length, dim=0)
            reviews_ext = torch.split(enc_batch_extend_vocab[idx],
                                      item_length,
                                      dim=0)

            r_vectors = []  # store the vector repr of each review
            rs_vectors = []  # store the token vectors repr of each review
            r_exts = []
            r_pad_vec, r_ext_pad = None, None
            for r_idx in range(len(item_length)):
                if r_idx == len(item_length) - 1:
                    r_pad_vec = reviews[r_idx]
                    r_ext_pad = reviews_ext[r_idx]
                    break
                r = self.rcr.hierarchical_pooling(reviews[r_idx].unsqueeze(
                    0)).squeeze(0).detach().cpu().numpy() * sa_scores[idx,
                                                                      r_idx]
                r_vectors.append(r)
                rs_vectors.append(reviews[r_idx])
                r_exts.append(reviews_ext[r_idx])

            rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln = \
                self.rcr.perform(r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, dec_rank_batch[idx], max_src_len)
            # rs_repr: (max_rs_length, embed_dim); ext_repr: (max_rs_length); srctgt_aln_mask/srctgt_aln: (tgt_len, max_rs_length)

            src_enc_rank = torch.cat((src_enc_rank, rs_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len, embed_dim)
            src_ext_rank = torch.cat((src_ext_rank, ext_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len)
            aln_rank = torch.cat((aln_rank, srctgt_aln.unsqueeze(0)),
                                 dim=0)  # (1->bsz, max_tgt_len, max_src_len)
            aln_mask_rank = torch.cat(
                (aln_mask_rank, srctgt_aln_mask.unsqueeze(0)), dim=0)

        del encoder_outputs, reviews, reviews_ext, r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln
        torch.cuda.empty_cache()
        torch.backends.cuda.cufft_plan_cache.clear()

        ys = torch.LongTensor([config.SOS_idx] *
                              enc_batch.size(0)).unsqueeze(1).to(
                                  config.device)  # (bsz, 1)
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        ys_rank = torch.LongTensor([1] * enc_batch.size(0)).unsqueeze(1).to(
            config.device)

        max_tgt_len = dec_batch.size(1)
        loss, loss_ppl = 0, 0
        for t in range(max_tgt_len):
            aln_rank_cur = aln_rank[:, t, :].unsqueeze(1)  # (bsz, 1, src_len)
            aln_mask_cur = aln_mask_rank[:, :(t + 1), :]  # (bsz, src_len)
            pre_logit, attn_dist, aln_loss_cur = self.decoder(
                inputs=self.embedding(ys),
                inputs_rank=ys_rank,
                encoder_output=src_enc_rank,
                aln_rank=aln_rank_cur,
                aln_mask_rank=aln_mask_cur,
                mask=(mask_src, mask_trg),
                speed='slow')
            # todo
            if iter >= 13000:
                loss += (0.1 * aln_loss_cur)
            else:
                loss += aln_loss_cur
            logit = self.generator(
                pre_logit, attn_dist.unsqueeze(1),
                enc_batch_extend_vocab if config.pointer_gen else None,
                extra_zeros)

            if config.pointer_gen:
                loss += self.criterion(
                    logit[:, -1, :].contiguous().view(-1, logit.size(-1)),
                    dec_ext_batch[:, t].contiguous().view(-1))
            else:
                loss += self.criterion(
                    logit[:, -1, :].contiguous().view(-1, logit.size(-1)),
                    dec_batch[:, t].contiguous().view(-1))

            if config.label_smoothing:
                loss_ppl += self.criterion_ppl(
                    logit[:, -1, :].contiguous().view(-1, logit.size(-1)),
                    dec_ext_batch[:, t].contiguous().view(-1) if
                    config.pointer_gen else dec_batch[:,
                                                      t].contiguous().view(-1))

            ys = torch.cat((ys, dec_batch[:, t].unsqueeze(1)), dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
            ys_rank = torch.cat((ys_rank, dec_rank_batch[:, t].unsqueeze(1)),
                                dim=1)

        loss = loss + cla_loss
        if train:
            loss /= max_tgt_len
            loss.backward()
            self.optimizer.step()

        if config.label_smoothing:
            loss_ppl /= max_tgt_len
            if torch.isnan(loss_ppl).sum().item() != 0 or torch.isinf(
                    loss_ppl).sum().item() != 0:
                print("check")
                pdb.set_trace()
            return loss_ppl.item(), math.exp(min(loss_ppl.item(),
                                                 100)), cla_loss.item(), sa_acc
        else:
            return loss.item(), math.exp(min(loss.item(),
                                             100)), cla_loss.item(), sa_acc

    def train_one_batch(self, batch, iter, train=True):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        src_batch = batch[
            'reviews_batch']  # reviews sequence (bsz, r_num, r_len)
        src_mask = batch[
            'reviews_mask']  # indicate which review is fake(for padding).  (bsz, r_num)
        src_length = batch['reviews_length']  # (bsz, r_num)
        enc_length_batch = batch[
            'reviews_length_list']  # 2-dim list, 0: len=bsz, 1: lens of reviews and pads
        src_labels = batch['reviews_label']  # (bsz, r_num)
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]
        dec_rank_batch = batch[
            'tags_idx_batch']  # tag indexes sequence (bsz, tgt_len)

        if config.noam:
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        # 1. Sentence-level Salience Estimation (SSE)
        cla_loss, sa_scores, sa_acc = self.sse.salience_estimate(
            src_batch, src_mask, src_length,
            src_labels)  # sa_scores: (bsz, r_num)

        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        src_enc_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, src_len, emb_dim)
        src_ext_rank = torch.LongTensor([]).to(config.device)  # (bsz, src_len)
        aln_rank = torch.LongTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)
        aln_mask_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)

        bsz, max_src_len = enc_batch.size()
        for idx in range(bsz):  # Custering (by k-means) and Ranking
            item_length = enc_length_batch[idx]
            reviews = torch.split(encoder_outputs[idx], item_length, dim=0)
            reviews_ext = torch.split(enc_batch_extend_vocab[idx],
                                      item_length,
                                      dim=0)

            r_vectors = []  # store the vector repr of each review
            rs_vectors = []  # store the token vectors repr of each review
            r_exts = []
            r_pad_vec, r_ext_pad = None, None
            for r_idx in range(len(item_length)):
                if r_idx == len(item_length) - 1:
                    r_pad_vec = reviews[r_idx]
                    r_ext_pad = reviews_ext[r_idx]
                    break
                r = self.rcr.hierarchical_pooling(reviews[r_idx].unsqueeze(
                    0)).squeeze(0).detach().cpu().numpy() * sa_scores[idx,
                                                                      r_idx]
                r_vectors.append(r)
                rs_vectors.append(reviews[r_idx])
                r_exts.append(reviews_ext[r_idx])

            rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln = \
                self.rcr.perform(r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, dec_rank_batch[idx], max_src_len)
            # rs_repr: (max_rs_length, embed_dim); ext_repr: (max_rs_length); srctgt_aln_mask/srctgt_aln: (tgt_len, max_rs_length)

            src_enc_rank = torch.cat((src_enc_rank, rs_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len, embed_dim)
            src_ext_rank = torch.cat((src_ext_rank, ext_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len)
            aln_rank = torch.cat((aln_rank, srctgt_aln.unsqueeze(0)),
                                 dim=0)  # (1->bsz, max_tgt_len, max_src_len)
            aln_mask_rank = torch.cat(
                (aln_mask_rank, srctgt_aln_mask.unsqueeze(0)), dim=0)

        del encoder_outputs, reviews, reviews_ext, r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln
        torch.cuda.empty_cache()
        torch.backends.cuda.cufft_plan_cache.clear()

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1).to(
                                         config.device)  # (bsz, 1)
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  # (bsz, tgt_len)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

        sos_rank = torch.LongTensor([1] * enc_batch.size(0)).unsqueeze(1).to(
            config.device)
        dec_rank_batch = torch.cat((sos_rank, dec_rank_batch[:, :-1]), 1)

        aln_rank = aln_rank[:, :-1, :]
        aln_mask_rank = aln_mask_rank[:, :-1, :]

        pre_logit, attn_dist, aln_loss = self.decoder(
            inputs=self.embedding(dec_batch_shift),
            inputs_rank=dec_rank_batch,
            encoder_output=src_enc_rank,
            aln_rank=aln_rank,
            aln_mask_rank=aln_mask_rank,
            mask=(mask_src, mask_trg))
        logit = self.generator(
            pre_logit, attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros)

        if config.pointer_gen:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_ext_batch.contiguous().view(-1))
        else:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))

        if config.label_smoothing:
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_ext_batch.contiguous().view(-1)
                if config.pointer_gen else dec_batch.contiguous().view(-1))

        if train:
            if iter >= 13000:
                loss = loss + (0.1 * aln_loss) + cla_loss
            else:
                loss = loss + aln_loss + cla_loss
            loss = loss + aln_loss + cla_loss
            loss.backward()
            self.optimizer.step()

        if config.label_smoothing:
            if torch.isnan(loss_ppl).sum().item() != 0 or torch.isinf(
                    loss_ppl).sum().item() != 0:
                print("check")
                pdb.set_trace()
            return loss_ppl.item(), math.exp(min(loss_ppl.item(),
                                                 100)), cla_loss.item(), sa_acc
        else:
            return loss.item(), math.exp(min(loss.item(),
                                             100)), cla_loss.item(), sa_acc

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self, batch, max_dec_step=30):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        src_batch = batch[
            'reviews_batch']  # reviews sequence (bsz, r_num, r_len)
        src_mask = batch[
            'reviews_mask']  # indicate which review is fake(for padding).  (bsz, r_num)
        src_length = batch['reviews_length']  # (bsz, r_num)
        enc_length_batch = batch[
            'reviews_length_list']  # 2-dim list, 0: len=bsz, 1: lens of reviews and pads
        src_labels = batch['reviews_label']  # (bsz, r_num)
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        # 1. Sentence-level Salience Estimation (SSE)
        cla_loss, sa_scores, sa_acc = self.sse.salience_estimate(
            src_batch, src_mask, src_length,
            src_labels)  # sa_scores: (bsz, r_num)

        ## Encode - context
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch) + emb_mask  # todo eos or sentence embedding??
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        src_enc_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, src_len, emb_dim)
        src_ext_rank = torch.LongTensor([]).to(config.device)  # (bsz, src_len)
        aln_rank = torch.LongTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)
        aln_mask_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)

        bsz, max_src_len = enc_batch.size()
        for idx in range(bsz):  # Custering (by k-means) and Ranking
            item_length = enc_length_batch[idx]
            reviews = torch.split(encoder_outputs[idx], item_length, dim=0)
            reviews_ext = torch.split(enc_batch_extend_vocab[idx],
                                      item_length,
                                      dim=0)

            r_vectors = []  # store the vector repr of each review
            rs_vectors = []  # store the token vectors repr of each review
            r_exts = []
            r_pad_vec, r_ext_pad = None, None
            for r_idx in range(len(item_length)):
                if r_idx == len(item_length) - 1:
                    r_pad_vec = reviews[r_idx]
                    r_ext_pad = reviews_ext[r_idx]
                    break
                r = self.rcr.hierarchical_pooling(reviews[r_idx].unsqueeze(
                    0)).squeeze(0).detach().cpu().numpy() * sa_scores[idx,
                                                                      r_idx]
                r_vectors.append(r)
                rs_vectors.append(reviews[r_idx])
                r_exts.append(reviews_ext[r_idx])

            rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln = self.rcr.perform(
                r_vecs=r_vectors,
                rs_vecs=rs_vectors,
                r_exts=r_exts,
                r_pad_vec=r_pad_vec,
                r_ext_pad=r_ext_pad,
                max_rs_length=max_src_len,
                train=False)
            # rs_repr: (max_rs_length, embed_dim); ext_repr: (max_rs_length); srctgt_aln_mask/srctgt_aln: (tgt_len, max_rs_length)

            src_enc_rank = torch.cat((src_enc_rank, rs_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len, embed_dim)
            src_ext_rank = torch.cat((src_ext_rank, ext_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len)
            aln_rank = torch.cat((aln_rank, srctgt_aln.unsqueeze(0)),
                                 dim=0)  # (1->bsz, max_tgt_len, max_src_len)
            aln_mask_rank = torch.cat(
                (aln_mask_rank, srctgt_aln_mask.unsqueeze(0)), dim=0)

        del encoder_outputs, reviews, reviews_ext, r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln
        torch.cuda.empty_cache()
        torch.backends.cuda.cufft_plan_cache.clear()

        # ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        ys = torch.zeros(enc_batch.size(0), 1).fill_(config.SOS_idx).long().to(
            config.device)  # when testing, we set bsz into 1
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        ys_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device)
        last_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device)

        pred_attn_dist = torch.FloatTensor([]).to(config.device)
        decoded_words = []
        for i in range(max_dec_step + 1):
            aln_rank_cur = aln_rank[:, last_rank.item(), :].unsqueeze(
                1)  # (bsz, src_len)
            if config.project:
                out, attn_dist, _ = self.decoder(
                    inputs=self.embedding_proj_in(self.embedding(ys)),
                    inputs_rank=ys_rank,
                    encoder_output=self.embedding_proj_in(src_enc_rank),
                    aln_rank=aln_rank_cur,
                    aln_mask_rank=aln_mask_rank,  # nouse
                    mask=(mask_src, mask_trg),
                    speed='slow')
            else:
                out, attn_dist, _ = self.decoder(inputs=self.embedding(ys),
                                                 inputs_rank=ys_rank,
                                                 encoder_output=src_enc_rank,
                                                 aln_rank=aln_rank_cur,
                                                 aln_mask_rank=aln_mask_rank,
                                                 mask=(mask_src, mask_trg),
                                                 speed='slow')
            prob = self.generator(out, attn_dist, enc_batch_extend_vocab,
                                  extra_zeros)
            _, next_word = torch.max(prob[:, -1], dim=1)  # bsz=1, if test

            cur_words = []
            for i_batch, ni in enumerate(next_word.view(-1)):
                if ni.item() == config.EOS_idx:
                    cur_words.append('<EOS>')
                    last_rank[i_batch] = 0
                elif ni.item() in self.vocab.index2word:
                    cur_words.append(self.vocab.index2word[ni.item()])
                    if ni.item() == config.SOS_idx:
                        last_rank[i_batch] += 1
                else:
                    cur_words.append(oovs[i_batch][
                        ni.item() -
                        self.vocab.n_words])  # output non-dict word
                    next_word[i_batch] = config.UNK_idx  # input unk word
            decoded_words.append(cur_words)
            # next_word = next_word.data[0]
            # if next_word.item() not in self.vocab.index2word:
            #     next_word = torch.tensor(config.UNK_idx)

            # if config.USE_CUDA:
            ys = torch.cat([ys, next_word.unsqueeze(1)],
                           dim=1).to(config.device)
            ys_rank = torch.cat([ys_rank, last_rank], dim=1).to(config.device)
            # else:
            #     ys = torch.cat([ys, next_word],dim=1)
            #     ys_rank = torch.cat([ys_rank, last_rank],dim=1)

            # if config.USE_CUDA:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).cuda()], dim=1)
            #     ys = ys.cuda()
            # else:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

            if config.attn_analysis:
                pred_attn_dist = torch.cat(
                    (pred_attn_dist, attn_dist.unsqueeze(1)), dim=1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)

        if config.attn_analysis:
            bsz, tgt_len, src_len = aln_mask_rank.size()
            pred_attn = pred_attn_dist[:, 1:, :].view(bsz * tgt_len, src_len)
            tgt_attn = aln_mask_rank.view(bsz * tgt_len, src_len)
            good_attn_sum = torch.masked_select(
                pred_attn,
                tgt_attn.bool()).sum()  # pred_attn: bsz * tgt_len, src_len
            bad_attn_sum = torch.masked_select(pred_attn,
                                               ~tgt_attn.bool()).sum()
            bad_num = ~tgt_attn.bool()
            ratio = bad_num.sum() / tgt_attn.bool().sum()
            bad_attn_sum /= ratio

            good_attn = good_attn_sum[
                0]  # last step (because this's already been the whole sentence length.).
            bad_attn = bad_attn_sum[1]
            good_attn /= (tgt_len * bsz)
            bad_attn /= (tgt_len * bsz)

            return sent, [good_attn, bad_attn]
        else:
            return sent
Example #6
0
class woRCR(nn.Module):
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(woRCR, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.sse = SSE(vocab, config.emb_dim, config.dropout,
                       config.rnn_hidden_dim)
        self.rcr = RCR()

        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if config.label_smoothing:
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if config.noam:
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if load_optim:
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]

        src_batch = batch[
            'reviews_batch']  # reviews sequence (bsz, r_num, r_len)
        src_mask = batch[
            'reviews_mask']  # indicate which review is fake(for padding).  (bsz, r_num)
        src_length = batch['reviews_length']  # (bsz, r_num)
        enc_length_batch = batch[
            'reviews_length_list']  # 2-dim list, 0: len=bsz, 1: lens of reviews and pads

        src_labels = batch['reviews_label']  # (bsz, r_num)

        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]
        tid_batch = batch[
            'tags_idx_batch']  # tag indexes sequence (bsz, tgt_len)

        if config.noam:
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        # 1. Sentence-level Salience Estimation (SSE)
        cla_loss, sa_scores, sa_acc = self.sse.salience_estimate(
            src_batch, src_mask, src_length,
            src_labels)  # sa_scores: (bsz, r_num)

        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1).to(
                                         config.device)  # (bsz, 1)
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  # (bsz, tgt_len)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist, aln_loss = self.decoder(
            inputs=self.embedding(dec_batch_shift),
            inputs_rank=tid_batch,
            encoder_output=encoder_outputs,
            mask=(mask_src, mask_trg))

        logit = self.generator(
            pre_logit, attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros)
        if config.pointer_gen:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_ext_batch.contiguous().view(-1))
        else:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))

        if config.label_smoothing:
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_ext_batch.contiguous().view(-1)
                if config.pointer_gen else dec_batch.contiguous().view(-1))
        loss = loss + cla_loss
        if torch.isnan(loss).sum().item() != 0 or torch.isinf(
                loss).sum().item() != 0:
            print("check")
            pdb.set_trace()
        if train:
            loss.backward()
            self.optimizer.step()

        if config.label_smoothing:
            loss_ppl = loss_ppl.item()
            cla_loss = cla_loss.item()
            return loss_ppl, math.exp(min(loss_ppl, 100)), cla_loss, sa_acc
        else:
            return loss.item(), math.exp(min(loss.item(),
                                             100)), cla_loss, sa_acc

    def decoder_greedy(self, batch, max_dec_step=30):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]

        src_batch = batch[
            'reviews_batch']  # reviews sequence (bsz, r_num, r_len)
        src_mask = batch[
            'reviews_mask']  # indicate which review is fake(for padding).  (bsz, r_num)
        src_length = batch['reviews_length']  # (bsz, r_num)
        enc_length_batch = batch[
            'reviews_length_list']  # 2-dim list, 0: len=bsz, 1: lens of reviews and pads

        src_labels = batch['reviews_label']  # (bsz, r_num)

        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch) + emb_mask  # todo eos or sentence embedding??
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        ys = torch.zeros(enc_batch.size(0), 1).fill_(config.SOS_idx).long().to(
            config.device)  # when testing, we set bsz into 1
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        ys_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device)
        last_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device)

        decoded_words = []
        for i in range(max_dec_step + 1):
            if config.project:
                out, attn_dist, aln_loss = self.decoder(
                    inputs=self.embedding_proj_in(self.embedding(ys)),
                    inputs_rank=ys_rank,
                    encoder_output=self.embedding_proj_in(encoder_outputs),
                    mask=(mask_src, mask_trg))
            else:
                out, attn_dist, aln_loss = self.decoder(
                    inputs=self.embedding(ys),
                    inputs_rank=ys_rank,
                    encoder_output=encoder_outputs,
                    mask=(mask_src, mask_trg))
            prob = self.generator(
                out, attn_dist,
                enc_batch_extend_vocab if config.pointer_gen else None,
                extra_zeros)

            _, next_word = torch.max(prob[:, -1], dim=1)  # bsz=1
            cur_words = []
            for i_batch, ni in enumerate(next_word.view(-1)):
                if ni.item() == config.EOS_idx:
                    cur_words.append('<EOS>')
                    last_rank[i_batch] = 0
                elif ni.item() in self.vocab.index2word:
                    cur_words.append(self.vocab.index2word[ni.item()])
                    if ni.item() == config.SOS_idx:
                        last_rank[i_batch] += 1
                else:
                    cur_words.append(oovs[i_batch][ni.item() -
                                                   self.vocab.n_words])
            decoded_words.append(cur_words)
            next_word = next_word.data[0]

            if next_word.item() not in self.vocab.index2word:
                next_word = torch.tensor(config.UNK_idx)

            ys = torch.cat([
                ys,
                torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).to(
                    config.device)
            ],
                           dim=1).to(config.device)
            ys_rank = torch.cat([ys_rank, last_rank], dim=1).to(config.device)
            # if config.USE_CUDA:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).cuda()], dim=1)
            #     ys = ys.cuda()
            # else:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent
Example #7
0
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Seq2SPG, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.preptrained)
        self.encoder = nn.LSTM(config.emb_dim,
                               config.hidden_dim,
                               config.hop,
                               bidirectional=False,
                               batch_first=True,
                               dropout=0.2)
        self.encoder2decoder = nn.Linear(config.hidden_dim, config.hidden_dim)
        self.decoder = LSTMAttentionDot(config.emb_dim,
                                        config.hidden_dim,
                                        batch_first=True)
        self.memory = MLP(
            config.hidden_dim + config.emb_dim,
            [config.private_dim1, config.private_dim2, config.private_dim3],
            config.hidden_dim)
        self.dec_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim)
        self.mem_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim)
        self.generator = Generator(config.hidden_dim, self.vocab_size)
        self.hooks = {
        }  #Save the model structure of each task as masks of the parameters
        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.weight
        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        if is_eval:
            self.encoder = self.encoder.eval()
            self.encoder2decoder = self.encoder2decoder.eval()
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()
            self.embedding = self.embedding.eval()
            self.memory = self.memory.eval()
            self.dec_gate = self.dec_gate.eval()
            self.mem_gate = self.mem_gate.eval()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 4000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))
        if config.use_sgd:
            self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr)
        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            print("LOSS", state['current_loss'])
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.encoder2decoder.load_state_dict(
                state['encoder2decoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            self.memory.load_state_dict(state['memory_dict'])
            self.dec_gate.load_state_dict(state['dec_gate_dict'])
            self.mem_gate.load_state_dict(state['mem_gate_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda()
            self.encoder2decoder = self.encoder2decoder.cuda()
            self.decoder = self.decoder.cuda()
            self.generator = self.generator.cuda()
            self.criterion = self.criterion.cuda()
            self.embedding = self.embedding.cuda()
            self.memory = self.memory.cuda()
            self.dec_gate = self.dec_gate.cuda()
            self.mem_gate = self.mem_gate.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""
Example #8
0
class CvaeTrans(nn.Module):
    def __init__(self,
                 vocab,
                 emo_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(CvaeTrans, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.r_encoder = Encoder(config.emb_dim,
                                 config.hidden_dim,
                                 num_layers=config.hop,
                                 num_heads=config.heads,
                                 total_key_depth=config.depth,
                                 total_value_depth=config.depth,
                                 filter_size=config.filter,
                                 universal=config.universal)
        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)
        self.latent_layer = Latent(is_eval)
        self.bow = SoftmaxOutputLayer(config.hidden_dim, self.vocab_size)
        if config.multitask:
            self.emo = SoftmaxOutputLayer(config.hidden_dim, emo_number)
            self.emo_criterion = nn.NLLLoss()
        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.r_encoder.load_state_dict(state['r_encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            self.latent_layer.load_state_dict(state['latent_dict'])
            self.bow.load_state_dict(state['bow'])
        if (config.USE_CUDA):
            self.cuda()
        if is_eval:
            self.eval()
        else:
            self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
            if (config.noam):
                self.optimizer = NoamOpt(
                    config.hidden_dim, 1, 8000,
                    torch.optim.Adam(self.parameters(),
                                     lr=0,
                                     betas=(0.9, 0.98),
                                     eps=1e-9))
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
                if config.USE_CUDA:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'r_encoder_state_dict': self.r_encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'latent_dict': self.latent_layer.state_dict(),
            'bow': self.bow.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(
            config.PAD_idx).unsqueeze(1)
        posterior_mask = self.embedding(batch["posterior_mask"])
        r_encoder_outputs = self.r_encoder(
            self.embedding(batch["posterior_batch"]) + posterior_mask,
            mask_res)
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)
        #latent variable
        if config.model == "cvaetrs":
            kld_loss, z = self.latent_layer(encoder_outputs[:, 0],
                                            r_encoder_outputs[:, 0],
                                            train=True)

        meta = self.embedding(batch["program_label"])
        if config.dataset == "empathetic":
            meta = meta - meta
        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()

        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  #(batch, len, embedding)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        input_vector = self.embedding(dec_batch_shift)
        if config.model == "cvaetrs":
            input_vector[:, 0] = input_vector[:, 0] + z + meta
        else:
            input_vector[:, 0] = input_vector[:, 0] + meta
        pre_logit, attn_dist = self.decoder(input_vector, encoder_outputs,
                                            (mask_src, mask_trg))

        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
        if config.model == "cvaetrs":
            z_logit = self.bow(z + meta)  # [batch_size, vocab_size]
            z_logit = z_logit.unsqueeze(1).repeat(1, logit.size(1), 1)
            loss_aux = self.criterion(
                z_logit.contiguous().view(-1, z_logit.size(-1)),
                dec_batch.contiguous().view(-1))
            if config.multitask:
                emo_logit = self.emo(encoder_outputs[:, 0])
                emo_loss = self.emo_criterion(emo_logit,
                                              batch["program_label"] - 9)
            #kl_weight = min(iter/config.full_kl_step, 0.28) if config.full_kl_step >0 else 1.0
            kl_weight = min(
                math.tanh(6 * iter / config.full_kl_step - 3) + 1, 1)
            loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux
            if config.multitask:
                loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux + emo_loss
            aux = loss_aux.item()
            elbo = loss_rec + kld_loss
        else:
            loss = loss_rec
            elbo = loss_rec
            kld_loss = torch.Tensor([0])
            aux = 0
            if config.multitask:
                emo_logit = self.emo(encoder_outputs[:, 0])
                emo_loss = self.emo_criterion(emo_logit,
                                              batch["program_label"] - 9)
                loss = loss_rec + emo_loss
        if (train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(
            loss_rec.item(), 100)), kld_loss.item(), aux, elbo.item()

    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        if config.dataset == "empathetic":
            meta = meta - meta
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)
        if config.model == "cvaetrs":
            kld_loss, z = self.latent_layer(encoder_outputs[:, 0],
                                            None,
                                            train=False)

        ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        decoded_words = []
        for i in range(max_dec_step + 1):
            input_vector = self.embedding(ys)
            if config.model == "cvaetrs":
                input_vector[:, 0] = input_vector[:, 0] + z + meta
            else:
                input_vector[:, 0] = input_vector[:, 0] + meta
            out, attn_dist = self.decoder(input_vector, encoder_outputs,
                                          (mask_src, mask_trg))

            prob = self.generator(out,
                                  attn_dist,
                                  enc_batch_extend_vocab,
                                  extra_zeros,
                                  attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim=1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])

            if config.USE_CUDA:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)

            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Example #9
0
class PGNet(nn.Module):
    '''
    refer: https://github.com/atulkum/pointer_summarizer
    '''
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(PGNet, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.reduce_state = ReduceState()

        self.generator = Generator(config.rnn_hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if config.label_smoothing:
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if config.noam:
            self.optimizer = NoamOpt(
                config.rnn_hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if load_optim:
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch = batch["review_batch"]
        enc_lens = batch["review_length"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]
        max_tgt_len = dec_batch.size(0)

        if config.noam:
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Embedding - context
        mask_src = enc_batch
        mask_src = ~(mask_src.data.eq(config.PAD_idx))
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs, encoder_feature, encoder_hidden = self.encoder(
            src_emb, enc_lens)

        # reduce bidirectional hidden to one hidden (h and c)
        s_t_1 = self.reduce_state(encoder_hidden)  # 1 x b x hidden_dim
        c_t_1 = Variable(
            torch.zeros((enc_batch.size(0),
                         2 * config.rnn_hidden_dim))).to(config.device)
        coverage = Variable(torch.zeros(enc_batch.size())).to(config.device)
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1).to(
                                         config.device)  # (bsz, 1)
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  # (bsz, tgt_len)
        dec_batch_embd = self.embedding(dec_batch_shift)

        step_losses = []
        step_loss_ppls = 0
        for di in range(max_tgt_len):
            y_t_1 = dec_batch_embd[:, di, :]
            logit, s_t_1, c_t_1, attn_dist, next_coverage, p_gen = self.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature, mask_src,
                c_t_1, coverage, di)

            logit = self.generator(
                logit, attn_dist,
                enc_batch_extend_vocab if config.pointer_gen else None,
                extra_zeros, 1, p_gen)

            if config.pointer_gen:
                step_loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_ext_batch[:, di].contiguous().view(-1))
            else:
                step_loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch[:, di].contiguous().view(-1))

            if config.label_smoothing:
                step_loss_ppl = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch[:, di].contiguous().view(-1))
                step_loss_ppls += step_loss_ppl

            if config.is_coverage:
                # coverage loss
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                # loss sum
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                # update coverage
                coverage = next_coverage

            step_losses.append(step_loss)
        if config.is_coverage:
            sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
            batch_avg_loss = sum_losses / batch['tags_length'].float()
            loss = torch.mean(batch_avg_loss)
        else:
            loss = sum(step_losses) / max_tgt_len
        if config.label_smoothing:
            loss_ppl = (step_loss_ppls / max_tgt_len).item()

        if train:
            loss.backward()
            self.optimizer.step()

        if config.label_smoothing:
            return loss_ppl, math.exp(min(loss_ppl, 100)), 0, 0
        else:
            return loss.item(), math.exp(min(loss.item(), 100)), 0, 0

    def decoder_greedy(self, batch, max_dec_step=30):
        enc_batch = batch["review_batch"]
        enc_lens = batch["review_length"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]
        max_tgt_len = dec_batch.size(0)

        ## Embedding - context
        mask_src = enc_batch
        mask_src = ~(mask_src.data.eq(config.PAD_idx))
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs, encoder_feature, encoder_hidden = self.encoder(
            src_emb, enc_lens)

        # reduce bidirectional hidden to one hidden (h and c)
        s_t_1 = self.reduce_state(encoder_hidden)  # 1 x b x hidden_dim
        c_t_1 = Variable(
            torch.zeros((enc_batch.size(0),
                         2 * config.rnn_hidden_dim))).to(config.device)
        coverage = Variable(torch.zeros(enc_batch.size())).to(config.device)
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        # ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        ys = torch.zeros(enc_batch.size(0)).fill_(config.SOS_idx).long().to(
            config.device)  # when testing, we set bsz into 1
        decoded_words = []
        for i in range(max_dec_step + 1):
            logit, s_t_1, c_t_1, attn_dist, next_coverage, p_gen = self.decoder(
                self.embedding(ys), s_t_1, encoder_outputs, encoder_feature,
                mask_src, c_t_1, coverage, i)
            prob = self.generator(
                logit, attn_dist,
                enc_batch_extend_vocab if config.pointer_gen else None,
                extra_zeros, 1, p_gen)

            _, next_word = torch.max(prob, dim=1)  # bsz=1
            cur_words = []
            for i_batch, ni in enumerate(next_word.view(-1)):
                if ni.item() == config.EOS_idx:
                    cur_words.append('<EOS>')
                elif ni.item() in self.vocab.index2word:
                    cur_words.append(self.vocab.index2word[ni.item()])
                else:
                    cur_words.append(oovs[i_batch][ni.item() -
                                                   self.vocab.n_words])
            decoded_words.append(cur_words)
            next_word = next_word.data[0]

            if next_word.item() not in self.vocab.index2word:
                next_word = torch.tensor(config.UNK_idx)

            ys = torch.zeros(enc_batch.size(0)).long().fill_(next_word).to(
                config.device)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent
class Transformer(nn.Module):
    def __init__(self,
                 vocab,
                 emo_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words
        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        meta = self.embedding(batch["program_label"])
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(
            self.embedding(dec_batch_shift) + meta.unsqueeze(1),
            encoder_outputs, (mask_src, mask_trg))
        #+meta.unsqueeze(1)
        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        #logit = F.log_softmax(logit,dim=-1) #fix the name later
        ## loss: NNL if ptr else Cross entropy
        loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                              dec_batch.contiguous().view(-1))

        if (train):
            loss.backward()
            self.optimizer.step()

        return loss.item(), math.exp(min(loss.item(), 100)), 0

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        # print('=====================ys========================')
        # print(ys)

        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):

            out, attn_dist = self.decoder(
                self.embedding(ys) + meta.unsqueeze(1), encoder_outputs,
                (mask_src, mask_trg))

            prob = self.generator(out,
                                  attn_dist,
                                  enc_batch_extend_vocab,
                                  extra_zeros,
                                  attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim=1)
            # print('=====================next_word1========================')
            # print(next_word)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            #next_word = next_word.data[0]
            # print('=====================next_word2========================')
            # print(next_word)
            if config.USE_CUDA:
                # print('=====================shape========================')
                # print(ys.shape, next_word.shape)
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)

            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
            # print('=====================new_ys========================')
            # print(ys)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent

    def decoder_greedy_po(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        # print('=====================ys========================')
        # print(ys)

        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):

            out, attn_dist = self.decoder(
                self.embedding(ys) + meta.unsqueeze(1), encoder_outputs,
                (mask_src, mask_trg))

            prob = self.generator(out,
                                  attn_dist,
                                  enc_batch_extend_vocab,
                                  extra_zeros,
                                  attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim=1)
            # print('=====================next_word1========================')
            # print(next_word)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            #next_word = next_word.data[0]
            # print('=====================next_word2========================')
            # print(next_word)
            if config.USE_CUDA:
                # print('=====================shape========================')
                # print(ys.shape, next_word.shape)
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)

            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
            # print('=====================new_ys========================')
            # print(ys)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Example #11
0
writer = SummaryWriter(log_dir=config.save_path)
# Build model, optimizer, and set states
if not (config.load_frompretrain == 'None'):
    meta_net = Seq2SPG(p.vocab,
                       model_file_path=config.load_frompretrain,
                       is_eval=False)
else:
    meta_net = Seq2SPG(p.vocab)
if config.meta_optimizer == 'sgd':
    meta_optimizer = torch.optim.SGD(meta_net.parameters(), lr=config.meta_lr)
elif config.meta_optimizer == 'adam':
    meta_optimizer = torch.optim.Adam(meta_net.parameters(), lr=config.meta_lr)
elif config.meta_optimizer == 'noam':
    meta_optimizer = NoamOpt(
        config.hidden_dim, 1, 4000,
        torch.optim.Adam(meta_net.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
else:
    raise ValueError

meta_batch_size = config.meta_batch_size
tasks = p.get_personas('train')
steps = (len(tasks) //
         meta_batch_size) + int(len(tasks) % meta_batch_size != 0)

# meta early stop
patience = 10
if config.fix_dialnum_train:
    patience = 100
best_loss = 10000000
Example #12
0
class VAE(nn.Module):
    def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False):
        super(VAE, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab,config.preptrained)
        self.encoder = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True,
                               dropout=0.2)
        self.encoder_r = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True,
                               dropout=0.2)
        self.represent = R_MLP(2 * config.hidden_dim, 68)
        self.prior = P_MLP(config.hidden_dim, 68)
        self.mlp_b = nn.Linear(config.hidden_dim + 68, self.vocab_size)
        self.encoder2decoder = nn.Linear(
            config.hidden_dim + 68,
            config.hidden_dim)
        self.decoder = LSTMAttentionDot(config.emb_dim, config.hidden_dim, batch_first=True)   
        self.generator = Generator(config.hidden_dim,self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        if is_eval:
            self.encoder = self.encoder.eval()
            self.encoder_r = self.encoder_r.eval()
            self.represent = self.represent.eval()
            self.prior = self.prior.eval()
            self.mlp_b = self.mlp_b.eval()
            self.encoder2decoder = self.encoder2decoder.eval()
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()
            self.embedding = self.embedding.eval()

    
        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if(config.noam):
            self.optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
        if config.use_sgd:
            self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr)
        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            print("LOSS",state['current_loss'])
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.encoder_r.load_state_dict(state['encoder_r_state_dict'])
            self.represent.load_state_dict(state['represent_state_dict'])
            self.prior.load_state_dict(state['prior_state_dict'])
            self.mlp_b.load_state_dict(state['mlp_b_state_dict'])
            self.encoder2decoder.load_state_dict(state['encoder2decoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda()
            self.encoder_r = self.encoder_r.cuda()
            self.represent = self.represent.cuda()
            self.prior = self.prior.cuda()
            self.mlp_b = self.mlp_b.cuda()
            self.encoder2decoder = self.encoder2decoder.cuda()
            self.decoder = self.decoder.cuda()
            self.generator = self.generator.cuda()
            self.criterion = self.criterion.cuda()
            self.embedding = self.embedding.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""
    
    def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b, log=False, d="save/paml_model_sim"):
        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'encoder_r_state_dict': self.encoder_r.state_dict(),
            'represent_state_dict': self.represent.state_dict(),
            'prior_state_dict': self.prior.state_dict(),
            'mlp_b_state_dict': self.mlp_b.state_dict(),
            'encoder2decoder_state_dict': self.encoder2decoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            #'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        if log:
            model_save_path = os.path.join(d, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) )
        else:
            model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) )
        self.best_path = model_save_path
        torch.save(state, model_save_path)
    
    def get_state(self, batch):
        """Get cell states and hidden states."""
        batch_size = batch.size(0) \
            if self.encoder.batch_first else batch.size(1)
        h0_encoder = Variable(torch.zeros(
            self.encoder.num_layers,
            batch_size,
            config.hidden_dim
        ), requires_grad=False)
        c0_encoder = Variable(torch.zeros(
            self.encoder.num_layers,
            batch_size,
            config.hidden_dim
        ), requires_grad=False)

        return h0_encoder.cuda(), c0_encoder.cuda()
    
    def train_one_batch(self, batch, train=True):
        ## pad and other stuff
        enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        self.h0_encoder, self.c0_encoder = self.get_state(enc_batch)
        src_h, (src_h_t, src_c_t) = self.encoder(
            self.embedding(enc_batch), (self.h0_encoder, self.c0_encoder))
        h_t = src_h_t[-1]
        c_t = src_c_t[-1]
        self.h0_encoder_r, self.c0_encoder_r = self.get_state(dec_batch)
        src_h_r, (src_h_t_r, src_c_t_r) = self.encoder_r(
            self.embedding(dec_batch), (self.h0_encoder_r, self.c0_encoder_r))
        h_t_r = src_h_t_r[-1]
        c_t_r = src_c_t_r[-1]
        
        #sample and reparameter
        z_sample, mu, var = self.represent(torch.cat((h_t_r, h_t), 1))
        p_z_sample, p_mu, p_var = self.prior(h_t)
        
        # Decode
        decoder_init_state = nn.Tanh()(self.encoder2decoder(torch.cat((z_sample, h_t), 1)))
        
        sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1)
        target_embedding = self.embedding(dec_batch_shift)
        ctx = src_h.transpose(0, 1)
        trg_h, (_, _) = self.decoder(
            target_embedding,
            (decoder_init_state, c_t),
            ctx    
        )
        pre_logit = trg_h
        logit = self.generator(pre_logit)
        
        ## loss: NNL if ptr else Cross entropy
        re_loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        kl_losses = 0.5 * torch.sum(torch.exp(var - p_var) + (mu - p_mu) ** 2 / torch.exp(p_var) - 1. - var + p_var, 1)
        kl_loss = torch.mean(kl_losses)
        latent_logit = self.mlp_b(torch.cat((z_sample, h_t), 1)).unsqueeze(1)
        latent_logit = F.log_softmax(latent_logit,dim=-1)
        latent_logits = latent_logit.repeat(1, logit.size(1), 1)
        bow_loss = self.criterion(latent_logits.contiguous().view(-1, latent_logits.size(-1)), dec_batch.contiguous().view(-1))
        loss = re_loss + 0.48 * kl_loss + bow_loss
        if(train):
            loss.backward()
            self.optimizer.step()
        if(config.label_smoothing): 
            s_loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        
        return s_loss.item(), math.exp(min(s_loss.item(), 100)), loss.item(), re_loss.item(), kl_loss.item(), bow_loss.item()
Example #13
0
    def __init__(self,
                 vocab,
                 decoder_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super().__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words
        self.embedding = share_embedding(self.vocab, config.pretrain_emb)

        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)
        self.decoder_number = decoder_number

        self.decoder = DecoderContextV(config.emb_dim,
                                       config.hidden_dim,
                                       num_layers=config.hop,
                                       num_heads=config.heads,
                                       total_key_depth=config.depth,
                                       total_value_depth=config.depth,
                                       filter_size=config.filter)

        self.vae_sampler = VAESampling(config.hidden_dim,
                                       config.hidden_dim,
                                       out_dim=300)

        # outputs m
        self.emotion_input_encoder_1 = EmotionInputEncoder(
            config.emb_dim,
            config.hidden_dim,
            num_layers=config.hop,
            num_heads=config.heads,
            total_key_depth=config.depth,
            total_value_depth=config.depth,
            filter_size=config.filter,
            universal=config.universal,
            emo_input=config.emo_input)
        # outputs m~
        self.emotion_input_encoder_2 = EmotionInputEncoder(
            config.emb_dim,
            config.hidden_dim,
            num_layers=config.hop,
            num_heads=config.heads,
            total_key_depth=config.depth,
            total_value_depth=config.depth,
            filter_size=config.filter,
            universal=config.universal,
            emo_input=config.emo_input)

        if config.emo_combine == "att":
            self.cdecoder = ComplexResDecoder(config.emb_dim,
                                              config.hidden_dim,
                                              num_layers=config.hop,
                                              num_heads=config.heads,
                                              total_key_depth=config.depth,
                                              total_value_depth=config.depth,
                                              filter_size=config.filter,
                                              universal=config.universal)

        elif config.emo_combine == "gate":
            self.cdecoder = ComplexResGate(config.emb_dim)

        self.s_weight = nn.Linear(config.hidden_dim,
                                  config.emb_dim,
                                  bias=False)
        self.decoder_key = nn.Linear(config.hidden_dim,
                                     decoder_number,
                                     bias=False)

        # v^T tanh(W E[i] + H c + b)
        method3 = True
        if method3:
            self.e_weight = nn.Linear(config.emb_dim,
                                      config.emb_dim,
                                      bias=True)
            self.v = torch.rand(config.emb_dim, requires_grad=True)
            if config.USE_CUDA: self.v = self.v.cuda()

        self.generator = Generator(config.hidden_dim, self.vocab_size)
        self.emoji_embedding = nn.Embedding(32, config.emb_dim)
        if config.init_emo_emb: self.init_emoji_embedding_with_glove()

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        if config.softmax:
            self.attention_activation = nn.Softmax(dim=1)
        else:
            self.attention_activation = nn.Sigmoid()  # nn.Softmax()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.load_state_dict(state['model'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

        # Added positive emotions
        self.positive_emotions = [
            11, 16, 6, 8, 3, 1, 28, 13, 31, 17, 24, 0, 27
        ]
        self.negative_emotions = [
            9, 4, 2, 22, 14, 30, 29, 25, 15, 10, 23, 19, 18, 21, 7, 20, 5, 26,
            12
        ]
Example #14
0
class Train_MIME(nn.Module):
    '''
    for emotion attention, simply pass the randomly sampled emotion as the Q in a decoder block of transformer
    '''
    def __init__(self,
                 vocab,
                 decoder_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super().__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words
        self.embedding = share_embedding(self.vocab, config.pretrain_emb)

        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)
        self.decoder_number = decoder_number

        self.decoder = DecoderContextV(config.emb_dim,
                                       config.hidden_dim,
                                       num_layers=config.hop,
                                       num_heads=config.heads,
                                       total_key_depth=config.depth,
                                       total_value_depth=config.depth,
                                       filter_size=config.filter)

        self.vae_sampler = VAESampling(config.hidden_dim,
                                       config.hidden_dim,
                                       out_dim=300)

        # outputs m
        self.emotion_input_encoder_1 = EmotionInputEncoder(
            config.emb_dim,
            config.hidden_dim,
            num_layers=config.hop,
            num_heads=config.heads,
            total_key_depth=config.depth,
            total_value_depth=config.depth,
            filter_size=config.filter,
            universal=config.universal,
            emo_input=config.emo_input)
        # outputs m~
        self.emotion_input_encoder_2 = EmotionInputEncoder(
            config.emb_dim,
            config.hidden_dim,
            num_layers=config.hop,
            num_heads=config.heads,
            total_key_depth=config.depth,
            total_value_depth=config.depth,
            filter_size=config.filter,
            universal=config.universal,
            emo_input=config.emo_input)

        if config.emo_combine == "att":
            self.cdecoder = ComplexResDecoder(config.emb_dim,
                                              config.hidden_dim,
                                              num_layers=config.hop,
                                              num_heads=config.heads,
                                              total_key_depth=config.depth,
                                              total_value_depth=config.depth,
                                              filter_size=config.filter,
                                              universal=config.universal)

        elif config.emo_combine == "gate":
            self.cdecoder = ComplexResGate(config.emb_dim)

        self.s_weight = nn.Linear(config.hidden_dim,
                                  config.emb_dim,
                                  bias=False)
        self.decoder_key = nn.Linear(config.hidden_dim,
                                     decoder_number,
                                     bias=False)

        # v^T tanh(W E[i] + H c + b)
        method3 = True
        if method3:
            self.e_weight = nn.Linear(config.emb_dim,
                                      config.emb_dim,
                                      bias=True)
            self.v = torch.rand(config.emb_dim, requires_grad=True)
            if config.USE_CUDA: self.v = self.v.cuda()

        self.generator = Generator(config.hidden_dim, self.vocab_size)
        self.emoji_embedding = nn.Embedding(32, config.emb_dim)
        if config.init_emo_emb: self.init_emoji_embedding_with_glove()

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        if config.softmax:
            self.attention_activation = nn.Softmax(dim=1)
        else:
            self.attention_activation = nn.Sigmoid()  # nn.Softmax()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.load_state_dict(state['model'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

        # Added positive emotions
        self.positive_emotions = [
            11, 16, 6, 8, 3, 1, 28, 13, 31, 17, 24, 0, 27
        ]
        self.negative_emotions = [
            9, 4, 2, 22, 14, 30, 29, 25, 15, 10, 23, 19, 18, 21, 7, 20, 5, 26,
            12
        ]

    def init_emoji_embedding_with_glove(self):
        self.emotions = [
            'surprised', 'excited', 'annoyed', 'proud', 'angry', 'sad',
            'grateful', 'lonely', 'impressed', 'afraid', 'disgusted',
            'confident', 'terrified', 'hopeful', 'anxious', 'disappointed',
            'joyful', 'prepared', 'guilty', 'furious', 'nostalgic', 'jealous',
            'anticipating', 'embarrassed', 'content', 'devastated',
            'sentimental', 'caring', 'trusting', 'ashamed', 'apprehensive',
            'faithful'
        ]
        self.emotion_index = [self.vocab.word2index[i] for i in self.emotions]
        self.emoji_embedding_init = self.embedding(
            torch.Tensor(self.emotion_index).long())
        self.emoji_embedding.weight.data = self.emoji_embedding_init
        self.emoji_embedding.weight.requires_grad = True

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b,
                   ent_t):
        state = {
            'iter': iter,
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl,
            'model': self.state_dict()
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b, ent_t))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def random_sampling(self, e):
        p = np.random.choice(self.positive_emotions)
        n = np.random.choice(self.negative_emotions)
        if e in self.positive_emotions:
            mimic = p
            mimic_t = n
        else:
            mimic = n
            mimic_t = p
        return mimic, mimic_t

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)

        if config.dataset == "empathetic":
            emb_mask = self.embedding(batch["mask_input"])
            encoder_outputs = self.encoder(
                self.embedding(enc_batch) + emb_mask, mask_src)
        else:
            encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        # q_h = torch.max(encoder_outputs, dim=1)
        emotions_mimic, emotions_non_mimic, mu_positive_prior, logvar_positive_prior, mu_negative_prior, logvar_negative_prior = \
            self.vae_sampler(q_h, batch['program_label'], self.emoji_embedding)
        # KLLoss = -0.5 * (torch.sum(1 + logvar_n - mu_n.pow(2) - logvar_n.exp()) + torch.sum(1 + logvar_p - mu_p.pow(2) - logvar_p.exp()))
        m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1),
                                             encoder_outputs, mask_src)
        m_tilde_out = self.emotion_input_encoder_2(
            emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src)
        if train:
            emotions_mimic, emotions_non_mimic, mu_positive_posterior, logvar_positive_posterior, mu_negative_posterior, logvar_negative_posterior = \
                self.vae_sampler.forward_train(q_h, batch['program_label'], self.emoji_embedding, M_out=m_out.mean(dim=1), M_tilde_out=m_tilde_out.mean(dim=1))
            KLLoss_positive = self.vae_sampler.kl_div(
                mu_positive_posterior, logvar_positive_posterior,
                mu_positive_prior, logvar_positive_prior)
            KLLoss_negative = self.vae_sampler.kl_div(
                mu_negative_posterior, logvar_negative_posterior,
                mu_negative_prior, logvar_negative_prior)
            KLLoss = KLLoss_positive + KLLoss_negative
        else:
            KLLoss_positive = self.vae_sampler.kl_div(mu_positive_prior,
                                                      logvar_positive_prior)
            KLLoss_negative = self.vae_sampler.kl_div(mu_negative_prior,
                                                      logvar_negative_prior)
            KLLoss = KLLoss_positive + KLLoss_negative

        if config.emo_combine == "att":
            v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src)
        elif config.emo_combine == "gate":
            v = self.cdecoder(m_out, m_tilde_out)

        x = self.s_weight(q_h)

        # method2: E (W@c)
        logit_prob = torch.matmul(x, self.emoji_embedding.weight.transpose(
            0, 1))  # shape (b_size, 32)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift), v,
                                            v, (mask_src, mask_trg))

        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)

        if (train and config.schedule > 10):
            if (random.uniform(0, 1) <=
                (0.0001 +
                 (1 - 0.0001) * math.exp(-1. * iter / config.schedule))):
                config.oracle = True
            else:
                config.oracle = False

        if config.softmax:
            program_label = torch.LongTensor(batch['program_label'])
            if config.USE_CUDA: program_label = program_label.cuda()

            if config.emo_combine == 'gate':
                L1_loss = nn.CrossEntropyLoss()(logit_prob, program_label)
                loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1)) + KLLoss + L1_loss
            else:
                L1_loss = nn.CrossEntropyLoss()(
                    logit_prob, torch.LongTensor(batch['program_label'])
                    if not config.USE_CUDA else torch.LongTensor(
                        batch['program_label']).cuda())
                loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1)) + KLLoss + L1_loss

            loss_bce_program = nn.CrossEntropyLoss()(logit_prob,
                                                     program_label).item()
        else:
            loss = self.criterion(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)) + nn.BCEWithLogitsLoss()(
                    logit_prob, torch.FloatTensor(
                        batch['target_program']).cuda())
            loss_bce_program = nn.BCEWithLogitsLoss()(
                logit_prob,
                torch.FloatTensor(batch['target_program']).cuda()).item()
        pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1)
        program_acc = accuracy_score(batch["program_label"], pred_program)

        if (config.label_smoothing):
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)).item()

        if (train):
            loss.backward()
            self.optimizer.step()

        if (config.label_smoothing):
            return loss_ppl, math.exp(min(loss_ppl,
                                          100)), loss_bce_program, program_acc
        else:
            return loss.item(), math.exp(min(
                loss.item(), 100)), loss_bce_program, program_acc

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self,
                       batch,
                       max_dec_step=30,
                       emotion_classifier='built_in'):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        emotions = batch['program_label']

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)

        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]

        # method 2
        x = self.s_weight(q_h)
        logit_prob = torch.matmul(x,
                                  self.emoji_embedding.weight.transpose(0, 1))
        emo_pred = torch.argmax(logit_prob, dim=-1)

        if emotion_classifier == "vader":
            context_emo = [
                self.positive_emotions[0]
                if d['compound'] > 0 else self.negative_emotions[0]
                for d in batch['context_emotion_scores']
            ]
            context_emo = torch.Tensor(context_emo)
            if config.USE_CUDA:
                context_emo = context_emo.cuda()
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, context_emo, self.emoji_embedding)
        elif emotion_classifier == None:
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, batch['program_label'], self.emoji_embedding)
        elif emotion_classifier == "built_in":
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, emo_pred, self.emoji_embedding)

        m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1),
                                             encoder_outputs, mask_src)
        m_tilde_out = self.emotion_input_encoder_2(
            emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src)

        if config.emo_combine == "att":
            v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src)
            # v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src_chosen)
        elif config.emo_combine == "gate":
            v = self.cdecoder(m_out, m_tilde_out)
        elif config.emo_combine == 'vader':
            m_weight = context_emo_scores.unsqueeze(-1).unsqueeze(-1)
            m_tilde_weight = 1 - m_weight
            v = m_weight * m_weight + m_tilde_weight * m_tilde_out

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    self.embedding_proj_in(v), (mask_src, mask_trg),
                    attention_parameters)
            else:
                out, attn_dist = self.decoder(self.embedding(ys), v, v,
                                              (mask_src, mask_trg))

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            _, next_word = torch.max(logit[:, -1], dim=1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent, batch['context_emotion_scores'][0]['compound'], int(
            emo_pred[0].data.cpu())

    def decoder_topk(self,
                     batch,
                     max_dec_step=30,
                     emotion_classifier='built_in'):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        emotions = batch['program_label']

        context_emo = [
            self.positive_emotions[0]
            if d['compound'] > 0 else self.negative_emotions[0]
            for d in batch['context_emotion_scores']
        ]
        context_emo = torch.Tensor(context_emo)
        if config.USE_CUDA:
            context_emo = context_emo.cuda()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)

        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]

        x = self.s_weight(q_h)
        # method 2
        logit_prob = torch.matmul(x,
                                  self.emoji_embedding.weight.transpose(0, 1))

        if emotion_classifier == "vader":
            context_emo = [
                self.positive_emotions[0]
                if d['compound'] > 0 else self.negative_emotions[0]
                for d in batch['context_emotion_scores']
            ]
            context_emo = torch.Tensor(context_emo)
            if config.USE_CUDA:
                context_emo = context_emo.cuda()
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, context_emo, self.emoji_embedding)
        elif emotion_classifier == None:
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, batch['program_label'], self.emoji_embedding)
        elif emotion_classifier == "built_in":
            emo_pred = torch.argmax(logit_prob, dim=-1)
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, emo_pred, self.emoji_embedding)

        m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1),
                                             encoder_outputs, mask_src)
        m_tilde_out = self.emotion_input_encoder_2(
            emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src)

        if config.emo_combine == "att":
            v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src)
        elif config.emo_combine == "gate":
            v = self.cdecoder(m_out, m_tilde_out)
        elif config.emo_combine == 'vader':
            m_weight = context_emo_scores.unsqueeze(-1).unsqueeze(-1)
            m_tilde_weight = 1 - m_weight
            v = m_weight * m_weight + m_tilde_weight * m_tilde_out

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg), attention_parameters)
            else:
                out, attn_dist = self.decoder(self.embedding(ys), v, v,
                                              (mask_src, mask_trg))

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            filtered_logit = top_k_top_p_filtering(logit[:, -1],
                                                   top_k=3,
                                                   top_p=0,
                                                   filter_value=-float('Inf'))
            # Sample from the filtered distribution
            next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1),
                                          1).squeeze()
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data.item()

            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent
Example #15
0
class Transformer_experts(nn.Module):
    def __init__(self,
                 vocab,
                 decoder_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Transformer_experts, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)
        self.decoder_number = decoder_number
        ## multiple decoders
        self.decoder = MulDecoder(decoder_number,
                                  config.emb_dim,
                                  config.hidden_dim,
                                  num_layers=config.hop,
                                  num_heads=config.heads,
                                  total_key_depth=config.depth,
                                  total_value_depth=config.depth,
                                  filter_size=config.filter)

        self.decoder_key = nn.Linear(config.hidden_dim,
                                     decoder_number,
                                     bias=False)

        self.generator = Generator(config.hidden_dim, self.vocab_size)
        self.emoji_embedding = nn.Linear(64, config.emb_dim, bias=False)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        if config.softmax:
            self.attention_activation = nn.Softmax(dim=1)
        else:
            self.attention_activation = nn.Sigmoid()  #nn.Softmax()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.decoder_key.load_state_dict(state['decoder_key_state_dict'])
            #self.emoji_embedding.load_state_dict(state['emoji_embedding_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):
        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'decoder_key_state_dict': self.decoder_key.state_dict(),
            #'emoji_embedding_dict': self.emoji_embedding.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        if config.dataset == "empathetic":
            emb_mask = self.embedding(batch["mask_input"])
            encoder_outputs = self.encoder(
                self.embedding(enc_batch) + emb_mask, mask_src)
        else:
            encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)
        ## Attention over decoder
        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        #q_h = encoder_outputs[:,0]
        logit_prob = self.decoder_key(q_h)  #(bsz, num_experts)

        if (config.topk > 0):
            k_max_value, k_max_index = torch.topk(logit_prob, config.topk)
            a = np.empty([logit_prob.shape[0], self.decoder_number])
            a.fill(float('-inf'))
            mask = torch.Tensor(a).cuda()
            logit_prob_ = mask.scatter_(1,
                                        k_max_index.cuda().long(), k_max_value)
            attention_parameters = self.attention_activation(logit_prob_)
        else:
            attention_parameters = self.attention_activation(logit_prob)
        # print("===============================================================================")
        # print("listener attention weight:",attention_parameters.data.cpu().numpy())
        # print("===============================================================================")
        if (config.oracle):
            attention_parameters = self.attention_activation(
                torch.FloatTensor(batch['target_program']) * 1000).cuda()
        attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze(
            -1)  # (batch_size, expert_num, 1, 1)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),
                                            encoder_outputs,
                                            (mask_src, mask_trg),
                                            attention_parameters)
        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        #logit = F.log_softmax(logit,dim=-1) #fix the name later
        ## loss: NNL if ptr else Cross entropy
        if (train and config.schedule > 10):
            if (random.uniform(0, 1) <=
                (0.0001 +
                 (1 - 0.0001) * math.exp(-1. * iter / config.schedule))):
                config.oracle = True
            else:
                config.oracle = False

        if config.softmax:
            loss = self.criterion(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)) + nn.CrossEntropyLoss()(
                    logit_prob, torch.LongTensor(
                        batch['program_label']).cuda())
            loss_bce_program = nn.CrossEntropyLoss()(
                logit_prob,
                torch.LongTensor(batch['program_label']).cuda()).item()
        else:
            loss = self.criterion(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)) + nn.BCEWithLogitsLoss()(
                    logit_prob, torch.FloatTensor(
                        batch['target_program']).cuda())
            loss_bce_program = nn.BCEWithLogitsLoss()(
                logit_prob,
                torch.FloatTensor(batch['target_program']).cuda()).item()
        pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1)
        program_acc = accuracy_score(batch["program_label"], pred_program)

        if (config.label_smoothing):
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)).item()

        if (train):
            loss.backward()
            self.optimizer.step()

        if (config.label_smoothing):
            return loss_ppl, math.exp(min(loss_ppl,
                                          100)), loss_bce_program, program_acc
        else:
            return loss.item(), math.exp(min(
                loss.item(), 100)), loss_bce_program, program_acc

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self, batch, max_dec_step=30):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)
        ## Attention over decoder
        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        #q_h = encoder_outputs[:,0]
        logit_prob = self.decoder_key(q_h)

        if (config.topk > 0):
            k_max_value, k_max_index = torch.topk(logit_prob, config.topk)
            a = np.empty([logit_prob.shape[0], self.decoder_number])
            a.fill(float('-inf'))
            mask = torch.Tensor(a).cuda()
            logit_prob = mask.scatter_(1,
                                       k_max_index.cuda().long(), k_max_value)

        attention_parameters = self.attention_activation(logit_prob)

        if (config.oracle):
            attention_parameters = self.attention_activation(
                torch.FloatTensor(batch['target_program']) * 1000).cuda()
        attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze(
            -1)  # (batch_size, expert_num, 1, 1)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg), attention_parameters)
            else:

                out, attn_dist = self.decoder(self.embedding(ys),
                                              encoder_outputs,
                                              (mask_src, mask_trg),
                                              attention_parameters)

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            #logit = F.log_softmax(logit,dim=-1) #fix the name later
            _, next_word = torch.max(logit[:, -1], dim=1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent

    def decoder_topk(self, batch, max_dec_step=30):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        ## Attention over decoder
        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        #q_h = encoder_outputs[:,0]
        logit_prob = self.decoder_key(q_h)

        if (config.topk > 0):
            k_max_value, k_max_index = torch.topk(logit_prob, config.topk)
            a = np.empty([logit_prob.shape[0], self.decoder_number])
            a.fill(float('-inf'))
            mask = torch.Tensor(a).cuda()
            logit_prob = mask.scatter_(1,
                                       k_max_index.cuda().long(), k_max_value)

        attention_parameters = self.attention_activation(logit_prob)

        if (config.oracle):
            attention_parameters = self.attention_activation(
                torch.FloatTensor(batch['target_program']) * 1000).cuda()
        attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze(
            -1)  # (batch_size, expert_num, 1, 1)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg), attention_parameters)
            else:

                out, attn_dist = self.decoder(self.embedding(ys),
                                              encoder_outputs,
                                              (mask_src, mask_trg),
                                              attention_parameters)

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            filtered_logit = top_k_top_p_filtering(logit[:, -1],
                                                   top_k=3,
                                                   top_p=0,
                                                   filter_value=-float('Inf'))
            # Sample from the filtered distribution
            next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1),
                                          1).squeeze()
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Example #16
0
    def __init__(self,
                 vocab,
                 decoder_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Transformer_experts, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)
        self.decoder_number = decoder_number
        ## multiple decoders
        self.decoder = MulDecoder(decoder_number,
                                  config.emb_dim,
                                  config.hidden_dim,
                                  num_layers=config.hop,
                                  num_heads=config.heads,
                                  total_key_depth=config.depth,
                                  total_value_depth=config.depth,
                                  filter_size=config.filter)

        self.decoder_key = nn.Linear(config.hidden_dim,
                                     decoder_number,
                                     bias=False)

        self.generator = Generator(config.hidden_dim, self.vocab_size)
        self.emoji_embedding = nn.Linear(64, config.emb_dim, bias=False)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        if config.softmax:
            self.attention_activation = nn.Softmax(dim=1)
        else:
            self.attention_activation = nn.Sigmoid()  #nn.Softmax()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.decoder_key.load_state_dict(state['decoder_key_state_dict'])
            #self.emoji_embedding.load_state_dict(state['emoji_embedding_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""
Example #17
0
class CvaeNAD(nn.Module):
    def __init__(self,
                 vocab,
                 emo_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(CvaeNAD, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.r_encoder = Encoder(config.emb_dim,
                                 config.hidden_dim,
                                 num_layers=config.hop,
                                 num_heads=config.heads,
                                 total_key_depth=config.depth,
                                 total_value_depth=config.depth,
                                 filter_size=config.filter,
                                 universal=config.universal)
        if config.num_var_layers > 0:
            self.decoder = VarDecoder2(config.emb_dim,
                                       hidden_size=config.hidden_dim,
                                       num_layers=config.hop,
                                       num_heads=config.heads,
                                       total_key_depth=config.depth,
                                       total_value_depth=config.depth,
                                       filter_size=config.filter)
        else:
            self.decoder = VarDecoder3(config.emb_dim,
                                       hidden_size=config.hidden_dim,
                                       num_layers=config.hop,
                                       num_heads=config.heads,
                                       total_key_depth=config.depth,
                                       total_value_depth=config.depth,
                                       filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.r_encoder.load_state_dict(state['r_encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'r_encoder_state_dict': self.r_encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(
            config.PAD_idx).unsqueeze(1)
        posterior_mask = self.embedding(batch["posterior_mask"])
        r_encoder_outputs = self.r_encoder(
            self.embedding(batch["posterior_batch"]), mask_res)
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)
        meta = self.embedding(batch["program_label"])

        # Decode
        mask_trg = dec_batch.data.eq(config.PAD_idx).unsqueeze(1)
        latent_dim = meta.size()[-1]
        meta = meta.repeat(1, dec_batch.size(1)).view(dec_batch.size(0),
                                                      dec_batch.size(1),
                                                      latent_dim)
        pre_logit, attn_dist, mean, log_var = self.decoder(
            meta, encoder_outputs, r_encoder_outputs,
            (mask_src, mask_res, mask_trg))
        if not train:
            pre_logit, attn_dist, _, _ = self.decoder(
                meta, encoder_outputs, None, (mask_src, None, mask_trg))
        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
        kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],
                                mean["prior"], log_var["prior"])
        kld_loss = torch.mean(kld_loss)
        kl_weight = min(iter / config.full_kl_step,
                        1) if config.full_kl_step > 0 else 1.0
        loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss
        if (train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(loss_rec.item(),
                                             100)), kld_loss.item()

    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        mask_trg = torch.ones((enc_batch.size(0), 50))
        meta_size = meta.size()
        meta = meta.repeat(1, 50).view(meta_size[0], 50, meta_size[1])
        out, attn_dist, _, _ = self.decoder(meta, encoder_outputs, None,
                                            (mask_src, None, mask_trg))
        prob = self.generator(out,
                              attn_dist,
                              enc_batch_extend_vocab,
                              extra_zeros,
                              attn_dist_db=None)
        _, batch_out = torch.max(prob, dim=1)
        batch_out = batch_out.data.cpu().numpy()
        sentences = []
        for sent in batch_out:
            st = ''
            for w in sent:
                if w == config.EOS_idx: break
                else: st += self.vocab.index2word[w] + ' '
            sentences.append(st)
        return sentences

    def decoder_greedy_po(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(
            config.PAD_idx).unsqueeze(1)
        posterior_mask = self.embedding(batch["posterior_mask"])
        r_encoder_outputs = self.r_encoder(
            self.embedding(batch["posterior_batch"]), mask_res)
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)
        meta = self.embedding(batch["program_label"])

        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)
        mask_trg = torch.ones((enc_batch.size(0), 50))
        meta_size = meta.size()
        meta = meta.repeat(1, 50).view(meta_size[0], 50, meta_size[1])
        out, attn_dist, mean, log_var = self.decoder(
            meta, encoder_outputs, r_encoder_outputs,
            (mask_src, mask_res, mask_trg))
        prob = self.generator(out,
                              attn_dist,
                              enc_batch_extend_vocab,
                              extra_zeros,
                              attn_dist_db=None)
        _, batch_out = torch.max(prob, dim=1)
        batch_out = batch_out.data.cpu().numpy()
        sentences = []
        for sent in batch_out:
            st = ''
            for w in sent:
                if w == config.EOS_idx: break
                else: st += self.vocab.index2word[w] + ' '
            sentences.append(st)
        return sentences
Example #18
0
class Summarizer(nn.Module):
    def __init__(self,
                 is_draft,
                 toeknizer,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Summarizer, self).__init__()
        self.is_draft = is_draft
        self.toeknizer = toeknizer
        if is_draft:
            self.encoder = BertModel.from_pretrained('bert-base-uncased')
        else:
            BertForMaskedLM.from_pretrained('bert-base-uncased')
        self.encoder.eval()  # always in eval mode
        self.embedding = self.encoder.embeddings

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.decoder = Decoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, config.vocab_size)

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)

        self.embedding = self.embedding.eval()
        if is_eval:
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 4000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))
        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            print("LOSS", state['current_loss'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda(device=0)
            self.decoder = self.decoder.cuda(device=0)
            self.generator = self.generator.cuda(device=0)
            self.criterion = self.criterion.cuda(device=0)
            self.embedding = self.embedding.cuda(device=0)
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, loss, iter, r_avg):
        state = {
            'iter': iter,
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            #'optimizer': self.optimizer.state_dict(),
            'current_loss': loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_{}_{:.4f}_{:.4f}'.format(iter, loss, r_avg))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, train=True):
        ## pad and other stuff
        input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch(
            batch)
        dec_batch, dec_mask_batch, dec_index_batch, copy_gate, copy_ptr = get_output_from_batch(
            batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        with torch.no_grad():
            # encoder_outputs are hidden states from transformer
            encoder_outputs, _ = self.encoder(
                input_ids_batch,
                token_type_ids=example_index_batch,
                attention_mask=input_mask_batch,
                output_all_encoded_layers=False)

        # # Draft Decoder
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     input_ids_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda(device=0)

        dec_batch_shift = torch.cat(
            (sos_token, dec_batch[:, :-1]),
            1)  # shift the decoder input (summary) by one step
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit1, attn_dist1 = self.decoder(self.embedding(dec_batch_shift),
                                              encoder_outputs,
                                              (None, mask_trg))

        # print(pre_logit1.size())
        ## compute output dist
        logit1 = self.generator(pre_logit1,
                                attn_dist1,
                                enc_batch_extend_vocab,
                                extra_zeros,
                                copy_gate=copy_gate,
                                copy_ptr=copy_ptr,
                                mask_trg=mask_trg)
        ## loss: NNL if ptr else Cross entropy
        loss1 = self.criterion(logit1.contiguous().view(-1, logit1.size(-1)),
                               dec_batch.contiguous().view(-1))

        # Refine Decoder - train using gold label TARGET
        'TODO: turn gold-target-text into BERT insertable representation'
        pre_logit2, attn_dist2 = self.generate_refinement_output(
            encoder_outputs, dec_batch, dec_index_batch, extra_zeros,
            dec_mask_batch)
        # pre_logit2, attn_dist2 = self.decoder(self.embedding(encoded_gold_target),encoder_outputs, (None,mask_trg))

        logit2 = self.generator(pre_logit2,
                                attn_dist2,
                                enc_batch_extend_vocab,
                                extra_zeros,
                                copy_gate=copy_gate,
                                copy_ptr=copy_ptr,
                                mask_trg=None)
        loss2 = self.criterion(logit2.contiguous().view(-1, logit2.size(-1)),
                               dec_batch.contiguous().view(-1))

        loss = loss1 + loss2

        if train:
            loss.backward()
            self.optimizer.step()
        return loss

    def eval_one_batch(self, batch):
        draft_seq_batch = self.decoder_greedy(batch)

        d_seq_input_ids_batch, d_seq_input_mask_batch, d_seq_example_index_batch = text_input2bert_input(
            draft_seq_batch, self.tokenizer)
        pre_logit2, attn_dist2 = self.generate_refinement_output(
            encoder_outputs, d_seq_input_ids_batch, d_seq_example_index_batch,
            extra_zeros, d_seq_input_mask_batch)

        decoded_words, sent = [], []
        for out, attn_dist in zip(pre_logit2, attn_dist2):
            prob = self.generator(out,
                                  attn_dist,
                                  enc_batch_extend_vocab,
                                  extra_zeros,
                                  copy_gate=copy_gate,
                                  copy_ptr=copy_ptr,
                                  mask_trg=None)
            _, next_word = torch.max(prob[:, -1], dim=1)
            decoded_words.append(
                self.tokenizer.convert_ids_to_tokens(next_word.tolist()))

        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>' or e.strip() == '<PAD>': break
                else: st += e + ' '
            sent.append(st)
        return sent

    def generate_refinement_output(self, encoder_outputs, input_ids_batch,
                                   example_index_batch, extra_zeros,
                                   input_mask_batch):
        # mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        # decoded_words = []
        logits, attns = [], []

        for i in range(config.max_dec_step):
            # print(i)
            with torch.no_grad():
                # Additionally mask the location of i.
                context_input_mask_batch = []
                # print(context_input_mask_batch.shape) # (2,512) (batch_size, seq_len)
                for mask in input_mask_batch:
                    mask[i] = 0
                    context_input_mask_batch.append(mask)

                context_input_mask_batch = torch.stack(
                    context_input_mask_batch)  #.cuda(device=0)
                # self.embedding = self.embedding.cuda(device=0)
                context_vector, _ = self.encoder(
                    input_ids_batch,
                    token_type_ids=example_index_batch,
                    attention_mask=context_input_mask_batch,
                    output_all_encoded_layers=False)

                if config.USE_CUDA:
                    context_vector = context_vector.cuda(device=0)
            # decoder input size == encoder output size == (batch_size, 512, 768)
            out, attn_dist = self.decoder(context_vector, encoder_outputs,
                                          (None, None))

            logits.append(out[:, i:i + 1, :])
            attns.append(attn_dist[:, i:i + 1, :])

        logits = torch.cat(logits, dim=1)
        attns = torch.cat(attns, dim=1)

        # print(logits.size(), attns.size())
        return logits, attns

    def decoder_greedy(self, batch):
        input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch(
            batch)
        # mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        with torch.no_grad():
            encoder_outputs, _ = self.encoder(
                input_ids_batch,
                token_type_ids=enc_batch_extend_vocab,
                attention_mask=input_mask_batch,
                output_all_encoded_layers=False)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA: ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(config.max_dec_step):
            out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs,
                                          (None, mask_trg))
            prob = self.generator(out, attn_dist, enc_batch_extend_vocab,
                                  extra_zeros)
            _, next_word = torch.max(prob[:, -1], dim=1)

            decoded_words.append(
                self.tokenizer.convert_ids_to_tokens(next_word.tolist()))

            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>' or e == '<PAD>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Example #19
0
    def __init__(self,
                 is_draft,
                 toeknizer,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Summarizer, self).__init__()
        self.is_draft = is_draft
        self.toeknizer = toeknizer
        if is_draft:
            self.encoder = BertModel.from_pretrained('bert-base-uncased')
        else:
            BertForMaskedLM.from_pretrained('bert-base-uncased')
        self.encoder.eval()  # always in eval mode
        self.embedding = self.encoder.embeddings

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.decoder = Decoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, config.vocab_size)

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)

        self.embedding = self.embedding.eval()
        if is_eval:
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 4000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))
        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            print("LOSS", state['current_loss'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda(device=0)
            self.decoder = self.decoder.cuda(device=0)
            self.generator = self.generator.cuda(device=0)
            self.criterion = self.criterion.cuda(device=0)
            self.embedding = self.embedding.cuda(device=0)
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""
Example #20
0
class Seq2SPG(nn.Module):
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Seq2SPG, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.preptrained)
        self.encoder = nn.LSTM(config.emb_dim,
                               config.hidden_dim,
                               config.hop,
                               bidirectional=False,
                               batch_first=True,
                               dropout=0.2)
        self.encoder2decoder = nn.Linear(config.hidden_dim, config.hidden_dim)
        self.decoder = LSTMAttentionDot(config.emb_dim,
                                        config.hidden_dim,
                                        batch_first=True)
        self.memory = MLP(
            config.hidden_dim + config.emb_dim,
            [config.private_dim1, config.private_dim2, config.private_dim3],
            config.hidden_dim)
        self.dec_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim)
        self.mem_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim)
        self.generator = Generator(config.hidden_dim, self.vocab_size)
        self.hooks = {
        }  #Save the model structure of each task as masks of the parameters
        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.weight
        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        if is_eval:
            self.encoder = self.encoder.eval()
            self.encoder2decoder = self.encoder2decoder.eval()
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()
            self.embedding = self.embedding.eval()
            self.memory = self.memory.eval()
            self.dec_gate = self.dec_gate.eval()
            self.mem_gate = self.mem_gate.eval()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 4000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))
        if config.use_sgd:
            self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr)
        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            print("LOSS", state['current_loss'])
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.encoder2decoder.load_state_dict(
                state['encoder2decoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            self.memory.load_state_dict(state['memory_dict'])
            self.dec_gate.load_state_dict(state['dec_gate_dict'])
            self.mem_gate.load_state_dict(state['mem_gate_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda()
            self.encoder2decoder = self.encoder2decoder.cuda()
            self.decoder = self.decoder.cuda()
            self.generator = self.generator.cuda()
            self.criterion = self.criterion.cuda()
            self.embedding = self.embedding.cuda()
            self.memory = self.memory.cuda()
            self.dec_gate = self.dec_gate.cuda()
            self.mem_gate = self.mem_gate.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self,
                   running_avg_ppl,
                   iter,
                   f1_g,
                   f1_b,
                   ent_g,
                   ent_b,
                   log=False,
                   d="tmaml_sim_model"):
        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'encoder2decoder_state_dict': self.encoder2decoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'memory_dict': self.memory.state_dict(),
            'dec_gate_dict': self.dec_gate.state_dict(),
            'mem_gate_dict': self.mem_gate.state_dict(),
            #'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        if log:
            model_save_path = os.path.join(
                d, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                    iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        else:
            model_save_path = os.path.join(
                self.model_dir,
                'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                    iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def get_state(self, batch):
        """Get cell states and hidden states for LSTM"""
        batch_size = batch.size(0) \
            if self.encoder.batch_first else batch.size(1)
        h0_encoder = Variable(torch.zeros(self.encoder.num_layers, batch_size,
                                          config.hidden_dim),
                              requires_grad=False)
        c0_encoder = Variable(torch.zeros(self.encoder.num_layers, batch_size,
                                          config.hidden_dim),
                              requires_grad=False)

        return h0_encoder.cuda(), c0_encoder.cuda()

    def compute_hooks(self, task):
        """Compute the masks of the private module"""
        current_layer = 3
        out_mask = torch.ones(self.memory.output_size)
        self.hooks[task] = {}
        self.hooks[task]["w_hooks"] = {}
        self.hooks[task]["b_hooks"] = {}
        while (current_layer >= 0):
            connections = self.memory.layers[current_layer].weight.data
            output_size, input_size = connections.shape
            mask = connections.abs() > 0.05
            in_mask = torch.zeros(input_size)
            for index, line in enumerate(mask):
                if (out_mask[index] == 1):
                    torch.max(in_mask, (line.cpu() != 0).float(), out=in_mask)
            if (config.USE_CUDA):
                self.hooks[task]["b_hooks"][current_layer] = out_mask.cuda()
                self.hooks[task]["w_hooks"][current_layer] = torch.mm(
                    out_mask.unsqueeze(1), in_mask.unsqueeze(0)).cuda()
            else:
                self.hooks[task]["b_hooks"][current_layer] = out_mask
                self.hooks[task]["w_hooks"][current_layer] = torch.mm(
                    out_mask.unsqueeze(1), in_mask.unsqueeze(0))
            out_mask = in_mask
            current_layer -= 1

    def register_hooks(self, task):
        if "hook_handles" not in self.hooks[task]:
            self.hooks[task]["hook_handles"] = []
        for i, l in enumerate(self.memory.layers):
            self.hooks[task]["hook_handles"].append(
                l.bias.register_hook(make_hook(
                    self.hooks[task]["b_hooks"][i])))
            self.hooks[task]["hook_handles"].append(
                l.weight.register_hook(
                    make_hook(self.hooks[task]["w_hooks"][i])))

    def unhook(self, task):
        for handle in self.hooks[task]["hook_handles"]:
            handle.remove()
        self.hooks[task]["hook_handles"] = []

    def train_one_batch(self, batch, train=True, mode="pretrain", task=0):
        enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        self.h0_encoder, self.c0_encoder = self.get_state(enc_batch)
        src_h, (src_h_t,
                src_c_t) = self.encoder(self.embedding(enc_batch),
                                        (self.h0_encoder, self.c0_encoder))
        h_t = src_h_t[-1]
        c_t = src_c_t[-1]

        # Decode
        decoder_init_state = nn.Tanh()(self.encoder2decoder(h_t))

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)
        target_embedding = self.embedding(dec_batch_shift)
        ctx = src_h.transpose(0, 1)
        trg_h, (_, _) = self.decoder(target_embedding,
                                     (decoder_init_state, c_t), ctx)

        #Memory
        mem_h_input = torch.cat(
            (decoder_init_state.unsqueeze(1), trg_h[:, 0:-1, :]), 1)
        mem_input = torch.cat((target_embedding, mem_h_input), 2)
        mem_output = self.memory(mem_input)

        #Combine
        gates = self.dec_gate(trg_h) + self.mem_gate(mem_output)
        decoder_gate, memory_gate = gates.chunk(2, 2)
        decoder_gate = F.sigmoid(decoder_gate)
        memory_gate = F.sigmoid(memory_gate)
        pre_logit = F.tanh(decoder_gate * trg_h + memory_gate * mem_output)
        logit = self.generator(pre_logit)

        if mode == "pretrain":
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
            if train:
                loss.backward()
                self.optimizer.step()
            if (config.label_smoothing):
                loss = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1))
            return loss.item(), math.exp(min(loss.item(), 100)), loss

        elif mode == "select":
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
            if (train):
                l1_loss = 0.0
                for p in self.memory.parameters():
                    l1_loss += torch.sum(torch.abs(p))
                loss += 0.0005 * l1_loss
                loss.backward()
                self.optimizer.step()
                self.compute_hooks(task)
            if (config.label_smoothing):
                loss = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1))
            return loss.item(), math.exp(min(loss.item(), 100)), loss

        else:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
            if (train):
                self.register_hooks(task)
                loss.backward()
                self.optimizer.step()
                self.unhook(task)
            if (config.label_smoothing):
                loss = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1))
            return loss.item(), math.exp(min(loss.item(), 100)), loss
Example #21
0
class Transformer(nn.Module):
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        ## multiple decoders
        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if config.label_smoothing:
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if config.noam:
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if load_optim:
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]

        if config.noam:
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Embedding - context
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1).to(
                                         config.device)  # (bsz, 1)
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  # (bsz, tgt_len)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),
                                            encoder_outputs,
                                            (mask_src, mask_trg))

        logit = self.generator(
            pre_logit, attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros)
        #logit = F.log_softmax(logit,dim=-1) #fix the name later
        ## loss: NNL if ptr else Cross entropy
        if config.pointer_gen:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_ext_batch.contiguous().view(-1))
        else:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))

        if config.label_smoothing:
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_ext_batch.contiguous().view(-1)
                if config.pointer_gen else dec_batch.contiguous().view(-1))

        if train:
            loss.backward()
            self.optimizer.step()

        if config.label_smoothing:
            if torch.isnan(loss_ppl).sum().item() != 0 or torch.isinf(
                    loss_ppl).sum().item() != 0:
                print("check")
                pdb.set_trace()
            return loss_ppl.item(), math.exp(min(loss_ppl.item(), 100)), 0, 0
        else:
            return loss.item(), math.exp(min(loss.item(), 100)), 0, 0

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self, batch, max_dec_step=30):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]

        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        ## Encode - context
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch) + emb_mask  # todo eos or sentence embedding??
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)
        enc_ext_batch = enc_batch_extend_vocab

        # ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        ys = torch.zeros(enc_batch.size(0), 1).fill_(config.SOS_idx).long().to(
            config.device)  # when testing, we set bsz into 1
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if config.project:
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg))
            else:
                out, attn_dist = self.decoder(self.embedding(ys),
                                              encoder_outputs,
                                              (mask_src, mask_trg))
            prob = self.generator(out, attn_dist, enc_ext_batch, extra_zeros)

            _, next_word = torch.max(prob[:, -1], dim=1)  # bsz=1
            cur_words = []
            for i_batch, ni in enumerate(next_word.view(-1)):
                if ni.item() == config.EOS_idx:
                    cur_words.append('<EOS>')
                elif ni.item() in self.vocab.index2word:
                    cur_words.append(self.vocab.index2word[ni.item()])
                else:
                    cur_words.append(oovs[i_batch][ni.item() -
                                                   self.vocab.n_words])
            decoded_words.append(cur_words)
            next_word = next_word.data[0]

            if next_word.item() not in self.vocab.index2word:
                next_word = torch.tensor(config.UNK_idx)

            ys = torch.cat([
                ys,
                torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).to(
                    config.device)
            ],
                           dim=1).to(config.device)

            # if config.USE_CUDA:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).cuda()], dim=1)
            #     ys = ys.cuda()
            # else:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent
Example #22
0
class Transformer(nn.Module):

    def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab,config.preptrained)
        self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth, total_value_depth=config.depth,
                                filter_size=config.filter,universal=config.universal)
            
        self.decoder = Decoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth,total_value_depth=config.depth,
                                filter_size=config.filter,universal=config.universal)
        self.generator = Generator(config.hidden_dim,self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        if is_eval:
            self.encoder = self.encoder.eval()
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()
            self.embedding = self.embedding.eval()

    
        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if(config.noam):
            self.optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
        if config.use_sgd:
            self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr)
        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            print("LOSS",state['current_loss'])
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.generator = self.generator.cuda()
            self.criterion = self.criterion.cuda()
            self.embedding = self.embedding.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b):
        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            #'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) )
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, train=True):
        ## pad and other stuff
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        encoder_outputs = self.encoder(self.embedding(enc_batch),mask_src)

        # Decode 
        sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),encoder_outputs, (mask_src,mask_trg))
        ## compute output dist
        logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab, extra_zeros)
        
        ## loss: NNL if ptr else Cross entropy
        loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))

        if(train):
            loss.backward()
            self.optimizer.step()
        if(config.label_smoothing): 
            loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        
        return loss.item(), math.exp(min(loss.item(), 100)), loss
Example #23
0
class Transformer(nn.Module):
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.preptrained)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.decoder = Decoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)
        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final
            # logit dense layer
            self.generator.proj.weight = self.embedding.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)

        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        if is_eval:
            self.encoder = self.encoder.eval()
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()
            self.embedding = self.embedding.eval()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)

        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 4000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if config.use_sgd:
            self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr)

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            print("LOSS", state['current_loss'])
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.generator = self.generator.cuda()
            self.criterion = self.criterion.cuda()
            self.embedding = self.embedding.cuda()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):
        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            # 'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, train=True):
        # pad and other stuff
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = \
            get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        # Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA:
            sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),
                                            encoder_outputs,
                                            (mask_src, mask_trg))
        # compute output dist
        logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab,
                               extra_zeros)

        # loss: NNL if ptr else Cross entropy
        loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                              dec_batch.contiguous().view(-1))

        if (config.act):
            loss += self.compute_act_loss(self.encoder)
            loss += self.compute_act_loss(self.decoder)

        if (train):
            loss.backward()
            self.optimizer.step()
        if (config.label_smoothing):
            loss = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1))

        return loss.item(), math.exp(min(loss.item(), 100)), loss

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self, batch):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(config.max_dec_step):
            out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs,
                                          (mask_src, mask_trg))
            prob = self.generator(out, attn_dist, enc_batch_extend_vocab,
                                  extra_zeros)
            _, next_word = torch.max(prob[:, -1], dim=1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]

            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent

    def score_sentence(self, batch):
        # pad and other stuff
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        cand_batch = batch["cand_index"]
        hit_1 = 0
        for i, b in enumerate(enc_batch):
            # Encode
            mask_src = b.unsqueeze(0).data.eq(config.PAD_idx).unsqueeze(1)
            encoder_outputs = self.encoder(self.embedding(b.unsqueeze(0)),
                                           mask_src)
            rank = {}
            for j, c in enumerate(cand_batch[i]):
                if config.USE_CUDA:
                    c = c.cuda()
                # Decode
                sos_token = torch.LongTensor(
                    [config.SOS_idx] * b.unsqueeze(0).size(0)).unsqueeze(1)
                if config.USE_CUDA:
                    sos_token = sos_token.cuda()
                dec_batch_shift = torch.cat(
                    (sos_token, c.unsqueeze(0)[:, :-1]), 1)

                mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
                pre_logit, attn_dist = self.decoder(
                    self.embedding(dec_batch_shift), encoder_outputs,
                    (mask_src, mask_trg))

                # compute output dist
                logit = self.generator(pre_logit, attn_dist,
                                       enc_batch_extend_vocab[i].unsqueeze(0),
                                       extra_zeros)
                loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    c.unsqueeze(0).contiguous().view(-1))
                # print("CANDIDATE {}".format(j), loss.item(), math.exp(min(loss.item(), 100)))
                rank[j] = math.exp(min(loss.item(), 100))
            s = sorted(rank.items(), key=lambda x: x[1], reverse=False)
            if (
                    s[1][0] == 19
            ):  # because the candidate are sorted in revers order ====> last (19) is the correct one
                hit_1 += 1
        return hit_1 / float(len(enc_batch))