def __init__(self, vocab, emo_number, model_file_path=None, is_eval=False, load_optim=False): super(Transformer, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder = Decoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.generator = Generator(config.hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = ""
def __init__(self, vocab, emo_number, model_file_path=None, is_eval=False, load_optim=False): super(CvaeTrans, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, pretrain=False) self.word_encoder = WordEncoder(config.emb_dim, config.hidden_dim, config.bidirectional) self.encoder = Encoder(config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.r_encoder = Encoder(config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder = VarDecoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, vocab_size=self.vocab_size) self.generator = Generator(config.hidden_dim, self.vocab_size) self.linear = nn.Linear(2 * config.hidden_dim, config.hidden_dim) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if model_file_path: print("loading weights") state = torch.load(model_file_path, map_location= lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) #self.r_encoder.load_state_dict(state['r_encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (config.USE_CUDA): self.cuda() if is_eval: self.eval() else: self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if(config.noam): self.optimizer = NoamOpt(config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if config.USE_CUDA: for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = ""
def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(PGNet, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder() self.decoder = Decoder() self.reduce_state = ReduceState() self.generator = Generator(config.rnn_hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if config.label_smoothing: self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if config.noam: self.optimizer = NoamOpt( config.rnn_hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if load_optim: self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = ""
class CvaeTrans(nn.Module): def __init__(self, vocab, emo_number, model_file_path=None, is_eval=False, load_optim=False): super(CvaeTrans, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, pretrain=False) self.word_encoder = WordEncoder(config.emb_dim, config.hidden_dim, config.bidirectional) self.encoder = Encoder(config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.r_encoder = Encoder(config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder = VarDecoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, vocab_size=self.vocab_size) self.generator = Generator(config.hidden_dim, self.vocab_size) self.linear = nn.Linear(2 * config.hidden_dim, config.hidden_dim) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if model_file_path: print("loading weights") state = torch.load(model_file_path, map_location= lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) #self.r_encoder.load_state_dict(state['r_encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (config.USE_CUDA): self.cuda() if is_eval: self.eval() else: self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if(config.noam): self.optimizer = NoamOpt(config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if config.USE_CUDA: for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), #'r_encoder_state_dict': self.r_encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) ) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, iter, train=True): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Response encode mask_res = batch["posterior_batch"].data.eq(config.PAD_idx).unsqueeze(1) post_emb = self.embedding(batch["posterior_batch"]) r_encoder_outputs = self.r_encoder(post_emb, mask_res) ## Encode num_sentences, enc_seq_len = enc_batch.size() batch_size = enc_lens.size(0) max_len = enc_lens.data.max().item() input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1) # word level encoder enc_emb = self.embedding(enc_batch) word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths) word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1) # pad and pack word_encoder_hidden start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0) word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len) for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0) # mask_src = ~(enc_padding_mask.bool()).unsqueeze(1) mask_src = (1 - enc_padding_mask.byte()).unsqueeze(1) # context level encoder if word_encoder_hidden.size(-1) != config.hidden_dim: word_encoder_hidden = self.linear(word_encoder_hidden) encoder_outputs = self.encoder(word_encoder_hidden, mask_src) # Decode sos_token = torch.LongTensor([config.SOS_idx] * batch_size).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) #(batch, len, embedding) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) dec_emb = self.embedding(dec_batch_shift) pre_logit, attn_dist, mean, log_var, probs = self.decoder(dec_emb, encoder_outputs, r_encoder_outputs, (mask_src, mask_res, mask_trg)) ## compute output dist logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) ## loss: NNL if ptr else Cross entropy sbow = dec_batch #[batch, seq_len] seq_len = sbow.size(1) loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if config.model=="cvaetrs": loss_aux = 0 for prob in probs: sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2) sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len] loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1)) kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"]) kld_loss = torch.mean(kld_loss) kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1) #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0 loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux elbo = loss_rec + kld_loss else: loss = loss_rec elbo = loss_rec kld_loss = torch.Tensor([0]) loss_aux = torch.Tensor([0]) if(train): loss.backward() # clip gradient nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm) self.optimizer.step() return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item() def train_n_batch(self, batchs, iter, train=True): if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() for batch in batchs: enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) meta = self.embedding(batch["program_label"]) if config.dataset=="empathetic": meta = meta-meta # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist, mean, log_var, probs= self.decoder(self.embedding(dec_batch_shift)+meta.unsqueeze(1),encoder_outputs, True, (mask_src,mask_trg)) ## compute output dist logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) ## loss: NNL if ptr else Cross entropy sbow = dec_batch #[batch, seq_len] seq_len = sbow.size(1) loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if config.model=="cvaetrs": loss_aux = 0 for prob in probs: sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2) sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len] loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1)) kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"]) kld_loss = torch.mean(kld_loss) kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1) #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0 loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux elbo = loss_rec+kld_loss else: loss = loss_rec elbo = loss_rec kld_loss = torch.Tensor([0]) loss_aux = torch.Tensor([0]) loss.backward() # clip gradient nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm) self.optimizer.step() return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item() def decoder_greedy(self, batch, max_dec_step=50): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch) ## Encode num_sentences, enc_seq_len = enc_batch.size() batch_size = enc_lens.size(0) max_len = enc_lens.data.max().item() input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1) # word level encoder enc_emb = self.embedding(enc_batch) word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths) word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1) # pad and pack word_encoder_hidden start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0) word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len) for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0) mask_src = ~(enc_padding_mask.bool()).unsqueeze(1) # context level encoder if word_encoder_hidden.size(-1) != config.hidden_dim: word_encoder_hidden = self.linear(word_encoder_hidden) encoder_outputs = self.encoder(word_encoder_hidden, mask_src) ys = torch.ones(batch_size, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step+1): out, attn_dist, _, _,_ = self.decoder(self.embedding(ys), encoder_outputs, None, (mask_src, None, mask_trg)) prob = self.generator(out,attn_dist,enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, next_word = torch.max(prob[:, -1], dim = 1) decoded_words.append(['<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1)]) if config.USE_CUDA: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) ys = ys.cuda() else: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st+= e + ' ' sent.append(st) return sent
class AOT(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(AOT, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.sse = SSE(vocab, config.emb_dim, config.dropout, config.rnn_hidden_dim) self.rcr = RCR() ## multiple decoders self.decoder = Decoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.generator = Generator(config.hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if config.label_smoothing: self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if config.noam: self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if load_optim: self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch_slow(self, batch, iter, train=True): enc_batch = batch["review_batch"] enc_batch_extend_vocab = batch["review_ext_batch"] src_batch = batch[ 'reviews_batch'] # reviews sequence (bsz, r_num, r_len) src_mask = batch[ 'reviews_mask'] # indicate which review is fake(for padding). (bsz, r_num) src_length = batch['reviews_length'] # (bsz, r_num) enc_length_batch = batch[ 'reviews_length_list'] # 2-dim list, 0: len=bsz, 1: lens of reviews and pads src_labels = batch['reviews_label'] # (bsz, r_num) oovs = batch["oovs"] max_oov_length = len( sorted(oovs, key=lambda i: len(i), reverse=True)[0]) extra_zeros = Variable(torch.zeros( (enc_batch.size(0), max_oov_length))).to(config.device) dec_batch = batch["tags_batch"] dec_ext_batch = batch["tags_ext_batch"] dec_rank_batch = batch[ 'tags_idx_batch'] # tag indexes sequence (bsz, tgt_len) if config.noam: self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() # 1. Sentence-level Salience Estimation (SSE) cla_loss, sa_scores, sa_acc = self.sse.salience_estimate( src_batch, src_mask, src_length, src_labels) # sa_scores: (bsz, r_num) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze( 1) # (bsz, src_len)->(bsz, 1, src_len) # emb_mask = self.embedding(batch["mask_context"]) # src_emb = self.embedding(enc_batch)+emb_mask src_emb = self.embedding(enc_batch) encoder_outputs = self.encoder(src_emb, mask_src) # (bsz, src_len, emb_dim) src_enc_rank = torch.FloatTensor([]).to( config.device) # (bsz, src_len, emb_dim) src_ext_rank = torch.LongTensor([]).to(config.device) # (bsz, src_len) aln_rank = torch.LongTensor([]).to( config.device) # (bsz, tgt_len, src_len) aln_mask_rank = torch.FloatTensor([]).to( config.device) # (bsz, tgt_len, src_len) bsz, max_src_len = enc_batch.size() for idx in range(bsz): # Custering (by k-means) and Ranking item_length = enc_length_batch[idx] reviews = torch.split(encoder_outputs[idx], item_length, dim=0) reviews_ext = torch.split(enc_batch_extend_vocab[idx], item_length, dim=0) r_vectors = [] # store the vector repr of each review rs_vectors = [] # store the token vectors repr of each review r_exts = [] r_pad_vec, r_ext_pad = None, None for r_idx in range(len(item_length)): if r_idx == len(item_length) - 1: r_pad_vec = reviews[r_idx] r_ext_pad = reviews_ext[r_idx] break r = self.rcr.hierarchical_pooling(reviews[r_idx].unsqueeze( 0)).squeeze(0).detach().cpu().numpy() * sa_scores[idx, r_idx] r_vectors.append(r) rs_vectors.append(reviews[r_idx]) r_exts.append(reviews_ext[r_idx]) rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln = \ self.rcr.perform(r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, dec_rank_batch[idx], max_src_len) # rs_repr: (max_rs_length, embed_dim); ext_repr: (max_rs_length); srctgt_aln_mask/srctgt_aln: (tgt_len, max_rs_length) src_enc_rank = torch.cat((src_enc_rank, rs_repr.unsqueeze(0)), dim=0) # (1->bsz, max_src_len, embed_dim) src_ext_rank = torch.cat((src_ext_rank, ext_repr.unsqueeze(0)), dim=0) # (1->bsz, max_src_len) aln_rank = torch.cat((aln_rank, srctgt_aln.unsqueeze(0)), dim=0) # (1->bsz, max_tgt_len, max_src_len) aln_mask_rank = torch.cat( (aln_mask_rank, srctgt_aln_mask.unsqueeze(0)), dim=0) del encoder_outputs, reviews, reviews_ext, r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln torch.cuda.empty_cache() torch.backends.cuda.cufft_plan_cache.clear() ys = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1).to( config.device) # (bsz, 1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) ys_rank = torch.LongTensor([1] * enc_batch.size(0)).unsqueeze(1).to( config.device) max_tgt_len = dec_batch.size(1) loss, loss_ppl = 0, 0 for t in range(max_tgt_len): aln_rank_cur = aln_rank[:, t, :].unsqueeze(1) # (bsz, 1, src_len) aln_mask_cur = aln_mask_rank[:, :(t + 1), :] # (bsz, src_len) pre_logit, attn_dist, aln_loss_cur = self.decoder( inputs=self.embedding(ys), inputs_rank=ys_rank, encoder_output=src_enc_rank, aln_rank=aln_rank_cur, aln_mask_rank=aln_mask_cur, mask=(mask_src, mask_trg), speed='slow') # todo if iter >= 13000: loss += (0.1 * aln_loss_cur) else: loss += aln_loss_cur logit = self.generator( pre_logit, attn_dist.unsqueeze(1), enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros) if config.pointer_gen: loss += self.criterion( logit[:, -1, :].contiguous().view(-1, logit.size(-1)), dec_ext_batch[:, t].contiguous().view(-1)) else: loss += self.criterion( logit[:, -1, :].contiguous().view(-1, logit.size(-1)), dec_batch[:, t].contiguous().view(-1)) if config.label_smoothing: loss_ppl += self.criterion_ppl( logit[:, -1, :].contiguous().view(-1, logit.size(-1)), dec_ext_batch[:, t].contiguous().view(-1) if config.pointer_gen else dec_batch[:, t].contiguous().view(-1)) ys = torch.cat((ys, dec_batch[:, t].unsqueeze(1)), dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) ys_rank = torch.cat((ys_rank, dec_rank_batch[:, t].unsqueeze(1)), dim=1) loss = loss + cla_loss if train: loss /= max_tgt_len loss.backward() self.optimizer.step() if config.label_smoothing: loss_ppl /= max_tgt_len if torch.isnan(loss_ppl).sum().item() != 0 or torch.isinf( loss_ppl).sum().item() != 0: print("check") pdb.set_trace() return loss_ppl.item(), math.exp(min(loss_ppl.item(), 100)), cla_loss.item(), sa_acc else: return loss.item(), math.exp(min(loss.item(), 100)), cla_loss.item(), sa_acc def train_one_batch(self, batch, iter, train=True): enc_batch = batch["review_batch"] enc_batch_extend_vocab = batch["review_ext_batch"] src_batch = batch[ 'reviews_batch'] # reviews sequence (bsz, r_num, r_len) src_mask = batch[ 'reviews_mask'] # indicate which review is fake(for padding). (bsz, r_num) src_length = batch['reviews_length'] # (bsz, r_num) enc_length_batch = batch[ 'reviews_length_list'] # 2-dim list, 0: len=bsz, 1: lens of reviews and pads src_labels = batch['reviews_label'] # (bsz, r_num) oovs = batch["oovs"] max_oov_length = len( sorted(oovs, key=lambda i: len(i), reverse=True)[0]) extra_zeros = Variable(torch.zeros( (enc_batch.size(0), max_oov_length))).to(config.device) dec_batch = batch["tags_batch"] dec_ext_batch = batch["tags_ext_batch"] dec_rank_batch = batch[ 'tags_idx_batch'] # tag indexes sequence (bsz, tgt_len) if config.noam: self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() # 1. Sentence-level Salience Estimation (SSE) cla_loss, sa_scores, sa_acc = self.sse.salience_estimate( src_batch, src_mask, src_length, src_labels) # sa_scores: (bsz, r_num) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze( 1) # (bsz, src_len)->(bsz, 1, src_len) # emb_mask = self.embedding(batch["mask_context"]) # src_emb = self.embedding(enc_batch)+emb_mask src_emb = self.embedding(enc_batch) encoder_outputs = self.encoder(src_emb, mask_src) # (bsz, src_len, emb_dim) src_enc_rank = torch.FloatTensor([]).to( config.device) # (bsz, src_len, emb_dim) src_ext_rank = torch.LongTensor([]).to(config.device) # (bsz, src_len) aln_rank = torch.LongTensor([]).to( config.device) # (bsz, tgt_len, src_len) aln_mask_rank = torch.FloatTensor([]).to( config.device) # (bsz, tgt_len, src_len) bsz, max_src_len = enc_batch.size() for idx in range(bsz): # Custering (by k-means) and Ranking item_length = enc_length_batch[idx] reviews = torch.split(encoder_outputs[idx], item_length, dim=0) reviews_ext = torch.split(enc_batch_extend_vocab[idx], item_length, dim=0) r_vectors = [] # store the vector repr of each review rs_vectors = [] # store the token vectors repr of each review r_exts = [] r_pad_vec, r_ext_pad = None, None for r_idx in range(len(item_length)): if r_idx == len(item_length) - 1: r_pad_vec = reviews[r_idx] r_ext_pad = reviews_ext[r_idx] break r = self.rcr.hierarchical_pooling(reviews[r_idx].unsqueeze( 0)).squeeze(0).detach().cpu().numpy() * sa_scores[idx, r_idx] r_vectors.append(r) rs_vectors.append(reviews[r_idx]) r_exts.append(reviews_ext[r_idx]) rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln = \ self.rcr.perform(r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, dec_rank_batch[idx], max_src_len) # rs_repr: (max_rs_length, embed_dim); ext_repr: (max_rs_length); srctgt_aln_mask/srctgt_aln: (tgt_len, max_rs_length) src_enc_rank = torch.cat((src_enc_rank, rs_repr.unsqueeze(0)), dim=0) # (1->bsz, max_src_len, embed_dim) src_ext_rank = torch.cat((src_ext_rank, ext_repr.unsqueeze(0)), dim=0) # (1->bsz, max_src_len) aln_rank = torch.cat((aln_rank, srctgt_aln.unsqueeze(0)), dim=0) # (1->bsz, max_tgt_len, max_src_len) aln_mask_rank = torch.cat( (aln_mask_rank, srctgt_aln_mask.unsqueeze(0)), dim=0) del encoder_outputs, reviews, reviews_ext, r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln torch.cuda.empty_cache() torch.backends.cuda.cufft_plan_cache.clear() sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1).to( config.device) # (bsz, 1) dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) # (bsz, tgt_len) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) sos_rank = torch.LongTensor([1] * enc_batch.size(0)).unsqueeze(1).to( config.device) dec_rank_batch = torch.cat((sos_rank, dec_rank_batch[:, :-1]), 1) aln_rank = aln_rank[:, :-1, :] aln_mask_rank = aln_mask_rank[:, :-1, :] pre_logit, attn_dist, aln_loss = self.decoder( inputs=self.embedding(dec_batch_shift), inputs_rank=dec_rank_batch, encoder_output=src_enc_rank, aln_rank=aln_rank, aln_mask_rank=aln_mask_rank, mask=(mask_src, mask_trg)) logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros) if config.pointer_gen: loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_ext_batch.contiguous().view(-1)) else: loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if config.label_smoothing: loss_ppl = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_ext_batch.contiguous().view(-1) if config.pointer_gen else dec_batch.contiguous().view(-1)) if train: if iter >= 13000: loss = loss + (0.1 * aln_loss) + cla_loss else: loss = loss + aln_loss + cla_loss loss = loss + aln_loss + cla_loss loss.backward() self.optimizer.step() if config.label_smoothing: if torch.isnan(loss_ppl).sum().item() != 0 or torch.isinf( loss_ppl).sum().item() != 0: print("check") pdb.set_trace() return loss_ppl.item(), math.exp(min(loss_ppl.item(), 100)), cla_loss.item(), sa_acc else: return loss.item(), math.exp(min(loss.item(), 100)), cla_loss.item(), sa_acc def compute_act_loss(self, module): R_t = module.remainders N_t = module.n_updates p_t = R_t + N_t avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0) loss = config.act_loss_weight * avg_p_t.item() return loss def decoder_greedy(self, batch, max_dec_step=30): enc_batch = batch["review_batch"] enc_batch_extend_vocab = batch["review_ext_batch"] src_batch = batch[ 'reviews_batch'] # reviews sequence (bsz, r_num, r_len) src_mask = batch[ 'reviews_mask'] # indicate which review is fake(for padding). (bsz, r_num) src_length = batch['reviews_length'] # (bsz, r_num) enc_length_batch = batch[ 'reviews_length_list'] # 2-dim list, 0: len=bsz, 1: lens of reviews and pads src_labels = batch['reviews_label'] # (bsz, r_num) oovs = batch["oovs"] max_oov_length = len( sorted(oovs, key=lambda i: len(i), reverse=True)[0]) extra_zeros = Variable(torch.zeros( (enc_batch.size(0), max_oov_length))).to(config.device) # 1. Sentence-level Salience Estimation (SSE) cla_loss, sa_scores, sa_acc = self.sse.salience_estimate( src_batch, src_mask, src_length, src_labels) # sa_scores: (bsz, r_num) ## Encode - context mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze( 1) # (bsz, src_len)->(bsz, 1, src_len) # emb_mask = self.embedding(batch["mask_context"]) # src_emb = self.embedding(enc_batch) + emb_mask # todo eos or sentence embedding?? src_emb = self.embedding(enc_batch) encoder_outputs = self.encoder(src_emb, mask_src) # (bsz, src_len, emb_dim) src_enc_rank = torch.FloatTensor([]).to( config.device) # (bsz, src_len, emb_dim) src_ext_rank = torch.LongTensor([]).to(config.device) # (bsz, src_len) aln_rank = torch.LongTensor([]).to( config.device) # (bsz, tgt_len, src_len) aln_mask_rank = torch.FloatTensor([]).to( config.device) # (bsz, tgt_len, src_len) bsz, max_src_len = enc_batch.size() for idx in range(bsz): # Custering (by k-means) and Ranking item_length = enc_length_batch[idx] reviews = torch.split(encoder_outputs[idx], item_length, dim=0) reviews_ext = torch.split(enc_batch_extend_vocab[idx], item_length, dim=0) r_vectors = [] # store the vector repr of each review rs_vectors = [] # store the token vectors repr of each review r_exts = [] r_pad_vec, r_ext_pad = None, None for r_idx in range(len(item_length)): if r_idx == len(item_length) - 1: r_pad_vec = reviews[r_idx] r_ext_pad = reviews_ext[r_idx] break r = self.rcr.hierarchical_pooling(reviews[r_idx].unsqueeze( 0)).squeeze(0).detach().cpu().numpy() * sa_scores[idx, r_idx] r_vectors.append(r) rs_vectors.append(reviews[r_idx]) r_exts.append(reviews_ext[r_idx]) rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln = self.rcr.perform( r_vecs=r_vectors, rs_vecs=rs_vectors, r_exts=r_exts, r_pad_vec=r_pad_vec, r_ext_pad=r_ext_pad, max_rs_length=max_src_len, train=False) # rs_repr: (max_rs_length, embed_dim); ext_repr: (max_rs_length); srctgt_aln_mask/srctgt_aln: (tgt_len, max_rs_length) src_enc_rank = torch.cat((src_enc_rank, rs_repr.unsqueeze(0)), dim=0) # (1->bsz, max_src_len, embed_dim) src_ext_rank = torch.cat((src_ext_rank, ext_repr.unsqueeze(0)), dim=0) # (1->bsz, max_src_len) aln_rank = torch.cat((aln_rank, srctgt_aln.unsqueeze(0)), dim=0) # (1->bsz, max_tgt_len, max_src_len) aln_mask_rank = torch.cat( (aln_mask_rank, srctgt_aln_mask.unsqueeze(0)), dim=0) del encoder_outputs, reviews, reviews_ext, r_vectors, rs_vectors, r_exts, r_pad_vec, r_ext_pad, rs_repr, ext_repr, srctgt_aln_mask, srctgt_aln torch.cuda.empty_cache() torch.backends.cuda.cufft_plan_cache.clear() # ys = torch.ones(1, 1).fill_(config.SOS_idx).long() ys = torch.zeros(enc_batch.size(0), 1).fill_(config.SOS_idx).long().to( config.device) # when testing, we set bsz into 1 mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) ys_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device) last_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device) pred_attn_dist = torch.FloatTensor([]).to(config.device) decoded_words = [] for i in range(max_dec_step + 1): aln_rank_cur = aln_rank[:, last_rank.item(), :].unsqueeze( 1) # (bsz, src_len) if config.project: out, attn_dist, _ = self.decoder( inputs=self.embedding_proj_in(self.embedding(ys)), inputs_rank=ys_rank, encoder_output=self.embedding_proj_in(src_enc_rank), aln_rank=aln_rank_cur, aln_mask_rank=aln_mask_rank, # nouse mask=(mask_src, mask_trg), speed='slow') else: out, attn_dist, _ = self.decoder(inputs=self.embedding(ys), inputs_rank=ys_rank, encoder_output=src_enc_rank, aln_rank=aln_rank_cur, aln_mask_rank=aln_mask_rank, mask=(mask_src, mask_trg), speed='slow') prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros) _, next_word = torch.max(prob[:, -1], dim=1) # bsz=1, if test cur_words = [] for i_batch, ni in enumerate(next_word.view(-1)): if ni.item() == config.EOS_idx: cur_words.append('<EOS>') last_rank[i_batch] = 0 elif ni.item() in self.vocab.index2word: cur_words.append(self.vocab.index2word[ni.item()]) if ni.item() == config.SOS_idx: last_rank[i_batch] += 1 else: cur_words.append(oovs[i_batch][ ni.item() - self.vocab.n_words]) # output non-dict word next_word[i_batch] = config.UNK_idx # input unk word decoded_words.append(cur_words) # next_word = next_word.data[0] # if next_word.item() not in self.vocab.index2word: # next_word = torch.tensor(config.UNK_idx) # if config.USE_CUDA: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1).to(config.device) ys_rank = torch.cat([ys_rank, last_rank], dim=1).to(config.device) # else: # ys = torch.cat([ys, next_word],dim=1) # ys_rank = torch.cat([ys_rank, last_rank],dim=1) # if config.USE_CUDA: # ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).cuda()], dim=1) # ys = ys.cuda() # else: # ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) if config.attn_analysis: pred_attn_dist = torch.cat( (pred_attn_dist, attn_dist.unsqueeze(1)), dim=1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) if config.attn_analysis: bsz, tgt_len, src_len = aln_mask_rank.size() pred_attn = pred_attn_dist[:, 1:, :].view(bsz * tgt_len, src_len) tgt_attn = aln_mask_rank.view(bsz * tgt_len, src_len) good_attn_sum = torch.masked_select( pred_attn, tgt_attn.bool()).sum() # pred_attn: bsz * tgt_len, src_len bad_attn_sum = torch.masked_select(pred_attn, ~tgt_attn.bool()).sum() bad_num = ~tgt_attn.bool() ratio = bad_num.sum() / tgt_attn.bool().sum() bad_attn_sum /= ratio good_attn = good_attn_sum[ 0] # last step (because this's already been the whole sentence length.). bad_attn = bad_attn_sum[1] good_attn /= (tgt_len * bsz) bad_attn /= (tgt_len * bsz) return sent, [good_attn, bad_attn] else: return sent
class woRCR(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(woRCR, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.sse = SSE(vocab, config.emb_dim, config.dropout, config.rnn_hidden_dim) self.rcr = RCR() self.decoder = Decoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.generator = Generator(config.hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if config.label_smoothing: self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if config.noam: self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if load_optim: self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, iter, train=True): enc_batch = batch["review_batch"] enc_batch_extend_vocab = batch["review_ext_batch"] src_batch = batch[ 'reviews_batch'] # reviews sequence (bsz, r_num, r_len) src_mask = batch[ 'reviews_mask'] # indicate which review is fake(for padding). (bsz, r_num) src_length = batch['reviews_length'] # (bsz, r_num) enc_length_batch = batch[ 'reviews_length_list'] # 2-dim list, 0: len=bsz, 1: lens of reviews and pads src_labels = batch['reviews_label'] # (bsz, r_num) oovs = batch["oovs"] max_oov_length = len( sorted(oovs, key=lambda i: len(i), reverse=True)[0]) extra_zeros = Variable(torch.zeros( (enc_batch.size(0), max_oov_length))).to(config.device) dec_batch = batch["tags_batch"] dec_ext_batch = batch["tags_ext_batch"] tid_batch = batch[ 'tags_idx_batch'] # tag indexes sequence (bsz, tgt_len) if config.noam: self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() # 1. Sentence-level Salience Estimation (SSE) cla_loss, sa_scores, sa_acc = self.sse.salience_estimate( src_batch, src_mask, src_length, src_labels) # sa_scores: (bsz, r_num) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze( 1) # (bsz, src_len)->(bsz, 1, src_len) # emb_mask = self.embedding(batch["mask_context"]) # src_emb = self.embedding(enc_batch)+emb_mask src_emb = self.embedding(enc_batch) encoder_outputs = self.encoder(src_emb, mask_src) # (bsz, src_len, emb_dim) sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1).to( config.device) # (bsz, 1) dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) # (bsz, tgt_len) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist, aln_loss = self.decoder( inputs=self.embedding(dec_batch_shift), inputs_rank=tid_batch, encoder_output=encoder_outputs, mask=(mask_src, mask_trg)) logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros) if config.pointer_gen: loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_ext_batch.contiguous().view(-1)) else: loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if config.label_smoothing: loss_ppl = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_ext_batch.contiguous().view(-1) if config.pointer_gen else dec_batch.contiguous().view(-1)) loss = loss + cla_loss if torch.isnan(loss).sum().item() != 0 or torch.isinf( loss).sum().item() != 0: print("check") pdb.set_trace() if train: loss.backward() self.optimizer.step() if config.label_smoothing: loss_ppl = loss_ppl.item() cla_loss = cla_loss.item() return loss_ppl, math.exp(min(loss_ppl, 100)), cla_loss, sa_acc else: return loss.item(), math.exp(min(loss.item(), 100)), cla_loss, sa_acc def decoder_greedy(self, batch, max_dec_step=30): enc_batch = batch["review_batch"] enc_batch_extend_vocab = batch["review_ext_batch"] src_batch = batch[ 'reviews_batch'] # reviews sequence (bsz, r_num, r_len) src_mask = batch[ 'reviews_mask'] # indicate which review is fake(for padding). (bsz, r_num) src_length = batch['reviews_length'] # (bsz, r_num) enc_length_batch = batch[ 'reviews_length_list'] # 2-dim list, 0: len=bsz, 1: lens of reviews and pads src_labels = batch['reviews_label'] # (bsz, r_num) oovs = batch["oovs"] max_oov_length = len( sorted(oovs, key=lambda i: len(i), reverse=True)[0]) extra_zeros = Variable(torch.zeros( (enc_batch.size(0), max_oov_length))).to(config.device) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze( 1) # (bsz, src_len)->(bsz, 1, src_len) # emb_mask = self.embedding(batch["mask_context"]) # src_emb = self.embedding(enc_batch) + emb_mask # todo eos or sentence embedding?? src_emb = self.embedding(enc_batch) encoder_outputs = self.encoder(src_emb, mask_src) # (bsz, src_len, emb_dim) ys = torch.zeros(enc_batch.size(0), 1).fill_(config.SOS_idx).long().to( config.device) # when testing, we set bsz into 1 mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) ys_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device) last_rank = torch.ones(enc_batch.size(0), 1).long().to(config.device) decoded_words = [] for i in range(max_dec_step + 1): if config.project: out, attn_dist, aln_loss = self.decoder( inputs=self.embedding_proj_in(self.embedding(ys)), inputs_rank=ys_rank, encoder_output=self.embedding_proj_in(encoder_outputs), mask=(mask_src, mask_trg)) else: out, attn_dist, aln_loss = self.decoder( inputs=self.embedding(ys), inputs_rank=ys_rank, encoder_output=encoder_outputs, mask=(mask_src, mask_trg)) prob = self.generator( out, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros) _, next_word = torch.max(prob[:, -1], dim=1) # bsz=1 cur_words = [] for i_batch, ni in enumerate(next_word.view(-1)): if ni.item() == config.EOS_idx: cur_words.append('<EOS>') last_rank[i_batch] = 0 elif ni.item() in self.vocab.index2word: cur_words.append(self.vocab.index2word[ni.item()]) if ni.item() == config.SOS_idx: last_rank[i_batch] += 1 else: cur_words.append(oovs[i_batch][ni.item() - self.vocab.n_words]) decoded_words.append(cur_words) next_word = next_word.data[0] if next_word.item() not in self.vocab.index2word: next_word = torch.tensor(config.UNK_idx) ys = torch.cat([ ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).to( config.device) ], dim=1).to(config.device) ys_rank = torch.cat([ys_rank, last_rank], dim=1).to(config.device) # if config.USE_CUDA: # ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).cuda()], dim=1) # ys = ys.cuda() # else: # ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(Seq2SPG, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.preptrained) self.encoder = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True, dropout=0.2) self.encoder2decoder = nn.Linear(config.hidden_dim, config.hidden_dim) self.decoder = LSTMAttentionDot(config.emb_dim, config.hidden_dim, batch_first=True) self.memory = MLP( config.hidden_dim + config.emb_dim, [config.private_dim1, config.private_dim2, config.private_dim3], config.hidden_dim) self.dec_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim) self.mem_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim) self.generator = Generator(config.hidden_dim, self.vocab_size) self.hooks = { } #Save the model structure of each task as masks of the parameters if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if is_eval: self.encoder = self.encoder.eval() self.encoder2decoder = self.encoder2decoder.eval() self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.embedding = self.embedding.eval() self.memory = self.memory.eval() self.dec_gate = self.dec_gate.eval() self.mem_gate = self.mem_gate.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if config.use_sgd: self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) print("LOSS", state['current_loss']) self.encoder.load_state_dict(state['encoder_state_dict']) self.encoder2decoder.load_state_dict( state['encoder2decoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) self.memory.load_state_dict(state['memory_dict']) self.dec_gate.load_state_dict(state['dec_gate_dict']) self.mem_gate.load_state_dict(state['mem_gate_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if (config.USE_CUDA): self.encoder = self.encoder.cuda() self.encoder2decoder = self.encoder2decoder.cuda() self.decoder = self.decoder.cuda() self.generator = self.generator.cuda() self.criterion = self.criterion.cuda() self.embedding = self.embedding.cuda() self.memory = self.memory.cuda() self.dec_gate = self.dec_gate.cuda() self.mem_gate = self.mem_gate.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = ""
class CvaeTrans(nn.Module): def __init__(self, vocab, emo_number, model_file_path=None, is_eval=False, load_optim=False): super(CvaeTrans, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.r_encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder = Decoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.latent_layer = Latent(is_eval) self.bow = SoftmaxOutputLayer(config.hidden_dim, self.vocab_size) if config.multitask: self.emo = SoftmaxOutputLayer(config.hidden_dim, emo_number) self.emo_criterion = nn.NLLLoss() self.generator = Generator(config.hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.r_encoder.load_state_dict(state['r_encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) self.latent_layer.load_state_dict(state['latent_dict']) self.bow.load_state_dict(state['bow']) if (config.USE_CUDA): self.cuda() if is_eval: self.eval() else: self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if config.USE_CUDA: for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'r_encoder_state_dict': self.r_encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'latent_dict': self.latent_layer.state_dict(), 'bow': self.bow.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Response encode mask_res = batch["posterior_batch"].data.eq( config.PAD_idx).unsqueeze(1) posterior_mask = self.embedding(batch["posterior_mask"]) r_encoder_outputs = self.r_encoder( self.embedding(batch["posterior_batch"]) + posterior_mask, mask_res) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) #latent variable if config.model == "cvaetrs": kld_loss, z = self.latent_layer(encoder_outputs[:, 0], r_encoder_outputs[:, 0], train=True) meta = self.embedding(batch["program_label"]) if config.dataset == "empathetic": meta = meta - meta # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) #(batch, len, embedding) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) input_vector = self.embedding(dec_batch_shift) if config.model == "cvaetrs": input_vector[:, 0] = input_vector[:, 0] + z + meta else: input_vector[:, 0] = input_vector[:, 0] + meta pre_logit, attn_dist = self.decoder(input_vector, encoder_outputs, (mask_src, mask_trg)) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) ## loss: NNL if ptr else Cross entropy loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if config.model == "cvaetrs": z_logit = self.bow(z + meta) # [batch_size, vocab_size] z_logit = z_logit.unsqueeze(1).repeat(1, logit.size(1), 1) loss_aux = self.criterion( z_logit.contiguous().view(-1, z_logit.size(-1)), dec_batch.contiguous().view(-1)) if config.multitask: emo_logit = self.emo(encoder_outputs[:, 0]) emo_loss = self.emo_criterion(emo_logit, batch["program_label"] - 9) #kl_weight = min(iter/config.full_kl_step, 0.28) if config.full_kl_step >0 else 1.0 kl_weight = min( math.tanh(6 * iter / config.full_kl_step - 3) + 1, 1) loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux if config.multitask: loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux + emo_loss aux = loss_aux.item() elbo = loss_rec + kld_loss else: loss = loss_rec elbo = loss_rec kld_loss = torch.Tensor([0]) aux = 0 if config.multitask: emo_logit = self.emo(encoder_outputs[:, 0]) emo_loss = self.emo_criterion(emo_logit, batch["program_label"] - 9) loss = loss_rec + emo_loss if (train): loss.backward() # clip gradient nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm) self.optimizer.step() return loss_rec.item(), math.exp(min( loss_rec.item(), 100)), kld_loss.item(), aux, elbo.item() def decoder_greedy(self, batch, max_dec_step=50): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) meta = self.embedding(batch["program_label"]) if config.dataset == "empathetic": meta = meta - meta encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) if config.model == "cvaetrs": kld_loss, z = self.latent_layer(encoder_outputs[:, 0], None, train=False) ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): input_vector = self.embedding(ys) if config.model == "cvaetrs": input_vector[:, 0] = input_vector[:, 0] + z + meta else: input_vector[:, 0] = input_vector[:, 0] + meta out, attn_dist = self.decoder(input_vector, encoder_outputs, (mask_src, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) if config.USE_CUDA: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) ys = ys.cuda() else: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
class PGNet(nn.Module): ''' refer: https://github.com/atulkum/pointer_summarizer ''' def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(PGNet, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder() self.decoder = Decoder() self.reduce_state = ReduceState() self.generator = Generator(config.rnn_hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if config.label_smoothing: self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if config.noam: self.optimizer = NoamOpt( config.rnn_hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if load_optim: self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, iter, train=True): enc_batch = batch["review_batch"] enc_lens = batch["review_length"] enc_batch_extend_vocab = batch["review_ext_batch"] oovs = batch["oovs"] max_oov_length = len( sorted(oovs, key=lambda i: len(i), reverse=True)[0]) dec_batch = batch["tags_batch"] dec_ext_batch = batch["tags_ext_batch"] max_tgt_len = dec_batch.size(0) if config.noam: self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Embedding - context mask_src = enc_batch mask_src = ~(mask_src.data.eq(config.PAD_idx)) # emb_mask = self.embedding(batch["mask_context"]) # src_emb = self.embedding(enc_batch)+emb_mask src_emb = self.embedding(enc_batch) encoder_outputs, encoder_feature, encoder_hidden = self.encoder( src_emb, enc_lens) # reduce bidirectional hidden to one hidden (h and c) s_t_1 = self.reduce_state(encoder_hidden) # 1 x b x hidden_dim c_t_1 = Variable( torch.zeros((enc_batch.size(0), 2 * config.rnn_hidden_dim))).to(config.device) coverage = Variable(torch.zeros(enc_batch.size())).to(config.device) extra_zeros = Variable(torch.zeros( (enc_batch.size(0), max_oov_length))).to(config.device) sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1).to( config.device) # (bsz, 1) dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) # (bsz, tgt_len) dec_batch_embd = self.embedding(dec_batch_shift) step_losses = [] step_loss_ppls = 0 for di in range(max_tgt_len): y_t_1 = dec_batch_embd[:, di, :] logit, s_t_1, c_t_1, attn_dist, next_coverage, p_gen = self.decoder( y_t_1, s_t_1, encoder_outputs, encoder_feature, mask_src, c_t_1, coverage, di) logit = self.generator( logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, 1, p_gen) if config.pointer_gen: step_loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_ext_batch[:, di].contiguous().view(-1)) else: step_loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch[:, di].contiguous().view(-1)) if config.label_smoothing: step_loss_ppl = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch[:, di].contiguous().view(-1)) step_loss_ppls += step_loss_ppl if config.is_coverage: # coverage loss step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) # loss sum step_loss = step_loss + config.cov_loss_wt * step_coverage_loss # update coverage coverage = next_coverage step_losses.append(step_loss) if config.is_coverage: sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / batch['tags_length'].float() loss = torch.mean(batch_avg_loss) else: loss = sum(step_losses) / max_tgt_len if config.label_smoothing: loss_ppl = (step_loss_ppls / max_tgt_len).item() if train: loss.backward() self.optimizer.step() if config.label_smoothing: return loss_ppl, math.exp(min(loss_ppl, 100)), 0, 0 else: return loss.item(), math.exp(min(loss.item(), 100)), 0, 0 def decoder_greedy(self, batch, max_dec_step=30): enc_batch = batch["review_batch"] enc_lens = batch["review_length"] enc_batch_extend_vocab = batch["review_ext_batch"] oovs = batch["oovs"] max_oov_length = len( sorted(oovs, key=lambda i: len(i), reverse=True)[0]) dec_batch = batch["tags_batch"] dec_ext_batch = batch["tags_ext_batch"] max_tgt_len = dec_batch.size(0) ## Embedding - context mask_src = enc_batch mask_src = ~(mask_src.data.eq(config.PAD_idx)) # emb_mask = self.embedding(batch["mask_context"]) # src_emb = self.embedding(enc_batch)+emb_mask src_emb = self.embedding(enc_batch) encoder_outputs, encoder_feature, encoder_hidden = self.encoder( src_emb, enc_lens) # reduce bidirectional hidden to one hidden (h and c) s_t_1 = self.reduce_state(encoder_hidden) # 1 x b x hidden_dim c_t_1 = Variable( torch.zeros((enc_batch.size(0), 2 * config.rnn_hidden_dim))).to(config.device) coverage = Variable(torch.zeros(enc_batch.size())).to(config.device) extra_zeros = Variable(torch.zeros( (enc_batch.size(0), max_oov_length))).to(config.device) # ys = torch.ones(1, 1).fill_(config.SOS_idx).long() ys = torch.zeros(enc_batch.size(0)).fill_(config.SOS_idx).long().to( config.device) # when testing, we set bsz into 1 decoded_words = [] for i in range(max_dec_step + 1): logit, s_t_1, c_t_1, attn_dist, next_coverage, p_gen = self.decoder( self.embedding(ys), s_t_1, encoder_outputs, encoder_feature, mask_src, c_t_1, coverage, i) prob = self.generator( logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, 1, p_gen) _, next_word = torch.max(prob, dim=1) # bsz=1 cur_words = [] for i_batch, ni in enumerate(next_word.view(-1)): if ni.item() == config.EOS_idx: cur_words.append('<EOS>') elif ni.item() in self.vocab.index2word: cur_words.append(self.vocab.index2word[ni.item()]) else: cur_words.append(oovs[i_batch][ni.item() - self.vocab.n_words]) decoded_words.append(cur_words) next_word = next_word.data[0] if next_word.item() not in self.vocab.index2word: next_word = torch.tensor(config.UNK_idx) ys = torch.zeros(enc_batch.size(0)).long().fill_(next_word).to( config.device) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
class Transformer(nn.Module): def __init__(self, vocab, emo_number, model_file_path=None, is_eval=False, load_optim=False): super(Transformer, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder = Decoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.generator = Generator(config.hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) meta = self.embedding(batch["program_label"]) emb_mask = self.embedding(batch["input_mask"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder( self.embedding(dec_batch_shift) + meta.unsqueeze(1), encoder_outputs, (mask_src, mask_trg)) #+meta.unsqueeze(1) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) #logit = F.log_softmax(logit,dim=-1) #fix the name later ## loss: NNL if ptr else Cross entropy loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if (train): loss.backward() self.optimizer.step() return loss.item(), math.exp(min(loss.item(), 100)), 0 def compute_act_loss(self, module): R_t = module.remainders N_t = module.n_updates p_t = R_t + N_t avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0) loss = config.act_loss_weight * avg_p_t.item() return loss def decoder_greedy(self, batch, max_dec_step=50): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) meta = self.embedding(batch["program_label"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() # print('=====================ys========================') # print(ys) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): out, attn_dist = self.decoder( self.embedding(ys) + meta.unsqueeze(1), encoder_outputs, (mask_src, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, next_word = torch.max(prob[:, -1], dim=1) # print('=====================next_word1========================') # print(next_word) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) #next_word = next_word.data[0] # print('=====================next_word2========================') # print(next_word) if config.USE_CUDA: # print('=====================shape========================') # print(ys.shape, next_word.shape) ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) ys = ys.cuda() else: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) # print('=====================new_ys========================') # print(ys) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent def decoder_greedy_po(self, batch, max_dec_step=50): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) meta = self.embedding(batch["program_label"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() # print('=====================ys========================') # print(ys) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): out, attn_dist = self.decoder( self.embedding(ys) + meta.unsqueeze(1), encoder_outputs, (mask_src, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, next_word = torch.max(prob[:, -1], dim=1) # print('=====================next_word1========================') # print(next_word) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) #next_word = next_word.data[0] # print('=====================next_word2========================') # print(next_word) if config.USE_CUDA: # print('=====================shape========================') # print(ys.shape, next_word.shape) ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) ys = ys.cuda() else: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) # print('=====================new_ys========================') # print(ys) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
writer = SummaryWriter(log_dir=config.save_path) # Build model, optimizer, and set states if not (config.load_frompretrain == 'None'): meta_net = Seq2SPG(p.vocab, model_file_path=config.load_frompretrain, is_eval=False) else: meta_net = Seq2SPG(p.vocab) if config.meta_optimizer == 'sgd': meta_optimizer = torch.optim.SGD(meta_net.parameters(), lr=config.meta_lr) elif config.meta_optimizer == 'adam': meta_optimizer = torch.optim.Adam(meta_net.parameters(), lr=config.meta_lr) elif config.meta_optimizer == 'noam': meta_optimizer = NoamOpt( config.hidden_dim, 1, 4000, torch.optim.Adam(meta_net.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) else: raise ValueError meta_batch_size = config.meta_batch_size tasks = p.get_personas('train') steps = (len(tasks) // meta_batch_size) + int(len(tasks) % meta_batch_size != 0) # meta early stop patience = 10 if config.fix_dialnum_train: patience = 100 best_loss = 10000000
class VAE(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(VAE, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab,config.preptrained) self.encoder = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True, dropout=0.2) self.encoder_r = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True, dropout=0.2) self.represent = R_MLP(2 * config.hidden_dim, 68) self.prior = P_MLP(config.hidden_dim, 68) self.mlp_b = nn.Linear(config.hidden_dim + 68, self.vocab_size) self.encoder2decoder = nn.Linear( config.hidden_dim + 68, config.hidden_dim) self.decoder = LSTMAttentionDot(config.emb_dim, config.hidden_dim, batch_first=True) self.generator = Generator(config.hidden_dim,self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if is_eval: self.encoder = self.encoder.eval() self.encoder_r = self.encoder_r.eval() self.represent = self.represent.eval() self.prior = self.prior.eval() self.mlp_b = self.mlp_b.eval() self.encoder2decoder = self.encoder2decoder.eval() self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.embedding = self.embedding.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if(config.noam): self.optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if config.use_sgd: self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location= lambda storage, location: storage) print("LOSS",state['current_loss']) self.encoder.load_state_dict(state['encoder_state_dict']) self.encoder_r.load_state_dict(state['encoder_r_state_dict']) self.represent.load_state_dict(state['represent_state_dict']) self.prior.load_state_dict(state['prior_state_dict']) self.mlp_b.load_state_dict(state['mlp_b_state_dict']) self.encoder2decoder.load_state_dict(state['encoder2decoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if (config.USE_CUDA): self.encoder = self.encoder.cuda() self.encoder_r = self.encoder_r.cuda() self.represent = self.represent.cuda() self.prior = self.prior.cuda() self.mlp_b = self.mlp_b.cuda() self.encoder2decoder = self.encoder2decoder.cuda() self.decoder = self.decoder.cuda() self.generator = self.generator.cuda() self.criterion = self.criterion.cuda() self.embedding = self.embedding.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b, log=False, d="save/paml_model_sim"): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'encoder_r_state_dict': self.encoder_r.state_dict(), 'represent_state_dict': self.represent.state_dict(), 'prior_state_dict': self.prior.state_dict(), 'mlp_b_state_dict': self.mlp_b.state_dict(), 'encoder2decoder_state_dict': self.encoder2decoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), #'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } if log: model_save_path = os.path.join(d, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) ) else: model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) ) self.best_path = model_save_path torch.save(state, model_save_path) def get_state(self, batch): """Get cell states and hidden states.""" batch_size = batch.size(0) \ if self.encoder.batch_first else batch.size(1) h0_encoder = Variable(torch.zeros( self.encoder.num_layers, batch_size, config.hidden_dim ), requires_grad=False) c0_encoder = Variable(torch.zeros( self.encoder.num_layers, batch_size, config.hidden_dim ), requires_grad=False) return h0_encoder.cuda(), c0_encoder.cuda() def train_one_batch(self, batch, train=True): ## pad and other stuff enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _, _ = get_output_from_batch(batch) if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode self.h0_encoder, self.c0_encoder = self.get_state(enc_batch) src_h, (src_h_t, src_c_t) = self.encoder( self.embedding(enc_batch), (self.h0_encoder, self.c0_encoder)) h_t = src_h_t[-1] c_t = src_c_t[-1] self.h0_encoder_r, self.c0_encoder_r = self.get_state(dec_batch) src_h_r, (src_h_t_r, src_c_t_r) = self.encoder_r( self.embedding(dec_batch), (self.h0_encoder_r, self.c0_encoder_r)) h_t_r = src_h_t_r[-1] c_t_r = src_c_t_r[-1] #sample and reparameter z_sample, mu, var = self.represent(torch.cat((h_t_r, h_t), 1)) p_z_sample, p_mu, p_var = self.prior(h_t) # Decode decoder_init_state = nn.Tanh()(self.encoder2decoder(torch.cat((z_sample, h_t), 1))) sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1) target_embedding = self.embedding(dec_batch_shift) ctx = src_h.transpose(0, 1) trg_h, (_, _) = self.decoder( target_embedding, (decoder_init_state, c_t), ctx ) pre_logit = trg_h logit = self.generator(pre_logit) ## loss: NNL if ptr else Cross entropy re_loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) kl_losses = 0.5 * torch.sum(torch.exp(var - p_var) + (mu - p_mu) ** 2 / torch.exp(p_var) - 1. - var + p_var, 1) kl_loss = torch.mean(kl_losses) latent_logit = self.mlp_b(torch.cat((z_sample, h_t), 1)).unsqueeze(1) latent_logit = F.log_softmax(latent_logit,dim=-1) latent_logits = latent_logit.repeat(1, logit.size(1), 1) bow_loss = self.criterion(latent_logits.contiguous().view(-1, latent_logits.size(-1)), dec_batch.contiguous().view(-1)) loss = re_loss + 0.48 * kl_loss + bow_loss if(train): loss.backward() self.optimizer.step() if(config.label_smoothing): s_loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return s_loss.item(), math.exp(min(s_loss.item(), 100)), loss.item(), re_loss.item(), kl_loss.item(), bow_loss.item()
def __init__(self, vocab, decoder_number, model_file_path=None, is_eval=False, load_optim=False): super().__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder_number = decoder_number self.decoder = DecoderContextV(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.vae_sampler = VAESampling(config.hidden_dim, config.hidden_dim, out_dim=300) # outputs m self.emotion_input_encoder_1 = EmotionInputEncoder( config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal, emo_input=config.emo_input) # outputs m~ self.emotion_input_encoder_2 = EmotionInputEncoder( config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal, emo_input=config.emo_input) if config.emo_combine == "att": self.cdecoder = ComplexResDecoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) elif config.emo_combine == "gate": self.cdecoder = ComplexResGate(config.emb_dim) self.s_weight = nn.Linear(config.hidden_dim, config.emb_dim, bias=False) self.decoder_key = nn.Linear(config.hidden_dim, decoder_number, bias=False) # v^T tanh(W E[i] + H c + b) method3 = True if method3: self.e_weight = nn.Linear(config.emb_dim, config.emb_dim, bias=True) self.v = torch.rand(config.emb_dim, requires_grad=True) if config.USE_CUDA: self.v = self.v.cuda() self.generator = Generator(config.hidden_dim, self.vocab_size) self.emoji_embedding = nn.Embedding(32, config.emb_dim) if config.init_emo_emb: self.init_emoji_embedding_with_glove() if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if config.softmax: self.attention_activation = nn.Softmax(dim=1) else: self.attention_activation = nn.Sigmoid() # nn.Softmax() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.load_state_dict(state['model']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" # Added positive emotions self.positive_emotions = [ 11, 16, 6, 8, 3, 1, 28, 13, 31, 17, 24, 0, 27 ] self.negative_emotions = [ 9, 4, 2, 22, 14, 30, 29, 25, 15, 10, 23, 19, 18, 21, 7, 20, 5, 26, 12 ]
class Train_MIME(nn.Module): ''' for emotion attention, simply pass the randomly sampled emotion as the Q in a decoder block of transformer ''' def __init__(self, vocab, decoder_number, model_file_path=None, is_eval=False, load_optim=False): super().__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder_number = decoder_number self.decoder = DecoderContextV(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.vae_sampler = VAESampling(config.hidden_dim, config.hidden_dim, out_dim=300) # outputs m self.emotion_input_encoder_1 = EmotionInputEncoder( config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal, emo_input=config.emo_input) # outputs m~ self.emotion_input_encoder_2 = EmotionInputEncoder( config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal, emo_input=config.emo_input) if config.emo_combine == "att": self.cdecoder = ComplexResDecoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) elif config.emo_combine == "gate": self.cdecoder = ComplexResGate(config.emb_dim) self.s_weight = nn.Linear(config.hidden_dim, config.emb_dim, bias=False) self.decoder_key = nn.Linear(config.hidden_dim, decoder_number, bias=False) # v^T tanh(W E[i] + H c + b) method3 = True if method3: self.e_weight = nn.Linear(config.emb_dim, config.emb_dim, bias=True) self.v = torch.rand(config.emb_dim, requires_grad=True) if config.USE_CUDA: self.v = self.v.cuda() self.generator = Generator(config.hidden_dim, self.vocab_size) self.emoji_embedding = nn.Embedding(32, config.emb_dim) if config.init_emo_emb: self.init_emoji_embedding_with_glove() if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if config.softmax: self.attention_activation = nn.Softmax(dim=1) else: self.attention_activation = nn.Sigmoid() # nn.Softmax() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.load_state_dict(state['model']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" # Added positive emotions self.positive_emotions = [ 11, 16, 6, 8, 3, 1, 28, 13, 31, 17, 24, 0, 27 ] self.negative_emotions = [ 9, 4, 2, 22, 14, 30, 29, 25, 15, 10, 23, 19, 18, 21, 7, 20, 5, 26, 12 ] def init_emoji_embedding_with_glove(self): self.emotions = [ 'surprised', 'excited', 'annoyed', 'proud', 'angry', 'sad', 'grateful', 'lonely', 'impressed', 'afraid', 'disgusted', 'confident', 'terrified', 'hopeful', 'anxious', 'disappointed', 'joyful', 'prepared', 'guilty', 'furious', 'nostalgic', 'jealous', 'anticipating', 'embarrassed', 'content', 'devastated', 'sentimental', 'caring', 'trusting', 'ashamed', 'apprehensive', 'faithful' ] self.emotion_index = [self.vocab.word2index[i] for i in self.emotions] self.emoji_embedding_init = self.embedding( torch.Tensor(self.emotion_index).long()) self.emoji_embedding.weight.data = self.emoji_embedding_init self.emoji_embedding.weight.requires_grad = True def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b, ent_t): state = { 'iter': iter, 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl, 'model': self.state_dict() } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b, ent_t)) self.best_path = model_save_path torch.save(state, model_save_path) def random_sampling(self, e): p = np.random.choice(self.positive_emotions) n = np.random.choice(self.negative_emotions) if e in self.positive_emotions: mimic = p mimic_t = n else: mimic = n mimic_t = p return mimic, mimic_t def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) if config.dataset == "empathetic": emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) else: encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] # q_h = torch.max(encoder_outputs, dim=1) emotions_mimic, emotions_non_mimic, mu_positive_prior, logvar_positive_prior, mu_negative_prior, logvar_negative_prior = \ self.vae_sampler(q_h, batch['program_label'], self.emoji_embedding) # KLLoss = -0.5 * (torch.sum(1 + logvar_n - mu_n.pow(2) - logvar_n.exp()) + torch.sum(1 + logvar_p - mu_p.pow(2) - logvar_p.exp())) m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1), encoder_outputs, mask_src) m_tilde_out = self.emotion_input_encoder_2( emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src) if train: emotions_mimic, emotions_non_mimic, mu_positive_posterior, logvar_positive_posterior, mu_negative_posterior, logvar_negative_posterior = \ self.vae_sampler.forward_train(q_h, batch['program_label'], self.emoji_embedding, M_out=m_out.mean(dim=1), M_tilde_out=m_tilde_out.mean(dim=1)) KLLoss_positive = self.vae_sampler.kl_div( mu_positive_posterior, logvar_positive_posterior, mu_positive_prior, logvar_positive_prior) KLLoss_negative = self.vae_sampler.kl_div( mu_negative_posterior, logvar_negative_posterior, mu_negative_prior, logvar_negative_prior) KLLoss = KLLoss_positive + KLLoss_negative else: KLLoss_positive = self.vae_sampler.kl_div(mu_positive_prior, logvar_positive_prior) KLLoss_negative = self.vae_sampler.kl_div(mu_negative_prior, logvar_negative_prior) KLLoss = KLLoss_positive + KLLoss_negative if config.emo_combine == "att": v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src) elif config.emo_combine == "gate": v = self.cdecoder(m_out, m_tilde_out) x = self.s_weight(q_h) # method2: E (W@c) logit_prob = torch.matmul(x, self.emoji_embedding.weight.transpose( 0, 1)) # shape (b_size, 32) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift), v, v, (mask_src, mask_trg)) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) if (train and config.schedule > 10): if (random.uniform(0, 1) <= (0.0001 + (1 - 0.0001) * math.exp(-1. * iter / config.schedule))): config.oracle = True else: config.oracle = False if config.softmax: program_label = torch.LongTensor(batch['program_label']) if config.USE_CUDA: program_label = program_label.cuda() if config.emo_combine == 'gate': L1_loss = nn.CrossEntropyLoss()(logit_prob, program_label) loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + KLLoss + L1_loss else: L1_loss = nn.CrossEntropyLoss()( logit_prob, torch.LongTensor(batch['program_label']) if not config.USE_CUDA else torch.LongTensor( batch['program_label']).cuda()) loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + KLLoss + L1_loss loss_bce_program = nn.CrossEntropyLoss()(logit_prob, program_label).item() else: loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + nn.BCEWithLogitsLoss()( logit_prob, torch.FloatTensor( batch['target_program']).cuda()) loss_bce_program = nn.BCEWithLogitsLoss()( logit_prob, torch.FloatTensor(batch['target_program']).cuda()).item() pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1) program_acc = accuracy_score(batch["program_label"], pred_program) if (config.label_smoothing): loss_ppl = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)).item() if (train): loss.backward() self.optimizer.step() if (config.label_smoothing): return loss_ppl, math.exp(min(loss_ppl, 100)), loss_bce_program, program_acc else: return loss.item(), math.exp(min( loss.item(), 100)), loss_bce_program, program_acc def compute_act_loss(self, module): R_t = module.remainders N_t = module.n_updates p_t = R_t + N_t avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0) loss = config.act_loss_weight * avg_p_t.item() return loss def decoder_greedy(self, batch, max_dec_step=30, emotion_classifier='built_in'): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) emotions = batch['program_label'] ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] # method 2 x = self.s_weight(q_h) logit_prob = torch.matmul(x, self.emoji_embedding.weight.transpose(0, 1)) emo_pred = torch.argmax(logit_prob, dim=-1) if emotion_classifier == "vader": context_emo = [ self.positive_emotions[0] if d['compound'] > 0 else self.negative_emotions[0] for d in batch['context_emotion_scores'] ] context_emo = torch.Tensor(context_emo) if config.USE_CUDA: context_emo = context_emo.cuda() emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, context_emo, self.emoji_embedding) elif emotion_classifier == None: emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, batch['program_label'], self.emoji_embedding) elif emotion_classifier == "built_in": emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, emo_pred, self.emoji_embedding) m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1), encoder_outputs, mask_src) m_tilde_out = self.emotion_input_encoder_2( emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src) if config.emo_combine == "att": v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src) # v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src_chosen) elif config.emo_combine == "gate": v = self.cdecoder(m_out, m_tilde_out) elif config.emo_combine == 'vader': m_weight = context_emo_scores.unsqueeze(-1).unsqueeze(-1) m_tilde_weight = 1 - m_weight v = m_weight * m_weight + m_tilde_weight * m_tilde_out ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if (config.project): out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), self.embedding_proj_in(v), (mask_src, mask_trg), attention_parameters) else: out, attn_dist = self.decoder(self.embedding(ys), v, v, (mask_src, mask_trg)) logit = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, next_word = torch.max(logit[:, -1], dim=1) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent, batch['context_emotion_scores'][0]['compound'], int( emo_pred[0].data.cpu()) def decoder_topk(self, batch, max_dec_step=30, emotion_classifier='built_in'): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) emotions = batch['program_label'] context_emo = [ self.positive_emotions[0] if d['compound'] > 0 else self.negative_emotions[0] for d in batch['context_emotion_scores'] ] context_emo = torch.Tensor(context_emo) if config.USE_CUDA: context_emo = context_emo.cuda() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] x = self.s_weight(q_h) # method 2 logit_prob = torch.matmul(x, self.emoji_embedding.weight.transpose(0, 1)) if emotion_classifier == "vader": context_emo = [ self.positive_emotions[0] if d['compound'] > 0 else self.negative_emotions[0] for d in batch['context_emotion_scores'] ] context_emo = torch.Tensor(context_emo) if config.USE_CUDA: context_emo = context_emo.cuda() emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, context_emo, self.emoji_embedding) elif emotion_classifier == None: emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, batch['program_label'], self.emoji_embedding) elif emotion_classifier == "built_in": emo_pred = torch.argmax(logit_prob, dim=-1) emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, emo_pred, self.emoji_embedding) m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1), encoder_outputs, mask_src) m_tilde_out = self.emotion_input_encoder_2( emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src) if config.emo_combine == "att": v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src) elif config.emo_combine == "gate": v = self.cdecoder(m_out, m_tilde_out) elif config.emo_combine == 'vader': m_weight = context_emo_scores.unsqueeze(-1).unsqueeze(-1) m_tilde_weight = 1 - m_weight v = m_weight * m_weight + m_tilde_weight * m_tilde_out ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if (config.project): out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), (mask_src, mask_trg), attention_parameters) else: out, attn_dist = self.decoder(self.embedding(ys), v, v, (mask_src, mask_trg)) logit = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) filtered_logit = top_k_top_p_filtering(logit[:, -1], top_k=3, top_p=0, filter_value=-float('Inf')) # Sample from the filtered distribution next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1), 1).squeeze() decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data.item() if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
class Transformer_experts(nn.Module): def __init__(self, vocab, decoder_number, model_file_path=None, is_eval=False, load_optim=False): super(Transformer_experts, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder_number = decoder_number ## multiple decoders self.decoder = MulDecoder(decoder_number, config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.decoder_key = nn.Linear(config.hidden_dim, decoder_number, bias=False) self.generator = Generator(config.hidden_dim, self.vocab_size) self.emoji_embedding = nn.Linear(64, config.emb_dim, bias=False) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if config.softmax: self.attention_activation = nn.Softmax(dim=1) else: self.attention_activation = nn.Sigmoid() #nn.Softmax() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.decoder_key.load_state_dict(state['decoder_key_state_dict']) #self.emoji_embedding.load_state_dict(state['emoji_embedding_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'decoder_key_state_dict': self.decoder_key.state_dict(), #'emoji_embedding_dict': self.emoji_embedding.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) if config.dataset == "empathetic": emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) else: encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) ## Attention over decoder q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] #q_h = encoder_outputs[:,0] logit_prob = self.decoder_key(q_h) #(bsz, num_experts) if (config.topk > 0): k_max_value, k_max_index = torch.topk(logit_prob, config.topk) a = np.empty([logit_prob.shape[0], self.decoder_number]) a.fill(float('-inf')) mask = torch.Tensor(a).cuda() logit_prob_ = mask.scatter_(1, k_max_index.cuda().long(), k_max_value) attention_parameters = self.attention_activation(logit_prob_) else: attention_parameters = self.attention_activation(logit_prob) # print("===============================================================================") # print("listener attention weight:",attention_parameters.data.cpu().numpy()) # print("===============================================================================") if (config.oracle): attention_parameters = self.attention_activation( torch.FloatTensor(batch['target_program']) * 1000).cuda() attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze( -1) # (batch_size, expert_num, 1, 1) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift), encoder_outputs, (mask_src, mask_trg), attention_parameters) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) #logit = F.log_softmax(logit,dim=-1) #fix the name later ## loss: NNL if ptr else Cross entropy if (train and config.schedule > 10): if (random.uniform(0, 1) <= (0.0001 + (1 - 0.0001) * math.exp(-1. * iter / config.schedule))): config.oracle = True else: config.oracle = False if config.softmax: loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + nn.CrossEntropyLoss()( logit_prob, torch.LongTensor( batch['program_label']).cuda()) loss_bce_program = nn.CrossEntropyLoss()( logit_prob, torch.LongTensor(batch['program_label']).cuda()).item() else: loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + nn.BCEWithLogitsLoss()( logit_prob, torch.FloatTensor( batch['target_program']).cuda()) loss_bce_program = nn.BCEWithLogitsLoss()( logit_prob, torch.FloatTensor(batch['target_program']).cuda()).item() pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1) program_acc = accuracy_score(batch["program_label"], pred_program) if (config.label_smoothing): loss_ppl = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)).item() if (train): loss.backward() self.optimizer.step() if (config.label_smoothing): return loss_ppl, math.exp(min(loss_ppl, 100)), loss_bce_program, program_acc else: return loss.item(), math.exp(min( loss.item(), 100)), loss_bce_program, program_acc def compute_act_loss(self, module): R_t = module.remainders N_t = module.n_updates p_t = R_t + N_t avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0) loss = config.act_loss_weight * avg_p_t.item() return loss def decoder_greedy(self, batch, max_dec_step=30): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) ## Attention over decoder q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] #q_h = encoder_outputs[:,0] logit_prob = self.decoder_key(q_h) if (config.topk > 0): k_max_value, k_max_index = torch.topk(logit_prob, config.topk) a = np.empty([logit_prob.shape[0], self.decoder_number]) a.fill(float('-inf')) mask = torch.Tensor(a).cuda() logit_prob = mask.scatter_(1, k_max_index.cuda().long(), k_max_value) attention_parameters = self.attention_activation(logit_prob) if (config.oracle): attention_parameters = self.attention_activation( torch.FloatTensor(batch['target_program']) * 1000).cuda() attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze( -1) # (batch_size, expert_num, 1, 1) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if (config.project): out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), (mask_src, mask_trg), attention_parameters) else: out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (mask_src, mask_trg), attention_parameters) logit = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) #logit = F.log_softmax(logit,dim=-1) #fix the name later _, next_word = torch.max(logit[:, -1], dim=1) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent def decoder_topk(self, batch, max_dec_step=30): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) ## Attention over decoder q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] #q_h = encoder_outputs[:,0] logit_prob = self.decoder_key(q_h) if (config.topk > 0): k_max_value, k_max_index = torch.topk(logit_prob, config.topk) a = np.empty([logit_prob.shape[0], self.decoder_number]) a.fill(float('-inf')) mask = torch.Tensor(a).cuda() logit_prob = mask.scatter_(1, k_max_index.cuda().long(), k_max_value) attention_parameters = self.attention_activation(logit_prob) if (config.oracle): attention_parameters = self.attention_activation( torch.FloatTensor(batch['target_program']) * 1000).cuda() attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze( -1) # (batch_size, expert_num, 1, 1) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if (config.project): out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), (mask_src, mask_trg), attention_parameters) else: out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (mask_src, mask_trg), attention_parameters) logit = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) filtered_logit = top_k_top_p_filtering(logit[:, -1], top_k=3, top_p=0, filter_value=-float('Inf')) # Sample from the filtered distribution next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1), 1).squeeze() decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
def __init__(self, vocab, decoder_number, model_file_path=None, is_eval=False, load_optim=False): super(Transformer_experts, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder_number = decoder_number ## multiple decoders self.decoder = MulDecoder(decoder_number, config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.decoder_key = nn.Linear(config.hidden_dim, decoder_number, bias=False) self.generator = Generator(config.hidden_dim, self.vocab_size) self.emoji_embedding = nn.Linear(64, config.emb_dim, bias=False) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if config.softmax: self.attention_activation = nn.Softmax(dim=1) else: self.attention_activation = nn.Sigmoid() #nn.Softmax() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.decoder_key.load_state_dict(state['decoder_key_state_dict']) #self.emoji_embedding.load_state_dict(state['emoji_embedding_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = ""
class CvaeNAD(nn.Module): def __init__(self, vocab, emo_number, model_file_path=None, is_eval=False, load_optim=False): super(CvaeNAD, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.r_encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) if config.num_var_layers > 0: self.decoder = VarDecoder2(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) else: self.decoder = VarDecoder3(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.generator = Generator(config.hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.r_encoder.load_state_dict(state['r_encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'r_encoder_state_dict': self.r_encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Response encode mask_res = batch["posterior_batch"].data.eq( config.PAD_idx).unsqueeze(1) posterior_mask = self.embedding(batch["posterior_mask"]) r_encoder_outputs = self.r_encoder( self.embedding(batch["posterior_batch"]), mask_res) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) meta = self.embedding(batch["program_label"]) # Decode mask_trg = dec_batch.data.eq(config.PAD_idx).unsqueeze(1) latent_dim = meta.size()[-1] meta = meta.repeat(1, dec_batch.size(1)).view(dec_batch.size(0), dec_batch.size(1), latent_dim) pre_logit, attn_dist, mean, log_var = self.decoder( meta, encoder_outputs, r_encoder_outputs, (mask_src, mask_res, mask_trg)) if not train: pre_logit, attn_dist, _, _ = self.decoder( meta, encoder_outputs, None, (mask_src, None, mask_trg)) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) ## loss: NNL if ptr else Cross entropy loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"], mean["prior"], log_var["prior"]) kld_loss = torch.mean(kld_loss) kl_weight = min(iter / config.full_kl_step, 1) if config.full_kl_step > 0 else 1.0 loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss if (train): loss.backward() # clip gradient nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm) self.optimizer.step() return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item() def decoder_greedy(self, batch, max_dec_step=50): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) meta = self.embedding(batch["program_label"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) mask_trg = torch.ones((enc_batch.size(0), 50)) meta_size = meta.size() meta = meta.repeat(1, 50).view(meta_size[0], 50, meta_size[1]) out, attn_dist, _, _ = self.decoder(meta, encoder_outputs, None, (mask_src, None, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, batch_out = torch.max(prob, dim=1) batch_out = batch_out.data.cpu().numpy() sentences = [] for sent in batch_out: st = '' for w in sent: if w == config.EOS_idx: break else: st += self.vocab.index2word[w] + ' ' sentences.append(st) return sentences def decoder_greedy_po(self, batch, max_dec_step=50): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) ## Response encode mask_res = batch["posterior_batch"].data.eq( config.PAD_idx).unsqueeze(1) posterior_mask = self.embedding(batch["posterior_mask"]) r_encoder_outputs = self.r_encoder( self.embedding(batch["posterior_batch"]), mask_res) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) meta = self.embedding(batch["program_label"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) mask_trg = torch.ones((enc_batch.size(0), 50)) meta_size = meta.size() meta = meta.repeat(1, 50).view(meta_size[0], 50, meta_size[1]) out, attn_dist, mean, log_var = self.decoder( meta, encoder_outputs, r_encoder_outputs, (mask_src, mask_res, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, batch_out = torch.max(prob, dim=1) batch_out = batch_out.data.cpu().numpy() sentences = [] for sent in batch_out: st = '' for w in sent: if w == config.EOS_idx: break else: st += self.vocab.index2word[w] + ' ' sentences.append(st) return sentences
class Summarizer(nn.Module): def __init__(self, is_draft, toeknizer, model_file_path=None, is_eval=False, load_optim=False): super(Summarizer, self).__init__() self.is_draft = is_draft self.toeknizer = toeknizer if is_draft: self.encoder = BertModel.from_pretrained('bert-base-uncased') else: BertForMaskedLM.from_pretrained('bert-base-uncased') self.encoder.eval() # always in eval mode self.embedding = self.encoder.embeddings self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') self.decoder = Decoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.generator = Generator(config.hidden_dim, config.vocab_size) self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.embedding = self.embedding.eval() if is_eval: self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) print("LOSS", state['current_loss']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (config.USE_CUDA): self.encoder = self.encoder.cuda(device=0) self.decoder = self.decoder.cuda(device=0) self.generator = self.generator.cuda(device=0) self.criterion = self.criterion.cuda(device=0) self.embedding = self.embedding.cuda(device=0) self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, loss, iter, r_avg): state = { 'iter': iter, 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), #'optimizer': self.optimizer.state_dict(), 'current_loss': loss } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}'.format(iter, loss, r_avg)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, train=True): ## pad and other stuff input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch( batch) dec_batch, dec_mask_batch, dec_index_batch, copy_gate, copy_ptr = get_output_from_batch( batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() with torch.no_grad(): # encoder_outputs are hidden states from transformer encoder_outputs, _ = self.encoder( input_ids_batch, token_type_ids=example_index_batch, attention_mask=input_mask_batch, output_all_encoded_layers=False) # # Draft Decoder sos_token = torch.LongTensor([config.SOS_idx] * input_ids_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda(device=0) dec_batch_shift = torch.cat( (sos_token, dec_batch[:, :-1]), 1) # shift the decoder input (summary) by one step mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit1, attn_dist1 = self.decoder(self.embedding(dec_batch_shift), encoder_outputs, (None, mask_trg)) # print(pre_logit1.size()) ## compute output dist logit1 = self.generator(pre_logit1, attn_dist1, enc_batch_extend_vocab, extra_zeros, copy_gate=copy_gate, copy_ptr=copy_ptr, mask_trg=mask_trg) ## loss: NNL if ptr else Cross entropy loss1 = self.criterion(logit1.contiguous().view(-1, logit1.size(-1)), dec_batch.contiguous().view(-1)) # Refine Decoder - train using gold label TARGET 'TODO: turn gold-target-text into BERT insertable representation' pre_logit2, attn_dist2 = self.generate_refinement_output( encoder_outputs, dec_batch, dec_index_batch, extra_zeros, dec_mask_batch) # pre_logit2, attn_dist2 = self.decoder(self.embedding(encoded_gold_target),encoder_outputs, (None,mask_trg)) logit2 = self.generator(pre_logit2, attn_dist2, enc_batch_extend_vocab, extra_zeros, copy_gate=copy_gate, copy_ptr=copy_ptr, mask_trg=None) loss2 = self.criterion(logit2.contiguous().view(-1, logit2.size(-1)), dec_batch.contiguous().view(-1)) loss = loss1 + loss2 if train: loss.backward() self.optimizer.step() return loss def eval_one_batch(self, batch): draft_seq_batch = self.decoder_greedy(batch) d_seq_input_ids_batch, d_seq_input_mask_batch, d_seq_example_index_batch = text_input2bert_input( draft_seq_batch, self.tokenizer) pre_logit2, attn_dist2 = self.generate_refinement_output( encoder_outputs, d_seq_input_ids_batch, d_seq_example_index_batch, extra_zeros, d_seq_input_mask_batch) decoded_words, sent = [], [] for out, attn_dist in zip(pre_logit2, attn_dist2): prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, copy_gate=copy_gate, copy_ptr=copy_ptr, mask_trg=None) _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append( self.tokenizer.convert_ids_to_tokens(next_word.tolist())) for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>' or e.strip() == '<PAD>': break else: st += e + ' ' sent.append(st) return sent def generate_refinement_output(self, encoder_outputs, input_ids_batch, example_index_batch, extra_zeros, input_mask_batch): # mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) # decoded_words = [] logits, attns = [], [] for i in range(config.max_dec_step): # print(i) with torch.no_grad(): # Additionally mask the location of i. context_input_mask_batch = [] # print(context_input_mask_batch.shape) # (2,512) (batch_size, seq_len) for mask in input_mask_batch: mask[i] = 0 context_input_mask_batch.append(mask) context_input_mask_batch = torch.stack( context_input_mask_batch) #.cuda(device=0) # self.embedding = self.embedding.cuda(device=0) context_vector, _ = self.encoder( input_ids_batch, token_type_ids=example_index_batch, attention_mask=context_input_mask_batch, output_all_encoded_layers=False) if config.USE_CUDA: context_vector = context_vector.cuda(device=0) # decoder input size == encoder output size == (batch_size, 512, 768) out, attn_dist = self.decoder(context_vector, encoder_outputs, (None, None)) logits.append(out[:, i:i + 1, :]) attns.append(attn_dist[:, i:i + 1, :]) logits = torch.cat(logits, dim=1) attns = torch.cat(attns, dim=1) # print(logits.size(), attns.size()) return logits, attns def decoder_greedy(self, batch): input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch( batch) # mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) with torch.no_grad(): encoder_outputs, _ = self.encoder( input_ids_batch, token_type_ids=enc_batch_extend_vocab, attention_mask=input_mask_batch, output_all_encoded_layers=False) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(config.max_dec_step): out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (None, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros) _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append( self.tokenizer.convert_ids_to_tokens(next_word.tolist())) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>' or e == '<PAD>': break else: st += e + ' ' sent.append(st) return sent
def __init__(self, is_draft, toeknizer, model_file_path=None, is_eval=False, load_optim=False): super(Summarizer, self).__init__() self.is_draft = is_draft self.toeknizer = toeknizer if is_draft: self.encoder = BertModel.from_pretrained('bert-base-uncased') else: BertForMaskedLM.from_pretrained('bert-base-uncased') self.encoder.eval() # always in eval mode self.embedding = self.encoder.embeddings self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') self.decoder = Decoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.generator = Generator(config.hidden_dim, config.vocab_size) self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.embedding = self.embedding.eval() if is_eval: self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) print("LOSS", state['current_loss']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (config.USE_CUDA): self.encoder = self.encoder.cuda(device=0) self.decoder = self.decoder.cuda(device=0) self.generator = self.generator.cuda(device=0) self.criterion = self.criterion.cuda(device=0) self.embedding = self.embedding.cuda(device=0) self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = ""
class Seq2SPG(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(Seq2SPG, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.preptrained) self.encoder = nn.LSTM(config.emb_dim, config.hidden_dim, config.hop, bidirectional=False, batch_first=True, dropout=0.2) self.encoder2decoder = nn.Linear(config.hidden_dim, config.hidden_dim) self.decoder = LSTMAttentionDot(config.emb_dim, config.hidden_dim, batch_first=True) self.memory = MLP( config.hidden_dim + config.emb_dim, [config.private_dim1, config.private_dim2, config.private_dim3], config.hidden_dim) self.dec_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim) self.mem_gate = nn.Linear(config.hidden_dim, 2 * config.hidden_dim) self.generator = Generator(config.hidden_dim, self.vocab_size) self.hooks = { } #Save the model structure of each task as masks of the parameters if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if is_eval: self.encoder = self.encoder.eval() self.encoder2decoder = self.encoder2decoder.eval() self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.embedding = self.embedding.eval() self.memory = self.memory.eval() self.dec_gate = self.dec_gate.eval() self.mem_gate = self.mem_gate.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if config.use_sgd: self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) print("LOSS", state['current_loss']) self.encoder.load_state_dict(state['encoder_state_dict']) self.encoder2decoder.load_state_dict( state['encoder2decoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) self.memory.load_state_dict(state['memory_dict']) self.dec_gate.load_state_dict(state['dec_gate_dict']) self.mem_gate.load_state_dict(state['mem_gate_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if (config.USE_CUDA): self.encoder = self.encoder.cuda() self.encoder2decoder = self.encoder2decoder.cuda() self.decoder = self.decoder.cuda() self.generator = self.generator.cuda() self.criterion = self.criterion.cuda() self.embedding = self.embedding.cuda() self.memory = self.memory.cuda() self.dec_gate = self.dec_gate.cuda() self.mem_gate = self.mem_gate.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b, log=False, d="tmaml_sim_model"): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'encoder2decoder_state_dict': self.encoder2decoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'memory_dict': self.memory.state_dict(), 'dec_gate_dict': self.dec_gate.state_dict(), 'mem_gate_dict': self.mem_gate.state_dict(), #'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } if log: model_save_path = os.path.join( d, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) else: model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def get_state(self, batch): """Get cell states and hidden states for LSTM""" batch_size = batch.size(0) \ if self.encoder.batch_first else batch.size(1) h0_encoder = Variable(torch.zeros(self.encoder.num_layers, batch_size, config.hidden_dim), requires_grad=False) c0_encoder = Variable(torch.zeros(self.encoder.num_layers, batch_size, config.hidden_dim), requires_grad=False) return h0_encoder.cuda(), c0_encoder.cuda() def compute_hooks(self, task): """Compute the masks of the private module""" current_layer = 3 out_mask = torch.ones(self.memory.output_size) self.hooks[task] = {} self.hooks[task]["w_hooks"] = {} self.hooks[task]["b_hooks"] = {} while (current_layer >= 0): connections = self.memory.layers[current_layer].weight.data output_size, input_size = connections.shape mask = connections.abs() > 0.05 in_mask = torch.zeros(input_size) for index, line in enumerate(mask): if (out_mask[index] == 1): torch.max(in_mask, (line.cpu() != 0).float(), out=in_mask) if (config.USE_CUDA): self.hooks[task]["b_hooks"][current_layer] = out_mask.cuda() self.hooks[task]["w_hooks"][current_layer] = torch.mm( out_mask.unsqueeze(1), in_mask.unsqueeze(0)).cuda() else: self.hooks[task]["b_hooks"][current_layer] = out_mask self.hooks[task]["w_hooks"][current_layer] = torch.mm( out_mask.unsqueeze(1), in_mask.unsqueeze(0)) out_mask = in_mask current_layer -= 1 def register_hooks(self, task): if "hook_handles" not in self.hooks[task]: self.hooks[task]["hook_handles"] = [] for i, l in enumerate(self.memory.layers): self.hooks[task]["hook_handles"].append( l.bias.register_hook(make_hook( self.hooks[task]["b_hooks"][i]))) self.hooks[task]["hook_handles"].append( l.weight.register_hook( make_hook(self.hooks[task]["w_hooks"][i]))) def unhook(self, task): for handle in self.hooks[task]["hook_handles"]: handle.remove() self.hooks[task]["hook_handles"] = [] def train_one_batch(self, batch, train=True, mode="pretrain", task=0): enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode self.h0_encoder, self.c0_encoder = self.get_state(enc_batch) src_h, (src_h_t, src_c_t) = self.encoder(self.embedding(enc_batch), (self.h0_encoder, self.c0_encoder)) h_t = src_h_t[-1] c_t = src_c_t[-1] # Decode decoder_init_state = nn.Tanh()(self.encoder2decoder(h_t)) sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) target_embedding = self.embedding(dec_batch_shift) ctx = src_h.transpose(0, 1) trg_h, (_, _) = self.decoder(target_embedding, (decoder_init_state, c_t), ctx) #Memory mem_h_input = torch.cat( (decoder_init_state.unsqueeze(1), trg_h[:, 0:-1, :]), 1) mem_input = torch.cat((target_embedding, mem_h_input), 2) mem_output = self.memory(mem_input) #Combine gates = self.dec_gate(trg_h) + self.mem_gate(mem_output) decoder_gate, memory_gate = gates.chunk(2, 2) decoder_gate = F.sigmoid(decoder_gate) memory_gate = F.sigmoid(memory_gate) pre_logit = F.tanh(decoder_gate * trg_h + memory_gate * mem_output) logit = self.generator(pre_logit) if mode == "pretrain": loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if train: loss.backward() self.optimizer.step() if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss elif mode == "select": loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if (train): l1_loss = 0.0 for p in self.memory.parameters(): l1_loss += torch.sum(torch.abs(p)) loss += 0.0005 * l1_loss loss.backward() self.optimizer.step() self.compute_hooks(task) if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss else: loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if (train): self.register_hooks(task) loss.backward() self.optimizer.step() self.unhook(task) if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss
class Transformer(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(Transformer, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.pretrain_emb) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) ## multiple decoders self.decoder = Decoder(config.emb_dim, hidden_size=config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter) self.generator = Generator(config.hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.lut.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if config.label_smoothing: self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if config.noam: self.optimizer = NoamOpt( config.hidden_dim, 1, 8000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if load_optim: self.optimizer.load_state_dict(state['optimizer']) self.eval() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, iter, train=True): enc_batch = batch["review_batch"] enc_batch_extend_vocab = batch["review_ext_batch"] oovs = batch["oovs"] max_oov_length = len( sorted(oovs, key=lambda i: len(i), reverse=True)[0]) extra_zeros = Variable(torch.zeros( (enc_batch.size(0), max_oov_length))).to(config.device) dec_batch = batch["tags_batch"] dec_ext_batch = batch["tags_ext_batch"] if config.noam: self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Embedding - context mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze( 1) # (bsz, src_len)->(bsz, 1, src_len) # emb_mask = self.embedding(batch["mask_context"]) # src_emb = self.embedding(enc_batch)+emb_mask src_emb = self.embedding(enc_batch) encoder_outputs = self.encoder(src_emb, mask_src) # (bsz, src_len, emb_dim) sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1).to( config.device) # (bsz, 1) dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) # (bsz, tgt_len) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift), encoder_outputs, (mask_src, mask_trg)) logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros) #logit = F.log_softmax(logit,dim=-1) #fix the name later ## loss: NNL if ptr else Cross entropy if config.pointer_gen: loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_ext_batch.contiguous().view(-1)) else: loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if config.label_smoothing: loss_ppl = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_ext_batch.contiguous().view(-1) if config.pointer_gen else dec_batch.contiguous().view(-1)) if train: loss.backward() self.optimizer.step() if config.label_smoothing: if torch.isnan(loss_ppl).sum().item() != 0 or torch.isinf( loss_ppl).sum().item() != 0: print("check") pdb.set_trace() return loss_ppl.item(), math.exp(min(loss_ppl.item(), 100)), 0, 0 else: return loss.item(), math.exp(min(loss.item(), 100)), 0, 0 def compute_act_loss(self, module): R_t = module.remainders N_t = module.n_updates p_t = R_t + N_t avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0) loss = config.act_loss_weight * avg_p_t.item() return loss def decoder_greedy(self, batch, max_dec_step=30): enc_batch = batch["review_batch"] enc_batch_extend_vocab = batch["review_ext_batch"] oovs = batch["oovs"] max_oov_length = len( sorted(oovs, key=lambda i: len(i), reverse=True)[0]) extra_zeros = Variable(torch.zeros( (enc_batch.size(0), max_oov_length))).to(config.device) ## Encode - context mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze( 1) # (bsz, src_len)->(bsz, 1, src_len) # emb_mask = self.embedding(batch["mask_context"]) # src_emb = self.embedding(enc_batch) + emb_mask # todo eos or sentence embedding?? src_emb = self.embedding(enc_batch) encoder_outputs = self.encoder(src_emb, mask_src) # (bsz, src_len, emb_dim) enc_ext_batch = enc_batch_extend_vocab # ys = torch.ones(1, 1).fill_(config.SOS_idx).long() ys = torch.zeros(enc_batch.size(0), 1).fill_(config.SOS_idx).long().to( config.device) # when testing, we set bsz into 1 mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if config.project: out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), (mask_src, mask_trg)) else: out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (mask_src, mask_trg)) prob = self.generator(out, attn_dist, enc_ext_batch, extra_zeros) _, next_word = torch.max(prob[:, -1], dim=1) # bsz=1 cur_words = [] for i_batch, ni in enumerate(next_word.view(-1)): if ni.item() == config.EOS_idx: cur_words.append('<EOS>') elif ni.item() in self.vocab.index2word: cur_words.append(self.vocab.index2word[ni.item()]) else: cur_words.append(oovs[i_batch][ni.item() - self.vocab.n_words]) decoded_words.append(cur_words) next_word = next_word.data[0] if next_word.item() not in self.vocab.index2word: next_word = torch.tensor(config.UNK_idx) ys = torch.cat([ ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).to( config.device) ], dim=1).to(config.device) # if config.USE_CUDA: # ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word).cuda()], dim=1) # ys = ys.cuda() # else: # ys = torch.cat([ys, torch.zeros(enc_batch.size(0), 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
class Transformer(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(Transformer, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab,config.preptrained) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter,universal=config.universal) self.decoder = Decoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth,total_value_depth=config.depth, filter_size=config.filter,universal=config.universal) self.generator = Generator(config.hidden_dim,self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.generator.proj.weight = self.embedding.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if is_eval: self.encoder = self.encoder.eval() self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.embedding = self.embedding.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if(config.noam): self.optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if config.use_sgd: self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location= lambda storage, location: storage) print("LOSS",state['current_loss']) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if (config.USE_CUDA): self.encoder = self.encoder.cuda() self.decoder = self.decoder.cuda() self.generator = self.generator.cuda() self.criterion = self.criterion.cuda() self.embedding = self.embedding.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g,f1_b,ent_g,ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), #'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join(self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format(iter,running_avg_ppl,f1_g,f1_b,ent_g,ent_b) ) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, train=True): ## pad and other stuff enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(enc_batch),mask_src) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),encoder_outputs, (mask_src,mask_trg)) ## compute output dist logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab, extra_zeros) ## loss: NNL if ptr else Cross entropy loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if(train): loss.backward() self.optimizer.step() if(config.label_smoothing): loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss
class Transformer(nn.Module): def __init__(self, vocab, model_file_path=None, is_eval=False, load_optim=False): super(Transformer, self).__init__() self.vocab = vocab self.vocab_size = vocab.n_words self.embedding = share_embedding(self.vocab, config.preptrained) self.encoder = Encoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.decoder = Decoder(config.emb_dim, config.hidden_dim, num_layers=config.hop, num_heads=config.heads, total_key_depth=config.depth, total_value_depth=config.depth, filter_size=config.filter, universal=config.universal) self.generator = Generator(config.hidden_dim, self.vocab_size) if config.weight_sharing: # Share the weight matrix between target word embedding & the final # logit dense layer self.generator.proj.weight = self.embedding.weight self.criterion = nn.NLLLoss(ignore_index=config.PAD_idx) if (config.label_smoothing): self.criterion = LabelSmoothing(size=self.vocab_size, padding_idx=config.PAD_idx, smoothing=0.1) self.criterion_ppl = nn.NLLLoss(ignore_index=config.PAD_idx) if is_eval: self.encoder = self.encoder.eval() self.decoder = self.decoder.eval() self.generator = self.generator.eval() self.embedding = self.embedding.eval() self.optimizer = torch.optim.Adam(self.parameters(), lr=config.lr) if (config.noam): self.optimizer = NoamOpt( config.hidden_dim, 1, 4000, torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) if config.use_sgd: self.optimizer = torch.optim.SGD(self.parameters(), lr=config.lr) if model_file_path is not None: print("loading weights") state = torch.load(model_file_path, map_location=lambda storage, location: storage) print("LOSS", state['current_loss']) self.encoder.load_state_dict(state['encoder_state_dict']) self.decoder.load_state_dict(state['decoder_state_dict']) self.generator.load_state_dict(state['generator_dict']) self.embedding.load_state_dict(state['embedding_dict']) if (load_optim): self.optimizer.load_state_dict(state['optimizer']) if (config.USE_CUDA): self.encoder = self.encoder.cuda() self.decoder = self.decoder.cuda() self.generator = self.generator.cuda() self.criterion = self.criterion.cuda() self.embedding = self.embedding.cuda() self.model_dir = config.save_path if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self.best_path = "" def save_model(self, running_avg_ppl, iter, f1_g, f1_b, ent_g, ent_b): state = { 'iter': iter, 'encoder_state_dict': self.encoder.state_dict(), 'decoder_state_dict': self.decoder.state_dict(), 'generator_dict': self.generator.state_dict(), 'embedding_dict': self.embedding.state_dict(), # 'optimizer': self.optimizer.state_dict(), 'current_loss': running_avg_ppl } model_save_path = os.path.join( self.model_dir, 'model_{}_{:.4f}_{:.4f}_{:.4f}_{:.4f}_{:.4f}'.format( iter, running_avg_ppl, f1_g, f1_b, ent_g, ent_b)) self.best_path = model_save_path torch.save(state, model_save_path) def train_one_batch(self, batch, train=True): # pad and other stuff enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = \ get_input_from_batch(batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() # Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift), encoder_outputs, (mask_src, mask_trg)) # compute output dist logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab, extra_zeros) # loss: NNL if ptr else Cross entropy loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if (config.act): loss += self.compute_act_loss(self.encoder) loss += self.compute_act_loss(self.decoder) if (train): loss.backward() self.optimizer.step() if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss def compute_act_loss(self, module): R_t = module.remainders N_t = module.n_updates p_t = R_t + N_t avg_p_t = torch.sum(torch.sum(p_t, dim=1) / p_t.size(1)) / p_t.size(0) loss = config.act_loss_weight * avg_p_t.item() return loss def decoder_greedy(self, batch): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(config.max_dec_step): out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (mask_src, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros) _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent def score_sentence(self, batch): # pad and other stuff enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) cand_batch = batch["cand_index"] hit_1 = 0 for i, b in enumerate(enc_batch): # Encode mask_src = b.unsqueeze(0).data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(b.unsqueeze(0)), mask_src) rank = {} for j, c in enumerate(cand_batch[i]): if config.USE_CUDA: c = c.cuda() # Decode sos_token = torch.LongTensor( [config.SOS_idx] * b.unsqueeze(0).size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat( (sos_token, c.unsqueeze(0)[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder( self.embedding(dec_batch_shift), encoder_outputs, (mask_src, mask_trg)) # compute output dist logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab[i].unsqueeze(0), extra_zeros) loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), c.unsqueeze(0).contiguous().view(-1)) # print("CANDIDATE {}".format(j), loss.item(), math.exp(min(loss.item(), 100))) rank[j] = math.exp(min(loss.item(), 100)) s = sorted(rank.items(), key=lambda x: x[1], reverse=False) if ( s[1][0] == 19 ): # because the candidate are sorted in revers order ====> last (19) is the correct one hit_1 += 1 return hit_1 / float(len(enc_batch))