async def handle_asr_train( self, train: AsrTrain, site_id: str = "default" ) -> typing.AsyncIterable[ typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError] ]: """Re-trains ASR system""" try: # Load gzipped graph pickle _LOGGER.debug("Loading %s", train.graph_path) with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip: intent_graph = nx.readwrite.gpickle.read_gpickle(graph_gzip) # Get JSON intent graph json_graph = rhasspynlu.graph_to_json(intent_graph) if self.asr_train_url: # Remote ASR server _LOGGER.debug(self.asr_train_url) async with self.http_session.post( self.asr_train_url, json=json_graph, ssl=self.ssl_context ) as response: # No data expected back response.raise_for_status() elif self.asr_train_command: # Local ASR training command _LOGGER.debug(self.asr_train_command) proc = await asyncio.create_subprocess_exec( *self.asr_train_command, stdin=asyncio.subprocess.PIPE, sterr=asyncio.subprocess.PIPE, ) output, error = await proc.communicate(json.dumps(json_graph).encode()) if output: _LOGGER.debug(output.decode()) if error: _LOGGER.debug(error.decode()) else: _LOGGER.warning("Can't train ASR system. No train URL or command.") # Report success yield (AsrTrainSuccess(id=train.id), {"site_id": site_id}) except Exception as e: _LOGGER.exception("handle_asr_train") yield AsrError( error=str(e), context=f"url='{self.asr_train_url}', command='{self.asr_train_command}'", site_id=site_id, session_id=train.id, )
async def async_test_train_success(self): """Check successful training.""" train = AsrTrain(id=self.session_id, graph_path="fake.pickle.gz") # Send in training request result = None async for response in self.hermes.on_message(train, site_id=self.site_id): result = response self.assertEqual( result, (AsrTrainSuccess(id=self.session_id), {"site_id": self.site_id}) )
async def handle_train( self, train: AsrTrain, site_id: str = "default" ) -> typing.AsyncIterable[typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError]]: """Re-trains ASR system This is basically a noop function as training is done by Google. Rhasspy requires a training handler though. """ yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
async def handle_train( self, train: AsrTrain, site_id: str = "default" ) -> typing.AsyncIterable[ typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError] ]: """Re-trains ASR system""" try: if ( not self.no_overwrite_train and self.model_path and self.scorer_path and self.alphabet_path ): _LOGGER.debug("no overwrite: %s, model_path: %s, scorer_path: %s, alphabet: %s", self.no_overwrite_train, self.model_path, self.scorer_path, self.alphabet_path) _LOGGER.debug("Loading %s", train.graph_path) with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip: graph = nx.readwrite.gpickle.read_gpickle(graph_gzip) # Generate language model/trie _LOGGER.debug("Starting training") rhasspyasr_deepspeech.train( graph, self.model_path, self.scorer_path, self.alphabet_path, base_language_model_fst=self.base_language_model_fst, base_language_model_weight=self.base_language_model_weight, mixed_language_model_fst=self.mixed_language_model_fst, ) else: _LOGGER.debug("no overwrite: %s, model_path: %s, scorer_path: %s, alphabet: %s", self.no_overwrite_train, self.model_path, self.scorer_path, self.alphabet_path) _LOGGER.warning("Not overwriting language model/trie") # Model will reload self.transcriber = None yield (AsrTrainSuccess(id=train.id), {"site_id": site_id}) except Exception as e: _LOGGER.exception("handle_train") yield AsrError( error=str(e), context="handle_train", site_id=site_id, session_id=train.id, )
async def handle_train( self, train: AsrTrain, site_id: str = "default" ) -> typing.AsyncIterable[ typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError] ]: """Re-trains ASR system""" try: if ( not self.no_overwrite_train and self.language_model_path and self.scorer_path and self.alphabet_path ): _LOGGER.debug("Loading %s", train.graph_path) with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip: graph = nx.readwrite.gpickle.read_gpickle(graph_gzip) # Generate language model/scorer _LOGGER.debug("Starting training") rhasspyasr_deepspeech.train( graph=graph, language_model=self.language_model_path, scorer_path=self.scorer_path, alphabet_path=self.alphabet_path, base_language_model_fst=self.base_language_model_fst, base_language_model_weight=self.base_language_model_weight, mixed_language_model_fst=self.mixed_language_model_fst, ) else: _LOGGER.warning("Not overwriting language model/scorer") # Clear out existing transcribers so models can reload on next voice command self.free_transcribers = [] for info in self.sessions.values(): info.reuse = False yield (AsrTrainSuccess(id=train.id), {"site_id": site_id}) except Exception as e: _LOGGER.exception("handle_train") yield AsrError( error=str(e), context="handle_train", site_id=site_id, session_id=train.id, )
async def handle_train( self, train: AsrTrain, site_id: str = "default" ) -> typing.AsyncIterable[typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError]]: """Re-trains ASR system.""" try: assert (self.model_dir and self.graph_dir ), "Model and graph dirs are required to train" # Load base dictionaries pronunciations: PronunciationsType = defaultdict(list) for base_dict in self.base_dictionaries: if not os.path.exists(base_dict.path): _LOGGER.warning("Base dictionary does not exist: %s", base_dict.path) continue # Re-load dictionary if modification time has changed dict_mtime_ns = os.stat(base_dict.path).st_mtime_ns if (base_dict.mtime_ns is None) or (base_dict.mtime_ns != dict_mtime_ns): base_dict.mtime_ns = dict_mtime_ns _LOGGER.debug("Loading base dictionary from %s", base_dict.path) with open(base_dict.path, "r") as base_dict_file: rhasspynlu.g2p.read_pronunciations( base_dict_file, word_dict=base_dict.pronunciations) for word in base_dict.pronunciations: pronunciations[word].extend(base_dict.pronunciations[word]) if not self.no_overwrite_train: _LOGGER.debug("Loading %s", train.graph_path) with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip: graph = nx.readwrite.gpickle.read_gpickle(graph_gzip) # Re-generate HCLG.fst _LOGGER.debug("Starting training") rhasspyasr_kaldi.train( graph, pronunciations, self.model_dir, self.graph_dir, dictionary=self.dictionary_path, language_model=self.language_model_path, dictionary_word_transform=self.dictionary_word_transform, g2p_model=self.g2p_model, g2p_word_transform=self.g2p_word_transform, missing_words_path=self.unknown_words, base_language_model_fst=self.base_language_model_fst, base_language_model_weight=self.base_language_model_weight, mixed_language_model_fst=self.mixed_language_model_fst, ) else: _LOGGER.warning("Not overwriting HCLG.fst") kaldi_dir = rhasspyasr_kaldi.get_kaldi_dir() rhasspyasr_kaldi.train_prepare_online_decoding( self.model_dir, self.graph_dir, kaldi_dir) yield (AsrTrainSuccess(id=train.id), {"site_id": site_id}) except Exception as e: _LOGGER.exception("train") yield AsrError(error=str(e), site_id=site_id, session_id=train.id)
async def handle_train( self, train: AsrTrain, site_id: str = "default" ) -> typing.AsyncIterable[ typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError] ]: """Re-trains ASR system""" try: if not self.base_dictionaries: _LOGGER.warning( "No base dictionaries provided. Training will likely fail." ) # Load base dictionaries pronunciations: PronunciationsType = defaultdict(list) for base_dict in self.base_dictionaries: if not os.path.exists(base_dict.path): _LOGGER.warning( "Base dictionary does not exist: %s", base_dict.path ) continue # Re-load dictionary if modification time has changed dict_mtime_ns = os.stat(base_dict.path).st_mtime_ns if (base_dict.mtime_ns is None) or ( base_dict.mtime_ns != dict_mtime_ns ): base_dict.mtime_ns = dict_mtime_ns _LOGGER.debug("Loading base dictionary from %s", base_dict.path) with open(base_dict.path, "r") as base_dict_file: rhasspynlu.g2p.read_pronunciations( base_dict_file, word_dict=base_dict.pronunciations ) for word in base_dict.pronunciations: pronunciations[word].extend(base_dict.pronunciations[word]) # Load intent graph _LOGGER.debug("Loading %s", train.graph_path) with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip: self.intent_graph = nx.readwrite.gpickle.read_gpickle(graph_gzip) # Clean LM cache completely for lm_path in self.lm_cache_paths.values(): try: lm_path.unlink() except Exception: pass self.lm_cache_paths = {} self.lm_cache_transcribers = {} # Generate dictionary/language model if not self.no_overwrite_train: _LOGGER.debug("Starting training") rhasspyasr_pocketsphinx.train( self.intent_graph, self.dictionary, self.language_model, pronunciations, dictionary_word_transform=self.dictionary_word_transform, g2p_model=self.g2p_model, g2p_word_transform=self.g2p_word_transform, missing_words_path=self.unknown_words, base_language_model_fst=self.base_language_model_fst, base_language_model_weight=self.base_language_model_weight, mixed_language_model_fst=self.mixed_language_model_fst, ) else: _LOGGER.warning("Not overwriting dictionary/language model") _LOGGER.debug("Re-loading transcriber") self.transcriber = self.make_transcriber(self.language_model) yield (AsrTrainSuccess(id=train.id), {"site_id": site_id}) except Exception as e: _LOGGER.exception("handle_train") yield AsrError( error=str(e), context=repr(self.transcriber), site_id=site_id, session_id=train.id, )