def encode_streaming(self, xs, params, task='ys'): """Simulate streaming encoding. Decoding is performed in the offline mode. Args: xs (FloatTensor): `[B, T, idim]` params (dict): hyper-parameters for decoding task (str): task to evaluate Returns: eout (FloatTensor): `[B, T, idim]` elens (IntTensor): `[B]` """ assert task == 'ys' assert self.input_type == 'speech' assert self.fwd_weight > 0 assert len(xs) == 1 # batch size streaming = Streaming(xs[0], params, self.enc) self.enc.reset_cache() while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature( ) eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] streaming.cache_eout(eout_block) streaming.next_block() if is_last_block: break eout = streaming.pop_eouts() elens = torch.IntTensor([eout.size(1)]) return eout, elens
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, task='ys'): """Simulate streaming decoding. Both encoding and decoding are performed in the online mode.""" assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 assert len(xs) == 1 # batch size # assert params['recog_length_norm'] global_params = copy.deepcopy(params) global_params['recog_max_len_ratio'] = 1.0 block_sync = params['recog_block_sync'] block_size = params['recog_block_sync_size'] # before subsampling streaming = Streaming(xs[0], params, self.enc) factor = self.enc.subsampling_factor block_size //= factor hyps = None best_hyp_id_stream = [] is_reset = True # for the first block stdout = False self.eval() lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) # with torch.no_grad(): while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature( ) if is_reset: self.enc.reset_cache() eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] is_reset = False # detect the first boundary in the same block # CTC-based VAD if streaming.is_ctc_vad: if self.ctc_weight_sub1 > 0: ctc_probs_block = self.dec_fwd_sub1.ctc_probs( eout_block_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_block = self.dec_fwd.ctc_probs(eout_block) is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout) # Truncate the most right frames if is_reset and not is_last_block and streaming.bd_offset >= 0: eout_block = eout_block[:, :streaming.bd_offset] streaming.cache_eout(eout_block) # Block-synchronous attention decoding if isinstance(self.dec_fwd, RNNT): raise NotImplementedError elif isinstance(self.dec_fwd, RNNDecoder) and block_sync: for i in range(math.ceil(eout_block.size(1) / block_size)): eout_block_i = eout_block[:, i * block_size:(i + 1) * block_size] end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync( eout_block_i, params, idx2token, hyps, lm, state_carry_over=False) merged_hyps = sorted(end_hyps + hyps, key=lambda x: x['score'], reverse=True) if len(merged_hyps) > 0: best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) if len(best_hyp_id_prefix ) > 0 and best_hyp_id_prefix[-1] == self.eos: # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Segmentation strategy 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if not is_reset: streaming._bd_offset = eout_block.size(1) - 1 # TODO: fix later is_reset = True if len(best_hyp_id_prefix) > 0: print( '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s' % (streaming.offset + eout_block.size(1) * factor, self.dec_fwd.n_frames * factor, streaming.n_blanks * factor, idx2token(best_hyp_id_prefix))) elif isinstance(self.dec_fwd, TransformerDecoder): best_hyp_id_prefix = [] raise NotImplementedError if is_reset: # Global decoding over the segmented region if not block_sync: eout = streaming.pop_eouts() elens = torch.IntTensor([eout.size(1)]) ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_log_probs = torch.log(self.dec_fwd.ctc_probs(eout)) nbest_hyps_id_offline = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, ctc_log_probs=ctc_log_probs, exclude_eos=exclude_eos)[0] # pick up the best hyp from ended and active hypotheses if block_sync: if len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) else: if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0]) # reset streaming.reset(stdout=stdout) hyps = None streaming.next_block() if is_last_block: break # next block will start from the frame next to the boundary streaming.backoff(x_block, self.dec_fwd, stdout=stdout) # Global decoding for tail blocks if not block_sync and streaming.n_cache_block > 0: eout = streaming.pop_eouts() elens = torch.IntTensor([eout.size(1)]) nbest_hyps_id_offline = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, exclude_eos=exclude_eos)[0] if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0]) # pick up the best hyp if not is_reset and block_sync and len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) if len(best_hyp_id_stream) > 0: return [[np.stack(best_hyp_id_stream, axis=0)]], [None] else: return [[[]]], [None]
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, task='ys'): """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode.""" block_size = params.get('recog_block_sync_size') # before subsampling cache_emb = params.get('recog_cache_embedding') ctc_weight = params.get('recog_ctc_weight') assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 or self.ctc_weight == 1.0 assert len(xs) == 1 # batch size assert params.get('recog_block_sync') # assert params.get('recog_length_norm') streaming = Streaming(xs[0], params, self.enc) factor = self.enc.subsampling_factor block_size //= factor assert block_size >= 1, "block_size is too small." hyps = None best_hyp_id_stream = [] is_reset = True # for the first block stdout = False self.eval() helper = BeamSearch(params.get('recog_beam_width'), self.eos, params.get('recog_ctc_weight'), params.get('recog_lm_weight'), self.device) lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'), cache_emb) if lm is not None: assert isinstance(lm, RNNLM) lm_second = helper.verify_lm_eval_mode( lm_second, params.get('recog_lm_second_weight'), cache_emb) # cache token embeddings if cache_emb and self.fwd_weight > 0: self.dec_fwd.cache_embedding(self.device) while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature( ) if is_reset: self.enc.reset_cache() eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] is_reset = False # detect the first boundary in the same block # CTC-based VAD if streaming.is_ctc_vad: if self.ctc_weight_sub1 > 0: ctc_probs_block = self.dec_fwd_sub1.ctc.probs( eout_block_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_block = self.dec_fwd.ctc.probs(eout_block) is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout) # Truncate the most right frames if is_reset and not is_last_block and streaming.bd_offset >= 0: eout_block = eout_block[:, :streaming.bd_offset] streaming.cache_eout(eout_block) # Block-synchronous decoding if ctc_weight == 1 or self.ctc_weight == 1: end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNT): end_hyps, hyps = self.dec_fwd.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNDecoder): for i in range(math.ceil(eout_block.size(1) / block_size)): eout_block_i = eout_block[:, i * block_size:(i + 1) * block_size] end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync( eout_block_i, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, TransformerDecoder): raise NotImplementedError else: raise NotImplementedError(self.dec_fwd) merged_hyps = sorted(end_hyps + hyps, key=lambda x: x['score'], reverse=True) if len(merged_hyps) > 0: best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) if len(hyps) == 0 or (len(best_hyp_id_prefix) > 0 and best_hyp_id_prefix[-1] == self.eos): # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Segmentation strategy 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if not is_reset: streaming._bd_offset = eout_block.size( 1) - 1 # TODO: fix later is_reset = True if len(best_hyp_id_prefix) > 0: n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames print( '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s' % (streaming.offset + eout_block.size(1) * factor, n_frames * factor, streaming.n_blanks * factor, idx2token(best_hyp_id_prefix))) if is_reset: # pick up the best hyp from ended and active hypotheses if len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) # reset streaming.reset(stdout=stdout) hyps = None streaming.next_block() if is_last_block: break # next block will start from the frame next to the boundary streaming.backoff(x_block, self.dec_fwd, stdout=stdout) # pick up the best hyp if not is_reset and len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) if len(best_hyp_id_stream) > 0: return [[np.stack(best_hyp_id_stream, axis=0)]], [None] else: return [[[]]], [None]
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, speaker=None, task='ys'): """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode.""" self.eval() block_size = params.get('recog_block_sync_size') # before subsampling cache_emb = params.get('recog_cache_embedding') ctc_weight = params.get('recog_ctc_weight') backoff = True assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 or self.ctc_weight == 1.0 assert len(xs) == 1 # batch size assert params.get('recog_block_sync') # assert params.get('recog_length_norm') streaming = Streaming(xs[0], params, self.enc, idx2token) factor = self.enc.subsampling_factor block_size //= factor assert block_size >= 1, "block_size is too small." is_transformer_enc = 'former' in self.enc.enc_type hyps = None hyps_nobd = [] best_hyp_id_session = [] is_reset = False helper = BeamSearch(params.get('recog_beam_width'), self.eos, params.get('recog_ctc_weight'), params.get('recog_lm_weight'), self.device) lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'), cache_emb) if lm is not None: assert isinstance(lm, RNNLM) lm_second = helper.verify_lm_eval_mode( lm_second, params.get('recog_lm_second_weight'), cache_emb) # cache token embeddings if cache_emb and self.fwd_weight > 0: self.dec_fwd.cache_embedding(self.device) self.enc.reset_cache() eout_block_tail = None x_block_prev, xlen_block_prev = None, None while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feat( ) if not is_transformer_enc and is_reset: self.enc.reset_cache() if backoff: self.encode([x_block_prev], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block_prev) x_block_prev = x_block xlen_block_prev = xlen_block eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] if eout_block_tail is not None: eout_block = torch.cat([eout_block_tail, eout_block], dim=1) eout_block_tail = None if eout_block.size(1) > 0: streaming.cache_eout(eout_block) # Block-synchronous decoding if ctc_weight == 1 or self.ctc_weight == 1: end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNT): end_hyps, hyps = self.dec_fwd.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNDecoder): n_frames = getattr(self.dec_fwd, 'n_frames', 0) for i in range(math.ceil(eout_block.size(1) / block_size)): eout_block_i = eout_block[:, i * block_size:(i + 1) * block_size] end_hyps, hyps, hyps_nobd = self.dec_fwd.beam_search_block_sync( eout_block_i, params, helper, idx2token, hyps, hyps_nobd, lm, speaker=speaker) elif isinstance(self.dec_fwd, TransformerDecoder): raise NotImplementedError else: raise NotImplementedError(self.dec_fwd) # CTC-based reset point detection is_reset = False if streaming.enable_ctc_reset_point_detection: if self.ctc_weight_sub1 > 0: ctc_probs_block = self.dec_fwd_sub1.ctc.probs( eout_block_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_block = self.dec_fwd.ctc.probs(eout_block) is_reset = streaming.ctc_reset_point_detection( ctc_probs_block) merged_hyps = sorted(end_hyps + hyps + hyps_nobd, key=lambda x: x['score'], reverse=True) if len(merged_hyps) > 0: best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) best_hyp_id_prefix_viz = np.array( merged_hyps[0]['hyp'][1:]) if (len(best_hyp_id_prefix) > 0 and best_hyp_id_prefix[-1] == self.eos): # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Condition 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if (not is_reset) and (not streaming.safeguard_reset): is_reset = True if len(best_hyp_id_prefix_viz) > 0: n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames print( '\rStreaming (T:%.3f [sec], offset:%d [frame], blank:%d [frame]): %s' % ((streaming.offset + eout_block.size(1) * factor) / 100, n_frames * factor, streaming.n_blanks * factor, idx2token(best_hyp_id_prefix_viz))) if is_reset: # pick up the best hyp from ended and active hypotheses if len(best_hyp_id_prefix) > 0: best_hyp_id_session.extend(best_hyp_id_prefix) # reset streaming.reset() hyps = None hyps_nobd = [] streaming.next_block() if is_last_block: break # pick up the best hyp if not is_reset and len(best_hyp_id_prefix) > 0: best_hyp_id_session.extend(best_hyp_id_prefix) if len(best_hyp_id_session) > 0: return [[np.stack(best_hyp_id_session, axis=0)]], [None] else: return [[[]]], [None]