예제 #1
0
    async def handle_asr_train(
        self, train: AsrTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError]
    ]:
        """Re-trains ASR system"""
        try:
            # Load gzipped graph pickle
            _LOGGER.debug("Loading %s", train.graph_path)
            with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                intent_graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

            # Get JSON intent graph
            json_graph = rhasspynlu.graph_to_json(intent_graph)

            if self.asr_train_url:
                # Remote ASR server
                _LOGGER.debug(self.asr_train_url)

                async with self.http_session.post(
                    self.asr_train_url, json=json_graph, ssl=self.ssl_context
                ) as response:
                    # No data expected back
                    response.raise_for_status()
            elif self.asr_train_command:
                # Local ASR training command
                _LOGGER.debug(self.asr_train_command)

                proc = await asyncio.create_subprocess_exec(
                    *self.asr_train_command,
                    stdin=asyncio.subprocess.PIPE,
                    sterr=asyncio.subprocess.PIPE,
                )

                output, error = await proc.communicate(json.dumps(json_graph).encode())

                if output:
                    _LOGGER.debug(output.decode())

                if error:
                    _LOGGER.debug(error.decode())
            else:
                _LOGGER.warning("Can't train ASR system. No train URL or command.")

            # Report success
            yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_asr_train")
            yield AsrError(
                error=str(e),
                context=f"url='{self.asr_train_url}', command='{self.asr_train_command}'",
                site_id=site_id,
                session_id=train.id,
            )
예제 #2
0
    async def async_test_train_success(self):
        """Check successful training."""
        train = AsrTrain(id=self.session_id, graph_path="fake.pickle.gz")

        # Send in training request
        result = None
        async for response in self.hermes.on_message(train, site_id=self.site_id):
            result = response

        self.assertEqual(
            result, (AsrTrainSuccess(id=self.session_id), {"site_id": self.site_id})
        )
예제 #3
0
    async def handle_train(
        self,
        train: AsrTrain,
        site_id: str = "default"
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[AsrTrainSuccess,
                                                        TopicArgs], AsrError]]:
        """Re-trains ASR system

        This is basically a noop function as training is done by Google.
        Rhasspy requires a training handler though.
        """

        yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
예제 #4
0
    async def handle_train(
        self, train: AsrTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError]
    ]:
        """Re-trains ASR system"""
        try:
            if (
                not self.no_overwrite_train
                and self.model_path
                and self.scorer_path
                and self.alphabet_path
            ):
                _LOGGER.debug("no overwrite: %s, model_path: %s, scorer_path: %s, alphabet: %s", self.no_overwrite_train, self.model_path, self.scorer_path, self.alphabet_path)
                _LOGGER.debug("Loading %s", train.graph_path)
                with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                    graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

                # Generate language model/trie
                _LOGGER.debug("Starting training")
                rhasspyasr_deepspeech.train(
                    graph,
                    self.model_path,
                    self.scorer_path,
                    self.alphabet_path,
                    base_language_model_fst=self.base_language_model_fst,
                    base_language_model_weight=self.base_language_model_weight,
                    mixed_language_model_fst=self.mixed_language_model_fst,
                )
            else:
                _LOGGER.debug("no overwrite: %s, model_path: %s, scorer_path: %s, alphabet: %s", self.no_overwrite_train, self.model_path, self.scorer_path, self.alphabet_path)
                _LOGGER.warning("Not overwriting language model/trie")

            # Model will reload
            self.transcriber = None

            yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield AsrError(
                error=str(e),
                context="handle_train",
                site_id=site_id,
                session_id=train.id,
            )
    async def handle_train(
        self, train: AsrTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError]
    ]:
        """Re-trains ASR system"""
        try:
            if (
                not self.no_overwrite_train
                and self.language_model_path
                and self.scorer_path
                and self.alphabet_path
            ):
                _LOGGER.debug("Loading %s", train.graph_path)
                with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                    graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

                # Generate language model/scorer
                _LOGGER.debug("Starting training")
                rhasspyasr_deepspeech.train(
                    graph=graph,
                    language_model=self.language_model_path,
                    scorer_path=self.scorer_path,
                    alphabet_path=self.alphabet_path,
                    base_language_model_fst=self.base_language_model_fst,
                    base_language_model_weight=self.base_language_model_weight,
                    mixed_language_model_fst=self.mixed_language_model_fst,
                )
            else:
                _LOGGER.warning("Not overwriting language model/scorer")

            # Clear out existing transcribers so models can reload on next voice command
            self.free_transcribers = []
            for info in self.sessions.values():
                info.reuse = False

            yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield AsrError(
                error=str(e),
                context="handle_train",
                site_id=site_id,
                session_id=train.id,
            )
예제 #6
0
    async def handle_train(
        self,
        train: AsrTrain,
        site_id: str = "default"
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[AsrTrainSuccess,
                                                        TopicArgs], AsrError]]:
        """Re-trains ASR system."""
        try:
            assert (self.model_dir and self.graph_dir
                    ), "Model and graph dirs are required to train"

            # Load base dictionaries
            pronunciations: PronunciationsType = defaultdict(list)
            for base_dict in self.base_dictionaries:
                if not os.path.exists(base_dict.path):
                    _LOGGER.warning("Base dictionary does not exist: %s",
                                    base_dict.path)
                    continue

                # Re-load dictionary if modification time has changed
                dict_mtime_ns = os.stat(base_dict.path).st_mtime_ns
                if (base_dict.mtime_ns is None) or (base_dict.mtime_ns !=
                                                    dict_mtime_ns):
                    base_dict.mtime_ns = dict_mtime_ns
                    _LOGGER.debug("Loading base dictionary from %s",
                                  base_dict.path)
                    with open(base_dict.path, "r") as base_dict_file:
                        rhasspynlu.g2p.read_pronunciations(
                            base_dict_file, word_dict=base_dict.pronunciations)

                for word in base_dict.pronunciations:
                    pronunciations[word].extend(base_dict.pronunciations[word])

            if not self.no_overwrite_train:
                _LOGGER.debug("Loading %s", train.graph_path)
                with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                    graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

                # Re-generate HCLG.fst
                _LOGGER.debug("Starting training")
                rhasspyasr_kaldi.train(
                    graph,
                    pronunciations,
                    self.model_dir,
                    self.graph_dir,
                    dictionary=self.dictionary_path,
                    language_model=self.language_model_path,
                    dictionary_word_transform=self.dictionary_word_transform,
                    g2p_model=self.g2p_model,
                    g2p_word_transform=self.g2p_word_transform,
                    missing_words_path=self.unknown_words,
                    base_language_model_fst=self.base_language_model_fst,
                    base_language_model_weight=self.base_language_model_weight,
                    mixed_language_model_fst=self.mixed_language_model_fst,
                )
            else:
                _LOGGER.warning("Not overwriting HCLG.fst")
                kaldi_dir = rhasspyasr_kaldi.get_kaldi_dir()
                rhasspyasr_kaldi.train_prepare_online_decoding(
                    self.model_dir, self.graph_dir, kaldi_dir)

            yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("train")
            yield AsrError(error=str(e), site_id=site_id, session_id=train.id)
    async def handle_train(
        self, train: AsrTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError]
    ]:
        """Re-trains ASR system"""
        try:
            if not self.base_dictionaries:
                _LOGGER.warning(
                    "No base dictionaries provided. Training will likely fail."
                )

            # Load base dictionaries
            pronunciations: PronunciationsType = defaultdict(list)
            for base_dict in self.base_dictionaries:
                if not os.path.exists(base_dict.path):
                    _LOGGER.warning(
                        "Base dictionary does not exist: %s", base_dict.path
                    )
                    continue

                # Re-load dictionary if modification time has changed
                dict_mtime_ns = os.stat(base_dict.path).st_mtime_ns
                if (base_dict.mtime_ns is None) or (
                    base_dict.mtime_ns != dict_mtime_ns
                ):
                    base_dict.mtime_ns = dict_mtime_ns
                    _LOGGER.debug("Loading base dictionary from %s", base_dict.path)
                    with open(base_dict.path, "r") as base_dict_file:
                        rhasspynlu.g2p.read_pronunciations(
                            base_dict_file, word_dict=base_dict.pronunciations
                        )

                for word in base_dict.pronunciations:
                    pronunciations[word].extend(base_dict.pronunciations[word])

            # Load intent graph
            _LOGGER.debug("Loading %s", train.graph_path)
            with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                self.intent_graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

            # Clean LM cache completely
            for lm_path in self.lm_cache_paths.values():
                try:
                    lm_path.unlink()
                except Exception:
                    pass

            self.lm_cache_paths = {}
            self.lm_cache_transcribers = {}

            # Generate dictionary/language model
            if not self.no_overwrite_train:
                _LOGGER.debug("Starting training")
                rhasspyasr_pocketsphinx.train(
                    self.intent_graph,
                    self.dictionary,
                    self.language_model,
                    pronunciations,
                    dictionary_word_transform=self.dictionary_word_transform,
                    g2p_model=self.g2p_model,
                    g2p_word_transform=self.g2p_word_transform,
                    missing_words_path=self.unknown_words,
                    base_language_model_fst=self.base_language_model_fst,
                    base_language_model_weight=self.base_language_model_weight,
                    mixed_language_model_fst=self.mixed_language_model_fst,
                )
            else:
                _LOGGER.warning("Not overwriting dictionary/language model")

            _LOGGER.debug("Re-loading transcriber")
            self.transcriber = self.make_transcriber(self.language_model)

            yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield AsrError(
                error=str(e),
                context=repr(self.transcriber),
                site_id=site_id,
                session_id=train.id,
            )