예제 #1
0
def test_hotword_detected():
    """Test HotwordDetected."""
    assert HotwordDetected.is_topic(HotwordDetected.topic(wakeword_id=wakeword_id))
    assert (
        HotwordDetected.get_wakeword_id(HotwordDetected.topic(wakeword_id=wakeword_id))
        == wakeword_id
    )
예제 #2
0
    async def on_message(
        self,
        message: Message,
        site_id: typing.Optional[str] = None,
        session_id: typing.Optional[str] = None,
        topic: typing.Optional[str] = None,
    ) -> GeneratorType:
        if isinstance(message, AsrTextCaptured):
            # ASR transcription received
            if (not message.session_id) or (
                not self.valid_session_id(message.session_id)
            ):
                _LOGGER.warning("Ignoring unknown session %s", message.session_id)
                return

            async for play_recorded_result in self.maybe_play_sound(
                "recorded", site_id=message.site_id
            ):
                yield play_recorded_result

            async for text_result in self.handle_text_captured(message):
                yield text_result

        elif isinstance(message, AudioPlayFinished):
            # Audio output finished
            play_finished_event = self.message_events[AudioPlayFinished].get(message.id)
            if play_finished_event:
                play_finished_event.set()
        elif isinstance(message, DialogueConfigure):
            # Configure intent filter
            self.handle_configure(message)
        elif isinstance(message, DialogueStartSession):
            # Start session
            async for start_result in self.handle_start(message):
                yield start_result
        elif isinstance(message, DialogueContinueSession):
            # Continue session
            async for continue_result in self.handle_continue(message):
                yield continue_result
        elif isinstance(message, DialogueEndSession):
            # End session
            async for end_result in self.handle_end(message):
                yield end_result
        elif isinstance(message, HotwordDetected):
            # Wakeword detected
            assert topic, "Missing topic"
            wakeword_id = HotwordDetected.get_wakeword_id(topic)
            if (not self.wakeword_ids) or (wakeword_id in self.wakeword_ids):
                async for wake_result in self.handle_wake(wakeword_id, message):
                    yield wake_result
            else:
                _LOGGER.warning("Ignoring wake word id=%s", wakeword_id)
        elif isinstance(message, NluIntent):
            # Intent recognized
            await self.handle_recognized(message)
        elif isinstance(message, NluIntentNotRecognized):
            # Intent not recognized
            async for play_error_result in self.maybe_play_sound(
                "error", site_id=message.site_id
            ):
                yield play_error_result

            async for not_recognized_result in self.handle_not_recognized(message):
                yield not_recognized_result
        elif isinstance(message, TtsSayFinished):
            # Text to speech finished
            say_finished_event = self.message_events[TtsSayFinished].pop(
                message.id, None
            )
            if say_finished_event:
                say_finished_event.set()
        else:
            _LOGGER.warning("Unexpected message: %s", message)
예제 #3
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            topic, payload = msg.topic, msg.payload

            for message, site_id, _ in HermesClient.parse_mqtt_message(
                topic, payload, self.subscribed_types
            ):
                is_self_site_id = True
                if site_id and self.site_ids and (site_id not in self.site_ids):
                    # Invalid site id
                    is_self_site_id = False

                if is_self_site_id:
                    # Base-only messages
                    if isinstance(message, AsrAudioCaptured):
                        # Audio data from ASR session
                        assert site_id, "Missing site id"
                        self.last_audio_captured = message
                        self.handle_message(topic, message)
                    elif isinstance(message, AsrError):
                        # ASR service error
                        self.handle_message(topic, message)
                    elif isinstance(message, AsrTextCaptured):
                        # Successful transcription
                        self.handle_message(topic, message)

                        # Report to websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("text", message)
                            )
                    elif isinstance(message, AudioDevices):
                        # Microphones or speakers
                        self.handle_message(topic, message)
                    elif isinstance(message, AudioPlayBytes):
                        # Request to play audio
                        assert site_id, "Missing site id"
                        self.handle_message(topic, message)
                    elif isinstance(message, AudioPlayError):
                        # Error playing audio
                        self.handle_message(topic, message)
                    elif isinstance(message, AudioPlayFinished):
                        # Audio finished playing
                        assert site_id, "Missing site id"
                        self.handle_message(topic, message)
                    elif isinstance(message, AudioRecordError):
                        # Error recording audio
                        self.handle_message(topic, message)
                    elif isinstance(message, AudioFrame):
                        # Recorded audio frame
                        assert site_id, "Missing site id"

                        self.compute_audio_energies(message.wav_bytes)

                        # Report to websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("audiosummary", self.audio_energies)
                            )
                    elif isinstance(message, DialogueSessionStarted):
                        # Dialogue session started
                        self.handle_message(topic, message)
                    elif isinstance(message, G2pPhonemes):
                        # Word pronunciations
                        self.handle_message(topic, message)
                    elif isinstance(message, Hotwords):
                        # Hotword list
                        self.handle_message(topic, message)
                    elif isinstance(message, HotwordDetected):
                        _LOGGER.debug("<- %s", message)

                        # Hotword detected
                        wakeword_id = HotwordDetected.get_wakeword_id(topic)
                        self.handle_message(topic, message)

                        # Report to websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("wake", message, wakeword_id)
                            )

                        # Warn user if they're expected wake -> ASR -> NLU workflow
                        if (self.dialogue_system == "dummy") and (
                            self.asr_system != "dummy"
                        ):
                            _LOGGER.warning(
                                "Dialogue management is disabled. ASR will NOT be automatically enabled."
                            )
                    elif isinstance(message, (HotwordError, HotwordExampleRecorded)):
                        # Other hotword message
                        self.handle_message(topic, message)
                    elif isinstance(message, NluError):
                        # NLU service error
                        self.handle_message(topic, message)
                    elif isinstance(message, NluIntent):
                        _LOGGER.debug("<- %s", message)

                        # Successful intent recognition
                        self.handle_message(topic, message)

                        # Report to websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("intent", message)
                            )
                    elif isinstance(message, NluIntentNotRecognized):
                        _LOGGER.debug("<- %s", message)

                        # Failed intent recognition
                        self.handle_message(topic, message)

                        # Report to websockets
                        for queue in self.message_queues:
                            queue.put_nowait(("intent", message))
                    elif isinstance(message, TtsSayFinished):
                        # Text to speech complete
                        self.handle_message(topic, message)
                    elif isinstance(message, AsrTrainSuccess):
                        # ASR training success
                        assert site_id, "Missing site id"
                        self.handle_message(topic, message)
                    elif isinstance(message, NluTrainSuccess):
                        # NLU training success
                        assert site_id, "Missing site id"
                        self.handle_message(topic, message)
                    elif isinstance(message, TtsError):
                        # Error during text to speech
                        self.handle_message(topic, message)
                    elif isinstance(message, Voices):
                        # Text to speech voices
                        self.handle_message(topic, message)
                    else:
                        _LOGGER.warning("Unexpected message: %s", message)
                else:
                    # Check for satellite messages.
                    # This ensures that websocket events are reported on the base
                    # station as well as the satellite.
                    if isinstance(message, (NluIntent, NluIntentNotRecognized)) and (
                        site_id in self.satellite_site_ids["intent"]
                    ):
                        # Report satellite message to base websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("intent", message)
                            )
                    elif isinstance(message, AsrTextCaptured) and (
                        site_id in self.satellite_site_ids["speech_to_text"]
                    ):
                        # Report satellite message to base websockets
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("text", message)
                            )
                    elif isinstance(message, HotwordDetected) and (
                        site_id in self.satellite_site_ids["wake"]
                    ):
                        # Report satellite message to base websockets
                        wakeword_id = HotwordDetected.get_wakeword_id(topic)
                        for queue in self.message_queues:
                            self.loop.call_soon_threadsafe(
                                queue.put_nowait, ("wake", message, wakeword_id)
                            )
                    elif isinstance(message, (TtsSayFinished, AudioPlayFinished)) and (
                        site_id in self.satellite_site_ids["text_to_speech"]
                    ):
                        # Satellite text to speech/audio finished
                        self.handle_message(topic, message)

            # -----------------------------------------------------------------

            # Forward to external message queues
            for queue in self.message_queues:
                queue.put_nowait(("mqtt", topic, payload))

        except Exception:
            _LOGGER.exception("on_message")