def decode_streaming(self, xs, params, idx2token, exclude_eos=False, task='ys'): from neural_sp.models.seq2seq.frontends.streaming import Streaming # check configurations assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 assert len(xs) == 1 # batch size # assert params['recog_length_norm'] global_params = copy.deepcopy(params) global_params['recog_max_len_ratio'] = 1.0 streaming = Streaming(xs[0], params, self.enc, idx2token) hyps = None best_hyp_id_stream = [] is_reset = True # for the first chunk self.eval() with torch.no_grad(): lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) while True: # Encode input features chunk by chunk x_chunk, is_last_chunk = streaming.extract_feature() eout_chunk = self.encode([x_chunk], task, use_cache=not is_reset, streaming=True)[task]['xs'] is_reset = False # detect the first boundary in the same chunk # CTC-based VAD ctc_log_probs_chunk = None if streaming.is_ctc_vad: ctc_probs_chunk = self.dec_fwd.ctc_probs(eout_chunk) if params['recog_ctc_weight'] > 0: ctc_log_probs_chunk = torch.log(ctc_probs_chunk) is_reset = streaming.ctc_vad(ctc_probs_chunk) # Truncate the most right frames if is_reset and not is_last_chunk: eout_chunk = eout_chunk[:, :streaming.bd_offset + 1] streaming.eout_chunks.append(eout_chunk) # Chunk-synchronous attention decoding if params['recog_chunk_sync']: end_hyps, hyps, aws_seg = self.dec_fwd.beam_search_chunk_sync( eout_chunk, params, idx2token, lm, ctc_log_probs=ctc_log_probs_chunk, hyps=hyps, state_carry_over=False, ignore_eos=self.enc.rnn_type in ['lstm', 'conv_lstm']) merged_hyps = sorted(end_hyps + hyps, key=lambda x: x['score'], reverse=True) best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) if len(best_hyp_id_prefix ) > 0 and best_hyp_id_prefix[-1] == self.eos: # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Segmentation strategy 2: # If <eos> is emitted from the decoder (not CTC), # the current chunk is segmented. if not is_reset: streaming.bd_offset = eout_chunk.size(1) - 1 is_reset = True print( '\rSync MoChA (T:%d, offset:%d, blank:%d frames): %s' % (streaming.offset + eout_chunk.size(1) * streaming.factor, self.dec_fwd.n_frames * streaming.factor, streaming.n_blanks * streaming.factor, idx2token(best_hyp_id_prefix)), end='') # print('-' * 50) if is_reset: # Global decoding over the segmented region if not params['recog_chunk_sync']: eout = torch.cat(streaming.eout_chunks, dim=1) elens = torch.IntTensor([eout.size(1)]) ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_log_probs = torch.log( self.dec_fwd.ctc_probs(eout)) nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, ctc_log_probs=ctc_log_probs) # print('Offline MoChA (T:%d): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # idx2token(nbest_hyps_id_offline[0][0]))) eout = torch.cat(streaming.eout_chunks, dim=1) elens = torch.IntTensor([eout.size(1)]) ctc_log_probs = None nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, ctc_log_probs=ctc_log_probs) # print('Offline MoChA (T:%d): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # idx2token(nbest_hyps_id_offline[0][0]))) # pick up the best hyp from ended and active hypotheses if not params['recog_chunk_sync']: if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend( nbest_hyps_id_offline[0][0]) else: if len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) # print('Final Sync MoChA (T:%d, segment:%d frames): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # self.dec_fwd.n_frames * streaming.factor, # idx2token(best_hyp_id_prefix))) # print('-' * 50) # for test # eos_hyp = np.zeros(1, dtype=np.int32) # eos_hyp[0] = self.eos # best_hyp_id_stream.extend(eos_hyp) # reset streaming.reset() hyps = None # next chunk will start from the frame next to the boundary if not is_last_chunk and 0 <= streaming.bd_offset * streaming.factor < streaming.N_l - 1: streaming.offset -= x_chunk[ (streaming.bd_offset + 1) * streaming.factor:streaming.N_l].shape[0] self.dec_fwd.n_frames -= x_chunk[ (streaming.bd_offset + 1) * streaming. factor:streaming.N_l].shape[0] // streaming.factor # print('Back %d frames' % (x_chunk[(streaming.bd_offset + 1) * streaming.factor:streaming.N_l].shape[0])) streaming.next_chunk() if is_last_chunk: break # Global decoding over the last chunk if not params['recog_chunk_sync'] and len( streaming.eout_chunks) > 0: eout = torch.cat(streaming.eout_chunks, dim=1) elens = torch.IntTensor([eout.size(1)]) nbest_hyps_id_offline, _, _ = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, None) # print('MoChA: ' + idx2token(nbest_hyps_id_offline[0][0])) # print('*' * 50) if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0]) # pick up the best hyp if not is_reset and params['recog_chunk_sync'] and len( best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) if len(best_hyp_id_stream) > 0: return [np.stack(best_hyp_id_stream, axis=0)], [None] else: return [[]], [None]
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, task='ys'): assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 assert len(xs) == 1 # batch size # assert params['recog_length_norm'] global_params = copy.deepcopy(params) global_params['recog_max_len_ratio'] = 1.0 block_sync = params['recog_block_sync'] streaming = Streaming(xs[0], params, self.enc, params['recog_block_sync_size']) hyps = None best_hyp_id_stream = [] is_reset = True # for the first block stdout = False self.eval() with torch.no_grad(): lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) while True: # Encode input features block by block x_chunk, is_last_block, lookback, lookahead = streaming.extract_feature( ) if is_reset: self.enc.reset_cache() eout_chunk_dict = self.encode([x_chunk], 'all', streaming=True, lookback=lookback, lookahead=lookahead) eout_chunk = eout_chunk_dict[task]['xs'] is_reset = False # detect the first boundary in the same block # CTC-based VAD if streaming.is_ctc_vad: if self.ctc_weight_sub1 > 0: ctc_probs_chunk = self.dec_fwd_sub1.ctc_probs( eout_chunk_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_chunk = self.dec_fwd.ctc_probs(eout_chunk) is_reset = streaming.ctc_vad(ctc_probs_chunk, stdout=stdout) # Truncate the most right frames if is_reset and not is_last_block and streaming.bd_offset >= 0: eout_chunk = eout_chunk[:, :streaming.bd_offset] streaming.eout_chunks.append(eout_chunk) # Chunk-synchronous attention decoding if isinstance(self.dec_fwd, RNNDecoder) and block_sync: end_hyps, hyps, aws_seg = self.dec_fwd.beam_search_block_sync( eout_chunk, params, idx2token, lm, hyps=hyps, state_carry_over=False) merged_hyps = sorted(end_hyps + hyps, key=lambda x: x['score'], reverse=True) best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) if len(best_hyp_id_prefix ) > 0 and best_hyp_id_prefix[-1] == self.eos: # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Segmentation strategy 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if not is_reset: streaming.bd_offset = eout_chunk.size(1) - 1 is_reset = True if len(best_hyp_id_prefix) > 0: # print('\rStreaming (T:%d [frame], offset:%d [frame], blank:%d [frame]): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # self.dec_fwd.n_frames * streaming.factor, # streaming.n_blanks * streaming.factor, # idx2token(best_hyp_id_prefix))) print('\r%s' % (idx2token(best_hyp_id_prefix))) if is_reset: # Global decoding over the segmented region if not block_sync: eout = torch.cat(streaming.eout_chunks, dim=1) elens = torch.IntTensor([eout.size(1)]) ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_log_probs = torch.log( self.dec_fwd.ctc_probs(eout)) nbest_hyps_id_offline = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, ctc_log_probs=ctc_log_probs)[0] # print('Offline (T:%d [10ms]): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # idx2token(nbest_hyps_id_offline[0][0]))) # pick up the best hyp from ended and active hypotheses if not block_sync: if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend( nbest_hyps_id_offline[0][0]) else: if len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) # print('Final (T:%d [10ms], offset:%d [10ms]): %s' % # (streaming.offset + eout_chunk.size(1) * streaming.factor, # self.dec_fwd.n_frames * streaming.factor, # idx2token(best_hyp_id_prefix))) # print('-' * 50) # for test # eos_hyp = np.zeros(1, dtype=np.int32) # eos_hyp[0] = self.eos # best_hyp_id_stream.extend(eos_hyp) # reset streaming.reset(stdout=stdout) hyps = None streaming.next_chunk() # next block will start from the frame next to the boundary if not is_last_block: streaming.backoff(x_chunk, self.dec_fwd, stdout=stdout) if is_last_block: break # Global decoding for tail chunks if not block_sync and len(streaming.eout_chunks) > 0: eout = torch.cat(streaming.eout_chunks, dim=1) elens = torch.IntTensor([eout.size(1)]) nbest_hyps_id_offline = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second)[0] # print('MoChA: ' + idx2token(nbest_hyps_id_offline[0][0])) # print('*' * 50) if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0]) # pick up the best hyp if not is_reset and block_sync and len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) if len(best_hyp_id_stream) > 0: return [[np.stack(best_hyp_id_stream, axis=0)]], [None] else: return [[[]]], [None]