class Transformer(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(Transformer, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab,config.preptrained) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter,universal=config.universal) self.decoder = Decoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth,total_value_depth=config.depth, filter_size=config.filter,universal=config.universal) self.generator = Generator(config.hidden_dim,self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if is_eval: self.encoder = self.encoder.eval() self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.embedding = self.embedding.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if(config.noam): self.optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if config.use_sgd: self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location= lambda storage, location: storage) print("LOSS",state['current_loss']) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if (config.USE_CUDA): self.encoder = self.encoder.cuda() self.decoder = self.decoder.cuda() self.generator = self.generator.cuda() self.criterion = self.criterion.cuda() self.embedding = self.embedding.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), #'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) ) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, train=True): ## pad and other stuff enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(enc_batch),mask_src) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),encoder_outputs, (mask_src,mask_trg)) ## compute output dist logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab, extra_zeros) ## loss: NNL if ptr else Cross entropy loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if(train): loss.backward() self.optimizer.step() if(config.label_smoothing): loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss
class Seq2SPG(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(Seq2SPG, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.preptrained) self.encoder = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True, dropout=0.2) self.encoder2decoder = nn.Linear(config.hidden_dim, config.hidden_dim) self.decoder = LSTMAttentionDot(config.emb_dim, config.hidden_dim, batch_first=True) self.memory = MLP( config.hidden_dim + config.emb_dim, [config.private_dim1, config.private_dim2, config.private_dim3], config.hidden_dim) self.dec_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim) self.mem_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim) self.generator = Generator(config.hidden_dim, self.vocab_size) self.hooks = { } #Save the model structure of each task as masks of the parameters if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if is_eval: self.encoder = self.encoder.eval() self.encoder2decoder = self.encoder2decoder.eval() self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.embedding = self.embedding.eval() self.memory = self.memory.eval() self.dec_gate = self.dec_gate.eval() self.mem_gate = self.mem_gate.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if config.use_sgd: self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) print("LOSS", state['current_loss']) self.encoder.load_state_dict(state['encoder_state_dict']) self.encoder2decoder.load_state_dict( state['encoder2decoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) self.memory.load_state_dict(state['memory_dict']) self.dec_gate.load_state_dict(state['dec_gate_dict']) self.mem_gate.load_state_dict(state['mem_gate_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if (config.USE_CUDA): self.encoder = self.encoder.cuda() self.encoder2decoder = self.encoder2decoder.cuda() self.decoder = self.decoder.cuda() self.generator = self.generator.cuda() self.criterion = self.criterion.cuda() self.embedding = self.embedding.cuda() self.memory = self.memory.cuda() self.dec_gate = self.dec_gate.cuda() self.mem_gate = self.mem_gate.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b, log=False, d="tmaml_sim_model"): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'encoder2decoder_state_dict': self.encoder2decoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'memory_dict': self.memory.state_dict(), 'dec_gate_dict': self.dec_gate.state_dict(), 'mem_gate_dict': self.mem_gate.state_dict(), #'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } if log: model_save_path = os.path.join( d, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) else: model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def get_state(self, batch): """Get cell states and hidden states for LSTM""" batch_size = batch.size(0) \ if self.encoder.batch_first else batch.size(1) h0_encoder = Variable(torch.zeros(self.encoder.num_layers, batch_size, config.hidden_dim), requires_grad=False) c0_encoder = Variable(torch.zeros(self.encoder.num_layers, batch_size, config.hidden_dim), requires_grad=False) return h0_encoder.cuda(), c0_encoder.cuda() def compute_hooks(self, task): """Compute the masks of the private module""" current_layer = 3 out_mask = torch.ones(self.memory.output_size) self.hooks[task] = {} self.hooks[task]["w_hooks"] = {} self.hooks[task]["b_hooks"] = {} while (current_layer >= 0): connections = self.memory.layers[current_layer].weight.data output_size, input_size = connections.shape mask = connections.abs() > 0.05 in_mask = torch.zeros(input_size) for index, line in enumerate(mask): if (out_mask[index] == 1): torch.max(in_mask, (line.cpu() != 0).float(), out=in_mask) if (config.USE_CUDA): self.hooks[task]["b_hooks"][current_layer] = out_mask.cuda() self.hooks[task]["w_hooks"][current_layer] = torch.mm( out_mask.unsqueeze(1), in_mask.unsqueeze(0)).cuda() else: self.hooks[task]["b_hooks"][current_layer] = out_mask self.hooks[task]["w_hooks"][current_layer] = torch.mm( out_mask.unsqueeze(1), in_mask.unsqueeze(0)) out_mask = in_mask current_layer -= 1 def register_hooks(self, task): if "hook_handles" not in self.hooks[task]: self.hooks[task]["hook_handles"] = [] for i, l in enumerate(self.memory.layers): self.hooks[task]["hook_handles"].append( l.bias.register_hook(make_hook( self.hooks[task]["b_hooks"][i]))) self.hooks[task]["hook_handles"].append( l.weight.register_hook( make_hook(self.hooks[task]["w_hooks"][i]))) def unhook(self, task): for handle in self.hooks[task]["hook_handles"]: handle.remove() self.hooks[task]["hook_handles"] = [] def train_one_batch(self, batch, train=True, mode="pretrain", task=0): enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode self.h0_encoder, self.c0_encoder = self.get_state(enc_batch) src_h, (src_h_t, src_c_t) = self.encoder(self.embedding(enc_batch), (self.h0_encoder, self.c0_encoder)) h_t = src_h_t[-1] c_t = src_c_t[-1] # Decode decoder_init_state = nn.Tanh()(self.encoder2decoder(h_t)) sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) target_embedding = self.embedding(dec_batch_shift) ctx = src_h.transpose(0, 1) trg_h, (_, _) = self.decoder(target_embedding, (decoder_init_state, c_t), ctx) #Memory mem_h_input = torch.cat( (decoder_init_state.unsqueeze(1), trg_h[:, 0:-1, :]), 1) mem_input = torch.cat((target_embedding, mem_h_input), 2) mem_output = self.memory(mem_input) #Combine gates = self.dec_gate(trg_h) + self.mem_gate(mem_output) decoder_gate, memory_gate = gates.chunk(2, 2) decoder_gate = F.sigmoid(decoder_gate) memory_gate = F.sigmoid(memory_gate) pre_logit = F.tanh(decoder_gate * trg_h + memory_gate * mem_output) logit = self.generator(pre_logit) if mode == "pretrain": loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if train: loss.backward() self.optimizer.step() if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss elif mode == "select": loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if (train): l1_loss = 0.0 for p in self.memory.parameters(): l1_loss += torch.sum(torch.abs(p)) loss += 0.0005 * l1_loss loss.backward() self.optimizer.step() self.compute_hooks(task) if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss else: loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if (train): self.register_hooks(task) loss.backward() self.optimizer.step() self.unhook(task) if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss
class VAE(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(VAE, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab,config.preptrained) self.encoder = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True, dropout=0.2) self.encoder_r = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True, dropout=0.2) self.represent = R_MLP(2 * config.hidden_dim, 68) self.prior = P_MLP(config.hidden_dim, 68) self.mlp_b = nn.Linear(config.hidden_dim + 68, self.vocab_size) self.encoder2decoder = nn.Linear( config.hidden_dim + 68, config.hidden_dim) self.decoder = LSTMAttentionDot(config.emb_dim, config.hidden_dim, batch_first=True) self.generator = Generator(config.hidden_dim,self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if is_eval: self.encoder = self.encoder.eval() self.encoder_r = self.encoder_r.eval() self.represent = self.represent.eval() self.prior = self.prior.eval() self.mlp_b = self.mlp_b.eval() self.encoder2decoder = self.encoder2decoder.eval() self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.embedding = self.embedding.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if(config.noam): self.optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if config.use_sgd: self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location= lambda storage, location: storage) print("LOSS",state['current_loss']) self.encoder.load_state_dict(state['encoder_state_dict']) self.encoder_r.load_state_dict(state['encoder_r_state_dict']) self.represent.load_state_dict(state['represent_state_dict']) self.prior.load_state_dict(state['prior_state_dict']) self.mlp_b.load_state_dict(state['mlp_b_state_dict']) self.encoder2decoder.load_state_dict(state['encoder2decoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if (config.USE_CUDA): self.encoder = self.encoder.cuda() self.encoder_r = self.encoder_r.cuda() self.represent = self.represent.cuda() self.prior = self.prior.cuda() self.mlp_b = self.mlp_b.cuda() self.encoder2decoder = self.encoder2decoder.cuda() self.decoder = self.decoder.cuda() self.generator = self.generator.cuda() self.criterion = self.criterion.cuda() self.embedding = self.embedding.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b, log=False, d="save/paml_model_sim"): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'encoder_r_state_dict': self.encoder_r.state_dict(), 'represent_state_dict': self.represent.state_dict(), 'prior_state_dict': self.prior.state_dict(), 'mlp_b_state_dict': self.mlp_b.state_dict(), 'encoder2decoder_state_dict': self.encoder2decoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), #'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } if log: model_save_path = os.path.join(d, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) ) else: model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) ) self.best_path = model_save_path torch.save(state, model_save_path) def get_state(self, batch): """Get cell states and hidden states.""" batch_size = batch.size(0) \ if self.encoder.batch_first else batch.size(1) h0_encoder = Variable(torch.zeros( self.encoder.num_layers, batch_size, config.hidden_dim ), requires_grad=False) c0_encoder = Variable(torch.zeros( self.encoder.num_layers, batch_size, config.hidden_dim ), requires_grad=False) return h0_encoder.cuda(), c0_encoder.cuda() def train_one_batch(self, batch, train=True): ## pad and other stuff enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _, _ = get_output_from_batch(batch) if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode self.h0_encoder, self.c0_encoder = self.get_state(enc_batch) src_h, (src_h_t, src_c_t) = self.encoder( self.embedding(enc_batch), (self.h0_encoder, self.c0_encoder)) h_t = src_h_t[-1] c_t = src_c_t[-1] self.h0_encoder_r, self.c0_encoder_r = self.get_state(dec_batch) src_h_r, (src_h_t_r, src_c_t_r) = self.encoder_r( self.embedding(dec_batch), (self.h0_encoder_r, self.c0_encoder_r)) h_t_r = src_h_t_r[-1] c_t_r = src_c_t_r[-1] #sample and reparameter z_sample, mu, var = self.represent(torch.cat((h_t_r, h_t), 1)) p_z_sample, p_mu, p_var = self.prior(h_t) # Decode decoder_init_state = nn.Tanh()(self.encoder2decoder(torch.cat((z_sample, h_t), 1))) sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1) target_embedding = self.embedding(dec_batch_shift) ctx = src_h.transpose(0, 1) trg_h, (_, _) = self.decoder( target_embedding, (decoder_init_state, c_t), ctx ) pre_logit = trg_h logit = self.generator(pre_logit) ## loss: NNL if ptr else Cross entropy re_loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) kl_losses = 0.5 * torch.sum(torch.exp(var - p_var) + (mu - p_mu) ** 2 / torch.exp(p_var) - 1. - var + p_var, 1) kl_loss = torch.mean(kl_losses) latent_logit = self.mlp_b(torch.cat((z_sample, h_t), 1)).unsqueeze(1) latent_logit = F.log_softmax(latent_logit,dim=-1) latent_logits = latent_logit.repeat(1, logit.size(1), 1) bow_loss = self.criterion(latent_logits.contiguous().view(-1, latent_logits.size(-1)), dec_batch.contiguous().view(-1)) loss = re_loss + 0.48 * kl_loss + bow_loss if(train): loss.backward() self.optimizer.step() if(config.label_smoothing): s_loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return s_loss.item(), math.exp(min(s_loss.item(), 100)), loss.item(), re_loss.item(), kl_loss.item(), bow_loss.item()
class Summarizer(nn.Module): def __init__(self, is_draft, toeknizer, model_file_path=None, is_eval=False, load_optim=False): super(Summarizer, self).__init__() self.is_draft = is_draft self.toeknizer = toeknizer if is_draft: self.encoder = BertModel.from_pretrained('bert-base-uncased') else: BertForMaskedLM.from_pretrained('bert-base-uncased') self.encoder.eval() # always in eval mode self.embedding = self.encoder.embeddings self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') self.decoder = Decoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.generator = Generator(config.hidden_dim, config.vocab_size) self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.embedding = self.embedding.eval() if is_eval: self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) print("LOSS", state['current_loss']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (config.USE_CUDA): self.encoder = self.encoder.cuda(device=0) self.decoder = self.decoder.cuda(device=0) self.generator = self.generator.cuda(device=0) self.criterion = self.criterion.cuda(device=0) self.embedding = self.embedding.cuda(device=0) self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, loss, iter, r_avg): state = { 'iter': iter, 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), #'optimizer': self.optimizer.state_dict(), 'current_loss': loss } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}'.format(iter, loss, r_avg)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, train=True): ## pad and other stuff input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch( batch) dec_batch, dec_mask_batch, dec_index_batch, copy_gate, copy_ptr = get_output_from_batch( batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() with torch.no_grad(): # encoder_outputs are hidden states from transformer encoder_outputs, _ = self.encoder( input_ids_batch, token_type_ids=example_index_batch, attention_mask=input_mask_batch, output_all_encoded_layers=False) # # Draft Decoder sos_token = torch.LongTensor([config.SOS_idx] * input_ids_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda(device=0) dec_batch_shift = torch.cat( (sos_token, dec_batch[:, :-1]), 1) # shift the decoder input (summary) by one step mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit1, attn_dist1 = self.decoder(self.embedding(dec_batch_shift), encoder_outputs, (None, mask_trg)) # print(pre_logit1.size()) ## compute output dist logit1 = self.generator(pre_logit1, attn_dist1, enc_batch_extend_vocab, extra_zeros, copy_gate=copy_gate, copy_ptr=copy_ptr, mask_trg=mask_trg) ## loss: NNL if ptr else Cross entropy loss1 = self.criterion(logit1.contiguous().view(-1, logit1.size(-1)), dec_batch.contiguous().view(-1)) # Refine Decoder - train using gold label TARGET 'TODO: turn gold-target-text into BERT insertable representation' pre_logit2, attn_dist2 = self.generate_refinement_output( encoder_outputs, dec_batch, dec_index_batch, extra_zeros, dec_mask_batch) # pre_logit2, attn_dist2 = self.decoder(self.embedding(encoded_gold_target),encoder_outputs, (None,mask_trg)) logit2 = self.generator(pre_logit2, attn_dist2, enc_batch_extend_vocab, extra_zeros, copy_gate=copy_gate, copy_ptr=copy_ptr, mask_trg=None) loss2 = self.criterion(logit2.contiguous().view(-1, logit2.size(-1)), dec_batch.contiguous().view(-1)) loss = loss1 + loss2 if train: loss.backward() self.optimizer.step() return loss def eval_one_batch(self, batch): draft_seq_batch = self.decoder_greedy(batch) d_seq_input_ids_batch, d_seq_input_mask_batch, d_seq_example_index_batch = text_input2bert_input( draft_seq_batch, self.tokenizer) pre_logit2, attn_dist2 = self.generate_refinement_output( encoder_outputs, d_seq_input_ids_batch, d_seq_example_index_batch, extra_zeros, d_seq_input_mask_batch) decoded_words, sent = [], [] for out, attn_dist in zip(pre_logit2, attn_dist2): prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, copy_gate=copy_gate, copy_ptr=copy_ptr, mask_trg=None) _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append( self.tokenizer.convert_ids_to_tokens(next_word.tolist())) for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>' or e.strip() == '<PAD>': break else: st += e + ' ' sent.append(st) return sent def generate_refinement_output(self, encoder_outputs, input_ids_batch, example_index_batch, extra_zeros, input_mask_batch): # mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) # decoded_words = [] logits, attns = [], [] for i in range(config.max_dec_step): # print(i) with torch.no_grad(): # Additionally mask the location of i. context_input_mask_batch = [] # print(context_input_mask_batch.shape) # (2,512) (batch_size, seq_len) for mask in input_mask_batch: mask[i] = 0 context_input_mask_batch.append(mask) context_input_mask_batch = torch.stack( context_input_mask_batch) #.cuda(device=0) # self.embedding = self.embedding.cuda(device=0) context_vector, _ = self.encoder( input_ids_batch, token_type_ids=example_index_batch, attention_mask=context_input_mask_batch, output_all_encoded_layers=False) if config.USE_CUDA: context_vector = context_vector.cuda(device=0) # decoder input size == encoder output size == (batch_size, 512, 768) out, attn_dist = self.decoder(context_vector, encoder_outputs, (None, None)) logits.append(out[:, i:i + 1, :]) attns.append(attn_dist[:, i:i + 1, :]) logits = torch.cat(logits, dim=1) attns = torch.cat(attns, dim=1) # print(logits.size(), attns.size()) return logits, attns def decoder_greedy(self, batch): input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch( batch) # mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) with torch.no_grad(): encoder_outputs, _ = self.encoder( input_ids_batch, token_type_ids=enc_batch_extend_vocab, attention_mask=input_mask_batch, output_all_encoded_layers=False) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(config.max_dec_step): out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (None, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros) _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append( self.tokenizer.convert_ids_to_tokens(next_word.tolist())) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>' or e == '<PAD>': break else: st += e + ' ' sent.append(st) return sent
class Transformer(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(Transformer, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.preptrained) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder = Decoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.generator = Generator(config.hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final # logit dense layer self.generator.proj.weight = self.embedding.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if is_eval: self.encoder = self.encoder.eval() self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.embedding = self.embedding.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if config.use_sgd: self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) print("LOSS", state['current_loss']) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if (config.USE_CUDA): self.encoder = self.encoder.cuda() self.decoder = self.decoder.cuda() self.generator = self.generator.cuda() self.criterion = self.criterion.cuda() self.embedding = self.embedding.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), # 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, train=True): # pad and other stuff enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = \ get_input_from_batch(batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() # Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift), encoder_outputs, (mask_src, mask_trg)) # compute output dist logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab, extra_zeros) # loss: NNL if ptr else Cross entropy loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if (config.act): loss += self.compute_act_loss(self.encoder) loss += self.compute_act_loss(self.decoder) if (train): loss.backward() self.optimizer.step() if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss def compute_act_loss(self, module): R_t = module.remainders N_t = module.n_updates p_t = R_t + N_t avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0) loss = config.act_loss_weight * avg_p_t.item() return loss def decoder_greedy(self, batch): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(config.max_dec_step): out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (mask_src, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros) _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent def score_sentence(self, batch): # pad and other stuff enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) cand_batch = batch["cand_index"] hit_1 = 0 for i, b in enumerate(enc_batch): # Encode mask_src = b.unsqueeze(0).data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(b.unsqueeze(0)), mask_src) rank = {} for j, c in enumerate(cand_batch[i]): if config.USE_CUDA: c = c.cuda() # Decode sos_token = torch.LongTensor( [config.SOS_idx] * b.unsqueeze(0).size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat( (sos_token, c.unsqueeze(0)[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder( self.embedding(dec_batch_shift), encoder_outputs, (mask_src, mask_trg)) # compute output dist logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab[i].unsqueeze(0), extra_zeros) loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), c.unsqueeze(0).contiguous().view(-1)) # print("CANDIDATE {}".format(j), loss.item(), math.exp(min(loss.item(), 100))) rank[j] = math.exp(min(loss.item(), 100)) s = sorted(rank.items(), key=lambda x: x[1], reverse=False) if ( s[1][0] == 19 ): # because the candidate are sorted in revers order ====> last (19) is the correct one hit_1 += 1 return hit_1 / float(len(enc_batch))