示例#1
0
    def decoder_topk(self, batch, max_dec_step=30):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg))
            else:
                out, attn_dist = self.decoder(self.embedding(ys),
                                              encoder_outputs,
                                              (mask_src, mask_trg))

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            filtered_logit = top_k_top_p_filtering(logit[:, -1],
                                                   top_k=3,
                                                   top_p=0,
                                                   filter_value=-float('Inf'))
            # Sample from the filtered distribution
            next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1),
                                          1).squeeze()
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]

            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
示例#2
0
    def decoder_topk(self, batch, max_dec_step=30):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)
        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        ## Attention over decoder
        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]
        #q_h = encoder_outputs[:,0]
        logit_prob = self.decoder_key(q_h)

        if (config.topk > 0):
            k_max_value, k_max_index = torch.topk(logit_prob, config.topk)
            a = np.empty([logit_prob.shape[0], self.decoder_number])
            a.fill(float('-inf'))
            mask = torch.Tensor(a).cuda()
            logit_prob = mask.scatter_(1,
                                       k_max_index.cuda().long(), k_max_value)

        attention_parameters = self.attention_activation(logit_prob)

        if (config.oracle):
            attention_parameters = self.attention_activation(
                torch.FloatTensor(batch['target_program']) * 1000).cuda()
        attention_parameters = attention_parameters.unsqueeze(-1).unsqueeze(
            -1)  # (batch_size, expert_num, 1, 1)

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg), attention_parameters)
            else:

                out, attn_dist = self.decoder(self.embedding(ys),
                                              encoder_outputs,
                                              (mask_src, mask_trg),
                                              attention_parameters)

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            filtered_logit = top_k_top_p_filtering(logit[:, -1],
                                                   top_k=3,
                                                   top_p=0,
                                                   filter_value=-float('Inf'))
            # Sample from the filtered distribution
            next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1),
                                          1).squeeze()
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data[0]
            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>': break
                else: st += e + ' '
            sent.append(st)
        return sent
示例#3
0
    def decoder_topk(self,
                     batch,
                     max_dec_step=30,
                     emotion_classifier='built_in'):
        enc_batch, _, _, enc_batch_extend_vocab, extra_zeros, _, _ = get_input_from_batch(
            batch)

        emotions = batch['program_label']

        context_emo = [
            self.positive_emotions[0]
            if d['compound'] > 0 else self.negative_emotions[0]
            for d in batch['context_emotion_scores']
        ]
        context_emo = torch.Tensor(context_emo)
        if config.USE_CUDA:
            context_emo = context_emo.cuda()

        ## Encode
        mask_src = enc_batch.data.eq(config.PAD_idx).unsqueeze(1)

        emb_mask = self.embedding(batch["mask_input"])
        encoder_outputs = self.encoder(
            self.embedding(enc_batch) + emb_mask, mask_src)

        q_h = torch.mean(encoder_outputs,
                         dim=1) if config.mean_query else encoder_outputs[:, 0]

        x = self.s_weight(q_h)
        # method 2
        logit_prob = torch.matmul(x,
                                  self.emoji_embedding.weight.transpose(0, 1))

        if emotion_classifier == "vader":
            context_emo = [
                self.positive_emotions[0]
                if d['compound'] > 0 else self.negative_emotions[0]
                for d in batch['context_emotion_scores']
            ]
            context_emo = torch.Tensor(context_emo)
            if config.USE_CUDA:
                context_emo = context_emo.cuda()
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, context_emo, self.emoji_embedding)
        elif emotion_classifier == None:
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, batch['program_label'], self.emoji_embedding)
        elif emotion_classifier == "built_in":
            emo_pred = torch.argmax(logit_prob, dim=-1)
            emotions_mimic, emotions_non_mimic, mu_p, logvar_p, mu_n, logvar_n = self.vae_sampler(
                q_h, emo_pred, self.emoji_embedding)

        m_out = self.emotion_input_encoder_1(emotions_mimic.unsqueeze(1),
                                             encoder_outputs, mask_src)
        m_tilde_out = self.emotion_input_encoder_2(
            emotions_non_mimic.unsqueeze(1), encoder_outputs, mask_src)

        if config.emo_combine == "att":
            v = self.cdecoder(encoder_outputs, m_out, m_tilde_out, mask_src)
        elif config.emo_combine == "gate":
            v = self.cdecoder(m_out, m_tilde_out)
        elif config.emo_combine == 'vader':
            m_weight = context_emo_scores.unsqueeze(-1).unsqueeze(-1)
            m_tilde_weight = 1 - m_weight
            v = m_weight * m_weight + m_tilde_weight * m_tilde_out

        ys = torch.ones(1, 1).fill_(config.SOS_idx).long()
        if config.USE_CUDA:
            ys = ys.cuda()
        mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)
        decoded_words = []
        for i in range(max_dec_step + 1):
            if (config.project):
                out, attn_dist = self.decoder(
                    self.embedding_proj_in(self.embedding(ys)),
                    self.embedding_proj_in(encoder_outputs),
                    (mask_src, mask_trg), attention_parameters)
            else:
                out, attn_dist = self.decoder(self.embedding(ys), v, v,
                                              (mask_src, mask_trg))

            logit = self.generator(out,
                                   attn_dist,
                                   enc_batch_extend_vocab,
                                   extra_zeros,
                                   attn_dist_db=None)
            filtered_logit = top_k_top_p_filtering(logit[:, -1],
                                                   top_k=3,
                                                   top_p=0,
                                                   filter_value=-float('Inf'))
            # Sample from the filtered distribution
            next_word = torch.multinomial(F.softmax(filtered_logit, dim=-1),
                                          1).squeeze()
            decoded_words.append([
                '<EOS>' if ni.item() == config.EOS_idx else
                self.vocab.index2word[ni.item()] for ni in next_word.view(-1)
            ])
            next_word = next_word.data.item()

            if config.USE_CUDA:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word).cuda()],
                    dim=1)
                ys = ys.cuda()
            else:
                ys = torch.cat(
                    [ys, torch.ones(1, 1).long().fill_(next_word)], dim=1)
            mask_trg = ys.data.eq(config.PAD_idx).unsqueeze(1)

        sent = []
        for _, row in enumerate(np.transpose(decoded_words)):
            st = ''
            for e in row:
                if e == '<EOS>':
                    break
                else:
                    st += e + ' '
            sent.append(st)
        return sent