def greedy_decode(self, input_tensor_with_lengths): """ :param input_tensor_with_lengths: tuple(max_seq_length * batch_size, batch_size: actual sequence lengths) """ input_tensor, input_lengths = input_tensor_with_lengths input_tensor = input_tensor.transpose(0, 1) batch_size, input_sequence_length = input_tensor.size() target_length = min(int(cfg.maximum_decoding_length * 1.1), input_sequence_length * 2) memory, src_mask = self.encode(input_tensor_with_lengths) ys = torch.ones(batch_size, 1).fill_( self.TGT.vocab.stoi[cfg.bos_token]).type_as(input_tensor.data) for i in range(target_length - 1): output_tensor, output_mask = ys.clone().detach(), \ subsequent_mask(ys.size(1)).type_as(input_tensor.data).clone().detach() x = self.tgt_embed(output_tensor) for layer in self.dec_layers: x = layer(x, memory, src_mask, output_mask) out = self.dec_norm(x) prob = self.generator(out[:, -1]) _, next_word = torch.max(prob, dim=1) ys = torch.cat([ys, next_word.view(batch_size, 1)], dim=1) max_attention_indices = None return ys.transpose(0, 1), max_attention_indices, torch.zeros( 1, device=device), 1, 1
def generate_tgt_mask(self, output_tensor): """ Create a mask to hide padding and future words """ output_tensor = output_tensor[:, :-1] tgt_mask = (output_tensor != self.TGT.vocab.stoi[cfg.pad_token]).unsqueeze(-2) tgt_mask = tgt_mask & subsequent_mask(output_tensor.size(-1)).type_as(tgt_mask.data).clone().detach() return tgt_mask
def extract_output_probabilities(self, ys, memory, src_mask, input_tensor): output_tensor, output_mask = ys.clone().detach(), subsequent_mask(ys.size(1)).type_as(input_tensor.data).clone().detach() x = self.tgt_embed(output_tensor) for layer in self.dec_layers: x = layer(x, memory, src_mask, output_mask) out = self.dec_norm(x) prob = self.generator(out[:, -1]) return prob
def beam_search_decode(self, input_tensor_with_lengths, beam_size=1): """ :param input_tensor_with_lengths: tuple(max_seq_length * batch_size, batch_size: actual sequence lengths) :param beam_size: number of the hypothesis expansions during inference """ input_tensor, input_lengths = input_tensor_with_lengths input_tensor = input_tensor.transpose(0, 1) batch_size, input_sequence_length = input_tensor.size() target_length = min(int(cfg.maximum_decoding_length * 1.1), input_sequence_length * 2) memory, src_mask = self.encode(input_tensor_with_lengths) # #################################INITIALIZATION OF DECODING PARAMETERS####################################### init_ys = torch.ones(batch_size, 1).fill_( self.TGT.vocab.stoi[cfg.bos_token]).type_as(input_tensor.data) nodes = [(init_ys, torch.zeros(batch_size, device=device), torch.zeros(batch_size, device=device).byte())] final_results = [] for i in range(target_length - 1): k = beam_size - len(final_results) if k < 1: break all_predictions = torch.zeros(batch_size, len(nodes) * k, device=device).long() all_lm_scores = torch.zeros(batch_size, len(nodes) * k, device=device).float() # iterating over all the available hypotheses to expand the beams for n_id, (ys, lm_scores, eos_predicted) in enumerate(nodes): output_tensor, output_mask = ys.clone().detach(), \ subsequent_mask(ys.size(1)).type_as(input_tensor.data).clone().detach() x = self.tgt_embed(output_tensor) for layer in self.dec_layers: x = layer(x, memory, src_mask, output_mask) out = self.dec_norm(x) prob = self.generator(out[:, -1]) k_values, k_indices = torch.topk(prob, dim=1, k=k) for beam_index in range(k): overall_index = n_id * k + beam_index all_predictions[:, overall_index] = k_indices[:, beam_index] all_lm_scores[:, overall_index] = lm_scores + k_values[:, beam_index] k_values, k_indices = torch.topk(all_lm_scores, dim=1, k=k) temp_next_nodes = [] # creating the next k hypotheses for beam_index in range(k): node_ids = k_indices[:, beam_index] / k node_ids = list( node_ids.cpu().numpy()) # list of size batch_size pred_ids = list(k_indices[:, beam_index].cpu().numpy()) lm_score = k_values[:, beam_index] next_word = torch.zeros((batch_size, ), device=device).long() for b in range(batch_size): next_word[b] = all_predictions[b, pred_ids[b]] eos_p = torch.cat([ nodes[n_id][2][b_id].unsqueeze(0) for b_id, n_id in enumerate(node_ids) ], dim=0) eos_predicted = torch.max( eos_p, (next_word == self.TGT.vocab.stoi[cfg.eos_token])) ys = torch.cat([ nodes[n_id][0][b_id].unsqueeze(0) for b_id, n_id in enumerate(node_ids) ], dim=0) ys = torch.cat([ys, next_word.view(batch_size, 1)], dim=1) next_step_node = (ys, lm_score, eos_predicted) if sum(eos_predicted.int()) == batch_size: final_results.append(next_step_node) else: temp_next_nodes.append(next_step_node) del nodes[:] nodes = temp_next_nodes if not len(final_results): for node in nodes: final_results.append(node) # creating the final result based on the best scoring hypotheses result = torch.zeros(target_length, batch_size, device=device) lp = lambda l: ((5 + l)**self.beam_search_length_norm_factor) / ( 5 + 1)**self.beam_search_length_norm_factor for b_ind in range(batch_size): best_score = float('-inf') best_tokens = None for node in final_results: tokens = node[0][b_ind] eos_ind = (tokens == self.TGT.vocab.stoi[cfg.eos_token] ).nonzero().view(-1) if eos_ind.size(0): tsize = eos_ind[0].item() else: tsize = tokens.size(0) # based on Google's NMT system paper [https://arxiv.org/pdf/1609.08144.pdf] # since coverage is not being tracked here, coverage penalty is not also considered in this formula lms = node[1][b_ind].item() / lp(tsize) if lms > best_score: best_score = lms best_tokens = tokens result[:best_tokens[1:].size(0), b_ind] = best_tokens[1:] max_attention_indices = None return result, max_attention_indices, torch.zeros(1, device=device), 1, 1