def generate(self, eval_dataloader): generate_corpus = [] idx2token = eval_dataloader.idx2token for batch_data in eval_dataloader: source_idx = batch_data['attribute_idx'] self.batch_size = source_idx.size(0) encoder_outputs, encoder_states = self.encoder(source_idx) for bid in range(self.batch_size): c = torch.zeros(self.num_dec_layers, 1, self.hidden_size).to(self.device) decoder_states = (encoder_states[:, bid, :].unsqueeze(1), c) encoder_output = encoder_outputs[bid, :, :].unsqueeze(0) generate_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx] ]).to(self.device) if (self.strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis( self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token) for gen_idx in range(self.max_target_length): decoder_input = self.target_token_embedder(input_seq) decoder_outputs, decoder_states, _ = self.decoder( decoder_input, decoder_states, encoder_output) token_logits = self.vocab_linear(decoder_outputs) if (self.strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.strategy == 'beam_search'): input_seq, decoder_states, encoder_output = \ hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output) if (self.strategy in ['topk_sampling', 'greedy_search']): if token_idx == self.eos_token_idx: break else: generate_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx] ]).to(self.device) elif (self.strategy == 'beam_search'): if (hypothesis.stop()): break if (self.strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
def generate(self, batch_data, eval_data): generate_corpus = [] idx2token = eval_data.idx2token batch_size = batch_data['attribute_idx'].size(0) attr_embeddings, h_c_1D = self.encoder(batch_data['attribute_idx']) h_c = h_c_1D.repeat(self.num_dec_layers, 1, 1) for bid in range(batch_size): hidden_states = h_c[:, bid, :].unsqueeze(1).contiguous() generate_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx] ]).to(self.device) if (self.decoding_strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis(self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token) for gen_idx in range(self.max_length): decoder_input = self.token_embedder(input_seq) outputs, hidden_states = self.decoder(decoder_input, hidden_states) if self.is_gated: m_t = torch.sigmoid(self.gate_linear(outputs)) outputs = outputs + m_t * h_c_1D[bid] token_logits = self.vocab_linear(outputs) if (self.decoding_strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.decoding_strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.decoding_strategy == 'beam_search'): input_seq, hidden_states = \ hypothesis.step(gen_idx, token_logits, hidden_states) if (self.decoding_strategy in ['topk_sampling', 'greedy_search']): if token_idx == self.eos_token_idx: break else: generate_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx] ]).to(self.device) elif (self.decoding_strategy == 'beam_search'): if (hypothesis.stop()): break if (self.decoding_strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
def generate(self, eval_dataloader): generate_corpus = [] idx2token = eval_dataloader.target_idx2token for batch_data in eval_dataloader: source_text = batch_data['source_idx'] source_length = batch_data['source_length'] source_embeddings = self.source_token_embedder(source_text) encoder_outputs, encoder_states = self.encoder( source_embeddings, source_length) if self.bidirectional: encoder_outputs = encoder_outputs[:, :, self. hidden_size:] + encoder_outputs[:, :, : self . hidden_size] if (self.rnn_type == 'lstm'): encoder_states = (encoder_states[0][::2], encoder_states[1][::2]) else: encoder_states = encoder_states[::2] encoder_masks = torch.ne(source_text, self.padding_token_idx) for bid in range(source_text.size(0)): decoder_states = encoder_states[:, bid, :].unsqueeze(1) encoder_output = encoder_outputs[bid, :, :].unsqueeze(0) encoder_mask = encoder_masks[bid, :].unsqueeze(0) generate_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx] ]).to(self.device) if (self.strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis( self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token) for gen_idx in range(self.max_target_length): decoder_input = self.target_token_embedder(input_seq) if self.attention_type is not None: decoder_outputs, decoder_states, _ = self.decoder( decoder_input, decoder_states, encoder_output, encoder_mask) else: decoder_outputs, decoder_states = self.decoder( decoder_input, decoder_states) token_logits = self.vocab_linear(decoder_outputs) if (self.strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.strategy == 'beam_search'): if self.attention_type is not None: input_seq, decoder_states, encoder_output, encoder_mask = \ hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output, encoder_mask) else: input_seq, decoder_states = hypothesis.step( gen_idx, token_logits, decoder_states) if (self.strategy in ['topk_sampling', 'greedy_search']): if token_idx == self.eos_token_idx: break else: generate_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx] ]).to(self.device) elif (self.strategy == 'beam_search'): if (hypothesis.stop()): break if (self.strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
def generate(self, eval_dataloader): generate_corpus = [] idx2token = eval_dataloader.target_idx2token for batch_data in eval_dataloader: source_text = batch_data['source_idx'] source_embeddings = self.source_token_embedder(source_text) + \ self.position_embedder(source_text).to(self.device) source_padding_mask = torch.eq(source_text, self.padding_token_idx).to(self.device) encoder_outputs = self.encoder( source_embeddings, self_padding_mask=source_padding_mask, output_all_encoded_layers=False ) for bid in range(source_text.size(0)): encoder_output = encoder_outputs[bid, :, :].unsqueeze(0) encoder_mask = source_padding_mask[bid, :].unsqueeze(0) generate_tokens = [] prev_token_ids = [self.sos_token_idx] input_seq = torch.LongTensor([prev_token_ids]).to(self.device) if (self.decoding_strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis( self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token ) for gen_idx in range(self.max_target_length): self_attn_mask = self.self_attn_mask(input_seq.size(-1)).bool().to(self.device) decoder_input = self.target_token_embedder(input_seq) + \ self.position_embedder(input_seq).to(self.device) decoder_outputs = self.decoder( decoder_input, self_attn_mask=self_attn_mask, external_states=encoder_output, external_padding_mask=encoder_mask ) token_logits = self.vocab_linear(decoder_outputs[:, -1, :].unsqueeze(1)) if (self.decoding_strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.decoding_strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.decoding_strategy == 'beam_search'): input_seq, encoder_output, encoder_mask = \ hypothesis.step(gen_idx, token_logits, encoder_output=encoder_output, encoder_mask=encoder_mask, input_type='whole') if (self.decoding_strategy in ['topk_sampling', 'greedy_search']): if token_idx == self.eos_token_idx: break else: generate_tokens.append(idx2token[token_idx]) prev_token_ids.append(token_idx) input_seq = torch.LongTensor([prev_token_ids]).to(self.device) elif (self.decoding_strategy == 'beam_search'): if (hypothesis.stop()): break if (self.decoding_strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
def generate(self, batch_data, eval_data): generate_corpus = [] idx2token = eval_data.idx2token utt_states, context_states = self.encode( batch_data) # [b, t, nd * h], [nl, b, h] source_length = batch_data['source_length'] # [b, t] utt_masks = torch.ne(source_length, 0) # [b, t] for bid in range(utt_states.size(0)): encoder_states = utt_states[bid].unsqueeze(0) # [1, t, nd * h] decoder_states = context_states[:, bid, :].unsqueeze(1) # [nl, 1, h] context_state = decoder_states[-1].unsqueeze(0) # [1, 1, h] encoder_masks = utt_masks[bid].unsqueeze(0) # [1, t] genetare_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx] ]).to(self.device) if (self.strategy == 'beam_search'): hypothesis = Beam_Search_Hypothesis(self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token) for gen_idx in range(self.target_max_length): input_embedding = self.token_embedder( input_seq) # [beam, 1, e] decoder_input = torch.cat( (input_embedding, context_state.repeat(input_embedding.size(0), 1, 1)), dim=-1) # [beam, 1, e + h] if self.attention_type is not None: decoder_outputs, decoder_states, _ = self.decoder( decoder_input, decoder_states, encoder_states, encoder_masks) else: decoder_outputs, decoder_states = self.decoder( decoder_input, decoder_states) token_logits = self.vocab_linear(decoder_outputs) if (self.strategy == 'topk_sampling'): token_idx = topk_sampling(token_logits).item() elif (self.strategy == 'greedy_search'): token_idx = greedy_search(token_logits).item() elif (self.strategy == 'beam_search'): input_seq, decoder_states, encoder_states, encoder_masks = hypothesis.step( gen_idx, token_logits, decoder_states, encoder_states, encoder_masks) if (self.strategy in ['topk_sampling', 'geedy_search']): if token_idx == self.eos_token_idx: break else: genetare_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx] ]).to(self.device) elif (self.strategy == 'beam_search'): if (hypothesis.stop()): break if (self.strategy == 'beam_search'): generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus
def generate(self, batch_data, eval_data): generate_corpus = [] idx2token = eval_data.dataset.target_idx2token source_idx = batch_data['source_idx'] source_entity = batch_data['source_entity'] target_dict = batch_data['target_dict'] self.batch_size = source_idx.size(0) entity_embeddings, source_embeddings, root_embeddings, entity_len = self.encoder( batch_data) root_embeddings = root_embeddings.unsqueeze(0) c = root_embeddings.clone().detach() encoder_title_masks = torch.eq(source_idx, self.padding_token_idx).to(self.device) encoder_entity_masks = [ torch.cat( [torch.zeros(i), torch.ones(entity_embeddings.size(1) - i)]).unsqueeze(0) for i in entity_len ] encoder_entity_masks = torch.cat(encoder_entity_masks, dim=0).bool().to(self.device) for bid in range(self.batch_size): decoder_states = (root_embeddings[:, bid, :].unsqueeze(1), c[:, bid, :].unsqueeze(1)) entity_embeddings_ = entity_embeddings[bid, :, :].unsqueeze(0) source_embeddings_ = source_embeddings[bid, :, :].unsqueeze(0) encoder_title_masks_ = encoder_title_masks[bid, :].unsqueeze(0) encoder_entity_masks_ = encoder_entity_masks[bid, :].unsqueeze(0) generate_tokens = [] input_seq = torch.LongTensor([[self.sos_token_idx] ]).to(self.device) if self.strategy == 'beam_search': hypothesis = Beam_Search_Hypothesis(self.beam_size, self.sos_token_idx, self.eos_token_idx, self.device, idx2token) for gen_idx in range(self.target_max_length): decoder_input = self.target_token_embedder(input_seq) decoder_outputs, decoder_states = self.decoder( decoder_input, decoder_states, entity_embeddings_, source_embeddings_, encoder_entity_masks_, encoder_title_masks_) copy_prob = torch.sigmoid(self.copy_linear(decoder_outputs)) EPSI = torch.tensor(1e-6) pred_vocab = torch.log(copy_prob + EPSI) + torch.log_softmax( self.vocab_linear(decoder_outputs), -1) douts = self.d_linear(decoder_outputs) attn_weight = self.copy_attn(douts, entity_embeddings_, entity_embeddings_, encoder_entity_masks_)[2] pred_copy = torch.log((1. - copy_prob) + EPSI) + attn_weight.squeeze(1) token_logits = torch.cat([pred_vocab, pred_copy], -1) # we only support greedy search for this task if self.strategy == 'topk_sampling': token_idx = topk_sampling(token_logits).item() elif self.strategy == 'greedy_search': token_idx = greedy_search(token_logits).item() elif self.strategy == 'beam_search': input_seq, decoder_states, encoder_output = \ hypothesis.step(gen_idx, token_logits, decoder_states, encoder_output) if self.strategy in ['topk_sampling', 'greedy_search']: if token_idx == self.eos_token_idx: break elif token_idx >= self.target_vocab_size: entity_tokens = source_entity[bid][ token_idx - self.target_vocab_size] entity_tokens = entity_tokens.split(" ") generate_tokens.extend(entity_tokens) # retrieve next token next_token = self.target_token2idx[target_dict[bid][ token_idx - self.target_vocab_size]] input_seq = torch.LongTensor([[next_token] ]).to(self.device) else: generate_tokens.append(idx2token[token_idx]) input_seq = torch.LongTensor([[token_idx] ]).to(self.device) elif self.strategy == 'beam_search': if (hypothesis.stop()): break if self.strategy == 'beam_search': generate_tokens = hypothesis.generate() generate_corpus.append(generate_tokens) return generate_corpus