def encode_streaming(self, xs, params, task='ys'): """Simulate streaming encoding. Decoding is performed in the offline mode. Args: xs (FloatTensor): `[B, T, idim]` params (dict): hyper-parameters for decoding task (str): task to evaluate Returns: eout (FloatTensor): `[B, T, idim]` elens (IntTensor): `[B]` """ assert task == 'ys' assert self.input_type == 'speech' assert self.fwd_weight > 0 assert len(xs) == 1 # batch size streaming = Streaming(xs[0], params, self.enc) self.enc.reset_cache() while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feat( ) eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] streaming.cache_eout(eout_block) streaming.next_block() if is_last_block: break eout = streaming.pop_eouts() elens = torch.IntTensor([eout.size(1)]) return eout, elens
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, speaker=None, task='ys'): """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode.""" self.eval() block_size = params.get('recog_block_sync_size') # before subsampling cache_emb = params.get('recog_cache_embedding') ctc_weight = params.get('recog_ctc_weight') backoff = True assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 or self.ctc_weight == 1.0 assert len(xs) == 1 # batch size assert params.get('recog_block_sync') # assert params.get('recog_length_norm') streaming = Streaming(xs[0], params, self.enc, idx2token) factor = self.enc.subsampling_factor block_size //= factor assert block_size >= 1, "block_size is too small." is_transformer_enc = 'former' in self.enc.enc_type hyps = None hyps_nobd = [] best_hyp_id_session = [] is_reset = False helper = BeamSearch(params.get('recog_beam_width'), self.eos, params.get('recog_ctc_weight'), params.get('recog_lm_weight'), self.device) lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'), cache_emb) if lm is not None: assert isinstance(lm, RNNLM) lm_second = helper.verify_lm_eval_mode( lm_second, params.get('recog_lm_second_weight'), cache_emb) # cache token embeddings if cache_emb and self.fwd_weight > 0: self.dec_fwd.cache_embedding(self.device) self.enc.reset_cache() eout_block_tail = None x_block_prev, xlen_block_prev = None, None while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feat( ) if not is_transformer_enc and is_reset: self.enc.reset_cache() if backoff: self.encode([x_block_prev], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block_prev) x_block_prev = x_block xlen_block_prev = xlen_block eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] if eout_block_tail is not None: eout_block = torch.cat([eout_block_tail, eout_block], dim=1) eout_block_tail = None if eout_block.size(1) > 0: streaming.cache_eout(eout_block) # Block-synchronous decoding if ctc_weight == 1 or self.ctc_weight == 1: end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNT): end_hyps, hyps = self.dec_fwd.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNDecoder): n_frames = getattr(self.dec_fwd, 'n_frames', 0) for i in range(math.ceil(eout_block.size(1) / block_size)): eout_block_i = eout_block[:, i * block_size:(i + 1) * block_size] end_hyps, hyps, hyps_nobd = self.dec_fwd.beam_search_block_sync( eout_block_i, params, helper, idx2token, hyps, hyps_nobd, lm, speaker=speaker) elif isinstance(self.dec_fwd, TransformerDecoder): raise NotImplementedError else: raise NotImplementedError(self.dec_fwd) # CTC-based reset point detection is_reset = False if streaming.enable_ctc_reset_point_detection: if self.ctc_weight_sub1 > 0: ctc_probs_block = self.dec_fwd_sub1.ctc.probs( eout_block_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_block = self.dec_fwd.ctc.probs(eout_block) is_reset = streaming.ctc_reset_point_detection( ctc_probs_block) merged_hyps = sorted(end_hyps + hyps + hyps_nobd, key=lambda x: x['score'], reverse=True) if len(merged_hyps) > 0: best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) best_hyp_id_prefix_viz = np.array( merged_hyps[0]['hyp'][1:]) if (len(best_hyp_id_prefix) > 0 and best_hyp_id_prefix[-1] == self.eos): # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Condition 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if (not is_reset) and (not streaming.safeguard_reset): is_reset = True if len(best_hyp_id_prefix_viz) > 0: n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames print( '\rStreaming (T:%.3f [sec], offset:%d [frame], blank:%d [frame]): %s' % ((streaming.offset + eout_block.size(1) * factor) / 100, n_frames * factor, streaming.n_blanks * factor, idx2token(best_hyp_id_prefix_viz))) if is_reset: # pick up the best hyp from ended and active hypotheses if len(best_hyp_id_prefix) > 0: best_hyp_id_session.extend(best_hyp_id_prefix) # reset streaming.reset() hyps = None hyps_nobd = [] streaming.next_block() if is_last_block: break # pick up the best hyp if not is_reset and len(best_hyp_id_prefix) > 0: best_hyp_id_session.extend(best_hyp_id_prefix) if len(best_hyp_id_session) > 0: return [[np.stack(best_hyp_id_session, axis=0)]], [None] else: return [[[]]], [None]