Ejemplo n.º 1
0
class Transformer(nn.Module):

    def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab,config.preptrained)
        self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth, total_value_depth=config.depth,
                                filter_size=config.filter,universal=config.universal)
            
        self.decoder = Decoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, 
                                total_key_depth=config.depth,total_value_depth=config.depth,
                                filter_size=config.filter,universal=config.universal)
        self.generator = Generator(config.hidden_dim,self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        if is_eval:
            self.encoder = self.encoder.eval()
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()
            self.embedding = self.embedding.eval()

    
        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if(config.noam):
            self.optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
        if config.use_sgd:
            self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr)
        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            print("LOSS",state['current_loss'])
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.generator = self.generator.cuda()
            self.criterion = self.criterion.cuda()
            self.embedding = self.embedding.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b):
        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            #'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) )
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, train=True):
        ## pad and other stuff
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        encoder_outputs = self.encoder(self.embedding(enc_batch),mask_src)

        # Decode 
        sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),encoder_outputs, (mask_src,mask_trg))
        ## compute output dist
        logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab, extra_zeros)
        
        ## loss: NNL if ptr else Cross entropy
        loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))

        if(train):
            loss.backward()
            self.optimizer.step()
        if(config.label_smoothing): 
            loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        
        return loss.item(), math.exp(min(loss.item(), 100)), loss
Ejemplo n.º 2
0
class Seq2SPG(nn.Module):
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Seq2SPG, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.preptrained)
        self.encoder = nn.LSTM(config.emb_dim,
                               config.hidden_dim,
                               config.hop,
                               bidirectional=False,
                               batch_first=True,
                               dropout=0.2)
        self.encoder2decoder = nn.Linear(config.hidden_dim, config.hidden_dim)
        self.decoder = LSTMAttentionDot(config.emb_dim,
                                        config.hidden_dim,
                                        batch_first=True)
        self.memory = MLP(
            config.hidden_dim + config.emb_dim,
            [config.private_dim1, config.private_dim2, config.private_dim3],
            config.hidden_dim)
        self.dec_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim)
        self.mem_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim)
        self.generator = Generator(config.hidden_dim, self.vocab_size)
        self.hooks = {
        }  #Save the model structure of each task as masks of the parameters
        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.weight
        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        if is_eval:
            self.encoder = self.encoder.eval()
            self.encoder2decoder = self.encoder2decoder.eval()
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()
            self.embedding = self.embedding.eval()
            self.memory = self.memory.eval()
            self.dec_gate = self.dec_gate.eval()
            self.mem_gate = self.mem_gate.eval()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 4000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))
        if config.use_sgd:
            self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr)
        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            print("LOSS", state['current_loss'])
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.encoder2decoder.load_state_dict(
                state['encoder2decoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            self.memory.load_state_dict(state['memory_dict'])
            self.dec_gate.load_state_dict(state['dec_gate_dict'])
            self.mem_gate.load_state_dict(state['mem_gate_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda()
            self.encoder2decoder = self.encoder2decoder.cuda()
            self.decoder = self.decoder.cuda()
            self.generator = self.generator.cuda()
            self.criterion = self.criterion.cuda()
            self.embedding = self.embedding.cuda()
            self.memory = self.memory.cuda()
            self.dec_gate = self.dec_gate.cuda()
            self.mem_gate = self.mem_gate.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self,
                   running_avg_ppl,
                   iter,
                   f1_g,
                   f1_b,
                   ent_g,
                   ent_b,
                   log=False,
                   d="tmaml_sim_model"):
        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'encoder2decoder_state_dict': self.encoder2decoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            'memory_dict': self.memory.state_dict(),
            'dec_gate_dict': self.dec_gate.state_dict(),
            'mem_gate_dict': self.mem_gate.state_dict(),
            #'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        if log:
            model_save_path = os.path.join(
                d, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                    iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        else:
            model_save_path = os.path.join(
                self.model_dir,
                'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                    iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def get_state(self, batch):
        """Get cell states and hidden states for LSTM"""
        batch_size = batch.size(0) \
            if self.encoder.batch_first else batch.size(1)
        h0_encoder = Variable(torch.zeros(self.encoder.num_layers, batch_size,
                                          config.hidden_dim),
                              requires_grad=False)
        c0_encoder = Variable(torch.zeros(self.encoder.num_layers, batch_size,
                                          config.hidden_dim),
                              requires_grad=False)

        return h0_encoder.cuda(), c0_encoder.cuda()

    def compute_hooks(self, task):
        """Compute the masks of the private module"""
        current_layer = 3
        out_mask = torch.ones(self.memory.output_size)
        self.hooks[task] = {}
        self.hooks[task]["w_hooks"] = {}
        self.hooks[task]["b_hooks"] = {}
        while (current_layer >= 0):
            connections = self.memory.layers[current_layer].weight.data
            output_size, input_size = connections.shape
            mask = connections.abs() > 0.05
            in_mask = torch.zeros(input_size)
            for index, line in enumerate(mask):
                if (out_mask[index] == 1):
                    torch.max(in_mask, (line.cpu() != 0).float(), out=in_mask)
            if (config.USE_CUDA):
                self.hooks[task]["b_hooks"][current_layer] = out_mask.cuda()
                self.hooks[task]["w_hooks"][current_layer] = torch.mm(
                    out_mask.unsqueeze(1), in_mask.unsqueeze(0)).cuda()
            else:
                self.hooks[task]["b_hooks"][current_layer] = out_mask
                self.hooks[task]["w_hooks"][current_layer] = torch.mm(
                    out_mask.unsqueeze(1), in_mask.unsqueeze(0))
            out_mask = in_mask
            current_layer -= 1

    def register_hooks(self, task):
        if "hook_handles" not in self.hooks[task]:
            self.hooks[task]["hook_handles"] = []
        for i, l in enumerate(self.memory.layers):
            self.hooks[task]["hook_handles"].append(
                l.bias.register_hook(make_hook(
                    self.hooks[task]["b_hooks"][i])))
            self.hooks[task]["hook_handles"].append(
                l.weight.register_hook(
                    make_hook(self.hooks[task]["w_hooks"][i])))

    def unhook(self, task):
        for handle in self.hooks[task]["hook_handles"]:
            handle.remove()
        self.hooks[task]["hook_handles"] = []

    def train_one_batch(self, batch, train=True, mode="pretrain", task=0):
        enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        self.h0_encoder, self.c0_encoder = self.get_state(enc_batch)
        src_h, (src_h_t,
                src_c_t) = self.encoder(self.embedding(enc_batch),
                                        (self.h0_encoder, self.c0_encoder))
        h_t = src_h_t[-1]
        c_t = src_c_t[-1]

        # Decode
        decoder_init_state = nn.Tanh()(self.encoder2decoder(h_t))

        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)
        target_embedding = self.embedding(dec_batch_shift)
        ctx = src_h.transpose(0, 1)
        trg_h, (_, _) = self.decoder(target_embedding,
                                     (decoder_init_state, c_t), ctx)

        #Memory
        mem_h_input = torch.cat(
            (decoder_init_state.unsqueeze(1), trg_h[:, 0:-1, :]), 1)
        mem_input = torch.cat((target_embedding, mem_h_input), 2)
        mem_output = self.memory(mem_input)

        #Combine
        gates = self.dec_gate(trg_h) + self.mem_gate(mem_output)
        decoder_gate, memory_gate = gates.chunk(2, 2)
        decoder_gate = F.sigmoid(decoder_gate)
        memory_gate = F.sigmoid(memory_gate)
        pre_logit = F.tanh(decoder_gate * trg_h + memory_gate * mem_output)
        logit = self.generator(pre_logit)

        if mode == "pretrain":
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
            if train:
                loss.backward()
                self.optimizer.step()
            if (config.label_smoothing):
                loss = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1))
            return loss.item(), math.exp(min(loss.item(), 100)), loss

        elif mode == "select":
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
            if (train):
                l1_loss = 0.0
                for p in self.memory.parameters():
                    l1_loss += torch.sum(torch.abs(p))
                loss += 0.0005 * l1_loss
                loss.backward()
                self.optimizer.step()
                self.compute_hooks(task)
            if (config.label_smoothing):
                loss = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1))
            return loss.item(), math.exp(min(loss.item(), 100)), loss

        else:
            loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                                  dec_batch.contiguous().view(-1))
            if (train):
                self.register_hooks(task)
                loss.backward()
                self.optimizer.step()
                self.unhook(task)
            if (config.label_smoothing):
                loss = self.criterion_ppl(
                    logit.contiguous().view(-1, logit.size(-1)),
                    dec_batch.contiguous().view(-1))
            return loss.item(), math.exp(min(loss.item(), 100)), loss
Ejemplo n.º 3
0
class VAE(nn.Module):
    def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False):
        super(VAE, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab,config.preptrained)
        self.encoder = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True,
                               dropout=0.2)
        self.encoder_r = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True,
                               dropout=0.2)
        self.represent = R_MLP(2 * config.hidden_dim, 68)
        self.prior = P_MLP(config.hidden_dim, 68)
        self.mlp_b = nn.Linear(config.hidden_dim + 68, self.vocab_size)
        self.encoder2decoder = nn.Linear(
            config.hidden_dim + 68,
            config.hidden_dim)
        self.decoder = LSTMAttentionDot(config.emb_dim, config.hidden_dim, batch_first=True)   
        self.generator = Generator(config.hidden_dim,self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.generator.proj.weight = self.embedding.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        if is_eval:
            self.encoder = self.encoder.eval()
            self.encoder_r = self.encoder_r.eval()
            self.represent = self.represent.eval()
            self.prior = self.prior.eval()
            self.mlp_b = self.mlp_b.eval()
            self.encoder2decoder = self.encoder2decoder.eval()
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()
            self.embedding = self.embedding.eval()

    
        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if(config.noam):
            self.optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
        if config.use_sgd:
            self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr)
        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path, map_location= lambda storage, location: storage)
            print("LOSS",state['current_loss'])
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.encoder_r.load_state_dict(state['encoder_r_state_dict'])
            self.represent.load_state_dict(state['represent_state_dict'])
            self.prior.load_state_dict(state['prior_state_dict'])
            self.mlp_b.load_state_dict(state['mlp_b_state_dict'])
            self.encoder2decoder.load_state_dict(state['encoder2decoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda()
            self.encoder_r = self.encoder_r.cuda()
            self.represent = self.represent.cuda()
            self.prior = self.prior.cuda()
            self.mlp_b = self.mlp_b.cuda()
            self.encoder2decoder = self.encoder2decoder.cuda()
            self.decoder = self.decoder.cuda()
            self.generator = self.generator.cuda()
            self.criterion = self.criterion.cuda()
            self.embedding = self.embedding.cuda()
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""
    
    def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b, log=False, d="save/paml_model_sim"):
        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'encoder_r_state_dict': self.encoder_r.state_dict(),
            'represent_state_dict': self.represent.state_dict(),
            'prior_state_dict': self.prior.state_dict(),
            'mlp_b_state_dict': self.mlp_b.state_dict(),
            'encoder2decoder_state_dict': self.encoder2decoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            #'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        if log:
            model_save_path = os.path.join(d, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) )
        else:
            model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) )
        self.best_path = model_save_path
        torch.save(state, model_save_path)
    
    def get_state(self, batch):
        """Get cell states and hidden states."""
        batch_size = batch.size(0) \
            if self.encoder.batch_first else batch.size(1)
        h0_encoder = Variable(torch.zeros(
            self.encoder.num_layers,
            batch_size,
            config.hidden_dim
        ), requires_grad=False)
        c0_encoder = Variable(torch.zeros(
            self.encoder.num_layers,
            batch_size,
            config.hidden_dim
        ), requires_grad=False)

        return h0_encoder.cuda(), c0_encoder.cuda()
    
    def train_one_batch(self, batch, train=True):
        ## pad and other stuff
        enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch(batch)
        dec_batch, _, _, _, _, _ = get_output_from_batch(batch)
        
        if(config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        ## Encode
        self.h0_encoder, self.c0_encoder = self.get_state(enc_batch)
        src_h, (src_h_t, src_c_t) = self.encoder(
            self.embedding(enc_batch), (self.h0_encoder, self.c0_encoder))
        h_t = src_h_t[-1]
        c_t = src_c_t[-1]
        self.h0_encoder_r, self.c0_encoder_r = self.get_state(dec_batch)
        src_h_r, (src_h_t_r, src_c_t_r) = self.encoder_r(
            self.embedding(dec_batch), (self.h0_encoder_r, self.c0_encoder_r))
        h_t_r = src_h_t_r[-1]
        c_t_r = src_c_t_r[-1]
        
        #sample and reparameter
        z_sample, mu, var = self.represent(torch.cat((h_t_r, h_t), 1))
        p_z_sample, p_mu, p_var = self.prior(h_t)
        
        # Decode
        decoder_init_state = nn.Tanh()(self.encoder2decoder(torch.cat((z_sample, h_t), 1)))
        
        sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1)
        target_embedding = self.embedding(dec_batch_shift)
        ctx = src_h.transpose(0, 1)
        trg_h, (_, _) = self.decoder(
            target_embedding,
            (decoder_init_state, c_t),
            ctx    
        )
        pre_logit = trg_h
        logit = self.generator(pre_logit)
        
        ## loss: NNL if ptr else Cross entropy
        re_loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        kl_losses = 0.5 * torch.sum(torch.exp(var - p_var) + (mu - p_mu) ** 2 / torch.exp(p_var) - 1. - var + p_var, 1)
        kl_loss = torch.mean(kl_losses)
        latent_logit = self.mlp_b(torch.cat((z_sample, h_t), 1)).unsqueeze(1)
        latent_logit = F.log_softmax(latent_logit,dim=-1)
        latent_logits = latent_logit.repeat(1, logit.size(1), 1)
        bow_loss = self.criterion(latent_logits.contiguous().view(-1, latent_logits.size(-1)), dec_batch.contiguous().view(-1))
        loss = re_loss + 0.48 * kl_loss + bow_loss
        if(train):
            loss.backward()
            self.optimizer.step()
        if(config.label_smoothing): 
            s_loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1))
        
        return s_loss.item(), math.exp(min(s_loss.item(), 100)), loss.item(), re_loss.item(), kl_loss.item(), bow_loss.item()
Ejemplo n.º 4
0
class Summarizer(nn.Module):
    def __init__(self,
                 is_draft,
                 toeknizer,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Summarizer, self).__init__()
        self.is_draft = is_draft
        self.toeknizer = toeknizer
        if is_draft:
            self.encoder = BertModel.from_pretrained('bert-base-uncased')
        else:
            BertForMaskedLM.from_pretrained('bert-base-uncased')
        self.encoder.eval()  # always in eval mode
        self.embedding = self.encoder.embeddings

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.decoder = Decoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter)

        self.generator = Generator(config.hidden_dim, config.vocab_size)

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)
        self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)
        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)

        self.embedding = self.embedding.eval()
        if is_eval:
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 4000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))
        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            print("LOSS", state['current_loss'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda(device=0)
            self.decoder = self.decoder.cuda(device=0)
            self.generator = self.generator.cuda(device=0)
            self.criterion = self.criterion.cuda(device=0)
            self.embedding = self.embedding.cuda(device=0)
        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, loss, iter, r_avg):
        state = {
            'iter': iter,
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            #'optimizer': self.optimizer.state_dict(),
            'current_loss': loss
        }
        model_save_path = os.path.join(
            self.model_dir, 'model_{}_{:.4f}_{:.4f}'.format(iter, loss, r_avg))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, train=True):
        ## pad and other stuff
        input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch(
            batch)
        dec_batch, dec_mask_batch, dec_index_batch, copy_gate, copy_ptr = get_output_from_batch(
            batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        with torch.no_grad():
            # encoder_outputs are hidden states from transformer
            encoder_outputs, _ = self.encoder(
                input_ids_batch,
                token_type_ids=example_index_batch,
                attention_mask=input_mask_batch,
                output_all_encoded_layers=False)

        # # Draft Decoder
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     input_ids_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA: sos_token = sos_token.cuda(device=0)

        dec_batch_shift = torch.cat(
            (sos_token, dec_batch[:, :-1]),
            1)  # shift the decoder input (summary) by one step
        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit1, attn_dist1 = self.decoder(self.embedding(dec_batch_shift),
                                              encoder_outputs,
                                              (None, mask_trg))

        # print(pre_logit1.size())
        ## compute output dist
        logit1 = self.generator(pre_logit1,
                                attn_dist1,
                                enc_batch_extend_vocab,
                                extra_zeros,
                                copy_gate=copy_gate,
                                copy_ptr=copy_ptr,
                                mask_trg=mask_trg)
        ## loss: NNL if ptr else Cross entropy
        loss1 = self.criterion(logit1.contiguous().view(-1, logit1.size(-1)),
                               dec_batch.contiguous().view(-1))

        # Refine Decoder - train using gold label TARGET
        'TODO: turn gold-target-text into BERT insertable representation'
        pre_logit2, attn_dist2 = self.generate_refinement_output(
            encoder_outputs, dec_batch, dec_index_batch, extra_zeros,
            dec_mask_batch)
        # pre_logit2, attn_dist2 = self.decoder(self.embedding(encoded_gold_target),encoder_outputs, (None,mask_trg))

        logit2 = self.generator(pre_logit2,
                                attn_dist2,
                                enc_batch_extend_vocab,
                                extra_zeros,
                                copy_gate=copy_gate,
                                copy_ptr=copy_ptr,
                                mask_trg=None)
        loss2 = self.criterion(logit2.contiguous().view(-1, logit2.size(-1)),
                               dec_batch.contiguous().view(-1))

        loss = loss1 + loss2

        if train:
            loss.backward()
            self.optimizer.step()
        return loss

    def eval_one_batch(self, batch):
        draft_seq_batch = self.decoder_greedy(batch)

        d_seq_input_ids_batch, d_seq_input_mask_batch, d_seq_example_index_batch = text_input2bert_input(
            draft_seq_batch, self.tokenizer)
        pre_logit2, attn_dist2 = self.generate_refinement_output(
            encoder_outputs, d_seq_input_ids_batch, d_seq_example_index_batch,
            extra_zeros, d_seq_input_mask_batch)

        decoded_words, sent = [], []
        for out, attn_dist in zip(pre_logit2, attn_dist2):
            prob = self.generator(out,
                                  attn_dist,
                                  enc_batch_extend_vocab,
                                  extra_zeros,
                                  copy_gate=copy_gate,
                                  copy_ptr=copy_ptr,
                                  mask_trg=None)
            _, next_word = torch.max(prob[:, -1], dim=1)
            decoded_words.append(
                self.tokenizer.convert_ids_to_tokens(next_word.tolist()))

        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>' or e.strip() == '<PAD>': break
                else: st += e + ' '
            sent.append(st)
        return sent

    def generate_refinement_output(self, encoder_outputs, input_ids_batch,
                                   example_index_batch, extra_zeros,
                                   input_mask_batch):
        # mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        # decoded_words = []
        logits, attns = [], []

        for i in range(config.max_dec_step):
            # print(i)
            with torch.no_grad():
                # Additionally mask the location of i.
                context_input_mask_batch = []
                # print(context_input_mask_batch.shape) # (2,512) (batch_size, seq_len)
                for mask in input_mask_batch:
                    mask[i] = 0
                    context_input_mask_batch.append(mask)

                context_input_mask_batch = torch.stack(
                    context_input_mask_batch)  #.cuda(device=0)
                # self.embedding = self.embedding.cuda(device=0)
                context_vector, _ = self.encoder(
                    input_ids_batch,
                    token_type_ids=example_index_batch,
                    attention_mask=context_input_mask_batch,
                    output_all_encoded_layers=False)

                if config.USE_CUDA:
                    context_vector = context_vector.cuda(device=0)
            # decoder input size == encoder output size == (batch_size, 512, 768)
            out, attn_dist = self.decoder(context_vector, encoder_outputs,
                                          (None, None))

            logits.append(out[:, i:i + 1, :])
            attns.append(attn_dist[:, i:i + 1, :])

        logits = torch.cat(logits, dim=1)
        attns = torch.cat(attns, dim=1)

        # print(logits.size(), attns.size())
        return logits, attns

    def decoder_greedy(self, batch):
        input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch(
            batch)
        # mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        with torch.no_grad():
            encoder_outputs, _ = self.encoder(
                input_ids_batch,
                token_type_ids=enc_batch_extend_vocab,
                attention_mask=input_mask_batch,
                output_all_encoded_layers=False)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA: ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(config.max_dec_step):
            out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs,
                                          (None, mask_trg))
            prob = self.generator(out, attn_dist, enc_batch_extend_vocab,
                                  extra_zeros)
            _, next_word = torch.max(prob[:, -1], dim=1)

            decoded_words.append(
                self.tokenizer.convert_ids_to_tokens(next_word.tolist()))

            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>' or e == '<PAD>': break
                else: st += e + ' '
            sent.append(st)
        return sent
Ejemplo n.º 5
0
class Transformer(nn.Module):
    def __init__(self,
                 vocab,
                 model_file_path=None,
                 is_eval=False,
                 load_optim=False):
        super(Transformer, self).__init__()
        self.vocab = vocab
        self.vocab_size = vocab.n_words

        self.embedding = share_embedding(self.vocab, config.preptrained)
        self.encoder = Encoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)

        self.decoder = Decoder(config.emb_dim,
                               config.hidden_dim,
                               num_layers=config.hop,
                               num_heads=config.heads,
                               total_key_depth=config.depth,
                               total_value_depth=config.depth,
                               filter_size=config.filter,
                               universal=config.universal)
        self.generator = Generator(config.hidden_dim, self.vocab_size)

        if config.weight_sharing:
            # Share the weight matrix between target word embedding & the final
            # logit dense layer
            self.generator.proj.weight = self.embedding.weight

        self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx)

        if (config.label_smoothing):
            self.criterion = LabelSmoothing(size=self.vocab_size,
                                            padding_idx=config.PAD_idx,
                                            smoothing=0.1)
            self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx)

        if is_eval:
            self.encoder = self.encoder.eval()
            self.decoder = self.decoder.eval()
            self.generator = self.generator.eval()
            self.embedding = self.embedding.eval()

        self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)

        if (config.noam):
            self.optimizer = NoamOpt(
                config.hidden_dim, 1, 4000,
                torch.optim.Adam(self.parameters(),
                                 lr=0,
                                 betas=(0.9, 0.98),
                                 eps=1e-9))

        if config.use_sgd:
            self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr)

        if model_file_path is not None:
            print("loading weights")
            state = torch.load(model_file_path,
                               map_location=lambda storage, location: storage)
            print("LOSS", state['current_loss'])
            self.encoder.load_state_dict(state['encoder_state_dict'])
            self.decoder.load_state_dict(state['decoder_state_dict'])
            self.generator.load_state_dict(state['generator_dict'])
            self.embedding.load_state_dict(state['embedding_dict'])
            if (load_optim):
                self.optimizer.load_state_dict(state['optimizer'])

        if (config.USE_CUDA):
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.generator = self.generator.cuda()
            self.criterion = self.criterion.cuda()
            self.embedding = self.embedding.cuda()

        self.model_dir = config.save_path
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        self.best_path = ""

    def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b):
        state = {
            'iter': iter,
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'generator_dict': self.generator.state_dict(),
            'embedding_dict': self.embedding.state_dict(),
            # 'optimizer': self.optimizer.state_dict(),
            'current_loss': running_avg_ppl
        }
        model_save_path = os.path.join(
            self.model_dir,
            'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(
                iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b))
        self.best_path = model_save_path
        torch.save(state, model_save_path)

    def train_one_batch(self, batch, train=True):
        # pad and other stuff
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = \
            get_input_from_batch(batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)

        if (config.noam):
            self.optimizer.optimizer.zero_grad()
        else:
            self.optimizer.zero_grad()

        # Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        # Decode
        sos_token = torch.LongTensor([config.SOS_idx] *
                                     enc_batch.size(0)).unsqueeze(1)
        if config.USE_CUDA:
            sos_token = sos_token.cuda()
        dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1)

        mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
        pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),
                                            encoder_outputs,
                                            (mask_src, mask_trg))
        # compute output dist
        logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab,
                               extra_zeros)

        # loss: NNL if ptr else Cross entropy
        loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)),
                              dec_batch.contiguous().view(-1))

        if (config.act):
            loss += self.compute_act_loss(self.encoder)
            loss += self.compute_act_loss(self.decoder)

        if (train):
            loss.backward()
            self.optimizer.step()
        if (config.label_smoothing):
            loss = self.criterion_ppl(
                logit.contiguous().view(-1, logit.size(-1)),
                dec_batch.contiguous().view(-1))

        return loss.item(), math.exp(min(loss.item(), 100)), loss

    def compute_act_loss(self, module):
        R_t = module.remainders
        N_t = module.n_updates
        p_t = R_t + N_t
        avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0)
        loss = config.act_loss_weight * avg_p_t.item()
        return loss

    def decoder_greedy(self, batch):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(config.max_dec_step):
            out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs,
                                          (mask_src, mask_trg))
            prob = self.generator(out, attn_dist, enc_batch_extend_vocab,
                                  extra_zeros)
            _, next_word = torch.max(prob[:, -1], dim=1)
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]

            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent

    def score_sentence(self, batch):
        # pad and other stuff
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        dec_batch, _, _, _, _ = get_output_from_batch(batch)
        cand_batch = batch["cand_index"]
        hit_1 = 0
        for i, b in enumerate(enc_batch):
            # Encode
            mask_src = b.unsqueeze(0).data.eq(config.PAD_idx).unsqueeze(1)
            encoder_outputs = self.encoder(self.embedding(b.unsqueeze(0)),
                                           mask_src)
            rank = {}
            for j, c in enumerate(cand_batch[i]):
                if config.USE_CUDA:
                    c = c.cuda()
                # Decode
                sos_token = torch.LongTensor(
                    [config.SOS_idx] * b.unsqueeze(0).size(0)).unsqueeze(1)
                if config.USE_CUDA:
                    sos_token = sos_token.cuda()
                dec_batch_shift = torch.cat(
                    (sos_token, c.unsqueeze(0)[:, :-1]), 1)

                mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1)
                pre_logit, attn_dist = self.decoder(
                    self.embedding(dec_batch_shift), encoder_outputs,
                    (mask_src, mask_trg))

                # compute output dist
                logit = self.generator(pre_logit, attn_dist,
                                       enc_batch_extend_vocab[i].unsqueeze(0),
                                       extra_zeros)
                loss = self.criterion(
                    logit.contiguous().view(-1, logit.size(-1)),
                    c.unsqueeze(0).contiguous().view(-1))
                # print("CANDIDATE {}".format(j), loss.item(), math.exp(min(loss.item(), 100)))
                rank[j] = math.exp(min(loss.item(), 100))
            s = sorted(rank.items(), key=lambda x: x[1], reverse=False)
            if (
                    s[1][0] == 19
            ):  # because the candidate are sorted in revers order ====> last (19) is the correct one
                hit_1 += 1
        return hit_1 / float(len(enc_batch))