def encode_streaming(self, xs, params, task='ys'): """Simulate streaming encoding. Decoding is performed in the offline mode. Args: xs (FloatTensor): `[B, T, idim]` params (dict): hyper-parameters for decoding task (str): task to evaluate Returns: eout (FloatTensor): `[B, T, idim]` elens (IntTensor): `[B]` """ assert task == 'ys' assert self.input_type == 'speech' assert self.fwd_weight > 0 assert len(xs) == 1 # batch size streaming = Streaming(xs[0], params, self.enc) self.enc.reset_cache() while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature( ) eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] streaming.cache_eout(eout_block) streaming.next_block() if is_last_block: break eout = streaming.pop_eouts() elens = torch.IntTensor([eout.size(1)]) return eout, elens
def decode_streaming(self, xs, params, idx2token, exclude_eos=False, task='ys'): """Simulate streaming decoding. Both encoding and decoding are performed in the online mode.""" assert task == 'ys' assert self.input_type == 'speech' assert self.ctc_weight > 0 assert self.fwd_weight > 0 assert len(xs) == 1 # batch size # assert params['recog_length_norm'] global_params = copy.deepcopy(params) global_params['recog_max_len_ratio'] = 1.0 block_sync = params['recog_block_sync'] block_size = params['recog_block_sync_size'] # before subsampling streaming = Streaming(xs[0], params, self.enc) factor = self.enc.subsampling_factor block_size //= factor hyps = None best_hyp_id_stream = [] is_reset = True # for the first block stdout = False self.eval() lm = getattr(self, 'lm_fwd', None) lm_second = getattr(self, 'lm_second', None) # with torch.no_grad(): while True: # Encode input features block by block x_block, is_last_block, cnn_lookback, cnn_lookahead, xlen_block = streaming.extract_feature( ) if is_reset: self.enc.reset_cache() eout_block_dict = self.encode([x_block], 'all', streaming=True, cnn_lookback=cnn_lookback, cnn_lookahead=cnn_lookahead, xlen_block=xlen_block) eout_block = eout_block_dict[task]['xs'] is_reset = False # detect the first boundary in the same block # CTC-based VAD if streaming.is_ctc_vad: if self.ctc_weight_sub1 > 0: ctc_probs_block = self.dec_fwd_sub1.ctc_probs( eout_block_dict['ys_sub1']['xs']) # TODO: consider subsampling else: ctc_probs_block = self.dec_fwd.ctc_probs(eout_block) is_reset = streaming.ctc_vad(ctc_probs_block, stdout=stdout) # Truncate the most right frames if is_reset and not is_last_block and streaming.bd_offset >= 0: eout_block = eout_block[:, :streaming.bd_offset] streaming.cache_eout(eout_block) # Block-synchronous attention decoding if isinstance(self.dec_fwd, RNNT): raise NotImplementedError elif isinstance(self.dec_fwd, RNNDecoder) and block_sync: for i in range(math.ceil(eout_block.size(1) / block_size)): eout_block_i = eout_block[:, i * block_size:(i + 1) * block_size] end_hyps, hyps, _ = self.dec_fwd.beam_search_block_sync( eout_block_i, params, idx2token, hyps, lm, state_carry_over=False) merged_hyps = sorted(end_hyps + hyps, key=lambda x: x['score'], reverse=True) if len(merged_hyps) > 0: best_hyp_id_prefix = np.array(merged_hyps[0]['hyp'][1:]) if len(best_hyp_id_prefix ) > 0 and best_hyp_id_prefix[-1] == self.eos: # reset beam if <eos> is generated from the best hypothesis best_hyp_id_prefix = best_hyp_id_prefix[: -1] # exclude <eos> # Segmentation strategy 2: # If <eos> is emitted from the decoder (not CTC), # the current block is segmented. if not is_reset: streaming._bd_offset = eout_block.size(1) - 1 # TODO: fix later is_reset = True if len(best_hyp_id_prefix) > 0: print( '\rStreaming (T:%d [10ms], offset:%d [10ms], blank:%d [10ms]): %s' % (streaming.offset + eout_block.size(1) * factor, self.dec_fwd.n_frames * factor, streaming.n_blanks * factor, idx2token(best_hyp_id_prefix))) elif isinstance(self.dec_fwd, TransformerDecoder): best_hyp_id_prefix = [] raise NotImplementedError if is_reset: # Global decoding over the segmented region if not block_sync: eout = streaming.pop_eouts() elens = torch.IntTensor([eout.size(1)]) ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_log_probs = torch.log(self.dec_fwd.ctc_probs(eout)) nbest_hyps_id_offline = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, ctc_log_probs=ctc_log_probs, exclude_eos=exclude_eos)[0] # pick up the best hyp from ended and active hypotheses if block_sync: if len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) else: if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0]) # reset streaming.reset(stdout=stdout) hyps = None streaming.next_block() if is_last_block: break # next block will start from the frame next to the boundary streaming.backoff(x_block, self.dec_fwd, stdout=stdout) # Global decoding for tail blocks if not block_sync and streaming.n_cache_block > 0: eout = streaming.pop_eouts() elens = torch.IntTensor([eout.size(1)]) nbest_hyps_id_offline = self.dec_fwd.beam_search( eout, elens, global_params, idx2token, lm, lm_second, exclude_eos=exclude_eos)[0] if len(nbest_hyps_id_offline[0][0]) > 0: best_hyp_id_stream.extend(nbest_hyps_id_offline[0][0]) # pick up the best hyp if not is_reset and block_sync and len(best_hyp_id_prefix) > 0: best_hyp_id_stream.extend(best_hyp_id_prefix) if len(best_hyp_id_stream) > 0: return [[np.stack(best_hyp_id_stream, axis=0)]], [None] else: return [[[]]], [None]