Beispiel #1
0
def test_hotword_detected():
    """Test HotwordDetected."""
    assert HotwordDetected.is_topic(HotwordDetected.topic(wakeword_id=wakeword_id))
    assert (
        HotwordDetected.get_wakeword_id(HotwordDetected.topic(wakeword_id=wakeword_id))
        == wakeword_id
    )
    def test_topics(self):
        """Check get_ methods for topics"""
        siteId = "testSiteId"
        requestId = "testRequestId"
        intentName = "testIntent"
        wakewordId = "testWakeWord"

        # AudioFrame
        self.assertTrue(AudioFrame.is_topic(AudioFrame.topic(siteId=siteId)))
        self.assertEqual(
            AudioFrame.get_siteId(AudioFrame.topic(siteId=siteId)), siteId)

        # AudioPlayBytes
        self.assertTrue(
            AudioPlayBytes.is_topic(
                AudioPlayBytes.topic(siteId=siteId, requestId=requestId)))
        self.assertEqual(
            AudioPlayBytes.get_siteId(
                AudioPlayBytes.topic(siteId=siteId, requestId=requestId)),
            siteId,
        )
        self.assertEqual(
            AudioPlayBytes.get_requestId(
                AudioPlayBytes.topic(siteId=siteId, requestId=requestId)),
            requestId,
        )

        # AudioPlayFinished
        self.assertTrue(
            AudioPlayFinished.is_topic(AudioPlayFinished.topic(siteId=siteId)))
        self.assertEqual(
            AudioPlayFinished.get_siteId(
                AudioPlayFinished.topic(siteId=siteId)), siteId)

        # NluIntent
        self.assertTrue(
            NluIntent.is_topic(NluIntent.topic(intentName=intentName)))
        self.assertEqual(
            NluIntent.get_intentName(NluIntent.topic(intentName=intentName)),
            intentName)

        # HotwordDetected
        self.assertTrue(
            HotwordDetected.is_topic(
                HotwordDetected.topic(wakewordId=wakewordId)))
        self.assertEqual(
            HotwordDetected.get_wakewordId(
                HotwordDetected.topic(wakewordId=wakewordId)),
            wakewordId,
        )
Beispiel #3
0
    async def handle_detection(
        self, matching_indexes: typing.List[int]
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[HotwordDetected, TopicArgs], HotwordError]
    ]:
        """Handle a successful hotword detection"""
        try:
            template = self.raven.templates[matching_indexes[0]]
            wakeword_id = self.wakeword_id
            if not wakeword_id:
                wakeword_id = template.name

            yield (
                HotwordDetected(
                    site_id=self.last_audio_site_id,
                    model_id=template.name,
                    current_sensitivity=self.raven.distance_threshold,
                    model_version="",
                    model_type="personal",
                ),
                {"wakeword_id": wakeword_id},
            )
        except Exception as e:
            _LOGGER.exception("handle_detection")
            yield HotwordError(
                error=str(e),
                context=str(matching_indexes),
                site_id=self.last_audio_site_id,
            )
    async def handle_detection(
        self,
        keyword_index: int,
        wakeword_id: str,
        site_id="default"
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[
            HotwordDetected, TopicArgs], HotwordError]]:
        """Handle a successful hotword detection"""
        try:
            assert (len(self.model_ids) >
                    keyword_index), f"Missing {keyword_index} in models"

            yield (
                HotwordDetected(
                    site_id=site_id,
                    model_id=self.model_ids[keyword_index],
                    current_sensitivity=self.sensitivities[keyword_index],
                    model_version="",
                    model_type="personal",
                    lang=self.lang,
                ),
                {
                    "wakeword_id": wakeword_id
                },
            )
        except Exception as e:
            _LOGGER.exception("handle_detection")
            yield HotwordError(error=str(e),
                               context=str(keyword_index),
                               site_id=site_id)
    async def handle_detection(
        self,
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[
            HotwordDetected, TopicArgs], HotwordError]]:
        """Handle a successful hotword detection"""
        try:
            wakeword_id = self.wakeword_id
            if not wakeword_id:
                # Use file name
                wakeword_id = self.model_path.stem

            yield (
                HotwordDetected(
                    site_id=self.last_audio_site_id,
                    model_id=self.model_id,
                    current_sensitivity=self.sensitivity,
                    model_version="",
                    model_type="personal",
                ),
                {
                    "wakeword_id": wakeword_id
                },
            )
        except Exception as e:
            _LOGGER.exception("handle_detection")
            yield HotwordError(
                error=str(e),
                context=str(self.model_path),
                site_id=self.last_audio_site_id,
            )
Beispiel #6
0
 async def handle_detection(
     self,
     wakeword_id: str,
     site_id: str = "default"
 ) -> typing.AsyncIterable[typing.Union[typing.Tuple[
         HotwordDetected, TopicArgs], HotwordError]]:
     """Handle a successful hotword detection"""
     try:
         yield (
             HotwordDetected(
                 site_id=site_id,
                 model_id=self.keyphrase,
                 current_sensitivity=self.keyphrase_threshold,
                 model_version="",
                 model_type="personal",
             ),
             {
                 "wakeword_id": wakeword_id
             },
         )
     except Exception as e:
         _LOGGER.exception("handle_detection")
         yield HotwordError(error=str(e),
                            context=self.keyphrase,
                            site_id=site_id)
    async def handle_detection(
        self, model_index: int, wakeword_id: str, site_info: SiteInfo
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[
            HotwordDetected, TopicArgs], HotwordError]]:
        """Handle a successful hotword detection"""
        try:
            assert (len(site_info.model_ids) >
                    model_index), f"Missing {model_index} in models"
            sensitivity = 0.5

            if model_index < len(self.models):
                sensitivity = self.models[model_index].float_sensitivity()

            yield (
                HotwordDetected(
                    site_id=site_info.site_id,
                    model_id=site_info.model_ids[model_index],
                    current_sensitivity=sensitivity,
                    model_version="",
                    model_type="personal",
                    lang=self.lang,
                ),
                {
                    "wakeword_id": wakeword_id
                },
            )
        except Exception as e:
            _LOGGER.exception("handle_detection")
            yield HotwordError(error=str(e),
                               context=str(model_index),
                               site_id=site_info.site_id)
Beispiel #8
0
    async def async_test_basic_wake(self):
        """Test wake/asr/nlu workflow without a satellite"""
        for event_name in ["started", "ended"]:
            self.events[event_name] = asyncio.Event()

        # Wait until connected
        self.hermes.on_message = self.on_message_test_basic_wake
        await asyncio.wait_for(self.hermes.mqtt_connected_event.wait(),
                               timeout=5)

        # Start listening
        self.hermes.subscribe(DialogueSessionStarted, DialogueSessionEnded)
        message_task = asyncio.create_task(self.hermes.handle_messages_async())

        # Send wake up signal
        self.hermes.publish(
            HotwordDetected(model_id="default", site_id=self.base_id),
            wakeword_id="default",
        )

        # Wait for up to 10 seconds
        await asyncio.wait_for(
            asyncio.gather(*[e.wait() for e in self.events.values()]),
            timeout=10)

        message_task.cancel()
    async def handle_detection(
        self, matching_indexes: typing.List[int], raven: Raven
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[
            HotwordDetected, TopicArgs], HotwordError]]:
        """Handle a successful hotword detection"""
        try:
            template = raven.templates[matching_indexes[0]]

            wakeword_id = raven.keyword_name or template.name
            if not wakeword_id:
                wakeword_id = "default"

            yield (
                HotwordDetected(
                    site_id=self.last_audio_site_id,
                    model_id=template.name,
                    current_sensitivity=raven.probability_threshold,
                    model_version="",
                    model_type="personal",
                    lang=self.lang,
                ),
                {
                    "wakeword_id": wakeword_id
                },
            )
        except Exception as e:
            _LOGGER.exception("handle_detection")
            yield HotwordError(
                error=str(e),
                context=f"{raven.keyword_name}: {template.name}",
                site_id=self.last_audio_site_id,
            )
Beispiel #10
0
    async def async_test_ws_wake(self):
        """Test api/events/wake endpoint"""
        # Start listening
        event_queue = asyncio.Queue()
        connected = asyncio.Event()
        receive_task = asyncio.ensure_future(
            self.async_ws_receive("events/wake", event_queue, connected)
        )
        await asyncio.wait_for(connected.wait(), timeout=5)

        # Send in a message
        detected = HotwordDetected(model_id=str(uuid4()), site_id=self.site_id)
        wakeword_id = str(uuid4())

        self.client.publish(detected.topic(wakeword_id=wakeword_id), detected.payload())

        # Wait for response
        event = json.loads(await asyncio.wait_for(event_queue.get(), timeout=5))
        self.assertEqual(wakeword_id, event.get("wakewordId", ""))
        self.assertEqual(detected.site_id, event.get("siteId", ""))

        # Stop listening
        receive_task.cancel()
    def _subscribe_callbacks(self):
        # Remove duplicate intent names
        intent_names = list(set(self._callbacks_intent.keys()))
        topics = [
            NluIntent.topic(intent_name=intent_name)
            for intent_name in intent_names
        ]

        if self._callbacks_hotword:
            topics.append(HotwordDetected.topic())

        if self._callbacks_intent_not_recognized:
            topics.append(NluIntentNotRecognized.topic())

        topic_names = list(set(self._callbacks_topic.keys()))
        topics.extend(topic_names)
        topics.extend(self._additional_topic)

        self.subscribe_topics(*topics)
Beispiel #12
0
    def handle_detection(
        self, keyword_index, siteId="default"
    ) -> typing.Union[HotwordDetected, HotwordError]:
        """Handle a successful hotword detection"""
        try:
            assert (
                len(self.model_ids) > keyword_index
            ), f"Missing {keyword_index} in models"

            return HotwordDetected(
                siteId=siteId,
                modelId=self.model_ids[keyword_index],
                currentSensitivity=self.sensitivities[keyword_index],
                modelVersion="",
                modelType="personal",
            )
        except Exception as e:
            _LOGGER.exception("handle_detection")
            return HotwordError(error=str(e), context=str(keyword_index), siteId=siteId)
async def test_callbacks_hotword(mocker):
    """Test hotword callbacks."""
    app = HermesApp("Test HotwordDetected", mqtt_client=mocker.MagicMock())

    # Mock wake callback and apply on_hotword decorator.
    wake = mocker.MagicMock()
    app.on_hotword(wake)

    # Simulate app.run() without the MQTT client.
    app._subscribe_callbacks()

    # Check whether callback has been added to the app.
    assert len(app._callbacks_hotword) == 1
    assert app._callbacks_hotword[0] == wake

    # Simulate detected hotword.
    await app.on_raw_message(HOTWORD_TOPIC, HOTWORD_PAYLOAD)

    # Check whether callback has been called with the right Rhasspy Hermes object.
    wake.assert_called_once_with(HotwordDetected.from_json(HOTWORD_PAYLOAD))
Beispiel #14
0
    def __init__(
        self,
        client,
        siteIds: typing.Optional[typing.List[str]] = None,
        wakewordIds: typing.Optional[typing.List[str]] = None,
        loop=None,
    ):
        self.client = client
        self.siteIds = siteIds or []
        self.loop = loop or asyncio.get_event_loop()

        self.session: typing.Optional[SessionInfo] = None
        self.session_queue = deque()

        self.wakeword_topics = {
            HotwordDetected.topic(wakewordId=w): w
            for w in wakewordIds or []
        }

        # Set when TtsSayFinished comes back
        self.say_finished_event = asyncio.Event()
        self.say_finished_timeout: float = 10
Beispiel #15
0
"""Tests for rhasspyhermes_app hotword."""
# pylint: disable=protected-access,too-many-function-args
import asyncio

import pytest
from rhasspyhermes.wake import HotwordDetected

from rhasspyhermes_app import HermesApp

HOTWORD_TOPIC = f"hermes/hotword/test/detected"
HOTWORD = HotwordDetected("test_model")
HOTWORD_TOPIC2 = f"hermes/hotword/test2/detected"
HOTWORD2 = HotwordDetected("test_model2")

_LOOP = asyncio.get_event_loop()


@pytest.mark.asyncio
async def test_callbacks_hotword(mocker):
    """Test hotword callbacks."""
    app = HermesApp("Test HotwordDetected", mqtt_client=mocker.MagicMock())

    # Mock wake callback and apply on_hotword decorator.
    wake = mocker.MagicMock()
    app.on_hotword(wake)

    # Simulate app.run() without the MQTT client.
    app._subscribe_callbacks()

    # Simulate detected hotword.
    await app.on_raw_message(HOTWORD_TOPIC, HOTWORD.to_json())
Beispiel #16
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload),
                          msg.topic)
            if msg.topic == DialogueStartSession.topic():
                # Start session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_start(DialogueStartSession(**json_payload)),
                    self.loop)
            elif msg.topic == DialogueContinueSession.topic():
                # Continue session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_continue(
                        DialogueContinueSession(**json_payload)),
                    self.loop,
                )
            elif msg.topic == DialogueEndSession.topic():
                # End session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run outside event loop
                self.handle_end(DialogueEndSession(**json_payload))
            elif msg.topic == TtsSayFinished.topic():
                # TTS finished
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Signal event loop
                self.loop.call_soon_threadsafe(self.say_finished_event.set)
            elif msg.topic == AsrTextCaptured.topic():
                # Text captured
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run outside event loop
                self.handle_text_captured(AsrTextCaptured(**json_payload))
            elif msg.topic.startswith(NluIntent.topic(intent_name="")):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                self.handle_recognized(NluIntent(**json_payload))
            elif msg.topic.startswith(NluIntentNotRecognized.topic()):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_not_recognized(
                        NluIntentNotRecognized(**json_payload)),
                    self.loop,
                )
            elif msg.topic in self.wakeword_topics:
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                wakeword_id = self.wakeword_topics[msg.topic]
                asyncio.run_coroutine_threadsafe(
                    self.handle_wake(wakeword_id,
                                     HotwordDetected(**json_payload)),
                    self.loop,
                )
        except Exception:
            _LOGGER.exception("on_message")
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            topic, payload = msg.topic, msg.payload

            for message, site_id, _ in HermesClient.parse_mqtt_message(
                topic, payload, self.subscribed_types
            ):
                is_self_site_id = True
                if site_id and self.site_ids and (site_id not in self.site_ids):
                    # Invalid site id
                    is_self_site_id = False

                if is_self_site_id:
                    # Base-only messages
                    if isinstance(message, AsrAudioCaptured):
                        # Audio data from ASR session
                        assert site_id, "Missing site id"
                        self.last_audio_captured = message
                        self.handle_message(topic, message)
                    elif isinstance(message, AsrError):
                        # ASR service error
                        self.handle_message(topic, message)
                    elif isinstance(message, AsrTextCaptured):
                        # Successful transcription
                        self.handle_message(topic, message)

                        # Report to websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("text", message)
                            )
                    elif isinstance(message, AudioDevices):
                        # Microphones or speakers
                        self.handle_message(topic, message)
                    elif isinstance(message, AudioPlayBytes):
                        # Request to play audio
                        assert site_id, "Missing site id"
                        self.handle_message(topic, message)
                    elif isinstance(message, AudioPlayError):
                        # Error playing audio
                        self.handle_message(topic, message)
                    elif isinstance(message, AudioPlayFinished):
                        # Audio finished playing
                        assert site_id, "Missing site id"
                        self.handle_message(topic, message)
                    elif isinstance(message, AudioRecordError):
                        # Error recording audio
                        self.handle_message(topic, message)
                    elif isinstance(message, AudioFrame):
                        # Recorded audio frame
                        assert site_id, "Missing site id"

                        self.compute_audio_energies(message.wav_bytes)

                        # Report to websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("audiosummary", self.audio_energies)
                            )
                    elif isinstance(message, DialogueSessionStarted):
                        # Dialogue session started
                        self.handle_message(topic, message)
                    elif isinstance(message, G2pPhonemes):
                        # Word pronunciations
                        self.handle_message(topic, message)
                    elif isinstance(message, Hotwords):
                        # Hotword list
                        self.handle_message(topic, message)
                    elif isinstance(message, HotwordDetected):
                        _LOGGER.debug("<- %s", message)

                        # Hotword detected
                        wakeword_id = HotwordDetected.get_wakeword_id(topic)
                        self.handle_message(topic, message)

                        # Report to websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("wake", message, wakeword_id)
                            )

                        # Warn user if they're expected wake -> ASR -> NLU workflow
                        if (self.dialogue_system == "dummy") and (
                            self.asr_system != "dummy"
                        ):
                            _LOGGER.warning(
                                "Dialogue management is disabled. ASR will NOT be automatically enabled."
                            )
                    elif isinstance(message, (HotwordError, HotwordExampleRecorded)):
                        # Other hotword message
                        self.handle_message(topic, message)
                    elif isinstance(message, NluError):
                        # NLU service error
                        self.handle_message(topic, message)
                    elif isinstance(message, NluIntent):
                        _LOGGER.debug("<- %s", message)

                        # Successful intent recognition
                        self.handle_message(topic, message)

                        # Report to websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("intent", message)
                            )
                    elif isinstance(message, NluIntentNotRecognized):
                        _LOGGER.debug("<- %s", message)

                        # Failed intent recognition
                        self.handle_message(topic, message)

                        # Report to websockets
                        for queue in self.message_queues:
                            queue.put_nowait(("intent", message))
                    elif isinstance(message, TtsSayFinished):
                        # Text to speech complete
                        self.handle_message(topic, message)
                    elif isinstance(message, AsrTrainSuccess):
                        # ASR training success
                        assert site_id, "Missing site id"
                        self.handle_message(topic, message)
                    elif isinstance(message, NluTrainSuccess):
                        # NLU training success
                        assert site_id, "Missing site id"
                        self.handle_message(topic, message)
                    elif isinstance(message, TtsError):
                        # Error during text to speech
                        self.handle_message(topic, message)
                    elif isinstance(message, Voices):
                        # Text to speech voices
                        self.handle_message(topic, message)
                    else:
                        _LOGGER.warning("Unexpected message: %s", message)
                else:
                    # Check for satellite messages.
                    # This ensures that websocket events are reported on the base
                    # station as well as the satellite.
                    if isinstance(message, (NluIntent, NluIntentNotRecognized)) and (
                        site_id in self.satellite_site_ids["intent"]
                    ):
                        # Report satellite message to base websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("intent", message)
                            )
                    elif isinstance(message, AsrTextCaptured) and (
                        site_id in self.satellite_site_ids["speech_to_text"]
                    ):
                        # Report satellite message to base websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("text", message)
                            )
                    elif isinstance(message, HotwordDetected) and (
                        site_id in self.satellite_site_ids["wake"]
                    ):
                        # Report satellite message to base websockets
                        wakeword_id = HotwordDetected.get_wakeword_id(topic)
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("wake", message, wakeword_id)
                            )
                    elif isinstance(message, (TtsSayFinished, AudioPlayFinished)) and (
                        site_id in self.satellite_site_ids["text_to_speech"]
                    ):
                        # Satellite text to speech/audio finished
                        self.handle_message(topic, message)

            # -----------------------------------------------------------------

            # Forward to external message queues
            for queue in self.message_queues:
                queue.put_nowait(("mqtt", topic, payload))

        except Exception:
            _LOGGER.exception("on_message")
Beispiel #18
0
    async def on_raw_message(self, topic: str, payload: bytes):
        """This method handles messages from the MQTT broker.

        Arguments:
            topic: The topic of the received MQTT message.

            payload: The payload of the received MQTT message.

        .. warning:: Don't override this method in your app. This is where all the magic happens in Rhasspy Hermes App.
        """
        try:
            if HotwordDetected.is_topic(topic):
                # hermes/hotword/<wakeword_id>/detected
                try:
                    hotword_detected = HotwordDetected.from_json(payload)
                    for function_h in self._callbacks_hotword:
                        await function_h(hotword_detected)
                except KeyError as key:
                    _LOGGER.error("Missing key %s in JSON payload for %s: %s",
                                  key, topic, payload)
            elif NluIntent.is_topic(topic):
                # hermes/intent/<intent_name>
                try:
                    nlu_intent = NluIntent.from_json(payload)
                    intent_name = nlu_intent.intent.intent_name
                    if intent_name in self._callbacks_intent:
                        for function_i in self._callbacks_intent[intent_name]:
                            await function_i(nlu_intent)
                except KeyError as key:
                    _LOGGER.error("Missing key %s in JSON payload for %s: %s",
                                  key, topic, payload)
            elif NluIntentNotRecognized.is_topic(topic):
                # hermes/nlu/intentNotRecognized
                try:
                    nlu_intent_not_recognized = NluIntentNotRecognized.from_json(
                        payload)
                    for function_inr in self._callbacks_intent_not_recognized:
                        await function_inr(nlu_intent_not_recognized)
                except KeyError as key:
                    _LOGGER.error("Missing key %s in JSON payload for %s: %s",
                                  key, topic, payload)
            elif DialogueIntentNotRecognized.is_topic(topic):
                # hermes/dialogueManager/intentNotRecognized
                try:
                    dialogue_intent_not_recognized = DialogueIntentNotRecognized.from_json(
                        payload)
                    for function_dinr in self._callbacks_dialogue_intent_not_recognized:
                        await function_dinr(dialogue_intent_not_recognized)
                except KeyError as key:
                    _LOGGER.error("Missing key %s in JSON payload for %s: %s",
                                  key, topic, payload)
            else:
                unexpected_topic = True
                if topic in self._callbacks_topic:
                    for function_1 in self._callbacks_topic[topic]:
                        await function_1(TopicData(topic, {}), payload)
                        unexpected_topic = False
                else:
                    for function_2 in self._callbacks_topic_regex:
                        if hasattr(function_2, "topic_extras"):
                            topic_extras = getattr(function_2, "topic_extras")
                            for pattern, named_positions in topic_extras:
                                if re.match(pattern, topic) is not None:
                                    data = TopicData(topic, {})
                                    parts = topic.split(sep="/")
                                    if named_positions is not None:
                                        for name, position in named_positions.items(
                                        ):
                                            data.data[name] = parts[position]

                                    function_2(data, payload)
                                    unexpected_topic = False

                if unexpected_topic:
                    _LOGGER.warning("Unexpected topic: %s", topic)

        except Exception:
            _LOGGER.exception("on_raw_message")
    async def on_message(
        self,
        message: Message,
        site_id: typing.Optional[str] = None,
        session_id: typing.Optional[str] = None,
        topic: typing.Optional[str] = None,
    ) -> GeneratorType:
        if isinstance(message, AsrTextCaptured):
            # ASR transcription received
            if (not message.session_id) or (
                not self.valid_session_id(message.session_id)
            ):
                _LOGGER.warning("Ignoring unknown session %s", message.session_id)
                return

            async for play_recorded_result in self.maybe_play_sound(
                "recorded", site_id=message.site_id
            ):
                yield play_recorded_result

            async for text_result in self.handle_text_captured(message):
                yield text_result

        elif isinstance(message, AudioPlayFinished):
            # Audio output finished
            play_finished_event = self.message_events[AudioPlayFinished].get(message.id)
            if play_finished_event:
                play_finished_event.set()
        elif isinstance(message, DialogueConfigure):
            # Configure intent filter
            self.handle_configure(message)
        elif isinstance(message, DialogueStartSession):
            # Start session
            async for start_result in self.handle_start(message):
                yield start_result
        elif isinstance(message, DialogueContinueSession):
            # Continue session
            async for continue_result in self.handle_continue(message):
                yield continue_result
        elif isinstance(message, DialogueEndSession):
            # End session
            async for end_result in self.handle_end(message):
                yield end_result
        elif isinstance(message, HotwordDetected):
            # Wakeword detected
            assert topic, "Missing topic"
            wakeword_id = HotwordDetected.get_wakeword_id(topic)
            if (not self.wakeword_ids) or (wakeword_id in self.wakeword_ids):
                async for wake_result in self.handle_wake(wakeword_id, message):
                    yield wake_result
            else:
                _LOGGER.warning("Ignoring wake word id=%s", wakeword_id)
        elif isinstance(message, NluIntent):
            # Intent recognized
            await self.handle_recognized(message)
        elif isinstance(message, NluIntentNotRecognized):
            # Intent not recognized
            async for play_error_result in self.maybe_play_sound(
                "error", site_id=message.site_id
            ):
                yield play_error_result

            async for not_recognized_result in self.handle_not_recognized(message):
                yield not_recognized_result
        elif isinstance(message, TtsSayFinished):
            # Text to speech finished
            say_finished_event = self.message_events[TtsSayFinished].pop(
                message.id, None
            )
            if say_finished_event:
                say_finished_event.set()
        else:
            _LOGGER.warning("Unexpected message: %s", message)
Beispiel #20
0
    async def handle_audio_frame(
        self,
        wav_bytes: bytes,
        site_id: str = "default",
        session_id: typing.Optional[str] = None,
    ) -> typing.AsyncIterable[
        typing.Union[
            typing.Tuple[HotwordDetected, TopicArgs],
            AsrTextCaptured,
            typing.Tuple[AsrAudioCaptured, TopicArgs],
            AsrError,
        ]
    ]:
        """Add audio frame to open sessions."""
        try:
            if self.asr_enabled:
                if session_id is None:
                    # Add to every open session
                    target_sessions = list(self.asr_sessions.items())
                else:
                    # Add to single session
                    target_sessions = [(session_id, self.asr_sessions[session_id])]

                with io.BytesIO(wav_bytes) as in_io:
                    with wave.open(in_io) as in_wav:
                        # Get WAV details from first frame
                        sample_rate = in_wav.getframerate()
                        sample_width = in_wav.getsampwidth()
                        channels = in_wav.getnchannels()
                        audio_data = in_wav.readframes(in_wav.getnframes())

                # Add to target ASR sessions
                for target_id, session in target_sessions:
                    # Skip non-matching site_id
                    if session.start_listening.site_id != site_id:
                        continue

                    session.sample_rate = sample_rate
                    session.sample_width = sample_width
                    session.channels = channels
                    session.audio_data += audio_data

                    if session.start_listening.stop_on_silence:
                        # Detect silence (end of command)
                        audio_data = self.maybe_convert_wav(
                            wav_bytes,
                            self.recorder_sample_rate,
                            self.recorder_sample_width,
                            self.recorder_channels,
                        )
                        command = session.recorder.process_chunk(audio_data)
                        if command and (command.result == VoiceCommandResult.SUCCESS):
                            # Complete session
                            stop_listening = AsrStopListening(
                                site_id=site_id, session_id=target_id
                            )
                            async for message in self.handle_stop_listening(
                                stop_listening
                            ):
                                yield message

            if self.wake_enabled and (session_id is None) and self.wake_proc:
                # Convert and send to wake command
                audio_bytes = self.maybe_convert_wav(
                    wav_bytes,
                    self.wake_sample_rate,
                    self.wake_sample_width,
                    self.wake_channels,
                )
                assert self.wake_proc.stdin
                self.wake_proc.stdin.write(audio_bytes)
                if self.wake_proc.poll():
                    stdout, stderr = self.wake_proc.communicate()
                    if stderr:
                        _LOGGER.debug(stderr.decode())

                    wakeword_id = stdout.decode().strip()
                    _LOGGER.debug("Detected wake word %s", wakeword_id)
                    yield (
                        HotwordDetected(
                            model_id=wakeword_id,
                            model_version="",
                            model_type="personal",
                            current_sensitivity=1.0,
                            site_id=site_id,
                        ),
                        {"wakeword_id": wakeword_id},
                    )

                    # Restart wake process
                    self.start_wake_command()

        except Exception:
            _LOGGER.exception("handle_audio_frame")