def encode_streaming(self, xs, params, task='ys'): """Simulate streaming encoding. Decoding is performed in the offline mode. Args: xs (FloatTensor): `[B, T, idim]` params (dict): hyper-parameters for decoding task (str): task to evaluate Returns: eout (FloatTensor): `[B, T, idim]` elens (IntTensor): `[B]` """ assert task == 'ys' assert self.input_type == 'speech' assert self.fwd_weight > 0 assert len(xs) == 1 # batch size streaming = Streaming(xs[0], params, self.enc) self.enc.reset_cache() while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature( ) eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] streaming.cache_eout(eout_block) streaming.next_block() if is_last_block: break eout = streaming.pop_eouts() elens = torch.IntTensor([eout.size(1)]) return eout, elens
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, task='ys'): """Simulate streaming decoding. Both encoding and decoding are performed in the online mode.""" assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 assert len(xs) == 1 # batch size # assert params['recog_length_norm'] global_params = copy.deepcopy(params) global_params['recog_max_len_ratio'] = 1.0 block_sync = params['recog_block_sync'] block_size = params['recog_block_sync_size'] # before subsampling streaming = Streaming(xs[0], params, self.enc) factor = self.enc.subsampling_factor block_size //= factor hyps = None best_hyp_id_stream = [] is_reset = True # for the first block stdout = False self.eval() lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) # with torch.no_grad(): while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature( ) if is_reset: self.enc.reset_cache() eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] is_reset = False # detect the first boundary in the same block # CTC-based VAD if streaming.is_ctc_vad: if self.ctc_weight_sub1 > 0: ctc_probs_block = self.dec_fwd_sub1.ctc_probs( eout_block_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_block = self.dec_fwd.ctc_probs(eout_block) is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout) # Truncate the most right frames if is_reset and not is_last_block and streaming.bd_offset >= 0: eout_block = eout_block[:, :streaming.bd_offset] streaming.cache_eout(eout_block) # Block-synchronous attention decoding if isinstance(self.dec_fwd, RNNT): raise NotImplementedError elif isinstance(self.dec_fwd, RNNDecoder) and block_sync: for i in range(math.ceil(eout_block.size(1) / block_size)): eout_block_i = eout_block[:, i * block_size:(i + 1) * block_size] end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync( eout_block_i, params, idx2token, hyps, lm, state_carry_over=False) merged_hyps = sorted(end_hyps + hyps, key=lambda x: x['score'], reverse=True) if len(merged_hyps) > 0: best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) if len(best_hyp_id_prefix ) > 0 and best_hyp_id_prefix[-1] == self.eos: # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Segmentation strategy 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if not is_reset: streaming._bd_offset = eout_block.size(1) - 1 # TODO: fix later is_reset = True if len(best_hyp_id_prefix) > 0: print( '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s' % (streaming.offset + eout_block.size(1) * factor, self.dec_fwd.n_frames * factor, streaming.n_blanks * factor, idx2token(best_hyp_id_prefix))) elif isinstance(self.dec_fwd, TransformerDecoder): best_hyp_id_prefix = [] raise NotImplementedError if is_reset: # Global decoding over the segmented region if not block_sync: eout = streaming.pop_eouts() elens = torch.IntTensor([eout.size(1)]) ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_log_probs = torch.log(self.dec_fwd.ctc_probs(eout)) nbest_hyps_id_offline = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, ctc_log_probs=ctc_log_probs, exclude_eos=exclude_eos)[0] # pick up the best hyp from ended and active hypotheses if block_sync: if len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) else: if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0]) # reset streaming.reset(stdout=stdout) hyps = None streaming.next_block() if is_last_block: break # next block will start from the frame next to the boundary streaming.backoff(x_block, self.dec_fwd, stdout=stdout) # Global decoding for tail blocks if not block_sync and streaming.n_cache_block > 0: eout = streaming.pop_eouts() elens = torch.IntTensor([eout.size(1)]) nbest_hyps_id_offline = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, exclude_eos=exclude_eos)[0] if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0]) # pick up the best hyp if not is_reset and block_sync and len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) if len(best_hyp_id_stream) > 0: return [[np.stack(best_hyp_id_stream, axis=0)]], [None] else: return [[[]]], [None]
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, task='ys'): from neural_sp.models.seq2seq.frontends.streaming import Streaming # check configurations assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 assert len(xs) == 1 # batch size # assert params['recog_length_norm'] global_params = copy.deepcopy(params) global_params['recog_max_len_ratio'] = 1.0 streaming = Streaming(xs[0], params, self.enc, idx2token) hyps = None best_hyp_id_stream = [] is_reset = True # for the first chunk self.eval() with torch.no_grad(): lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) while True: # Encode input features chunk by chunk x_chunk, is_last_chunk = streaming.extract_feature() eout_chunk = self.encode([x_chunk], task, use_cache=not is_reset, streaming=True)[task]['xs'] is_reset = False # detect the first boundary in the same chunk # CTC-based VAD ctc_log_probs_chunk = None if streaming.is_ctc_vad: ctc_probs_chunk = self.dec_fwd.ctc_probs(eout_chunk) if params['recog_ctc_weight'] > 0: ctc_log_probs_chunk = torch.log(ctc_probs_chunk) is_reset = streaming.ctc_vad(ctc_probs_chunk) # Truncate the most right frames if is_reset and not is_last_chunk: eout_chunk = eout_chunk[:, :streaming.bd_offset + 1] streaming.eout_chunks.append(eout_chunk) # Chunk-synchronous attention decoding if params['recog_chunk_sync']: end_hyps, hyps, aws_seg = self.dec_fwd.beam_search_chunk_sync( eout_chunk, params, idx2token, lm, ctc_log_probs=ctc_log_probs_chunk, hyps=hyps, state_carry_over=False, ignore_eos=self.enc.rnn_type in ['lstm', 'conv_lstm']) merged_hyps = sorted(end_hyps + hyps, key=lambda x: x['score'], reverse=True) best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) if len(best_hyp_id_prefix ) > 0 and best_hyp_id_prefix[-1] == self.eos: # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Segmentation strategy 2: # If <eos> is emitted from the decoder (not CTC), # the current chunk is segmented. if not is_reset: streaming.bd_offset = eout_chunk.size(1) - 1 is_reset = True print( '\rSync MoChA (T:%d, offset:%d, blank:%d frames): %s' % (streaming.offset + eout_chunk.size(1) * streaming.factor, self.dec_fwd.n_frames * streaming.factor, streaming.n_blanks * streaming.factor, idx2token(best_hyp_id_prefix)), end='') # print('-' * 50) if is_reset: # Global decoding over the segmented region if not params['recog_chunk_sync']: eout = torch.cat(streaming.eout_chunks, dim=1) elens = torch.IntTensor([eout.size(1)]) ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_log_probs = torch.log( self.dec_fwd.ctc_probs(eout)) nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, ctc_log_probs=ctc_log_probs) # print('Offline MoChA (T:%d): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # idx2token(nbest_hyps_id_offline[0][0]))) eout = torch.cat(streaming.eout_chunks, dim=1) elens = torch.IntTensor([eout.size(1)]) ctc_log_probs = None nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, ctc_log_probs=ctc_log_probs) # print('Offline MoChA (T:%d): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # idx2token(nbest_hyps_id_offline[0][0]))) # pick up the best hyp from ended and active hypotheses if not params['recog_chunk_sync']: if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend( nbest_hyps_id_offline[0][0]) else: if len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) # print('Final Sync MoChA (T:%d, segment:%d frames): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # self.dec_fwd.n_frames * streaming.factor, # idx2token(best_hyp_id_prefix))) # print('-' * 50) # for test # eos_hyp = np.zeros(1, dtype=np.int32) # eos_hyp[0] = self.eos # best_hyp_id_stream.extend(eos_hyp) # reset streaming.reset() hyps = None # next chunk will start from the frame next to the boundary if not is_last_chunk and 0 <= streaming.bd_offset * streaming.factor < streaming.N_l - 1: streaming.offset -= x_chunk[ (streaming.bd_offset + 1) * streaming.factor:streaming.N_l].shape[0] self.dec_fwd.n_frames -= x_chunk[ (streaming.bd_offset + 1) * streaming. factor:streaming.N_l].shape[0] // streaming.factor # print('Back %d frames' % (x_chunk[(streaming.bd_offset + 1) * streaming.factor:streaming.N_l].shape[0])) streaming.next_chunk() if is_last_chunk: break # Global decoding over the last chunk if not params['recog_chunk_sync'] and len( streaming.eout_chunks) > 0: eout = torch.cat(streaming.eout_chunks, dim=1) elens = torch.IntTensor([eout.size(1)]) nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, None) # print('MoChA: ' + idx2token(nbest_hyps_id_offline[0][0])) # print('*' * 50) if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0]) # pick up the best hyp if not is_reset and params['recog_chunk_sync'] and len( best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) if len(best_hyp_id_stream) > 0: return [np.stack(best_hyp_id_stream, axis=0)], [None] else: return [[]], [None]
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, task='ys'): assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 assert len(xs) == 1 # batch size # assert params['recog_length_norm'] global_params = copy.deepcopy(params) global_params['recog_max_len_ratio'] = 1.0 block_sync = params['recog_block_sync'] streaming = Streaming(xs[0], params, self.enc, params['recog_block_sync_size']) hyps = None best_hyp_id_stream = [] is_reset = True # for the first block stdout = False self.eval() with torch.no_grad(): lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) while True: # Encode input features block by block x_chunk, is_last_block, lookback, lookahead = streaming.extract_feature( ) if is_reset: self.enc.reset_cache() eout_chunk_dict = self.encode([x_chunk], 'all', streaming=True, lookback=lookback, lookahead=lookahead) eout_chunk = eout_chunk_dict[task]['xs'] is_reset = False # detect the first boundary in the same block # CTC-based VAD if streaming.is_ctc_vad: if self.ctc_weight_sub1 > 0: ctc_probs_chunk = self.dec_fwd_sub1.ctc_probs( eout_chunk_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_chunk = self.dec_fwd.ctc_probs(eout_chunk) is_reset = streaming.ctc_vad(ctc_probs_chunk, stdout=stdout) # Truncate the most right frames if is_reset and not is_last_block and streaming.bd_offset >= 0: eout_chunk = eout_chunk[:, :streaming.bd_offset] streaming.eout_chunks.append(eout_chunk) # Chunk-synchronous attention decoding if isinstance(self.dec_fwd, RNNDecoder) and block_sync: end_hyps, hyps, aws_seg = self.dec_fwd.beam_search_block_sync( eout_chunk, params, idx2token, lm, hyps=hyps, state_carry_over=False) merged_hyps = sorted(end_hyps + hyps, key=lambda x: x['score'], reverse=True) best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) if len(best_hyp_id_prefix ) > 0 and best_hyp_id_prefix[-1] == self.eos: # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Segmentation strategy 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if not is_reset: streaming.bd_offset = eout_chunk.size(1) - 1 is_reset = True if len(best_hyp_id_prefix) > 0: # print('\rStreaming (T:%d [frame], offset:%d [frame], blank:%d [frame]): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # self.dec_fwd.n_frames * streaming.factor, # streaming.n_blanks * streaming.factor, # idx2token(best_hyp_id_prefix))) print('\r%s' % (idx2token(best_hyp_id_prefix))) if is_reset: # Global decoding over the segmented region if not block_sync: eout = torch.cat(streaming.eout_chunks, dim=1) elens = torch.IntTensor([eout.size(1)]) ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_log_probs = torch.log( self.dec_fwd.ctc_probs(eout)) nbest_hyps_id_offline = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, ctc_log_probs=ctc_log_probs)[0] # print('Offline (T:%d [10ms]): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # idx2token(nbest_hyps_id_offline[0][0]))) # pick up the best hyp from ended and active hypotheses if not block_sync: if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend( nbest_hyps_id_offline[0][0]) else: if len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) # print('Final (T:%d [10ms], offset:%d [10ms]): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # self.dec_fwd.n_frames * streaming.factor, # idx2token(best_hyp_id_prefix))) # print('-' * 50) # for test # eos_hyp = np.zeros(1, dtype=np.int32) # eos_hyp[0] = self.eos # best_hyp_id_stream.extend(eos_hyp) # reset streaming.reset(stdout=stdout) hyps = None streaming.next_chunk() # next block will start from the frame next to the boundary if not is_last_block: streaming.backoff(x_chunk, self.dec_fwd, stdout=stdout) if is_last_block: break # Global decoding for tail chunks if not block_sync and len(streaming.eout_chunks) > 0: eout = torch.cat(streaming.eout_chunks, dim=1) elens = torch.IntTensor([eout.size(1)]) nbest_hyps_id_offline = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second)[0] # print('MoChA: ' + idx2token(nbest_hyps_id_offline[0][0])) # print('*' * 50) if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0]) # pick up the best hyp if not is_reset and block_sync and len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) if len(best_hyp_id_stream) > 0: return [[np.stack(best_hyp_id_stream, axis=0)]], [None] else: return [[[]]], [None]
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, task='ys'): """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode.""" block_size = params.get('recog_block_sync_size') # before subsampling cache_emb = params.get('recog_cache_embedding') ctc_weight = params.get('recog_ctc_weight') assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 or self.ctc_weight == 1.0 assert len(xs) == 1 # batch size assert params.get('recog_block_sync') # assert params.get('recog_length_norm') streaming = Streaming(xs[0], params, self.enc) factor = self.enc.subsampling_factor block_size //= factor assert block_size >= 1, "block_size is too small." hyps = None best_hyp_id_stream = [] is_reset = True # for the first block stdout = False self.eval() helper = BeamSearch(params.get('recog_beam_width'), self.eos, params.get('recog_ctc_weight'), params.get('recog_lm_weight'), self.device) lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'), cache_emb) if lm is not None: assert isinstance(lm, RNNLM) lm_second = helper.verify_lm_eval_mode( lm_second, params.get('recog_lm_second_weight'), cache_emb) # cache token embeddings if cache_emb and self.fwd_weight > 0: self.dec_fwd.cache_embedding(self.device) while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature( ) if is_reset: self.enc.reset_cache() eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] is_reset = False # detect the first boundary in the same block # CTC-based VAD if streaming.is_ctc_vad: if self.ctc_weight_sub1 > 0: ctc_probs_block = self.dec_fwd_sub1.ctc.probs( eout_block_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_block = self.dec_fwd.ctc.probs(eout_block) is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout) # Truncate the most right frames if is_reset and not is_last_block and streaming.bd_offset >= 0: eout_block = eout_block[:, :streaming.bd_offset] streaming.cache_eout(eout_block) # Block-synchronous decoding if ctc_weight == 1 or self.ctc_weight == 1: end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNT): end_hyps, hyps = self.dec_fwd.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNDecoder): for i in range(math.ceil(eout_block.size(1) / block_size)): eout_block_i = eout_block[:, i * block_size:(i + 1) * block_size] end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync( eout_block_i, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, TransformerDecoder): raise NotImplementedError else: raise NotImplementedError(self.dec_fwd) merged_hyps = sorted(end_hyps + hyps, key=lambda x: x['score'], reverse=True) if len(merged_hyps) > 0: best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) if len(hyps) == 0 or (len(best_hyp_id_prefix) > 0 and best_hyp_id_prefix[-1] == self.eos): # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Segmentation strategy 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if not is_reset: streaming._bd_offset = eout_block.size( 1) - 1 # TODO: fix later is_reset = True if len(best_hyp_id_prefix) > 0: n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames print( '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s' % (streaming.offset + eout_block.size(1) * factor, n_frames * factor, streaming.n_blanks * factor, idx2token(best_hyp_id_prefix))) if is_reset: # pick up the best hyp from ended and active hypotheses if len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) # reset streaming.reset(stdout=stdout) hyps = None streaming.next_block() if is_last_block: break # next block will start from the frame next to the boundary streaming.backoff(x_block, self.dec_fwd, stdout=stdout) # pick up the best hyp if not is_reset and len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) if len(best_hyp_id_stream) > 0: return [[np.stack(best_hyp_id_stream, axis=0)]], [None] else: return [[[]]], [None]
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, speaker=None, task='ys'): """Simulate streaming encoding+decoding. Both encoding and decoding are performed in the online mode.""" self.eval() block_size = params.get('recog_block_sync_size') # before subsampling cache_emb = params.get('recog_cache_embedding') ctc_weight = params.get('recog_ctc_weight') backoff = True assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 or self.ctc_weight == 1.0 assert len(xs) == 1 # batch size assert params.get('recog_block_sync') # assert params.get('recog_length_norm') streaming = Streaming(xs[0], params, self.enc, idx2token) factor = self.enc.subsampling_factor block_size //= factor assert block_size >= 1, "block_size is too small." is_transformer_enc = 'former' in self.enc.enc_type hyps = None hyps_nobd = [] best_hyp_id_session = [] is_reset = False helper = BeamSearch(params.get('recog_beam_width'), self.eos, params.get('recog_ctc_weight'), params.get('recog_lm_weight'), self.device) lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) lm = helper.verify_lm_eval_mode(lm, params.get('recog_lm_weight'), cache_emb) if lm is not None: assert isinstance(lm, RNNLM) lm_second = helper.verify_lm_eval_mode( lm_second, params.get('recog_lm_second_weight'), cache_emb) # cache token embeddings if cache_emb and self.fwd_weight > 0: self.dec_fwd.cache_embedding(self.device) self.enc.reset_cache() eout_block_tail = None x_block_prev, xlen_block_prev = None, None while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feat( ) if not is_transformer_enc and is_reset: self.enc.reset_cache() if backoff: self.encode([x_block_prev], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block_prev) x_block_prev = x_block xlen_block_prev = xlen_block eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] if eout_block_tail is not None: eout_block = torch.cat([eout_block_tail, eout_block], dim=1) eout_block_tail = None if eout_block.size(1) > 0: streaming.cache_eout(eout_block) # Block-synchronous decoding if ctc_weight == 1 or self.ctc_weight == 1: end_hyps, hyps = self.dec_fwd.ctc.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNT): end_hyps, hyps = self.dec_fwd.beam_search_block_sync( eout_block, params, helper, idx2token, hyps, lm) elif isinstance(self.dec_fwd, RNNDecoder): n_frames = getattr(self.dec_fwd, 'n_frames', 0) for i in range(math.ceil(eout_block.size(1) / block_size)): eout_block_i = eout_block[:, i * block_size:(i + 1) * block_size] end_hyps, hyps, hyps_nobd = self.dec_fwd.beam_search_block_sync( eout_block_i, params, helper, idx2token, hyps, hyps_nobd, lm, speaker=speaker) elif isinstance(self.dec_fwd, TransformerDecoder): raise NotImplementedError else: raise NotImplementedError(self.dec_fwd) # CTC-based reset point detection is_reset = False if streaming.enable_ctc_reset_point_detection: if self.ctc_weight_sub1 > 0: ctc_probs_block = self.dec_fwd_sub1.ctc.probs( eout_block_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_block = self.dec_fwd.ctc.probs(eout_block) is_reset = streaming.ctc_reset_point_detection( ctc_probs_block) merged_hyps = sorted(end_hyps + hyps + hyps_nobd, key=lambda x: x['score'], reverse=True) if len(merged_hyps) > 0: best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) best_hyp_id_prefix_viz = np.array( merged_hyps[0]['hyp'][1:]) if (len(best_hyp_id_prefix) > 0 and best_hyp_id_prefix[-1] == self.eos): # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Condition 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if (not is_reset) and (not streaming.safeguard_reset): is_reset = True if len(best_hyp_id_prefix_viz) > 0: n_frames = self.dec_fwd.ctc.n_frames if ctc_weight == 1 or self.ctc_weight == 1 else self.dec_fwd.n_frames print( '\rStreaming (T:%.3f [sec], offset:%d [frame], blank:%d [frame]): %s' % ((streaming.offset + eout_block.size(1) * factor) / 100, n_frames * factor, streaming.n_blanks * factor, idx2token(best_hyp_id_prefix_viz))) if is_reset: # pick up the best hyp from ended and active hypotheses if len(best_hyp_id_prefix) > 0: best_hyp_id_session.extend(best_hyp_id_prefix) # reset streaming.reset() hyps = None hyps_nobd = [] streaming.next_block() if is_last_block: break # pick up the best hyp if not is_reset and len(best_hyp_id_prefix) > 0: best_hyp_id_session.extend(best_hyp_id_prefix) if len(best_hyp_id_session) > 0: return [[np.stack(best_hyp_id_session, axis=0)]], [None] else: return [[[]]], [None]