def finish_session(
        self, info: TranscriberInfo, siteId: str, sessionId: str
    ) -> typing.Iterable[AsrTextCaptured]:
        """Publish transcription result for a session if not already published"""
        # Stop silence detection
        assert info.recorder is not None
        info.recorder.stop()

        if not info.result_sent:
            # Last chunk
            info.frame_queue.put(None)

            # Wait for result
            info.result_event.wait(timeout=self.session_result_timeout)

            transcription = info.result
            if transcription:
                # Successful transcription
                yield (
                    AsrTextCaptured(
                        text=transcription.text,
                        likelihood=transcription.likelihood,
                        seconds=transcription.transcribe_seconds,
                        siteId=siteId,
                        sessionId=sessionId,
                    )
                )
            else:
                # Empty transcription
                yield AsrTextCaptured(
                    text="", likelihood=0, seconds=0, siteId=siteId, sessionId=sessionId
                )

            # Avoid re-sending transcription
            info.result_sent = True
示例#2
0
    async def async_test_ws_text(self):
        """Test api/events/text endpoint"""
        # Start listening
        event_queue = asyncio.Queue()
        connected = asyncio.Event()
        receive_task = asyncio.ensure_future(
            self.async_ws_receive("events/text", event_queue, connected))
        await asyncio.wait_for(connected.wait(), timeout=5)

        # Send in a message
        text_captured = AsrTextCaptured(
            text="this is a test",
            likelihood=1,
            seconds=0,
            site_id=self.site_id,
            session_id=self.session_id,
            wakeword_id=str(uuid4()),
        )

        self.client.publish(text_captured.topic(), text_captured.payload())

        # Wait for response
        event = json.loads(await asyncio.wait_for(event_queue.get(),
                                                  timeout=5))
        self.assertEqual(text_captured.text, event.get("text", ""))
        self.assertEqual(text_captured.site_id, event.get("siteId", ""))
        self.assertEqual(text_captured.wakeword_id,
                         event.get("wakewordId", ""))

        # Stop listening
        receive_task.cancel()
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload),
                          msg.topic)

            if NluIntent.is_topic(msg.topic):
                # Intent from query
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                intent = NluIntent(**json_payload)

                # Report to websockets
                for queue in self.message_queues:
                    self.loop.call_soon_threadsafe(queue.put_nowait,
                                                   ("intent", intent))

                self.handle_message(msg.topic, intent)
            elif msg.topic == NluIntentNotRecognized.topic():
                # Intent not recognized
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                not_recognized = NluIntentNotRecognized(**json_payload)
                # Report to websockets
                for queue in self.message_queues:
                    self.loop.call_soon_threadsafe(queue.put_nowait,
                                                   ("intent", not_recognized))

                self.handle_message(msg.topic, not_recognized)
            elif msg.topic == TtsSayFinished.topic():
                # Text to speech finished
                json_payload = json.loads(msg.payload)
                self.handle_message(msg.topic, TtsSayFinished(**json_payload))
            elif msg.topic == AsrTextCaptured.topic():
                # Speech to text result
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                self.handle_message(msg.topic, AsrTextCaptured(**json_payload))

            # Forward to external message queues
            for queue in self.message_queues:
                self.loop.call_soon_threadsafe(
                    queue.put_nowait, ("mqtt", msg.topic, msg.payload))
        except Exception:
            _LOGGER.exception("on_message")
示例#4
0
    async def transcribe(
        self,
        wav_bytes: bytes,
        site_id: str,
        session_id: typing.Optional[str] = None,
        wakeword_id: typing.Optional[str] = None,
    ) -> AsrTextCaptured:
        """Transcribe audio data and publish captured text."""
        _LOGGER.debug("Transcribing %s byte(s) of audio data", len(wav_bytes))
        transcription = self.transcriber.transcribe_wav(wav_bytes)
        if transcription:
            _LOGGER.debug(transcription)
            asr_tokens: typing.Optional[typing.List[
                typing.List[AsrToken]]] = None

            if transcription.tokens:
                # Only one level of ASR tokens
                asr_inner_tokens: typing.List[AsrToken] = []
                asr_tokens = [asr_inner_tokens]
                range_start = 0
                for ps_token in transcription.tokens:
                    range_end = range_start + len(ps_token.token) + 1
                    asr_inner_tokens.append(
                        AsrToken(
                            value=ps_token.token,
                            confidence=ps_token.likelihood,
                            range_start=range_start,
                            range_end=range_start + len(ps_token.token) + 1,
                            time=AsrTokenTime(start=ps_token.start_time,
                                              end=ps_token.end_time),
                        ))

                    range_start = range_end

            # Actual transcription
            return AsrTextCaptured(text=transcription.text,
                                   likelihood=transcription.likelihood,
                                   seconds=transcription.transcribe_seconds,
                                   site_id=site_id,
                                   session_id=session_id,
                                   asr_tokens=asr_tokens,
                                   wakeword_id=wakeword_id)

        _LOGGER.warning("Received empty transcription")
        return AsrTextCaptured(text="",
                               likelihood=0,
                               seconds=0,
                               site_id=site_id,
                               session_id=session_id,
                               wakeword_id=wakeword_id)
示例#5
0
    async def async_test_silence(self):
        """Check start/stop session with silence detection."""
        fake_transcription = Transcription(
            text="turn on the living room lamp",
            likelihood=1,
            transcribe_seconds=0,
            wav_seconds=0,
        )

        def fake_transcribe(stream, *args):
            """Return test trancription."""
            for chunk in stream:
                if not chunk:
                    break

            return fake_transcription

        self.transcriber.transcribe_stream = fake_transcribe

        # Start session
        start_listening = AsrStartListening(
            site_id=self.site_id,
            session_id=self.session_id,
            stop_on_silence=True,
            send_audio_captured=False,
        )
        result = None
        async for response in self.hermes.on_message_blocking(start_listening):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Send in "audio"
        wav_path = Path("etc/turn_on_the_living_room_lamp.wav")

        results = []
        with open(wav_path, "rb") as wav_file:
            for wav_bytes in AudioFrame.iter_wav_chunked(wav_file, 4096):
                frame = AudioFrame(wav_bytes=wav_bytes)
                async for response in self.hermes.on_message_blocking(
                    frame, site_id=self.site_id
                ):
                    results.append(response)

        # Except transcription
        self.assertEqual(
            results,
            [
                AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id),
                AsrTextCaptured(
                    text=fake_transcription.text,
                    likelihood=fake_transcription.likelihood,
                    seconds=fake_transcription.transcribe_seconds,
                    site_id=self.site_id,
                    session_id=self.session_id,
                ),
            ],
        )
    def transcribe(
            self,
            audio_data: bytes,
            siteId: str = "default",
            sessionId: str = "") -> typing.Union[AsrTextCaptured, AsrError]:
        """Transcribe audio data and publish captured text."""
        try:
            with io.BytesIO() as wav_buffer:
                wav_file: wave.Wave_write = wave.open(wav_buffer, mode="wb")
                with wav_file:
                    wav_file.setframerate(self.sample_rate)
                    wav_file.setsampwidth(self.sample_width)
                    wav_file.setnchannels(self.channels)
                    wav_file.writeframesraw(audio_data)

                transcription = self.transcriber.transcribe_wav(
                    wav_buffer.getvalue())
                if transcription:
                    # Actual transcription
                    return AsrTextCaptured(
                        text=transcription.text,
                        likelihood=transcription.likelihood,
                        seconds=transcription.transcribe_seconds,
                        siteId=siteId,
                        sessionId=sessionId,
                    )

                _LOGGER.warning("Received empty transcription")

                # Empty transcription
                return AsrTextCaptured(text="",
                                       likelihood=0,
                                       seconds=0,
                                       siteId=siteId,
                                       sessionId=sessionId)
        except Exception as e:
            _LOGGER.exception("transcribe")
            return AsrError(
                error=str(e),
                context=repr(self.transcriber),
                siteId=siteId,
                sessionId=sessionId,
            )
示例#7
0
    async def async_test_transcriber_error(self):
        """Check start/stop session with error in transcriber."""

        def fake_transcribe(stream, *args):
            """Raise an exception."""
            raise FakeException()

        self.transcriber.transcribe_stream = fake_transcribe

        # Start session
        start_listening = AsrStartListening(
            site_id=self.site_id, session_id=self.session_id, stop_on_silence=False
        )
        result = None
        async for response in self.hermes.on_message_blocking(start_listening):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Send in "audio"
        fake_wav_bytes = self.hermes.to_wav_bytes(secrets.token_bytes(100))
        fake_frame = AudioFrame(wav_bytes=fake_wav_bytes)
        async for response in self.hermes.on_message_blocking(
            fake_frame, site_id=self.site_id
        ):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Stop session
        stop_listening = AsrStopListening(
            site_id=self.site_id, session_id=self.session_id
        )

        results = []
        async for response in self.hermes.on_message_blocking(stop_listening):
            results.append(response)

        # Check results for empty transcription
        self.assertEqual(
            results,
            [
                AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id),
                AsrTextCaptured(
                    text="",
                    likelihood=0,
                    seconds=0,
                    site_id=self.site_id,
                    session_id=self.session_id,
                ),
            ],
        )
    async def transcribe_wav(self,
                             wav_bytes: bytes,
                             frames_per_chunk: int = 1024) -> AsrTextCaptured:
        """Transcribe WAV data"""
        sessionId = str(uuid4())

        def handle_captured():
            while True:
                _, message = yield

                if isinstance(message, AsrTextCaptured) and (message.sessionId
                                                             == sessionId):
                    return True, message

        def messages():
            yield AsrStartListening(siteId=self.siteId, sessionId=sessionId)

            # Break WAV into chunks
            with io.BytesIO(wav_bytes) as wav_buffer:
                with wave.open(wav_buffer, "rb") as wav_file:
                    frames_left = wav_file.getnframes()
                    while frames_left > 0:
                        with io.BytesIO() as chunk_buffer:
                            chunk_file: wave.Wave_write = wave.open(
                                chunk_buffer, "wb")
                            with chunk_file:
                                chunk_file.setframerate(
                                    wav_file.getframerate())
                                chunk_file.setsampwidth(
                                    wav_file.getsampwidth())
                                chunk_file.setnchannels(
                                    wav_file.getnchannels())
                                chunk_file.writeframes(
                                    wav_file.readframes(frames_per_chunk))

                            yield (
                                AudioFrame(wav_data=chunk_buffer.getvalue()),
                                {
                                    "siteId": self.siteId
                                },
                            )

                        frames_left -= frames_per_chunk

            yield AsrStopListening(siteId=self.siteId, sessionId=sessionId)

        topics = [AsrTextCaptured.topic()]

        # Expecting only a single result
        async for result in self.publish_wait(handle_captured(), messages(),
                                              topics):
            return result
示例#9
0
    def test_http_speech_to_text_hermes(self):
        """Text speech-to-text HTTP endpoint (Hermes format)"""
        response = requests.post(
            self.api_url("speech-to-text"),
            data=self.wav_bytes,
            params={"outputFormat": "hermes"},
        )
        self.check_status(response)

        result = response.json()
        self.assertEqual(result["type"], "textCaptured")

        text_captured = AsrTextCaptured.from_dict(result["value"])

        self.assertEqual(text_captured.text, "turn on the living room lamp")
示例#10
0
 def on_connect(self, client, userdata, flags, rc):
     """Connected to MQTT broker."""
     try:
         topics = [
             DialogueStartSession.topic(),
             DialogueContinueSession.topic(),
             DialogueEndSession.topic(),
             TtsSayFinished.topic(),
             NluIntent.topic(intent_name="#"),
             NluIntentNotRecognized.topic(),
             AsrTextCaptured.topic(),
         ] + list(self.wakeword_topics.keys())
         for topic in topics:
             self.client.subscribe(topic)
             _LOGGER.debug("Subscribed to %s", topic)
     except Exception:
         _LOGGER.exception("on_connect")
示例#11
0
def test_asr_text_captured():
    """Test AsrTextCaptured."""
    assert AsrTextCaptured.topic() == "hermes/asr/textCaptured"
示例#12
0
    async def finish_session(
        self, info: TranscriberInfo, site_id: str,
        session_id: typing.Optional[str]
    ) -> typing.AsyncIterable[typing.Union[AsrTextCaptured,
                                           AudioCapturedType]]:
        """Publish transcription result for a session if not already published"""

        if info.recorder is not None:
            # Stop silence detection and get trimmed audio
            audio_data = info.recorder.stop()
        else:
            # Use complete audio buffer
            assert info.audio_buffer is not None
            audio_data = info.audio_buffer

        if not info.result_sent:
            # Avoid re-sending transcription
            info.result_sent = True

            # Last chunk
            info.frame_queue.put(None)

            # Wait for result
            result_success = info.result_event.wait(
                timeout=self.session_result_timeout)
            if not result_success:
                # Mark transcription as non-reusable
                info.reuse = False

            transcription = info.result
            assert info.start_listening is not None

            if transcription:
                # Successful transcription
                yield (AsrTextCaptured(
                    text=transcription.text,
                    likelihood=transcription.likelihood,
                    seconds=transcription.transcribe_seconds,
                    site_id=site_id,
                    session_id=session_id,
                    lang=info.start_listening.lang,
                ))
            else:
                # Empty transcription
                yield AsrTextCaptured(
                    text="",
                    likelihood=0,
                    seconds=0,
                    site_id=site_id,
                    session_id=session_id,
                    lang=info.start_listening.lang,
                )

            if info.start_listening.send_audio_captured:
                wav_bytes = self.to_wav_bytes(audio_data)

                # Send audio data
                yield (
                    # pylint: disable=E1121
                    AsrAudioCaptured(wav_bytes),
                    {
                        "site_id": site_id,
                        "session_id": session_id
                    },
                )
示例#13
0
    async def finish_session(
        self, info: TranscriberInfo, site_id: str,
        session_id: typing.Optional[str]
    ) -> typing.AsyncIterable[typing.Union[
            AsrRecordingFinished, AsrTextCaptured, AudioCapturedType]]:
        """Publish transcription result for a session if not already published"""

        if info.recorder is not None:
            # Stop silence detection and get trimmed audio
            audio_data = info.recorder.stop()
        else:
            # Use complete audio buffer
            assert info.audio_buffer is not None
            audio_data = info.audio_buffer

        if not info.result_sent:
            # Send recording finished message
            yield AsrRecordingFinished(site_id=site_id, session_id=session_id)

            # Avoid re-sending transcription
            info.result_sent = True

            # Last chunk
            info.frame_queue.put(None)

            # Wait for result
            result_success = info.result_event.wait(
                timeout=self.session_result_timeout)
            if not result_success:
                # Mark transcription as non-reusable
                info.reuse = False

            transcription = info.result
            assert info.start_listening is not None

            if transcription:
                # Successful transcription
                asr_tokens: typing.Optional[typing.List[
                    typing.List[AsrToken]]] = None

                if transcription.tokens:
                    # Only one level of ASR tokens
                    asr_inner_tokens: typing.List[AsrToken] = []
                    asr_tokens = [asr_inner_tokens]
                    range_start = 0
                    for ps_token in transcription.tokens:
                        range_end = range_start + len(ps_token.token) + 1
                        asr_inner_tokens.append(
                            AsrToken(
                                value=ps_token.token,
                                confidence=ps_token.likelihood,
                                range_start=range_start,
                                range_end=range_start + len(ps_token.token) +
                                1,
                                time=AsrTokenTime(start=ps_token.start_time,
                                                  end=ps_token.end_time),
                            ))

                        range_start = range_end

                yield (AsrTextCaptured(
                    text=transcription.text,
                    likelihood=transcription.likelihood,
                    seconds=transcription.transcribe_seconds,
                    site_id=site_id,
                    session_id=session_id,
                    asr_tokens=asr_tokens,
                    lang=(info.start_listening.lang or self.lang),
                ))
            else:
                # Empty transcription
                yield AsrTextCaptured(
                    text="",
                    likelihood=0,
                    seconds=0,
                    site_id=site_id,
                    session_id=session_id,
                    lang=(info.start_listening.lang or self.lang),
                )

            if info.start_listening.send_audio_captured:
                wav_bytes = self.to_wav_bytes(audio_data)

                # Send audio data
                yield (
                    # pylint: disable=E1121
                    AsrAudioCaptured(wav_bytes),
                    {
                        "site_id": site_id,
                        "session_id": session_id
                    },
                )
示例#14
0
    def on_message(self, client, userdata, msg):
        """Received message from MQTT broker."""
        try:
            _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload),
                          msg.topic)
            if msg.topic == DialogueStartSession.topic():
                # Start session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_start(DialogueStartSession(**json_payload)),
                    self.loop)
            elif msg.topic == DialogueContinueSession.topic():
                # Continue session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_continue(
                        DialogueContinueSession(**json_payload)),
                    self.loop,
                )
            elif msg.topic == DialogueEndSession.topic():
                # End session
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                # Run outside event loop
                self.handle_end(DialogueEndSession(**json_payload))
            elif msg.topic == TtsSayFinished.topic():
                # TTS finished
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Signal event loop
                self.loop.call_soon_threadsafe(self.say_finished_event.set)
            elif msg.topic == AsrTextCaptured.topic():
                # Text captured
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run outside event loop
                self.handle_text_captured(AsrTextCaptured(**json_payload))
            elif msg.topic.startswith(NluIntent.topic(intent_name="")):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                self.handle_recognized(NluIntent(**json_payload))
            elif msg.topic.startswith(NluIntentNotRecognized.topic()):
                # Intent recognized
                json_payload = json.loads(msg.payload)
                if not self._check_sessionId(json_payload):
                    return

                # Run in event loop (for TTS)
                asyncio.run_coroutine_threadsafe(
                    self.handle_not_recognized(
                        NluIntentNotRecognized(**json_payload)),
                    self.loop,
                )
            elif msg.topic in self.wakeword_topics:
                json_payload = json.loads(msg.payload)
                if not self._check_siteId(json_payload):
                    return

                wakeword_id = self.wakeword_topics[msg.topic]
                asyncio.run_coroutine_threadsafe(
                    self.handle_wake(wakeword_id,
                                     HotwordDetected(**json_payload)),
                    self.loop,
                )
        except Exception:
            _LOGGER.exception("on_message")
示例#15
0
    async def async_test_session(self):
        """Check good start/stop session."""
        fake_transcription = Transcription(
            text="this is a test", likelihood=1, transcribe_seconds=0, wav_seconds=0
        )

        def fake_transcribe(stream, *args):
            """Return test trancription."""
            for chunk in stream:
                if not chunk:
                    break

            return fake_transcription

        self.transcriber.transcribe_stream = fake_transcribe

        # Start session
        start_listening = AsrStartListening(
            site_id=self.site_id,
            session_id=self.session_id,
            stop_on_silence=False,
            send_audio_captured=True,
        )
        result = None
        async for response in self.hermes.on_message_blocking(start_listening):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Send in "audio"
        fake_wav_bytes = self.hermes.to_wav_bytes(secrets.token_bytes(100))
        fake_frame = AudioFrame(wav_bytes=fake_wav_bytes)
        async for response in self.hermes.on_message_blocking(
            fake_frame, site_id=self.site_id
        ):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Stop session
        stop_listening = AsrStopListening(
            site_id=self.site_id, session_id=self.session_id
        )

        results = []
        async for response in self.hermes.on_message_blocking(stop_listening):
            results.append(response)

        # Check results
        self.assertEqual(
            results,
            [
                AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id),
                AsrTextCaptured(
                    text=fake_transcription.text,
                    likelihood=fake_transcription.likelihood,
                    seconds=fake_transcription.transcribe_seconds,
                    site_id=self.site_id,
                    session_id=self.session_id,
                ),
                (
                    AsrAudioCaptured(wav_bytes=fake_wav_bytes),
                    {"site_id": self.site_id, "session_id": self.session_id},
                ),
            ],
        )
示例#16
0
    async def handle_stop_listening(
        self, stop_listening: AsrStopListening
    ) -> typing.AsyncIterable[
        typing.Union[
            AsrTextCaptured, typing.Tuple[AsrAudioCaptured, TopicArgs], AsrError
        ]
    ]:
        """Stop ASR session."""
        _LOGGER.debug("<- %s", stop_listening)

        try:
            session = self.asr_sessions.pop(stop_listening.session_id, None)
            if session is None:
                _LOGGER.warning("Session not found for %s", stop_listening.session_id)
                return

            assert session.sample_rate is not None, "No sample rate"
            assert session.sample_width is not None, "No sample width"
            assert session.channels is not None, "No channels"

            if session.start_listening.stop_on_silence:
                # Use recorded voice command
                audio_data = session.recorder.stop()
            else:
                # Use entire audio
                audio_data = session.audio_data

            # Process entire WAV file
            wav_bytes = self.to_wav_bytes(
                audio_data, session.sample_rate, session.sample_width, session.channels
            )
            _LOGGER.debug("Received %s byte(s) of WAV data", len(wav_bytes))

            if self.asr_url:
                _LOGGER.debug(self.asr_url)

                # Remote ASR server
                async with self.http_session.post(
                    self.asr_url,
                    data=wav_bytes,
                    headers={"Content-Type": "audio/wav", "Accept": "application/json"},
                    ssl=self.ssl_context,
                ) as response:
                    response.raise_for_status()
                    transcription_dict = await response.json()
            elif self.asr_command:
                # Local ASR command
                _LOGGER.debug(self.asr_command)

                start_time = time.perf_counter()
                proc = await asyncio.create_subprocess_exec(
                    *self.asr_command,
                    stdin=asyncio.subprocess.PIPE,
                    stdout=asyncio.subprocess.PIPE,
                    stderr=asyncio.subprocess.PIPE,
                )

                output, error = await proc.communicate(wav_bytes)

                if error:
                    _LOGGER.debug(error.decode())

                text = output.decode()
                end_time = time.perf_counter()

                transcription_dict = {
                    "text": text,
                    "transcribe_seconds": (end_time - start_time),
                }
            else:
                # Empty transcription
                _LOGGER.warning(
                    "No ASR URL or command. Only empty transcriptions will be returned."
                )
                transcription_dict = {}

            # Publish transcription
            yield AsrTextCaptured(
                text=transcription_dict.get("text", ""),
                likelihood=float(transcription_dict.get("likelihood", 0)),
                seconds=float(transcription_dict.get("transcribe_seconds", 0)),
                site_id=stop_listening.site_id,
                session_id=stop_listening.session_id,
                lang=session.start_listening.lang,
            )

            if session.start_listening.send_audio_captured:
                # Send audio data
                yield (
                    AsrAudioCaptured(wav_bytes=wav_bytes),
                    {
                        "site_id": stop_listening.site_id,
                        "session_id": stop_listening.session_id,
                    },
                )

        except Exception as e:
            _LOGGER.exception("handle_stop_listening")
            yield AsrError(
                error=str(e),
                context=f"url='{self.asr_url}', command='{self.asr_command}'",
                site_id=stop_listening.site_id,
                session_id=stop_listening.session_id,
            )