def generate(self, eval_dataloader): generate_corpus = [] idx2token = eval_dataloader.idx2token for batch_data in eval_dataloader: source_idx = batch_data['attribute_idx'] self.batch_size = source_idx.size(0) encoder_outputs, encoder_states = self.encoder(source_idx) for bid in range(self.batch_size): c = torch.zeros(self.num_dec_layers, 1, self.hidden_size).to(self.device) decoder_states = (encoder_states[:, bid, :].unsqueeze(1), c) encoder_output = encoder_outputs[bid, :, :].unsqueeze(0) generate_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx] ]).to(self.device) if (self.strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis( self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token) for gen_idx in range(self.max_target_length): decoder_input = self.target_token_embedder(input_seq) decoder_outputs, decoder_states, _ = self.decoder( decoder_input, decoder_states, encoder_output) token_logits = self.vocab_linear(decoder_outputs) if (self.strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.strategy == 'beam_search'): input_seq, decoder_states, encoder_output = \ hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output) if (self.strategy in ['topk_sampling', 'greedy_search']): if token_idx == self.eos_token_idx: break else: generate_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx] ]).to(self.device) elif (self.strategy == 'beam_search'): if (hypothesis.stop()): break if (self.strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
def generate(self, batch_data, eval_data): generate_corpus = [] idx2token = eval_data.idx2token title_text = batch_data['source_idx'] title_length = batch_data['source_length'] sentence_length = batch_data['target_length'] sentence_length = torch.Tensor([sentence_length[i][0].item() for i in range(len(sentence_length))]) batch_size = title_text.size(0) pad_text = torch.full((batch_size, self.target_max_length + 2), self.padding_token_idx).to(self.device) pad_emb = self.token_embedder(pad_text) title_emb = self.token_embedder(title_text) title_o, title_hidden = self.encoder(title_emb, title_length) pre_o, pre_hidden = self.encoder(pad_emb, sentence_length) if self.rnn_type == "lstm": title_h, title_c = title_hidden fir_h, fir_c = pre_hidden elif self.rnn_type == 'gru' or self.rnn_type == 'rnn': title_h = title_hidden fir_h = pre_hidden else: raise NotImplementedError("No such rnn type {} for CVAE.".format(self.rnn_type)) if self.bidirectional: title_h = title_h.view(self.num_enc_layers, 2, batch_size, self.hidden_size) title_h = title_h[-1] title_h = torch.cat([title_h[0], title_h[1]], dim=1) fir_h = fir_h.view(self.num_enc_layers, 2, batch_size, self.hidden_size) fir_h = fir_h[-1] fir_h = torch.cat([fir_h[0], fir_h[1]], dim=1) else: # title (Tensor): shape: [batch_size, num_direction*hidden_size] title_h = title_h[-1] fir_h = fir_h[-1] for bid in range(batch_size): poem = [] pre_h = torch.unsqueeze(fir_h[bid], 0) single_title_h = torch.unsqueeze(title_h[bid], 0) for i in range(self.target_max_num): generate_sentence = [] generate_sentence_idx = [] condition = torch.cat((single_title_h, pre_h), 1) # mean and logvar of prior: prior_mean = self.prior_mean_linear1(condition) prior_mean = self.prior_mean_linear2(torch.tanh(prior_mean)) prior_logvar = self.prior_logvar_linear1(condition) prior_logvar = self.prior_logvar_linear2(torch.tanh(prior_logvar)) # sample from prior prior_z = torch.randn([1, self.latent_size]).to(self.device) prior_z = prior_mean + prior_z * torch.exp(0.5 * prior_logvar) hidden = self.latent_to_hidden1(torch.cat((condition, prior_z), 1)) hidden = self.latent_to_hidden2(torch.tanh(hidden)) # hidden = self.latent_to_hidden(torch.cat((condition, prior_z), 1)) if self.rnn_type == "lstm": decoder_hidden = torch.chunk(hidden, 2, dim=-1) h_0 = decoder_hidden[0].unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous() c_0 = decoder_hidden[1].unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous() decoder_hidden = (h_0, c_0) else: # decoder_hidden (Torch.tensor): shape: [num_dec_layers,1,hidden_size] decoder_hidden = hidden.unsqueeze(0).expand(self.num_dec_layers, -1, -1).contiguous() input_seq = torch.LongTensor([[self.sos_token_idx]]).to(self.device) for _ in range( int(sentence_length[bid].item()) - 2 ): # generate until reach the maximum number of words in a sentence # decoder_input (Torch.tensor): shape: [1,1,embedding_size] decoder_input = self.token_embedder(input_seq) outputs, hidden_states = self.decoder(input_embeddings=decoder_input, hidden_states=decoder_hidden) # token_logits (Tensor): shape [1,1,vocab_size] token_logits = self.vocab_linear(outputs) token_idx = greedy_search( token_logits[:, :, 5:] ) # exclude invalid token:idx2token: <|pad|> <|unk|> <|startoftext|> <|endoftext|> __eol__ token_idx = token_idx.item() + 5 generate_sentence.append(idx2token[token_idx]) generate_sentence_idx.append(token_idx) input_seq = torch.LongTensor([[token_idx]]).to(self.device) poem.extend(generate_sentence) generate_sentence_idx = torch.tensor(generate_sentence_idx).to(self.device).to(torch.int64) generate_sentence_length = torch.tensor(len(generate_sentence)).to(self.device).expand(1, 1) pre_emb = self.token_embedder(generate_sentence_idx) pre_emb = torch.unsqueeze(pre_emb, 0) pre_o, pre_hidden = self.encoder(pre_emb, generate_sentence_length[0]) if self.rnn_type == "lstm": pre_h, pre_c = pre_hidden else: pre_h = pre_hidden if self.bidirectional: pre_h = pre_h.view(self.num_enc_layers, 2, 1, self.hidden_size) pre_h = pre_h[-1] pre_h = torch.cat([pre_h[0], pre_h[1]], dim=1) else: pre_h = pre_h[-1] generate_corpus.append(poem) return generate_corpus
def generate(self, eval_dataloader): generate_corpus = [] idx2token = eval_dataloader.target_idx2token for batch_data in eval_dataloader: source_text = batch_data['source_idx'] source_length = batch_data['source_length'] source_embeddings = self.source_token_embedder(source_text) encoder_outputs, encoder_states = self.encoder( source_embeddings, source_length) if self.bidirectional: encoder_outputs = encoder_outputs[:, :, self. hidden_size:] + encoder_outputs[:, :, : self . hidden_size] if (self.rnn_type == 'lstm'): encoder_states = (encoder_states[0][::2], encoder_states[1][::2]) else: encoder_states = encoder_states[::2] encoder_masks = torch.ne(source_text, self.padding_token_idx) for bid in range(source_text.size(0)): decoder_states = encoder_states[:, bid, :].unsqueeze(1) encoder_output = encoder_outputs[bid, :, :].unsqueeze(0) encoder_mask = encoder_masks[bid, :].unsqueeze(0) generate_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx] ]).to(self.device) if (self.strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis( self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token) for gen_idx in range(self.max_target_length): decoder_input = self.target_token_embedder(input_seq) if self.attention_type is not None: decoder_outputs, decoder_states, _ = self.decoder( decoder_input, decoder_states, encoder_output, encoder_mask) else: decoder_outputs, decoder_states = self.decoder( decoder_input, decoder_states) token_logits = self.vocab_linear(decoder_outputs) if (self.strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.strategy == 'beam_search'): if self.attention_type is not None: input_seq, decoder_states, encoder_output, encoder_mask = \ hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output, encoder_mask) else: input_seq, decoder_states = hypothesis.step( gen_idx, token_logits, decoder_states) if (self.strategy in ['topk_sampling', 'greedy_search']): if token_idx == self.eos_token_idx: break else: generate_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx] ]).to(self.device) elif (self.strategy == 'beam_search'): if (hypothesis.stop()): break if (self.strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
def generate(self, eval_dataloader): generate_corpus = [] idx2token = eval_dataloader.target_idx2token for batch_data in eval_dataloader: source_text = batch_data['source_idx'] source_embeddings = self.source_token_embedder(source_text) + \ self.position_embedder(source_text).to(self.device) source_padding_mask = torch.eq(source_text, self.padding_token_idx).to(self.device) encoder_outputs = self.encoder( source_embeddings, self_padding_mask=source_padding_mask, output_all_encoded_layers=False ) for bid in range(source_text.size(0)): encoder_output = encoder_outputs[bid, :, :].unsqueeze(0) encoder_mask = source_padding_mask[bid, :].unsqueeze(0) generate_tokens = [] prev_token_ids = [self.sos_token_idx] input_seq = torch.LongTensor([prev_token_ids]).to(self.device) if (self.decoding_strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis( self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token ) for gen_idx in range(self.max_target_length): self_attn_mask = self.self_attn_mask(input_seq.size(-1)).bool().to(self.device) decoder_input = self.target_token_embedder(input_seq) + \ self.position_embedder(input_seq).to(self.device) decoder_outputs = self.decoder( decoder_input, self_attn_mask=self_attn_mask, external_states=encoder_output, external_padding_mask=encoder_mask ) token_logits = self.vocab_linear(decoder_outputs[:, -1, :].unsqueeze(1)) if (self.decoding_strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.decoding_strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.decoding_strategy == 'beam_search'): input_seq, encoder_output, encoder_mask = \ hypothesis.step(gen_idx, token_logits, encoder_output=encoder_output, encoder_mask=encoder_mask, input_type='whole') if (self.decoding_strategy in ['topk_sampling', 'greedy_search']): if token_idx == self.eos_token_idx: break else: generate_tokens.append(idx2token[token_idx]) prev_token_ids.append(token_idx) input_seq = torch.LongTensor([prev_token_ids]).to(self.device) elif (self.decoding_strategy == 'beam_search'): if (hypothesis.stop()): break if (self.decoding_strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
def generate(self, batch_data, eval_data): generate_corpus = [] idx2token = eval_data.idx2token utt_states, context_states = self.encode( batch_data) # [b, t, nd * h], [nl, b, h] source_length = batch_data['source_length'] # [b, t] utt_masks = torch.ne(source_length, 0) # [b, t] for bid in range(utt_states.size(0)): encoder_states = utt_states[bid].unsqueeze(0) # [1, t, nd * h] decoder_states = context_states[:, bid, :].unsqueeze(1) # [nl, 1, h] context_state = decoder_states[-1].unsqueeze(0) # [1, 1, h] encoder_masks = utt_masks[bid].unsqueeze(0) # [1, t] genetare_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx] ]).to(self.device) if (self.strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis(self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token) for gen_idx in range(self.target_max_length): input_embedding = self.token_embedder( input_seq) # [beam, 1, e] decoder_input = torch.cat( (input_embedding, context_state.repeat(input_embedding.size(0), 1, 1)), dim=-1) # [beam, 1, e + h] if self.attention_type is not None: decoder_outputs, decoder_states, _ = self.decoder( decoder_input, decoder_states, encoder_states, encoder_masks) else: decoder_outputs, decoder_states = self.decoder( decoder_input, decoder_states) token_logits = self.vocab_linear(decoder_outputs) if (self.strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.strategy == 'beam_search'): input_seq, decoder_states, encoder_states, encoder_masks = hypothesis.step( gen_idx, token_logits, decoder_states, encoder_states, encoder_masks) if (self.strategy in ['topk_sampling', 'geedy_search']): if token_idx == self.eos_token_idx: break else: genetare_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx] ]).to(self.device) elif (self.strategy == 'beam_search'): if (hypothesis.stop()): break if (self.strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
def generate(self, batch_data, eval_data): generate_corpus = [] idx2token = eval_data.dataset.target_idx2token source_idx = batch_data['source_idx'] source_entity = batch_data['source_entity'] target_dict = batch_data['target_dict'] self.batch_size = source_idx.size(0) entity_embeddings, source_embeddings, root_embeddings, entity_len = self.encoder( batch_data) root_embeddings = root_embeddings.unsqueeze(0) c = root_embeddings.clone().detach() encoder_title_masks = torch.eq(source_idx, self.padding_token_idx).to(self.device) encoder_entity_masks = [ torch.cat( [torch.zeros(i), torch.ones(entity_embeddings.size(1) - i)]).unsqueeze(0) for i in entity_len ] encoder_entity_masks = torch.cat(encoder_entity_masks, dim=0).bool().to(self.device) for bid in range(self.batch_size): decoder_states = (root_embeddings[:, bid, :].unsqueeze(1), c[:, bid, :].unsqueeze(1)) entity_embeddings_ = entity_embeddings[bid, :, :].unsqueeze(0) source_embeddings_ = source_embeddings[bid, :, :].unsqueeze(0) encoder_title_masks_ = encoder_title_masks[bid, :].unsqueeze(0) encoder_entity_masks_ = encoder_entity_masks[bid, :].unsqueeze(0) generate_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx] ]).to(self.device) if self.strategy == 'beam_search': hypothesis = Beam_Search_Hypothesis(self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token) for gen_idx in range(self.target_max_length): decoder_input = self.target_token_embedder(input_seq) decoder_outputs, decoder_states = self.decoder( decoder_input, decoder_states, entity_embeddings_, source_embeddings_, encoder_entity_masks_, encoder_title_masks_) copy_prob = torch.sigmoid(self.copy_linear(decoder_outputs)) EPSI = torch.tensor(1e-6) pred_vocab = torch.log(copy_prob + EPSI) + torch.log_softmax( self.vocab_linear(decoder_outputs), -1) douts = self.d_linear(decoder_outputs) attn_weight = self.copy_attn(douts, entity_embeddings_, entity_embeddings_, encoder_entity_masks_)[2] pred_copy = torch.log((1. - copy_prob) + EPSI) + attn_weight.squeeze(1) token_logits = torch.cat([pred_vocab, pred_copy], -1) # we only support greedy search for this task if self.strategy == 'topk_sampling': token_idx = topk_sampling(token_logits).item() elif self.strategy == 'greedy_search': token_idx = greedy_search(token_logits).item() elif self.strategy == 'beam_search': input_seq, decoder_states, encoder_output = \ hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output) if self.strategy in ['topk_sampling', 'greedy_search']: if token_idx == self.eos_token_idx: break elif token_idx >= self.target_vocab_size: entity_tokens = source_entity[bid][ token_idx - self.target_vocab_size] entity_tokens = entity_tokens.split(" ") generate_tokens.extend(entity_tokens) # retrieve next token next_token = self.target_token2idx[target_dict[bid][ token_idx - self.target_vocab_size]] input_seq = torch.LongTensor([[next_token] ]).to(self.device) else: generate_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx] ]).to(self.device) elif self.strategy == 'beam_search': if (hypothesis.stop()): break if self.strategy == 'beam_search': generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus