def load_model(config_file, online_config, models_path='models/', beam_size=10, frames_per_chunk=50): # Read YAML file with open(config_file, 'r') as stream: model_yaml = yaml.safe_load(stream) decoder_yaml_opts = model_yaml['decoder'] print(decoder_yaml_opts) feat_opts = OnlineNnetFeaturePipelineConfig() endpoint_opts = OnlineEndpointConfig() if not os.path.isfile(online_config): print(online_config + ' does not exists. Trying to create it from yaml file settings.') print( 'See also online_config_options.info.txt for what possible settings are.' ) with open(online_config, 'w') as online_config_file: online_config_file.write("--add_pitch=False\n") online_config_file.write("--mfcc_config=" + models_path + decoder_yaml_opts['mfcc-config'] + "\n") online_config_file.write("--feature_type=mfcc\n") online_config_file.write( "--ivector_extraction_config=" + models_path + decoder_yaml_opts['ivector-extraction-config'] + '\n') online_config_file.write( "--endpoint.silence-phones=" + decoder_yaml_opts['endpoint-silence-phones'] + '\n') else: print("Loading online conf from:", online_config) po = ParseOptions("") feat_opts.register(po) endpoint_opts.register(po) po.read_config_file(online_config) feat_info = OnlineNnetFeaturePipelineInfo.from_config(feat_opts) # Construct recognizer decoder_opts = LatticeFasterDecoderOptions() decoder_opts.beam = beam_size decoder_opts.max_active = 7000 decodable_opts = NnetSimpleLoopedComputationOptions() decodable_opts.acoustic_scale = 1.0 decodable_opts.frame_subsampling_factor = 3 decodable_opts.frames_per_chunk = frames_per_chunk asr = NnetLatticeFasterOnlineRecognizer.from_files( models_path + decoder_yaml_opts["model"], models_path + decoder_yaml_opts["fst"], models_path + decoder_yaml_opts["word-syms"], decoder_opts=decoder_opts, decodable_opts=decodable_opts, endpoint_opts=endpoint_opts) return asr, feat_info, decodable_opts
def LoadModels(self): try: # Define online feature pipeline po = ParseOptions("") decoder_opts = LatticeFasterDecoderOptions() self.endpoint_opts = OnlineEndpointConfig() self.decodable_opts = NnetSimpleLoopedComputationOptions() feat_opts = OnlineNnetFeaturePipelineConfig() decoder_opts.register(po) self.endpoint_opts.register(po) self.decodable_opts.register(po) feat_opts.register(po) po.read_config_file(self.CONFIG_FILES_PATH + "/online.conf") self.feat_info = OnlineNnetFeaturePipelineInfo.from_config( feat_opts) # Set metadata parameters self.samp_freq = self.feat_info.mfcc_opts.frame_opts.samp_freq self.frame_shift = self.feat_info.mfcc_opts.frame_opts.frame_shift_ms / 1000 self.acwt = self.decodable_opts.acoustic_scale # Load Acoustic and graph models and other files self.transition_model, self.acoustic_model = NnetRecognizer.read_model( self.AM_PATH + "/final.mdl") graph = _fst.read_fst_kaldi(self.LM_PATH + "/HCLG.fst") self.decoder_graph = LatticeFasterOnlineDecoder( graph, decoder_opts) self.symbols = _fst.SymbolTable.read_text(self.LM_PATH + "/words.txt") self.info = WordBoundaryInfo.from_file( WordBoundaryInfoNewOpts(), self.LM_PATH + "/word_boundary.int") self.asr = NnetLatticeFasterOnlineRecognizer( self.transition_model, self.acoustic_model, self.decoder_graph, self.symbols, decodable_opts=self.decodable_opts, endpoint_opts=self.endpoint_opts) del graph, decoder_opts except Exception as e: self.log.error(e) raise ValueError( "AM and LM loading failed!!! (see logs for more details)")
from kaldi.asr import NnetLatticeFasterOnlineRecognizer from kaldi.decoder import LatticeFasterDecoderOptions from kaldi.nnet3 import NnetSimpleLoopedComputationOptions from kaldi.online2 import (OnlineEndpointConfig, OnlineIvectorExtractorAdaptationState, OnlineNnetFeaturePipelineConfig, OnlineNnetFeaturePipelineInfo, OnlineNnetFeaturePipeline, OnlineSilenceWeighting) from kaldi.util.options import ParseOptions from kaldi.util.table import SequentialWaveReader chunk_size = 1440 # Define online feature pipeline feat_opts = OnlineNnetFeaturePipelineConfig() endpoint_opts = OnlineEndpointConfig() po = ParseOptions("") feat_opts.register(po) endpoint_opts.register(po) po.read_config_file("online.conf") feat_info = OnlineNnetFeaturePipelineInfo.from_config(feat_opts) # Construct recognizer decoder_opts = LatticeFasterDecoderOptions() decoder_opts.beam = 13 decoder_opts.max_active = 7000 decodable_opts = NnetSimpleLoopedComputationOptions() decodable_opts.acoustic_scale = 1.0 decodable_opts.frame_subsampling_factor = 3 decodable_opts.frames_per_chunk = 150 asr = NnetLatticeFasterOnlineRecognizer.from_files(
class ASR: def __init__(self, AM_PATH, LM_PATH, CONFIG_FILES_PATH): self.log = logging.getLogger('__stt-standelone-worker__.ASR') self.AM_PATH = AM_PATH self.LM_PATH = LM_PATH self.CONFIG_FILES_PATH = CONFIG_FILES_PATH self.LoadModels() def LoadModels(self): try: # Define online feature pipeline po = ParseOptions("") decoder_opts = LatticeFasterDecoderOptions() self.endpoint_opts = OnlineEndpointConfig() self.decodable_opts = NnetSimpleLoopedComputationOptions() feat_opts = OnlineNnetFeaturePipelineConfig() decoder_opts.register(po) self.endpoint_opts.register(po) self.decodable_opts.register(po) feat_opts.register(po) po.read_config_file(self.CONFIG_FILES_PATH+"/online.conf") self.feat_info = OnlineNnetFeaturePipelineInfo.from_config( feat_opts) # Set metadata parameters self.samp_freq = self.feat_info.mfcc_opts.frame_opts.samp_freq self.frame_shift = self.feat_info.mfcc_opts.frame_opts.frame_shift_ms / 1000 self.acwt = self.decodable_opts.acoustic_scale # Load Acoustic and graph models and other files self.transition_model, self.acoustic_model = NnetRecognizer.read_model( self.AM_PATH+"/final.mdl") graph = _fst.read_fst_kaldi(self.LM_PATH+"/HCLG.fst") self.decoder_graph = LatticeFasterOnlineDecoder( graph, decoder_opts) self.symbols = _fst.SymbolTable.read_text( self.LM_PATH+"/words.txt") self.info = WordBoundaryInfo.from_file( WordBoundaryInfoNewOpts(), self.LM_PATH+"/word_boundary.int") self.asr = NnetLatticeFasterOnlineRecognizer(self.transition_model, self.acoustic_model, self.decoder_graph, self.symbols, decodable_opts=self.decodable_opts, endpoint_opts=self.endpoint_opts) del graph, decoder_opts except Exception as e: self.log.error(e) raise ValueError( "AM and LM loading failed!!! (see logs for more details)") def get_sample_rate(self): return self.samp_freq def get_frames(self, feat_pipeline): rows = feat_pipeline.num_frames_ready() cols = feat_pipeline.dim() frames = Matrix(rows, cols) feat_pipeline.get_frames(range(rows), frames) return frames[:, :self.feat_info.mfcc_opts.num_ceps], frames[:, self.feat_info.mfcc_opts.num_ceps:] # return feats + ivectors def compute_feat(self, wav): try: feat_pipeline = OnlineNnetFeaturePipeline(self.feat_info) feat_pipeline.accept_waveform(self.samp_freq, wav) feat_pipeline.input_finished() except Exception as e: self.log.error(e) raise ValueError("Feature extraction failed!!!") else: return feat_pipeline def decoder(self, feats): try: start_time = time.time() self.log.info("Start Decoding: %s" % (start_time)) self.asr.set_input_pipeline(feats) decode = self.asr.decode() self.log.info("Decode time in seconds: %s" % (time.time() - start_time)) except Exception as e: self.log.error(e) raise ValueError("Decoder failed to transcribe the input audio!!!") else: return decode def wordTimestamp(self, text, lattice, frame_shift, frame_subsampling): try: _fst.utils.scale_compact_lattice( [[1.0, 0], [0, float(self.acwt)]], lattice) bestPath = compact_lattice_shortest_path(lattice) _fst.utils.scale_compact_lattice( [[1.0, 0], [0, 1.0/float(self.acwt)]], bestPath) bestLattice = word_align_lattice( bestPath, self.transition_model, self.info, 0) alignment = compact_lattice_to_word_alignment(bestLattice[1]) words = _fst.indices_to_symbols(self.symbols, alignment[0]) start = alignment[1] dur = alignment[2] output = {} output["words"] = [] for i in range(len(words)): meta = {} meta["word"] = words[i] meta["start"] = round(start[i] * frame_shift * frame_subsampling, 2) meta["end"] = round((start[i]+dur[i]) * frame_shift * frame_subsampling, 2) output["words"].append(meta) text += " "+meta["word"] output["text"] = text except Exception as e: self.log.error(e) raise ValueError("Decoder failed to create the word timestamps!!!") else: return output