Exemplo n.º 1
0
    def generate(self, batch_data, eval_data):
        generate_corpus = []
        idx2token = eval_data.idx2token
        batch_size = len(batch_data['target_text'])

        for _ in range(batch_size):
            if self.rnn_type == "lstm":
                hidden_states = torch.randn(size=(1, 2 * self.hidden_size), device=self.device)
                hidden_states = torch.chunk(hidden_states, 2, dim=-1)
                h_0 = hidden_states[0].unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous()
                c_0 = hidden_states[1].unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous()
                hidden_states = (h_0, c_0)
            else:
                hidden_states = torch.randn(size=(self.num_dec_layers, 1, self.hidden_size), device=self.device)
            # draw noise from standard gussian distribution
            generate_tokens = []
            input_seq = torch.LongTensor([[self.sos_token_idx]]).to(self.device)
            for _ in range(self.max_length):
                decoder_input = self.token_embedder(input_seq)
                outputs, hidden_states = self.decoder(input_embeddings=decoder_input, hidden_states=hidden_states)
                token_logits = self.vocab_linear(outputs)
                token_idx = topk_sampling(token_logits)
                token_idx = token_idx.item()
                if token_idx == self.eos_token_idx:
                    break
                else:
                    generate_tokens.append(idx2token[token_idx])
                    input_seq = torch.LongTensor([[token_idx]]).to(self.device)
            generate_corpus.append(generate_tokens)
        return generate_corpus
Exemplo n.º 2
0
    def generate(self, eval_data):
        generate_corpus = []
        idx2token = eval_data.idx2token

        with torch.no_grad():
            for _ in range(self.eval_generate_num):
                z = torch.randn(size=(1, self.latent_size), device=self.device)
                generate_tokens = []
                input_seq = torch.LongTensor([[self.sos_token_idx]
                                              ]).to(self.device)
                for _ in range(self.max_length):
                    decoder_input = self.token_embedder(input_seq)
                    outputs = self.decoder(decoder_input=decoder_input,
                                           noise=z)
                    token_logits = self.vocab_linear(outputs)
                    token_idx = topk_sampling(token_logits)
                    token_idx = token_idx.item()
                    if token_idx == self.eos_token_idx:
                        break
                    else:
                        generate_tokens.append(idx2token[token_idx])
                        input_seq = torch.LongTensor([[token_idx]
                                                      ]).to(self.device)
                generate_corpus.append(generate_tokens)
        return generate_corpus
Exemplo n.º 3
0
    def generate(self, eval_dataloader):
        generate_corpus = []
        idx2token = eval_dataloader.idx2token

        for batch_data in eval_dataloader:
            source_idx = batch_data['attribute_idx']
            self.batch_size = source_idx.size(0)

            encoder_outputs, encoder_states = self.encoder(source_idx)

            for bid in range(self.batch_size):
                c = torch.zeros(self.num_dec_layers, 1,
                                self.hidden_size).to(self.device)
                decoder_states = (encoder_states[:, bid, :].unsqueeze(1), c)
                encoder_output = encoder_outputs[bid, :, :].unsqueeze(0)
                generate_tokens = []
                input_seq = torch.LongTensor([[self.sos_token_idx]
                                              ]).to(self.device)

                if (self.strategy == 'beam_search'):
                    hypothesis = Beam_Search_Hypothesis(
                        self.beam_size, self.sos_token_idx, self.eos_token_idx,
                        self.device, idx2token)

                for gen_idx in range(self.max_target_length):
                    decoder_input = self.target_token_embedder(input_seq)
                    decoder_outputs, decoder_states, _ = self.decoder(
                        decoder_input, decoder_states, encoder_output)

                    token_logits = self.vocab_linear(decoder_outputs)
                    if (self.strategy == 'topk_sampling'):
                        token_idx = topk_sampling(token_logits).item()
                    elif (self.strategy == 'greedy_search'):
                        token_idx = greedy_search(token_logits).item()
                    elif (self.strategy == 'beam_search'):
                        input_seq, decoder_states, encoder_output = \
                            hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output)

                    if (self.strategy in ['topk_sampling', 'greedy_search']):
                        if token_idx == self.eos_token_idx:
                            break
                        else:
                            generate_tokens.append(idx2token[token_idx])
                            input_seq = torch.LongTensor([[token_idx]
                                                          ]).to(self.device)
                    elif (self.strategy == 'beam_search'):
                        if (hypothesis.stop()):
                            break

                if (self.strategy == 'beam_search'):
                    generate_tokens = hypothesis.generate()

                generate_corpus.append(generate_tokens)

        return generate_corpus
Exemplo n.º 4
0
    def generate(self, eval_data):
        generate_corpus = []
        idx2token = eval_data.idx2token

        with torch.no_grad():
            for _ in range(self.eval_generate_num):
                z = torch.randn(size=(1, self.latent_size), device=self.device)
                cnn_out = self.decoder.conv_decoder(z)
                if self.rnn_type == "lstm":
                    hidden_states = torch.randn(size=(1, 2 * self.hidden_size),
                                                device=self.device)
                    hidden_states = torch.chunk(hidden_states, 2, dim=-1)
                    h_0 = hidden_states[0].unsqueeze(0).expand(
                        self.num_dec_layers, -1, -1).contiguous()
                    c_0 = hidden_states[1].unsqueeze(0).expand(
                        self.num_dec_layers, -1, -1).contiguous()
                    hidden_states = (h_0, c_0)
                else:
                    hidden_states = torch.randn(size=(self.num_dec_layers, 1,
                                                      self.hidden_size),
                                                device=self.device)
                generate_tokens = []
                input_seq = torch.LongTensor([[self.sos_token_idx]
                                              ]).to(self.device)
                for gen_idx in range(self.max_length):
                    decoder_input = self.token_embedder(input_seq)

                    token_logits, hidden_states = self.decoder.rnn_decoder(
                        cnn_out[:, gen_idx, :].unsqueeze(1),
                        decoder_input=decoder_input,
                        initial_state=hidden_states)
                    token_idx = topk_sampling(token_logits)
                    token_idx = token_idx.item()
                    if token_idx == self.eos_token_idx:
                        break
                    else:
                        generate_tokens.append(idx2token[token_idx])
                        input_seq = torch.LongTensor([[token_idx]
                                                      ]).to(self.device)
                generate_corpus.append(generate_tokens)
        return generate_corpus
Exemplo n.º 5
0
    def generate(self, eval_dataloader):
        generate_corpus = []
        idx2token = eval_dataloader.target_idx2token

        for batch_data in eval_dataloader:
            source_text = batch_data['source_idx']
            source_length = batch_data['source_length']
            source_embeddings = self.source_token_embedder(source_text)
            encoder_outputs, encoder_states = self.encoder(
                source_embeddings, source_length)

            if self.bidirectional:
                encoder_outputs = encoder_outputs[:, :, self.
                                                  hidden_size:] + encoder_outputs[:, :, :
                                                                                  self
                                                                                  .
                                                                                  hidden_size]
                if (self.rnn_type == 'lstm'):
                    encoder_states = (encoder_states[0][::2],
                                      encoder_states[1][::2])
                else:
                    encoder_states = encoder_states[::2]

            encoder_masks = torch.ne(source_text, self.padding_token_idx)
            for bid in range(source_text.size(0)):
                decoder_states = encoder_states[:, bid, :].unsqueeze(1)
                encoder_output = encoder_outputs[bid, :, :].unsqueeze(0)
                encoder_mask = encoder_masks[bid, :].unsqueeze(0)
                generate_tokens = []
                input_seq = torch.LongTensor([[self.sos_token_idx]
                                              ]).to(self.device)

                if (self.strategy == 'beam_search'):
                    hypothesis = Beam_Search_Hypothesis(
                        self.beam_size, self.sos_token_idx, self.eos_token_idx,
                        self.device, idx2token)

                for gen_idx in range(self.max_target_length):
                    decoder_input = self.target_token_embedder(input_seq)
                    if self.attention_type is not None:
                        decoder_outputs, decoder_states, _ = self.decoder(
                            decoder_input, decoder_states, encoder_output,
                            encoder_mask)
                    else:
                        decoder_outputs, decoder_states = self.decoder(
                            decoder_input, decoder_states)

                    token_logits = self.vocab_linear(decoder_outputs)
                    if (self.strategy == 'topk_sampling'):
                        token_idx = topk_sampling(token_logits).item()
                    elif (self.strategy == 'greedy_search'):
                        token_idx = greedy_search(token_logits).item()
                    elif (self.strategy == 'beam_search'):
                        if self.attention_type is not None:
                            input_seq, decoder_states, encoder_output, encoder_mask = \
                                hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output, encoder_mask)
                        else:
                            input_seq, decoder_states = hypothesis.step(
                                gen_idx, token_logits, decoder_states)

                    if (self.strategy in ['topk_sampling', 'greedy_search']):
                        if token_idx == self.eos_token_idx:
                            break
                        else:
                            generate_tokens.append(idx2token[token_idx])
                            input_seq = torch.LongTensor([[token_idx]
                                                          ]).to(self.device)
                    elif (self.strategy == 'beam_search'):
                        if (hypothesis.stop()):
                            break

                if (self.strategy == 'beam_search'):
                    generate_tokens = hypothesis.generate()

                generate_corpus.append(generate_tokens)

        return generate_corpus
Exemplo n.º 6
0
    def generate(self, eval_dataloader):
        generate_corpus = []
        idx2token = eval_dataloader.target_idx2token

        for batch_data in eval_dataloader:
            source_text = batch_data['source_idx']
            source_embeddings = self.source_token_embedder(source_text) + \
                                self.position_embedder(source_text).to(self.device)
            source_padding_mask = torch.eq(source_text, self.padding_token_idx).to(self.device)
            encoder_outputs = self.encoder(
                source_embeddings, self_padding_mask=source_padding_mask, output_all_encoded_layers=False
            )

            for bid in range(source_text.size(0)):
                encoder_output = encoder_outputs[bid, :, :].unsqueeze(0)
                encoder_mask = source_padding_mask[bid, :].unsqueeze(0)
                generate_tokens = []
                prev_token_ids = [self.sos_token_idx]
                input_seq = torch.LongTensor([prev_token_ids]).to(self.device)

                if (self.decoding_strategy == 'beam_search'):
                    hypothesis = Beam_Search_Hypothesis(
                        self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token
                    )

                for gen_idx in range(self.max_target_length):
                    self_attn_mask = self.self_attn_mask(input_seq.size(-1)).bool().to(self.device)
                    decoder_input = self.target_token_embedder(input_seq) + \
                                    self.position_embedder(input_seq).to(self.device)
                    decoder_outputs = self.decoder(
                        decoder_input,
                        self_attn_mask=self_attn_mask,
                        external_states=encoder_output,
                        external_padding_mask=encoder_mask
                    )

                    token_logits = self.vocab_linear(decoder_outputs[:, -1, :].unsqueeze(1))

                    if (self.decoding_strategy == 'topk_sampling'):
                        token_idx = topk_sampling(token_logits).item()
                    elif (self.decoding_strategy == 'greedy_search'):
                        token_idx = greedy_search(token_logits).item()
                    elif (self.decoding_strategy == 'beam_search'):
                        input_seq, encoder_output, encoder_mask = \
                            hypothesis.step(gen_idx, token_logits, encoder_output=encoder_output, encoder_mask=encoder_mask, input_type='whole')

                    if (self.decoding_strategy in ['topk_sampling', 'greedy_search']):
                        if token_idx == self.eos_token_idx:
                            break
                        else:
                            generate_tokens.append(idx2token[token_idx])
                            prev_token_ids.append(token_idx)
                            input_seq = torch.LongTensor([prev_token_ids]).to(self.device)
                    elif (self.decoding_strategy == 'beam_search'):
                        if (hypothesis.stop()):
                            break

                if (self.decoding_strategy == 'beam_search'):
                    generate_tokens = hypothesis.generate()

                generate_corpus.append(generate_tokens)

        return generate_corpus
Exemplo n.º 7
0
    def generate(self, batch_data, eval_data):
        generate_corpus = []
        idx2token = eval_data.idx2token

        utt_states, context_states = self.encode(
            batch_data)  # [b, t, nd * h], [nl, b, h]
        source_length = batch_data['source_length']  # [b, t]
        utt_masks = torch.ne(source_length, 0)  # [b, t]

        for bid in range(utt_states.size(0)):
            encoder_states = utt_states[bid].unsqueeze(0)  # [1, t, nd * h]
            decoder_states = context_states[:,
                                            bid, :].unsqueeze(1)  # [nl, 1, h]
            context_state = decoder_states[-1].unsqueeze(0)  # [1, 1, h]
            encoder_masks = utt_masks[bid].unsqueeze(0)  # [1, t]

            genetare_tokens = []
            input_seq = torch.LongTensor([[self.sos_token_idx]
                                          ]).to(self.device)

            if (self.strategy == 'beam_search'):
                hypothesis = Beam_Search_Hypothesis(self.beam_size,
                                                    self.sos_token_idx,
                                                    self.eos_token_idx,
                                                    self.device, idx2token)

            for gen_idx in range(self.target_max_length):
                input_embedding = self.token_embedder(
                    input_seq)  # [beam, 1, e]
                decoder_input = torch.cat(
                    (input_embedding,
                     context_state.repeat(input_embedding.size(0), 1, 1)),
                    dim=-1)  # [beam, 1, e + h]

                if self.attention_type is not None:
                    decoder_outputs, decoder_states, _ = self.decoder(
                        decoder_input, decoder_states, encoder_states,
                        encoder_masks)
                else:
                    decoder_outputs, decoder_states = self.decoder(
                        decoder_input, decoder_states)

                token_logits = self.vocab_linear(decoder_outputs)
                if (self.strategy == 'topk_sampling'):
                    token_idx = topk_sampling(token_logits).item()
                elif (self.strategy == 'greedy_search'):
                    token_idx = greedy_search(token_logits).item()
                elif (self.strategy == 'beam_search'):
                    input_seq, decoder_states, encoder_states, encoder_masks = hypothesis.step(
                        gen_idx, token_logits, decoder_states, encoder_states,
                        encoder_masks)

                if (self.strategy in ['topk_sampling', 'geedy_search']):
                    if token_idx == self.eos_token_idx:
                        break
                    else:
                        genetare_tokens.append(idx2token[token_idx])
                        input_seq = torch.LongTensor([[token_idx]
                                                      ]).to(self.device)
                elif (self.strategy == 'beam_search'):
                    if (hypothesis.stop()):
                        break

            if (self.strategy == 'beam_search'):
                generate_tokens = hypothesis.generate()

            generate_corpus.append(generate_tokens)
        return generate_corpus
Exemplo n.º 8
0
    def generate(self, batch_data, eval_data):
        generate_corpus = []
        idx2token = eval_data.dataset.target_idx2token
        source_idx = batch_data['source_idx']
        source_entity = batch_data['source_entity']
        target_dict = batch_data['target_dict']
        self.batch_size = source_idx.size(0)

        entity_embeddings, source_embeddings, root_embeddings, entity_len = self.encoder(
            batch_data)
        root_embeddings = root_embeddings.unsqueeze(0)

        c = root_embeddings.clone().detach()

        encoder_title_masks = torch.eq(source_idx,
                                       self.padding_token_idx).to(self.device)
        encoder_entity_masks = [
            torch.cat(
                [torch.zeros(i),
                 torch.ones(entity_embeddings.size(1) - i)]).unsqueeze(0)
            for i in entity_len
        ]
        encoder_entity_masks = torch.cat(encoder_entity_masks,
                                         dim=0).bool().to(self.device)

        for bid in range(self.batch_size):
            decoder_states = (root_embeddings[:, bid, :].unsqueeze(1),
                              c[:, bid, :].unsqueeze(1))
            entity_embeddings_ = entity_embeddings[bid, :, :].unsqueeze(0)
            source_embeddings_ = source_embeddings[bid, :, :].unsqueeze(0)
            encoder_title_masks_ = encoder_title_masks[bid, :].unsqueeze(0)
            encoder_entity_masks_ = encoder_entity_masks[bid, :].unsqueeze(0)
            generate_tokens = []
            input_seq = torch.LongTensor([[self.sos_token_idx]
                                          ]).to(self.device)

            if self.strategy == 'beam_search':
                hypothesis = Beam_Search_Hypothesis(self.beam_size,
                                                    self.sos_token_idx,
                                                    self.eos_token_idx,
                                                    self.device, idx2token)

            for gen_idx in range(self.target_max_length):
                decoder_input = self.target_token_embedder(input_seq)
                decoder_outputs, decoder_states = self.decoder(
                    decoder_input, decoder_states, entity_embeddings_,
                    source_embeddings_, encoder_entity_masks_,
                    encoder_title_masks_)
                copy_prob = torch.sigmoid(self.copy_linear(decoder_outputs))

                EPSI = torch.tensor(1e-6)
                pred_vocab = torch.log(copy_prob + EPSI) + torch.log_softmax(
                    self.vocab_linear(decoder_outputs), -1)
                douts = self.d_linear(decoder_outputs)
                attn_weight = self.copy_attn(douts, entity_embeddings_,
                                             entity_embeddings_,
                                             encoder_entity_masks_)[2]
                pred_copy = torch.log((1. - copy_prob) +
                                      EPSI) + attn_weight.squeeze(1)
                token_logits = torch.cat([pred_vocab, pred_copy], -1)

                # we only support greedy search for this task
                if self.strategy == 'topk_sampling':
                    token_idx = topk_sampling(token_logits).item()
                elif self.strategy == 'greedy_search':
                    token_idx = greedy_search(token_logits).item()
                elif self.strategy == 'beam_search':
                    input_seq, decoder_states, encoder_output = \
                        hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output)

                if self.strategy in ['topk_sampling', 'greedy_search']:
                    if token_idx == self.eos_token_idx:
                        break
                    elif token_idx >= self.target_vocab_size:
                        entity_tokens = source_entity[bid][
                            token_idx - self.target_vocab_size]
                        entity_tokens = entity_tokens.split(" ")
                        generate_tokens.extend(entity_tokens)
                        # retrieve next token
                        next_token = self.target_token2idx[target_dict[bid][
                            token_idx - self.target_vocab_size]]
                        input_seq = torch.LongTensor([[next_token]
                                                      ]).to(self.device)
                    else:
                        generate_tokens.append(idx2token[token_idx])
                        input_seq = torch.LongTensor([[token_idx]
                                                      ]).to(self.device)
                elif self.strategy == 'beam_search':
                    if (hypothesis.stop()):
                        break

            if self.strategy == 'beam_search':
                generate_tokens = hypothesis.generate()

            generate_corpus.append(generate_tokens)

        return generate_corpus