async def end_session( self, reason: DialogueSessionTerminationReason, site_id: str ) -> typing.AsyncIterable[typing.Union[EndSessionType, StartSessionType, SayType]]: """End current session and start queued session.""" assert self.session is not None, "No session" session = self.session if session.start_session.init.type != DialogueActionType.NOTIFICATION: # Stop listening yield AsrStopListening(site_id=session.site_id, session_id=session.session_id) yield DialogueSessionEnded( site_id=site_id, session_id=session.session_id, custom_data=session.custom_data, termination=DialogueSessionTermination(reason=reason), ) self.session = None # Check session queue if self.session_queue: _LOGGER.debug("Handling queued session") async for start_result in self.start_session( self.session_queue.popleft()): yield start_result else: # Enable hotword if no queued sessions yield HotwordToggleOn(site_id=session.site_id, reason=HotwordToggleReason.DIALOGUE_SESSION)
async def say( self, text: str, site_id="default", session_id="", request_id: typing.Optional[str] = None, block: bool = True, ) -> typing.AsyncIterable[ typing.Union[ TtsSay, HotwordToggleOn, HotwordToggleOff, AsrToggleOn, AsrToggleOff ] ]: """Send text to TTS system and wait for reply.""" finished_event = asyncio.Event() finished_id = request_id or str(uuid4()) self.message_events[TtsSayFinished][finished_id] = finished_event # Disable ASR/hotword at site yield HotwordToggleOff(site_id=site_id, reason=HotwordToggleReason.TTS_SAY) yield AsrToggleOff(site_id=site_id, reason=AsrToggleReason.TTS_SAY) # Wait for messages to be delivered await asyncio.sleep(self.toggle_delay) try: # Forward to TTS _LOGGER.debug("Say: %s", text) yield TtsSay( id=finished_id, site_id=site_id, session_id=session_id, text=text ) if block: # Wait for finished event say_finished_timeout = 10.0 if self.say_chars_per_second > 0: # Estimate timeout based on text length say_finished_timeout = max( say_finished_timeout, len(text) / self.say_chars_per_second ) _LOGGER.debug( "Waiting for sayFinished (id=%s, timeout=%s)", finished_id, say_finished_timeout, ) await asyncio.wait_for( finished_event.wait(), timeout=say_finished_timeout ) except asyncio.TimeoutError: _LOGGER.warning("Did not receive sayFinished before timeout") except Exception: _LOGGER.exception("say") finally: # Wait for audio to finish play await asyncio.sleep(self.toggle_delay) # Re-enable ASR/hotword at site yield HotwordToggleOn(site_id=site_id, reason=HotwordToggleReason.TTS_SAY) yield AsrToggleOn(site_id=site_id, reason=AsrToggleReason.TTS_SAY)
async def play_wav_data( self, wav_bytes: bytes, site_id: typing.Optional[str] = None ) -> AudioPlayFinished: """Play WAV data through speakers.""" if self.sound_system == "dummy": raise RuntimeError("No audio output system configured") site_id = site_id or self.site_id request_id = str(uuid4()) def handle_finished(): while True: _, message = yield if ( isinstance(message, AudioPlayFinished) and (message.id == request_id) ) or isinstance(message, AudioPlayError): return message def messages(): yield ( AudioPlayBytes(wav_bytes=wav_bytes), {"site_id": site_id, "request_id": request_id}, ) message_types: typing.List[typing.Type[Message]] = [ AudioPlayFinished, AudioPlayError, ] # Disable hotword/ASR self.publish( HotwordToggleOff(site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO) ) self.publish(AsrToggleOff(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO)) try: # Expecting only a single result result = None async for response in self.publish_wait( handle_finished(), messages(), message_types ): result = response if isinstance(result, AudioPlayError): _LOGGER.error(result) raise RuntimeError(result.error) assert isinstance(result, AudioPlayFinished) return result finally: # Enable hotword/ASR self.publish( HotwordToggleOn(site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO) ) self.publish( AsrToggleOn(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO) )
async def handle_end( self, end_session: DialogueEndSession ) -> typing.AsyncIterable[typing.Union[EndSessionType, StartSessionType, SayType]]: """End the current session.""" site_session = self.all_sessions.get(end_session.session_id) if not site_session: _LOGGER.warning("No session for id %s. Cannot end.", end_session.session_id) return try: # Say text before ending session if end_session.text: # Forward to TTS async for tts_result in self.say( end_session.text, site_id=site_session.site_id, session_id=end_session.session_id, ): yield tts_result # Update fields if end_session.custom_data is not None: site_session.custom_data = end_session.custom_data _LOGGER.debug("Session ended nominally: %s", site_session.session_id) async for end_result in self.end_session( DialogueSessionTerminationReason.NOMINAL, site_id=site_session.site_id, session_id=site_session.session_id, start_next_session=True, ): yield end_result except Exception as e: _LOGGER.exception("handle_end") yield DialogueError( error=str(e), context=str(end_session), site_id=site_session.site_id, session_id=end_session.session_id, ) # Enable hotword on error yield HotwordToggleOn( site_id=site_session.site_id, reason=HotwordToggleReason.DIALOGUE_SESSION, )
def on_connect(self, client, userdata, flags, rc): """Connected to MQTT broker.""" try: topics = [HotwordToggleOn.topic(), HotwordToggleOff.topic()] if self.audioframe_topics: # Specific siteIds topics.extend(self.audioframe_topics) else: # All siteIds topics.append(AudioFrame.topic(siteId="#")) for topic in topics: self.client.subscribe(topic) _LOGGER.debug("Subscribed to %s", topic) except Exception: _LOGGER.exception("on_connect")
async def handle_text_captured( self, text_captured: AsrTextCaptured ) -> typing.AsyncIterable[typing.Union[AsrStopListening, HotwordToggleOn, NluQuery]]: """Handle ASR text captured for session.""" try: if not text_captured.session_id: _LOGGER.warning("Missing session id on text captured message.") return site_session = self.all_sessions.get(text_captured.session_id) if site_session is None: _LOGGER.warning( "No session for id %s. Dropping captured text from ASR.", text_captured.session_id, ) return _LOGGER.debug("Received text: %s", text_captured.text) # Record result site_session.text_captured = text_captured # Stop listening yield AsrStopListening(site_id=text_captured.site_id, session_id=site_session.session_id) # Enable hotword yield HotwordToggleOn( site_id=text_captured.site_id, reason=HotwordToggleReason.DIALOGUE_SESSION, ) # Perform query yield NluQuery( input=text_captured.text, intent_filter=site_session.intent_filter or self.default_intent_filter, session_id=site_session.session_id, site_id=site_session.site_id, wakeword_id=text_captured.wakeword_id or site_session.wakeword_id, lang=text_captured.lang or site_session.lang, ) except Exception: _LOGGER.exception("handle_text_captured")
def on_message(self, client, userdata, msg): """Received message from MQTT broker.""" try: if not msg.topic.endswith("/audioFrame"): _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload), msg.topic) # Check enable/disable messages if msg.topic == HotwordToggleOn.topic(): json_payload = json.loads(msg.payload or "{}") if self._check_siteId(json_payload): self.enabled = True self.first_audio = True _LOGGER.debug("Enabled") elif msg.topic == HotwordToggleOff.topic(): json_payload = json.loads(msg.payload or "{}") if self._check_siteId(json_payload): self.enabled = False _LOGGER.debug("Disabled") if not self.enabled: # Disabled return # Handle audio frames if AudioFrame.is_topic(msg.topic): if (not self.audioframe_topics) or ( msg.topic in self.audioframe_topics ): if self.first_audio: _LOGGER.debug("Receiving audio") self.first_audio = False siteId = AudioFrame.get_siteId(msg.topic) for wakewordId, result in self.handle_audio_frame( msg.payload, siteId=siteId ): if isinstance(result, HotwordDetected): # Topic contains wake word id self.publish(result, wakewordId=wakewordId) else: self.publish(result) except Exception: _LOGGER.exception("on_message")
async def end_session( self, reason: DialogueSessionTerminationReason, site_id: str, session_id: str, start_next_session: bool, ) -> typing.AsyncIterable[typing.Union[EndSessionType, StartSessionType, SayType]]: """End current session and start queued session.""" site_session = self.all_sessions.pop(session_id, None) if site_session: # Remove session for site self.session_by_site.pop(site_session.site_id, None) # End the existing session if site_session.start_session.init.type != DialogueActionType.NOTIFICATION: # Stop listening yield AsrStopListening( site_id=site_session.site_id, session_id=site_session.session_id ) yield DialogueSessionEnded( site_id=site_id, session_id=site_session.session_id, custom_data=site_session.custom_data, termination=DialogueSessionTermination(reason=reason), ) else: _LOGGER.warning("No session for id %s", session_id) # Check session queue session_queue = self.session_queue_by_site[site_id] if session_queue: if start_next_session: _LOGGER.debug("Handling queued session") async for start_result in self.start_session(session_queue.popleft()): yield start_result else: # Enable hotword if no queued sessions yield HotwordToggleOn( site_id=site_id, reason=HotwordToggleReason.DIALOGUE_SESSION )
async def maybe_play_sound( self, sound_name: str, site_id: typing.Optional[str] = None, request_id: typing.Optional[str] = None, block: bool = True, ) -> typing.AsyncIterable[SoundsType]: """Play WAV sound through audio out if it exists.""" if site_id in self.no_sound: _LOGGER.debug("Sound is disabled for site %s", site_id) return site_id = site_id or self.site_id sound_path = self.sound_paths.get(sound_name) if sound_path: if sound_path.is_dir(): sound_file_paths = [ p for p in sound_path.rglob("*") if p.is_file() and (p.suffix in self.sound_suffixes) ] if not sound_file_paths: _LOGGER.debug("No sound files found in %s", str(sound_path)) return sound_path = random.choice(sound_file_paths) elif not sound_path.is_file(): _LOGGER.error("Sound does not exist: %s", str(sound_path)) return _LOGGER.debug("Playing sound %s", str(sound_path)) # Convert to WAV wav_bytes = DialogueHermesMqtt.convert_to_wav(sound_path) if (self.volume is not None) and (self.volume != 1.0): wav_bytes = DialogueHermesMqtt.change_volume(wav_bytes, self.volume) # Send messages request_id = request_id or str(uuid4()) finished_event = asyncio.Event() finished_id = request_id self.message_events[AudioPlayFinished][finished_id] = finished_event # Disable ASR/hotword at site yield HotwordToggleOff( site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO ) yield AsrToggleOff(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO) # Wait for messages to be delivered await asyncio.sleep(self.toggle_delay) try: yield ( AudioPlayBytes(wav_bytes=wav_bytes), {"site_id": site_id, "request_id": request_id}, ) # Wait for finished event or WAV duration if block: wav_duration = get_wav_duration(wav_bytes) wav_timeout = wav_duration + self.sound_timeout_extra _LOGGER.debug( "Waiting for playFinished (id=%s, timeout=%s)", finished_id, wav_timeout, ) await asyncio.wait_for(finished_event.wait(), timeout=wav_timeout) except asyncio.TimeoutError: _LOGGER.warning("Did not receive sayFinished before timeout") except Exception: _LOGGER.exception("maybe_play_sound") finally: # Wait for audio to finish playing await asyncio.sleep(self.toggle_delay) # Re-enable ASR/hotword at site yield HotwordToggleOn( site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO ) yield AsrToggleOn(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO)
async def handle_text_captured( self, text_captured: AsrTextCaptured ) -> typing.AsyncIterable[ typing.Union[ AsrStopListening, HotwordToggleOn, NluQuery, NluIntentNotRecognized ] ]: """Handle ASR text captured for session.""" try: if not text_captured.session_id: _LOGGER.warning("Missing session id on text captured message.") return site_session = self.all_sessions.get(text_captured.session_id) if site_session is None: _LOGGER.warning( "No session for id %s. Dropping captured text from ASR.", text_captured.session_id, ) return _LOGGER.debug("Received text: %s", text_captured.text) # Record result site_session.text_captured = text_captured # Stop listening yield AsrStopListening( site_id=text_captured.site_id, session_id=site_session.session_id ) # Enable hotword yield HotwordToggleOn( site_id=text_captured.site_id, reason=HotwordToggleReason.DIALOGUE_SESSION, ) if (self.min_asr_confidence is not None) and ( text_captured.likelihood < self.min_asr_confidence ): # Transcription is below thresold. # Don't actually do an NLU query, just reject as "not recognized". _LOGGER.debug( "Transcription is below confidence threshold (%s < %s): %s", text_captured.likelihood, self.min_asr_confidence, text_captured.text, ) yield NluIntentNotRecognized( input=text_captured.text, site_id=site_session.site_id, session_id=site_session.session_id, ) else: # Perform query custom_entities: typing.Optional[typing.Dict[str, typing.Any]] = None # Copy custom entities from hotword detected if site_session.detected: custom_entities = site_session.detected.custom_entities yield NluQuery( input=text_captured.text, intent_filter=site_session.intent_filter or self.default_intent_filter, session_id=site_session.session_id, site_id=site_session.site_id, wakeword_id=text_captured.wakeword_id or site_session.wakeword_id, lang=text_captured.lang or site_session.lang, custom_data=site_session.custom_data, asr_confidence=text_captured.likelihood, custom_entities=custom_entities, ) except Exception: _LOGGER.exception("handle_text_captured")
async def maybe_play_sound( self, sound_name: str, site_id: typing.Optional[str] = None, request_id: typing.Optional[str] = None, block: bool = True, ) -> typing.AsyncIterable[SoundsType]: """Play WAV sound through audio out if it exists.""" site_id = site_id or self.site_id wav_path = self.sound_paths.get(sound_name) if wav_path: if not wav_path.is_file(): _LOGGER.error("WAV does not exist: %s", str(wav_path)) return _LOGGER.debug("Playing WAV %s", str(wav_path)) wav_bytes = wav_path.read_bytes() request_id = request_id or str(uuid4()) finished_event = asyncio.Event() finished_id = request_id self.message_events[AudioPlayFinished][ finished_id] = finished_event # Disable ASR/hotword at site yield HotwordToggleOff(site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO) yield AsrToggleOff(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO) # Wait for messages to be delivered await asyncio.sleep(self.toggle_delay) try: yield ( AudioPlayBytes(wav_bytes=wav_bytes), { "site_id": site_id, "request_id": request_id }, ) # Wait for finished event or WAV duration if block: wav_duration = get_wav_duration(wav_bytes) wav_timeout = wav_duration + self.sound_timeout_extra _LOGGER.debug("Waiting for playFinished (timeout=%s)", wav_timeout) await asyncio.wait_for(finished_event.wait(), timeout=wav_timeout) except asyncio.TimeoutError: _LOGGER.warning("Did not receive sayFinished before timeout") except Exception: _LOGGER.exception("maybe_play_sound") finally: # Wait for audio to finish playing await asyncio.sleep(self.toggle_delay) # Re-enable ASR/hotword at site yield HotwordToggleOn(site_id=site_id, reason=HotwordToggleReason.PLAY_AUDIO) yield AsrToggleOn(site_id=site_id, reason=AsrToggleReason.PLAY_AUDIO)
def test_hotword_toggle_on(): """Test HotwordToggleOn.""" assert HotwordToggleOn.topic() == "hermes/hotword/toggleOn"