def decoder_topk(self, batch, max_dec_step=30): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if (config.project): out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), (mask_src, mask_trg)) else: out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (mask_src, mask_trg)) logit = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) filtered_logit = top_k_top_p_filtering(logit[:, -1], top_k=3, top_p=0, filter_value=-float('Inf')) # Sample from the filtered distribution next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1), 1).squeeze() decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
def decoder_topk(self, batch, max_dec_step=30): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) ## Attention over decoder q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] #q_h = encoder_outputs[:,0] logit_prob = self.decoder_key(q_h) if (config.topk > 0): k_max_value, k_max_index = torch.topk(logit_prob, config.topk) a = np.empty([logit_prob.shape[0], self.decoder_number]) a.fill(float('-inf')) mask = torch.Tensor(a).cuda() logit_prob = mask.scatter_(1, k_max_index.cuda().long(), k_max_value) attention_parameters = self.attention_activation(logit_prob) if (config.oracle): attention_parameters = self.attention_activation( torch.FloatTensor(batch['target_program']) * 1000).cuda() attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze( -1) # (batch_size, expert_num, 1, 1) ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if (config.project): out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), (mask_src, mask_trg), attention_parameters) else: out, attn_dist = self.decoder(self.embedding(ys), encoder_outputs, (mask_src, mask_trg), attention_parameters) logit = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) filtered_logit = top_k_top_p_filtering(logit[:, -1], top_k=3, top_p=0, filter_value=-float('Inf')) # Sample from the filtered distribution next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1), 1).squeeze() decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data[0] if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent
def decoder_topk(self, batch, max_dec_step=30, emotion_classifier='built_in'): enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch( batch) emotions = batch['program_label'] context_emo = [ self.positive_emotions[0] if d['compound'] > 0 else self.negative_emotions[0] for d in batch['context_emotion_scores'] ] context_emo = torch.Tensor(context_emo) if config.USE_CUDA: context_emo = context_emo.cuda() ## Encode mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1) emb_mask = self.embedding(batch["mask_input"]) encoder_outputs = self.encoder( self.embedding(enc_batch) + emb_mask, mask_src) q_h = torch.mean(encoder_outputs, dim=1) if config.mean_query else encoder_outputs[:, 0] x = self.s_weight(q_h) # method 2 logit_prob = torch.matmul(x, self.emoji_embedding.weight.transpose(0, 1)) if emotion_classifier == "vader": context_emo = [ self.positive_emotions[0] if d['compound'] > 0 else self.negative_emotions[0] for d in batch['context_emotion_scores'] ] context_emo = torch.Tensor(context_emo) if config.USE_CUDA: context_emo = context_emo.cuda() emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, context_emo, self.emoji_embedding) elif emotion_classifier == None: emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, batch['program_label'], self.emoji_embedding) elif emotion_classifier == "built_in": emo_pred = torch.argmax(logit_prob, dim=-1) emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler( q_h, emo_pred, self.emoji_embedding) m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1), encoder_outputs, mask_src) m_tilde_out = self.emotion_input_encoder_2( emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src) if config.emo_combine == "att": v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src) elif config.emo_combine == "gate": v = self.cdecoder(m_out, m_tilde_out) elif config.emo_combine == 'vader': m_weight = context_emo_scores.unsqueeze(-1).unsqueeze(-1) m_tilde_weight = 1 - m_weight v = m_weight * m_weight + m_tilde_weight * m_tilde_out ys = torch.ones(1, 1).fill_(config.SOS_idx).long() if config.USE_CUDA: ys = ys.cuda() mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) decoded_words = [] for i in range(max_dec_step + 1): if (config.project): out, attn_dist = self.decoder( self.embedding_proj_in(self.embedding(ys)), self.embedding_proj_in(encoder_outputs), (mask_src, mask_trg), attention_parameters) else: out, attn_dist = self.decoder(self.embedding(ys), v, v, (mask_src, mask_trg)) logit = self.generator(out, attn_dist, enc_batch_extend_vocab, extra_zeros, attn_dist_db=None) filtered_logit = top_k_top_p_filtering(logit[:, -1], top_k=3, top_p=0, filter_value=-float('Inf')) # Sample from the filtered distribution next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1), 1).squeeze() decoded_words.append([ '<EOS>' if ni.item() == config.EOS_idx else self.vocab.index2word[ni.item()] for ni in next_word.view(-1) ]) next_word = next_word.data.item() if config.USE_CUDA: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1) ys = ys.cuda() else: ys = torch.cat( [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1) mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1) sent = [] for _, row in enumerate(np.transpose(decoded_words)): st = '' for e in row: if e == '<EOS>': break else: st += e + ' ' sent.append(st) return sent