示例#1
0
 def wrapped(intent: NluIntent):
     message = function(intent)
     if isinstance(message, EndSession):
         if intent.session_id is not None:
             self.publish(
                 DialogueEndSession(
                     session_id=intent.session_id,
                     site_id=intent.site_id,
                     text=message.text,
                     custom_data=message.custom_data,
                 ))
         else:
             _LOGGER.error(
                 "Cannot end session of intent without session ID.")
     elif isinstance(message, ContinueSession):
         if intent.session_id is not None:
             self.publish(
                 DialogueContinueSession(
                     session_id=intent.session_id,
                     site_id=intent.site_id,
                     text=message.text,
                     intent_filter=message.intent_filter,
                     custom_data=message.custom_data,
                     send_intent_not_recognized=message.
                     send_intent_not_recognized,
                 ))
         else:
             _LOGGER.error(
                 "Cannot continue session of intent without session ID."
             )
示例#2
0
 async def wrapped(inr: DialogueIntentNotRecognized) -> None:
     message = await function(inr)
     if isinstance(message, EndSession):
         if inr.session_id is not None:
             self.publish(
                 DialogueEndSession(
                     session_id=inr.session_id,
                     site_id=inr.site_id,
                     text=message.text,
                     custom_data=message.custom_data,
                 ))
         else:
             _LOGGER.error(
                 "Cannot end session of dialogue intent not recognized message without session ID."
             )
     elif isinstance(message, ContinueSession):
         if inr.session_id is not None:
             self.publish(
                 DialogueContinueSession(
                     session_id=inr.session_id,
                     site_id=inr.site_id,
                     text=message.text,
                     intent_filter=message.intent_filter,
                     custom_data=message.custom_data,
                     send_intent_not_recognized=message.
                     send_intent_not_recognized,
                 ))
         else:
             _LOGGER.error(
                 "Cannot continue session of dialogue intent not recognized message without session ID."
             )
示例#3
0
 def on_connect(self, client, userdata, flags, rc):
     """Connected to MQTT broker."""
     try:
         topics = [
             DialogueStartSession.topic(),
             DialogueContinueSession.topic(),
             DialogueEndSession.topic(),
             TtsSayFinished.topic(),
             NluIntent.topic(intent_name="#"),
             NluIntentNotRecognized.topic(),
             AsrTextCaptured.topic(),
         ] + list(self.wakeword_topics.keys())
         for topic in topics:
             self.client.subscribe(topic)
             _LOGGER.debug("Subscribed to %s", topic)
     except Exception:
         _LOGGER.exception("on_connect")
示例#4
0
    async def repeat_item(self):
        """Continues session with current item."""
        assert self.current_item, "No current item"

        intent_filter = [
            intent
            for intent in [
                self.current_item.confirm_intent,
                self.current_item.disconfirm_intent,
                self.current_item.cancel_intent,
            ]
            if intent
        ]

        assert intent_filter, "Need confirm/disconfirm/cancel intent"

        yield DialogueContinueSession(
            session_id=self.session_id,
            text=self.current_item.text,
            intent_filter=intent_filter,
            send_intent_not_recognized=True,
        )
示例#5
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload),
                          msg.topic)
            if msg.topic == DialogueStartSession.topic():
                # Start session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_start(DialogueStartSession(**json_payload)),
                    self.loop)
            elif msg.topic == DialogueContinueSession.topic():
                # Continue session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_continue(
                        DialogueContinueSession(**json_payload)),
                    self.loop,
                )
            elif msg.topic == DialogueEndSession.topic():
                # End session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run outside event loop
                self.handle_end(DialogueEndSession(**json_payload))
            elif msg.topic == TtsSayFinished.topic():
                # TTS finished
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Signal event loop
                self.loop.call_soon_threadsafe(self.say_finished_event.set)
            elif msg.topic == AsrTextCaptured.topic():
                # Text captured
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run outside event loop
                self.handle_text_captured(AsrTextCaptured(**json_payload))
            elif msg.topic.startswith(NluIntent.topic(intent_name="")):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                self.handle_recognized(NluIntent(**json_payload))
            elif msg.topic.startswith(NluIntentNotRecognized.topic()):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_not_recognized(
                        NluIntentNotRecognized(**json_payload)),
                    self.loop,
                )
            elif msg.topic in self.wakeword_topics:
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                wakeword_id = self.wakeword_topics[msg.topic]
                asyncio.run_coroutine_threadsafe(
                    self.handle_wake(wakeword_id,
                                     HotwordDetected(**json_payload)),
                    self.loop,
                )
        except Exception:
            _LOGGER.exception("on_message")
示例#6
0
    async def on_message_test_multi_session(
        self,
        message: Message,
        site_id: typing.Optional[str] = None,
        session_id: typing.Optional[str] = None,
        topic: typing.Optional[str] = None,
    ):
        """Receive messages for test_multi_session"""
        if not isinstance(message, AudioPlayBytes):
            _LOGGER.debug(message)

        if isinstance(message, DialogueSessionStarted):
            # Verify session was started on the base or a satellite
            self.assertIn(message.site_id, [self.base_id] + self.satellite_ids)
            self.assertEqual(message.custom_data,
                             self.custom_data[message.site_id])

            self.events[f"{message.site_id}_started"].set()

            self.session_ids[message.site_id] = message.session_id

            if message.site_id == self.continue_site_id:
                # Make session continue one more step
                self.custom_data[message.site_id] = "done"

                yield DialogueContinueSession(session_id=message.session_id,
                                              custom_data="done")
            else:
                self.custom_data[message.site_id] = str(uuid4())

                # Publish an intent not recognized message to abort the session
                yield NluIntentNotRecognized(
                    input=f"test intent (site={message.site_id})",
                    site_id=message.site_id,
                    session_id=message.session_id,
                    custom_data=self.custom_data[message.site_id],
                )
        elif isinstance(message, DialogueSessionEnded):

            # Verify session was aborted on the base or a satellite
            self.assertIn(message.site_id, [self.base_id] + self.satellite_ids)
            self.assertEqual(message.custom_data,
                             self.custom_data[message.site_id])

            self.events[f"{message.site_id}_ended"].set()

            self.assertEqual(
                message.termination.reason,
                DialogueSessionTerminationReason.INTENT_NOT_RECOGNIZED,
            )

            self.assertEqual(self.session_ids[message.site_id],
                             message.session_id)
        elif isinstance(message, AsrStartListening):
            # Follow on from continue session
            if (message.site_id == self.continue_site_id) and (
                    self.custom_data[message.site_id] == "done"):
                # Publish an intent not recognized message to abort the session
                yield NluIntentNotRecognized(
                    input=f"test intent (site={message.site_id})",
                    site_id=message.site_id,
                    session_id=message.session_id,
                )
        elif isinstance(message, AudioPlayBytes):
            yield (AudioPlayFinished(id=session_id), {"site_id": site_id})

        yield None
示例#7
0
def test_dialogue_continue_session():
    """Test DialogueContinueSession."""
    assert DialogueContinueSession.topic(
    ) == "hermes/dialogueManager/continueSession"