def decoder_topk(self, batch, max_dec_step=30): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if (config.project): out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), (mask_src, mask_trg)) else: out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (mask_src, mask_trg)) logit = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) filtered_logit = top_k_top_p_filtering(logit[:, -1], top_k=3, top_p=0, filter_value=-float('Inf')) # Sample from the filtered distribution next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1), 1).squeeze() decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
def train_one_batch(self, batch, train=True): ## pad and other stuff enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(enc_batch),mask_src) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),encoder_outputs, (mask_src,mask_trg)) ## compute output dist logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab, extra_zeros) ## loss: NNL if ptr else Cross entropy loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if(train): loss.backward() self.optimizer.step() if(config.label_smoothing): loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss
def train_one_batch(self, batch, train=True): if train: self.model.train() enc_batch, enc_mask, _, enc_batch_extend_vocab, extra_zeros, _, _ = \ get_input_from_batch(batch) dec_batch, dec_mask, _, _, _ = get_output_from_batch(batch) dec_batch_input, dec_batch_output = dec_batch[:, :-1], dec_batch[:, 1:] dec_mask = dec_mask[:, :-1] self.optimizer.zero_grad() logit, *_ = self.model( enc_batch, attention_mask=enc_mask, decoder_input_ids=dec_batch_input, decoder_attention_mask=dec_mask, ) # loss: NNL if ptr else Cross entropy loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch_output.contiguous().view(-1)) if (config.act): loss += self.compute_act_loss(self.encoder) loss += self.compute_act_loss(self.decoder) if (train): loss.backward() self.optimizer.step() if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(loss.item()), loss
def decoder_greedy(self, batch, max_dec_step=50): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) meta = self.embedding(batch["program_label"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) mask_trg = torch.ones((enc_batch.size(0), 50)) meta_size = meta.size() meta = meta.repeat(1, 50).view(meta_size[0], 50, meta_size[1]) out, attn_dist, _, _ = self.decoder(meta, encoder_outputs, None, (mask_src, None, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, batch_out = torch.max(prob, dim=1) batch_out = batch_out.data.cpu().numpy() sentences = [] for sent in batch_out: st = '' for w in sent: if w == config.EOS_idx: break else: st += self.vocab.index2word[w] + ' ' sentences.append(st) return sentences
def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Response encode mask_res = batch["posterior_batch"].data.eq( config.PAD_idx).unsqueeze(1) posterior_mask = self.embedding(batch["posterior_mask"]) r_encoder_outputs = self.r_encoder( self.embedding(batch["posterior_batch"]), mask_res) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) meta = self.embedding(batch["program_label"]) # Decode mask_trg = dec_batch.data.eq(config.PAD_idx).unsqueeze(1) latent_dim = meta.size()[-1] meta = meta.repeat(1, dec_batch.size(1)).view(dec_batch.size(0), dec_batch.size(1), latent_dim) pre_logit, attn_dist, mean, log_var = self.decoder( meta, encoder_outputs, r_encoder_outputs, (mask_src, mask_res, mask_trg)) if not train: pre_logit, attn_dist, _, _ = self.decoder( meta, encoder_outputs, None, (mask_src, None, mask_trg)) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) ## loss: NNL if ptr else Cross entropy loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"], mean["prior"], log_var["prior"]) kld_loss = torch.mean(kld_loss) kl_weight = min(iter / config.full_kl_step, 1) if config.full_kl_step > 0 else 1.0 loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss if (train): loss.backward() # clip gradient nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm) self.optimizer.step() return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item()
def train_one_batch(self, batch, train=True): ## pad and other stuff enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _, _ = get_output_from_batch(batch) if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode self.h0_encoder, self.c0_encoder = self.get_state(enc_batch) src_h, (src_h_t, src_c_t) = self.encoder( self.embedding(enc_batch), (self.h0_encoder, self.c0_encoder)) h_t = src_h_t[-1] c_t = src_c_t[-1] self.h0_encoder_r, self.c0_encoder_r = self.get_state(dec_batch) src_h_r, (src_h_t_r, src_c_t_r) = self.encoder_r( self.embedding(dec_batch), (self.h0_encoder_r, self.c0_encoder_r)) h_t_r = src_h_t_r[-1] c_t_r = src_c_t_r[-1] #sample and reparameter z_sample, mu, var = self.represent(torch.cat((h_t_r, h_t), 1)) p_z_sample, p_mu, p_var = self.prior(h_t) # Decode decoder_init_state = nn.Tanh()(self.encoder2decoder(torch.cat((z_sample, h_t), 1))) sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1) target_embedding = self.embedding(dec_batch_shift) ctx = src_h.transpose(0, 1) trg_h, (_, _) = self.decoder( target_embedding, (decoder_init_state, c_t), ctx ) pre_logit = trg_h logit = self.generator(pre_logit) ## loss: NNL if ptr else Cross entropy re_loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) kl_losses = 0.5 * torch.sum(torch.exp(var - p_var) + (mu - p_mu) ** 2 / torch.exp(p_var) - 1. - var + p_var, 1) kl_loss = torch.mean(kl_losses) latent_logit = self.mlp_b(torch.cat((z_sample, h_t), 1)).unsqueeze(1) latent_logit = F.log_softmax(latent_logit,dim=-1) latent_logits = latent_logit.repeat(1, logit.size(1), 1) bow_loss = self.criterion(latent_logits.contiguous().view(-1, latent_logits.size(-1)), dec_batch.contiguous().view(-1)) loss = re_loss + 0.48 * kl_loss + bow_loss if(train): loss.backward() self.optimizer.step() if(config.label_smoothing): s_loss = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return s_loss.item(), math.exp(min(s_loss.item(), 100)), loss.item(), re_loss.item(), kl_loss.item(), bow_loss.item()
def decoder_greedy(self, batch, max_dec_step=50): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) meta = self.embedding(batch["program_label"]) if config.dataset == "empathetic": meta = meta - meta encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) if config.model == "cvaetrs": kld_loss, z = self.latent_layer(encoder_outputs[:, 0], None, train=False) ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): input_vector = self.embedding(ys) if config.model == "cvaetrs": input_vector[:, 0] = input_vector[:, 0] + z + meta else: input_vector[:, 0] = input_vector[:, 0] + meta out, attn_dist = self.decoder(input_vector, encoder_outputs, (mask_src, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) if config.USE_CUDA: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) ys = ys.cuda() else: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
def decoder_greedy_po(self, batch, max_dec_step=50): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) meta = self.embedding(batch["program_label"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) ys = torch.ones(enc_batch.shape[0], 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() # print('=====================ys========================') # print(ys) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): out, attn_dist = self.decoder( self.embedding(ys) + meta.unsqueeze(1), encoder_outputs, (mask_src, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, next_word = torch.max(prob[:, -1], dim=1) # print('=====================next_word1========================') # print(next_word) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) #next_word = next_word.data[0] # print('=====================next_word2========================') # print(next_word) if config.USE_CUDA: # print('=====================shape========================') # print(ys.shape, next_word.shape) ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) ys = ys.cuda() else: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) # print('=====================new_ys========================') # print(ys) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder(self.embedding(enc_batch)+emb_mask,mask_src) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift),encoder_outputs, (mask_src,mask_trg)) ## compute output dist logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) #logit = F.log_softmax(logit,dim=-1) #fix the name later ## loss: NNL if ptr else Cross entropy loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) #multi-task if config.multitask: #q_h = torch.mean(encoder_outputs,dim=1) q_h = encoder_outputs[:,0] logit_prob = self.decoder_key(q_h) loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + nn.CrossEntropyLoss()(logit_prob,torch.LongTensor(batch['program_label']).cuda()) loss_bce_program = nn.CrossEntropyLoss()(logit_prob,torch.LongTensor(batch['program_label']).cuda()).item() pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1) program_acc = accuracy_score(batch["program_label"], pred_program) if(config.label_smoothing): loss_ppl = self.criterion_ppl(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)).item() if(train): loss.backward() self.optimizer.step() if config.multitask: if(config.label_smoothing): return loss_ppl, math.exp(min(loss_ppl, 100)), loss_bce_program, program_acc else: return loss.item(), math.exp(min(loss.item(), 100)), loss_bce_program, program_acc else: if(config.label_smoothing): return loss_ppl, math.exp(min(loss_ppl, 100)), 0, 0 else: return loss.item(), math.exp(min(loss.item(), 100)), 0, 0
def train_n_batch(self, batchs, iter, train=True): if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() for batch in batchs: enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) meta = self.embedding(batch["program_label"]) if config.dataset=="empathetic": meta = meta-meta # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token,dec_batch[:, :-1]),1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist, mean, log_var, probs= self.decoder(self.embedding(dec_batch_shift)+meta.unsqueeze(1),encoder_outputs, True, (mask_src,mask_trg)) ## compute output dist logit = self.generator(pre_logit,attn_dist,enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) ## loss: NNL if ptr else Cross entropy sbow = dec_batch #[batch, seq_len] seq_len = sbow.size(1) loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if config.model=="cvaetrs": loss_aux = 0 for prob in probs: sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2) sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len] loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1)) kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"]) kld_loss = torch.mean(kld_loss) kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1) #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0 loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux elbo = loss_rec+kld_loss else: loss = loss_rec elbo = loss_rec kld_loss = torch.Tensor([0]) loss_aux = torch.Tensor([0]) loss.backward() # clip gradient nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm) self.optimizer.step() return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item()
def forward(self, batch): enc_batch, enc_mask, _, enc_batch_extend_vocab, extra_zeros, _, _ = \ get_input_from_batch(batch) dec_batch, dec_mask, _, _, _ = get_output_from_batch(batch) dec_batch_input, dec_batch_output = dec_batch[:, :-1], dec_batch[:, 1:] dec_mask = dec_mask[:, :-1] self.optimizer.zero_grad() loss = self.model( input_ids=enc_batch, decoder_input_ids=dec_batch_input, lm_labels=dec_batch_output, attention_mask=enc_mask, decoder_attention_mask=dec_mask, )[0] return loss.item(), math.exp(min(loss.item(), 600)), loss
def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) meta = self.embedding(batch["program_label"]) emb_mask = self.embedding(batch["input_mask"]) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder( self.embedding(dec_batch_shift) + meta.unsqueeze(1), encoder_outputs, (mask_src, mask_trg)) #+meta.unsqueeze(1) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) #logit = F.log_softmax(logit,dim=-1) #fix the name later ## loss: NNL if ptr else Cross entropy loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if (train): loss.backward() self.optimizer.step() return loss.item(), math.exp(min(loss.item(), 100)), 0
def score_sentence(self, batch): # pad and other stuff enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) cand_batch = batch["cand_index"] hit_1 = 0 for i, b in enumerate(enc_batch): # Encode mask_src = b.unsqueeze(0).data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(b.unsqueeze(0)), mask_src) rank = {} for j, c in enumerate(cand_batch[i]): if config.USE_CUDA: c = c.cuda() # Decode sos_token = torch.LongTensor( [config.SOS_idx] * b.unsqueeze(0).size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat( (sos_token, c.unsqueeze(0)[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder( self.embedding(dec_batch_shift), encoder_outputs, (mask_src, mask_trg)) # compute output dist logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab[i].unsqueeze(0), extra_zeros) loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), c.unsqueeze(0).contiguous().view(-1)) # print("CANDIDATE {}".format(j), loss.item(), math.exp(min(loss.item(), 100))) rank[j] = math.exp(min(loss.item(), 100)) s = sorted(rank.items(), key=lambda x: x[1], reverse=False) if ( s[1][0] == 19 ): # because the candidate are sorted in revers order ====> last (19) is the correct one hit_1 += 1 return hit_1 / float(len(enc_batch))
def decoder_greedy(self, batch): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(config.max_dec_step): out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (mask_src, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros) _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
def decoder_greedy(self, batch, max_dec_step=31): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = get_input_from_batch( batch) encoder_outputs, encoder_hidden = self.encoder(enc_batch, enc_lens) s_t_1 = encoder_hidden.transpose(0, 1).contiguous().view( -1, config.hidden_dim * 2) #b x hidden_dim*2 kld_loss, z = self.latent(s_t_1, None, False) if config.model == "seq2seq": z = z - z s_t_1 = torch.cat((z, s_t_1), dim=-1) batch_size = enc_batch.size(0) y_t_1 = torch.LongTensor([config.SOS_idx] * batch_size) if config.USE_CUDA: y_t_1 = y_t_1.cuda() decoded_words = [] for di in range(max_dec_step): final_dist, s_t_1, c_t_1, attn_dist, p_gen, coverage = self.decoder( y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) _, topk_ids = torch.topk(final_dist, 1) decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in topk_ids.view(-1) ]) if config.USE_CUDA: y_t_1 = topk_ids.squeeze(-1).cuda() # Teacher forcing else: y_t_1 = topk_ids.squeeze(-1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
def train_one_batch(self, batch, train=True): ## pad and other stuff input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch( batch) dec_batch, dec_mask_batch, dec_index_batch, copy_gate, copy_ptr = get_output_from_batch( batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() with torch.no_grad(): # encoder_outputs are hidden states from transformer encoder_outputs, _ = self.encoder( input_ids_batch, token_type_ids=example_index_batch, attention_mask=input_mask_batch, output_all_encoded_layers=False) # # Draft Decoder sos_token = torch.LongTensor([config.SOS_idx] * input_ids_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda(device=0) dec_batch_shift = torch.cat( (sos_token, dec_batch[:, :-1]), 1) # shift the decoder input (summary) by one step mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit1, attn_dist1 = self.decoder(self.embedding(dec_batch_shift), encoder_outputs, (None, mask_trg)) # print(pre_logit1.size()) ## compute output dist logit1 = self.generator(pre_logit1, attn_dist1, enc_batch_extend_vocab, extra_zeros, copy_gate=copy_gate, copy_ptr=copy_ptr, mask_trg=mask_trg) ## loss: NNL if ptr else Cross entropy loss1 = self.criterion(logit1.contiguous().view(-1, logit1.size(-1)), dec_batch.contiguous().view(-1)) # Refine Decoder - train using gold label TARGET 'TODO: turn gold-target-text into BERT insertable representation' pre_logit2, attn_dist2 = self.generate_refinement_output( encoder_outputs, dec_batch, dec_index_batch, extra_zeros, dec_mask_batch) # pre_logit2, attn_dist2 = self.decoder(self.embedding(encoded_gold_target),encoder_outputs, (None,mask_trg)) logit2 = self.generator(pre_logit2, attn_dist2, enc_batch_extend_vocab, extra_zeros, copy_gate=copy_gate, copy_ptr=copy_ptr, mask_trg=None) loss2 = self.criterion(logit2.contiguous().view(-1, logit2.size(-1)), dec_batch.contiguous().view(-1)) loss = loss1 + loss2 if train: loss.backward() self.optimizer.step() return loss
def train_one_batch(self, batch, train=True, mode="pretrain", task=0): enc_batch, _, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode self.h0_encoder, self.c0_encoder = self.get_state(enc_batch) src_h, (src_h_t, src_c_t) = self.encoder(self.embedding(enc_batch), (self.h0_encoder, self.c0_encoder)) h_t = src_h_t[-1] c_t = src_c_t[-1] # Decode decoder_init_state = nn.Tanh()(self.encoder2decoder(h_t)) sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) target_embedding = self.embedding(dec_batch_shift) ctx = src_h.transpose(0, 1) trg_h, (_, _) = self.decoder(target_embedding, (decoder_init_state, c_t), ctx) #Memory mem_h_input = torch.cat( (decoder_init_state.unsqueeze(1), trg_h[:, 0:-1, :]), 1) mem_input = torch.cat((target_embedding, mem_h_input), 2) mem_output = self.memory(mem_input) #Combine gates = self.dec_gate(trg_h) + self.mem_gate(mem_output) decoder_gate, memory_gate = gates.chunk(2, 2) decoder_gate = F.sigmoid(decoder_gate) memory_gate = F.sigmoid(memory_gate) pre_logit = F.tanh(decoder_gate * trg_h + memory_gate * mem_output) logit = self.generator(pre_logit) if mode == "pretrain": loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if train: loss.backward() self.optimizer.step() if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss elif mode == "select": loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if (train): l1_loss = 0.0 for p in self.memory.parameters(): l1_loss += torch.sum(torch.abs(p)) loss += 0.0005 * l1_loss loss.backward() self.optimizer.step() self.compute_hooks(task) if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss else: loss = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if (train): self.register_hooks(task) loss.backward() self.optimizer.step() self.unhook(task) if (config.label_smoothing): loss = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) return loss.item(), math.exp(min(loss.item(), 100)), loss
def train_one_batch(self, batch, n_iter, train=True): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \ get_input_from_batch(batch) dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \ get_output_from_batch(batch) self.optimizer.zero_grad() encoder_outputs, encoder_hidden = self.encoder(enc_batch, enc_lens) # sort response for lstm r_len = np.array(batch["posterior_lengths"]) r_sort = r_len.argsort()[::-1] r_len = r_len[r_sort].tolist() unsort = r_sort.argsort() _, encoder_hidden_r = self.encoder_r( batch["posterior_batch"][r_sort.tolist()], r_len) #encoder_hidden_r = encoder_hidden_r[unsort.tolist()] #unsort s_t_1 = encoder_hidden.transpose(0, 1).contiguous().view( -1, config.hidden_dim * 2) #b x hidden_dim*2 s_t_1_r = encoder_hidden_r.transpose(0, 1).contiguous().view( -1, config.hidden_dim * 2)[unsort.tolist()] #unsort #b x hidden_dim*2 batch_size = enc_batch.size(0) #meta = self.embedding(batch["program_label"]) kld_loss, z = self.latent(s_t_1, s_t_1_r, train=True) if config.model == "seq2seq": z = z - z kld_loss = torch.Tensor([0]) s_t_1 = torch.cat((z, s_t_1), dim=-1) if config.model == "cvae": z_logit = self.bow(s_t_1) # [batch_size, vocab_size] z_logit = z_logit.unsqueeze(1).repeat(1, dec_batch.size(1), 1) loss_aux = self.criterion( z_logit.contiguous().view(-1, z_logit.size(-1)), dec_batch.contiguous().view(-1)) y_t_1 = torch.LongTensor([config.SOS_idx] * batch_size) if config.USE_CUDA: y_t_1 = y_t_1.cuda() step_losses = [] for di in range(max_dec_len): final_dist, s_t_1, c_t_1, attn_dist, p_gen, next_coverage = self.decoder( y_t_1, s_t_1, encoder_outputs, enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab, coverage, di) target = target_batch[:, di] gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze() step_loss = -torch.log(gold_probs + config.eps) if config.is_coverage: step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1) step_loss = step_loss + config.cov_loss_wt * step_coverage_loss coverage = next_coverage step_mask = dec_padding_mask[:, di] step_loss = step_loss * step_mask.cuda() step_losses.append(step_loss) y_t_1 = dec_batch[:, di] # Teacher forcing sum_losses = torch.sum(torch.stack(step_losses, 1), 1) batch_avg_loss = sum_losses / dec_lens_var.float().cuda() loss_rec = torch.mean(batch_avg_loss) if config.model == "cvae": kl_weight = min( math.tanh(6 * n_iter / config.full_kl_step - 3) + 1, 1) #kl_weight = min(n_iter/config.full_kl_step, 0.5) if config.full_kl_step >0 else 1.0 loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + loss_aux * config.aux_ceiling elbo = loss_rec + kld_loss else: loss = loss_rec loss_aux = torch.Tensor([0]) elbo = loss_rec if (train): loss.backward() self.optimizer.step() return loss_rec.item(), math.exp( loss_rec.item()), kld_loss.item(), loss_aux.item(), elbo.item()
def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Response encode mask_res = batch["posterior_batch"].data.eq( config.PAD_idx).unsqueeze(1) posterior_mask = self.embedding(batch["posterior_mask"]) r_encoder_outputs = self.r_encoder( self.embedding(batch["posterior_batch"]) + posterior_mask, mask_res) ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["input_mask"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) #latent variable if config.model == "cvaetrs": kld_loss, z = self.latent_layer(encoder_outputs[:, 0], r_encoder_outputs[:, 0], train=True) meta = self.embedding(batch["program_label"]) if config.dataset == "empathetic": meta = meta - meta # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) #(batch, len, embedding) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) input_vector = self.embedding(dec_batch_shift) if config.model == "cvaetrs": input_vector[:, 0] = input_vector[:, 0] + z + meta else: input_vector[:, 0] = input_vector[:, 0] + meta pre_logit, attn_dist = self.decoder(input_vector, encoder_outputs, (mask_src, mask_trg)) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) ## loss: NNL if ptr else Cross entropy loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if config.model == "cvaetrs": z_logit = self.bow(z + meta) # [batch_size, vocab_size] z_logit = z_logit.unsqueeze(1).repeat(1, logit.size(1), 1) loss_aux = self.criterion( z_logit.contiguous().view(-1, z_logit.size(-1)), dec_batch.contiguous().view(-1)) if config.multitask: emo_logit = self.emo(encoder_outputs[:, 0]) emo_loss = self.emo_criterion(emo_logit, batch["program_label"] - 9) #kl_weight = min(iter/config.full_kl_step, 0.28) if config.full_kl_step >0 else 1.0 kl_weight = min( math.tanh(6 * iter / config.full_kl_step - 3) + 1, 1) loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux if config.multitask: loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux + emo_loss aux = loss_aux.item() elbo = loss_rec + kld_loss else: loss = loss_rec elbo = loss_rec kld_loss = torch.Tensor([0]) aux = 0 if config.multitask: emo_logit = self.emo(encoder_outputs[:, 0]) emo_loss = self.emo_criterion(emo_logit, batch["program_label"] - 9) loss = loss_rec + emo_loss if (train): loss.backward() # clip gradient nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm) self.optimizer.step() return loss_rec.item(), math.exp(min( loss_rec.item(), 100)), kld_loss.item(), aux, elbo.item()
def train_one_batch(self, batch, iter, train=True): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Response encode mask_res = batch["posterior_batch"].data.eq( config.PAD_idx).unsqueeze(1) post_emb = self.embedding(batch["posterior_batch"]) r_encoder_outputs = self.r_encoder(post_emb, mask_res) ## Encode num_sentences, enc_seq_len = enc_batch.size() batch_size = enc_lens.size(0) max_len = enc_lens.data.max().item() input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1) # word level encoder enc_emb = self.embedding(enc_batch) word_encoder_outpus, word_encoder_hidden = self.word_encoder( enc_emb, input_lengths) word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape( num_sentences, -1) # pad and pack word_encoder_hidden start = torch.cumsum( torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0) word_encoder_hidden = torch.stack([ pad(word_encoder_hidden.narrow(0, s, l), max_len) for s, l in zip(start.data.tolist(), enc_lens.data.tolist()) ], 0) # mask_src = ~(enc_padding_mask.bool()).unsqueeze(1) mask_src = (1 - enc_padding_mask.byte()).unsqueeze(1) # context level encoder if word_encoder_hidden.size(-1) != config.hidden_dim: word_encoder_hidden = self.linear(word_encoder_hidden) encoder_outputs = self.encoder(word_encoder_hidden, mask_src) #latent variable if config.model == "cvaetrs": kld_loss, z = self.latent_layer(encoder_outputs[:, 0], r_encoder_outputs[:, 0], train=True) # Decode sos_token = torch.LongTensor([config.SOS_idx] * batch_size).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) #(batch, len, embedding) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) input_vector = self.embedding(dec_batch_shift) if config.model == "cvaetrs": input_vector[:, 0] = input_vector[:, 0] + z else: input_vector[:, 0] = input_vector[:, 0] pre_logit, attn_dist = self.decoder(input_vector, encoder_outputs, (mask_src, mask_trg)) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) ## loss: NNL if ptr else Cross entropy loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if config.model == "cvaetrs": z_logit = self.bow(z) # [batch_size, vocab_size] z_logit = z_logit.unsqueeze(1).repeat(1, logit.size(1), 1) loss_aux = self.criterion( z_logit.contiguous().view(-1, z_logit.size(-1)), dec_batch.contiguous().view(-1)) #kl_weight = min(iter/config.full_kl_step, 0.28) if config.full_kl_step >0 else 1.0 kl_weight = min( math.tanh(6 * iter / config.full_kl_step - 3) + 1, 1) loss = loss_rec + config.kl_ceiling * kl_weight * kld_loss + config.aux_ceiling * loss_aux aux = loss_aux.item() elbo = loss_rec + kld_loss else: loss = loss_rec elbo = loss_rec kld_loss = torch.Tensor([0]) aux = 0 if config.multitask: emo_logit = self.emo(encoder_outputs[:, 0]) emo_loss = self.emo_criterion(emo_logit, batch["program_label"] - 9) loss = loss_rec + emo_loss if (train): loss.backward() # clip gradient nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm) self.optimizer.step() return loss_rec.item(), math.exp(min( loss_rec.item(), 100)), kld_loss.item(), aux, elbo.item()
def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) if config.dataset == "empathetic": emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) else: encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) ## Attention over decoder q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] #q_h = encoder_outputs[:,0] logit_prob = self.decoder_key(q_h) #(bsz, num_experts) if (config.topk > 0): k_max_value, k_max_index = torch.topk(logit_prob, config.topk) a = np.empty([logit_prob.shape[0], self.decoder_number]) a.fill(float('-inf')) mask = torch.Tensor(a).cuda() logit_prob_ = mask.scatter_(1, k_max_index.cuda().long(), k_max_value) attention_parameters = self.attention_activation(logit_prob_) else: attention_parameters = self.attention_activation(logit_prob) # print("===============================================================================") # print("listener attention weight:",attention_parameters.data.cpu().numpy()) # print("===============================================================================") if (config.oracle): attention_parameters = self.attention_activation( torch.FloatTensor(batch['target_program']) * 1000).cuda() attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze( -1) # (batch_size, expert_num, 1, 1) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift), encoder_outputs, (mask_src, mask_trg), attention_parameters) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) #logit = F.log_softmax(logit,dim=-1) #fix the name later ## loss: NNL if ptr else Cross entropy if (train and config.schedule > 10): if (random.uniform(0, 1) <= (0.0001 + (1 - 0.0001) * math.exp(-1. * iter / config.schedule))): config.oracle = True else: config.oracle = False if config.softmax: loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + nn.CrossEntropyLoss()( logit_prob, torch.LongTensor( batch['program_label']).cuda()) loss_bce_program = nn.CrossEntropyLoss()( logit_prob, torch.LongTensor(batch['program_label']).cuda()).item() else: loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + nn.BCEWithLogitsLoss()( logit_prob, torch.FloatTensor( batch['target_program']).cuda()) loss_bce_program = nn.BCEWithLogitsLoss()( logit_prob, torch.FloatTensor(batch['target_program']).cuda()).item() pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1) program_acc = accuracy_score(batch["program_label"], pred_program) if (config.label_smoothing): loss_ppl = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)).item() if (train): loss.backward() self.optimizer.step() if (config.label_smoothing): return loss_ppl, math.exp(min(loss_ppl, 100)), loss_bce_program, program_acc else: return loss.item(), math.exp(min( loss.item(), 100)), loss_bce_program, program_acc
def decoder_topk(self, batch, max_dec_step=30): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) ## Attention over decoder q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] #q_h = encoder_outputs[:,0] logit_prob = self.decoder_key(q_h) if (config.topk > 0): k_max_value, k_max_index = torch.topk(logit_prob, config.topk) a = np.empty([logit_prob.shape[0], self.decoder_number]) a.fill(float('-inf')) mask = torch.Tensor(a).cuda() logit_prob = mask.scatter_(1, k_max_index.cuda().long(), k_max_value) attention_parameters = self.attention_activation(logit_prob) if (config.oracle): attention_parameters = self.attention_activation( torch.FloatTensor(batch['target_program']) * 1000).cuda() attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze( -1) # (batch_size, expert_num, 1, 1) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if (config.project): out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), (mask_src, mask_trg), attention_parameters) else: out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (mask_src, mask_trg), attention_parameters) logit = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) filtered_logit = top_k_top_p_filtering(logit[:, -1], top_k=3, top_p=0, filter_value=-float('Inf')) # Sample from the filtered distribution next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1), 1).squeeze() decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
def decoder_greedy(self, batch, max_dec_step=50): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch) ## Encode num_sentences, enc_seq_len = enc_batch.size() batch_size = enc_lens.size(0) max_len = enc_lens.data.max().item() input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1) # word level encoder enc_emb = self.embedding(enc_batch) word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths) word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1) # pad and pack word_encoder_hidden start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0) word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len) for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0) mask_src = ~(enc_padding_mask.bool()).unsqueeze(1) # context level encoder if word_encoder_hidden.size(-1) != config.hidden_dim: word_encoder_hidden = self.linear(word_encoder_hidden) encoder_outputs = self.encoder(word_encoder_hidden, mask_src) ys = torch.ones(batch_size, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step+1): out, attn_dist, _, _,_ = self.decoder(self.embedding(ys), encoder_outputs, None, (mask_src, None, mask_trg)) prob = self.generator(out,attn_dist,enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) _, next_word = torch.max(prob[:, -1], dim = 1) decoded_words.append(['<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1)]) if config.USE_CUDA: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) ys = ys.cuda() else: ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st+= e + ' ' sent.append(st) return sent
def train_one_batch(self, batch, iter, train=True): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if (config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) if config.dataset == "empathetic": emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) else: encoder_outputs = self.encoder(self.embedding(enc_batch), mask_src) q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] # q_h = torch.max(encoder_outputs, dim=1) emotions_mimic, emotions_non_mimic, mu_positive_prior, logvar_positive_prior, mu_negative_prior, logvar_negative_prior = \ self.vae_sampler(q_h, batch['program_label'], self.emoji_embedding) # KLLoss = -0.5 * (torch.sum(1 + logvar_n - mu_n.pow(2) - logvar_n.exp()) + torch.sum(1 + logvar_p - mu_p.pow(2) - logvar_p.exp())) m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1), encoder_outputs, mask_src) m_tilde_out = self.emotion_input_encoder_2( emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src) if train: emotions_mimic, emotions_non_mimic, mu_positive_posterior, logvar_positive_posterior, mu_negative_posterior, logvar_negative_posterior = \ self.vae_sampler.forward_train(q_h, batch['program_label'], self.emoji_embedding, M_out=m_out.mean(dim=1), M_tilde_out=m_tilde_out.mean(dim=1)) KLLoss_positive = self.vae_sampler.kl_div( mu_positive_posterior, logvar_positive_posterior, mu_positive_prior, logvar_positive_prior) KLLoss_negative = self.vae_sampler.kl_div( mu_negative_posterior, logvar_negative_posterior, mu_negative_prior, logvar_negative_prior) KLLoss = KLLoss_positive + KLLoss_negative else: KLLoss_positive = self.vae_sampler.kl_div(mu_positive_prior, logvar_positive_prior) KLLoss_negative = self.vae_sampler.kl_div(mu_negative_prior, logvar_negative_prior) KLLoss = KLLoss_positive + KLLoss_negative if config.emo_combine == "att": v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src) elif config.emo_combine == "gate": v = self.cdecoder(m_out, m_tilde_out) x = self.s_weight(q_h) # method2: E (W@c) logit_prob = torch.matmul(x, self.emoji_embedding.weight.transpose( 0, 1)) # shape (b_size, 32) # Decode sos_token = torch.LongTensor([config.SOS_idx] * enc_batch.size(0)).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) pre_logit, attn_dist = self.decoder(self.embedding(dec_batch_shift), v, v, (mask_src, mask_trg)) ## compute output dist logit = self.generator( pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) if (train and config.schedule > 10): if (random.uniform(0, 1) <= (0.0001 + (1 - 0.0001) * math.exp(-1. * iter / config.schedule))): config.oracle = True else: config.oracle = False if config.softmax: program_label = torch.LongTensor(batch['program_label']) if config.USE_CUDA: program_label = program_label.cuda() if config.emo_combine == 'gate': L1_loss = nn.CrossEntropyLoss()(logit_prob, program_label) loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + KLLoss + L1_loss else: L1_loss = nn.CrossEntropyLoss()( logit_prob, torch.LongTensor(batch['program_label']) if not config.USE_CUDA else torch.LongTensor( batch['program_label']).cuda()) loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + KLLoss + L1_loss loss_bce_program = nn.CrossEntropyLoss()(logit_prob, program_label).item() else: loss = self.criterion( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) + nn.BCEWithLogitsLoss()( logit_prob, torch.FloatTensor( batch['target_program']).cuda()) loss_bce_program = nn.BCEWithLogitsLoss()( logit_prob, torch.FloatTensor(batch['target_program']).cuda()).item() pred_program = np.argmax(logit_prob.detach().cpu().numpy(), axis=1) program_acc = accuracy_score(batch["program_label"], pred_program) if (config.label_smoothing): loss_ppl = self.criterion_ppl( logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)).item() if (train): loss.backward() self.optimizer.step() if (config.label_smoothing): return loss_ppl, math.exp(min(loss_ppl, 100)), loss_bce_program, program_acc else: return loss.item(), math.exp(min( loss.item(), 100)), loss_bce_program, program_acc
def train_one_batch(self, batch, iter, train=True): enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(batch) dec_batch, _, _, _, _ = get_output_from_batch(batch) if(config.noam): self.optimizer.optimizer.zero_grad() else: self.optimizer.zero_grad() ## Response encode mask_res = batch["posterior_batch"].data.eq(config.PAD_idx).unsqueeze(1) post_emb = self.embedding(batch["posterior_batch"]) r_encoder_outputs = self.r_encoder(post_emb, mask_res) ## Encode num_sentences, enc_seq_len = enc_batch.size() batch_size = enc_lens.size(0) max_len = enc_lens.data.max().item() input_lengths = torch.sum(~enc_batch.data.eq(config.PAD_idx), dim=1) # word level encoder enc_emb = self.embedding(enc_batch) word_encoder_outpus, word_encoder_hidden = self.word_encoder(enc_emb, input_lengths) word_encoder_hidden = word_encoder_hidden.transpose(1, 0).reshape(num_sentences, -1) # pad and pack word_encoder_hidden start = torch.cumsum(torch.cat((enc_lens.data.new(1).zero_(), enc_lens[:-1])), 0) word_encoder_hidden = torch.stack([pad(word_encoder_hidden.narrow(0, s, l), max_len) for s, l in zip(start.data.tolist(), enc_lens.data.tolist())], 0) # mask_src = ~(enc_padding_mask.bool()).unsqueeze(1) mask_src = (1 - enc_padding_mask.byte()).unsqueeze(1) # context level encoder if word_encoder_hidden.size(-1) != config.hidden_dim: word_encoder_hidden = self.linear(word_encoder_hidden) encoder_outputs = self.encoder(word_encoder_hidden, mask_src) # Decode sos_token = torch.LongTensor([config.SOS_idx] * batch_size).unsqueeze(1) if config.USE_CUDA: sos_token = sos_token.cuda() dec_batch_shift = torch.cat((sos_token, dec_batch[:, :-1]), 1) #(batch, len, embedding) mask_trg = dec_batch_shift.data.eq(config.PAD_idx).unsqueeze(1) dec_emb = self.embedding(dec_batch_shift) pre_logit, attn_dist, mean, log_var, probs = self.decoder(dec_emb, encoder_outputs, r_encoder_outputs, (mask_src, mask_res, mask_trg)) ## compute output dist logit = self.generator(pre_logit, attn_dist, enc_batch_extend_vocab if config.pointer_gen else None, extra_zeros, attn_dist_db=None) ## loss: NNL if ptr else Cross entropy sbow = dec_batch #[batch, seq_len] seq_len = sbow.size(1) loss_rec = self.criterion(logit.contiguous().view(-1, logit.size(-1)), dec_batch.contiguous().view(-1)) if config.model=="cvaetrs": loss_aux = 0 for prob in probs: sbow_mask = _get_attn_subsequent_mask(seq_len).transpose(1,2) sbow.unsqueeze(2).repeat(1,1,seq_len).masked_fill_(sbow_mask,config.PAD_idx)#[batch, seq_len, seq_len] loss_aux+= self.criterion(prob.contiguous().view(-1, prob.size(-1)), sbow.contiguous().view(-1)) kld_loss = gaussian_kld(mean["posterior"], log_var["posterior"],mean["prior"], log_var["prior"]) kld_loss = torch.mean(kld_loss) kl_weight = min(math.tanh(6 * iter/config.full_kl_step - 3) + 1, 1) #kl_weight = min(iter/config.full_kl_step, 1) if config.full_kl_step >0 else 1.0 loss = loss_rec + config.kl_ceiling * kl_weight*kld_loss + config.aux_ceiling*loss_aux elbo = loss_rec + kld_loss else: loss = loss_rec elbo = loss_rec kld_loss = torch.Tensor([0]) loss_aux = torch.Tensor([0]) if(train): loss.backward() # clip gradient nn.utils.clip_grad_norm_(self.parameters(), config.max_grad_norm) self.optimizer.step() return loss_rec.item(), math.exp(min(loss_rec.item(), 100)), kld_loss.item(), loss_aux.item(), elbo.item()
def decoder_greedy(self, batch): input_ids_batch, input_mask_batch, example_index_batch, enc_batch_extend_vocab, extra_zeros, _ = get_input_from_batch( batch) # mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) with torch.no_grad(): encoder_outputs, _ = self.encoder( input_ids_batch, token_type_ids=enc_batch_extend_vocab, attention_mask=input_mask_batch, output_all_encoded_layers=False) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(config.max_dec_step): out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (None, mask_trg)) prob = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros) _, next_word = torch.max(prob[:, -1], dim=1) decoded_words.append( self.tokenizer.convert_ids_to_tokens(next_word.tolist())) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>' or e == '<PAD>': break else: st += e + ' ' sent.append(st) return sent
def decoder_topk(self, batch, max_dec_step=30, emotion_classifier='built_in'): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) emotions = batch['program_label'] context_emo = [ self.positive_emotions[0] if d['compound'] > 0 else self.negative_emotions[0] for d in batch['context_emotion_scores'] ] context_emo = torch.Tensor(context_emo) if config.USE_CUDA: context_emo = context_emo.cuda() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] x = self.s_weight(q_h) # method 2 logit_prob = torch.matmul(x, self.emoji_embedding.weight.transpose(0, 1)) if emotion_classifier == "vader": context_emo = [ self.positive_emotions[0] if d['compound'] > 0 else self.negative_emotions[0] for d in batch['context_emotion_scores'] ] context_emo = torch.Tensor(context_emo) if config.USE_CUDA: context_emo = context_emo.cuda() emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, context_emo, self.emoji_embedding) elif emotion_classifier == None: emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, batch['program_label'], self.emoji_embedding) elif emotion_classifier == "built_in": emo_pred = torch.argmax(logit_prob, dim=-1) emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, emo_pred, self.emoji_embedding) m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1), encoder_outputs, mask_src) m_tilde_out = self.emotion_input_encoder_2( emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src) if config.emo_combine == "att": v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src) elif config.emo_combine == "gate": v = self.cdecoder(m_out, m_tilde_out) elif config.emo_combine == 'vader': m_weight = context_emo_scores.unsqueeze(-1).unsqueeze(-1) m_tilde_weight = 1 - m_weight v = m_weight * m_weight + m_tilde_weight * m_tilde_out ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if (config.project): out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), (mask_src, mask_trg), attention_parameters) else: out, attn_dist = self.decoder(self.embedding(ys), v, v, (mask_src, mask_trg)) logit = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) filtered_logit = top_k_top_p_filtering(logit[:, -1], top_k=3, top_p=0, filter_value=-float('Inf')) # Sample from the filtered distribution next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1), 1).squeeze() decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data.item() if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent