Пример #1
0
class CvaeTrans(nn.Module):
    def __init__(self,
                 vocab,
                 emo_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(CvaeTrans, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.r_encoder = Encoder(config.emb_dim,
                                 config.hidden_dim,
                                 num_layers=config.hop,
                                 num_heads=config.heads,
                                 total_key_depth=config.depth,
                                 total_value_depth=config.depth,
                                 filter_size=config.filter,
                                 universal=config.universal)
        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)
        self.latent_layer = Latent(is_eval)
        self.bow = SoftmaxOutputLayer(config.hidden_dim, self.vocab_size)
        if config.multitask:
            self.emo = SoftmaxOutputLayer(config.hidden_dim, emo_number)
            self.emo_criterion = nn.NLLLoss()
        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.r_encoder.load_state_dict(state['r_encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            self.latent_layer.load_state_dict(state['latent_dict'])
            self.bow.load_state_dict(state['bow'])
        if (config.USE_CUDA):
            self.cuda()
        if is_eval:
            self.eval()
        else:
            self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
            if (config.noam):
                self.optimizer = NoamOpt(
                    config.hidden_dim, 1, 8000,
                    torch.optim.Adam(self.parameters(),
                                     lr=0,
                                     betas=(0.9, 0.98),
                                     eps=1e-9))
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
                if config.USE_CUDA:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'r_encoder_state_dict': self.r_encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'latent_dict': self.latent_layer.state_dict(),
            'bow': self.bow.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(
            config.PAD_idx).unsqueeze(1)
        posterior_mask = self.embedding(batch["posterior_mask"])
        r_encoder_outputs = self.r_encoder(
            self.embedding(batch["posterior_batch"]) + posterior_mask,
            mask_res)
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)
        #latent variable
        if config.model == "cvaetrs":
            kld_loss, z = self.latent_layer(encoder_outputs[:, 0],
                                            r_encoder_outputs[:, 0],
                                            train=True)

        meta = self.embedding(batch["program_label"])
        if config.dataset == "empathetic":
            meta = meta - meta
        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()

        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  #(batch, len, embedding)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        input_vector = self.embedding(dec_batch_shift)
        if config.model == "cvaetrs":
            input_vector[:, 0] = input_vector[:, 0] + z + meta
        else:
            input_vector[:, 0] = input_vector[:, 0] + meta
        pre_logit, attn_dist = self.decoder(input_vector, encoder_outputs,
                                            (mask_src, mask_trg))

        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
        if config.model == "cvaetrs":
            z_logit = self.bow(z + meta)  # [batch_size, vocab_size]
            z_logit = z_logit.unsqueeze(1).repeat(1, logit.size(1), 1)
            loss_aux = self.criterion(
                z_logit.contiguous().view(-1, z_logit.size(-1)),
                dec_batch.contiguous().view(-1))
            if config.multitask:
                emo_logit = self.emo(encoder_outputs[:, 0])
                emo_loss = self.emo_criterion(emo_logit,
                                              batch["program_label"] - 9)
            #kl_weight = min(iter/config.full_kl_step, 0.28) if config.full_kl_step >0 else 1.0
            kl_weight = min(
                math.tanh(6 * iter / config.full_kl_step - 3) + 1, 1)
            loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux
            if config.multitask:
                loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux + emo_loss
            aux = loss_aux.item()
            elbo = loss_rec + kld_loss
        else:
            loss = loss_rec
            elbo = loss_rec
            kld_loss = torch.Tensor([0])
            aux = 0
            if config.multitask:
                emo_logit = self.emo(encoder_outputs[:, 0])
                emo_loss = self.emo_criterion(emo_logit,
                                              batch["program_label"] - 9)
                loss = loss_rec + emo_loss
        if (train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(
            loss_rec.item(), 100)), kld_loss.item(), aux, elbo.item()

    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        if config.dataset == "empathetic":
            meta = meta - meta
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)
        if config.model == "cvaetrs":
            kld_loss, z = self.latent_layer(encoder_outputs[:, 0],
                                            None,
                                            train=False)

        ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        decoded_words = []
        for i in range(max_dec_step + 1):
            input_vector = self.embedding(ys)
            if config.model == "cvaetrs":
                input_vector[:, 0] = input_vector[:, 0] + z + meta
            else:
                input_vector[:, 0] = input_vector[:, 0] + meta
            out, attn_dist = self.decoder(input_vector, encoder_outputs,
                                          (mask_src, mask_trg))

            prob = self.generator(out,
                                  attn_dist,
                                  enc_batch_extend_vocab,
                                  extra_zeros,
                                  attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim=1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])

            if config.USE_CUDA:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)

            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Пример #2
0
class Transformer_experts(nn.Module):
    def __init__(self,
                 vocab,
                 decoder_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Transformer_experts, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)
        self.decoder_number = decoder_number
        ## multiple decoders
        self.decoder = MulDecoder(decoder_number,
                                  config.emb_dim,
                                  config.hidden_dim,
                                  num_layers=config.hop,
                                  num_heads=config.heads,
                                  total_key_depth=config.depth,
                                  total_value_depth=config.depth,
                                  filter_size=config.filter)

        self.decoder_key = nn.Linear(config.hidden_dim,
                                     decoder_number,
                                     bias=False)

        self.generator = Generator(config.hidden_dim, self.vocab_size)
        self.emoji_embedding = nn.Linear(64, config.emb_dim, bias=False)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        if config.softmax:
            self.attention_activation = nn.Softmax(dim=1)
        else:
            self.attention_activation = nn.Sigmoid()  #nn.Softmax()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.decoder_key.load_state_dict(state['decoder_key_state_dict'])
            #self.emoji_embedding.load_state_dict(state['emoji_embedding_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):
        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'decoder_key_state_dict': self.decoder_key.state_dict(),
            #'emoji_embedding_dict': self.emoji_embedding.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        if config.dataset == "empathetic":
            emb_mask = self.embedding(batch["mask_input"])
            encoder_outputs = self.encoder(
                self.embedding(enc_batch) + emb_mask, mask_src)
        else:
            encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)
        ## Attention over decoder
        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        #q_h = encoder_outputs[:,0]
        logit_prob = self.decoder_key(q_h)  #(bsz, num_experts)

        if (config.topk > 0):
            k_max_value, k_max_index = torch.topk(logit_prob, config.topk)
            a = np.empty([logit_prob.shape[0], self.decoder_number])
            a.fill(float('-inf'))
            mask = torch.Tensor(a).cuda()
            logit_prob_ = mask.scatter_(1,
                                        k_max_index.cuda().long(), k_max_value)
            attention_parameters = self.attention_activation(logit_prob_)
        else:
            attention_parameters = self.attention_activation(logit_prob)
        # print("===============================================================================")
        # print("listener attention weight:",attention_parameters.data.cpu().numpy())
        # print("===============================================================================")
        if (config.oracle):
            attention_parameters = self.attention_activation(
                torch.FloatTensor(batch['target_program']) * 1000).cuda()
        attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze(
            -1)  # (batch_size, expert_num, 1, 1)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),
                                            encoder_outputs,
                                            (mask_src, mask_trg),
                                            attention_parameters)
        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        #logit = F.log_softmax(logit,dim=-1) #fix the name later
        ## loss: NNL if ptr else Cross entropy
        if (train and config.schedule > 10):
            if (random.uniform(0, 1) <=
                (0.0001 +
                 (1 - 0.0001) * math.exp(-1. * iter / config.schedule))):
                config.oracle = True
            else:
                config.oracle = False

        if config.softmax:
            loss = self.criterion(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)) + nn.CrossEntropyLoss()(
                    logit_prob, torch.LongTensor(
                        batch['program_label']).cuda())
            loss_bce_program = nn.CrossEntropyLoss()(
                logit_prob,
                torch.LongTensor(batch['program_label']).cuda()).item()
        else:
            loss = self.criterion(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)) + nn.BCEWithLogitsLoss()(
                    logit_prob, torch.FloatTensor(
                        batch['target_program']).cuda())
            loss_bce_program = nn.BCEWithLogitsLoss()(
                logit_prob,
                torch.FloatTensor(batch['target_program']).cuda()).item()
        pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1)
        program_acc = accuracy_score(batch["program_label"], pred_program)

        if (config.label_smoothing):
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)).item()

        if (train):
            loss.backward()
            self.optimizer.step()

        if (config.label_smoothing):
            return loss_ppl, math.exp(min(loss_ppl,
                                          100)), loss_bce_program, program_acc
        else:
            return loss.item(), math.exp(min(
                loss.item(), 100)), loss_bce_program, program_acc

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self, batch, max_dec_step=30):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)
        ## Attention over decoder
        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        #q_h = encoder_outputs[:,0]
        logit_prob = self.decoder_key(q_h)

        if (config.topk > 0):
            k_max_value, k_max_index = torch.topk(logit_prob, config.topk)
            a = np.empty([logit_prob.shape[0], self.decoder_number])
            a.fill(float('-inf'))
            mask = torch.Tensor(a).cuda()
            logit_prob = mask.scatter_(1,
                                       k_max_index.cuda().long(), k_max_value)

        attention_parameters = self.attention_activation(logit_prob)

        if (config.oracle):
            attention_parameters = self.attention_activation(
                torch.FloatTensor(batch['target_program']) * 1000).cuda()
        attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze(
            -1)  # (batch_size, expert_num, 1, 1)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg), attention_parameters)
            else:

                out, attn_dist = self.decoder(self.embedding(ys),
                                              encoder_outputs,
                                              (mask_src, mask_trg),
                                              attention_parameters)

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            #logit = F.log_softmax(logit,dim=-1) #fix the name later
            _, next_word = torch.max(logit[:, -1], dim=1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent

    def decoder_topk(self, batch, max_dec_step=30):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        ## Attention over decoder
        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        #q_h = encoder_outputs[:,0]
        logit_prob = self.decoder_key(q_h)

        if (config.topk > 0):
            k_max_value, k_max_index = torch.topk(logit_prob, config.topk)
            a = np.empty([logit_prob.shape[0], self.decoder_number])
            a.fill(float('-inf'))
            mask = torch.Tensor(a).cuda()
            logit_prob = mask.scatter_(1,
                                       k_max_index.cuda().long(), k_max_value)

        attention_parameters = self.attention_activation(logit_prob)

        if (config.oracle):
            attention_parameters = self.attention_activation(
                torch.FloatTensor(batch['target_program']) * 1000).cuda()
        attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze(
            -1)  # (batch_size, expert_num, 1, 1)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg), attention_parameters)
            else:

                out, attn_dist = self.decoder(self.embedding(ys),
                                              encoder_outputs,
                                              (mask_src, mask_trg),
                                              attention_parameters)

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            filtered_logit = top_k_top_p_filtering(logit[:, -1],
                                                   top_k=3,
                                                   top_p=0,
                                                   filter_value=-float('Inf'))
            # Sample from the filtered distribution
            next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1),
                                          1).squeeze()
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Пример #3
0
class PGNet(nn.Module):
    '''
    refer: https://github.com/atulkum/pointer_summarizer
    '''
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(PGNet, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.reduce_state = ReduceState()

        self.generator = Generator(config.rnn_hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if config.label_smoothing:
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if config.noam:
            self.optimizer = NoamOpt(
                config.rnn_hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if load_optim:
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch = batch["review_batch"]
        enc_lens = batch["review_length"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]
        max_tgt_len = dec_batch.size(0)

        if config.noam:
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Embedding - context
        mask_src = enc_batch
        mask_src = ~(mask_src.data.eq(config.PAD_idx))
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs, encoder_feature, encoder_hidden = self.encoder(
            src_emb, enc_lens)

        # reduce bidirectional hidden to one hidden (h and c)
        s_t_1 = self.reduce_state(encoder_hidden)  # 1 x b x hidden_dim
        c_t_1 = Variable(
            torch.zeros((enc_batch.size(0),
                         2 * config.rnn_hidden_dim))).to(config.device)
        coverage = Variable(torch.zeros(enc_batch.size())).to(config.device)
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1).to(
                                         config.device)  # (bsz, 1)
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  # (bsz, tgt_len)
        dec_batch_embd = self.embedding(dec_batch_shift)

        step_losses = []
        step_loss_ppls = 0
        for di in range(max_tgt_len):
            y_t_1 = dec_batch_embd[:, di, :]
            logit, s_t_1, c_t_1, attn_dist, next_coverage, p_gen = self.decoder(
                y_t_1, s_t_1, encoder_outputs, encoder_feature, mask_src,
                c_t_1, coverage, di)

            logit = self.generator(
                logit, attn_dist,
                enc_batch_extend_vocab if config.pointer_gen else None,
                extra_zeros, 1, p_gen)

            if config.pointer_gen:
                step_loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_ext_batch[:, di].contiguous().view(-1))
            else:
                step_loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch[:, di].contiguous().view(-1))

            if config.label_smoothing:
                step_loss_ppl = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch[:, di].contiguous().view(-1))
                step_loss_ppls += step_loss_ppl

            if config.is_coverage:
                # coverage loss
                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage),
                                               1)
                # loss sum
                step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
                # update coverage
                coverage = next_coverage

            step_losses.append(step_loss)
        if config.is_coverage:
            sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
            batch_avg_loss = sum_losses / batch['tags_length'].float()
            loss = torch.mean(batch_avg_loss)
        else:
            loss = sum(step_losses) / max_tgt_len
        if config.label_smoothing:
            loss_ppl = (step_loss_ppls / max_tgt_len).item()

        if train:
            loss.backward()
            self.optimizer.step()

        if config.label_smoothing:
            return loss_ppl, math.exp(min(loss_ppl, 100)), 0, 0
        else:
            return loss.item(), math.exp(min(loss.item(), 100)), 0, 0

    def decoder_greedy(self, batch, max_dec_step=30):
        enc_batch = batch["review_batch"]
        enc_lens = batch["review_length"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]
        max_tgt_len = dec_batch.size(0)

        ## Embedding - context
        mask_src = enc_batch
        mask_src = ~(mask_src.data.eq(config.PAD_idx))
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs, encoder_feature, encoder_hidden = self.encoder(
            src_emb, enc_lens)

        # reduce bidirectional hidden to one hidden (h and c)
        s_t_1 = self.reduce_state(encoder_hidden)  # 1 x b x hidden_dim
        c_t_1 = Variable(
            torch.zeros((enc_batch.size(0),
                         2 * config.rnn_hidden_dim))).to(config.device)
        coverage = Variable(torch.zeros(enc_batch.size())).to(config.device)
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        # ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        ys = torch.zeros(enc_batch.size(0)).fill_(config.SOS_idx).long().to(
            config.device)  # when testing, we set bsz into 1
        decoded_words = []
        for i in range(max_dec_step + 1):
            logit, s_t_1, c_t_1, attn_dist, next_coverage, p_gen = self.decoder(
                self.embedding(ys), s_t_1, encoder_outputs, encoder_feature,
                mask_src, c_t_1, coverage, i)
            prob = self.generator(
                logit, attn_dist,
                enc_batch_extend_vocab if config.pointer_gen else None,
                extra_zeros, 1, p_gen)

            _, next_word = torch.max(prob, dim=1)  # bsz=1
            cur_words = []
            for i_batch, ni in enumerate(next_word.view(-1)):
                if ni.item() == config.EOS_idx:
                    cur_words.append('<EOS>')
                elif ni.item() in self.vocab.index2word:
                    cur_words.append(self.vocab.index2word[ni.item()])
                else:
                    cur_words.append(oovs[i_batch][ni.item() -
                                                   self.vocab.n_words])
            decoded_words.append(cur_words)
            next_word = next_word.data[0]

            if next_word.item() not in self.vocab.index2word:
                next_word = torch.tensor(config.UNK_idx)

            ys = torch.zeros(enc_batch.size(0)).long().fill_(next_word).to(
                config.device)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent
Пример #4
0
Файл: AOTNet.py Проект: qtli/AOT
class AOT(nn.Module):
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(AOT, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)
        self.sse = SSE(vocab, config.emb_dim, config.dropout,
                       config.rnn_hidden_dim)
        self.rcr = RCR()

        ## multiple decoders
        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if config.label_smoothing:
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if config.noam:
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if load_optim:
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch_slow(self, batch, iter, train=True):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        src_batch = batch[
            'reviews_batch']  # reviews sequence (bsz, r_num, r_len)
        src_mask = batch[
            'reviews_mask']  # indicate which review is fake(for padding).  (bsz, r_num)
        src_length = batch['reviews_length']  # (bsz, r_num)
        enc_length_batch = batch[
            'reviews_length_list']  # 2-dim list, 0: len=bsz, 1: lens of reviews and pads
        src_labels = batch['reviews_label']  # (bsz, r_num)
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]
        dec_rank_batch = batch[
            'tags_idx_batch']  # tag indexes sequence (bsz, tgt_len)

        if config.noam:
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        # 1. Sentence-level Salience Estimation (SSE)
        cla_loss, sa_scores, sa_acc = self.sse.salience_estimate(
            src_batch, src_mask, src_length,
            src_labels)  # sa_scores: (bsz, r_num)

        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        src_enc_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, src_len, emb_dim)
        src_ext_rank = torch.LongTensor([]).to(config.device)  # (bsz, src_len)
        aln_rank = torch.LongTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)
        aln_mask_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)

        bsz, max_src_len = enc_batch.size()
        for idx in range(bsz):  # Custering (by k-means) and Ranking
            item_length = enc_length_batch[idx]
            reviews = torch.split(encoder_outputs[idx], item_length, dim=0)
            reviews_ext = torch.split(enc_batch_extend_vocab[idx],
                                      item_length,
                                      dim=0)

            r_vectors = []  # store the vector repr of each review
            rs_vectors = []  # store the token vectors repr of each review
            r_exts = []
            r_pad_vec, r_ext_pad = None, None
            for r_idx in range(len(item_length)):
                if r_idx == len(item_length) - 1:
                    r_pad_vec = reviews[r_idx]
                    r_ext_pad = reviews_ext[r_idx]
                    break
                r = self.rcr.hierarchical_pooling(reviews[r_idx].unsqueeze(
                    0)).squeeze(0).detach().cpu().numpy() * sa_scores[idx,
                                                                      r_idx]
                r_vectors.append(r)
                rs_vectors.append(reviews[r_idx])
                r_exts.append(reviews_ext[r_idx])

            rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln = \
                self.rcr.perform(r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, dec_rank_batch[idx], max_src_len)
            # rs_repr: (max_rs_length, embed_dim); ext_repr: (max_rs_length); srctgt_aln_mask/srctgt_aln: (tgt_len, max_rs_length)

            src_enc_rank = torch.cat((src_enc_rank, rs_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len, embed_dim)
            src_ext_rank = torch.cat((src_ext_rank, ext_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len)
            aln_rank = torch.cat((aln_rank, srctgt_aln.unsqueeze(0)),
                                 dim=0)  # (1->bsz, max_tgt_len, max_src_len)
            aln_mask_rank = torch.cat(
                (aln_mask_rank, srctgt_aln_mask.unsqueeze(0)), dim=0)

        del encoder_outputs, reviews, reviews_ext, r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln
        torch.cuda.empty_cache()
        torch.backends.cuda.cufft_plan_cache.clear()

        ys = torch.LongTensor([config.SOS_idx] *
                              enc_batch.size(0)).unsqueeze(1).to(
                                  config.device)  # (bsz, 1)
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        ys_rank = torch.LongTensor([1] * enc_batch.size(0)).unsqueeze(1).to(
            config.device)

        max_tgt_len = dec_batch.size(1)
        loss, loss_ppl = 0, 0
        for t in range(max_tgt_len):
            aln_rank_cur = aln_rank[:, t, :].unsqueeze(1)  # (bsz, 1, src_len)
            aln_mask_cur = aln_mask_rank[:, :(t + 1), :]  # (bsz, src_len)
            pre_logit, attn_dist, aln_loss_cur = self.decoder(
                inputs=self.embedding(ys),
                inputs_rank=ys_rank,
                encoder_output=src_enc_rank,
                aln_rank=aln_rank_cur,
                aln_mask_rank=aln_mask_cur,
                mask=(mask_src, mask_trg),
                speed='slow')
            # todo
            if iter >= 13000:
                loss += (0.1 * aln_loss_cur)
            else:
                loss += aln_loss_cur
            logit = self.generator(
                pre_logit, attn_dist.unsqueeze(1),
                enc_batch_extend_vocab if config.pointer_gen else None,
                extra_zeros)

            if config.pointer_gen:
                loss += self.criterion(
                    logit[:, -1, :].contiguous().view(-1, logit.size(-1)),
                    dec_ext_batch[:, t].contiguous().view(-1))
            else:
                loss += self.criterion(
                    logit[:, -1, :].contiguous().view(-1, logit.size(-1)),
                    dec_batch[:, t].contiguous().view(-1))

            if config.label_smoothing:
                loss_ppl += self.criterion_ppl(
                    logit[:, -1, :].contiguous().view(-1, logit.size(-1)),
                    dec_ext_batch[:, t].contiguous().view(-1) if
                    config.pointer_gen else dec_batch[:,
                                                      t].contiguous().view(-1))

            ys = torch.cat((ys, dec_batch[:, t].unsqueeze(1)), dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
            ys_rank = torch.cat((ys_rank, dec_rank_batch[:, t].unsqueeze(1)),
                                dim=1)

        loss = loss + cla_loss
        if train:
            loss /= max_tgt_len
            loss.backward()
            self.optimizer.step()

        if config.label_smoothing:
            loss_ppl /= max_tgt_len
            if torch.isnan(loss_ppl).sum().item() != 0 or torch.isinf(
                    loss_ppl).sum().item() != 0:
                print("check")
                pdb.set_trace()
            return loss_ppl.item(), math.exp(min(loss_ppl.item(),
                                                 100)), cla_loss.item(), sa_acc
        else:
            return loss.item(), math.exp(min(loss.item(),
                                             100)), cla_loss.item(), sa_acc

    def train_one_batch(self, batch, iter, train=True):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        src_batch = batch[
            'reviews_batch']  # reviews sequence (bsz, r_num, r_len)
        src_mask = batch[
            'reviews_mask']  # indicate which review is fake(for padding).  (bsz, r_num)
        src_length = batch['reviews_length']  # (bsz, r_num)
        enc_length_batch = batch[
            'reviews_length_list']  # 2-dim list, 0: len=bsz, 1: lens of reviews and pads
        src_labels = batch['reviews_label']  # (bsz, r_num)
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]
        dec_rank_batch = batch[
            'tags_idx_batch']  # tag indexes sequence (bsz, tgt_len)

        if config.noam:
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        # 1. Sentence-level Salience Estimation (SSE)
        cla_loss, sa_scores, sa_acc = self.sse.salience_estimate(
            src_batch, src_mask, src_length,
            src_labels)  # sa_scores: (bsz, r_num)

        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        src_enc_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, src_len, emb_dim)
        src_ext_rank = torch.LongTensor([]).to(config.device)  # (bsz, src_len)
        aln_rank = torch.LongTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)
        aln_mask_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)

        bsz, max_src_len = enc_batch.size()
        for idx in range(bsz):  # Custering (by k-means) and Ranking
            item_length = enc_length_batch[idx]
            reviews = torch.split(encoder_outputs[idx], item_length, dim=0)
            reviews_ext = torch.split(enc_batch_extend_vocab[idx],
                                      item_length,
                                      dim=0)

            r_vectors = []  # store the vector repr of each review
            rs_vectors = []  # store the token vectors repr of each review
            r_exts = []
            r_pad_vec, r_ext_pad = None, None
            for r_idx in range(len(item_length)):
                if r_idx == len(item_length) - 1:
                    r_pad_vec = reviews[r_idx]
                    r_ext_pad = reviews_ext[r_idx]
                    break
                r = self.rcr.hierarchical_pooling(reviews[r_idx].unsqueeze(
                    0)).squeeze(0).detach().cpu().numpy() * sa_scores[idx,
                                                                      r_idx]
                r_vectors.append(r)
                rs_vectors.append(reviews[r_idx])
                r_exts.append(reviews_ext[r_idx])

            rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln = \
                self.rcr.perform(r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, dec_rank_batch[idx], max_src_len)
            # rs_repr: (max_rs_length, embed_dim); ext_repr: (max_rs_length); srctgt_aln_mask/srctgt_aln: (tgt_len, max_rs_length)

            src_enc_rank = torch.cat((src_enc_rank, rs_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len, embed_dim)
            src_ext_rank = torch.cat((src_ext_rank, ext_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len)
            aln_rank = torch.cat((aln_rank, srctgt_aln.unsqueeze(0)),
                                 dim=0)  # (1->bsz, max_tgt_len, max_src_len)
            aln_mask_rank = torch.cat(
                (aln_mask_rank, srctgt_aln_mask.unsqueeze(0)), dim=0)

        del encoder_outputs, reviews, reviews_ext, r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln
        torch.cuda.empty_cache()
        torch.backends.cuda.cufft_plan_cache.clear()

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1).to(
                                         config.device)  # (bsz, 1)
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  # (bsz, tgt_len)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

        sos_rank = torch.LongTensor([1] * enc_batch.size(0)).unsqueeze(1).to(
            config.device)
        dec_rank_batch = torch.cat((sos_rank, dec_rank_batch[:, :-1]), 1)

        aln_rank = aln_rank[:, :-1, :]
        aln_mask_rank = aln_mask_rank[:, :-1, :]

        pre_logit, attn_dist, aln_loss = self.decoder(
            inputs=self.embedding(dec_batch_shift),
            inputs_rank=dec_rank_batch,
            encoder_output=src_enc_rank,
            aln_rank=aln_rank,
            aln_mask_rank=aln_mask_rank,
            mask=(mask_src, mask_trg))
        logit = self.generator(
            pre_logit, attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros)

        if config.pointer_gen:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_ext_batch.contiguous().view(-1))
        else:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))

        if config.label_smoothing:
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_ext_batch.contiguous().view(-1)
                if config.pointer_gen else dec_batch.contiguous().view(-1))

        if train:
            if iter >= 13000:
                loss = loss + (0.1 * aln_loss) + cla_loss
            else:
                loss = loss + aln_loss + cla_loss
            loss = loss + aln_loss + cla_loss
            loss.backward()
            self.optimizer.step()

        if config.label_smoothing:
            if torch.isnan(loss_ppl).sum().item() != 0 or torch.isinf(
                    loss_ppl).sum().item() != 0:
                print("check")
                pdb.set_trace()
            return loss_ppl.item(), math.exp(min(loss_ppl.item(),
                                                 100)), cla_loss.item(), sa_acc
        else:
            return loss.item(), math.exp(min(loss.item(),
                                             100)), cla_loss.item(), sa_acc

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self, batch, max_dec_step=30):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        src_batch = batch[
            'reviews_batch']  # reviews sequence (bsz, r_num, r_len)
        src_mask = batch[
            'reviews_mask']  # indicate which review is fake(for padding).  (bsz, r_num)
        src_length = batch['reviews_length']  # (bsz, r_num)
        enc_length_batch = batch[
            'reviews_length_list']  # 2-dim list, 0: len=bsz, 1: lens of reviews and pads
        src_labels = batch['reviews_label']  # (bsz, r_num)
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        # 1. Sentence-level Salience Estimation (SSE)
        cla_loss, sa_scores, sa_acc = self.sse.salience_estimate(
            src_batch, src_mask, src_length,
            src_labels)  # sa_scores: (bsz, r_num)

        ## Encode - context
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch) + emb_mask  # todo eos or sentence embedding??
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        src_enc_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, src_len, emb_dim)
        src_ext_rank = torch.LongTensor([]).to(config.device)  # (bsz, src_len)
        aln_rank = torch.LongTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)
        aln_mask_rank = torch.FloatTensor([]).to(
            config.device)  # (bsz, tgt_len, src_len)

        bsz, max_src_len = enc_batch.size()
        for idx in range(bsz):  # Custering (by k-means) and Ranking
            item_length = enc_length_batch[idx]
            reviews = torch.split(encoder_outputs[idx], item_length, dim=0)
            reviews_ext = torch.split(enc_batch_extend_vocab[idx],
                                      item_length,
                                      dim=0)

            r_vectors = []  # store the vector repr of each review
            rs_vectors = []  # store the token vectors repr of each review
            r_exts = []
            r_pad_vec, r_ext_pad = None, None
            for r_idx in range(len(item_length)):
                if r_idx == len(item_length) - 1:
                    r_pad_vec = reviews[r_idx]
                    r_ext_pad = reviews_ext[r_idx]
                    break
                r = self.rcr.hierarchical_pooling(reviews[r_idx].unsqueeze(
                    0)).squeeze(0).detach().cpu().numpy() * sa_scores[idx,
                                                                      r_idx]
                r_vectors.append(r)
                rs_vectors.append(reviews[r_idx])
                r_exts.append(reviews_ext[r_idx])

            rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln = self.rcr.perform(
                r_vecs=r_vectors,
                rs_vecs=rs_vectors,
                r_exts=r_exts,
                r_pad_vec=r_pad_vec,
                r_ext_pad=r_ext_pad,
                max_rs_length=max_src_len,
                train=False)
            # rs_repr: (max_rs_length, embed_dim); ext_repr: (max_rs_length); srctgt_aln_mask/srctgt_aln: (tgt_len, max_rs_length)

            src_enc_rank = torch.cat((src_enc_rank, rs_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len, embed_dim)
            src_ext_rank = torch.cat((src_ext_rank, ext_repr.unsqueeze(0)),
                                     dim=0)  # (1->bsz, max_src_len)
            aln_rank = torch.cat((aln_rank, srctgt_aln.unsqueeze(0)),
                                 dim=0)  # (1->bsz, max_tgt_len, max_src_len)
            aln_mask_rank = torch.cat(
                (aln_mask_rank, srctgt_aln_mask.unsqueeze(0)), dim=0)

        del encoder_outputs, reviews, reviews_ext, r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln
        torch.cuda.empty_cache()
        torch.backends.cuda.cufft_plan_cache.clear()

        # ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        ys = torch.zeros(enc_batch.size(0), 1).fill_(config.SOS_idx).long().to(
            config.device)  # when testing, we set bsz into 1
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        ys_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device)
        last_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device)

        pred_attn_dist = torch.FloatTensor([]).to(config.device)
        decoded_words = []
        for i in range(max_dec_step + 1):
            aln_rank_cur = aln_rank[:, last_rank.item(), :].unsqueeze(
                1)  # (bsz, src_len)
            if config.project:
                out, attn_dist, _ = self.decoder(
                    inputs=self.embedding_proj_in(self.embedding(ys)),
                    inputs_rank=ys_rank,
                    encoder_output=self.embedding_proj_in(src_enc_rank),
                    aln_rank=aln_rank_cur,
                    aln_mask_rank=aln_mask_rank,  # nouse
                    mask=(mask_src, mask_trg),
                    speed='slow')
            else:
                out, attn_dist, _ = self.decoder(inputs=self.embedding(ys),
                                                 inputs_rank=ys_rank,
                                                 encoder_output=src_enc_rank,
                                                 aln_rank=aln_rank_cur,
                                                 aln_mask_rank=aln_mask_rank,
                                                 mask=(mask_src, mask_trg),
                                                 speed='slow')
            prob = self.generator(out, attn_dist, enc_batch_extend_vocab,
                                  extra_zeros)
            _, next_word = torch.max(prob[:, -1], dim=1)  # bsz=1, if test

            cur_words = []
            for i_batch, ni in enumerate(next_word.view(-1)):
                if ni.item() == config.EOS_idx:
                    cur_words.append('<EOS>')
                    last_rank[i_batch] = 0
                elif ni.item() in self.vocab.index2word:
                    cur_words.append(self.vocab.index2word[ni.item()])
                    if ni.item() == config.SOS_idx:
                        last_rank[i_batch] += 1
                else:
                    cur_words.append(oovs[i_batch][
                        ni.item() -
                        self.vocab.n_words])  # output non-dict word
                    next_word[i_batch] = config.UNK_idx  # input unk word
            decoded_words.append(cur_words)
            # next_word = next_word.data[0]
            # if next_word.item() not in self.vocab.index2word:
            #     next_word = torch.tensor(config.UNK_idx)

            # if config.USE_CUDA:
            ys = torch.cat([ys, next_word.unsqueeze(1)],
                           dim=1).to(config.device)
            ys_rank = torch.cat([ys_rank, last_rank], dim=1).to(config.device)
            # else:
            #     ys = torch.cat([ys, next_word],dim=1)
            #     ys_rank = torch.cat([ys_rank, last_rank],dim=1)

            # if config.USE_CUDA:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).cuda()], dim=1)
            #     ys = ys.cuda()
            # else:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

            if config.attn_analysis:
                pred_attn_dist = torch.cat(
                    (pred_attn_dist, attn_dist.unsqueeze(1)), dim=1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)

        if config.attn_analysis:
            bsz, tgt_len, src_len = aln_mask_rank.size()
            pred_attn = pred_attn_dist[:, 1:, :].view(bsz * tgt_len, src_len)
            tgt_attn = aln_mask_rank.view(bsz * tgt_len, src_len)
            good_attn_sum = torch.masked_select(
                pred_attn,
                tgt_attn.bool()).sum()  # pred_attn: bsz * tgt_len, src_len
            bad_attn_sum = torch.masked_select(pred_attn,
                                               ~tgt_attn.bool()).sum()
            bad_num = ~tgt_attn.bool()
            ratio = bad_num.sum() / tgt_attn.bool().sum()
            bad_attn_sum /= ratio

            good_attn = good_attn_sum[
                0]  # last step (because this's already been the whole sentence length.).
            bad_attn = bad_attn_sum[1]
            good_attn /= (tgt_len * bsz)
            bad_attn /= (tgt_len * bsz)

            return sent, [good_attn, bad_attn]
        else:
            return sent
Пример #5
0
class CvaeTrans(nn.Module):

    def __init__(self, vocab, emo_number,  model_file_path=None, is_eval=False, load_optim=False):
        super(CvaeTrans, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, pretrain=False)
        
        self.word_encoder = WordEncoder(config.emb_dim, config.hidden_dim, config.bidirectional)
        self.encoder = Encoder(config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth, total_value_depth=config.depth,
                                filter_size=config.filter, universal=config.universal)
        self.r_encoder = Encoder(config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth, total_value_depth=config.depth,
                                filter_size=config.filter, universal=config.universal)
        
        self.decoder = VarDecoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, 
                                num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, 
                                filter_size=config.filter, vocab_size=self.vocab_size)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        self.linear = nn.Linear(2 * config.hidden_dim, config.hidden_dim)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        
        if model_file_path:
            print("loading weights")
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            #self.r_encoder.load_state_dict(state['r_encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            
        if (config.USE_CUDA):
            self.cuda()
        if is_eval:
            self.eval()
        else:
            self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
            if(config.noam):
                self.optimizer = NoamOpt(config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
                if config.USE_CUDA:
                    for state in self.optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            #'r_encoder_state_dict': self.r_encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) )
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(config.PAD_idx).unsqueeze(1)
        post_emb = self.embedding(batch["posterior_batch"])
        r_encoder_outputs = self.r_encoder(post_emb, mask_res)

        ## Encode
        num_sentences, enc_seq_len = enc_batch.size()
        batch_size = enc_lens.size(0)
        max_len = enc_lens.data.max().item()
        input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1)
        
        # word level encoder
        enc_emb = self.embedding(enc_batch)
        word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths)
        word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1)

        # pad and pack word_encoder_hidden
        start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0)
        word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len)
                                            for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0)
        
        # mask_src = ~(enc_padding_mask.bool()).unsqueeze(1)
        mask_src = (1 - enc_padding_mask.byte()).unsqueeze(1)
        
        # context level encoder
        if word_encoder_hidden.size(-1) != config.hidden_dim:
            word_encoder_hidden = self.linear(word_encoder_hidden)
        encoder_outputs = self.encoder(word_encoder_hidden, mask_src)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] * batch_size).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()

        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) #(batch, len, embedding)
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        dec_emb = self.embedding(dec_batch_shift)

        pre_logit, attn_dist, mean, log_var, probs = self.decoder(dec_emb, encoder_outputs, r_encoder_outputs, 
                                                                    (mask_src, mask_res, mask_trg))
        
        ## compute output dist
        logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        sbow = dec_batch #[batch, seq_len]
        seq_len = sbow.size(1)
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        if config.model=="cvaetrs":
            loss_aux = 0
            for prob in probs:
                sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2)
                sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len]

                loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1))
            kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"])
            kld_loss = torch.mean(kld_loss)
            kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1)
            #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0
            loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux
            elbo = loss_rec + kld_loss
        else:
            loss = loss_rec
            elbo = loss_rec
            kld_loss = torch.Tensor([0])
            loss_aux = torch.Tensor([0])
        if(train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item()

    def train_n_batch(self, batchs, iter, train=True):
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()
        for batch in batchs:
            enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
            dec_batch, _, _, _, _ = get_output_from_batch(batch)
            ## Encode
            mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
            encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

            meta = self.embedding(batch["program_label"])
            if config.dataset=="empathetic":
                meta = meta-meta
            # Decode
            sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
            if config.USE_CUDA: sos_token = sos_token.cuda()
            dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1)

            mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

            pre_logit, attn_dist, mean, log_var, probs= self.decoder(self.embedding(dec_batch_shift)+meta.unsqueeze(1),encoder_outputs, True, (mask_src,mask_trg))
            ## compute output dist
            logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None)
            ## loss: NNL if ptr else Cross entropy
            sbow = dec_batch #[batch, seq_len]
            seq_len = sbow.size(1)
            loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
            if config.model=="cvaetrs":
                loss_aux = 0
                for prob in probs:
                    sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2)
                    sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len]

                    loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1))
                kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"])
                kld_loss = torch.mean(kld_loss)
                kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1)
                #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0
                loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux
                elbo = loss_rec+kld_loss
            else:
                loss = loss_rec
                elbo = loss_rec
                kld_loss = torch.Tensor([0])
                loss_aux = torch.Tensor([0])
            loss.backward()
            # clip gradient
        nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
        self.optimizer.step()

        return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item()

    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        
        ## Encode
        num_sentences, enc_seq_len = enc_batch.size()
        batch_size = enc_lens.size(0)
        max_len = enc_lens.data.max().item()
        input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1)
        
        # word level encoder
        enc_emb = self.embedding(enc_batch)
        word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths)
        word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1)

        # pad and pack word_encoder_hidden
        start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0)
        word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len)
                                            for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0)
        mask_src = ~(enc_padding_mask.bool()).unsqueeze(1)

        # context level encoder
        if word_encoder_hidden.size(-1) != config.hidden_dim:
            word_encoder_hidden = self.linear(word_encoder_hidden)
        encoder_outputs = self.encoder(word_encoder_hidden, mask_src)
        
        ys = torch.ones(batch_size, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        
        decoded_words = []
        for i in range(max_dec_step+1):

            out, attn_dist, _, _,_ = self.decoder(self.embedding(ys), encoder_outputs, None, (mask_src, None, mask_trg))
            
            prob = self.generator(out,attn_dist,enc_batch_extend_vocab, extra_zeros, attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim = 1)
            decoded_words.append(['<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1)])

            if config.USE_CUDA:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
            
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st+= e + ' '
            sent.append(st)
        return sent
Пример #6
0
class Transformer(nn.Module):
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        ## multiple decoders
        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if config.label_smoothing:
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if config.noam:
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if load_optim:
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]
        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]

        if config.noam:
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Embedding - context
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1).to(
                                         config.device)  # (bsz, 1)
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  # (bsz, tgt_len)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),
                                            encoder_outputs,
                                            (mask_src, mask_trg))

        logit = self.generator(
            pre_logit, attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros)
        #logit = F.log_softmax(logit,dim=-1) #fix the name later
        ## loss: NNL if ptr else Cross entropy
        if config.pointer_gen:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_ext_batch.contiguous().view(-1))
        else:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))

        if config.label_smoothing:
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_ext_batch.contiguous().view(-1)
                if config.pointer_gen else dec_batch.contiguous().view(-1))

        if train:
            loss.backward()
            self.optimizer.step()

        if config.label_smoothing:
            if torch.isnan(loss_ppl).sum().item() != 0 or torch.isinf(
                    loss_ppl).sum().item() != 0:
                print("check")
                pdb.set_trace()
            return loss_ppl.item(), math.exp(min(loss_ppl.item(), 100)), 0, 0
        else:
            return loss.item(), math.exp(min(loss.item(), 100)), 0, 0

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self, batch, max_dec_step=30):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]

        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        ## Encode - context
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch) + emb_mask  # todo eos or sentence embedding??
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)
        enc_ext_batch = enc_batch_extend_vocab

        # ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        ys = torch.zeros(enc_batch.size(0), 1).fill_(config.SOS_idx).long().to(
            config.device)  # when testing, we set bsz into 1
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if config.project:
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg))
            else:
                out, attn_dist = self.decoder(self.embedding(ys),
                                              encoder_outputs,
                                              (mask_src, mask_trg))
            prob = self.generator(out, attn_dist, enc_ext_batch, extra_zeros)

            _, next_word = torch.max(prob[:, -1], dim=1)  # bsz=1
            cur_words = []
            for i_batch, ni in enumerate(next_word.view(-1)):
                if ni.item() == config.EOS_idx:
                    cur_words.append('<EOS>')
                elif ni.item() in self.vocab.index2word:
                    cur_words.append(self.vocab.index2word[ni.item()])
                else:
                    cur_words.append(oovs[i_batch][ni.item() -
                                                   self.vocab.n_words])
            decoded_words.append(cur_words)
            next_word = next_word.data[0]

            if next_word.item() not in self.vocab.index2word:
                next_word = torch.tensor(config.UNK_idx)

            ys = torch.cat([
                ys,
                torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).to(
                    config.device)
            ],
                           dim=1).to(config.device)

            # if config.USE_CUDA:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).cuda()], dim=1)
            #     ys = ys.cuda()
            # else:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent
Пример #7
0
class CvaeNAD(nn.Module):
    def __init__(self,
                 vocab,
                 emo_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(CvaeNAD, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.r_encoder = Encoder(config.emb_dim,
                                 config.hidden_dim,
                                 num_layers=config.hop,
                                 num_heads=config.heads,
                                 total_key_depth=config.depth,
                                 total_value_depth=config.depth,
                                 filter_size=config.filter,
                                 universal=config.universal)
        if config.num_var_layers > 0:
            self.decoder = VarDecoder2(config.emb_dim,
                                       hidden_size=config.hidden_dim,
                                       num_layers=config.hop,
                                       num_heads=config.heads,
                                       total_key_depth=config.depth,
                                       total_value_depth=config.depth,
                                       filter_size=config.filter)
        else:
            self.decoder = VarDecoder3(config.emb_dim,
                                       hidden_size=config.hidden_dim,
                                       num_layers=config.hop,
                                       num_heads=config.heads,
                                       total_key_depth=config.depth,
                                       total_value_depth=config.depth,
                                       filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.r_encoder.load_state_dict(state['r_encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'r_encoder_state_dict': self.r_encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(
            config.PAD_idx).unsqueeze(1)
        posterior_mask = self.embedding(batch["posterior_mask"])
        r_encoder_outputs = self.r_encoder(
            self.embedding(batch["posterior_batch"]), mask_res)
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)
        meta = self.embedding(batch["program_label"])

        # Decode
        mask_trg = dec_batch.data.eq(config.PAD_idx).unsqueeze(1)
        latent_dim = meta.size()[-1]
        meta = meta.repeat(1, dec_batch.size(1)).view(dec_batch.size(0),
                                                      dec_batch.size(1),
                                                      latent_dim)
        pre_logit, attn_dist, mean, log_var = self.decoder(
            meta, encoder_outputs, r_encoder_outputs,
            (mask_src, mask_res, mask_trg))
        if not train:
            pre_logit, attn_dist, _, _ = self.decoder(
                meta, encoder_outputs, None, (mask_src, None, mask_trg))
        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        ## loss: NNL if ptr else Cross entropy
        loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
        kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],
                                mean["prior"], log_var["prior"])
        kld_loss = torch.mean(kld_loss)
        kl_weight = min(iter / config.full_kl_step,
                        1) if config.full_kl_step > 0 else 1.0
        loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss
        if (train):
            loss.backward()
            # clip gradient
            nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm)
            self.optimizer.step()

        return loss_rec.item(), math.exp(min(loss_rec.item(),
                                             100)), kld_loss.item()

    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        mask_trg = torch.ones((enc_batch.size(0), 50))
        meta_size = meta.size()
        meta = meta.repeat(1, 50).view(meta_size[0], 50, meta_size[1])
        out, attn_dist, _, _ = self.decoder(meta, encoder_outputs, None,
                                            (mask_src, None, mask_trg))
        prob = self.generator(out,
                              attn_dist,
                              enc_batch_extend_vocab,
                              extra_zeros,
                              attn_dist_db=None)
        _, batch_out = torch.max(prob, dim=1)
        batch_out = batch_out.data.cpu().numpy()
        sentences = []
        for sent in batch_out:
            st = ''
            for w in sent:
                if w == config.EOS_idx: break
                else: st += self.vocab.index2word[w] + ' '
            sentences.append(st)
        return sentences

    def decoder_greedy_po(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        ## Response encode
        mask_res = batch["posterior_batch"].data.eq(
            config.PAD_idx).unsqueeze(1)
        posterior_mask = self.embedding(batch["posterior_mask"])
        r_encoder_outputs = self.r_encoder(
            self.embedding(batch["posterior_batch"]), mask_res)
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)
        meta = self.embedding(batch["program_label"])

        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)
        mask_trg = torch.ones((enc_batch.size(0), 50))
        meta_size = meta.size()
        meta = meta.repeat(1, 50).view(meta_size[0], 50, meta_size[1])
        out, attn_dist, mean, log_var = self.decoder(
            meta, encoder_outputs, r_encoder_outputs,
            (mask_src, mask_res, mask_trg))
        prob = self.generator(out,
                              attn_dist,
                              enc_batch_extend_vocab,
                              extra_zeros,
                              attn_dist_db=None)
        _, batch_out = torch.max(prob, dim=1)
        batch_out = batch_out.data.cpu().numpy()
        sentences = []
        for sent in batch_out:
            st = ''
            for w in sent:
                if w == config.EOS_idx: break
                else: st += self.vocab.index2word[w] + ' '
            sentences.append(st)
        return sentences
class Transformer(nn.Module):
    def __init__(self,
                 vocab,
                 emo_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words
        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        meta = self.embedding(batch["program_label"])
        emb_mask = self.embedding(batch["input_mask"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(
            self.embedding(dec_batch_shift) + meta.unsqueeze(1),
            encoder_outputs, (mask_src, mask_trg))
        #+meta.unsqueeze(1)
        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)
        #logit = F.log_softmax(logit,dim=-1) #fix the name later
        ## loss: NNL if ptr else Cross entropy
        loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                              dec_batch.contiguous().view(-1))

        if (train):
            loss.backward()
            self.optimizer.step()

        return loss.item(), math.exp(min(loss.item(), 100)), 0

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        # print('=====================ys========================')
        # print(ys)

        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):

            out, attn_dist = self.decoder(
                self.embedding(ys) + meta.unsqueeze(1), encoder_outputs,
                (mask_src, mask_trg))

            prob = self.generator(out,
                                  attn_dist,
                                  enc_batch_extend_vocab,
                                  extra_zeros,
                                  attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim=1)
            # print('=====================next_word1========================')
            # print(next_word)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            #next_word = next_word.data[0]
            # print('=====================next_word2========================')
            # print(next_word)
            if config.USE_CUDA:
                # print('=====================shape========================')
                # print(ys.shape, next_word.shape)
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)

            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
            # print('=====================new_ys========================')
            # print(ys)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent

    def decoder_greedy_po(self, batch, max_dec_step=50):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["input_mask"])
        meta = self.embedding(batch["program_label"])
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        # print('=====================ys========================')
        # print(ys)

        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):

            out, attn_dist = self.decoder(
                self.embedding(ys) + meta.unsqueeze(1), encoder_outputs,
                (mask_src, mask_trg))

            prob = self.generator(out,
                                  attn_dist,
                                  enc_batch_extend_vocab,
                                  extra_zeros,
                                  attn_dist_db=None)
            _, next_word = torch.max(prob[:, -1], dim=1)
            # print('=====================next_word1========================')
            # print(next_word)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            #next_word = next_word.data[0]
            # print('=====================next_word2========================')
            # print(next_word)
            if config.USE_CUDA:
                # print('=====================shape========================')
                # print(ys.shape, next_word.shape)
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1)

            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
            # print('=====================new_ys========================')
            # print(ys)
        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Пример #9
0
class Train_MIME(nn.Module):
    '''
    for emotion attention, simply pass the randomly sampled emotion as the Q in a decoder block of transformer
    '''
    def __init__(self,
                 vocab,
                 decoder_number,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super().__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words
        self.embedding = share_embedding(self.vocab, config.pretrain_emb)

        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)
        self.decoder_number = decoder_number

        self.decoder = DecoderContextV(config.emb_dim,
                                       config.hidden_dim,
                                       num_layers=config.hop,
                                       num_heads=config.heads,
                                       total_key_depth=config.depth,
                                       total_value_depth=config.depth,
                                       filter_size=config.filter)

        self.vae_sampler = VAESampling(config.hidden_dim,
                                       config.hidden_dim,
                                       out_dim=300)

        # outputs m
        self.emotion_input_encoder_1 = EmotionInputEncoder(
            config.emb_dim,
            config.hidden_dim,
            num_layers=config.hop,
            num_heads=config.heads,
            total_key_depth=config.depth,
            total_value_depth=config.depth,
            filter_size=config.filter,
            universal=config.universal,
            emo_input=config.emo_input)
        # outputs m~
        self.emotion_input_encoder_2 = EmotionInputEncoder(
            config.emb_dim,
            config.hidden_dim,
            num_layers=config.hop,
            num_heads=config.heads,
            total_key_depth=config.depth,
            total_value_depth=config.depth,
            filter_size=config.filter,
            universal=config.universal,
            emo_input=config.emo_input)

        if config.emo_combine == "att":
            self.cdecoder = ComplexResDecoder(config.emb_dim,
                                              config.hidden_dim,
                                              num_layers=config.hop,
                                              num_heads=config.heads,
                                              total_key_depth=config.depth,
                                              total_value_depth=config.depth,
                                              filter_size=config.filter,
                                              universal=config.universal)

        elif config.emo_combine == "gate":
            self.cdecoder = ComplexResGate(config.emb_dim)

        self.s_weight = nn.Linear(config.hidden_dim,
                                  config.emb_dim,
                                  bias=False)
        self.decoder_key = nn.Linear(config.hidden_dim,
                                     decoder_number,
                                     bias=False)

        # v^T tanh(W E[i] + H c + b)
        method3 = True
        if method3:
            self.e_weight = nn.Linear(config.emb_dim,
                                      config.emb_dim,
                                      bias=True)
            self.v = torch.rand(config.emb_dim, requires_grad=True)
            if config.USE_CUDA: self.v = self.v.cuda()

        self.generator = Generator(config.hidden_dim, self.vocab_size)
        self.emoji_embedding = nn.Embedding(32, config.emb_dim)
        if config.init_emo_emb: self.init_emoji_embedding_with_glove()

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        if config.softmax:
            self.attention_activation = nn.Softmax(dim=1)
        else:
            self.attention_activation = nn.Sigmoid()  # nn.Softmax()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.load_state_dict(state['model'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

        # Added positive emotions
        self.positive_emotions = [
            11, 16, 6, 8, 3, 1, 28, 13, 31, 17, 24, 0, 27
        ]
        self.negative_emotions = [
            9, 4, 2, 22, 14, 30, 29, 25, 15, 10, 23, 19, 18, 21, 7, 20, 5, 26,
            12
        ]

    def init_emoji_embedding_with_glove(self):
        self.emotions = [
            'surprised', 'excited', 'annoyed', 'proud', 'angry', 'sad',
            'grateful', 'lonely', 'impressed', 'afraid', 'disgusted',
            'confident', 'terrified', 'hopeful', 'anxious', 'disappointed',
            'joyful', 'prepared', 'guilty', 'furious', 'nostalgic', 'jealous',
            'anticipating', 'embarrassed', 'content', 'devastated',
            'sentimental', 'caring', 'trusting', 'ashamed', 'apprehensive',
            'faithful'
        ]
        self.emotion_index = [self.vocab.word2index[i] for i in self.emotions]
        self.emoji_embedding_init = self.embedding(
            torch.Tensor(self.emotion_index).long())
        self.emoji_embedding.weight.data = self.emoji_embedding_init
        self.emoji_embedding.weight.requires_grad = True

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b,
                   ent_t):
        state = {
            'iter': iter,
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl,
            'model': self.state_dict()
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b, ent_t))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def random_sampling(self, e):
        p = np.random.choice(self.positive_emotions)
        n = np.random.choice(self.negative_emotions)
        if e in self.positive_emotions:
            mimic = p
            mimic_t = n
        else:
            mimic = n
            mimic_t = p
        return mimic, mimic_t

    def train_one_batch(self, batch, iter, train=True):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()
        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)

        if config.dataset == "empathetic":
            emb_mask = self.embedding(batch["mask_input"])
            encoder_outputs = self.encoder(
                self.embedding(enc_batch) + emb_mask, mask_src)
        else:
            encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        # q_h = torch.max(encoder_outputs, dim=1)
        emotions_mimic, emotions_non_mimic, mu_positive_prior, logvar_positive_prior, mu_negative_prior, logvar_negative_prior = \
            self.vae_sampler(q_h, batch['program_label'], self.emoji_embedding)
        # KLLoss = -0.5 * (torch.sum(1 + logvar_n - mu_n.pow(2) - logvar_n.exp()) + torch.sum(1 + logvar_p - mu_p.pow(2) - logvar_p.exp()))
        m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1),
                                             encoder_outputs, mask_src)
        m_tilde_out = self.emotion_input_encoder_2(
            emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src)
        if train:
            emotions_mimic, emotions_non_mimic, mu_positive_posterior, logvar_positive_posterior, mu_negative_posterior, logvar_negative_posterior = \
                self.vae_sampler.forward_train(q_h, batch['program_label'], self.emoji_embedding, M_out=m_out.mean(dim=1), M_tilde_out=m_tilde_out.mean(dim=1))
            KLLoss_positive = self.vae_sampler.kl_div(
                mu_positive_posterior, logvar_positive_posterior,
                mu_positive_prior, logvar_positive_prior)
            KLLoss_negative = self.vae_sampler.kl_div(
                mu_negative_posterior, logvar_negative_posterior,
                mu_negative_prior, logvar_negative_prior)
            KLLoss = KLLoss_positive + KLLoss_negative
        else:
            KLLoss_positive = self.vae_sampler.kl_div(mu_positive_prior,
                                                      logvar_positive_prior)
            KLLoss_negative = self.vae_sampler.kl_div(mu_negative_prior,
                                                      logvar_negative_prior)
            KLLoss = KLLoss_positive + KLLoss_negative

        if config.emo_combine == "att":
            v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src)
        elif config.emo_combine == "gate":
            v = self.cdecoder(m_out, m_tilde_out)

        x = self.s_weight(q_h)

        # method2: E (W@c)
        logit_prob = torch.matmul(x, self.emoji_embedding.weight.transpose(
            0, 1))  # shape (b_size, 32)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)

        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift), v,
                                            v, (mask_src, mask_trg))

        ## compute output dist
        logit = self.generator(
            pre_logit,
            attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros,
            attn_dist_db=None)

        if (train and config.schedule > 10):
            if (random.uniform(0, 1) <=
                (0.0001 +
                 (1 - 0.0001) * math.exp(-1. * iter / config.schedule))):
                config.oracle = True
            else:
                config.oracle = False

        if config.softmax:
            program_label = torch.LongTensor(batch['program_label'])
            if config.USE_CUDA: program_label = program_label.cuda()

            if config.emo_combine == 'gate':
                L1_loss = nn.CrossEntropyLoss()(logit_prob, program_label)
                loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1)) + KLLoss + L1_loss
            else:
                L1_loss = nn.CrossEntropyLoss()(
                    logit_prob, torch.LongTensor(batch['program_label'])
                    if not config.USE_CUDA else torch.LongTensor(
                        batch['program_label']).cuda())
                loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1)) + KLLoss + L1_loss

            loss_bce_program = nn.CrossEntropyLoss()(logit_prob,
                                                     program_label).item()
        else:
            loss = self.criterion(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)) + nn.BCEWithLogitsLoss()(
                    logit_prob, torch.FloatTensor(
                        batch['target_program']).cuda())
            loss_bce_program = nn.BCEWithLogitsLoss()(
                logit_prob,
                torch.FloatTensor(batch['target_program']).cuda()).item()
        pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1)
        program_acc = accuracy_score(batch["program_label"], pred_program)

        if (config.label_smoothing):
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1)).item()

        if (train):
            loss.backward()
            self.optimizer.step()

        if (config.label_smoothing):
            return loss_ppl, math.exp(min(loss_ppl,
                                          100)), loss_bce_program, program_acc
        else:
            return loss.item(), math.exp(min(
                loss.item(), 100)), loss_bce_program, program_acc

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self,
                       batch,
                       max_dec_step=30,
                       emotion_classifier='built_in'):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        emotions = batch['program_label']

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)

        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]

        # method 2
        x = self.s_weight(q_h)
        logit_prob = torch.matmul(x,
                                  self.emoji_embedding.weight.transpose(0, 1))
        emo_pred = torch.argmax(logit_prob, dim=-1)

        if emotion_classifier == "vader":
            context_emo = [
                self.positive_emotions[0]
                if d['compound'] > 0 else self.negative_emotions[0]
                for d in batch['context_emotion_scores']
            ]
            context_emo = torch.Tensor(context_emo)
            if config.USE_CUDA:
                context_emo = context_emo.cuda()
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, context_emo, self.emoji_embedding)
        elif emotion_classifier == None:
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, batch['program_label'], self.emoji_embedding)
        elif emotion_classifier == "built_in":
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, emo_pred, self.emoji_embedding)

        m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1),
                                             encoder_outputs, mask_src)
        m_tilde_out = self.emotion_input_encoder_2(
            emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src)

        if config.emo_combine == "att":
            v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src)
            # v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src_chosen)
        elif config.emo_combine == "gate":
            v = self.cdecoder(m_out, m_tilde_out)
        elif config.emo_combine == 'vader':
            m_weight = context_emo_scores.unsqueeze(-1).unsqueeze(-1)
            m_tilde_weight = 1 - m_weight
            v = m_weight * m_weight + m_tilde_weight * m_tilde_out

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    self.embedding_proj_in(v), (mask_src, mask_trg),
                    attention_parameters)
            else:
                out, attn_dist = self.decoder(self.embedding(ys), v, v,
                                              (mask_src, mask_trg))

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            _, next_word = torch.max(logit[:, -1], dim=1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent, batch['context_emotion_scores'][0]['compound'], int(
            emo_pred[0].data.cpu())

    def decoder_topk(self,
                     batch,
                     max_dec_step=30,
                     emotion_classifier='built_in'):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        emotions = batch['program_label']

        context_emo = [
            self.positive_emotions[0]
            if d['compound'] > 0 else self.negative_emotions[0]
            for d in batch['context_emotion_scores']
        ]
        context_emo = torch.Tensor(context_emo)
        if config.USE_CUDA:
            context_emo = context_emo.cuda()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)

        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]

        x = self.s_weight(q_h)
        # method 2
        logit_prob = torch.matmul(x,
                                  self.emoji_embedding.weight.transpose(0, 1))

        if emotion_classifier == "vader":
            context_emo = [
                self.positive_emotions[0]
                if d['compound'] > 0 else self.negative_emotions[0]
                for d in batch['context_emotion_scores']
            ]
            context_emo = torch.Tensor(context_emo)
            if config.USE_CUDA:
                context_emo = context_emo.cuda()
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, context_emo, self.emoji_embedding)
        elif emotion_classifier == None:
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, batch['program_label'], self.emoji_embedding)
        elif emotion_classifier == "built_in":
            emo_pred = torch.argmax(logit_prob, dim=-1)
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, emo_pred, self.emoji_embedding)

        m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1),
                                             encoder_outputs, mask_src)
        m_tilde_out = self.emotion_input_encoder_2(
            emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src)

        if config.emo_combine == "att":
            v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src)
        elif config.emo_combine == "gate":
            v = self.cdecoder(m_out, m_tilde_out)
        elif config.emo_combine == 'vader':
            m_weight = context_emo_scores.unsqueeze(-1).unsqueeze(-1)
            m_tilde_weight = 1 - m_weight
            v = m_weight * m_weight + m_tilde_weight * m_tilde_out

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg), attention_parameters)
            else:
                out, attn_dist = self.decoder(self.embedding(ys), v, v,
                                              (mask_src, mask_trg))

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            filtered_logit = top_k_top_p_filtering(logit[:, -1],
                                                   top_k=3,
                                                   top_p=0,
                                                   filter_value=-float('Inf'))
            # Sample from the filtered distribution
            next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1),
                                          1).squeeze()
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data.item()

            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent
Пример #10
0
class woRCR(nn.Module):
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(woRCR, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.pretrain_emb)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.sse = SSE(vocab, config.emb_dim, config.dropout,
                       config.rnn_hidden_dim)
        self.rcr = RCR()

        self.decoder = Decoder(config.emb_dim,
                               hidden_size=config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.lut.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if config.label_smoothing:
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if config.noam:
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 8000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if load_optim:
                self.optimizer.load_state_dict(state['optimizer'])
            self.eval()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):

        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, iter, train=True):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]

        src_batch = batch[
            'reviews_batch']  # reviews sequence (bsz, r_num, r_len)
        src_mask = batch[
            'reviews_mask']  # indicate which review is fake(for padding).  (bsz, r_num)
        src_length = batch['reviews_length']  # (bsz, r_num)
        enc_length_batch = batch[
            'reviews_length_list']  # 2-dim list, 0: len=bsz, 1: lens of reviews and pads

        src_labels = batch['reviews_label']  # (bsz, r_num)

        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        dec_batch = batch["tags_batch"]
        dec_ext_batch = batch["tags_ext_batch"]
        tid_batch = batch[
            'tags_idx_batch']  # tag indexes sequence (bsz, tgt_len)

        if config.noam:
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        # 1. Sentence-level Salience Estimation (SSE)
        cla_loss, sa_scores, sa_acc = self.sse.salience_estimate(
            src_batch, src_mask, src_length,
            src_labels)  # sa_scores: (bsz, r_num)

        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch)+emb_mask
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1).to(
                                         config.device)  # (bsz, 1)
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]),
                                    1)  # (bsz, tgt_len)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist, aln_loss = self.decoder(
            inputs=self.embedding(dec_batch_shift),
            inputs_rank=tid_batch,
            encoder_output=encoder_outputs,
            mask=(mask_src, mask_trg))

        logit = self.generator(
            pre_logit, attn_dist,
            enc_batch_extend_vocab if config.pointer_gen else None,
            extra_zeros)
        if config.pointer_gen:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_ext_batch.contiguous().view(-1))
        else:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))

        if config.label_smoothing:
            loss_ppl = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_ext_batch.contiguous().view(-1)
                if config.pointer_gen else dec_batch.contiguous().view(-1))
        loss = loss + cla_loss
        if torch.isnan(loss).sum().item() != 0 or torch.isinf(
                loss).sum().item() != 0:
            print("check")
            pdb.set_trace()
        if train:
            loss.backward()
            self.optimizer.step()

        if config.label_smoothing:
            loss_ppl = loss_ppl.item()
            cla_loss = cla_loss.item()
            return loss_ppl, math.exp(min(loss_ppl, 100)), cla_loss, sa_acc
        else:
            return loss.item(), math.exp(min(loss.item(),
                                             100)), cla_loss, sa_acc

    def decoder_greedy(self, batch, max_dec_step=30):
        enc_batch = batch["review_batch"]
        enc_batch_extend_vocab = batch["review_ext_batch"]

        src_batch = batch[
            'reviews_batch']  # reviews sequence (bsz, r_num, r_len)
        src_mask = batch[
            'reviews_mask']  # indicate which review is fake(for padding).  (bsz, r_num)
        src_length = batch['reviews_length']  # (bsz, r_num)
        enc_length_batch = batch[
            'reviews_length_list']  # 2-dim list, 0: len=bsz, 1: lens of reviews and pads

        src_labels = batch['reviews_label']  # (bsz, r_num)

        oovs = batch["oovs"]
        max_oov_length = len(
            sorted(oovs, key=lambda i: len(i), reverse=True)[0])
        extra_zeros = Variable(torch.zeros(
            (enc_batch.size(0), max_oov_length))).to(config.device)

        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(
            1)  # (bsz, src_len)->(bsz, 1, src_len)
        # emb_mask = self.embedding(batch["mask_context"])
        # src_emb = self.embedding(enc_batch) + emb_mask  # todo eos or sentence embedding??
        src_emb = self.embedding(enc_batch)
        encoder_outputs = self.encoder(src_emb,
                                       mask_src)  # (bsz, src_len, emb_dim)

        ys = torch.zeros(enc_batch.size(0), 1).fill_(config.SOS_idx).long().to(
            config.device)  # when testing, we set bsz into 1
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        ys_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device)
        last_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device)

        decoded_words = []
        for i in range(max_dec_step + 1):
            if config.project:
                out, attn_dist, aln_loss = self.decoder(
                    inputs=self.embedding_proj_in(self.embedding(ys)),
                    inputs_rank=ys_rank,
                    encoder_output=self.embedding_proj_in(encoder_outputs),
                    mask=(mask_src, mask_trg))
            else:
                out, attn_dist, aln_loss = self.decoder(
                    inputs=self.embedding(ys),
                    inputs_rank=ys_rank,
                    encoder_output=encoder_outputs,
                    mask=(mask_src, mask_trg))
            prob = self.generator(
                out, attn_dist,
                enc_batch_extend_vocab if config.pointer_gen else None,
                extra_zeros)

            _, next_word = torch.max(prob[:, -1], dim=1)  # bsz=1
            cur_words = []
            for i_batch, ni in enumerate(next_word.view(-1)):
                if ni.item() == config.EOS_idx:
                    cur_words.append('<EOS>')
                    last_rank[i_batch] = 0
                elif ni.item() in self.vocab.index2word:
                    cur_words.append(self.vocab.index2word[ni.item()])
                    if ni.item() == config.SOS_idx:
                        last_rank[i_batch] += 1
                else:
                    cur_words.append(oovs[i_batch][ni.item() -
                                                   self.vocab.n_words])
            decoded_words.append(cur_words)
            next_word = next_word.data[0]

            if next_word.item() not in self.vocab.index2word:
                next_word = torch.tensor(config.UNK_idx)

            ys = torch.cat([
                ys,
                torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).to(
                    config.device)
            ],
                           dim=1).to(config.device)
            ys_rank = torch.cat([ys_rank, last_rank], dim=1).to(config.device)
            # if config.USE_CUDA:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).cuda()], dim=1)
            #     ys = ys.cuda()
            # else:
            #     ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent