Exemple #1
0
 def greedy_decode(self, input_tensor_with_lengths):
     """
     :param input_tensor_with_lengths: tuple(max_seq_length * batch_size, batch_size: actual sequence lengths)
     """
     input_tensor, input_lengths = input_tensor_with_lengths
     input_tensor = input_tensor.transpose(0, 1)
     batch_size, input_sequence_length = input_tensor.size()
     target_length = min(int(cfg.maximum_decoding_length * 1.1),
                         input_sequence_length * 2)
     memory, src_mask = self.encode(input_tensor_with_lengths)
     ys = torch.ones(batch_size, 1).fill_(
         self.TGT.vocab.stoi[cfg.bos_token]).type_as(input_tensor.data)
     for i in range(target_length - 1):
         output_tensor, output_mask = ys.clone().detach(), \
                                           subsequent_mask(ys.size(1)).type_as(input_tensor.data).clone().detach()
         x = self.tgt_embed(output_tensor)
         for layer in self.dec_layers:
             x = layer(x, memory, src_mask, output_mask)
         out = self.dec_norm(x)
         prob = self.generator(out[:, -1])
         _, next_word = torch.max(prob, dim=1)
         ys = torch.cat([ys, next_word.view(batch_size, 1)], dim=1)
     max_attention_indices = None
     return ys.transpose(0, 1), max_attention_indices, torch.zeros(
         1, device=device), 1, 1
Exemple #2
0
 def generate_tgt_mask(self, output_tensor):
     """
     Create a mask to hide padding and future words
     """
     output_tensor = output_tensor[:, :-1]
     tgt_mask = (output_tensor != self.TGT.vocab.stoi[cfg.pad_token]).unsqueeze(-2)
     tgt_mask = tgt_mask & subsequent_mask(output_tensor.size(-1)).type_as(tgt_mask.data).clone().detach()
     return tgt_mask
Exemple #3
0
 def extract_output_probabilities(self, ys, memory, src_mask, input_tensor):
     output_tensor, output_mask = ys.clone().detach(), subsequent_mask(ys.size(1)).type_as(input_tensor.data).clone().detach()
     x = self.tgt_embed(output_tensor)
     for layer in self.dec_layers:
         x = layer(x, memory, src_mask, output_mask)
     out = self.dec_norm(x)
     prob = self.generator(out[:, -1])
     return prob
Exemple #4
0
    def beam_search_decode(self, input_tensor_with_lengths, beam_size=1):
        """
        :param input_tensor_with_lengths: tuple(max_seq_length * batch_size, batch_size: actual sequence lengths)
        :param beam_size: number of the hypothesis expansions during inference
        """
        input_tensor, input_lengths = input_tensor_with_lengths
        input_tensor = input_tensor.transpose(0, 1)
        batch_size, input_sequence_length = input_tensor.size()
        target_length = min(int(cfg.maximum_decoding_length * 1.1),
                            input_sequence_length * 2)
        memory, src_mask = self.encode(input_tensor_with_lengths)

        # #################################INITIALIZATION OF DECODING PARAMETERS#######################################
        init_ys = torch.ones(batch_size, 1).fill_(
            self.TGT.vocab.stoi[cfg.bos_token]).type_as(input_tensor.data)
        nodes = [(init_ys, torch.zeros(batch_size, device=device),
                  torch.zeros(batch_size, device=device).byte())]
        final_results = []

        for i in range(target_length - 1):
            k = beam_size - len(final_results)
            if k < 1:
                break
            all_predictions = torch.zeros(batch_size,
                                          len(nodes) * k,
                                          device=device).long()
            all_lm_scores = torch.zeros(batch_size,
                                        len(nodes) * k,
                                        device=device).float()
            # iterating over all the available hypotheses to expand the beams
            for n_id, (ys, lm_scores, eos_predicted) in enumerate(nodes):
                output_tensor, output_mask = ys.clone().detach(), \
                                             subsequent_mask(ys.size(1)).type_as(input_tensor.data).clone().detach()
                x = self.tgt_embed(output_tensor)
                for layer in self.dec_layers:
                    x = layer(x, memory, src_mask, output_mask)
                out = self.dec_norm(x)
                prob = self.generator(out[:, -1])
                k_values, k_indices = torch.topk(prob, dim=1, k=k)
                for beam_index in range(k):
                    overall_index = n_id * k + beam_index
                    all_predictions[:, overall_index] = k_indices[:,
                                                                  beam_index]
                    all_lm_scores[:,
                                  overall_index] = lm_scores + k_values[:,
                                                                        beam_index]
            k_values, k_indices = torch.topk(all_lm_scores, dim=1, k=k)
            temp_next_nodes = []
            # creating the next k hypotheses
            for beam_index in range(k):
                node_ids = k_indices[:, beam_index] / k
                node_ids = list(
                    node_ids.cpu().numpy())  # list of size batch_size
                pred_ids = list(k_indices[:, beam_index].cpu().numpy())
                lm_score = k_values[:, beam_index]

                next_word = torch.zeros((batch_size, ), device=device).long()
                for b in range(batch_size):
                    next_word[b] = all_predictions[b, pred_ids[b]]

                eos_p = torch.cat([
                    nodes[n_id][2][b_id].unsqueeze(0)
                    for b_id, n_id in enumerate(node_ids)
                ],
                                  dim=0)
                eos_predicted = torch.max(
                    eos_p, (next_word == self.TGT.vocab.stoi[cfg.eos_token]))

                ys = torch.cat([
                    nodes[n_id][0][b_id].unsqueeze(0)
                    for b_id, n_id in enumerate(node_ids)
                ],
                               dim=0)
                ys = torch.cat([ys, next_word.view(batch_size, 1)], dim=1)
                next_step_node = (ys, lm_score, eos_predicted)
                if sum(eos_predicted.int()) == batch_size:
                    final_results.append(next_step_node)
                else:
                    temp_next_nodes.append(next_step_node)
            del nodes[:]
            nodes = temp_next_nodes
        if not len(final_results):
            for node in nodes:
                final_results.append(node)
        # creating the final result based on the best scoring hypotheses
        result = torch.zeros(target_length, batch_size, device=device)
        lp = lambda l: ((5 + l)**self.beam_search_length_norm_factor) / (
            5 + 1)**self.beam_search_length_norm_factor
        for b_ind in range(batch_size):
            best_score = float('-inf')
            best_tokens = None
            for node in final_results:
                tokens = node[0][b_ind]
                eos_ind = (tokens == self.TGT.vocab.stoi[cfg.eos_token]
                           ).nonzero().view(-1)
                if eos_ind.size(0):
                    tsize = eos_ind[0].item()
                else:
                    tsize = tokens.size(0)
                # based on Google's NMT system paper [https://arxiv.org/pdf/1609.08144.pdf]
                # since coverage is not being tracked here, coverage penalty is not also considered in this formula
                lms = node[1][b_ind].item() / lp(tsize)
                if lms > best_score:
                    best_score = lms
                    best_tokens = tokens
            result[:best_tokens[1:].size(0), b_ind] = best_tokens[1:]
        max_attention_indices = None
        return result, max_attention_indices, torch.zeros(1,
                                                          device=device), 1, 1