Beispiel #1
0
async def train_profile(profile_dir: Path,
                        profile: typing.Dict[str, typing.Any]) -> None:
    """Re-generate speech/intent artifacts for profile."""

    # Compact
    def ppath(query, default=None):
        return utils_ppath(profile, profile_dir, query, default)

    language_code = pydash.get(profile, "language.code", "en-US")

    sentences_ini = ppath("training.sentences-file", "sentences.ini")
    slots_dir = ppath("training.slots-directory", "slots")
    slot_programs = ppath("training.slot-programs-directory", "slot_programs")

    # Profile files that are split into parts and gzipped
    large_paths = [
        Path(p) for p in pydash.get(profile, "training.large-files", [])
    ]

    # -------------------
    # Speech to text
    # -------------------
    base_dictionary = ppath("training.base-dictionary", "base_dictionary.txt")
    custom_words = ppath("training.custom-words-file", "custom_words.txt")
    custom_words_action = PronunciationAction(
        pydash.get(profile, "training.custom-words-action", "append"))
    sounds_like = ppath("training.sounds-like-file", "sounds_like.txt")
    sounds_like_action = PronunciationAction(
        pydash.get(profile, "training.sounds-like-action", "append"))

    acoustic_model = ppath("training.acoustic-model", "acoustic_model")
    acoustic_model_type = AcousticModelType(
        pydash.get(profile, "training.acoustic-model-type",
                   AcousticModelType.DUMMY))

    # Replace numbers with words
    replace_numbers = bool(
        pydash.get(profile, "training.replace-numbers", True))

    # ignore/upper/lower
    word_casing = pydash.get(profile, "training.word-casing",
                             WordCasing.IGNORE)

    # Large pre-built language model
    base_language_model_fst = ppath("training.base-language-model-fst",
                                    "base_language_model.fst")
    base_language_model_weight = float(
        pydash.get(profile, "training.base-language-model-weight", 0))

    # -------------------
    # Grapheme to phoneme
    # -------------------
    g2p_model = ppath("training.grapheme-to-phoneme-model", "g2p.fst")
    g2p_corpus = ppath("training.grapheme-to-phoneme-corpus", "g2p.corpus")

    # default/ignore/upper/lower
    g2p_word_casing = pydash.get(profile, "training.g2p-word-casing",
                                 word_casing)

    # -------
    # Outputs
    # -------
    dictionary_path = ppath("training.dictionary", "dictionary.txt")
    language_model_path = ppath("training.language-model",
                                "language_model.txt")
    language_model_fst_path = ppath("training.language-model-fst",
                                    "language_model.fst")
    mixed_language_model_fst_path = ppath("training.mixed-language-model-fst",
                                          "mixed_language_model.fst")
    intent_graph_path = ppath("training.intent-graph", "intent.pickle.gz")
    vocab_path = ppath("training.vocabulary-file", "vocab.txt")
    unknown_words_path = ppath("training.unknown-words-file",
                               "unknown_words.txt")

    async def run(command: typing.List[str], **kwargs):
        """Run a command asynchronously."""
        process = await asyncio.create_subprocess_exec(*command, **kwargs)
        await process.wait()
        assert process.returncode == 0, "Command failed"

    # -------------------------------------------------------------------------
    # 1. Reassemble large files
    # -------------------------------------------------------------------------

    for target_path in large_paths:
        gzip_path = Path(str(target_path) + ".gz")
        part_paths = sorted(
            list(gzip_path.parent.glob(f"{gzip_path.name}.part-*")))
        if part_paths:
            # Concatenate paths to together
            cat_command = ["cat"] + [str(p) for p in part_paths]
            _LOGGER.debug(cat_command)

            with open(gzip_path, "wb") as gzip_file:
                await run(cat_command, stdout=gzip_file)

        if gzip_path.is_file():
            # Unzip single file
            unzip_command = ["gunzip", "-f", "--stdout", str(gzip_path)]
            _LOGGER.debug(unzip_command)

            with open(target_path, "wb") as target_file:
                await run(unzip_command, stdout=target_file)

            # Delete zip file
            gzip_path.unlink()

        # Delete unneeded .gz-part files
        for part_path in part_paths:
            part_path.unlink()

    # -------------------------------------------------------------------------
    # 2. Generate intent graph
    # -------------------------------------------------------------------------

    # Parse JSGF sentences
    _LOGGER.debug("Parsing %s", sentences_ini)
    intents = rhasspynlu.parse_ini(sentences_ini)

    # Split into sentences and rule/slot replacements
    sentences, replacements = rhasspynlu.ini_jsgf.split_rules(intents)

    word_transform = None
    if word_casing == WordCasing.UPPER:
        word_transform = str.upper
    elif word_casing == WordCasing.LOWER:
        word_transform = str.lower

    word_visitor: typing.Optional[typing.Callable[[Expression], typing.Union[
        bool, Expression]]] = None

    if word_transform:
        # Apply transformation to words

        def transform_visitor(word: Expression):
            if isinstance(word, Word):
                assert word_transform
                new_text = word_transform(word.text)

                # Preserve case by using original text as substition
                if (word.substitution is None) and (new_text != word.text):
                    word.substitution = word.text

                word.text = new_text

            return word

        word_visitor = transform_visitor

    # Apply case/number transforms
    if word_visitor or replace_numbers:
        for intent_sentences in sentences.values():
            for sentence in intent_sentences:
                if replace_numbers:
                    # Replace number ranges with slot references
                    # type: ignore
                    rhasspynlu.jsgf.walk_expression(
                        sentence, rhasspynlu.number_range_transform,
                        replacements)

                if word_visitor:
                    # Do case transformation
                    # type: ignore
                    rhasspynlu.jsgf.walk_expression(sentence, word_visitor,
                                                    replacements)

    # Load slot values
    slot_replacements = rhasspynlu.get_slot_replacements(
        intents,
        slots_dirs=[slots_dir],
        slot_programs_dirs=[slot_programs],
        slot_visitor=word_visitor,
    )

    # Merge with existing replacements
    for slot_key, slot_values in slot_replacements.items():
        replacements[slot_key] = slot_values

    if replace_numbers:
        # Do single number transformations
        for intent_sentences in sentences.values():
            for sentence in intent_sentences:
                rhasspynlu.jsgf.walk_expression(
                    sentence,
                    lambda w: rhasspynlu.number_transform(w, language_code),
                    replacements,
                )

    # Convert to directed graph
    intent_graph = rhasspynlu.sentences_to_graph(sentences,
                                                 replacements=replacements)

    # Convert to gzipped pickle
    intent_graph_path.parent.mkdir(exist_ok=True)
    with open(intent_graph_path, mode="wb") as intent_graph_file:
        rhasspynlu.graph_to_gzip_pickle(intent_graph, intent_graph_file)

    _LOGGER.debug("Wrote intent graph to %s", intent_graph_path)

    g2p_word_transform = None
    if g2p_word_casing == WordCasing.UPPER:
        g2p_word_transform = str.upper
    elif g2p_word_casing == WordCasing.LOWER:
        g2p_word_transform = str.lower

    # Load phonetic dictionaries
    pronunciations: PronunciationsType = {}
    if acoustic_model_type in [
            AcousticModelType.POCKETSPHINX,
            AcousticModelType.KALDI,
            AcousticModelType.JULIUS,
    ]:
        pronunciations, _ = load_pronunciations(
            base_dictionary=base_dictionary,
            custom_words=custom_words,
            custom_words_action=custom_words_action,
            sounds_like=sounds_like,
            sounds_like_action=sounds_like_action,
            g2p_corpus=g2p_corpus,
        )

    # -------------------------------------------------------------------------
    # Speech to Text Training
    # -------------------------------------------------------------------------

    if acoustic_model_type == AcousticModelType.POCKETSPHINX:
        # Pocketsphinx
        import rhasspyasr_pocketsphinx

        rhasspyasr_pocketsphinx.train(
            intent_graph,
            dictionary_path,
            language_model_path,
            pronunciations,
            dictionary_word_transform=word_transform,
            g2p_model=g2p_model,
            g2p_word_transform=g2p_word_transform,
            missing_words_path=unknown_words_path,
            vocab_path=vocab_path,
            language_model_fst=language_model_fst_path,
            base_language_model_fst=base_language_model_fst,
            base_language_model_weight=base_language_model_weight,
            mixed_language_model_fst=mixed_language_model_fst_path,
        )
    elif acoustic_model_type == AcousticModelType.KALDI:
        # Kaldi
        import rhasspyasr_kaldi
        from rhasspyasr_kaldi.train import LanguageModelType

        graph_dir = ppath("training.kaldi.graph-directory") or (
            acoustic_model / "graph")

        # Type of language model to generate
        language_model_type = LanguageModelType(
            pydash.get(profile, "training.kaldi.language-model-type", "arpa"))

        rhasspyasr_kaldi.train(
            intent_graph,
            pronunciations,
            acoustic_model,
            graph_dir,
            dictionary_path,
            language_model_path,
            language_model_type=language_model_type,
            dictionary_word_transform=word_transform,
            g2p_model=g2p_model,
            g2p_word_transform=g2p_word_transform,
            missing_words_path=unknown_words_path,
            vocab_path=vocab_path,
            language_model_fst=language_model_fst_path,
            base_language_model_fst=base_language_model_fst,
            base_language_model_weight=base_language_model_weight,
            mixed_language_model_fst=mixed_language_model_fst_path,
        )
    elif acoustic_model_type == AcousticModelType.DEEPSPEECH:
        # DeepSpeech
        import rhasspyasr_deepspeech

        trie_path = ppath("training.deepspeech.trie", "trie")
        alphabet_path = ppath("training.deepspeech.alphabet",
                              "model/alphabet.txt")

        rhasspyasr_deepspeech.train(
            intent_graph,
            language_model_path,
            trie_path,
            alphabet_path,
            vocab_path=vocab_path,
            language_model_fst=language_model_fst_path,
            base_language_model_fst=base_language_model_fst,
            base_language_model_weight=base_language_model_weight,
            mixed_language_model_fst=mixed_language_model_fst_path,
        )
    elif acoustic_model_type == AcousticModelType.JULIUS:
        # Julius
        from .julius import train as train_julius

        train_julius(
            intent_graph,
            dictionary_path,
            language_model_path,
            pronunciations,
            dictionary_word_transform=word_transform,
            silence_words={"<s>", "</s>"},
            g2p_model=g2p_model,
            g2p_word_transform=g2p_word_transform,
            missing_words_path=unknown_words_path,
            vocab_path=vocab_path,
            language_model_fst=language_model_fst_path,
            base_language_model_fst=base_language_model_fst,
            base_language_model_weight=base_language_model_weight,
            mixed_language_model_fst=mixed_language_model_fst_path,
        )
    else:
        _LOGGER.warning("Not training speech to text system (%s)",
                        acoustic_model_type)
    async def train(self):
        """Send an NLU query and wait for intent or not recognized"""

        # Load sentences.ini files
        sentences_ini = self.profile.read_path(
            self.profile.get("speech_to_text.sentences_ini", "sentences.ini")
        )
        sentences_dir: typing.Optional[Path] = self.profile.read_path(
            self.profile.get("speech_to_text.sentences_dir", "intents")
        )

        assert sentences_dir is not None
        if not sentences_dir.is_dir():
            sentences_dir = None

        ini_paths = get_ini_paths(sentences_ini, sentences_dir)
        _LOGGER.debug("Loading sentences from %s", ini_paths)

        sentences_dict = {str(p): p.read_text() for p in ini_paths}

        # Load settings
        language = self.profile.get("language", "en")
        dictionary_casing = self.profile.get(
            "speech_to_text.dictionary_casing", "ignore"
        ).lower()
        word_transform = None
        if dictionary_casing == "upper":
            word_transform = str.upper
        elif dictionary_casing == "lower":
            word_transform = str.lower

        slots_dir = self.profile.write_path(
            self.profile.get("speech_to_text.slots_dir", "slots")
        )
        system_slots_dir = (
            self.profile.system_profiles_dir / self.profile.name / "slots"
        )
        slot_programs_dir = self.profile.write_path(
            self.profile.get("speech_to_text.slot_programs_dir", "slot_programs")
        )
        system_slot_programs_dir = (
            self.profile.system_profiles_dir / self.profile.name / "slot_programs"
        )

        # Convert to graph
        _LOGGER.debug("Generating intent graph")

        intent_graph, slot_replacements = sentences_to_graph(
            sentences_dict,
            slots_dirs=[slots_dir, system_slots_dir],
            slot_programs_dirs=[slot_programs_dir, system_slot_programs_dir],
            language=language,
            word_transform=word_transform,
        )

        # Convert to dict for train messages
        slots_dict = {
            slot_name: [value.text for value in values]
            for slot_name, values in slot_replacements.items()
        }

        # Convert to gzipped pickle
        graph_path = self.profile.write_path(
            self.profile.get("intent.fsticuffs.intent_json", "intent_graph.pickle.gz")
        )
        _LOGGER.debug("Writing %s", graph_path)
        with open(graph_path, mode="wb") as graph_file:
            rhasspynlu.graph_to_gzip_pickle(intent_graph, graph_file)

        _LOGGER.debug("Finished writing %s", graph_path)

        # Send to ASR/NLU systems
        has_speech = self.asr_system != "dummy"
        has_intent = self.nlu_system != "dummy"

        if has_speech or has_intent:
            request_id = str(uuid4())

            def handle_train():
                asr_response = None if has_speech else True
                nlu_response = None if has_intent else True

                while True:
                    _, message = yield

                    if isinstance(message, NluTrainSuccess) and (
                        message.id == request_id
                    ):
                        nlu_response = message
                    elif isinstance(message, AsrTrainSuccess) and (
                        message.id == request_id
                    ):
                        asr_response = message
                    if isinstance(message, NluError) and (
                        message.session_id == request_id
                    ):
                        nlu_response = message
                    elif isinstance(message, AsrError) and (
                        message.session_id == request_id
                    ):
                        asr_response = message

                    if asr_response and nlu_response:
                        return [asr_response, nlu_response]

            messages: typing.List[
                typing.Tuple[
                    typing.Union[NluTrain, AsrTrain], typing.Dict[str, typing.Any]
                ],
            ] = []

            message_types: typing.List[typing.Type[Message]] = []

            if has_speech:
                # Request ASR training
                messages.append(
                    (
                        AsrTrain(
                            id=request_id,
                            graph_path=str(graph_path.absolute()),
                            sentences=sentences_dict,
                            slots=slots_dict,
                        ),
                        {"site_id": self.site_id},
                    )
                )
                message_types.extend([AsrTrainSuccess, AsrError])

            if has_intent:
                # Request NLU training
                messages.append(
                    (
                        NluTrain(
                            id=request_id,
                            graph_path=str(graph_path.absolute()),
                            sentences=sentences_dict,
                            slots=slots_dict,
                        ),
                        {"site_id": self.site_id},
                    )
                )
                message_types.extend([NluTrainSuccess, NluError])

            # Expecting only a single result
            result = None
            async for response in self.publish_wait(
                handle_train(),
                messages,
                message_types,
                timeout_seconds=self.training_timeout_seconds,
            ):
                result = response

            # Check result
            assert isinstance(result, list), f"Expected list, got {result}"
            asr_response, nlu_response = result

            if isinstance(asr_response, AsrError):
                _LOGGER.error(asr_response)
                raise TrainingFailedException(reason=asr_response.error)

            if isinstance(nlu_response, NluError):
                _LOGGER.error(nlu_response)
                raise TrainingFailedException(reason=nlu_response.error)

            return result

        return None