예제 #1
0
    async def stop_listening(
        self, message: AsrStopListening
    ) -> typing.AsyncIterable[
        typing.Union[
            AsrTextCaptured,
            AsrError,
            typing.Tuple[AsrAudioCaptured, typing.Dict[str, typing.Any]],
        ]
    ]:
        """Stop recording audio data for a session."""
        try:
            session = self.sessions.pop(message.session_id, None)
            if session:
                # Stop session
                if session.recorder:
                    audio_data = session.recorder.stop()
                else:
                    assert session.audio_buffer is not None
                    audio_data = session.audio_buffer

                wav_bytes = self.to_wav_bytes(audio_data)

                _LOGGER.debug(
                    "Received a total of %s byte(s) for WAV data for session %s",
                    session.num_wav_bytes,
                    message.session_id,
                )

                if not session.transcription_sent:
                    # Send transcription
                    session.transcription_sent = True

                    yield (
                        await self.transcribe(
                            wav_bytes,
                            site_id=message.site_id,
                            session_id=message.session_id,
                            lang=session.start_listening.lang,
                        )
                    )

                    if session.start_listening.send_audio_captured:
                        # Send audio data
                        yield (
                            AsrAudioCaptured(wav_bytes=wav_bytes),
                            {
                                "site_id": message.site_id,
                                "session_id": message.session_id,
                            },
                        )

            _LOGGER.debug("Stopping listening (session_id=%s)", message.session_id)
        except Exception as e:
            _LOGGER.exception("stop_listening")
            yield AsrError(
                error=str(e),
                context="stop_listening",
                site_id=message.site_id,
                session_id=message.session_id,
            )
예제 #2
0
    async def handle_audio_frame(
        self,
        frame_wav_bytes: bytes,
        site_id: str = "default",
        session_id: typing.Optional[str] = None,
    ) -> typing.AsyncIterable[typing.Union[
            AsrTextCaptured, AsrError, typing.Tuple[
                AsrAudioCaptured, typing.Dict[str, typing.Any]], ]]:
        """Process single frame of WAV audio"""
        # Don't process audio if no sessions
        if not self.sessions:
            return

        audio_data = self.maybe_convert_wav(frame_wav_bytes)

        if session_id is None:
            # Add to every open session
            target_sessions = list(self.sessions.items())
        else:
            # Add to single session
            target_sessions = [(session_id, self.sessions[session_id])]

        # Add audio to session(s)
        for target_id, session in target_sessions:
            try:
                # Skip if site_id doesn't match
                if session.start_listening.site_id != site_id:
                    continue

                session.num_wav_bytes += len(frame_wav_bytes)
                if session.recorder:
                    # Check for end of voice command
                    command = session.recorder.process_chunk(audio_data)
                    if command and (command.result
                                    == VoiceCommandResult.SUCCESS):
                        assert command.audio_data is not None
                        _LOGGER.debug(
                            "Voice command recorded for session %s (%s byte(s))",
                            target_id,
                            len(command.audio_data),
                        )

                        session.transcription_sent = True
                        wav_bytes = self.to_wav_bytes(command.audio_data)

                        yield (await self.transcribe(
                            wav_bytes,
                            site_id=site_id,
                            session_id=target_id,
                            wakeword_id=session.start_listening.wakeword_id))

                        if session.start_listening.send_audio_captured:
                            # Send audio data
                            yield (
                                AsrAudioCaptured(wav_bytes=wav_bytes),
                                {
                                    "site_id": site_id,
                                    "session_id": target_id
                                },
                            )

                        # Reset session (but keep open)
                        session.recorder.stop()
                        session.recorder.start()
                else:
                    # Add to buffer
                    assert session.audio_buffer is not None
                    session.audio_buffer += audio_data
            except Exception as e:
                _LOGGER.exception("handle_audio_frame")
                yield AsrError(
                    error=str(e),
                    context=repr(self.transcriber),
                    site_id=site_id,
                    session_id=target_id,
                )
예제 #3
0
    async def finish_session(
        self, info: TranscriberInfo, site_id: str,
        session_id: typing.Optional[str]
    ) -> typing.AsyncIterable[typing.Union[AsrTextCaptured,
                                           AudioCapturedType]]:
        """Publish transcription result for a session if not already published"""

        if info.recorder is not None:
            # Stop silence detection and get trimmed audio
            audio_data = info.recorder.stop()
        else:
            # Use complete audio buffer
            assert info.audio_buffer is not None
            audio_data = info.audio_buffer

        if not info.result_sent:
            # Avoid re-sending transcription
            info.result_sent = True

            # Last chunk
            info.frame_queue.put(None)

            # Wait for result
            result_success = info.result_event.wait(
                timeout=self.session_result_timeout)
            if not result_success:
                # Mark transcription as non-reusable
                info.reuse = False

            transcription = info.result
            assert info.start_listening is not None

            if transcription:
                # Successful transcription
                yield (AsrTextCaptured(
                    text=transcription.text,
                    likelihood=transcription.likelihood,
                    seconds=transcription.transcribe_seconds,
                    site_id=site_id,
                    session_id=session_id,
                    lang=info.start_listening.lang,
                ))
            else:
                # Empty transcription
                yield AsrTextCaptured(
                    text="",
                    likelihood=0,
                    seconds=0,
                    site_id=site_id,
                    session_id=session_id,
                    lang=info.start_listening.lang,
                )

            if info.start_listening.send_audio_captured:
                wav_bytes = self.to_wav_bytes(audio_data)

                # Send audio data
                yield (
                    # pylint: disable=E1121
                    AsrAudioCaptured(wav_bytes),
                    {
                        "site_id": site_id,
                        "session_id": session_id
                    },
                )
예제 #4
0
    async def finish_session(
        self, info: TranscriberInfo, site_id: str,
        session_id: typing.Optional[str]
    ) -> typing.AsyncIterable[typing.Union[
            AsrRecordingFinished, AsrTextCaptured, AudioCapturedType]]:
        """Publish transcription result for a session if not already published"""

        if info.recorder is not None:
            # Stop silence detection and get trimmed audio
            audio_data = info.recorder.stop()
        else:
            # Use complete audio buffer
            assert info.audio_buffer is not None
            audio_data = info.audio_buffer

        if not info.result_sent:
            # Send recording finished message
            yield AsrRecordingFinished(site_id=site_id, session_id=session_id)

            # Avoid re-sending transcription
            info.result_sent = True

            # Last chunk
            info.frame_queue.put(None)

            # Wait for result
            result_success = info.result_event.wait(
                timeout=self.session_result_timeout)
            if not result_success:
                # Mark transcription as non-reusable
                info.reuse = False

            transcription = info.result
            assert info.start_listening is not None

            if transcription:
                # Successful transcription
                asr_tokens: typing.Optional[typing.List[
                    typing.List[AsrToken]]] = None

                if transcription.tokens:
                    # Only one level of ASR tokens
                    asr_inner_tokens: typing.List[AsrToken] = []
                    asr_tokens = [asr_inner_tokens]
                    range_start = 0
                    for ps_token in transcription.tokens:
                        range_end = range_start + len(ps_token.token) + 1
                        asr_inner_tokens.append(
                            AsrToken(
                                value=ps_token.token,
                                confidence=ps_token.likelihood,
                                range_start=range_start,
                                range_end=range_start + len(ps_token.token) +
                                1,
                                time=AsrTokenTime(start=ps_token.start_time,
                                                  end=ps_token.end_time),
                            ))

                        range_start = range_end

                yield (AsrTextCaptured(
                    text=transcription.text,
                    likelihood=transcription.likelihood,
                    seconds=transcription.transcribe_seconds,
                    site_id=site_id,
                    session_id=session_id,
                    asr_tokens=asr_tokens,
                    lang=(info.start_listening.lang or self.lang),
                ))
            else:
                # Empty transcription
                yield AsrTextCaptured(
                    text="",
                    likelihood=0,
                    seconds=0,
                    site_id=site_id,
                    session_id=session_id,
                    lang=(info.start_listening.lang or self.lang),
                )

            if info.start_listening.send_audio_captured:
                wav_bytes = self.to_wav_bytes(audio_data)

                # Send audio data
                yield (
                    # pylint: disable=E1121
                    AsrAudioCaptured(wav_bytes),
                    {
                        "site_id": site_id,
                        "session_id": session_id
                    },
                )
예제 #5
0
    async def async_test_session(self):
        """Check good start/stop session."""
        fake_transcription = Transcription(
            text="this is a test", likelihood=1, transcribe_seconds=0, wav_seconds=0
        )

        def fake_transcribe(stream, *args):
            """Return test trancription."""
            for chunk in stream:
                if not chunk:
                    break

            return fake_transcription

        self.transcriber.transcribe_stream = fake_transcribe

        # Start session
        start_listening = AsrStartListening(
            site_id=self.site_id,
            session_id=self.session_id,
            stop_on_silence=False,
            send_audio_captured=True,
        )
        result = None
        async for response in self.hermes.on_message_blocking(start_listening):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Send in "audio"
        fake_wav_bytes = self.hermes.to_wav_bytes(secrets.token_bytes(100))
        fake_frame = AudioFrame(wav_bytes=fake_wav_bytes)
        async for response in self.hermes.on_message_blocking(
            fake_frame, site_id=self.site_id
        ):
            result = response

        # No response expected
        self.assertIsNone(result)

        # Stop session
        stop_listening = AsrStopListening(
            site_id=self.site_id, session_id=self.session_id
        )

        results = []
        async for response in self.hermes.on_message_blocking(stop_listening):
            results.append(response)

        # Check results
        self.assertEqual(
            results,
            [
                AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id),
                AsrTextCaptured(
                    text=fake_transcription.text,
                    likelihood=fake_transcription.likelihood,
                    seconds=fake_transcription.transcribe_seconds,
                    site_id=self.site_id,
                    session_id=self.session_id,
                ),
                (
                    AsrAudioCaptured(wav_bytes=fake_wav_bytes),
                    {"site_id": self.site_id, "session_id": self.session_id},
                ),
            ],
        )
예제 #6
0
    async def handle_stop_listening(
        self, stop_listening: AsrStopListening
    ) -> typing.AsyncIterable[
        typing.Union[
            AsrTextCaptured, typing.Tuple[AsrAudioCaptured, TopicArgs], AsrError
        ]
    ]:
        """Stop ASR session."""
        _LOGGER.debug("<- %s", stop_listening)

        try:
            session = self.asr_sessions.pop(stop_listening.session_id, None)
            if session is None:
                _LOGGER.warning("Session not found for %s", stop_listening.session_id)
                return

            assert session.sample_rate is not None, "No sample rate"
            assert session.sample_width is not None, "No sample width"
            assert session.channels is not None, "No channels"

            if session.start_listening.stop_on_silence:
                # Use recorded voice command
                audio_data = session.recorder.stop()
            else:
                # Use entire audio
                audio_data = session.audio_data

            # Process entire WAV file
            wav_bytes = self.to_wav_bytes(
                audio_data, session.sample_rate, session.sample_width, session.channels
            )
            _LOGGER.debug("Received %s byte(s) of WAV data", len(wav_bytes))

            if self.asr_url:
                _LOGGER.debug(self.asr_url)

                # Remote ASR server
                async with self.http_session.post(
                    self.asr_url,
                    data=wav_bytes,
                    headers={"Content-Type": "audio/wav", "Accept": "application/json"},
                    ssl=self.ssl_context,
                ) as response:
                    response.raise_for_status()
                    transcription_dict = await response.json()
            elif self.asr_command:
                # Local ASR command
                _LOGGER.debug(self.asr_command)

                start_time = time.perf_counter()
                proc = await asyncio.create_subprocess_exec(
                    *self.asr_command,
                    stdin=asyncio.subprocess.PIPE,
                    stdout=asyncio.subprocess.PIPE,
                    stderr=asyncio.subprocess.PIPE,
                )

                output, error = await proc.communicate(wav_bytes)

                if error:
                    _LOGGER.debug(error.decode())

                text = output.decode()
                end_time = time.perf_counter()

                transcription_dict = {
                    "text": text,
                    "transcribe_seconds": (end_time - start_time),
                }
            else:
                # Empty transcription
                _LOGGER.warning(
                    "No ASR URL or command. Only empty transcriptions will be returned."
                )
                transcription_dict = {}

            # Publish transcription
            yield AsrTextCaptured(
                text=transcription_dict.get("text", ""),
                likelihood=float(transcription_dict.get("likelihood", 0)),
                seconds=float(transcription_dict.get("transcribe_seconds", 0)),
                site_id=stop_listening.site_id,
                session_id=stop_listening.session_id,
                lang=session.start_listening.lang,
            )

            if session.start_listening.send_audio_captured:
                # Send audio data
                yield (
                    AsrAudioCaptured(wav_bytes=wav_bytes),
                    {
                        "site_id": stop_listening.site_id,
                        "session_id": stop_listening.session_id,
                    },
                )

        except Exception as e:
            _LOGGER.exception("handle_stop_listening")
            yield AsrError(
                error=str(e),
                context=f"url='{self.asr_url}', command='{self.asr_command}'",
                site_id=stop_listening.site_id,
                session_id=stop_listening.session_id,
            )