예제 #1
0
def top_k_top_p_filtering(logits,
                          top_k=0,
                          top_p=0.0,
                          filter_value=-float("Inf")):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < flow.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    # todo: support top_p
    # if top_p > 0.0:
    #     sorted_logits, sorted_indices = flow.sort(logits, descending=True)
    #     cumulative_probs = flow.cumsum(flow.softmax(sorted_logits, dim=-1), dim=-1)

    #     # Remove tokens with cumulative probability above the threshold
    #     sorted_indices_to_remove = cumulative_probs > top_p
    #     # Shift the indices to the right to keep also the first token above the threshold
    #     sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    #     sorted_indices_to_remove[..., 0] = 0

    #     indices_to_remove = sorted_indices[sorted_indices_to_remove]
    #     logits[indices_to_remove] = filter_value
    return logits
예제 #2
0
    def decode(self, ctc_matrix):
        top = flow.topk(ctc_matrix, k=1, dim=1)

        new_top = top[1][0].detach()
        for i in range(1, top[1].size(0)):
            cur = top[1][i].detach()
            new_top = flow.cat((new_top, cur), 0)

        return new_top
예제 #3
0
    def decode_step(self, preds, memory, memory_mask, cache, scores, flag):
        """ decode an utterance in a stepwise way"""

        batch_size = int(scores.size(0) / self.beam_width)

        batch_log_probs, dec_cache, dec_attn_weights = self.decode(
            preds, memory, memory_mask, cache["decoder"])

        if self.lm is not None:
            batch_lm_log_probs, lm_hidden = self.lm_decode(preds, cache["lm"])
            batch_lm_log_probs = batch_lm_log_probs.squeeze(1)
            batch_log_probs = batch_log_probs + self.lm_weight * batch_lm_log_probs
        else:
            lm_hidden = None

        if batch_log_probs.dim() == 3:
            batch_log_probs = batch_log_probs.squeeze(1)

        last_k_scores, last_k_preds = batch_log_probs.topk(self.beam_width)

        last_k_scores = mask_finished_scores(last_k_scores, flag)
        last_k_preds = mask_finished_preds(last_k_preds, flag)

        # update scores
        scores = scores + last_k_scores
        scores = scores.view(batch_size, self.beam_width * self.beam_width)

        # pruning
        scores, offset_k_indices = flow.topk(scores, k=self.beam_width)
        scores = scores.view(-1, 1)

        device = scores.device
        base_k_indices = (flow.arange(batch_size, device=device).view(
            -1, 1).repeat([1, self.beam_width]))
        base_k_indices *= self.beam_width**2
        best_k_indices = base_k_indices.view(-1) + offset_k_indices.view(-1)

        # update predictions
        best_k_preds = flow.index_select(last_k_preds.view(-1),
                                         dim=0,
                                         index=best_k_indices).to(flow.int64)

        preds_index = best_k_indices.floor_divide(self.beam_width)
        preds_symbol = flow.index_select(preds, dim=0, index=preds_index)
        preds_symbol = flow.cat(
            [preds_symbol, best_k_preds.view(-1, 1)], dim=1)

        # finished or not
        end_flag = flow.eq(preds_symbol[:, -1], EOS).view(-1, 1).to(flow.uint8)

        return preds_symbol, cache, scores, end_flag
예제 #4
0
def _topk(self, k, dim: int = None, largest: bool = True, sorted: bool = True):
    return flow.topk(self, k, dim, largest, sorted)
예제 #5
0
    def recognize_beam(self, encoder_outputs, char_list, args):
        """
        Beam search, decode one utterence now.
        Args:
            encoder_outputs: T x H #418 x 512
            char_list: list of character #4233
            args: args.beam #5

        Returns:
            nbest_hyps:
        """
        # search params
        beam = args.beam_size
        nbest = args.nbest
        if args.decode_max_len == 0:
            maxlen = encoder_outputs.size(0)
        else:
            maxlen = args.decode_max_len

        encoder_outputs = encoder_outputs.unsqueeze(0)
        # prepare sos
        ys = flow.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long()
        hyp = {"score": 0.0, "yseq": ys}
        hyps = [hyp]
        ended_hyps = []

        for i in range(maxlen):
            hyps_best_kept = []
            for hyp in hyps:
                ys = hyp["yseq"]
                ys = ys.to(device=encoder_outputs.device)
                # -- Prepare masks
                non_pad_mask = flow.ones_like(ys).to(
                    dtype=flow.float32).unsqueeze(-1)
                slf_attn_mask = get_subsequent_mask(ys)
                # -- Forward
                dec_output = self.dropout(
                    self.tgt_word_emb(ys) * self.x_logit_scale +
                    self.positional_encoding(ys))

                for dec_layer in self.layer_stack:
                    dec_output, _, _ = dec_layer(
                        dec_output,
                        encoder_outputs,
                        non_pad_mask=non_pad_mask,
                        slf_attn_mask=slf_attn_mask,
                        dec_enc_attn_mask=None,
                    )

                seq_logit = self.tgt_word_prj(dec_output[:, -1])
                local_logit = F.softmax(seq_logit)
                local_scores = flow.log(local_logit)
                # topk scores
                local_best_scores, local_best_ids = flow.topk(local_scores,
                                                              beam,
                                                              dim=1)

                for j in range(beam):
                    new_hyp = {}
                    new_hyp["score"] = hyp["score"] + local_best_scores[0, j]
                    new_hyp["yseq"] = (flow.ones(
                        1, (1 + ys.size(1))).type_as(encoder_outputs).long())
                    new_hyp["yseq"][:, :ys.size(1)] = hyp["yseq"]
                    new_hyp["yseq"][:, ys.size(1)] = int(
                        float(local_best_ids[0, j].numpy()))
                    hyps_best_kept.append(new_hyp)

                hyps_best_kept = sorted(hyps_best_kept,
                                        key=lambda x: x["score"],
                                        reverse=True)[:beam]
            # end for hyp in hyps
            hyps = hyps_best_kept
            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
                for hyp in hyps:
                    hyp["yseq"] = flow.cat(
                        [
                            hyp["yseq"],
                            flow.ones(1, 1).fill_(
                                self.eos_id).type_as(encoder_outputs).long(),
                        ],
                        dim=1,
                    )

            # add ended hypothes to a final list, and removed them from current hypothes
            # (this will be a probmlem, number of hyps < beam)
            remained_hyps = []
            for hyp in hyps:
                if hyp["yseq"][0, -1] == self.eos_id:
                    ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)

            hyps = remained_hyps
            if len(hyps) > 0:
                print("remeined hypothes: " + str(len(hyps)))
            else:
                print("no hypothesis. Finish decoding.")
                break
            for hyp in hyps:
                print("hypo: " + "".join(
                    [char_list[int(x.numpy())] for x in hyp["yseq"][0, 1:]]))

        nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"],
                            reverse=True)[:min(len(ended_hyps), nbest)]
        for hyp in nbest_hyps:
            hyp["yseq"] = hyp["yseq"][0].cpu().numpy().tolist()
        return nbest_hyps