예제 #1
0
    async def handle_wake(self, wakeword_id: str, detected: HotwordDetected):
        """Wake word was detected."""
        try:
            _LOGGER.debug("Hotword detected: %s", wakeword_id)

            sessionId = f"{self.siteId}-{wakeword_id}-{uuid4()}"
            new_session = SessionInfo(
                sessionId=sessionId,
                start_session=DialogueStartSession(
                    siteId=self.siteId,
                    customData=wakeword_id,
                    init=DialogueAction(canBeEnqueued=False),
                ),
            )

            if self.session:
                # Jump the queue
                self.session_queue.appendleft(new_session)

                # Abort previous session
                await self.end_session(
                    DialogueSessionTerminationReason.ABORTED_BY_USER)
            else:
                # Start new session
                await self.start_session(new_session)
        except Exception:
            _LOGGER.exception("handle_wake")
예제 #2
0
    async def async_test_not_recognized(self):
        """Test start/end/not recognized without a satellite"""
        self.custom_data[self.base_id] = str(uuid4())

        for event_name in ["started", "ended"]:
            self.events[event_name] = asyncio.Event()

        # Wait until connected
        self.hermes.on_message = self.on_message_test_not_recognized
        await asyncio.wait_for(self.hermes.mqtt_connected_event.wait(),
                               timeout=5)

        # Start listening
        self.hermes.subscribe(DialogueSessionStarted, DialogueSessionEnded)
        message_task = asyncio.create_task(self.hermes.handle_messages_async())

        # Start a new session
        self.hermes.publish(
            DialogueStartSession(
                init=DialogueAction(can_be_enqueued=False),
                site_id=self.base_id,
                custom_data=self.custom_data[self.base_id],
            ))

        # Wait for up to 10 seconds
        await asyncio.wait_for(
            asyncio.gather(*[e.wait() for e in self.events.values()]),
            timeout=10)

        message_task.cancel()
예제 #3
0
    def notify(self, text: str, site_id: str = "default"):
        """Send a dialogue notification.

        Use this to inform the user of something without expecting a response.

        Arguments:
            text: The text to say.
            site_id: The ID of the site where the text should be said.
        """
        notification = DialogueNotification(text)
        self.publish(DialogueStartSession(init=notification, site_id=site_id))
예제 #4
0
    async def start_checklist(self, start_message: StartChecklist):
        """Starts a new checklist."""
        assert start_message.items, "No checklist items"

        self.start_message = start_message
        self.finished_message = ChecklistFinished(
            id=start_message.id,
            status=ChecklistFinishStatus.UNKNOWN,
            site_id=start_message.site_id,
        )
        self.session_id = ""
        self.checklist_items = deque(start_message.items)

        # Complete intents with defaults
        for item in self.checklist_items:
            item.confirm_intent = (
                item.confirm_intent or self.start_message.confirm_intent
            )
            item.disconfirm_intent = (
                item.disconfirm_intent or self.start_message.disconfirm_intent
            )
            item.cancel_intent = item.cancel_intent or self.start_message.cancel_intent

        _LOGGER.debug(self.checklist_items)

        # First item
        self.current_item = self.checklist_items.popleft()
        intent_filter = [
            intent
            for intent in [
                self.current_item.confirm_intent,
                self.current_item.disconfirm_intent,
                self.current_item.cancel_intent,
            ]
            if intent
        ]

        assert intent_filter, "Need confirm/disconfirm/cancel intent"

        # Start new session
        yield DialogueStartSession(
            init=DialogueAction(
                can_be_enqueued=True,
                text=self.current_item.text,
                intent_filter=intent_filter,
                send_intent_not_recognized=True,
            ),
            custom_data=self.start_message.id,
            site_id=start_message.site_id,
        )
예제 #5
0
 def on_connect(self, client, userdata, flags, rc):
     """Connected to MQTT broker."""
     try:
         topics = [
             DialogueStartSession.topic(),
             DialogueContinueSession.topic(),
             DialogueEndSession.topic(),
             TtsSayFinished.topic(),
             NluIntent.topic(intent_name="#"),
             NluIntentNotRecognized.topic(),
             AsrTextCaptured.topic(),
         ] + list(self.wakeword_topics.keys())
         for topic in topics:
             self.client.subscribe(topic)
             _LOGGER.debug("Subscribed to %s", topic)
     except Exception:
         _LOGGER.exception("on_connect")
예제 #6
0
    async def handle_wake(
        self, wakeword_id: str, detected: HotwordDetected
    ) -> typing.AsyncIterable[typing.Union[EndSessionType, StartSessionType,
                                           SayType, SoundsType]]:
        """Wake word was detected."""
        try:
            session_id = (detected.session_id
                          or f"{detected.site_id}-{wakeword_id}-{uuid4()}")
            new_session = SessionInfo(
                session_id=session_id,
                site_id=detected.site_id,
                start_session=DialogueStartSession(
                    site_id=detected.site_id,
                    custom_data=wakeword_id,
                    init=DialogueAction(can_be_enqueued=False),
                ),
                detected=detected,
                wakeword_id=wakeword_id,
                lang=detected.lang,
            )

            # Play wake sound before ASR starts listening
            async for play_wake_result in self.maybe_play_sound(
                "wake", site_id=detected.site_id):
                yield play_wake_result

            if self.session:
                # Jump the queue
                self.session_queue.appendleft(new_session)

                # Abort previous session
                async for end_result in self.end_session(
                    DialogueSessionTerminationReason.ABORTED_BY_USER,
                    site_id=self.session.site_id,
                ):
                    yield end_result
            else:
                # Start new session
                async for start_result in self.start_session(new_session):
                    yield start_result
        except Exception as e:
            _LOGGER.exception("handle_wake")
            yield DialogueError(error=str(e),
                                context=str(detected),
                                site_id=detected.site_id)
예제 #7
0
    async def async_test_multi_session(self):
        """Test multiple sessions on multiple satellites without a satellite"""
        self.continue_site_id = self.satellite_ids[0]

        for site_id in [self.base_id] + self.satellite_ids:
            self.custom_data[site_id] = str(uuid4())

            for event_name in ["started", "ended"]:
                self.events[f"{site_id}_{event_name}"] = asyncio.Event()

        # Wait until connected
        self.hermes.on_message = self.on_message_test_multi_session
        await asyncio.wait_for(self.hermes.mqtt_connected_event.wait(),
                               timeout=5)

        # Start listening
        self.hermes.subscribe(
            DialogueSessionStarted,
            DialogueSessionEnded,
            AudioPlayBytes,
            AsrStartListening,
        )

        message_task = asyncio.create_task(self.hermes.handle_messages_async())

        # Start a new session on the base and all satellites
        for site_id in [self.base_id] + self.satellite_ids:
            self.hermes.publish(
                DialogueStartSession(
                    init=DialogueAction(can_be_enqueued=False),
                    site_id=site_id,
                    custom_data=self.custom_data[site_id],
                ))

        # Wait for up to 10 seconds
        await asyncio.wait_for(
            asyncio.gather(*[e.wait() for e in self.events.values()]),
            timeout=10)

        message_task.cancel()
예제 #8
0
    async def handle_wake(
        self, wakeword_id: str, detected: HotwordDetected
    ) -> typing.AsyncIterable[
        typing.Union[EndSessionType, StartSessionType, SayType, SoundsType]
    ]:
        """Wake word was detected."""
        group_lock: typing.Optional[asyncio.Lock] = None

        try:
            group_id = ""

            if self.group_separator:
                # Split site_id into <GROUP>[separator]<NAME>
                site_id_parts = detected.site_id.split(self.group_separator, maxsplit=1)
                if len(site_id_parts) > 1:
                    group_id = site_id_parts[0]

            if group_id:
                # Use a lock per group id to prevent multiple satellites from
                # starting sessions while the wake up sound is being played.
                async with self.global_wake_lock:
                    group_lock = self.group_wake_lock.get(group_id)
                    if group_lock is None:
                        # Create new lock for group
                        group_lock = asyncio.Lock()
                        self.group_wake_lock[group_id] = group_lock

                assert group_lock is not None
                await group_lock.acquire()

                # Check if a session from the same group is already active.
                # If so, ignore this wake up.
                for session in self.all_sessions.values():
                    # Also check if text has already been captured for this session.
                    # This prevents a new session for a group from being blocked
                    # because a previous (completed) one has not timed out yet.
                    if (session.group_id == group_id) and (
                        session.text_captured is None
                    ):
                        _LOGGER.debug(
                            "Group %s already has a session (%s). Ignoring wake word detection from %s.",
                            group_id,
                            session.site_id,
                            detected.site_id,
                        )
                        return

            # Create new session
            session_id = (
                detected.session_id or f"{detected.site_id}-{wakeword_id}-{uuid4()}"
            )
            new_session = SessionInfo(
                session_id=session_id,
                site_id=detected.site_id,
                start_session=DialogueStartSession(
                    site_id=detected.site_id,
                    custom_data=wakeword_id,
                    init=DialogueAction(can_be_enqueued=False),
                ),
                detected=detected,
                wakeword_id=wakeword_id,
                lang=detected.lang,
                group_id=group_id,
            )

            # Play wake sound before ASR starts listening
            async for play_wake_result in self.maybe_play_sound(
                "wake", site_id=detected.site_id
            ):
                yield play_wake_result

            site_session = self.session_by_site.get(detected.site_id)
            if site_session:
                # Jump the queue
                self.session_queue_by_site[site_session.site_id].appendleft(new_session)

                # Abort previous session and start queued session
                async for end_result in self.end_session(
                    DialogueSessionTerminationReason.ABORTED_BY_USER,
                    site_id=site_session.site_id,
                    session_id=site_session.session_id,
                    start_next_session=True,
                ):
                    yield end_result
            else:
                # Start new session
                async for start_result in self.start_session(new_session):
                    yield start_result
        except Exception as e:
            _LOGGER.exception("handle_wake")
            yield DialogueError(
                error=str(e), context=str(detected), site_id=detected.site_id
            )
        finally:
            if group_lock is not None:
                group_lock.release()
예제 #9
0
    async def handle_wake(
        self, wakeword_id: str, detected: HotwordDetected
    ) -> typing.AsyncIterable[typing.Union[EndSessionType, StartSessionType,
                                           SayType, SoundsType]]:
        """Wake word was detected."""
        try:
            group_id = ""

            if self.group_separator:
                # Split site_id into <GROUP>[separator]<NAME>
                site_id_parts = detected.site_id.split(self.group_separator,
                                                       maxsplit=1)
                if len(site_id_parts) > 1:
                    group_id = site_id_parts[0]

            if group_id:
                # Check if a session from the same group is already active.
                # If so, ignore this wake up.
                for session in self.all_sessions.values():
                    if session.group_id == group_id:
                        _LOGGER.debug(
                            "Group %s already has a session (%s). Ignoring wake word detection from %s.",
                            group_id,
                            session.site_id,
                            detected.site_id,
                        )
                        return

            # Create new session
            session_id = (detected.session_id
                          or f"{detected.site_id}-{wakeword_id}-{uuid4()}")
            new_session = SessionInfo(
                session_id=session_id,
                site_id=detected.site_id,
                start_session=DialogueStartSession(
                    site_id=detected.site_id,
                    custom_data=wakeword_id,
                    init=DialogueAction(can_be_enqueued=False),
                ),
                detected=detected,
                wakeword_id=wakeword_id,
                lang=detected.lang,
                group_id=group_id,
            )

            # Play wake sound before ASR starts listening
            async for play_wake_result in self.maybe_play_sound(
                "wake", site_id=detected.site_id):
                yield play_wake_result

            site_session = self.session_by_site.get(detected.site_id)
            if site_session:
                # Jump the queue
                self.session_queue_by_site[site_session.site_id].appendleft(
                    new_session)

                # Abort previous session and start queued session
                async for end_result in self.end_session(
                    DialogueSessionTerminationReason.ABORTED_BY_USER,
                    site_id=site_session.site_id,
                    session_id=site_session.session_id,
                    start_next_session=True,
                ):
                    yield end_result
            else:
                # Start new session
                async for start_result in self.start_session(new_session):
                    yield start_result
        except Exception as e:
            _LOGGER.exception("handle_wake")
            yield DialogueError(error=str(e),
                                context=str(detected),
                                site_id=detected.site_id)
예제 #10
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload),
                          msg.topic)
            if msg.topic == DialogueStartSession.topic():
                # Start session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_start(DialogueStartSession(**json_payload)),
                    self.loop)
            elif msg.topic == DialogueContinueSession.topic():
                # Continue session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_continue(
                        DialogueContinueSession(**json_payload)),
                    self.loop,
                )
            elif msg.topic == DialogueEndSession.topic():
                # End session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run outside event loop
                self.handle_end(DialogueEndSession(**json_payload))
            elif msg.topic == TtsSayFinished.topic():
                # TTS finished
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Signal event loop
                self.loop.call_soon_threadsafe(self.say_finished_event.set)
            elif msg.topic == AsrTextCaptured.topic():
                # Text captured
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run outside event loop
                self.handle_text_captured(AsrTextCaptured(**json_payload))
            elif msg.topic.startswith(NluIntent.topic(intent_name="")):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                self.handle_recognized(NluIntent(**json_payload))
            elif msg.topic.startswith(NluIntentNotRecognized.topic()):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_not_recognized(
                        NluIntentNotRecognized(**json_payload)),
                    self.loop,
                )
            elif msg.topic in self.wakeword_topics:
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                wakeword_id = self.wakeword_topics[msg.topic]
                asyncio.run_coroutine_threadsafe(
                    self.handle_wake(wakeword_id,
                                     HotwordDetected(**json_payload)),
                    self.loop,
                )
        except Exception:
            _LOGGER.exception("on_message")
예제 #11
0
def test_dialogue_start_session():
    """Test DialogueStartSession."""
    assert DialogueStartSession.topic(
    ) == "hermes/dialogueManager/startSession"