def beam_search(self, eouts, elens, params, idx2token=None, lm=None, lm_second=None, lm_bwd=None, ctc_log_probs=None, nbest=1, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, ensmbl_eouts=None, ensmbl_elens=None, ensmbl_decs=[], cache_states=True): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` params (dict): hyperparameters for decoding idx2token (): converter from index to token lm: firsh path LM lm_second: second path LM lm_bwd: first/secoding path backward LM ctc_log_probs (FloatTensor): nbest (int): exclude_eos (bool): exclude <eos> from hypothesis refs_id (list): reference list utt_ids (list): utterance id list speakers (list): speaker list ensmbl_eouts (list): list of FloatTensor ensmbl_elens (list) list of list ensmbl_decs (list): list of torch.nn.Module cache_states (bool): cache decoder states for fast decoding Returns: nbest_hyps_idx (list): length `B`, each of which contains list of N hypotheses aws (list): length `B`, each of which contains arrays of size `[H, L, T]` scores (list): """ bs, xmax, _ = eouts.size() n_models = len(ensmbl_decs) + 1 beam_width = params['recog_beam_width'] assert 1 <= nbest <= beam_width ctc_weight = params['recog_ctc_weight'] max_len_ratio = params['recog_max_len_ratio'] min_len_ratio = params['recog_min_len_ratio'] lp_weight = params['recog_length_penalty'] length_norm = params['recog_length_norm'] lm_weight = params['recog_lm_weight'] lm_weight_second = params['recog_lm_second_weight'] lm_weight_bwd = params['recog_lm_bwd_weight'] eos_threshold = params['recog_eos_threshold'] lm_state_carry_over = params['recog_lm_state_carry_over'] softmax_smoothing = params['recog_softmax_smoothing'] eps_wait = params['recog_mma_delay_threshold'] if lm is not None: assert lm_weight > 0 lm.eval() if lm_second is not None: assert lm_weight_second > 0 lm_second.eval() if lm_bwd is not None: assert lm_weight_bwd > 0 lm_bwd.eval() if ctc_log_probs is not None: assert ctc_weight > 0 ctc_log_probs = tensor2np(ctc_log_probs) nbest_hyps_idx, aws, scores = [], [], [] eos_flags = [] for b in range(bs): # Initialization per utterance lmstate = None ys = eouts.new_zeros(1, 1).fill_(self.eos).long() # For joint CTC-Attention decoding ctc_prefix_scorer = None if ctc_log_probs is not None: if self.bwd: ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b][::-1], self.blank, self.eos) else: ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos) if speakers is not None: if speakers[b] == self.prev_spk: if lm_state_carry_over and isinstance(lm, RNNLM): lmstate = self.lmstate_final self.prev_spk = speakers[b] helper = BeamSearch(beam_width, self.eos, ctc_weight, self.device_id) end_hyps = [] ymax = int(math.floor(elens[b] * max_len_ratio)) + 1 hyps = [{'hyp': [self.eos], 'ys': ys, 'cache': None, 'score': 0., 'score_attn': 0., 'score_ctc': 0., 'score_lm': 0., 'aws': [None], 'lmstate': lmstate, 'ensmbl_aws':[[None]] * (n_models - 1), 'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None, 'streamable': True, 'streaming_failed_point': 1000}] streamable_global = True for t in range(ymax): # batchfy all hypotheses for batch decoding cache = [None] * self.n_layers if cache_states and t > 0: for lth in range(self.n_layers): cache[lth] = torch.cat([beam['cache'][lth] for beam in hyps], dim=0) ys = eouts.new_zeros(len(hyps), t + 1).long() for j, beam in enumerate(hyps): ys[j, :] = beam['ys'] if t > 0: xy_aws_prev = torch.cat([beam['aws'][-1] for beam in hyps], dim=0) # `[B, n_layers, H_ma, 1, klen]` else: xy_aws_prev = None # Update LM states for shallow fusion lmstate, scores_lm = None, None if lm is not None: if hyps[0]['lmstate'] is not None: lm_hxs = torch.cat([beam['lmstate']['hxs'] for beam in hyps], dim=1) lm_cxs = torch.cat([beam['lmstate']['cxs'] for beam in hyps], dim=1) lmstate = {'hxs': lm_hxs, 'cxs': lm_cxs} y = ys[:, -1:].clone() # NOTE: this is important _, lmstate, scores_lm = lm.predict(y, lmstate) # for the main model causal_mask = eouts.new_ones(t + 1, t + 1).byte() causal_mask = torch.tril(causal_mask, out=causal_mask).unsqueeze(0).repeat([ys.size(0), 1, 1]) out = self.pos_enc(self.embed(ys)) # scaled mlen = 0 # TODO: fix later if self.memory_transformer: # NOTE: TransformerXL does not use positional encoding in the token embedding mems = self.init_memory() # adopt zero-centered offset pos_idxs = torch.arange(mlen - 1, -(t + 1) - 1, -1.0, dtype=torch.float) pos_embs = self.pos_emb(pos_idxs, self.device_id) out = self.dropout_emb(out) hidden_states = [out] n_heads_total = 0 eouts_b = eouts[b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) new_cache = [None] * self.n_layers xy_aws_all_layers = [] lth_s = self.mocha_first_layer - 1 for lth, layer in enumerate(self.layers): if self.memory_transformer: out = layer( out, causal_mask, eouts_b, None, cache=cache[lth], pos_embs=pos_embs, memory=mems[lth], u=self.u, v=self.v) hidden_states.append(out) else: out = layer( out, causal_mask, eouts_b, None, cache=cache[lth], xy_aws_prev=xy_aws_prev[:, lth - lth_s] if lth >= lth_s and t > 0 else None, eps_wait=eps_wait) new_cache[lth] = out if layer.xy_aws is not None: xy_aws_all_layers.append(layer.xy_aws) logits = self.output(self.norm_out(out)) probs = torch.softmax(logits[:, -1] * softmax_smoothing, dim=1) xy_aws_all_layers = torch.stack(xy_aws_all_layers, dim=1) # `[B, H, n_layers, L, T]` # for the ensemble ensmbl_new_cache = [] if n_models > 1: # Ensemble initialization # ensmbl_cache = [] # cache_e = [None] * self.n_layers # if cache_states and t > 0: # for lth in range(self.n_layers): # cache_e[lth] = torch.cat([beam['ensmbl_cache'][lth] for beam in hyps], dim=0) for i_e, dec in enumerate(ensmbl_decs): out_e = dec.pos_enc(dec.embed(ys)) # scaled eouts_e = ensmbl_eouts[i_e][b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) new_cache_e = [None] * dec.n_layers for lth in range(dec.n_layers): out_e, _, xy_aws_e, _, _ = dec.layers[lth](out_e, causal_mask, eouts_e, None, cache=cache[lth]) new_cache_e[lth] = out_e ensmbl_new_cache.append(new_cache_e) logits_e = dec.output(dec.norm_out(out_e)) probs += torch.softmax(logits_e[:, -1] * softmax_smoothing, dim=1) # NOTE: sum in the probability scale (not log-scale) # Ensemble in log-scale scores_attn = torch.log(probs) / n_models new_hyps = [] for j, beam in enumerate(hyps): # Attention scores total_scores_attn = beam['score_attn'] + scores_attn[j:j + 1] total_scores = total_scores_attn * (1 - ctc_weight) # Add LM score <before> top-K selection if lm is not None: total_scores_lm = beam['score_lm'] + scores_lm[j:j + 1, -1] total_scores += total_scores_lm * lm_weight else: total_scores_lm = eouts.new_zeros(1, self.vocab) total_scores_topk, topk_ids = torch.topk( total_scores, k=beam_width, dim=1, largest=True, sorted=True) # Add length penalty if lp_weight > 0: total_scores_topk += (len(beam['hyp'][1:]) + 1) * lp_weight # Add CTC score new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score( beam['hyp'], topk_ids, beam['ctc_state'], total_scores_topk, ctc_prefix_scorer) new_aws = beam['aws'] + [xy_aws_all_layers[j:j + 1, :, :, -1:]] aws_j = torch.cat(new_aws[1:], dim=3) # `[1, H, n_layers, L, T]` streaming_failed_point = beam['streaming_failed_point'] # forward direction for k in range(beam_width): idx = topk_ids[0, k].item() length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1 total_scores_topk /= length_norm_factor if idx == self.eos: # Exclude short hypotheses if len(beam['hyp']) - 1 < elens[b] * min_len_ratio: continue # EOS threshold max_score_no_eos = scores_attn[j, :idx].max(0)[0].item() max_score_no_eos = max(max_score_no_eos, scores_attn[j, idx + 1:].max(0)[0].item()) if scores_attn[j, idx].item() <= eos_threshold * max_score_no_eos: continue quantity_rate = 1. if 'mocha' in self.attn_type: n_tokens_hyp_k = t + 1 n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item() quantity_diff = n_tokens_hyp_k * n_heads_total - n_quantity_k if quantity_diff != 0: if idx == self.eos: n_tokens_hyp_k -= 1 # NOTE: do not count <eos> for streamability n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item() else: streamable_global = False if n_tokens_hyp_k * n_heads_total == 0: quantity_rate = 0 else: quantity_rate = n_quantity_k / (n_tokens_hyp_k * n_heads_total) if beam['streamable'] and not streamable_global: streaming_failed_point = t new_hyps.append( {'hyp': beam['hyp'] + [idx], 'ys': torch.cat([beam['ys'], eouts.new_zeros(1, 1).fill_(idx).long()], dim=-1), 'cache': [new_cache_l[j:j + 1] for new_cache_l in new_cache] if cache_states else cache, 'score': total_scores_topk[0, k].item(), 'score_attn': total_scores_attn[0, idx].item(), 'score_ctc': total_scores_ctc[k].item(), 'score_lm': total_scores_lm[0, idx].item(), 'aws': new_aws, 'lmstate': {'hxs': lmstate['hxs'][:, j:j + 1], 'cxs': lmstate['cxs'][:, j:j + 1]} if lmstate is not None else None, 'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None, 'ensmbl_cache': ensmbl_new_cache, 'streamable': streamable_global, 'streaming_failed_point': streaming_failed_point, 'quantity_rate': quantity_rate}) # Local pruning new_hyps_sorted = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width] # Remove complete hypotheses new_hyps, end_hyps, is_finish = helper.remove_complete_hyp( new_hyps_sorted, end_hyps, prune=True) hyps = new_hyps[:] if is_finish: break # Global pruning if len(end_hyps) == 0: end_hyps = hyps[:] elif len(end_hyps) < nbest and nbest > 1: end_hyps.extend(hyps[:nbest - len(end_hyps)]) # forward second path LM rescoring if lm_second is not None: self.lm_rescoring(end_hyps, lm_second, lm_weight_second, tag='second') # backward secodn path LM rescoring if lm_bwd is not None and lm_weight_bwd > 0: self.lm_rescoring(end_hyps, lm_bwd, lm_weight_bwd, tag='second_bwd') # Sort by score end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True) for j in range(len(end_hyps[0]['aws'][1:])): tmp = end_hyps[0]['aws'][j + 1] end_hyps[0]['aws'][j + 1] = tmp.view(1, -1, tmp.size(-2), tmp.size(-1)) # metrics for streaming infernece self.streamable = end_hyps[0]['streamable'] self.quantity_rate = end_hyps[0]['quantity_rate'] self.last_success_frame_ratio = None if idx2token is not None: if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) assert self.vocab == idx2token.vocab logger.info('=' * 200) for k in range(len(end_hyps)): if refs_id is not None: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token( end_hyps[k]['hyp'][1:][::-1] if self.bwd else end_hyps[k]['hyp'][1:])) logger.info('num tokens (hyp): %d' % len(end_hyps[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % end_hyps[k]['score']) logger.info('log prob (hyp, att): %.7f' % (end_hyps[k]['score_attn'] * (1 - ctc_weight))) if ctc_prefix_scorer is not None: logger.info('log prob (hyp, ctc): %.7f' % (end_hyps[k]['score_ctc'] * ctc_weight)) if lm is not None: logger.info('log prob (hyp, first-path lm): %.7f' % (end_hyps[k]['score_lm'] * lm_weight)) if lm_second is not None: logger.info('log prob (hyp, second-path lm): %.7f' % (end_hyps[k]['score_lm_second'] * lm_weight_second)) if lm_bwd is not None: logger.info('log prob (hyp, second-path lm-bwd): %.7f' % (end_hyps[k]['score_lm_second_bwd'] * lm_weight_bwd)) if 'mocha' in self.attn_type: logger.info('streamable: %s' % end_hyps[k]['streamable']) logger.info('streaming failed point: %d' % (end_hyps[k]['streaming_failed_point'] + 1)) logger.info('quantity rate [%%]: %.2f' % (end_hyps[k]['quantity_rate'] * 100)) logger.info('-' * 50) if 'mocha' in self.attn_type and end_hyps[0]['streaming_failed_point'] < 1000: assert not self.streamable aws_last_success = end_hyps[0]['aws'][1:][end_hyps[0]['streaming_failed_point'] - 1] rightmost_frame = max(0, aws_last_success[0, :, 0].nonzero()[:, -1].max().item()) + 1 frame_ratio = rightmost_frame * 100 / xmax self.last_success_frame_ratio = frame_ratio logger.info('streaming last success frame ratio: %.2f' % frame_ratio) # N-best list if self.bwd: # Reverse the order nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:][::-1]) for n in range(nbest)]] aws += [tensor2np(torch.cat(end_hyps[0]['aws'][1:][::-1], dim=2).squeeze(0))] else: nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)]] aws += [tensor2np(torch.cat(end_hyps[0]['aws'][1:], dim=2).squeeze(0))] scores += [[end_hyps[n]['score_attn'] for n in range(nbest)]] # Check <eos> eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)]) # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.bwd: nbest_hyps_idx = [[nbest_hyps_idx[b][n][1:] if eos_flags[b][n] else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)] else: nbest_hyps_idx = [[nbest_hyps_idx[b][n][:-1] if eos_flags[b][n] else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)] # Store ASR/LM state if len(end_hyps) > 0: self.lmstate_final = end_hyps[0]['lmstate'] return nbest_hyps_idx, aws, scores
def beam_search(self, eouts, elens, params, idx2token=None, lm=None, lm_second=None, lm_second_bwd=None, ctc_log_probs=None, nbest=1, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, ensmbl_eouts=[], ensmbl_elens=[], ensmbl_decs=[], cache_states=True): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, d_model]` elens (IntTensor): `[B]` params (dict): decoding hyperparameters idx2token (): converter from index to token lm (torch.nn.module): firsh-pass LM lm_second (torch.nn.module): second-pass LM lm_second_bwd (torch.nn.module): secoding-pass backward LM ctc_log_probs (FloatTensor): nbest (int): number of N-best list exclude_eos (bool): exclude <eos> from hypothesis refs_id (List): reference list utt_ids (List): utterance id list speakers (List): speaker list ensmbl_eouts (List[FloatTensor]): encoder outputs for ensemble models ensmbl_elens (List[IntTensor]) encoder outputs for ensemble models ensmbl_decs (List[torch.nn.Module): decoders for ensemble models cache_states (bool): cache decoder states for fast decoding Returns: nbest_hyps_idx (List): length `[B]`, each of which contains list of N hypotheses aws (List): length `[B]`, each of which contains arrays of size `[H, L, T]` scores (List): """ bs, xmax, _ = eouts.size() n_models = len(ensmbl_decs) + 1 beam_width = params.get('recog_beam_width') assert 1 <= nbest <= beam_width ctc_weight = params.get('recog_ctc_weight') max_len_ratio = params.get('recog_max_len_ratio') min_len_ratio = params.get('recog_min_len_ratio') lp_weight = params.get('recog_length_penalty') length_norm = params.get('recog_length_norm') cache_emb = params.get('recog_cache_embedding') lm_weight = params.get('recog_lm_weight') lm_weight_second = params.get('recog_lm_second_weight') lm_weight_second_bwd = params.get('recog_lm_bwd_weight') eos_threshold = params.get('recog_eos_threshold') lm_state_carry_over = params.get('recog_lm_state_carry_over') softmax_smoothing = params.get('recog_softmax_smoothing') eps_wait = params.get('recog_mma_delay_threshold') helper = BeamSearch(beam_width, self.eos, ctc_weight, lm_weight, self.device) lm = helper.verify_lm_eval_mode(lm, lm_weight, cache_emb) lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second, cache_emb) lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd, lm_weight_second_bwd, cache_emb) # cache token embeddings if cache_emb: self.cache_embedding(eouts.device) if ctc_log_probs is not None: assert ctc_weight > 0 ctc_log_probs = tensor2np(ctc_log_probs) nbest_hyps_idx, aws, scores = [], [], [] eos_flags = [] for b in range(bs): # Initialization per utterance lmstate = None ys = eouts.new_zeros((1, 1), dtype=torch.int64).fill_(self.eos) # print(ys.shape) for layer in self.layers: layer.reset() # For joint CTC-Attention decoding ctc_prefix_scorer = None if ctc_log_probs is not None: if self.bwd: ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b][::-1], self.blank, self.eos) else: ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos) if speakers is not None: if speakers[b] == self.prev_spk: if lm_state_carry_over and isinstance(lm, RNNLM): lmstate = self.lmstate_final self.prev_spk = speakers[b] end_hyps = [] hyps = [{'hyp': [self.eos], 'ys': ys, 'cache': None, 'score': 0., 'score_att': 0., 'score_ctc': 0., 'score_lm': 0., 'aws': [None], 'lmstate': lmstate, 'ensmbl_cache': [[None] * dec.n_layers for dec in ensmbl_decs] if n_models > 1 else None, 'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None, 'quantity_rate': 1., 'streamable': True, 'streaming_failed_point': 1000}] streamable_global = True ymax = math.ceil(elens[b] * max_len_ratio) for i in range(ymax): # batchfy all hypotheses for batch decoding cache = [None] * self.n_layers if cache_states and i > 0: for lth in range(self.n_layers): # cache[lth] = torch.cat([beam['cache'][lth] for beam in hyps], dim=0) ys = eouts.new_zeros((len(hyps), i + 1), dtype=torch.int64) for j, beam in enumerate(hyps): ys[j, :] = beam['ys'] if i > 0: xy_aws_prev = torch.cat([beam['aws'][-1] for beam in hyps], dim=0) # `[B, n_layers, H_ma, 1, klen]` else: xy_aws_prev = None # Update LM states for shallow fusion y_lm = ys[:, -1:].clone() # NOTE: this is important _, lmstate, scores_lm = helper.update_rnnlm_state_batch(lm, hyps, y_lm) # for the main model # print(i) causal_mask = eouts.new_ones(i + 1, i + 1, dtype=torch.uint8) causal_mask = torch.tril(causal_mask).unsqueeze(0).repeat([ys.size(0), 1, 1]) # print(causal_mask.shape) out = self.pos_enc(self.embed_token_id(ys), scale=True) # scaled + dropout # print(out.shape) # assert False, 'vv' n_heads_total = 0 eouts_b = eouts[b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) # [Beam, T, dim] new_cache = [None] * self.n_layers xy_aws_layers = [] xy_aws = None lth_s = self.mma_first_layer - 1 # 自回归解码 for lth, layer in enumerate(self.layers): out = layer( out, causal_mask, eouts_b, None, cache=cache[lth], xy_aws_prev=xy_aws_prev[:, lth - lth_s] if lth >= lth_s and i > 0 else None, eps_wait=eps_wait) xy_aws = layer.xy_aws new_cache[lth] = out if xy_aws is not None: xy_aws_layers.append(xy_aws) logits = self.output(self.norm_out(out[:, -1])) # 取当前时刻概率输出 probs = torch.softmax(logits * softmax_smoothing, dim=1) xy_aws_layers = torch.stack(xy_aws_layers, dim=1) # `[B, H, n_layers, L, T]` # Ensemble initialization ensmbl_cache = [[None] * dec.n_layers for dec in ensmbl_decs] if n_models > 1 and cache_states and i > 0: for i_e, dec in enumerate(ensmbl_decs): for lth in range(dec.n_layers): ensmbl_cache[i_e][lth] = torch.cat([beam['ensmbl_cache'][i_e][lth] for beam in hyps], dim=0) # for the ensemble ensmbl_new_cache = [[None] * dec.n_layers for dec in ensmbl_decs] for i_e, dec in enumerate(ensmbl_decs): out_e = dec.pos_enc(dec.embed(ys)) # scaled + dropout eouts_e = ensmbl_eouts[i_e][b:b + 1, :elens[b]].repeat([ys.size(0), 1, 1]) for lth in range(dec.n_layers): out_e = dec.layers[lth](out_e, causal_mask, eouts_e, None, cache=ensmbl_cache[i_e][lth]) ensmbl_new_cache[i_e][lth] = out_e logits_e = dec.output(dec.norm_out(out_e[:, -1])) probs += torch.softmax(logits_e * softmax_smoothing, dim=1) # NOTE: sum in the probability scale (not log-scale) # Ensemble 多个模型融合 scores_att = torch.log(probs / n_models) # [1, vocab] # print(scores_att.shape) # assert False, 'vv' new_hyps = [] for j, beam in enumerate(hyps): # hyps [,] # 每个beam生成beam # Attention scores total_scores_att = beam['score_att'] + scores_att[j:j + 1] # current time T # [[vocab]] total_scores = total_scores_att * (1 - ctc_weight) # Add LM score <before> top-K selection if lm is not None: total_scores_lm = beam['score_lm'] + scores_lm[j:j + 1, -1] total_scores += total_scores_lm * lm_weight else: total_scores_lm = eouts.new_zeros(1, self.vocab) # topk_ids total_scores_topk, topk_ids = torch.topk( total_scores, k=beam_width, dim=1, largest=True, sorted=True) # Add length penalty if lp_weight > 0: total_scores_topk += (len(beam['hyp'][1:]) + 1) * lp_weight # Add CTC score new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score( beam['hyp'], topk_ids, beam['ctc_state'], total_scores_topk, ctc_prefix_scorer) new_aws = beam['aws'] + [xy_aws_layers[j:j + 1, :, :, -1:]] aws_j = torch.cat(new_aws[1:], dim=3) # `[1, H, n_layers, L, T]` # forward direction for k in range(beam_width): idx = topk_ids[0, k].item() # k-beam 的索引 length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1 total_score = total_scores_topk[0, k].item() / length_norm_factor # 当前长度 if idx == self.eos: # Exclude short hypotheses # remove 短句 中间的静默信号 if len(beam['hyp'][1:]) < elens[b] * min_len_ratio: continue # EOS threshold # 找到不是EOS的最大得分idx max_score_no_eos = scores_att[j, :idx].max(0)[0].item() max_score_no_eos = max(max_score_no_eos, scores_att[j, idx + 1:].max(0)[0].item()) if scores_att[j, idx].item() <= eos_threshold * max_score_no_eos: # 继续识别 跳过当前帧 continue streaming_failed_point = beam['streaming_failed_point'] quantity_rate = 1. # 流式相关的 if self.attn_type == 'mocha': n_tokens_hyp_k = i + 1 n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item() quantity_diff = n_tokens_hyp_k * n_heads_total - n_quantity_k if quantity_diff != 0: if idx == self.eos: n_tokens_hyp_k -= 1 # NOTE: do not count <eos> for streamability n_quantity_k = aws_j[:, :, :, :n_tokens_hyp_k].int().sum().item() else: streamable_global = False if n_tokens_hyp_k * n_heads_total == 0: quantity_rate = 0 else: quantity_rate = n_quantity_k / (n_tokens_hyp_k * n_heads_total) if beam['streamable'] and not streamable_global: streaming_failed_point = i new_hyps.append( {'hyp': beam['hyp'] + [idx], 'ys': torch.cat([beam['ys'], eouts.new_zeros((1, 1), dtype=torch.int64).fill_(idx)], dim=-1), 'cache': [new_cache_l[j:j + 1] for new_cache_l in new_cache] if cache_states else cache, 'score': total_score, 'score_att': total_scores_att[0, idx].item(), 'score_ctc': total_scores_ctc[k].item(), 'score_lm': total_scores_lm[0, idx].item(), 'aws': new_aws, 'lmstate': {'hxs': lmstate['hxs'][:, j:j + 1], 'cxs': lmstate['cxs'][:, j:j + 1]} if lmstate is not None else None, 'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None, 'ensmbl_cache': [[new_cache_e_l[j:j + 1] for new_cache_e_l in new_cache_e] for new_cache_e in ensmbl_new_cache] if cache_states else None, 'streamable': streamable_global, 'streaming_failed_point': streaming_failed_point, 'quantity_rate': quantity_rate}) # Local pruning # new_hyps[beamsize,hyps] new_hyps_sorted = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width] # Remove complete hypotheses # 剪枝 结果beamwidth大小的列表 new_hyps, end_hyps, is_finish = helper.remove_complete_hyp( new_hyps_sorted, end_hyps, prune=True) hyps = new_hyps[:] if is_finish: break # Global pruning # 一句识别结束 if len(end_hyps) == 0: end_hyps = hyps[:] elif len(end_hyps) < nbest and nbest > 1: end_hyps.extend(hyps[:nbest - len(end_hyps)]) # forward/backward second-pass LM rescoring end_hyps = helper.lm_rescoring(end_hyps, lm_second, lm_weight_second, length_norm=length_norm, tag='second') end_hyps = helper.lm_rescoring(end_hyps, lm_second_bwd, lm_weight_second_bwd, length_norm=length_norm, tag='second_bwd') # Sort by score end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True) # TODO: for j in range(len(end_hyps[0]['aws'][1:])): tmp = end_hyps[0]['aws'][j + 1] end_hyps[0]['aws'][j + 1] = tmp.view(1, -1, tmp.size(-2), tmp.size(-1)) # metrics for streaming infernece self.streamable = end_hyps[0]['streamable'] self.quantity_rate = end_hyps[0]['quantity_rate'] self.last_success_frame_ratio = None if idx2token is not None: if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) assert self.vocab == idx2token.vocab logger.info('=' * 200) for k in range(len(end_hyps)): if refs_id is not None: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token( end_hyps[k]['hyp'][1:][::-1] if self.bwd else end_hyps[k]['hyp'][1:])) logger.info('num tokens (hyp): %d' % len(end_hyps[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % end_hyps[k]['score']) logger.info('log prob (hyp, att): %.7f' % (end_hyps[k]['score_att'] * (1 - ctc_weight))) if ctc_prefix_scorer is not None: logger.info('log prob (hyp, ctc): %.7f' % (end_hyps[k]['score_ctc'] * ctc_weight)) if lm is not None: logger.info('log prob (hyp, first-pass lm): %.7f' % (end_hyps[k]['score_lm'] * lm_weight)) if lm_second is not None: logger.info('log prob (hyp, second-pass lm): %.7f' % (end_hyps[k]['score_lm_second'] * lm_weight_second)) if lm_second_bwd is not None: logger.info('log prob (hyp, second-pass lm, reverse): %.7f' % (end_hyps[k]['score_lm_second_bwd'] * lm_weight_second_bwd)) if self.attn_type == 'mocha': logger.info('streamable: %s' % end_hyps[k]['streamable']) logger.info('streaming failed point: %d' % (end_hyps[k]['streaming_failed_point'] + 1)) logger.info('quantity rate [%%]: %.2f' % (end_hyps[k]['quantity_rate'] * 100)) logger.info('-' * 50) if self.attn_type == 'mocha' and end_hyps[0]['streaming_failed_point'] < 1000: assert not self.streamable aws_last_success = end_hyps[0]['aws'][1:][end_hyps[0]['streaming_failed_point'] - 1] rightmost_frame = max(0, aws_last_success[0, :, 0].nonzero()[:, -1].max().item()) + 1 frame_ratio = rightmost_frame * 100 / xmax self.last_success_frame_ratio = frame_ratio logger.info('streaming last success frame ratio: %.2f' % frame_ratio) # N-best list if self.bwd: # Reverse the order nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:][::-1]) for n in range(nbest)]] aws += [[tensor2np(torch.cat(end_hyps[n]['aws'][1:][::-1], dim=2).squeeze(0)) for n in range(nbest)]] else: nbest_hyps_idx += [[np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest)]] aws += [[tensor2np(torch.cat(end_hyps[n]['aws'][1:], dim=2).squeeze(0)) for n in range(nbest)]] scores += [[end_hyps[n]['score_att'] for n in range(nbest)]] # Check <eos> eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)]) # Exclude <eos> (<sos> in case of the backward decoder) if exclude_eos: if self.bwd: nbest_hyps_idx = [[nbest_hyps_idx[b][n][1:] if eos_flags[b][n] else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)] aws = [[aws[b][n][:, 1:] if eos_flags[b][n] else aws[b][n] for n in range(nbest)] for b in range(bs)] else: nbest_hyps_idx = [[nbest_hyps_idx[b][n][:-1] if eos_flags[b][n] else nbest_hyps_idx[b][n] for n in range(nbest)] for b in range(bs)] aws = [[aws[b][n][:, :-1] if eos_flags[b][n] else aws[b][n] for n in range(nbest)] for b in range(bs)] # Store ASR/LM state if bs == 1: self.lmstate_final = end_hyps[0]['lmstate'] return nbest_hyps_idx, aws, scores
def beam_search(self, eouts, elens, params, idx2token=None, lm=None, lm_second=None, lm_second_bwd=None, ctc_log_probs=None, nbest=1, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, ensmbl_eouts=None, ensmbl_elens=None, ensmbl_decs=[]): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (IntTensor): `[B]` params (dict): hyperparameters for decoding idx2token (): converter from index to token lm: firsh path LM lm_second: second path LM lm_second_bwd: secoding path backward LM ctc_log_probs (FloatTensor): `[B, T, vocab]` nbest (int): number of N-best list exclude_eos (bool): exclude <eos> from hypothesis refs_id (list): reference list utt_ids (list): utterance id list speakers (list): speaker list ensmbl_eouts (list): list of FloatTensor ensmbl_elens (list) list of list ensmbl_decs (list): list of torch.nn.Module Returns: nbest_hyps_idx (list): length `B`, each of which contains list of N hypotheses aws: dummy scores: dummy """ bs = eouts.size(0) beam_width = params['recog_beam_width'] assert 1 <= nbest <= beam_width ctc_weight = params['recog_ctc_weight'] assert ctc_weight == 0 assert ctc_log_probs is None # length_norm = params['recog_length_norm'] lm_weight = params['recog_lm_weight'] lm_weight_second = params['recog_lm_second_weight'] lm_weight_second_bwd = params['recog_lm_bwd_weight'] # asr_state_carry_over = params['recog_asr_state_carry_over'] lm_state_carry_over = params['recog_lm_state_carry_over'] if lm is not None: assert lm_weight > 0 lm.eval() if lm_second is not None: assert lm_weight_second > 0 lm_second.eval() if lm_second_bwd is not None: assert lm_weight_second_bwd > 0 lm_second_bwd.eval() nbest_hyps_idx = [] eos_flags = [] for b in range(bs): # Initialization per utterance y = eouts.new_zeros((1, 1), dtype=torch.int64).fill_(self.eos) y_emb = self.dropout_emb(self.embed(y)) dout, dstate = self.recurrency(y_emb, None) lmstate = None if speakers is not None: if speakers[b] == self.prev_spk: if lm_state_carry_over and isinstance(lm, RNNLM): lmstate = self.lmstate_final self.prev_spk = speakers[b] helper = BeamSearch(beam_width, self.eos, ctc_weight, eouts.device) end_hyps = [] hyps = [{ 'hyp': [self.eos], 'hyp_str': '', 'ys': [self.eos], 'score': 0., 'score_rnnt': 0., 'score_lm': 0., 'dout': dout, 'dstate': dstate, 'lmstate': lmstate }] for t in range(elens[b]): # batchfy all hypotheses for batch decoding douts = torch.cat([beam['dout'] for beam in hyps], dim=0) logits = self.joint( eouts[b:b + 1, t:t + 1].repeat([douts.size(0), 1, 1]), douts) scores_rnnt = torch.log_softmax(logits.squeeze(2).squeeze(1), dim=-1) # `[B, vocab]` new_hyps = [] for j, beam in enumerate(hyps): # Transducer scores total_scores_rnnt = beam['score_rnnt'] + scores_rnnt[j:j + 1] total_scores_topk, topk_ids = torch.topk(total_scores_rnnt, k=beam_width, dim=-1, largest=True, sorted=True) for k in range(beam_width): idx = topk_ids[0, k].item() # length_norm_factor = len(beam['hyp'][1:]) + 1 if length_norm else 1 # total_score = total_scores_topk[0, k].item() / length_norm_factor total_score = total_scores_topk[0, k].item() total_score_lm = beam['score_lm'] if idx == self.blank: new_hyps.append(beam.copy()) new_hyps[-1]['score'] += scores_rnnt[ j, self.blank].item() new_hyps[-1]['score_rnnt'] += scores_rnnt[ j, self.blank].item() continue # Update prediction network only when predicting non-blank labels hyp_ids = beam['hyp'] + [idx] hyp_str = ' '.join(list(map(str, hyp_ids))) if hyp_str in self.state_cache.keys(): # from cache dout = self.state_cache[hyp_str]['dout'] dstate = self.state_cache[hyp_str]['dstate'] lmstate = self.state_cache[hyp_str]['lmstate'] total_score_lm = self.state_cache[hyp_str][ 'total_score_lm'] else: y = eouts.new_zeros((1, 1), dtype=torch.int64).fill_(idx) y_emb = self.dropout_emb(self.embed(y)) dout, dstate = self.recurrency( y_emb, beam['dstate']) # Update LM states for shallow fusion y_prev = eouts.new_zeros( (1, 1), dtype=torch.int64).fill_(beam['hyp'][-1]) _, lmstate, scores_lm = helper.update_rnnlm_state( lm, beam, y_prev) if lm is not None: total_score_lm += scores_lm[0, -1, idx].item() # total_score_lm /= length_norm_factor self.state_cache[hyp_str] = { 'dout': dout, 'dstate': dstate, 'lmstate': { 'hxs': lmstate['hxs'], 'cxs': lmstate['cxs'] } if lmstate is not None else None, 'total_score_lm': total_score_lm, } if lm is not None: total_score += total_score_lm * lm_weight new_hyps.append({ 'hyp': hyp_ids, 'hyp_str': hyp_str, 'score': total_score, 'score_rnnt': total_scores_rnnt[0, idx].item(), 'score_lm': total_score_lm, 'dout': dout, 'dstate': dstate, 'lmstate': { 'hxs': lmstate['hxs'], 'cxs': lmstate['cxs'] } if lmstate is not None else None }) # Merge hypotheses having the same token sequences new_hyps_merged = {} for beam in new_hyps: hyp_str = ' '.join(list(map(str, beam['hyp']))) if hyp_str not in new_hyps_merged.keys(): new_hyps_merged[hyp_str] = beam elif hyp_str in new_hyps_merged.keys(): if beam['score'] > new_hyps_merged[hyp_str]['score']: new_hyps_merged[hyp_str] = beam new_hyps = [v for v in new_hyps_merged.values()] # Local pruning new_hyps_sorted = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width] # Remove complete hypotheses new_hyps, end_hyps, is_finish = helper.remove_complete_hyp( new_hyps_sorted, end_hyps) hyps = new_hyps[:] if is_finish: break # Global pruning if len(end_hyps) == 0: end_hyps = hyps[:] elif len(end_hyps) < nbest and nbest > 1: end_hyps.extend(hyps[:nbest - len(end_hyps)]) # forward second path LM rescoring if lm_second is not None: self.lm_rescoring(end_hyps, lm_second, lm_weight_second, tag='second') # backward second path LM rescoring if lm_second_bwd is not None: self.lm_rescoring(end_hyps, lm_second_bwd, lm_weight_second_bwd, tag='second_bwd') # Sort by score end_hyps = sorted( end_hyps, key=lambda x: x['score'] / max(len(x['hyp'][1:]), 1), reverse=True) # Reset state cache self.state_cache = OrderedDict() if idx2token is not None: if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) assert self.vocab == idx2token.vocab logger.info('=' * 200) for k in range(len(end_hyps)): if refs_id is not None: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % end_hyps[k]['score']) logger.info('log prob (hyp, rnnt): %.7f' % end_hyps[k]['score_rnnt']) if lm is not None: logger.info('log prob (hyp, first-path lm): %.7f' % (end_hyps[k]['score_lm'] * lm_weight)) if lm_second is not None: logger.info('log prob (hyp, second-path lm): %.7f' % (end_hyps[k]['score_lm_second'] * lm_weight_second)) if lm_second_bwd is not None: logger.info( 'log prob (hyp, second-path lm, reverse): %.7f' % (end_hyps[k]['score_lm_second_rev'] * lm_weight_second_bwd)) logger.info('-' * 50) # N-best list nbest_hyps_idx += [[ np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest) ]] # Check <eos> eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)]) return nbest_hyps_idx, None, None
def beam_search(self, eouts, elens, params, idx2token, lm=None, lm_second=None, lm_second_bwd=None, ctc_log_probs=None, nbest=1, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, ensmbl_eouts=None, ensmbl_elens=None, ensmbl_decs=[]): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (IntTensor): `[B]` params (dict): recog_beam_width (int): size of beam recog_max_len_ratio (int): maximum sequence length of tokens recog_min_len_ratio (float): minimum sequence length of tokens recog_length_penalty (float): length penalty recog_coverage_penalty (float): coverage penalty recog_coverage_threshold (float): threshold for coverage penalty recog_lm_weight (float): weight of LM score idx2token (): converter from index to token lm: firsh path LM lm_second: second path LM lm_second_bwd: secoding path backward LM ctc_log_probs (FloatTensor): nbest (int): exclude_eos (bool): exclude <eos> from hypothesis refs_id (list): reference list utt_ids (list): utterance id list speakers (list): speaker list ensmbl_eouts (list): list of FloatTensor ensmbl_elens (list) list of list ensmbl_decs (list): list of torch.nn.Module Returns: nbest_hyps_idx (list): A list of length `[B]`, which contains list of N hypotheses aws: dummy scores: dummy """ bs = eouts.size(0) beam_width = params['recog_beam_width'] ctc_weight = params['recog_ctc_weight'] lm_weight = params['recog_lm_weight'] lm_weight_second = params['recog_lm_second_weight'] lm_weight_second_bwd = params['recog_lm_bwd_weight'] asr_state_carry_over = params['recog_asr_state_carry_over'] lm_state_carry_over = params['recog_lm_state_carry_over'] if lm is not None: assert lm_weight > 0 lm.eval() if lm_second is not None: assert lm_weight_second > 0 lm_second.eval() if lm_second_bwd is not None: assert lm_weight_second_bwd > 0 lm_second_bwd.eval() if ctc_log_probs is not None: assert ctc_weight > 0 ctc_log_probs = tensor2np(ctc_log_probs) nbest_hyps_idx = [] eos_flags = [] for b in range(bs): # Initialization per utterance y = eouts.new_zeros(bs, 1).fill_(self.eos).long() y_emb = self.dropout_emb(self.embed(y)) dout, dstate = self.recurrency(y_emb, None) lmstate = None # For joint CTC-Attention decoding ctc_prefix_scorer = None if ctc_log_probs is not None: ctc_prefix_scorer = CTCPrefixScore(ctc_log_probs[b], self.blank, self.eos) if speakers is not None: if speakers[b] == self.prev_spk: if lm_state_carry_over and isinstance(lm, RNNLM): lmstate = self.lmstate_final self.prev_spk = speakers[b] helper = BeamSearch(beam_width, self.eos, ctc_weight, self.device_id) end_hyps = [] hyps = [{ 'hyp': [self.eos], 'ref_id': [self.eos], 'score': 0., 'score_rnnt': 0., 'score_lm': 0., 'score_ctc': 0., 'dout': dout, 'dstate': dstate, 'lmstate': lmstate, 'ctc_state': ctc_prefix_scorer.initial_state() if ctc_prefix_scorer is not None else None }] for t in range(elens[b]): # preprocess for batch decoding douts = torch.cat([beam['dout'] for beam in hyps], dim=0) outs = self.joint( eouts[b:b + 1, t:t + 1].repeat([douts.size(0), 1, 1]), douts) scores_rnnt = torch.log_softmax(outs.squeeze(2).squeeze(1), dim=-1) # Update LM states for shallow fusion y = eouts.new_zeros(len(hyps), 1).long() for j, beam in enumerate(hyps): y[j, 0] = beam['hyp'][-1] lmstate, scores_lm = None, None if lm is not None: if hyps[0]['lmstate'] is not None: lm_hxs = torch.cat( [beam['lmstate']['hxs'] for beam in hyps], dim=1) lm_cxs = torch.cat( [beam['lmstate']['cxs'] for beam in hyps], dim=1) lmstate = {'hxs': lm_hxs, 'cxs': lm_cxs} lmout, lmstate, scores_lm = lm.predict(y, lmstate) new_hyps = [] for j, beam in enumerate(hyps): dout = douts[j:j + 1] dstate = beam['dstate'] lmstate = beam['lmstate'] # Attention scores total_scores_rnnt = beam['score_rnnt'] + scores_rnnt[j:j + 1] total_scores = total_scores_rnnt * (1 - ctc_weight) # Add LM score <after> top-K selection total_scores_topk, topk_ids = torch.topk(total_scores, k=beam_width, dim=-1, largest=True, sorted=True) if lm is not None: total_scores_lm = beam['score_lm'] + scores_lm[ j, -1, topk_ids[0]] total_scores_topk += total_scores_lm * lm_weight else: total_scores_lm = eouts.new_zeros(beam_width) # Add CTC score new_ctc_states, total_scores_ctc, total_scores_topk = helper.add_ctc_score( beam['hyp'], topk_ids, beam['ctc_state'], total_scores_topk, ctc_prefix_scorer) for k in range(beam_width): idx = topk_ids[0, k].item() if idx == self.blank: beam['score'] = total_scores_topk[0, k].item() beam['score_rnnt'] = total_scores_topk[0, k].item() new_hyps.append(beam.copy()) continue # skip blank-dominant frames # if total_scores_topk[0, self.blank].item() > 0.7: # continue # Update prediction network only when predicting non-blank labels hyp_id = beam['hyp'] + [idx] hyp_str = ' '.join(list(map(str, hyp_id))) # if hyp_str in self.state_cache.keys(): # # from cache # dout = self.state_cache[hyp_str]['dout'] # new_dstate = self.state_cache[hyp_str]['dstate'] # lmstate = self.state_cache[hyp_str]['lmstate'] # else: y = eouts.new_zeros(1, 1).fill_(idx).long() y_emb = self.dropout_emb(self.embed(y)) dout, new_dstate = self.recurrency(y_emb, dstate) # store in cache self.state_cache[hyp_str] = { 'dout': dout, 'dstate': new_dstate, 'lmstate': { 'hxs': lmstate['hxs'][:, j:j + 1], 'cxs': lmstate['cxs'][:, j:j + 1] } if lmstate is not None else None, } new_hyps.append({ 'hyp': hyp_id, 'score': total_scores_topk[0, k].item(), 'score_rnnt': total_scores_rnnt[0, idx].item(), 'score_ctc': total_scores_ctc[k].item(), 'score_lm': total_scores_lm[k].item(), 'dout': dout, 'dstate': new_dstate, 'lmstate': { 'hxs': lmstate['hxs'][:, j:j + 1], 'cxs': lmstate['cxs'][:, j:j + 1] } if lmstate is not None else None, 'ctc_state': new_ctc_states[k] if ctc_prefix_scorer is not None else None }) # Merge hypotheses having the same token sequences new_hyps_merged = {} for beam in new_hyps: hyp_str = ' '.join(list(map(str, beam['hyp']))) if hyp_str not in new_hyps_merged.keys(): new_hyps_merged[hyp_str] = beam elif hyp_str in new_hyps_merged.keys(): if beam['score'] > new_hyps_merged[hyp_str]['score']: new_hyps_merged[hyp_str] = beam new_hyps = [v for v in new_hyps_merged.values()] # Local pruning new_hyps_tmp = sorted(new_hyps, key=lambda x: x['score'], reverse=True)[:beam_width] # Remove complete hypotheses new_hyps = [] for hyp in new_hyps_tmp: new_hyps += [hyp] if len(end_hyps) >= beam_width: end_hyps = end_hyps[:beam_width] break hyps = new_hyps[:] # Global pruning if len(end_hyps) == 0: end_hyps = hyps[:] elif len(end_hyps) < nbest and nbest > 1: end_hyps.extend(hyps[:nbest - len(end_hyps)]) # forward second path LM rescoring if lm_second is not None: self.lm_rescoring(end_hyps, lm_second, lm_weight_second, tag='second') # backward secodn path LM rescoring if lm_second_bwd is not None: self.lm_rescoring(end_hyps, lm_second_bwd, lm_weight_second_bwd, tag='second_rev') end_hyps = sorted(end_hyps, key=lambda x: x['score'], reverse=True) # Reset state cache self.state_cache = OrderedDict() if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) if idx2token is not None: logger.info('=' * 200) for k in range(len(end_hyps)): if refs_id is not None and self.vocab == idx2token.vocab: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % end_hyps[k]['score']) if ctc_log_probs is not None: logger.info('log prob (hyp, ctc): %.7f' % (end_hyps[k]['score_ctc'])) if lm is not None: logger.info('log prob (hyp, lm): %.7f' % (end_hyps[k]['score_lm'])) logger.info('-' * 50) # N-best list nbest_hyps_idx += [[ np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest) ]] # Check <eos> eos_flags.append([(end_hyps[n]['hyp'][-1] == self.eos) for n in range(nbest)]) return nbest_hyps_idx, None, None
def beam_search(self, eouts, elens, params, idx2token, lm=None, lm_second=None, lm_second_bwd=None, nbest=1, refs_id=None, utt_ids=None, speakers=None): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (List): length `B` params (dict): recog_beam_width (int): size of beam recog_length_penalty (float): length penalty recog_lm_weight (float): weight of first path LM score recog_lm_second_weight (float): weight of second path LM score recog_lm_bwd_weight (float): weight of second path backward LM score idx2token (): converter from index to token lm: firsh path LM lm_second: second path LM lm_second_bwd: second path backward LM nbest (int): refs_id (List): reference list utt_ids (List): utterance id list speakers (List): speaker list Returns: nbest_hyps_idx (List[List[List]]): Best path hypothesis """ bs = eouts.size(0) beam_width = params['recog_beam_width'] lp_weight = params['recog_length_penalty'] lm_weight = params['recog_lm_weight'] lm_weight_second = params['recog_lm_second_weight'] lm_weight_second_bwd = params['recog_lm_bwd_weight'] helper = BeamSearch(beam_width, self.eos, 1.0, eouts.device) lm = helper.verify_lm_eval_mode(lm, lm_weight) lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second) lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd, lm_weight_second_bwd) nbest_hyps_idx = [] log_probs = torch.log_softmax(self.output(eouts), dim=-1) for b in range(bs): # Elements in the beam are (prefix, (p_b, p_no_blank)) # Initialize the beam with the empty sequence, a probability of # 1 for ending in blank and zero for ending in non-blank (in log space). beam = [{'hyp': [self.eos], # <eos> is used for LM 'p_b': LOG_1, 'p_nb': LOG_0, 'score_lm': LOG_1, 'lmstate': None}] for t in range(elens[b]): new_beam = [] # Pick up the top-k scores log_probs_topk, topk_ids = torch.topk( log_probs[b:b + 1, t], k=min(beam_width, self.vocab), dim=-1, largest=True, sorted=True) for i_beam in range(len(beam)): hyp = beam[i_beam]['hyp'][:] p_b = beam[i_beam]['p_b'] p_nb = beam[i_beam]['p_nb'] score_lm = beam[i_beam]['score_lm'] # case 1. hyp is not extended new_p_b = np.logaddexp(p_b + log_probs[b, t, self.blank].item(), p_nb + log_probs[b, t, self.blank].item()) if len(hyp) > 1: new_p_nb = p_nb + log_probs[b, t, hyp[-1]].item() else: new_p_nb = LOG_0 score_ctc = np.logaddexp(new_p_b, new_p_nb) score_lp = len(hyp[1:]) * lp_weight new_beam.append({'hyp': hyp, 'score': score_ctc + score_lm + score_lp, 'p_b': new_p_b, 'p_nb': new_p_nb, 'score_ctc': score_ctc, 'score_lm': score_lm, 'score_lp': score_lp, 'lmstate': beam[i_beam]['lmstate']}) # Update LM states for shallow fusion if lm is not None: _, lmstate, lm_log_probs = lm.predict( eouts.new_zeros(1, 1).fill_(hyp[-1]), beam[i_beam]['lmstate']) else: lmstate = None # case 2. hyp is extended new_p_b = LOG_0 for c in tensor2np(topk_ids)[0]: p_t = log_probs[b, t, c].item() if c == self.blank: continue c_prev = hyp[-1] if len(hyp) > 1 else None if c == c_prev: new_p_nb = p_b + p_t # TODO(hirofumi): apply character LM here else: new_p_nb = np.logaddexp(p_b + p_t, p_nb + p_t) # TODO(hirofumi): apply character LM here if c == self.space: pass # TODO(hirofumi): apply word LM here score_ctc = np.logaddexp(new_p_b, new_p_nb) score_lp = (len(hyp[1:]) + 1) * lp_weight if lm_weight > 0 and lm is not None: local_score_lm = lm_log_probs[0, 0, c].item() * lm_weight score_lm += local_score_lm new_beam.append({'hyp': hyp + [c], 'score': score_ctc + score_lm + score_lp, 'p_b': new_p_b, 'p_nb': new_p_nb, 'score_ctc': score_ctc, 'score_lm': score_lm, 'score_lp': score_lp, 'lmstate': lmstate}) # Pruning beam = sorted(new_beam, key=lambda x: x['score'], reverse=True)[:beam_width] # forward second path LM rescoring helper.lm_rescoring(beam, lm_second, lm_weight_second, tag='second') # backward secodn path LM rescoring helper.lm_rescoring(beam, lm_second_bwd, lm_weight_second_bwd, tag='second_bwd') # Exclude <eos> nbest_hyps_idx.append([hyp['hyp'][1:] for hyp in beam]) if idx2token is not None: if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) assert self.vocab == idx2token.vocab logger.info('=' * 200) for k in range(len(beam)): if refs_id is not None: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token(beam[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % beam[k]['score']) logger.info('log prob (hyp, ctc): %.7f' % (beam[k]['score_ctc'])) logger.info('log prob (hyp, lp): %.7f' % (beam[k]['score_lp'] * lp_weight)) if lm is not None: logger.info('log prob (hyp, first-path lm): %.7f' % (beam[k]['score_lm'] * lm_weight)) if lm_second is not None: logger.info('log prob (hyp, second-path lm): %.7f' % (beam[k]['score_lm_second'] * lm_weight_second)) logger.info('-' * 50) return nbest_hyps_idx
def beam_search(self, eouts, elens, params, idx2token, lm=None, lm_second=None, lm_second_bwd=None, nbest=1, refs_id=None, utt_ids=None, speakers=None): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (List): length `[B]` params (dict): decoding hyperparameters idx2token (): converter from index to token lm (torch.nn.module): firsh-pass LM lm_second (torch.nn.module): second-pass LM lm_second_bwd (torch.nn.module): second-pass backward LM nbest (int): number of N-best list refs_id (List): reference list utt_ids (List): utterance id list speakers (List): speaker list Returns: nbest_hyps_idx (List[List[List]]): Best path hypothesis """ bs = eouts.size(0) beam_width = params.get('recog_beam_width') lp_weight = params.get('recog_length_penalty') cache_emb = params.get('recog_cache_embedding') lm_weight = params.get('recog_lm_weight') lm_weight_second = params.get('recog_lm_second_weight') lm_weight_second_bwd = params.get('recog_lm_bwd_weight') lm_state_CO = params.get('recog_lm_state_carry_over') softmax_smoothing = params.get('recog_softmax_smoothing') helper = BeamSearch(beam_width, self.eos, 1.0, lm_weight, eouts.device) lm = helper.verify_lm_eval_mode(lm, lm_weight, cache_emb) if lm is not None: assert isinstance(lm, RNNLM) lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second, cache_emb) lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd, lm_weight_second_bwd, cache_emb) log_probs = torch.log_softmax(self.output(eouts) * softmax_smoothing, dim=-1) nbest_hyps_idx = [] for b in range(bs): # Initialization per utterance lmstate = { 'hxs': eouts.new_zeros(lm.n_layers, 1, lm.n_units), 'cxs': eouts.new_zeros(lm.n_layers, 1, lm.n_units) } if lm is not None else None if speakers is not None: if speakers[b] == self.prev_spk: if lm_state_CO: lmstate = self.lmstate_final self.prev_spk = speakers[b] hyps = self.initialize_beam([self.eos], lmstate) self.state_cache = OrderedDict() hyps, new_hyps_sorted = self._beam_search(hyps, helper, log_probs[b], lm, lp_weight) # Global pruning end_hyps = hyps[:] if len(end_hyps) < nbest and nbest > 1: end_hyps.extend(new_hyps_sorted[:nbest - len(end_hyps)]) # forward/backward second-pass LM rescoring end_hyps = helper.lm_rescoring(end_hyps, lm_second, lm_weight_second, tag='second') end_hyps = helper.lm_rescoring(end_hyps, lm_second_bwd, lm_weight_second_bwd, tag='second_bwd') # Normalize by length end_hyps = sorted( end_hyps, key=lambda x: x['score'] / max(len(x['hyp'][1:]), 1), reverse=True) if idx2token is not None: if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) assert self.vocab == idx2token.vocab logger.info('=' * 200) for k in range(len(end_hyps)): if refs_id is not None: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % end_hyps[k]['score']) logger.info('log prob (hyp, ctc): %.7f' % (end_hyps[k]['score_ctc'])) logger.info('log prob (hyp, lp): %.7f' % (end_hyps[k]['score_lp'] * lp_weight)) if lm is not None: logger.info('log prob (hyp, first-pass lm): %.7f' % (end_hyps[k]['score_lm'] * lm_weight)) if lm_second is not None: logger.info('log prob (hyp, second-pass lm): %.7f' % (end_hyps[k]['score_lm_second'] * lm_weight_second)) if lm_second_bwd is not None: logger.info( 'log prob (hyp, second-pass lm, reverse): %.7f' % (end_hyps[k]['score_lm_second_bwd'] * lm_weight_second_bwd)) logger.info('-' * 50) # N-best list (exclude <eos>) nbest_hyps_idx += [[ np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest) ]] # Store LM state if bs == 1: self.lmstate_final = end_hyps[0]['lmstate'] return nbest_hyps_idx
def beam_search(self, eouts, elens, params, idx2token=None, lm=None, lm_second=None, lm_second_bwd=None, ctc_log_probs=None, nbest=1, exclude_eos=False, refs_id=None, utt_ids=None, speakers=None, ensmbl_eouts=[], ensmbl_elens=[], ensmbl_decs=[]): """Beam search decoding. Args: eouts (FloatTensor): `[B, T, enc_n_units]` elens (IntTensor): `[B]` params (dict): decoding hyperparameters idx2token (): converter from index to token lm (torch.nn.module): firsh-pass LM lm_second (torch.nn.module): second-pass LM lm_second_bwd (torch.nn.module): second-pass backward LM ctc_log_probs (FloatTensor): `[B, T, vocab]` nbest (int): number of N-best list exclude_eos (bool): exclude <eos> from hypothesis refs_id (List): reference list utt_ids (List): utterance id list speakers (List): speaker list ensmbl_eouts (List[FloatTensor]): encoder outputs for ensemble models ensmbl_elens (List[IntTensor]) encoder outputs for ensemble models ensmbl_decs (List[torch.nn.Module): decoders for ensemble models Returns: nbest_hyps_idx (List): length `[B]`, each of which contains list of N hypotheses aws: dummy scores: dummy """ bs = eouts.size(0) beam_width = params.get('recog_beam_width') assert 1 <= nbest <= beam_width ctc_weight = params.get('recog_ctc_weight') assert ctc_weight == 0 assert ctc_log_probs is None cache_emb = params.get('recog_cache_embedding') lm_weight = params.get('recog_lm_weight') lm_weight_second = params.get('recog_lm_second_weight') lm_weight_second_bwd = params.get('recog_lm_bwd_weight') lm_state_CO = params.get('recog_lm_state_carry_over') softmax_smoothing = params.get('recog_softmax_smoothing') beam_search_type = params.get('recog_rnnt_beam_search_type') helper = BeamSearch(beam_width, self.eos, ctc_weight, lm_weight, eouts.device) lm = helper.verify_lm_eval_mode(lm, lm_weight, cache_emb) if lm is not None: assert isinstance(lm, RNNLM) lm_second = helper.verify_lm_eval_mode(lm_second, lm_weight_second, cache_emb) lm_second_bwd = helper.verify_lm_eval_mode(lm_second_bwd, lm_weight_second_bwd, cache_emb) # cache token embeddings if cache_emb: self.cache_embedding(eouts.device) nbest_hyps_idx = [] for b in range(bs): # Initialization per utterance dstate = { 'hxs': eouts.new_zeros(self.n_layers, 1, self.dec_n_units), 'cxs': eouts.new_zeros(self.n_layers, 1, self.dec_n_units) } lmstate = { 'hxs': eouts.new_zeros(lm.n_layers, 1, lm.n_units), 'cxs': eouts.new_zeros(lm.n_layers, 1, lm.n_units) } if lm is not None else None if speakers is not None: if speakers[b] == self.prev_spk: if lm_state_CO: lmstate = self.lmstate_final self.prev_spk = speakers[b] end_hyps = [] hyps = self.initialize_beam([self.eos], dstate, lmstate) self.state_cache = OrderedDict() if beam_search_type == 'time_sync_mono': hyps, new_hyps = self._time_sync_mono( hyps, helper, eouts[b:b + 1, :elens[b]], softmax_smoothing, lm) elif beam_search_type == 'time_sync': hyps, new_hyps = self._time_sync(hyps, helper, eouts[b:b + 1, :elens[b]], softmax_smoothing, lm) else: raise NotImplementedError(beam_search_type) # Global pruning end_hyps = hyps[:] if len(end_hyps) < nbest and nbest > 1: end_hyps.extend(new_hyps[:nbest - len(end_hyps)]) # forward/backward second-pass LM rescoring end_hyps = helper.lm_rescoring(end_hyps, lm_second, lm_weight_second, tag='second') end_hyps = helper.lm_rescoring(end_hyps, lm_second_bwd, lm_weight_second_bwd, tag='second_bwd') # Normalize by length end_hyps = sorted( end_hyps, key=lambda x: x['score'] / max(len(x['hyp'][1:]), 1), reverse=True) # NOTE: See Algorithm 1 in https://arxiv.org/abs/1211.3711 if idx2token is not None: if utt_ids is not None: logger.info('Utt-id: %s' % utt_ids[b]) assert self.vocab == idx2token.vocab logger.info('=' * 200) for k in range(len(end_hyps)): if refs_id is not None: logger.info('Ref: %s' % idx2token(refs_id[b])) logger.info('Hyp: %s' % idx2token(end_hyps[k]['hyp'][1:])) if len(end_hyps[k]['hyp']) > 1: logger.info('num tokens (hyp): %d' % len(end_hyps[k]['hyp'][1:])) logger.info('log prob (hyp): %.7f' % end_hyps[k]['score']) logger.info('log prob (hyp, rnnt): %.7f' % end_hyps[k]['score_rnnt']) if lm is not None: logger.info('log prob (hyp, first-pass lm): %.7f' % (end_hyps[k]['score_lm'] * lm_weight)) if lm_second is not None: logger.info('log prob (hyp, second-pass lm): %.7f' % (end_hyps[k]['score_lm_second'] * lm_weight_second)) if lm_second_bwd is not None: logger.info( 'log prob (hyp, second-pass lm, reverse): %.7f' % (end_hyps[k]['score_lm_second_bwd'] * lm_weight_second_bwd)) logger.info('-' * 50) # N-best list (exclude <eos>) nbest_hyps_idx += [[ np.array(end_hyps[n]['hyp'][1:]) for n in range(nbest) ]] # Store ASR/LM state if bs == 1: self.dstates_final = end_hyps[0]['dstate'] self.lmstate_final = end_hyps[0]['lmstate'] return nbest_hyps_idx, None, None
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, task='ys'): """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode.""" block_size = params.get('recog_block_sync_size') # before subsampling cache_emb = params.get('recog_cache_embedding') ctc_weight = params.get('recog_ctc_weight') assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 or self.ctc_weight == 1.0 assert len(xs) == 1 # batch size assert params.get('recog_block_sync') # assert params.get('recog_length_norm') streaming = Streaming(xs[0], params, self.enc) factor = self.enc.subsampling_factor block_size //= factor assert block_size >= 1, "block_size is too small." hyps = None best_hyp_id_stream = [] is_reset = True # for the first block stdout = False self.eval() helper = BeamSearch(params.get('recog_beam_width'), self.eos, params.get('recog_ctc_weight'), params.get('recog_lm_weight'), self.device) lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'), cache_emb) if lm is not None: assert isinstance(lm, RNNLM) lm_second = helper.verify_lm_eval_mode( lm_second, params.get('recog_lm_second_weight'), cache_emb) # cache token embeddings if cache_emb and self.fwd_weight > 0: self.dec_fwd.cache_embedding(self.device) while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature( ) if is_reset: self.enc.reset_cache() eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] is_reset = False # detect the first boundary in the same block # CTC-based VAD if streaming.is_ctc_vad: if self.ctc_weight_sub1 > 0: ctc_probs_block = self.dec_fwd_sub1.ctc.probs( eout_block_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_block = self.dec_fwd.ctc.probs(eout_block) is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout) # Truncate the most right frames if is_reset and not is_last_block and streaming.bd_offset >= 0: eout_block = eout_block[:, :streaming.bd_offset] streaming.cache_eout(eout_block) # Block-synchronous decoding if ctc_weight == 1 or self.ctc_weight == 1: end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNT): end_hyps, hyps = self.dec_fwd.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNDecoder): for i in range(math.ceil(eout_block.size(1) / block_size)): eout_block_i = eout_block[:, i * block_size:(i + 1) * block_size] end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync( eout_block_i, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, TransformerDecoder): raise NotImplementedError else: raise NotImplementedError(self.dec_fwd) merged_hyps = sorted(end_hyps + hyps, key=lambda x: x['score'], reverse=True) if len(merged_hyps) > 0: best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) if len(hyps) == 0 or (len(best_hyp_id_prefix) > 0 and best_hyp_id_prefix[-1] == self.eos): # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Segmentation strategy 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if not is_reset: streaming._bd_offset = eout_block.size( 1) - 1 # TODO: fix later is_reset = True if len(best_hyp_id_prefix) > 0: n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames print( '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s' % (streaming.offset + eout_block.size(1) * factor, n_frames * factor, streaming.n_blanks * factor, idx2token(best_hyp_id_prefix))) if is_reset: # pick up the best hyp from ended and active hypotheses if len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) # reset streaming.reset(stdout=stdout) hyps = None streaming.next_block() if is_last_block: break # next block will start from the frame next to the boundary streaming.backoff(x_block, self.dec_fwd, stdout=stdout) # pick up the best hyp if not is_reset and len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) if len(best_hyp_id_stream) > 0: return [[np.stack(best_hyp_id_stream, axis=0)]], [None] else: return [[[]]], [None]
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, speaker=None, task='ys'): """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode.""" self.eval() block_size = params.get('recog_block_sync_size') # before subsampling cache_emb = params.get('recog_cache_embedding') ctc_weight = params.get('recog_ctc_weight') backoff = True assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 or self.ctc_weight == 1.0 assert len(xs) == 1 # batch size assert params.get('recog_block_sync') # assert params.get('recog_length_norm') streaming = Streaming(xs[0], params, self.enc, idx2token) factor = self.enc.subsampling_factor block_size //= factor assert block_size >= 1, "block_size is too small." is_transformer_enc = 'former' in self.enc.enc_type hyps = None hyps_nobd = [] best_hyp_id_session = [] is_reset = False helper = BeamSearch(params.get('recog_beam_width'), self.eos, params.get('recog_ctc_weight'), params.get('recog_lm_weight'), self.device) lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'), cache_emb) if lm is not None: assert isinstance(lm, RNNLM) lm_second = helper.verify_lm_eval_mode( lm_second, params.get('recog_lm_second_weight'), cache_emb) # cache token embeddings if cache_emb and self.fwd_weight > 0: self.dec_fwd.cache_embedding(self.device) self.enc.reset_cache() eout_block_tail = None x_block_prev, xlen_block_prev = None, None while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feat( ) if not is_transformer_enc and is_reset: self.enc.reset_cache() if backoff: self.encode([x_block_prev], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block_prev) x_block_prev = x_block xlen_block_prev = xlen_block eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] if eout_block_tail is not None: eout_block = torch.cat([eout_block_tail, eout_block], dim=1) eout_block_tail = None if eout_block.size(1) > 0: streaming.cache_eout(eout_block) # Block-synchronous decoding if ctc_weight == 1 or self.ctc_weight == 1: end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNT): end_hyps, hyps = self.dec_fwd.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNDecoder): n_frames = getattr(self.dec_fwd, 'n_frames', 0) for i in range(math.ceil(eout_block.size(1) / block_size)): eout_block_i = eout_block[:, i * block_size:(i + 1) * block_size] end_hyps, hyps, hyps_nobd = self.dec_fwd.beam_search_block_sync( eout_block_i, params, helper, idx2token, hyps, hyps_nobd, lm, speaker=speaker) elif isinstance(self.dec_fwd, TransformerDecoder): raise NotImplementedError else: raise NotImplementedError(self.dec_fwd) # CTC-based reset point detection is_reset = False if streaming.enable_ctc_reset_point_detection: if self.ctc_weight_sub1 > 0: ctc_probs_block = self.dec_fwd_sub1.ctc.probs( eout_block_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_block = self.dec_fwd.ctc.probs(eout_block) is_reset = streaming.ctc_reset_point_detection( ctc_probs_block) merged_hyps = sorted(end_hyps + hyps + hyps_nobd, key=lambda x: x['score'], reverse=True) if len(merged_hyps) > 0: best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) best_hyp_id_prefix_viz = np.array( merged_hyps[0]['hyp'][1:]) if (len(best_hyp_id_prefix) > 0 and best_hyp_id_prefix[-1] == self.eos): # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Condition 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if (not is_reset) and (not streaming.safeguard_reset): is_reset = True if len(best_hyp_id_prefix_viz) > 0: n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames print( '\rStreaming (T:%.3f [sec], offset:%d [frame], blank:%d [frame]): %s' % ((streaming.offset + eout_block.size(1) * factor) / 100, n_frames * factor, streaming.n_blanks * factor, idx2token(best_hyp_id_prefix_viz))) if is_reset: # pick up the best hyp from ended and active hypotheses if len(best_hyp_id_prefix) > 0: best_hyp_id_session.extend(best_hyp_id_prefix) # reset streaming.reset() hyps = None hyps_nobd = [] streaming.next_block() if is_last_block: break # pick up the best hyp if not is_reset and len(best_hyp_id_prefix) > 0: best_hyp_id_session.extend(best_hyp_id_prefix) if len(best_hyp_id_session) > 0: return [[np.stack(best_hyp_id_session, axis=0)]], [None] else: return [[[]]], [None]