def initNnetFeatPipeline(adaptation_state, asr, decodable_opts, feat_info): feat_pipeline = OnlineNnetFeaturePipeline(feat_info) feat_pipeline.set_adaptation_state(adaptation_state) asr.set_input_pipeline(feat_pipeline) asr.init_decoding() sil_weighting = OnlineSilenceWeighting( asr.transition_model, feat_info.silence_weighting_config, decodable_opts.frame_subsampling_factor) return feat_pipeline, sil_weighting
def decode_chunked_partial_endpointing(asr, feat_info, decodable_opts, scp, chunk_size=1024, compute_confidences=True, asr_client=None, speaker="Speaker", pad_confidences=True): # Decode (chunked + partial output + endpointing # + ivector adaptation + silence weighting) adaptation_state = OnlineIvectorExtractorAdaptationState.from_info( feat_info.ivector_extractor_info) for key, wav in SequentialWaveReader(scp): feat_pipeline = OnlineNnetFeaturePipeline(feat_info) feat_pipeline.set_adaptation_state(adaptation_state) asr.set_input_pipeline(feat_pipeline) asr.init_decoding() sil_weighting = OnlineSilenceWeighting( asr.transition_model, feat_info.silence_weighting_config, decodable_opts.frame_subsampling_factor) data = wav.data()[0] print("type(data):", type(data)) last_chunk = False utt, part = 1, 1 prev_num_frames_decoded, offset = 0, 0 for i in range(0, len(data), chunk_size): if i + chunk_size >= len(data): last_chunk = True feat_pipeline.accept_waveform(wav.samp_freq, data[i:i + chunk_size]) if last_chunk: feat_pipeline.input_finished() if sil_weighting.active(): sil_weighting.compute_current_traceback(asr.decoder) feat_pipeline.ivector_feature().update_frame_weights( sil_weighting.get_delta_weights( feat_pipeline.num_frames_ready())) asr.advance_decoding() num_frames_decoded = asr.decoder.num_frames_decoded() if not last_chunk: if asr.endpoint_detected(): asr.finalize_decoding() out = asr.get_output() mbr = MinimumBayesRisk(out["lattice"]) confd = mbr.get_one_best_confidences() if pad_confidences: token_length = len(out["text"].split()) # computed confidences array is smaller than the actual token length, if len(confd) < token_length: print( "WARNING: less computeted confidences than token length! Fixing this with padding!" ) confd = np.pad(confd, [0, token_length - len(confd)], mode='constant', constant_values=1.0) elif len(confd) > token_length: print( "WARNING: more computeted confidences than token length! Fixing this with slicing!" ) confd = confd[:token_length] print(confd) # print(key + "-utt%d-final" % utt, out["text"], flush=True) if asr_client is not None: asr_client.completeUtterance( utterance=out["text"], key=key + "-utt%d-part%d" % (utt, part), confidences=confd) offset += int(num_frames_decoded * decodable_opts.frame_subsampling_factor * feat_pipeline.frame_shift_in_seconds() * wav.samp_freq) feat_pipeline.get_adaptation_state(adaptation_state) feat_pipeline = OnlineNnetFeaturePipeline(feat_info) feat_pipeline.set_adaptation_state(adaptation_state) asr.set_input_pipeline(feat_pipeline) asr.init_decoding() sil_weighting = OnlineSilenceWeighting( asr.transition_model, feat_info.silence_weighting_config, decodable_opts.frame_subsampling_factor) remainder = data[offset:i + chunk_size] feat_pipeline.accept_waveform(wav.samp_freq, remainder) utt += 1 part = 1 prev_num_frames_decoded = 0 elif num_frames_decoded > prev_num_frames_decoded: prev_num_frames_decoded = num_frames_decoded out = asr.get_partial_output() # print(key + "-utt%d-part%d" % (utt, part), # out["text"], flush=True) if asr_client is not None: asr_client.partialUtterance(utterance=out["text"], key=key + "-utt%d-part%d" % (utt, part)) part += 1 asr.finalize_decoding() out = asr.get_output() mbr = MinimumBayesRisk(out["lattice"]) confd = mbr.get_one_best_confidences() print(out) # print(key + "-utt%d-final" % utt, out["text"], flush=True) if asr_client is not None: asr_client.completeUtterance(utterance=out["text"], key=key + "-utt%d-part%d" % (utt, part), confidences=confd) feat_pipeline.get_adaptation_state(adaptation_state)
part += 1 asr.finalize_decoding() out = asr.get_output() print(key + "-final", out["text"], flush=True) # Decode (chunked + partial output + endpointing # + ivector adaptation + silence weighting) adaptation_state = OnlineIvectorExtractorAdaptationState.from_info( feat_info.ivector_extractor_info) for key, wav in SequentialWaveReader("scp:wav.scp"): feat_pipeline = OnlineNnetFeaturePipeline(feat_info) feat_pipeline.set_adaptation_state(adaptation_state) asr.set_input_pipeline(feat_pipeline) asr.init_decoding() sil_weighting = OnlineSilenceWeighting( asr.transition_model, feat_info.silence_weighting_config, decodable_opts.frame_subsampling_factor) data = wav.data()[0] last_chunk = False utt, part = 1, 1 prev_num_frames_decoded, offset = 0, 0 for i in range(0, len(data), chunk_size): if i + chunk_size >= len(data): last_chunk = True feat_pipeline.accept_waveform(wav.samp_freq, data[i:i + chunk_size]) if last_chunk: feat_pipeline.input_finished() if sil_weighting.active(): sil_weighting.compute_current_traceback(asr.decoder) feat_pipeline.ivector_feature().update_frame_weights( sil_weighting.get_delta_weights(