예제 #1
0
    async def end_session(
        self, reason: DialogueSessionTerminationReason, site_id: str
    ) -> typing.AsyncIterable[typing.Union[EndSessionType, StartSessionType,
                                           SayType]]:
        """End current session and start queued session."""
        assert self.session is not None, "No session"
        session = self.session

        if session.start_session.init.type != DialogueActionType.NOTIFICATION:
            # Stop listening
            yield AsrStopListening(site_id=session.site_id,
                                   session_id=session.session_id)

        yield DialogueSessionEnded(
            site_id=site_id,
            session_id=session.session_id,
            custom_data=session.custom_data,
            termination=DialogueSessionTermination(reason=reason),
        )

        self.session = None

        # Check session queue
        if self.session_queue:
            _LOGGER.debug("Handling queued session")
            async for start_result in self.start_session(
                    self.session_queue.popleft()):
                yield start_result
        else:
            # Enable hotword if no queued sessions
            yield HotwordToggleOn(site_id=session.site_id,
                                  reason=HotwordToggleReason.DIALOGUE_SESSION)
예제 #2
0
    async def say(
        self,
        text: str,
        site_id="default",
        session_id="",
        request_id: typing.Optional[str] = None,
        block: bool = True,
    ) -> typing.AsyncIterable[
        typing.Union[
            TtsSay, HotwordToggleOn, HotwordToggleOff, AsrToggleOn, AsrToggleOff
        ]
    ]:
        """Send text to TTS system and wait for reply."""
        finished_event = asyncio.Event()
        finished_id = request_id or str(uuid4())
        self.message_events[TtsSayFinished][finished_id] = finished_event

        # Disable ASR/hotword at site
        yield HotwordToggleOff(site_id=site_id, reason=HotwordToggleReason.TTS_SAY)
        yield AsrToggleOff(site_id=site_id, reason=AsrToggleReason.TTS_SAY)

        # Wait for messages to be delivered
        await asyncio.sleep(self.toggle_delay)

        try:
            # Forward to TTS
            _LOGGER.debug("Say: %s", text)
            yield TtsSay(
                id=finished_id, site_id=site_id, session_id=session_id, text=text
            )

            if block:
                # Wait for finished event
                say_finished_timeout = 10.0
                if self.say_chars_per_second > 0:
                    # Estimate timeout based on text length
                    say_finished_timeout = max(
                        say_finished_timeout, len(text) / self.say_chars_per_second
                    )

                _LOGGER.debug(
                    "Waiting for sayFinished (id=%s, timeout=%s)",
                    finished_id,
                    say_finished_timeout,
                )
                await asyncio.wait_for(
                    finished_event.wait(), timeout=say_finished_timeout
                )
        except asyncio.TimeoutError:
            _LOGGER.warning("Did not receive sayFinished before timeout")
        except Exception:
            _LOGGER.exception("say")
        finally:
            # Wait for audio to finish play
            await asyncio.sleep(self.toggle_delay)

            # Re-enable ASR/hotword at site
            yield HotwordToggleOn(site_id=site_id, reason=HotwordToggleReason.TTS_SAY)
            yield AsrToggleOn(site_id=site_id, reason=AsrToggleReason.TTS_SAY)
예제 #3
0
    async def play_wav_data(
        self, wav_bytes: bytes, site_id: typing.Optional[str] = None
    ) -> AudioPlayFinished:
        """Play WAV data through speakers."""
        if self.sound_system == "dummy":
            raise RuntimeError("No audio output system configured")

        site_id = site_id or self.site_id
        request_id = str(uuid4())

        def handle_finished():
            while True:
                _, message = yield

                if (
                    isinstance(message, AudioPlayFinished)
                    and (message.id == request_id)
                ) or isinstance(message, AudioPlayError):
                    return message

        def messages():
            yield (
                AudioPlayBytes(wav_bytes=wav_bytes),
                {"site_id": site_id, "request_id": request_id},
            )

        message_types: typing.List[typing.Type[Message]] = [
            AudioPlayFinished,
            AudioPlayError,
        ]

        # Disable hotword/ASR
        self.publish(
            HotwordToggleOff(site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO)
        )
        self.publish(AsrToggleOff(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO))

        try:
            # Expecting only a single result
            result = None
            async for response in self.publish_wait(
                handle_finished(), messages(), message_types
            ):
                result = response

            if isinstance(result, AudioPlayError):
                _LOGGER.error(result)
                raise RuntimeError(result.error)

            assert isinstance(result, AudioPlayFinished)
            return result
        finally:
            # Enable hotword/ASR
            self.publish(
                HotwordToggleOn(site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO)
            )
            self.publish(
                AsrToggleOn(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO)
            )
    async def handle_end(
        self, end_session: DialogueEndSession
    ) -> typing.AsyncIterable[typing.Union[EndSessionType, StartSessionType,
                                           SayType]]:
        """End the current session."""
        site_session = self.all_sessions.get(end_session.session_id)
        if not site_session:
            _LOGGER.warning("No session for id %s. Cannot end.",
                            end_session.session_id)
            return

        try:
            # Say text before ending session
            if end_session.text:
                # Forward to TTS
                async for tts_result in self.say(
                        end_session.text,
                        site_id=site_session.site_id,
                        session_id=end_session.session_id,
                ):
                    yield tts_result

            # Update fields
            if end_session.custom_data is not None:
                site_session.custom_data = end_session.custom_data

            _LOGGER.debug("Session ended nominally: %s",
                          site_session.session_id)
            async for end_result in self.end_session(
                    DialogueSessionTerminationReason.NOMINAL,
                    site_id=site_session.site_id,
                    session_id=site_session.session_id,
                    start_next_session=True,
            ):
                yield end_result
        except Exception as e:
            _LOGGER.exception("handle_end")
            yield DialogueError(
                error=str(e),
                context=str(end_session),
                site_id=site_session.site_id,
                session_id=end_session.session_id,
            )

            # Enable hotword on error
            yield HotwordToggleOn(
                site_id=site_session.site_id,
                reason=HotwordToggleReason.DIALOGUE_SESSION,
            )
예제 #5
0
    def on_connect(self, client, userdata, flags, rc):
        """Connected to MQTT broker."""
        try:
            topics = [HotwordToggleOn.topic(), HotwordToggleOff.topic()]

            if self.audioframe_topics:
                # Specific siteIds
                topics.extend(self.audioframe_topics)
            else:
                # All siteIds
                topics.append(AudioFrame.topic(siteId="#"))

            for topic in topics:
                self.client.subscribe(topic)
                _LOGGER.debug("Subscribed to %s", topic)
        except Exception:
            _LOGGER.exception("on_connect")
예제 #6
0
    async def handle_text_captured(
        self, text_captured: AsrTextCaptured
    ) -> typing.AsyncIterable[typing.Union[AsrStopListening, HotwordToggleOn,
                                           NluQuery]]:
        """Handle ASR text captured for session."""
        try:
            if not text_captured.session_id:
                _LOGGER.warning("Missing session id on text captured message.")
                return

            site_session = self.all_sessions.get(text_captured.session_id)
            if site_session is None:
                _LOGGER.warning(
                    "No session for id %s. Dropping captured text from ASR.",
                    text_captured.session_id,
                )
                return

            _LOGGER.debug("Received text: %s", text_captured.text)

            # Record result
            site_session.text_captured = text_captured

            # Stop listening
            yield AsrStopListening(site_id=text_captured.site_id,
                                   session_id=site_session.session_id)

            # Enable hotword
            yield HotwordToggleOn(
                site_id=text_captured.site_id,
                reason=HotwordToggleReason.DIALOGUE_SESSION,
            )

            # Perform query
            yield NluQuery(
                input=text_captured.text,
                intent_filter=site_session.intent_filter
                or self.default_intent_filter,
                session_id=site_session.session_id,
                site_id=site_session.site_id,
                wakeword_id=text_captured.wakeword_id
                or site_session.wakeword_id,
                lang=text_captured.lang or site_session.lang,
            )
        except Exception:
            _LOGGER.exception("handle_text_captured")
예제 #7
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            if not msg.topic.endswith("/audioFrame"):
                _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload), msg.topic)

            # Check enable/disable messages
            if msg.topic == HotwordToggleOn.topic():
                json_payload = json.loads(msg.payload or "{}")
                if self._check_siteId(json_payload):
                    self.enabled = True
                    self.first_audio = True
                    _LOGGER.debug("Enabled")
            elif msg.topic == HotwordToggleOff.topic():
                json_payload = json.loads(msg.payload or "{}")
                if self._check_siteId(json_payload):
                    self.enabled = False
                    _LOGGER.debug("Disabled")

            if not self.enabled:
                # Disabled
                return

            # Handle audio frames
            if AudioFrame.is_topic(msg.topic):
                if (not self.audioframe_topics) or (
                    msg.topic in self.audioframe_topics
                ):
                    if self.first_audio:
                        _LOGGER.debug("Receiving audio")
                        self.first_audio = False

                    siteId = AudioFrame.get_siteId(msg.topic)
                    for wakewordId, result in self.handle_audio_frame(
                        msg.payload, siteId=siteId
                    ):
                        if isinstance(result, HotwordDetected):
                            # Topic contains wake word id
                            self.publish(result, wakewordId=wakewordId)
                        else:
                            self.publish(result)
        except Exception:
            _LOGGER.exception("on_message")
예제 #8
0
    async def end_session(
        self,
        reason: DialogueSessionTerminationReason,
        site_id: str,
        session_id: str,
        start_next_session: bool,
    ) -> typing.AsyncIterable[typing.Union[EndSessionType, StartSessionType, SayType]]:
        """End current session and start queued session."""
        site_session = self.all_sessions.pop(session_id, None)
        if site_session:
            # Remove session for site
            self.session_by_site.pop(site_session.site_id, None)

            # End the existing session
            if site_session.start_session.init.type != DialogueActionType.NOTIFICATION:
                # Stop listening
                yield AsrStopListening(
                    site_id=site_session.site_id, session_id=site_session.session_id
                )

            yield DialogueSessionEnded(
                site_id=site_id,
                session_id=site_session.session_id,
                custom_data=site_session.custom_data,
                termination=DialogueSessionTermination(reason=reason),
            )
        else:
            _LOGGER.warning("No session for id %s", session_id)

        # Check session queue
        session_queue = self.session_queue_by_site[site_id]
        if session_queue:
            if start_next_session:
                _LOGGER.debug("Handling queued session")
                async for start_result in self.start_session(session_queue.popleft()):
                    yield start_result
        else:
            # Enable hotword if no queued sessions
            yield HotwordToggleOn(
                site_id=site_id, reason=HotwordToggleReason.DIALOGUE_SESSION
            )
예제 #9
0
    async def maybe_play_sound(
        self,
        sound_name: str,
        site_id: typing.Optional[str] = None,
        request_id: typing.Optional[str] = None,
        block: bool = True,
    ) -> typing.AsyncIterable[SoundsType]:
        """Play WAV sound through audio out if it exists."""
        if site_id in self.no_sound:
            _LOGGER.debug("Sound is disabled for site %s", site_id)
            return

        site_id = site_id or self.site_id
        sound_path = self.sound_paths.get(sound_name)
        if sound_path:
            if sound_path.is_dir():
                sound_file_paths = [
                    p
                    for p in sound_path.rglob("*")
                    if p.is_file() and (p.suffix in self.sound_suffixes)
                ]
                if not sound_file_paths:
                    _LOGGER.debug("No sound files found in %s", str(sound_path))
                    return

                sound_path = random.choice(sound_file_paths)
            elif not sound_path.is_file():
                _LOGGER.error("Sound does not exist: %s", str(sound_path))
                return

            _LOGGER.debug("Playing sound %s", str(sound_path))

            # Convert to WAV
            wav_bytes = DialogueHermesMqtt.convert_to_wav(sound_path)

            if (self.volume is not None) and (self.volume != 1.0):
                wav_bytes = DialogueHermesMqtt.change_volume(wav_bytes, self.volume)

            # Send messages
            request_id = request_id or str(uuid4())
            finished_event = asyncio.Event()
            finished_id = request_id
            self.message_events[AudioPlayFinished][finished_id] = finished_event

            # Disable ASR/hotword at site
            yield HotwordToggleOff(
                site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO
            )
            yield AsrToggleOff(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO)

            # Wait for messages to be delivered
            await asyncio.sleep(self.toggle_delay)

            try:
                yield (
                    AudioPlayBytes(wav_bytes=wav_bytes),
                    {"site_id": site_id, "request_id": request_id},
                )

                # Wait for finished event or WAV duration
                if block:
                    wav_duration = get_wav_duration(wav_bytes)
                    wav_timeout = wav_duration + self.sound_timeout_extra
                    _LOGGER.debug(
                        "Waiting for playFinished (id=%s, timeout=%s)",
                        finished_id,
                        wav_timeout,
                    )
                    await asyncio.wait_for(finished_event.wait(), timeout=wav_timeout)
            except asyncio.TimeoutError:
                _LOGGER.warning("Did not receive sayFinished before timeout")
            except Exception:
                _LOGGER.exception("maybe_play_sound")
            finally:
                # Wait for audio to finish playing
                await asyncio.sleep(self.toggle_delay)

                # Re-enable ASR/hotword at site
                yield HotwordToggleOn(
                    site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO
                )
                yield AsrToggleOn(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO)
예제 #10
0
    async def handle_text_captured(
        self, text_captured: AsrTextCaptured
    ) -> typing.AsyncIterable[
        typing.Union[
            AsrStopListening, HotwordToggleOn, NluQuery, NluIntentNotRecognized
        ]
    ]:
        """Handle ASR text captured for session."""
        try:
            if not text_captured.session_id:
                _LOGGER.warning("Missing session id on text captured message.")
                return

            site_session = self.all_sessions.get(text_captured.session_id)
            if site_session is None:
                _LOGGER.warning(
                    "No session for id %s. Dropping captured text from ASR.",
                    text_captured.session_id,
                )
                return

            _LOGGER.debug("Received text: %s", text_captured.text)

            # Record result
            site_session.text_captured = text_captured

            # Stop listening
            yield AsrStopListening(
                site_id=text_captured.site_id, session_id=site_session.session_id
            )

            # Enable hotword
            yield HotwordToggleOn(
                site_id=text_captured.site_id,
                reason=HotwordToggleReason.DIALOGUE_SESSION,
            )

            if (self.min_asr_confidence is not None) and (
                text_captured.likelihood < self.min_asr_confidence
            ):
                # Transcription is below thresold.
                # Don't actually do an NLU query, just reject as "not recognized".
                _LOGGER.debug(
                    "Transcription is below confidence threshold (%s < %s): %s",
                    text_captured.likelihood,
                    self.min_asr_confidence,
                    text_captured.text,
                )

                yield NluIntentNotRecognized(
                    input=text_captured.text,
                    site_id=site_session.site_id,
                    session_id=site_session.session_id,
                )
            else:
                # Perform query
                custom_entities: typing.Optional[typing.Dict[str, typing.Any]] = None

                # Copy custom entities from hotword detected
                if site_session.detected:
                    custom_entities = site_session.detected.custom_entities

                yield NluQuery(
                    input=text_captured.text,
                    intent_filter=site_session.intent_filter
                    or self.default_intent_filter,
                    session_id=site_session.session_id,
                    site_id=site_session.site_id,
                    wakeword_id=text_captured.wakeword_id or site_session.wakeword_id,
                    lang=text_captured.lang or site_session.lang,
                    custom_data=site_session.custom_data,
                    asr_confidence=text_captured.likelihood,
                    custom_entities=custom_entities,
                )
        except Exception:
            _LOGGER.exception("handle_text_captured")
예제 #11
0
    async def maybe_play_sound(
        self,
        sound_name: str,
        site_id: typing.Optional[str] = None,
        request_id: typing.Optional[str] = None,
        block: bool = True,
    ) -> typing.AsyncIterable[SoundsType]:
        """Play WAV sound through audio out if it exists."""
        site_id = site_id or self.site_id
        wav_path = self.sound_paths.get(sound_name)
        if wav_path:
            if not wav_path.is_file():
                _LOGGER.error("WAV does not exist: %s", str(wav_path))
                return

            _LOGGER.debug("Playing WAV %s", str(wav_path))
            wav_bytes = wav_path.read_bytes()

            request_id = request_id or str(uuid4())
            finished_event = asyncio.Event()
            finished_id = request_id
            self.message_events[AudioPlayFinished][
                finished_id] = finished_event

            # Disable ASR/hotword at site
            yield HotwordToggleOff(site_id=site_id,
                                   reason=HotwordToggleReason.PLAY_AUDIO)
            yield AsrToggleOff(site_id=site_id,
                               reason=AsrToggleReason.PLAY_AUDIO)

            # Wait for messages to be delivered
            await asyncio.sleep(self.toggle_delay)

            try:
                yield (
                    AudioPlayBytes(wav_bytes=wav_bytes),
                    {
                        "site_id": site_id,
                        "request_id": request_id
                    },
                )

                # Wait for finished event or WAV duration
                if block:
                    wav_duration = get_wav_duration(wav_bytes)
                    wav_timeout = wav_duration + self.sound_timeout_extra
                    _LOGGER.debug("Waiting for playFinished (timeout=%s)",
                                  wav_timeout)
                    await asyncio.wait_for(finished_event.wait(),
                                           timeout=wav_timeout)
            except asyncio.TimeoutError:
                _LOGGER.warning("Did not receive sayFinished before timeout")
            except Exception:
                _LOGGER.exception("maybe_play_sound")
            finally:
                # Wait for audio to finish playing
                await asyncio.sleep(self.toggle_delay)

                # Re-enable ASR/hotword at site
                yield HotwordToggleOn(site_id=site_id,
                                      reason=HotwordToggleReason.PLAY_AUDIO)
                yield AsrToggleOn(site_id=site_id,
                                  reason=AsrToggleReason.PLAY_AUDIO)
예제 #12
0
def test_hotword_toggle_on():
    """Test HotwordToggleOn."""
    assert HotwordToggleOn.topic() == "hermes/hotword/toggleOn"