Ejemplo n.º 1
0
    async def stop_listening(
        self, message: AsrStopListening
    ) -> typing.AsyncIterable[
        typing.Union[
            AsrTextCaptured,
            AsrError,
            typing.Tuple[AsrAudioCaptured, typing.Dict[str, typing.Any]],
        ]
    ]:
        """Stop recording audio data for a session."""
        try:
            session = self.sessions.pop(message.session_id, None)
            if session:
                # Stop session
                if session.recorder:
                    audio_data = session.recorder.stop()
                else:
                    assert session.audio_buffer is not None
                    audio_data = session.audio_buffer

                wav_bytes = self.to_wav_bytes(audio_data)

                _LOGGER.debug(
                    "Received a total of %s byte(s) for WAV data for session %s",
                    session.num_wav_bytes,
                    message.session_id,
                )

                if not session.transcription_sent:
                    # Send transcription
                    session.transcription_sent = True

                    yield (
                        await self.transcribe(
                            wav_bytes,
                            site_id=message.site_id,
                            session_id=message.session_id,
                            lang=session.start_listening.lang,
                        )
                    )

                    if session.start_listening.send_audio_captured:
                        # Send audio data
                        yield (
                            AsrAudioCaptured(wav_bytes=wav_bytes),
                            {
                                "site_id": message.site_id,
                                "session_id": message.session_id,
                            },
                        )

            _LOGGER.debug("Stopping listening (session_id=%s)", message.session_id)
        except Exception as e:
            _LOGGER.exception("stop_listening")
            yield AsrError(
                error=str(e),
                context="stop_listening",
                site_id=message.site_id,
                session_id=message.session_id,
            )
Ejemplo n.º 2
0
    async def handle_audio_frame(
        self,
        frame_wav_bytes: bytes,
        site_id: str = "default",
        session_id: typing.Optional[str] = None,
    ) -> typing.AsyncIterable[typing.Union[AsrTextCaptured, AsrError,
                                           typing.Tuple[AsrAudioCaptured,
                                                        TopicArgs]]]:
        """Process single frame of WAV audio"""

        # Don't process audio if no sessions
        if not self.sessions:
            return

        audio_data = self.maybe_convert_wav(frame_wav_bytes)

        if session_id is None:
            # Add to every open session
            target_sessions = list(self.sessions.items())
        else:
            # Add to single session
            target_sessions = [(session_id, self.sessions[session_id])]

        # Add to every open session with matching site_id
        for target_id, info in target_sessions:
            try:
                assert info.start_listening is not None

                # Match site_id
                if info.start_listening.site_id != site_id:
                    continue

                # Push to transcription thread
                info.frame_queue.put(audio_data)

                if info.recorder is not None:
                    # Check for voice command end
                    command = info.recorder.process_chunk(audio_data)

                    if command:
                        # Trigger publishing of transcription on silence
                        async for result in self.finish_session(
                                info, site_id=site_id, session_id=target_id):
                            yield result
                else:
                    # Use session audio buffer
                    assert info.audio_buffer is not None
                    info.audio_buffer += audio_data
            except Exception as e:
                _LOGGER.exception("handle_audio_frame")
                yield AsrError(
                    error=str(e),
                    context=repr(info.transcriber),
                    site_id=site_id,
                    session_id=target_id,
                )
Ejemplo n.º 3
0
    async def handle_asr_train(
        self, train: AsrTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError]
    ]:
        """Re-trains ASR system"""
        try:
            # Load gzipped graph pickle
            _LOGGER.debug("Loading %s", train.graph_path)
            with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                intent_graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

            # Get JSON intent graph
            json_graph = rhasspynlu.graph_to_json(intent_graph)

            if self.asr_train_url:
                # Remote ASR server
                _LOGGER.debug(self.asr_train_url)

                async with self.http_session.post(
                    self.asr_train_url, json=json_graph, ssl=self.ssl_context
                ) as response:
                    # No data expected back
                    response.raise_for_status()
            elif self.asr_train_command:
                # Local ASR training command
                _LOGGER.debug(self.asr_train_command)

                proc = await asyncio.create_subprocess_exec(
                    *self.asr_train_command,
                    stdin=asyncio.subprocess.PIPE,
                    sterr=asyncio.subprocess.PIPE,
                )

                output, error = await proc.communicate(json.dumps(json_graph).encode())

                if output:
                    _LOGGER.debug(output.decode())

                if error:
                    _LOGGER.debug(error.decode())
            else:
                _LOGGER.warning("Can't train ASR system. No train URL or command.")

            # Report success
            yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_asr_train")
            yield AsrError(
                error=str(e),
                context=f"url='{self.asr_train_url}', command='{self.asr_train_command}'",
                site_id=site_id,
                session_id=train.id,
            )
Ejemplo n.º 4
0
    async def handle_train(
        self, train: AsrTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError]
    ]:
        """Re-trains ASR system"""
        try:
            if (
                not self.no_overwrite_train
                and self.model_path
                and self.scorer_path
                and self.alphabet_path
            ):
                _LOGGER.debug("no overwrite: %s, model_path: %s, scorer_path: %s, alphabet: %s", self.no_overwrite_train, self.model_path, self.scorer_path, self.alphabet_path)
                _LOGGER.debug("Loading %s", train.graph_path)
                with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                    graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

                # Generate language model/trie
                _LOGGER.debug("Starting training")
                rhasspyasr_deepspeech.train(
                    graph,
                    self.model_path,
                    self.scorer_path,
                    self.alphabet_path,
                    base_language_model_fst=self.base_language_model_fst,
                    base_language_model_weight=self.base_language_model_weight,
                    mixed_language_model_fst=self.mixed_language_model_fst,
                )
            else:
                _LOGGER.debug("no overwrite: %s, model_path: %s, scorer_path: %s, alphabet: %s", self.no_overwrite_train, self.model_path, self.scorer_path, self.alphabet_path)
                _LOGGER.warning("Not overwriting language model/trie")

            # Model will reload
            self.transcriber = None

            yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield AsrError(
                error=str(e),
                context="handle_train",
                site_id=site_id,
                session_id=train.id,
            )
    async def handle_train(
        self, train: AsrTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError]
    ]:
        """Re-trains ASR system"""
        try:
            if (
                not self.no_overwrite_train
                and self.language_model_path
                and self.scorer_path
                and self.alphabet_path
            ):
                _LOGGER.debug("Loading %s", train.graph_path)
                with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                    graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

                # Generate language model/scorer
                _LOGGER.debug("Starting training")
                rhasspyasr_deepspeech.train(
                    graph=graph,
                    language_model=self.language_model_path,
                    scorer_path=self.scorer_path,
                    alphabet_path=self.alphabet_path,
                    base_language_model_fst=self.base_language_model_fst,
                    base_language_model_weight=self.base_language_model_weight,
                    mixed_language_model_fst=self.mixed_language_model_fst,
                )
            else:
                _LOGGER.warning("Not overwriting language model/scorer")

            # Clear out existing transcribers so models can reload on next voice command
            self.free_transcribers = []
            for info in self.sessions.values():
                info.reuse = False

            yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield AsrError(
                error=str(e),
                context="handle_train",
                site_id=site_id,
                session_id=train.id,
            )
Ejemplo n.º 6
0
    async def handle_start_listening(
        self, start_listening: AsrStartListening
    ) -> typing.AsyncIterable[AsrError]:
        """Start ASR session."""
        _LOGGER.debug("<- %s", start_listening)

        try:
            session = AsrSession(
                start_listening=start_listening, recorder=self.make_recorder()
            )

            self.asr_sessions[start_listening.session_id] = session
            session.recorder.start()
        except Exception as e:
            _LOGGER.exception("handle_start_listening")
            yield AsrError(
                error=str(e),
                context="",
                site_id=start_listening.site_id,
                session_id=start_listening.session_id,
            )
    async def stop_listening(
        self, message: AsrStopListening
    ) -> typing.AsyncIterable[StopListeningType]:
        """Stop recording audio data for a session."""
        info = self.sessions.pop(message.session_id, None)
        if info:
            try:
                # Trigger publishing of transcription on end of session
                async for result in self.finish_session(
                    info, message.site_id, message.session_id
                ):
                    yield result

                if info.reuse and (info.transcriber is not None):
                    # Reset state
                    info.result = None
                    info.result_event.clear()
                    info.result_sent = False
                    info.result = None
                    info.start_listening = None
                    info.audio_buffer = None

                    while info.frame_queue.qsize() > 0:
                        info.frame_queue.get_nowait()

                    # Add to free pool
                    self.free_transcribers.append(info)
            except Exception as e:
                _LOGGER.exception("stop_listening")
                yield AsrError(
                    error=str(e),
                    context=repr(info.transcriber),
                    site_id=message.site_id,
                    session_id=message.session_id,
                )

        _LOGGER.debug("Stopping listening (session_id=%s)", message.session_id)
Ejemplo n.º 8
0
    async def handle_audio_frame(
        self,
        frame_wav_bytes: bytes,
        site_id: str = "default",
        session_id: typing.Optional[str] = None,
    ) -> typing.AsyncIterable[typing.Union[
            AsrTextCaptured, AsrError, typing.Tuple[
                AsrAudioCaptured, typing.Dict[str, typing.Any]], ]]:
        """Process single frame of WAV audio"""
        # Don't process audio if no sessions
        if not self.sessions:
            return

        audio_data = self.maybe_convert_wav(frame_wav_bytes)

        if session_id is None:
            # Add to every open session
            target_sessions = list(self.sessions.items())
        else:
            # Add to single session
            target_sessions = [(session_id, self.sessions[session_id])]

        # Add audio to session(s)
        for target_id, session in target_sessions:
            try:
                # Skip if site_id doesn't match
                if session.start_listening.site_id != site_id:
                    continue

                session.num_wav_bytes += len(frame_wav_bytes)
                if session.recorder:
                    # Check for end of voice command
                    command = session.recorder.process_chunk(audio_data)
                    if command and (command.result
                                    == VoiceCommandResult.SUCCESS):
                        assert command.audio_data is not None
                        _LOGGER.debug(
                            "Voice command recorded for session %s (%s byte(s))",
                            target_id,
                            len(command.audio_data),
                        )

                        session.transcription_sent = True
                        wav_bytes = self.to_wav_bytes(command.audio_data)

                        yield (await self.transcribe(
                            wav_bytes,
                            site_id=site_id,
                            session_id=target_id,
                            wakeword_id=session.start_listening.wakeword_id))

                        if session.start_listening.send_audio_captured:
                            # Send audio data
                            yield (
                                AsrAudioCaptured(wav_bytes=wav_bytes),
                                {
                                    "site_id": site_id,
                                    "session_id": target_id
                                },
                            )

                        # Reset session (but keep open)
                        session.recorder.stop()
                        session.recorder.start()
                else:
                    # Add to buffer
                    assert session.audio_buffer is not None
                    session.audio_buffer += audio_data
            except Exception as e:
                _LOGGER.exception("handle_audio_frame")
                yield AsrError(
                    error=str(e),
                    context=repr(self.transcriber),
                    site_id=site_id,
                    session_id=target_id,
                )
Ejemplo n.º 9
0
    async def handle_train(
        self,
        train: AsrTrain,
        site_id: str = "default"
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[AsrTrainSuccess,
                                                        TopicArgs], AsrError]]:
        """Re-trains ASR system."""
        try:
            assert (self.model_dir and self.graph_dir
                    ), "Model and graph dirs are required to train"

            # Load base dictionaries
            pronunciations: PronunciationsType = defaultdict(list)
            for base_dict in self.base_dictionaries:
                if not os.path.exists(base_dict.path):
                    _LOGGER.warning("Base dictionary does not exist: %s",
                                    base_dict.path)
                    continue

                # Re-load dictionary if modification time has changed
                dict_mtime_ns = os.stat(base_dict.path).st_mtime_ns
                if (base_dict.mtime_ns is None) or (base_dict.mtime_ns !=
                                                    dict_mtime_ns):
                    base_dict.mtime_ns = dict_mtime_ns
                    _LOGGER.debug("Loading base dictionary from %s",
                                  base_dict.path)
                    with open(base_dict.path, "r") as base_dict_file:
                        rhasspynlu.g2p.read_pronunciations(
                            base_dict_file, word_dict=base_dict.pronunciations)

                for word in base_dict.pronunciations:
                    pronunciations[word].extend(base_dict.pronunciations[word])

            if not self.no_overwrite_train:
                _LOGGER.debug("Loading %s", train.graph_path)
                with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                    graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

                # Re-generate HCLG.fst
                _LOGGER.debug("Starting training")
                rhasspyasr_kaldi.train(
                    graph,
                    pronunciations,
                    self.model_dir,
                    self.graph_dir,
                    dictionary=self.dictionary_path,
                    language_model=self.language_model_path,
                    dictionary_word_transform=self.dictionary_word_transform,
                    g2p_model=self.g2p_model,
                    g2p_word_transform=self.g2p_word_transform,
                    missing_words_path=self.unknown_words,
                    base_language_model_fst=self.base_language_model_fst,
                    base_language_model_weight=self.base_language_model_weight,
                    mixed_language_model_fst=self.mixed_language_model_fst,
                )
            else:
                _LOGGER.warning("Not overwriting HCLG.fst")
                kaldi_dir = rhasspyasr_kaldi.get_kaldi_dir()
                rhasspyasr_kaldi.train_prepare_online_decoding(
                    self.model_dir, self.graph_dir, kaldi_dir)

            yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("train")
            yield AsrError(error=str(e), site_id=site_id, session_id=train.id)
Ejemplo n.º 10
0
    async def start_listening(
        self, message: AsrStartListening
    ) -> typing.AsyncIterable[typing.Union[StopListeningType, AsrError]]:
        """Start recording audio data for a session."""
        try:
            if message.session_id in self.sessions:
                # Stop existing session
                async for stop_message in self.stop_listening(
                        AsrStopListening(site_id=message.site_id,
                                         session_id=message.session_id)):
                    yield stop_message

            if self.free_transcribers:
                # Re-use existing transcriber
                info = self.free_transcribers.pop()

                _LOGGER.debug("Re-using existing transcriber (session_id=%s)",
                              message.session_id)
            else:
                # Create new transcriber
                info = TranscriberInfo(reuse=self.reuse_transcribers)
                _LOGGER.debug("Creating new transcriber session %s",
                              message.session_id)

                def transcribe_proc(info, transcriber_factory, sample_rate,
                                    sample_width, channels):
                    def audio_stream(frame_queue) -> typing.Iterable[bytes]:
                        # Pull frames from the queue
                        frames = frame_queue.get()
                        while frames:
                            yield frames
                            frames = frame_queue.get()

                    try:
                        info.transcriber = transcriber_factory(
                            port_num=self.kaldi_port)

                        assert (info.transcriber
                                is not None), "Failed to create transcriber"

                        while True:
                            # Wait for session to start
                            info.ready_event.wait()
                            info.ready_event.clear()

                            # Get result of transcription
                            result = info.transcriber.transcribe_stream(
                                audio_stream(info.frame_queue),
                                sample_rate,
                                sample_width,
                                channels,
                            )

                            _LOGGER.debug("Transcription result: %s", result)

                            assert (result is not None
                                    and result.text), "Null transcription"

                            # Signal completion
                            info.result = result
                            info.result_event.set()

                            if not self.reuse_transcribers:
                                try:
                                    info.transcriber.stop()
                                except Exception:
                                    _LOGGER.exception("Transcriber stop")

                                break
                    except Exception:
                        _LOGGER.exception("session proc")

                        # Mark as not reusable
                        info.reuse = False

                        # Stop transcriber
                        if info.transcriber is not None:
                            try:
                                info.transcriber.stop()
                            except Exception:
                                _LOGGER.exception("Transcriber stop")

                        # Signal failure
                        info.transcriber = None
                        info.result = Transcription(text="",
                                                    likelihood=0,
                                                    transcribe_seconds=0,
                                                    wav_seconds=0)
                        info.result_event.set()

                # Run in separate thread
                info.thread = threading.Thread(
                    target=transcribe_proc,
                    args=(
                        info,
                        self.transcriber_factory,
                        self.sample_rate,
                        self.sample_width,
                        self.channels,
                    ),
                    daemon=True,
                )

                info.thread.start()

            # ---------------------------------------------------------------------

            # Settings for session
            info.start_listening = message

            # Signal session thread to start
            info.ready_event.set()

            if message.stop_on_silence:
                # Begin silence detection
                if info.recorder is None:
                    info.recorder = self.recorder_factory()

                info.recorder.start()
            else:
                # Use internal buffer (no silence detection)
                info.audio_buffer = bytes()

            self.sessions[message.session_id] = info
            _LOGGER.debug("Starting listening (session_id=%s)",
                          message.session_id)
            self.first_audio = True
        except Exception as e:
            _LOGGER.exception("start_listening")
            yield AsrError(
                error=str(e),
                context=repr(message),
                site_id=message.site_id,
                session_id=message.session_id,
            )
    async def handle_train(
        self, train: AsrTrain, site_id: str = "default"
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[AsrTrainSuccess, TopicArgs], AsrError]
    ]:
        """Re-trains ASR system"""
        try:
            if not self.base_dictionaries:
                _LOGGER.warning(
                    "No base dictionaries provided. Training will likely fail."
                )

            # Load base dictionaries
            pronunciations: PronunciationsType = defaultdict(list)
            for base_dict in self.base_dictionaries:
                if not os.path.exists(base_dict.path):
                    _LOGGER.warning(
                        "Base dictionary does not exist: %s", base_dict.path
                    )
                    continue

                # Re-load dictionary if modification time has changed
                dict_mtime_ns = os.stat(base_dict.path).st_mtime_ns
                if (base_dict.mtime_ns is None) or (
                    base_dict.mtime_ns != dict_mtime_ns
                ):
                    base_dict.mtime_ns = dict_mtime_ns
                    _LOGGER.debug("Loading base dictionary from %s", base_dict.path)
                    with open(base_dict.path, "r") as base_dict_file:
                        rhasspynlu.g2p.read_pronunciations(
                            base_dict_file, word_dict=base_dict.pronunciations
                        )

                for word in base_dict.pronunciations:
                    pronunciations[word].extend(base_dict.pronunciations[word])

            # Load intent graph
            _LOGGER.debug("Loading %s", train.graph_path)
            with gzip.GzipFile(train.graph_path, mode="rb") as graph_gzip:
                self.intent_graph = nx.readwrite.gpickle.read_gpickle(graph_gzip)

            # Clean LM cache completely
            for lm_path in self.lm_cache_paths.values():
                try:
                    lm_path.unlink()
                except Exception:
                    pass

            self.lm_cache_paths = {}
            self.lm_cache_transcribers = {}

            # Generate dictionary/language model
            if not self.no_overwrite_train:
                _LOGGER.debug("Starting training")
                rhasspyasr_pocketsphinx.train(
                    self.intent_graph,
                    self.dictionary,
                    self.language_model,
                    pronunciations,
                    dictionary_word_transform=self.dictionary_word_transform,
                    g2p_model=self.g2p_model,
                    g2p_word_transform=self.g2p_word_transform,
                    missing_words_path=self.unknown_words,
                    base_language_model_fst=self.base_language_model_fst,
                    base_language_model_weight=self.base_language_model_weight,
                    mixed_language_model_fst=self.mixed_language_model_fst,
                )
            else:
                _LOGGER.warning("Not overwriting dictionary/language model")

            _LOGGER.debug("Re-loading transcriber")
            self.transcriber = self.make_transcriber(self.language_model)

            yield (AsrTrainSuccess(id=train.id), {"site_id": site_id})
        except Exception as e:
            _LOGGER.exception("handle_train")
            yield AsrError(
                error=str(e),
                context=repr(self.transcriber),
                site_id=site_id,
                session_id=train.id,
            )
Ejemplo n.º 12
0
    async def handle_stop_listening(
        self, stop_listening: AsrStopListening
    ) -> typing.AsyncIterable[
        typing.Union[
            AsrTextCaptured, typing.Tuple[AsrAudioCaptured, TopicArgs], AsrError
        ]
    ]:
        """Stop ASR session."""
        _LOGGER.debug("<- %s", stop_listening)

        try:
            session = self.asr_sessions.pop(stop_listening.session_id, None)
            if session is None:
                _LOGGER.warning("Session not found for %s", stop_listening.session_id)
                return

            assert session.sample_rate is not None, "No sample rate"
            assert session.sample_width is not None, "No sample width"
            assert session.channels is not None, "No channels"

            if session.start_listening.stop_on_silence:
                # Use recorded voice command
                audio_data = session.recorder.stop()
            else:
                # Use entire audio
                audio_data = session.audio_data

            # Process entire WAV file
            wav_bytes = self.to_wav_bytes(
                audio_data, session.sample_rate, session.sample_width, session.channels
            )
            _LOGGER.debug("Received %s byte(s) of WAV data", len(wav_bytes))

            if self.asr_url:
                _LOGGER.debug(self.asr_url)

                # Remote ASR server
                async with self.http_session.post(
                    self.asr_url,
                    data=wav_bytes,
                    headers={"Content-Type": "audio/wav", "Accept": "application/json"},
                    ssl=self.ssl_context,
                ) as response:
                    response.raise_for_status()
                    transcription_dict = await response.json()
            elif self.asr_command:
                # Local ASR command
                _LOGGER.debug(self.asr_command)

                start_time = time.perf_counter()
                proc = await asyncio.create_subprocess_exec(
                    *self.asr_command,
                    stdin=asyncio.subprocess.PIPE,
                    stdout=asyncio.subprocess.PIPE,
                    stderr=asyncio.subprocess.PIPE,
                )

                output, error = await proc.communicate(wav_bytes)

                if error:
                    _LOGGER.debug(error.decode())

                text = output.decode()
                end_time = time.perf_counter()

                transcription_dict = {
                    "text": text,
                    "transcribe_seconds": (end_time - start_time),
                }
            else:
                # Empty transcription
                _LOGGER.warning(
                    "No ASR URL or command. Only empty transcriptions will be returned."
                )
                transcription_dict = {}

            # Publish transcription
            yield AsrTextCaptured(
                text=transcription_dict.get("text", ""),
                likelihood=float(transcription_dict.get("likelihood", 0)),
                seconds=float(transcription_dict.get("transcribe_seconds", 0)),
                site_id=stop_listening.site_id,
                session_id=stop_listening.session_id,
                lang=session.start_listening.lang,
            )

            if session.start_listening.send_audio_captured:
                # Send audio data
                yield (
                    AsrAudioCaptured(wav_bytes=wav_bytes),
                    {
                        "site_id": stop_listening.site_id,
                        "session_id": stop_listening.session_id,
                    },
                )

        except Exception as e:
            _LOGGER.exception("handle_stop_listening")
            yield AsrError(
                error=str(e),
                context=f"url='{self.asr_url}', command='{self.asr_command}'",
                site_id=stop_listening.site_id,
                session_id=stop_listening.session_id,
            )