def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (vocabulary size) top_k > 0: keep only top k tokens with highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < flow.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value # todo: support top_p # if top_p > 0.0: # sorted_logits, sorted_indices = flow.sort(logits, descending=True) # cumulative_probs = flow.cumsum(flow.softmax(sorted_logits, dim=-1), dim=-1) # # Remove tokens with cumulative probability above the threshold # sorted_indices_to_remove = cumulative_probs > top_p # # Shift the indices to the right to keep also the first token above the threshold # sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() # sorted_indices_to_remove[..., 0] = 0 # indices_to_remove = sorted_indices[sorted_indices_to_remove] # logits[indices_to_remove] = filter_value return logits
def decode(self, ctc_matrix): top = flow.topk(ctc_matrix, k=1, dim=1) new_top = top[1][0].detach() for i in range(1, top[1].size(0)): cur = top[1][i].detach() new_top = flow.cat((new_top, cur), 0) return new_top
def decode_step(self, preds, memory, memory_mask, cache, scores, flag): """ decode an utterance in a stepwise way""" batch_size = int(scores.size(0) / self.beam_width) batch_log_probs, dec_cache, dec_attn_weights = self.decode( preds, memory, memory_mask, cache["decoder"]) if self.lm is not None: batch_lm_log_probs, lm_hidden = self.lm_decode(preds, cache["lm"]) batch_lm_log_probs = batch_lm_log_probs.squeeze(1) batch_log_probs = batch_log_probs + self.lm_weight * batch_lm_log_probs else: lm_hidden = None if batch_log_probs.dim() == 3: batch_log_probs = batch_log_probs.squeeze(1) last_k_scores, last_k_preds = batch_log_probs.topk(self.beam_width) last_k_scores = mask_finished_scores(last_k_scores, flag) last_k_preds = mask_finished_preds(last_k_preds, flag) # update scores scores = scores + last_k_scores scores = scores.view(batch_size, self.beam_width * self.beam_width) # pruning scores, offset_k_indices = flow.topk(scores, k=self.beam_width) scores = scores.view(-1, 1) device = scores.device base_k_indices = (flow.arange(batch_size, device=device).view( -1, 1).repeat([1, self.beam_width])) base_k_indices *= self.beam_width**2 best_k_indices = base_k_indices.view(-1) + offset_k_indices.view(-1) # update predictions best_k_preds = flow.index_select(last_k_preds.view(-1), dim=0, index=best_k_indices).to(flow.int64) preds_index = best_k_indices.floor_divide(self.beam_width) preds_symbol = flow.index_select(preds, dim=0, index=preds_index) preds_symbol = flow.cat( [preds_symbol, best_k_preds.view(-1, 1)], dim=1) # finished or not end_flag = flow.eq(preds_symbol[:, -1], EOS).view(-1, 1).to(flow.uint8) return preds_symbol, cache, scores, end_flag
def _topk(self, k, dim: int = None, largest: bool = True, sorted: bool = True): return flow.topk(self, k, dim, largest, sorted)
def recognize_beam(self, encoder_outputs, char_list, args): """ Beam search, decode one utterence now. Args: encoder_outputs: T x H #418 x 512 char_list: list of character #4233 args: args.beam #5 Returns: nbest_hyps: """ # search params beam = args.beam_size nbest = args.nbest if args.decode_max_len == 0: maxlen = encoder_outputs.size(0) else: maxlen = args.decode_max_len encoder_outputs = encoder_outputs.unsqueeze(0) # prepare sos ys = flow.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long() hyp = {"score": 0.0, "yseq": ys} hyps = [hyp] ended_hyps = [] for i in range(maxlen): hyps_best_kept = [] for hyp in hyps: ys = hyp["yseq"] ys = ys.to(device=encoder_outputs.device) # -- Prepare masks non_pad_mask = flow.ones_like(ys).to( dtype=flow.float32).unsqueeze(-1) slf_attn_mask = get_subsequent_mask(ys) # -- Forward dec_output = self.dropout( self.tgt_word_emb(ys) * self.x_logit_scale + self.positional_encoding(ys)) for dec_layer in self.layer_stack: dec_output, _, _ = dec_layer( dec_output, encoder_outputs, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=None, ) seq_logit = self.tgt_word_prj(dec_output[:, -1]) local_logit = F.softmax(seq_logit) local_scores = flow.log(local_logit) # topk scores local_best_scores, local_best_ids = flow.topk(local_scores, beam, dim=1) for j in range(beam): new_hyp = {} new_hyp["score"] = hyp["score"] + local_best_scores[0, j] new_hyp["yseq"] = (flow.ones( 1, (1 + ys.size(1))).type_as(encoder_outputs).long()) new_hyp["yseq"][:, :ys.size(1)] = hyp["yseq"] new_hyp["yseq"][:, ys.size(1)] = int( float(local_best_ids[0, j].numpy())) hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[:beam] # end for hyp in hyps hyps = hyps_best_kept # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: for hyp in hyps: hyp["yseq"] = flow.cat( [ hyp["yseq"], flow.ones(1, 1).fill_( self.eos_id).type_as(encoder_outputs).long(), ], dim=1, ) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp["yseq"][0, -1] == self.eos_id: ended_hyps.append(hyp) else: remained_hyps.append(hyp) hyps = remained_hyps if len(hyps) > 0: print("remeined hypothes: " + str(len(hyps))) else: print("no hypothesis. Finish decoding.") break for hyp in hyps: print("hypo: " + "".join( [char_list[int(x.numpy())] for x in hyp["yseq"][0, 1:]])) nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[:min(len(ended_hyps), nbest)] for hyp in nbest_hyps: hyp["yseq"] = hyp["yseq"][0].cpu().numpy().tolist() return nbest_hyps