Example #1
0
    def test_http_mqtt_text_to_speech(self):
        """Test text-to-speech HTTP endpoint"""
        text = "This is a test."
        self.client.subscribe(TtsSay.topic())
        self.client.subscribe(AudioPlayBytes.topic(site_id=self.site_id))
        self.client.subscribe(TtsSayFinished.topic())

        response = requests.post(
            self.api_url("text-to-speech"),
            data=text,
            params={
                "siteId": self.site_id,
                "sessionId": self.session_id
            },
        )
        self.check_status(response)

        wav_data = response.content
        self.assertGreater(len(wav_data), 0)

        # Check tts/say
        tts_say_msg = self.mqtt_messages.get(timeout=5)
        self.assertTrue(TtsSay.is_topic(tts_say_msg.topic))

        tts_say = TtsSay.from_dict(json.loads(tts_say_msg.payload))
        self.assertEqual(tts_say.site_id, self.site_id)
        self.assertEqual(tts_say.session_id, self.session_id)
        self.assertEqual(tts_say.text, text)

        # Check audioServer/playBytes
        play_bytes_msg = self.mqtt_messages.get(timeout=5)
        self.assertTrue(AudioPlayBytes.is_topic(play_bytes_msg.topic))
        self.assertEqual(AudioPlayBytes.get_site_id(play_bytes_msg.topic),
                         self.site_id)
        self.assertEqual(play_bytes_msg.payload, wav_data)

        # Check tts/sayFinished
        tts_finished_msg = self.mqtt_messages.get(timeout=5)
        self.assertTrue(TtsSayFinished.is_topic(tts_finished_msg.topic))

        tts_finished = TtsSayFinished.from_dict(
            json.loads(tts_finished_msg.payload))
        self.assertEqual(tts_finished.session_id, self.session_id)

        # Ask for repeat
        response = requests.post(self.api_url("text-to-speech"),
                                 params={"repeat": "true"})
        self.check_status(response)
        self.assertEqual(wav_data, response.content)
        def handle_finished():
            say_finished: typing.Optional[TtsSayFinished] = None
            play_bytes: typing.Optional[
                AudioPlayBytes
            ] = None if capture_audio else True
            play_finished = not wait_play_finished

            while True:
                topic, message = yield

                if isinstance(message, TtsSayFinished) and (message.id == tts_id):
                    say_finished = message
                    play_finished = True
                elif isinstance(message, TtsError):
                    # Assume audio playback didn't happen
                    say_finished = message
                    play_bytes = True
                    play_finished = True
                elif isinstance(message, AudioPlayBytes):
                    request_id = AudioPlayBytes.get_request_id(topic)
                    if request_id == tts_id:
                        play_bytes = message
                elif isinstance(message, AudioPlayError):
                    play_bytes = message

                if say_finished and play_bytes and play_finished:
                    return (say_finished, play_bytes)
Example #3
0
 async def on_message(
     self,
     message: Message,
     site_id: typing.Optional[str] = None,
     session_id: typing.Optional[str] = None,
     topic: typing.Optional[str] = None,
 ) -> GeneratorType:
     """Received message from MQTT broker."""
     if isinstance(message, AudioPlayBytes):
         assert site_id and topic, "Missing site_id or topic"
         request_id = AudioPlayBytes.get_request_id(topic)
         session_id = session_id or ""
         async for play_result in self.handle_play(request_id,
                                                   message.wav_bytes,
                                                   site_id=site_id,
                                                   session_id=session_id):
             yield play_result
     elif isinstance(message, AudioGetDevices):
         async for device_result in self.handle_get_devices(message):
             yield device_result
     elif isinstance(message, AudioToggleOff):
         self.enabled = False
         _LOGGER.debug("Disabled audio")
     elif isinstance(message, AudioToggleOn):
         self.enabled = True
         _LOGGER.debug("Enabled audio")
     else:
         _LOGGER.warning("Unexpected message: %s", message)
 def messages():
     yield (
         AudioPlayBytes(wav_data=wav_data),
         {
             "siteId": self.siteId,
             "requestId": requestId
         },
     )
Example #5
0
def test_audio_play_bytes():
    """Test AudioPlayBytes."""
    assert AudioPlayBytes.is_topic(
        AudioPlayBytes.topic(site_id=site_id, request_id=request_id))
    assert (AudioPlayBytes.get_site_id(
        AudioPlayBytes.topic(site_id=site_id,
                             request_id=request_id)) == site_id)
    assert (AudioPlayBytes.get_request_id(
        AudioPlayBytes.topic(site_id=site_id,
                             request_id=request_id)) == request_id)
Example #6
0
    async def handle_say(
        self, say: TtsSay
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[AudioPlayBytes, TopicArgs], TtsSayFinished, TtsError]
    ]:
        """Do text to speech."""
        try:
            if self.tts_url:
                # Remote text to speech server
                _LOGGER.debug(self.tts_url)

                params = {"play": "false"}
                if say.lang:
                    # Add ?language=<lang> query parameter
                    params["language"] = say.lang

                async with self.http_session.post(
                    self.tts_url, data=say.text, params=params, ssl=self.ssl_context
                ) as response:
                    response.raise_for_status()
                    content_type = response.headers["Content-Type"]
                    if content_type != "audio/wav":
                        _LOGGER.warning(
                            "Expected audio/wav content type, got %s", content_type
                        )

                    wav_bytes = await response.read()
                    if wav_bytes:
                        yield (
                            AudioPlayBytes(wav_bytes=wav_bytes),
                            {"site_id": say.site_id, "request_id": say.id},
                        )
                    else:
                        _LOGGER.error("Received empty response")
        except Exception as e:
            _LOGGER.exception("handle_say")
            yield TtsError(
                error=str(e),
                context=say.id,
                site_id=say.site_id,
                session_id=say.session_id,
            )
        finally:
            yield TtsSayFinished(
                id=say.id, site_id=say.site_id, session_id=say.session_id
            )
Example #7
0
    def test_topics(self):
        """Check get_ methods for topics"""
        siteId = "testSiteId"
        requestId = "testRequestId"
        intentName = "testIntent"
        wakewordId = "testWakeWord"

        # AudioFrame
        self.assertTrue(AudioFrame.is_topic(AudioFrame.topic(siteId=siteId)))
        self.assertEqual(
            AudioFrame.get_siteId(AudioFrame.topic(siteId=siteId)), siteId)

        # AudioPlayBytes
        self.assertTrue(
            AudioPlayBytes.is_topic(
                AudioPlayBytes.topic(siteId=siteId, requestId=requestId)))
        self.assertEqual(
            AudioPlayBytes.get_siteId(
                AudioPlayBytes.topic(siteId=siteId, requestId=requestId)),
            siteId,
        )
        self.assertEqual(
            AudioPlayBytes.get_requestId(
                AudioPlayBytes.topic(siteId=siteId, requestId=requestId)),
            requestId,
        )

        # AudioPlayFinished
        self.assertTrue(
            AudioPlayFinished.is_topic(AudioPlayFinished.topic(siteId=siteId)))
        self.assertEqual(
            AudioPlayFinished.get_siteId(
                AudioPlayFinished.topic(siteId=siteId)), siteId)

        # NluIntent
        self.assertTrue(
            NluIntent.is_topic(NluIntent.topic(intentName=intentName)))
        self.assertEqual(
            NluIntent.get_intentName(NluIntent.topic(intentName=intentName)),
            intentName)

        # HotwordDetected
        self.assertTrue(
            HotwordDetected.is_topic(
                HotwordDetected.topic(wakewordId=wakewordId)))
        self.assertEqual(
            HotwordDetected.get_wakewordId(
                HotwordDetected.topic(wakewordId=wakewordId)),
            wakewordId,
        )
Example #8
0
    def handle_say(self, say: TtsSay):
        """Run TTS system and publish WAV data."""
        wav_bytes: typing.Optional[bytes] = None

        try:
            say_command = shlex.split(
                self.tts_command.format(lang=say.lang)) + [say.text]
            _LOGGER.debug(say_command)

            wav_bytes = subprocess.check_output(say_command)
            _LOGGER.debug("Got %s byte(s) of WAV data", len(wav_bytes))
        except Exception:
            _LOGGER.exception("tts_command")
        finally:
            self.publish(TtsSayFinished(id=say.id, sessionId=say.sessionId))

        if wav_bytes:
            # Play WAV
            if self.play_command:
                try:
                    # Play locally
                    play_command = shlex.split(
                        self.play_command.format(lang=say.lang))
                    _LOGGER.debug(play_command)

                    subprocess.run(play_command, input=wav_bytes, check=True)
                except Exception:
                    _LOGGER.exception("play_command")
            else:
                # Publish playBytes
                request_id = say.id or str(uuid4())
                self.client.publish(
                    AudioPlayBytes.topic(site_id=self.siteId,
                                         request_id=request_id),
                    wav_bytes,
                )
    async def maybe_play_sound(
        self,
        sound_name: str,
        site_id: typing.Optional[str] = None,
        request_id: typing.Optional[str] = None,
        block: bool = True,
    ) -> typing.AsyncIterable[SoundsType]:
        """Play WAV sound through audio out if it exists."""
        if site_id in self.no_sound:
            _LOGGER.debug("Sound is disabled for site %s", site_id)
            return

        site_id = site_id or self.site_id
        sound_path = self.sound_paths.get(sound_name)
        if sound_path:
            if sound_path.is_dir():
                sound_file_paths = [
                    p
                    for p in sound_path.rglob("*")
                    if p.is_file() and (p.suffix in self.sound_suffixes)
                ]
                if not sound_file_paths:
                    _LOGGER.debug("No sound files found in %s", str(sound_path))
                    return

                sound_path = random.choice(sound_file_paths)
            elif not sound_path.is_file():
                _LOGGER.error("Sound does not exist: %s", str(sound_path))
                return

            _LOGGER.debug("Playing sound %s", str(sound_path))

            # Convert to WAV
            wav_bytes = DialogueHermesMqtt.convert_to_wav(sound_path)

            if (self.volume is not None) and (self.volume != 1.0):
                wav_bytes = DialogueHermesMqtt.change_volume(wav_bytes, self.volume)

            # Send messages
            request_id = request_id or str(uuid4())
            finished_event = asyncio.Event()
            finished_id = request_id
            self.message_events[AudioPlayFinished][finished_id] = finished_event

            # Disable ASR/hotword at site
            yield HotwordToggleOff(
                site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO
            )
            yield AsrToggleOff(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO)

            # Wait for messages to be delivered
            await asyncio.sleep(self.toggle_delay)

            try:
                yield (
                    AudioPlayBytes(wav_bytes=wav_bytes),
                    {"site_id": site_id, "request_id": request_id},
                )

                # Wait for finished event or WAV duration
                if block:
                    wav_duration = get_wav_duration(wav_bytes)
                    wav_timeout = wav_duration + self.sound_timeout_extra
                    _LOGGER.debug(
                        "Waiting for playFinished (id=%s, timeout=%s)",
                        finished_id,
                        wav_timeout,
                    )
                    await asyncio.wait_for(finished_event.wait(), timeout=wav_timeout)
            except asyncio.TimeoutError:
                _LOGGER.warning("Did not receive sayFinished before timeout")
            except Exception:
                _LOGGER.exception("maybe_play_sound")
            finally:
                # Wait for audio to finish playing
                await asyncio.sleep(self.toggle_delay)

                # Re-enable ASR/hotword at site
                yield HotwordToggleOn(
                    site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO
                )
                yield AsrToggleOn(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO)
Example #10
0
    async def handle_say(
        self, say: TtsSay
    ) -> typing.AsyncIterable[typing.Union[TtsSayFinished, typing.Tuple[
            AudioPlayBytes, TopicArgs], TtsError, AudioPlayError, ]]:
        """Run TTS system and publish WAV data."""
        wav_bytes: typing.Optional[bytes] = None
        temp_wav_path: typing.Optional[str] = None

        try:
            language = say.lang or self.language
            format_args = {"lang": language}

            if self.use_temp_wav:
                # WAV audio will be stored in a temporary file
                temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav",
                                                            delete=False).name

                # Path to WAV file
                format_args["file"] = temp_wav_path

            if self.use_jinja2:
                # Interpret TTS command as a Jinja2 template
                if not self.jinja2_template:
                    from jinja2 import Environment

                    self.jinja2_template = Environment().from_string(
                        self.tts_command)

                tts_command_str = self.jinja2_template.render(**format_args)
            else:
                # Interpret TTS command as a formatted string
                tts_command_str = self.tts_command.format(**format_args)

            say_command = shlex.split(tts_command_str)

            if not self.text_on_stdin:
                # Text as command-line arguments
                say_command += [say.text]

            _LOGGER.debug(say_command)

            # WAV audio on stdout, text as command-line argument
            proc_stdin: typing.Optional[int] = None
            proc_stdout: typing.Optional[int] = subprocess.PIPE
            proc_input: typing.Optional[bytes] = None

            if self.use_temp_wav:
                # WAV audio from file
                proc_stdout = None

            if self.text_on_stdin:
                # Text from standard in
                proc_stdin = subprocess.PIPE
                proc_input = say.text.encode()

            # Run TTS process
            proc = subprocess.Popen(say_command,
                                    stdin=proc_stdin,
                                    stdout=proc_stdout)
            wav_bytes, _ = proc.communicate(input=proc_input)
            proc.wait()

            assert proc.returncode == 0, f"Non-zero exit code: {proc.returncode}"

            if self.use_temp_wav and temp_wav_path:
                with open(temp_wav_path, "rb") as wav_file:
                    wav_bytes = wav_file.read()

            assert wav_bytes, "No WAV data received"
            _LOGGER.debug("Got %s byte(s) of WAV data", len(wav_bytes))

            if wav_bytes:
                volume = self.volume
                if say.volume is not None:
                    # Override with message volume
                    volume = say.volume

                if volume is not None:
                    wav_bytes = TtsHermesMqtt.change_volume(wav_bytes, volume)

                finished_event = asyncio.Event()

                # Play WAV
                if self.play_command:
                    try:
                        # Play locally
                        play_command = shlex.split(
                            self.play_command.format(lang=say.lang))
                        _LOGGER.debug(play_command)

                        subprocess.run(play_command,
                                       input=wav_bytes,
                                       check=True)

                        # Don't wait for playFinished
                        finished_event.set()
                    except Exception as e:
                        _LOGGER.exception("play_command")
                        yield AudioPlayError(
                            error=str(e),
                            context=say.id,
                            site_id=say.site_id,
                            session_id=say.session_id,
                        )
                else:
                    # Publish playBytes
                    request_id = say.id or str(uuid4())
                    self.play_finished_events[request_id] = finished_event

                    yield (
                        AudioPlayBytes(wav_bytes=wav_bytes),
                        {
                            "site_id": say.site_id,
                            "request_id": request_id
                        },
                    )

                try:
                    # Wait for audio to finished playing or timeout
                    wav_duration = get_wav_duration(wav_bytes)
                    wav_timeout = wav_duration + self.finished_timeout_extra

                    _LOGGER.debug("Waiting for play finished (timeout=%s)",
                                  wav_timeout)
                    await asyncio.wait_for(finished_event.wait(),
                                           timeout=wav_timeout)
                except asyncio.TimeoutError:
                    _LOGGER.warning(
                        "Did not receive playFinished before timeout")

        except Exception as e:
            _LOGGER.exception("handle_say")
            yield TtsError(
                error=str(e),
                context=say.id,
                site_id=say.site_id,
                session_id=say.session_id,
            )
        finally:
            yield TtsSayFinished(id=say.id,
                                 site_id=say.site_id,
                                 session_id=say.session_id)

            if temp_wav_path:
                try:
                    os.unlink(temp_wav_path)
                except Exception:
                    pass
Example #11
0
    def test_no_play(self):
        """Test text-to-speech HTTP endpoint with play=false"""
        text = "This is a test."
        self.client.subscribe(TtsSay.topic())
        self.client.subscribe(AudioPlayBytes.topic(site_id=self.site_id))
        self.client.subscribe(TtsSayFinished.topic())
        self.client.subscribe(AudioToggleOff.topic())
        self.client.subscribe(AudioToggleOn.topic())

        response = requests.post(
            self.api_url("text-to-speech"),
            data=text,
            params={
                "siteId": self.site_id,
                "sessionId": self.session_id,
                "play": "false",
            },
        )
        self.check_status(response)

        wav_data = response.content
        self.assertGreater(len(wav_data), 0)

        # Check audioServer/toggleOff
        audio_off_msg = self.mqtt_messages.get(timeout=5)
        self.assertTrue(AudioToggleOff.is_topic(audio_off_msg.topic))

        audio_off = AudioToggleOff.from_dict(json.loads(audio_off_msg.payload))
        self.assertEqual(audio_off.site_id, self.site_id)

        # Check tts/say
        tts_say_msg = self.mqtt_messages.get(timeout=5)
        self.assertTrue(TtsSay.is_topic(tts_say_msg.topic))

        tts_say = TtsSay.from_dict(json.loads(tts_say_msg.payload))
        self.assertEqual(tts_say.site_id, self.site_id)
        self.assertEqual(tts_say.session_id, self.session_id)
        self.assertEqual(tts_say.text, text)

        # Check audioServer/playBytes (will be ignored by audio output system)
        play_bytes_msg = self.mqtt_messages.get(timeout=5)
        self.assertTrue(AudioPlayBytes.is_topic(play_bytes_msg.topic))
        self.assertEqual(AudioPlayBytes.get_site_id(play_bytes_msg.topic),
                         self.site_id)
        self.assertEqual(play_bytes_msg.payload, wav_data)

        # Check tts/sayFinished
        tts_finished_msg = self.mqtt_messages.get(timeout=5)
        self.assertTrue(TtsSayFinished.is_topic(tts_finished_msg.topic))

        tts_finished = TtsSayFinished.from_dict(
            json.loads(tts_finished_msg.payload))
        self.assertEqual(tts_finished.site_id, self.site_id)
        self.assertEqual(tts_finished.session_id, self.session_id)

        # Check audioServer/toggleOn
        audio_on_msg = self.mqtt_messages.get(timeout=5)
        self.assertTrue(AudioToggleOn.is_topic(audio_on_msg.topic))

        audio_on = AudioToggleOn.from_dict(json.loads(audio_on_msg.payload))
        self.assertEqual(audio_on.site_id, self.site_id)
Example #12
0
    async def maybe_play_sound(
        self,
        sound_name: str,
        site_id: typing.Optional[str] = None,
        request_id: typing.Optional[str] = None,
        block: bool = True,
    ) -> typing.AsyncIterable[SoundsType]:
        """Play WAV sound through audio out if it exists."""
        site_id = site_id or self.site_id
        wav_path = self.sound_paths.get(sound_name)
        if wav_path:
            if not wav_path.is_file():
                _LOGGER.error("WAV does not exist: %s", str(wav_path))
                return

            _LOGGER.debug("Playing WAV %s", str(wav_path))
            wav_bytes = wav_path.read_bytes()

            request_id = request_id or str(uuid4())
            finished_event = asyncio.Event()
            finished_id = request_id
            self.message_events[AudioPlayFinished][
                finished_id] = finished_event

            # Disable ASR/hotword at site
            yield HotwordToggleOff(site_id=site_id,
                                   reason=HotwordToggleReason.PLAY_AUDIO)
            yield AsrToggleOff(site_id=site_id,
                               reason=AsrToggleReason.PLAY_AUDIO)

            # Wait for messages to be delivered
            await asyncio.sleep(self.toggle_delay)

            try:
                yield (
                    AudioPlayBytes(wav_bytes=wav_bytes),
                    {
                        "site_id": site_id,
                        "request_id": request_id
                    },
                )

                # Wait for finished event or WAV duration
                if block:
                    wav_duration = get_wav_duration(wav_bytes)
                    wav_timeout = wav_duration + self.sound_timeout_extra
                    _LOGGER.debug("Waiting for playFinished (timeout=%s)",
                                  wav_timeout)
                    await asyncio.wait_for(finished_event.wait(),
                                           timeout=wav_timeout)
            except asyncio.TimeoutError:
                _LOGGER.warning("Did not receive sayFinished before timeout")
            except Exception:
                _LOGGER.exception("maybe_play_sound")
            finally:
                # Wait for audio to finish playing
                await asyncio.sleep(self.toggle_delay)

                # Re-enable ASR/hotword at site
                yield HotwordToggleOn(site_id=site_id,
                                      reason=HotwordToggleReason.PLAY_AUDIO)
                yield AsrToggleOn(site_id=site_id,
                                  reason=AsrToggleReason.PLAY_AUDIO)
 def messages():
     yield (
         AudioPlayBytes(wav_bytes=wav_bytes),
         {"site_id": site_id, "request_id": request_id},
     )
Example #14
0
    async def handle_say(
        self, say: TtsSay
    ) -> typing.AsyncIterable[typing.Union[TtsSayFinished, typing.Tuple[
            AudioPlayBytes, TopicArgs], TtsError, AudioPlayError, ]]:
        """Run TTS system and publish WAV data."""
        wav_bytes: typing.Optional[bytes] = None

        try:
            # Try to pull WAV from cache first
            voice_name = say.lang or self.default_voice
            voice = self.voices.get(voice_name)
            assert voice is not None, f"No voice named {voice_name}"

            # Check cache
            sentence_hash = TtsHermesMqtt.get_sentence_hash(
                voice.cache_id, say.text)
            wav_bytes = None
            from_cache = False
            cached_wav_path = None

            if self.cache_dir:
                # Create cache directory in profile if it doesn't exist
                self.cache_dir.mkdir(parents=True, exist_ok=True)

                # Load from cache
                cached_wav_path = self.cache_dir / f"{sentence_hash.hexdigest()}.wav"

                if cached_wav_path.is_file():
                    # Use WAV file from cache
                    _LOGGER.debug("Using WAV from cache: %s", cached_wav_path)
                    wav_bytes = cached_wav_path.read_bytes()
                    from_cache = True

            if not wav_bytes:
                # Run text to speech
                _LOGGER.debug("Synthesizing '%s' (voice=%s)", say.text,
                              voice_name)
                wav_bytes = self.synthesize(voice, say.text)

                assert wav_bytes, "No WAV data synthesized"
                _LOGGER.debug("Got %s byte(s) of WAV data", len(wav_bytes))

            # Adjust volume
            volume = self.volume
            if say.volume is not None:
                # Override with message volume
                volume = say.volume

            original_wav_bytes = wav_bytes
            if volume is not None:
                wav_bytes = TtsHermesMqtt.change_volume(wav_bytes, volume)

            finished_event = asyncio.Event()

            # Play WAV
            if self.play_command:
                try:
                    # Play locally
                    play_command = shlex.split(
                        self.play_command.format(lang=say.lang))
                    _LOGGER.debug(play_command)

                    subprocess.run(play_command, input=wav_bytes, check=True)

                    # Don't wait for playFinished
                    finished_event.set()
                except Exception as e:
                    _LOGGER.exception("play_command")
                    yield AudioPlayError(
                        error=str(e),
                        context=say.id,
                        site_id=say.site_id,
                        session_id=say.session_id,
                    )
            else:
                # Publish playBytes
                request_id = say.id or str(uuid4())
                self.play_finished_events[request_id] = finished_event

                yield (
                    AudioPlayBytes(wav_bytes=wav_bytes),
                    {
                        "site_id": say.site_id,
                        "request_id": request_id
                    },
                )

            # Save to cache
            if (not from_cache) and cached_wav_path:
                with open(cached_wav_path, "wb") as cached_wav_file:
                    cached_wav_file.write(original_wav_bytes)

            try:
                # Wait for audio to finished playing or timeout
                wav_duration = TtsHermesMqtt.get_wav_duration(wav_bytes)
                wav_timeout = wav_duration + self.finished_timeout_extra

                _LOGGER.debug("Waiting for play finished (timeout=%s)",
                              wav_timeout)
                await asyncio.wait_for(finished_event.wait(),
                                       timeout=wav_timeout)
            except asyncio.TimeoutError:
                _LOGGER.warning("Did not receive playFinished before timeout")

        except Exception as e:
            _LOGGER.exception("handle_say")
            yield TtsError(
                error=str(e),
                context=say.id,
                site_id=say.site_id,
                session_id=say.session_id,
            )
        finally:
            yield TtsSayFinished(id=say.id,
                                 site_id=say.site_id,
                                 session_id=say.session_id)