def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            if not msg.topic.endswith("/audioFrame"):
                _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload), msg.topic)

            # Check enable/disable messages
            if msg.topic == AsrToggleOn.topic():
                json_payload = json.loads(msg.payload or "{}")
                if self._check_siteId(json_payload):
                    self.enabled = True
                    _LOGGER.debug("Enabled")
            elif msg.topic == AsrToggleOn.topic():
                json_payload = json.loads(msg.payload or "{}")
                if self._check_siteId(json_payload):
                    self.enabled = False
                    _LOGGER.debug("Disabled")

            if not self.enabled:
                # Disabled
                return

            if AudioFrame.is_topic(msg.topic):
                # Check siteId
                if (not self.audioframe_topics) or (
                    msg.topic in self.audioframe_topics
                ):
                    # Add to all active sessions
                    if self.first_audio:
                        _LOGGER.debug("Receiving audio")
                        self.first_audio = False

                    siteId = AudioFrame.get_siteId(msg.topic)
                    for result in self.handle_audio_frame(msg.payload, siteId=siteId):
                        self.publish(result)

            elif msg.topic == AsrStartListening.topic():
                # hermes/asr/startListening
                json_payload = json.loads(msg.payload)
                if self._check_siteId(json_payload):
                    for result in self.start_listening(
                        AsrStartListening(**json_payload)
                    ):
                        self.publish(result)
            elif msg.topic == AsrStopListening.topic():
                # hermes/asr/stopListening
                json_payload = json.loads(msg.payload)
                if self._check_siteId(json_payload):
                    for result in self.stop_listening(AsrStopListening(**json_payload)):
                        self.publish(result)
        except Exception:
            _LOGGER.exception("on_message")
예제 #2
0
        def messages():
            yield AsrStartListening(siteId=self.siteId, sessionId=sessionId)

            # Break WAV into chunks
            with io.BytesIO(wav_bytes) as wav_buffer:
                with wave.open(wav_buffer, "rb") as wav_file:
                    frames_left = wav_file.getnframes()
                    while frames_left > 0:
                        with io.BytesIO() as chunk_buffer:
                            chunk_file: wave.Wave_write = wave.open(
                                chunk_buffer, "wb")
                            with chunk_file:
                                chunk_file.setframerate(
                                    wav_file.getframerate())
                                chunk_file.setsampwidth(
                                    wav_file.getsampwidth())
                                chunk_file.setnchannels(
                                    wav_file.getnchannels())
                                chunk_file.writeframes(
                                    wav_file.readframes(frames_per_chunk))

                            yield (
                                AudioFrame(wav_data=chunk_buffer.getvalue()),
                                {
                                    "siteId": self.siteId
                                },
                            )

                        frames_left -= frames_per_chunk

            yield AsrStopListening(siteId=self.siteId, sessionId=sessionId)
예제 #3
0
    async def handle_continue(self, continue_session: DialogueContinueSession):
        """Continue the existing session."""
        try:
            assert self.session, "No session"

            # Update fields
            self.session.customData = (continue_session.customData
                                       or self.session.customData)

            if self.session.intentFilter is not None:
                # Overwrite intent filter
                self.session.intentFilter = continue_session.intentFilter

            self.session.sendIntentNotRecognized = (
                continue_session.sendIntentNotRecognized)

            _LOGGER.debug("Continuing session %s", self.session.sessionId)
            if continue_session.text:
                # Forward to TTS
                await self.say_and_wait(continue_session.text)

            # Start ASR listening
            _LOGGER.debug("Listening for session %s", self.session.sessionId)
            self.publish(AsrStartListening())
        except Exception:
            _LOGGER.exception("handle_continue")
예제 #4
0
    async def async_test_silence(self):
        """Check start/stop session with silence detection."""
        fake_transcription = Transcription(
            text="turn on the living room lamp",
            likelihood=1,
            transcribe_seconds=0,
            wav_seconds=0,
        )

        def fake_transcribe(stream, *args):
            """Return test trancription."""
            for chunk in stream:
                if not chunk:
                    break

            return fake_transcription

        self.transcriber.transcribe_stream = fake_transcribe

        # Start session
        start_listening = AsrStartListening(
            site_id=self.site_id,
            session_id=self.session_id,
            stop_on_silence=True,
            send_audio_captured=False,
        )
        result = None
        async for response in self.hermes.on_message_blocking(start_listening):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Send in "audio"
        wav_path = Path("etc/turn_on_the_living_room_lamp.wav")

        results = []
        with open(wav_path, "rb") as wav_file:
            for wav_bytes in AudioFrame.iter_wav_chunked(wav_file, 4096):
                frame = AudioFrame(wav_bytes=wav_bytes)
                async for response in self.hermes.on_message_blocking(
                    frame, site_id=self.site_id
                ):
                    results.append(response)

        # Except transcription
        self.assertEqual(
            results,
            [
                AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id),
                AsrTextCaptured(
                    text=fake_transcription.text,
                    likelihood=fake_transcription.likelihood,
                    seconds=fake_transcription.transcribe_seconds,
                    site_id=self.site_id,
                    session_id=self.session_id,
                ),
            ],
        )
예제 #5
0
    async def async_test_transcriber_error(self):
        """Check start/stop session with error in transcriber."""

        def fake_transcribe(stream, *args):
            """Raise an exception."""
            raise FakeException()

        self.transcriber.transcribe_stream = fake_transcribe

        # Start session
        start_listening = AsrStartListening(
            site_id=self.site_id, session_id=self.session_id, stop_on_silence=False
        )
        result = None
        async for response in self.hermes.on_message_blocking(start_listening):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Send in "audio"
        fake_wav_bytes = self.hermes.to_wav_bytes(secrets.token_bytes(100))
        fake_frame = AudioFrame(wav_bytes=fake_wav_bytes)
        async for response in self.hermes.on_message_blocking(
            fake_frame, site_id=self.site_id
        ):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Stop session
        stop_listening = AsrStopListening(
            site_id=self.site_id, session_id=self.session_id
        )

        results = []
        async for response in self.hermes.on_message_blocking(stop_listening):
            results.append(response)

        # Check results for empty transcription
        self.assertEqual(
            results,
            [
                AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id),
                AsrTextCaptured(
                    text="",
                    likelihood=0,
                    seconds=0,
                    site_id=self.site_id,
                    session_id=self.session_id,
                ),
            ],
        )
    def on_connect(self, client, userdata, flags, rc):
        """Connected to MQTT broker."""
        try:
            topics = [
                AsrToggleOn.topic(),
                AsrToggleOff.topic(),
                AsrStartListening.topic(),
                AsrStopListening.topic(),
            ]

            if self.audioframe_topics:
                # Specific siteIds
                topics.extend(self.audioframe_topics)
            else:
                # All siteIds
                topics.append(AudioFrame.topic(siteId="+"))

            for topic in topics:
                self.client.subscribe(topic)
                _LOGGER.debug("Subscribed to %s", topic)
        except Exception:
            _LOGGER.exception("on_connect")
예제 #7
0
        def messages():
            yield AsrStartListening(
                site_id=self.site_id,
                session_id=session_id,
                stop_on_silence=stop_on_silence,
                send_audio_captured=send_audio_captured,
                intent_filter=intent_filter,
            )

            # Break WAV into chunks
            num_bytes_sent: int = 0
            with io.BytesIO(wav_bytes) as wav_buffer:
                for wav_chunk in AudioFrame.iter_wav_chunked(
                    wav_buffer, frames_per_chunk
                ):
                    num_bytes_sent += len(wav_chunk)
                    yield (
                        AudioSessionFrame(wav_bytes=wav_chunk),
                        {"site_id": self.site_id, "session_id": session_id},
                    )

            _LOGGER.debug("Sent %s byte(s) of WAV data", num_bytes_sent)
            yield AsrStopListening(site_id=self.site_id, session_id=session_id)
예제 #8
0
    async def handle_continue(
        self, continue_session: DialogueContinueSession
    ) -> typing.AsyncIterable[
        typing.Union[AsrStartListening, AsrStopListening, SayType, DialogueError]
    ]:
        """Continue the existing session."""
        site_session = self.all_sessions.get(continue_session.session_id)

        if site_session is None:
            _LOGGER.warning(
                "No session for id %s. Cannot continue.", continue_session.session_id
            )
            return

        try:
            if continue_session.custom_data is not None:
                # Overwrite custom data
                site_session.custom_data = continue_session.custom_data

            if continue_session.lang is not None:
                # Overwrite language
                site_session.lang = continue_session.lang

            site_session.intent_filter = continue_session.intent_filter

            site_session.send_intent_not_recognized = (
                continue_session.send_intent_not_recognized
            )

            site_session.step += 1

            _LOGGER.debug(
                "Continuing session %s (step=%s)",
                site_session.session_id,
                site_session.step,
            )

            # Stop listening
            yield AsrStopListening(
                site_id=site_session.site_id, session_id=site_session.session_id
            )

            # Ensure hotword is disabled for session
            yield HotwordToggleOff(
                site_id=site_session.site_id,
                reason=HotwordToggleReason.DIALOGUE_SESSION,
            )

            if continue_session.text:
                # Forward to TTS
                async for tts_result in self.say(
                    continue_session.text,
                    site_id=site_session.site_id,
                    session_id=continue_session.session_id,
                ):
                    yield tts_result

            # Start ASR listening
            _LOGGER.debug("Listening for session %s", site_session.session_id)
            yield AsrStartListening(
                site_id=site_session.site_id,
                session_id=site_session.session_id,
                send_audio_captured=site_session.send_audio_captured,
                lang=site_session.lang,
            )

            # Set up timeout
            asyncio.create_task(
                self.handle_session_timeout(
                    site_session.site_id, site_session.session_id, site_session.step
                )
            )

        except Exception as e:
            _LOGGER.exception("handle_continue")
            yield DialogueError(
                error=str(e),
                context=str(continue_session),
                site_id=site_session.site_id,
                session_id=continue_session.session_id,
            )
예제 #9
0
    async def start_session(
        self, new_session: SessionInfo
    ) -> typing.AsyncIterable[typing.Union[StartSessionType, EndSessionType, SayType]]:
        """Start a new session."""
        start_session = new_session.start_session
        site_session = self.session_by_site.get(new_session.site_id)

        if start_session.init.type == DialogueActionType.NOTIFICATION:
            # Notification session
            notification = start_session.init
            assert isinstance(
                notification, DialogueNotification
            ), "Not a DialogueNotification"

            if not site_session:
                # Create new session just for TTS
                _LOGGER.debug("Starting new session (id=%s)", new_session.session_id)
                self.all_sessions[new_session.session_id] = new_session
                self.session_by_site[new_session.site_id] = new_session

                yield DialogueSessionStarted(
                    site_id=new_session.site_id,
                    session_id=new_session.session_id,
                    custom_data=new_session.custom_data,
                    lang=new_session.lang,
                )

                site_session = new_session

            if notification.text:
                async for say_result in self.say(
                    notification.text,
                    site_id=site_session.site_id,
                    session_id=site_session.session_id,
                ):
                    yield say_result

            # End notification session immedately
            _LOGGER.debug("Session ended nominally: %s", site_session.session_id)
            async for end_result in self.end_session(
                DialogueSessionTerminationReason.NOMINAL,
                site_id=site_session.site_id,
                session_id=site_session.session_id,
                start_next_session=True,
            ):
                yield end_result
        else:
            # Action session
            action = start_session.init
            assert isinstance(action, DialogueAction), "Not a DialogueAction"

            new_session.custom_data = start_session.custom_data
            new_session.intent_filter = action.intent_filter
            new_session.send_intent_not_recognized = action.send_intent_not_recognized

            start_new_session = True

            if site_session:
                if action.can_be_enqueued:
                    # Queue session for later
                    session_queue = self.session_queue_by_site[new_session.site_id]

                    start_new_session = False
                    session_queue.append(new_session)

                    yield DialogueSessionQueued(
                        session_id=new_session.session_id,
                        site_id=new_session.site_id,
                        custom_data=new_session.custom_data,
                    )
                else:
                    # Abort existing session
                    _LOGGER.debug("Session aborted: %s", site_session.session_id)
                    async for end_result in self.end_session(
                        DialogueSessionTerminationReason.ABORTED_BY_USER,
                        site_id=site_session.site_id,
                        session_id=site_session.session_id,
                        start_next_session=False,
                    ):
                        yield end_result

            if start_new_session:
                # Start new session
                _LOGGER.debug("Starting new session (id=%s)", new_session.session_id)
                self.all_sessions[new_session.session_id] = new_session
                self.session_by_site[new_session.site_id] = new_session

                yield DialogueSessionStarted(
                    site_id=new_session.site_id,
                    session_id=new_session.session_id,
                    custom_data=new_session.custom_data,
                    lang=new_session.lang,
                )

                # Disable hotword for session
                yield HotwordToggleOff(
                    site_id=new_session.site_id,
                    reason=HotwordToggleReason.DIALOGUE_SESSION,
                )

                if action.text:
                    # Forward to TTS
                    async for say_result in self.say(
                        action.text,
                        site_id=new_session.site_id,
                        session_id=new_session.session_id,
                    ):
                        yield say_result

                # Start ASR listening
                _LOGGER.debug("Listening for session %s", new_session.session_id)
                if (
                    new_session.detected
                    and new_session.detected.send_audio_captured is not None
                ):
                    # Use setting from hotword detection
                    new_session.send_audio_captured = (
                        new_session.detected.send_audio_captured
                    )

                yield AsrStartListening(
                    site_id=new_session.site_id,
                    session_id=new_session.session_id,
                    send_audio_captured=new_session.send_audio_captured,
                    wakeword_id=new_session.wakeword_id,
                    lang=new_session.lang,
                )

                # Set up timeout
                asyncio.create_task(
                    self.handle_session_timeout(
                        new_session.site_id, new_session.session_id, new_session.step
                    )
                )
예제 #10
0
def test_asr_start_listening():
    """Test AsrStartListening."""
    assert AsrStartListening.topic() == "hermes/asr/startListening"
예제 #11
0
    async def start_session(self, new_session: SessionInfo):
        """Start a new session."""
        start_session = new_session.start_session

        if isinstance(start_session.init, Mapping):
            # Convert to object
            if start_session.init["type"] == DialogueActionType.NOTIFICATION:
                start_session.init = DialogueNotification(**start_session.init)
            else:
                start_session.init = DialogueAction(**start_session.init)

        if start_session.init.type == DialogueActionType.NOTIFICATION:
            # Notification session
            notification = start_session.init
            assert isinstance(notification, DialogueNotification)

            if not self.session:
                # Create new session just for TTS
                _LOGGER.debug("Starting new session (id=%s)",
                              new_session.sessionId)
                self.session = new_session

            if notification.text:
                # Forward to TTS
                yield (await self.say_and_wait(notification.text))

            # End notification session immedately
            _LOGGER.debug("Session ended nominally: %s",
                          self.session.sessionId)
            await self.end_session(DialogueSessionTerminationReason.NOMINAL)
        else:
            # Action session
            action = start_session.init
            assert isinstance(action, DialogueAction)
            _LOGGER.debug("Starting new session (id=%s)",
                          new_session.sessionId)

            new_session.customData = start_session.customData
            new_session.intentFilter = action.intentFilter
            new_session.sendIntentNotRecognized = action.sendIntentNotRecognized

            if self.session:
                # Existing session
                if action.canBeEnqueued:
                    # Queue session for later
                    self.session_queue.append(new_session)
                    yield DialogueSessionQueued(
                        sessionId=new_session.sessionId,
                        siteId=self.siteId,
                        customData=new_session.customData,
                    )
                else:
                    # Drop session
                    _LOGGER.warning("Session was dropped: %s", start_session)
            else:
                # Start new session
                _LOGGER.debug("Starting new session (id=%s)",
                              new_session.sessionId)
                self.session = new_session

                if action.text:
                    # Forward to TTS
                    await self.say_and_wait(action.text)

                # Start ASR listening
                _LOGGER.debug("Listening for session %s",
                              self.session.sessionId)
                yield AsrStartListening(siteId=self.siteId,
                                        sessionId=new_session.sessionId)

        self.session = new_session
        yield DialogueSessionStarted(
            siteId=self.siteId,
            sessionId=new_session.sessionId,
            customData=new_session.customData,
        )
예제 #12
0
    async def async_test_session(self):
        """Check good start/stop session."""
        fake_transcription = Transcription(
            text="this is a test", likelihood=1, transcribe_seconds=0, wav_seconds=0
        )

        def fake_transcribe(stream, *args):
            """Return test trancription."""
            for chunk in stream:
                if not chunk:
                    break

            return fake_transcription

        self.transcriber.transcribe_stream = fake_transcribe

        # Start session
        start_listening = AsrStartListening(
            site_id=self.site_id,
            session_id=self.session_id,
            stop_on_silence=False,
            send_audio_captured=True,
        )
        result = None
        async for response in self.hermes.on_message_blocking(start_listening):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Send in "audio"
        fake_wav_bytes = self.hermes.to_wav_bytes(secrets.token_bytes(100))
        fake_frame = AudioFrame(wav_bytes=fake_wav_bytes)
        async for response in self.hermes.on_message_blocking(
            fake_frame, site_id=self.site_id
        ):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Stop session
        stop_listening = AsrStopListening(
            site_id=self.site_id, session_id=self.session_id
        )

        results = []
        async for response in self.hermes.on_message_blocking(stop_listening):
            results.append(response)

        # Check results
        self.assertEqual(
            results,
            [
                AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id),
                AsrTextCaptured(
                    text=fake_transcription.text,
                    likelihood=fake_transcription.likelihood,
                    seconds=fake_transcription.transcribe_seconds,
                    site_id=self.site_id,
                    session_id=self.session_id,
                ),
                (
                    AsrAudioCaptured(wav_bytes=fake_wav_bytes),
                    {"site_id": self.site_id, "session_id": self.session_id},
                ),
            ],
        )