Ejemplo n.º 1
0
    async def async_test_ws_text(self):
        """Test api/events/text endpoint"""
        # Start listening
        event_queue = asyncio.Queue()
        connected = asyncio.Event()
        receive_task = asyncio.ensure_future(
            self.async_ws_receive("events/text", event_queue, connected))
        await asyncio.wait_for(connected.wait(), timeout=5)

        # Send in a message
        text_captured = AsrTextCaptured(
            text="this is a test",
            likelihood=1,
            seconds=0,
            site_id=self.site_id,
            session_id=self.session_id,
            wakeword_id=str(uuid4()),
        )

        self.client.publish(text_captured.topic(), text_captured.payload())

        # Wait for response
        event = json.loads(await asyncio.wait_for(event_queue.get(),
                                                  timeout=5))
        self.assertEqual(text_captured.text, event.get("text", ""))
        self.assertEqual(text_captured.site_id, event.get("siteId", ""))
        self.assertEqual(text_captured.wakeword_id,
                         event.get("wakewordId", ""))

        # Stop listening
        receive_task.cancel()
Ejemplo n.º 2
0
    async def transcribe_wav(self,
                             wav_bytes: bytes,
                             frames_per_chunk: int = 1024) -> AsrTextCaptured:
        """Transcribe WAV data"""
        sessionId = str(uuid4())

        def handle_captured():
            while True:
                _, message = yield

                if isinstance(message, AsrTextCaptured) and (message.sessionId
                                                             == sessionId):
                    return True, message

        def messages():
            yield AsrStartListening(siteId=self.siteId, sessionId=sessionId)

            # Break WAV into chunks
            with io.BytesIO(wav_bytes) as wav_buffer:
                with wave.open(wav_buffer, "rb") as wav_file:
                    frames_left = wav_file.getnframes()
                    while frames_left > 0:
                        with io.BytesIO() as chunk_buffer:
                            chunk_file: wave.Wave_write = wave.open(
                                chunk_buffer, "wb")
                            with chunk_file:
                                chunk_file.setframerate(
                                    wav_file.getframerate())
                                chunk_file.setsampwidth(
                                    wav_file.getsampwidth())
                                chunk_file.setnchannels(
                                    wav_file.getnchannels())
                                chunk_file.writeframes(
                                    wav_file.readframes(frames_per_chunk))

                            yield (
                                AudioFrame(wav_data=chunk_buffer.getvalue()),
                                {
                                    "siteId": self.siteId
                                },
                            )

                        frames_left -= frames_per_chunk

            yield AsrStopListening(siteId=self.siteId, sessionId=sessionId)

        topics = [AsrTextCaptured.topic()]

        # Expecting only a single result
        async for result in self.publish_wait(handle_captured(), messages(),
                                              topics):
            return result
Ejemplo n.º 3
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload),
                          msg.topic)

            if NluIntent.is_topic(msg.topic):
                # Intent from query
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                intent = NluIntent(**json_payload)

                # Report to websockets
                for queue in self.message_queues:
                    self.loop.call_soon_threadsafe(queue.put_nowait,
                                                   ("intent", intent))

                self.handle_message(msg.topic, intent)
            elif msg.topic == NluIntentNotRecognized.topic():
                # Intent not recognized
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                not_recognized = NluIntentNotRecognized(**json_payload)
                # Report to websockets
                for queue in self.message_queues:
                    self.loop.call_soon_threadsafe(queue.put_nowait,
                                                   ("intent", not_recognized))

                self.handle_message(msg.topic, not_recognized)
            elif msg.topic == TtsSayFinished.topic():
                # Text to speech finished
                json_payload = json.loads(msg.payload)
                self.handle_message(msg.topic, TtsSayFinished(**json_payload))
            elif msg.topic == AsrTextCaptured.topic():
                # Speech to text result
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                self.handle_message(msg.topic, AsrTextCaptured(**json_payload))

            # Forward to external message queues
            for queue in self.message_queues:
                self.loop.call_soon_threadsafe(
                    queue.put_nowait, ("mqtt", msg.topic, msg.payload))
        except Exception:
            _LOGGER.exception("on_message")
Ejemplo n.º 4
0
 def on_connect(self, client, userdata, flags, rc):
     """Connected to MQTT broker."""
     try:
         topics = [
             DialogueStartSession.topic(),
             DialogueContinueSession.topic(),
             DialogueEndSession.topic(),
             TtsSayFinished.topic(),
             NluIntent.topic(intent_name="#"),
             NluIntentNotRecognized.topic(),
             AsrTextCaptured.topic(),
         ] + list(self.wakeword_topics.keys())
         for topic in topics:
             self.client.subscribe(topic)
             _LOGGER.debug("Subscribed to %s", topic)
     except Exception:
         _LOGGER.exception("on_connect")
Ejemplo n.º 5
0
def test_asr_text_captured():
    """Test AsrTextCaptured."""
    assert AsrTextCaptured.topic() == "hermes/asr/textCaptured"
Ejemplo n.º 6
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload),
                          msg.topic)
            if msg.topic == DialogueStartSession.topic():
                # Start session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_start(DialogueStartSession(**json_payload)),
                    self.loop)
            elif msg.topic == DialogueContinueSession.topic():
                # Continue session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_continue(
                        DialogueContinueSession(**json_payload)),
                    self.loop,
                )
            elif msg.topic == DialogueEndSession.topic():
                # End session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run outside event loop
                self.handle_end(DialogueEndSession(**json_payload))
            elif msg.topic == TtsSayFinished.topic():
                # TTS finished
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Signal event loop
                self.loop.call_soon_threadsafe(self.say_finished_event.set)
            elif msg.topic == AsrTextCaptured.topic():
                # Text captured
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run outside event loop
                self.handle_text_captured(AsrTextCaptured(**json_payload))
            elif msg.topic.startswith(NluIntent.topic(intent_name="")):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                self.handle_recognized(NluIntent(**json_payload))
            elif msg.topic.startswith(NluIntentNotRecognized.topic()):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_not_recognized(
                        NluIntentNotRecognized(**json_payload)),
                    self.loop,
                )
            elif msg.topic in self.wakeword_topics:
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                wakeword_id = self.wakeword_topics[msg.topic]
                asyncio.run_coroutine_threadsafe(
                    self.handle_wake(wakeword_id,
                                     HotwordDetected(**json_payload)),
                    self.loop,
                )
        except Exception:
            _LOGGER.exception("on_message")