Ejemplo n.º 1
0
    def generate(self, eval_dataloader):
        generate_corpus = []
        idx2token = eval_dataloader.idx2token

        for batch_data in eval_dataloader:
            source_idx = batch_data['attribute_idx']
            self.batch_size = source_idx.size(0)

            encoder_outputs, encoder_states = self.encoder(source_idx)

            for bid in range(self.batch_size):
                c = torch.zeros(self.num_dec_layers, 1,
                                self.hidden_size).to(self.device)
                decoder_states = (encoder_states[:, bid, :].unsqueeze(1), c)
                encoder_output = encoder_outputs[bid, :, :].unsqueeze(0)
                generate_tokens = []
                input_seq = torch.LongTensor([[self.sos_token_idx]
                                              ]).to(self.device)

                if (self.strategy == 'beam_search'):
                    hypothesis = Beam_Search_Hypothesis(
                        self.beam_size, self.sos_token_idx, self.eos_token_idx,
                        self.device, idx2token)

                for gen_idx in range(self.max_target_length):
                    decoder_input = self.target_token_embedder(input_seq)
                    decoder_outputs, decoder_states, _ = self.decoder(
                        decoder_input, decoder_states, encoder_output)

                    token_logits = self.vocab_linear(decoder_outputs)
                    if (self.strategy == 'topk_sampling'):
                        token_idx = topk_sampling(token_logits).item()
                    elif (self.strategy == 'greedy_search'):
                        token_idx = greedy_search(token_logits).item()
                    elif (self.strategy == 'beam_search'):
                        input_seq, decoder_states, encoder_output = \
                            hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output)

                    if (self.strategy in ['topk_sampling', 'greedy_search']):
                        if token_idx == self.eos_token_idx:
                            break
                        else:
                            generate_tokens.append(idx2token[token_idx])
                            input_seq = torch.LongTensor([[token_idx]
                                                          ]).to(self.device)
                    elif (self.strategy == 'beam_search'):
                        if (hypothesis.stop()):
                            break

                if (self.strategy == 'beam_search'):
                    generate_tokens = hypothesis.generate()

                generate_corpus.append(generate_tokens)

        return generate_corpus
Ejemplo n.º 2
0
    def generate(self, batch_data, eval_data):
        generate_corpus = []
        idx2token = eval_data.idx2token

        title_text = batch_data['source_idx']
        title_length = batch_data['source_length']
        sentence_length = batch_data['target_length']
        sentence_length = torch.Tensor([sentence_length[i][0].item() for i in range(len(sentence_length))])
        batch_size = title_text.size(0)

        pad_text = torch.full((batch_size, self.target_max_length + 2), self.padding_token_idx).to(self.device)
        pad_emb = self.token_embedder(pad_text)
        title_emb = self.token_embedder(title_text)
        title_o, title_hidden = self.encoder(title_emb, title_length)
        pre_o, pre_hidden = self.encoder(pad_emb, sentence_length)

        if self.rnn_type == "lstm":
            title_h, title_c = title_hidden
            fir_h, fir_c = pre_hidden
        elif self.rnn_type == 'gru' or self.rnn_type == 'rnn':
            title_h = title_hidden
            fir_h = pre_hidden
        else:
            raise NotImplementedError("No such rnn type {} for CVAE.".format(self.rnn_type))
        if self.bidirectional:
            title_h = title_h.view(self.num_enc_layers, 2, batch_size, self.hidden_size)
            title_h = title_h[-1]
            title_h = torch.cat([title_h[0], title_h[1]], dim=1)
            fir_h = fir_h.view(self.num_enc_layers, 2, batch_size, self.hidden_size)
            fir_h = fir_h[-1]
            fir_h = torch.cat([fir_h[0], fir_h[1]], dim=1)
        else:
            # title (Tensor): shape: [batch_size, num_direction*hidden_size]
            title_h = title_h[-1]
            fir_h = fir_h[-1]

        for bid in range(batch_size):
            poem = []
            pre_h = torch.unsqueeze(fir_h[bid], 0)
            single_title_h = torch.unsqueeze(title_h[bid], 0)
            for i in range(self.target_max_num):
                generate_sentence = []
                generate_sentence_idx = []
                condition = torch.cat((single_title_h, pre_h), 1)
                # mean and logvar of prior:
                prior_mean = self.prior_mean_linear1(condition)
                prior_mean = self.prior_mean_linear2(torch.tanh(prior_mean))
                prior_logvar = self.prior_logvar_linear1(condition)
                prior_logvar = self.prior_logvar_linear2(torch.tanh(prior_logvar))
                # sample from prior
                prior_z = torch.randn([1, self.latent_size]).to(self.device)
                prior_z = prior_mean + prior_z * torch.exp(0.5 * prior_logvar)
                hidden = self.latent_to_hidden1(torch.cat((condition, prior_z), 1))
                hidden = self.latent_to_hidden2(torch.tanh(hidden))

                # hidden = self.latent_to_hidden(torch.cat((condition, prior_z), 1))
                if self.rnn_type == "lstm":
                    decoder_hidden = torch.chunk(hidden, 2, dim=-1)
                    h_0 = decoder_hidden[0].unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous()
                    c_0 = decoder_hidden[1].unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous()
                    decoder_hidden = (h_0, c_0)
                else:
                    # decoder_hidden (Torch.tensor): shape: [num_dec_layers,1,hidden_size]
                    decoder_hidden = hidden.unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous()
                input_seq = torch.LongTensor([[self.sos_token_idx]]).to(self.device)
                for _ in range(
                    int(sentence_length[bid].item()) - 2
                ):  # generate until reach the maximum number of words in a sentence
                    # decoder_input (Torch.tensor): shape: [1,1,embedding_size]
                    decoder_input = self.token_embedder(input_seq)
                    outputs, hidden_states = self.decoder(input_embeddings=decoder_input, hidden_states=decoder_hidden)
                    # token_logits (Tensor): shape [1,1,vocab_size]
                    token_logits = self.vocab_linear(outputs)
                    token_idx = greedy_search(
                        token_logits[:, :, 5:]
                    )  # exclude invalid token:idx2token: <|pad|> <|unk|> <|startoftext|> <|endoftext|> __eol__
                    token_idx = token_idx.item() + 5
                    generate_sentence.append(idx2token[token_idx])
                    generate_sentence_idx.append(token_idx)
                    input_seq = torch.LongTensor([[token_idx]]).to(self.device)
                poem.extend(generate_sentence)
                generate_sentence_idx = torch.tensor(generate_sentence_idx).to(self.device).to(torch.int64)
                generate_sentence_length = torch.tensor(len(generate_sentence)).to(self.device).expand(1, 1)
                pre_emb = self.token_embedder(generate_sentence_idx)
                pre_emb = torch.unsqueeze(pre_emb, 0)
                pre_o, pre_hidden = self.encoder(pre_emb, generate_sentence_length[0])
                if self.rnn_type == "lstm":
                    pre_h, pre_c = pre_hidden
                else:
                    pre_h = pre_hidden
                if self.bidirectional:
                    pre_h = pre_h.view(self.num_enc_layers, 2, 1, self.hidden_size)
                    pre_h = pre_h[-1]
                    pre_h = torch.cat([pre_h[0], pre_h[1]], dim=1)
                else:
                    pre_h = pre_h[-1]
            generate_corpus.append(poem)
        return generate_corpus
Ejemplo n.º 3
0
    def generate(self, eval_dataloader):
        generate_corpus = []
        idx2token = eval_dataloader.target_idx2token

        for batch_data in eval_dataloader:
            source_text = batch_data['source_idx']
            source_length = batch_data['source_length']
            source_embeddings = self.source_token_embedder(source_text)
            encoder_outputs, encoder_states = self.encoder(
                source_embeddings, source_length)

            if self.bidirectional:
                encoder_outputs = encoder_outputs[:, :, self.
                                                  hidden_size:] + encoder_outputs[:, :, :
                                                                                  self
                                                                                  .
                                                                                  hidden_size]
                if (self.rnn_type == 'lstm'):
                    encoder_states = (encoder_states[0][::2],
                                      encoder_states[1][::2])
                else:
                    encoder_states = encoder_states[::2]

            encoder_masks = torch.ne(source_text, self.padding_token_idx)
            for bid in range(source_text.size(0)):
                decoder_states = encoder_states[:, bid, :].unsqueeze(1)
                encoder_output = encoder_outputs[bid, :, :].unsqueeze(0)
                encoder_mask = encoder_masks[bid, :].unsqueeze(0)
                generate_tokens = []
                input_seq = torch.LongTensor([[self.sos_token_idx]
                                              ]).to(self.device)

                if (self.strategy == 'beam_search'):
                    hypothesis = Beam_Search_Hypothesis(
                        self.beam_size, self.sos_token_idx, self.eos_token_idx,
                        self.device, idx2token)

                for gen_idx in range(self.max_target_length):
                    decoder_input = self.target_token_embedder(input_seq)
                    if self.attention_type is not None:
                        decoder_outputs, decoder_states, _ = self.decoder(
                            decoder_input, decoder_states, encoder_output,
                            encoder_mask)
                    else:
                        decoder_outputs, decoder_states = self.decoder(
                            decoder_input, decoder_states)

                    token_logits = self.vocab_linear(decoder_outputs)
                    if (self.strategy == 'topk_sampling'):
                        token_idx = topk_sampling(token_logits).item()
                    elif (self.strategy == 'greedy_search'):
                        token_idx = greedy_search(token_logits).item()
                    elif (self.strategy == 'beam_search'):
                        if self.attention_type is not None:
                            input_seq, decoder_states, encoder_output, encoder_mask = \
                                hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output, encoder_mask)
                        else:
                            input_seq, decoder_states = hypothesis.step(
                                gen_idx, token_logits, decoder_states)

                    if (self.strategy in ['topk_sampling', 'greedy_search']):
                        if token_idx == self.eos_token_idx:
                            break
                        else:
                            generate_tokens.append(idx2token[token_idx])
                            input_seq = torch.LongTensor([[token_idx]
                                                          ]).to(self.device)
                    elif (self.strategy == 'beam_search'):
                        if (hypothesis.stop()):
                            break

                if (self.strategy == 'beam_search'):
                    generate_tokens = hypothesis.generate()

                generate_corpus.append(generate_tokens)

        return generate_corpus
Ejemplo n.º 4
0
    def generate(self, eval_dataloader):
        generate_corpus = []
        idx2token = eval_dataloader.target_idx2token

        for batch_data in eval_dataloader:
            source_text = batch_data['source_idx']
            source_embeddings = self.source_token_embedder(source_text) + \
                                self.position_embedder(source_text).to(self.device)
            source_padding_mask = torch.eq(source_text, self.padding_token_idx).to(self.device)
            encoder_outputs = self.encoder(
                source_embeddings, self_padding_mask=source_padding_mask, output_all_encoded_layers=False
            )

            for bid in range(source_text.size(0)):
                encoder_output = encoder_outputs[bid, :, :].unsqueeze(0)
                encoder_mask = source_padding_mask[bid, :].unsqueeze(0)
                generate_tokens = []
                prev_token_ids = [self.sos_token_idx]
                input_seq = torch.LongTensor([prev_token_ids]).to(self.device)

                if (self.decoding_strategy == 'beam_search'):
                    hypothesis = Beam_Search_Hypothesis(
                        self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token
                    )

                for gen_idx in range(self.max_target_length):
                    self_attn_mask = self.self_attn_mask(input_seq.size(-1)).bool().to(self.device)
                    decoder_input = self.target_token_embedder(input_seq) + \
                                    self.position_embedder(input_seq).to(self.device)
                    decoder_outputs = self.decoder(
                        decoder_input,
                        self_attn_mask=self_attn_mask,
                        external_states=encoder_output,
                        external_padding_mask=encoder_mask
                    )

                    token_logits = self.vocab_linear(decoder_outputs[:, -1, :].unsqueeze(1))

                    if (self.decoding_strategy == 'topk_sampling'):
                        token_idx = topk_sampling(token_logits).item()
                    elif (self.decoding_strategy == 'greedy_search'):
                        token_idx = greedy_search(token_logits).item()
                    elif (self.decoding_strategy == 'beam_search'):
                        input_seq, encoder_output, encoder_mask = \
                            hypothesis.step(gen_idx, token_logits, encoder_output=encoder_output, encoder_mask=encoder_mask, input_type='whole')

                    if (self.decoding_strategy in ['topk_sampling', 'greedy_search']):
                        if token_idx == self.eos_token_idx:
                            break
                        else:
                            generate_tokens.append(idx2token[token_idx])
                            prev_token_ids.append(token_idx)
                            input_seq = torch.LongTensor([prev_token_ids]).to(self.device)
                    elif (self.decoding_strategy == 'beam_search'):
                        if (hypothesis.stop()):
                            break

                if (self.decoding_strategy == 'beam_search'):
                    generate_tokens = hypothesis.generate()

                generate_corpus.append(generate_tokens)

        return generate_corpus
Ejemplo n.º 5
0
    def generate(self, batch_data, eval_data):
        generate_corpus = []
        idx2token = eval_data.idx2token

        utt_states, context_states = self.encode(
            batch_data)  # [b, t, nd * h], [nl, b, h]
        source_length = batch_data['source_length']  # [b, t]
        utt_masks = torch.ne(source_length, 0)  # [b, t]

        for bid in range(utt_states.size(0)):
            encoder_states = utt_states[bid].unsqueeze(0)  # [1, t, nd * h]
            decoder_states = context_states[:,
                                            bid, :].unsqueeze(1)  # [nl, 1, h]
            context_state = decoder_states[-1].unsqueeze(0)  # [1, 1, h]
            encoder_masks = utt_masks[bid].unsqueeze(0)  # [1, t]

            genetare_tokens = []
            input_seq = torch.LongTensor([[self.sos_token_idx]
                                          ]).to(self.device)

            if (self.strategy == 'beam_search'):
                hypothesis = Beam_Search_Hypothesis(self.beam_size,
                                                    self.sos_token_idx,
                                                    self.eos_token_idx,
                                                    self.device, idx2token)

            for gen_idx in range(self.target_max_length):
                input_embedding = self.token_embedder(
                    input_seq)  # [beam, 1, e]
                decoder_input = torch.cat(
                    (input_embedding,
                     context_state.repeat(input_embedding.size(0), 1, 1)),
                    dim=-1)  # [beam, 1, e + h]

                if self.attention_type is not None:
                    decoder_outputs, decoder_states, _ = self.decoder(
                        decoder_input, decoder_states, encoder_states,
                        encoder_masks)
                else:
                    decoder_outputs, decoder_states = self.decoder(
                        decoder_input, decoder_states)

                token_logits = self.vocab_linear(decoder_outputs)
                if (self.strategy == 'topk_sampling'):
                    token_idx = topk_sampling(token_logits).item()
                elif (self.strategy == 'greedy_search'):
                    token_idx = greedy_search(token_logits).item()
                elif (self.strategy == 'beam_search'):
                    input_seq, decoder_states, encoder_states, encoder_masks = hypothesis.step(
                        gen_idx, token_logits, decoder_states, encoder_states,
                        encoder_masks)

                if (self.strategy in ['topk_sampling', 'geedy_search']):
                    if token_idx == self.eos_token_idx:
                        break
                    else:
                        genetare_tokens.append(idx2token[token_idx])
                        input_seq = torch.LongTensor([[token_idx]
                                                      ]).to(self.device)
                elif (self.strategy == 'beam_search'):
                    if (hypothesis.stop()):
                        break

            if (self.strategy == 'beam_search'):
                generate_tokens = hypothesis.generate()

            generate_corpus.append(generate_tokens)
        return generate_corpus
Ejemplo n.º 6
0
    def generate(self, batch_data, eval_data):
        generate_corpus = []
        idx2token = eval_data.dataset.target_idx2token
        source_idx = batch_data['source_idx']
        source_entity = batch_data['source_entity']
        target_dict = batch_data['target_dict']
        self.batch_size = source_idx.size(0)

        entity_embeddings, source_embeddings, root_embeddings, entity_len = self.encoder(
            batch_data)
        root_embeddings = root_embeddings.unsqueeze(0)

        c = root_embeddings.clone().detach()

        encoder_title_masks = torch.eq(source_idx,
                                       self.padding_token_idx).to(self.device)
        encoder_entity_masks = [
            torch.cat(
                [torch.zeros(i),
                 torch.ones(entity_embeddings.size(1) - i)]).unsqueeze(0)
            for i in entity_len
        ]
        encoder_entity_masks = torch.cat(encoder_entity_masks,
                                         dim=0).bool().to(self.device)

        for bid in range(self.batch_size):
            decoder_states = (root_embeddings[:, bid, :].unsqueeze(1),
                              c[:, bid, :].unsqueeze(1))
            entity_embeddings_ = entity_embeddings[bid, :, :].unsqueeze(0)
            source_embeddings_ = source_embeddings[bid, :, :].unsqueeze(0)
            encoder_title_masks_ = encoder_title_masks[bid, :].unsqueeze(0)
            encoder_entity_masks_ = encoder_entity_masks[bid, :].unsqueeze(0)
            generate_tokens = []
            input_seq = torch.LongTensor([[self.sos_token_idx]
                                          ]).to(self.device)

            if self.strategy == 'beam_search':
                hypothesis = Beam_Search_Hypothesis(self.beam_size,
                                                    self.sos_token_idx,
                                                    self.eos_token_idx,
                                                    self.device, idx2token)

            for gen_idx in range(self.target_max_length):
                decoder_input = self.target_token_embedder(input_seq)
                decoder_outputs, decoder_states = self.decoder(
                    decoder_input, decoder_states, entity_embeddings_,
                    source_embeddings_, encoder_entity_masks_,
                    encoder_title_masks_)
                copy_prob = torch.sigmoid(self.copy_linear(decoder_outputs))

                EPSI = torch.tensor(1e-6)
                pred_vocab = torch.log(copy_prob + EPSI) + torch.log_softmax(
                    self.vocab_linear(decoder_outputs), -1)
                douts = self.d_linear(decoder_outputs)
                attn_weight = self.copy_attn(douts, entity_embeddings_,
                                             entity_embeddings_,
                                             encoder_entity_masks_)[2]
                pred_copy = torch.log((1. - copy_prob) +
                                      EPSI) + attn_weight.squeeze(1)
                token_logits = torch.cat([pred_vocab, pred_copy], -1)

                # we only support greedy search for this task
                if self.strategy == 'topk_sampling':
                    token_idx = topk_sampling(token_logits).item()
                elif self.strategy == 'greedy_search':
                    token_idx = greedy_search(token_logits).item()
                elif self.strategy == 'beam_search':
                    input_seq, decoder_states, encoder_output = \
                        hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output)

                if self.strategy in ['topk_sampling', 'greedy_search']:
                    if token_idx == self.eos_token_idx:
                        break
                    elif token_idx >= self.target_vocab_size:
                        entity_tokens = source_entity[bid][
                            token_idx - self.target_vocab_size]
                        entity_tokens = entity_tokens.split(" ")
                        generate_tokens.extend(entity_tokens)
                        # retrieve next token
                        next_token = self.target_token2idx[target_dict[bid][
                            token_idx - self.target_vocab_size]]
                        input_seq = torch.LongTensor([[next_token]
                                                      ]).to(self.device)
                    else:
                        generate_tokens.append(idx2token[token_idx])
                        input_seq = torch.LongTensor([[token_idx]
                                                      ]).to(self.device)
                elif self.strategy == 'beam_search':
                    if (hypothesis.stop()):
                        break

            if self.strategy == 'beam_search':
                generate_tokens = hypothesis.generate()

            generate_corpus.append(generate_tokens)

        return generate_corpus