def on_message(self, client, userdata, msg): """Received message from MQTT broker.""" try: if not msg.topic.endswith("/audioFrame"): _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload), msg.topic) # Check enable/disable messages if msg.topic == AsrToggleOn.topic(): json_payload = json.loads(msg.payload or "{}") if self._check_siteId(json_payload): self.enabled = True _LOGGER.debug("Enabled") elif msg.topic == AsrToggleOn.topic(): json_payload = json.loads(msg.payload or "{}") if self._check_siteId(json_payload): self.enabled = False _LOGGER.debug("Disabled") if not self.enabled: # Disabled return if AudioFrame.is_topic(msg.topic): # Check siteId if (not self.audioframe_topics) or ( msg.topic in self.audioframe_topics ): # Add to all active sessions if self.first_audio: _LOGGER.debug("Receiving audio") self.first_audio = False siteId = AudioFrame.get_siteId(msg.topic) for result in self.handle_audio_frame(msg.payload, siteId=siteId): self.publish(result) elif msg.topic == AsrStartListening.topic(): # hermes/asr/startListening json_payload = json.loads(msg.payload) if self._check_siteId(json_payload): for result in self.start_listening( AsrStartListening(**json_payload) ): self.publish(result) elif msg.topic == AsrStopListening.topic(): # hermes/asr/stopListening json_payload = json.loads(msg.payload) if self._check_siteId(json_payload): for result in self.stop_listening(AsrStopListening(**json_payload)): self.publish(result) except Exception: _LOGGER.exception("on_message")
def messages(): yield AsrStartListening(siteId=self.siteId, sessionId=sessionId) # Break WAV into chunks with io.BytesIO(wav_bytes) as wav_buffer: with wave.open(wav_buffer, "rb") as wav_file: frames_left = wav_file.getnframes() while frames_left > 0: with io.BytesIO() as chunk_buffer: chunk_file: wave.Wave_write = wave.open( chunk_buffer, "wb") with chunk_file: chunk_file.setframerate( wav_file.getframerate()) chunk_file.setsampwidth( wav_file.getsampwidth()) chunk_file.setnchannels( wav_file.getnchannels()) chunk_file.writeframes( wav_file.readframes(frames_per_chunk)) yield ( AudioFrame(wav_data=chunk_buffer.getvalue()), { "siteId": self.siteId }, ) frames_left -= frames_per_chunk yield AsrStopListening(siteId=self.siteId, sessionId=sessionId)
async def end_session( self, reason: DialogueSessionTerminationReason, site_id: str ) -> typing.AsyncIterable[typing.Union[EndSessionType, StartSessionType, SayType]]: """End current session and start queued session.""" assert self.session is not None, "No session" session = self.session if session.start_session.init.type != DialogueActionType.NOTIFICATION: # Stop listening yield AsrStopListening(site_id=session.site_id, session_id=session.session_id) yield DialogueSessionEnded( site_id=site_id, session_id=session.session_id, custom_data=session.custom_data, termination=DialogueSessionTermination(reason=reason), ) self.session = None # Check session queue if self.session_queue: _LOGGER.debug("Handling queued session") async for start_result in self.start_session( self.session_queue.popleft()): yield start_result else: # Enable hotword if no queued sessions yield HotwordToggleOn(site_id=session.site_id, reason=HotwordToggleReason.DIALOGUE_SESSION)
async def async_test_transcriber_error(self): """Check start/stop session with error in transcriber.""" def fake_transcribe(stream, *args): """Raise an exception.""" raise FakeException() self.transcriber.transcribe_stream = fake_transcribe # Start session start_listening = AsrStartListening( site_id=self.site_id, session_id=self.session_id, stop_on_silence=False ) result = None async for response in self.hermes.on_message_blocking(start_listening): result = response # No response expected self.assertIsNone(result) # Send in "audio" fake_wav_bytes = self.hermes.to_wav_bytes(secrets.token_bytes(100)) fake_frame = AudioFrame(wav_bytes=fake_wav_bytes) async for response in self.hermes.on_message_blocking( fake_frame, site_id=self.site_id ): result = response # No response expected self.assertIsNone(result) # Stop session stop_listening = AsrStopListening( site_id=self.site_id, session_id=self.session_id ) results = [] async for response in self.hermes.on_message_blocking(stop_listening): results.append(response) # Check results for empty transcription self.assertEqual( results, [ AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id), AsrTextCaptured( text="", likelihood=0, seconds=0, site_id=self.site_id, session_id=self.session_id, ), ], )
async def handle_text_captured( self, text_captured: AsrTextCaptured ) -> typing.AsyncIterable[typing.Union[AsrStopListening, HotwordToggleOn, NluQuery]]: """Handle ASR text captured for session.""" try: if not text_captured.session_id: _LOGGER.warning("Missing session id on text captured message.") return site_session = self.all_sessions.get(text_captured.session_id) if site_session is None: _LOGGER.warning( "No session for id %s. Dropping captured text from ASR.", text_captured.session_id, ) return _LOGGER.debug("Received text: %s", text_captured.text) # Record result site_session.text_captured = text_captured # Stop listening yield AsrStopListening(site_id=text_captured.site_id, session_id=site_session.session_id) # Enable hotword yield HotwordToggleOn( site_id=text_captured.site_id, reason=HotwordToggleReason.DIALOGUE_SESSION, ) # Perform query yield NluQuery( input=text_captured.text, intent_filter=site_session.intent_filter or self.default_intent_filter, session_id=site_session.session_id, site_id=site_session.site_id, wakeword_id=text_captured.wakeword_id or site_session.wakeword_id, lang=text_captured.lang or site_session.lang, ) except Exception: _LOGGER.exception("handle_text_captured")
def handle_text_captured( self, text_captured: AsrTextCaptured ) -> typing.Iterable[typing.Union[AsrStopListening, NluQuery]]: """Handle ASR text captured for session.""" try: assert self.session, "No session" _LOGGER.debug("Received text: %s", text_captured.text) # Stop listening yield AsrStopListening(siteId=self.siteId, sessionId=self.session.sessionId) # Perform query yield NluQuery( input=text_captured.text, intentFilter=self.session.intentFilter, sessionId=self.session.sessionId, ) except Exception: _LOGGER.exception("handle_text_captured")
async def end_session( self, reason: DialogueSessionTerminationReason, site_id: str, session_id: str, start_next_session: bool, ) -> typing.AsyncIterable[typing.Union[EndSessionType, StartSessionType, SayType]]: """End current session and start queued session.""" site_session = self.all_sessions.pop(session_id, None) if site_session: # Remove session for site self.session_by_site.pop(site_session.site_id, None) # End the existing session if site_session.start_session.init.type != DialogueActionType.NOTIFICATION: # Stop listening yield AsrStopListening( site_id=site_session.site_id, session_id=site_session.session_id ) yield DialogueSessionEnded( site_id=site_id, session_id=site_session.session_id, custom_data=site_session.custom_data, termination=DialogueSessionTermination(reason=reason), ) else: _LOGGER.warning("No session for id %s", session_id) # Check session queue session_queue = self.session_queue_by_site[site_id] if session_queue: if start_next_session: _LOGGER.debug("Handling queued session") async for start_result in self.start_session(session_queue.popleft()): yield start_result else: # Enable hotword if no queued sessions yield HotwordToggleOn( site_id=site_id, reason=HotwordToggleReason.DIALOGUE_SESSION )
def on_connect(self, client, userdata, flags, rc): """Connected to MQTT broker.""" try: topics = [ AsrToggleOn.topic(), AsrToggleOff.topic(), AsrStartListening.topic(), AsrStopListening.topic(), ] if self.audioframe_topics: # Specific siteIds topics.extend(self.audioframe_topics) else: # All siteIds topics.append(AudioFrame.topic(siteId="+")) for topic in topics: self.client.subscribe(topic) _LOGGER.debug("Subscribed to %s", topic) except Exception: _LOGGER.exception("on_connect")
def messages(): yield AsrStartListening( site_id=self.site_id, session_id=session_id, stop_on_silence=stop_on_silence, send_audio_captured=send_audio_captured, intent_filter=intent_filter, ) # Break WAV into chunks num_bytes_sent: int = 0 with io.BytesIO(wav_bytes) as wav_buffer: for wav_chunk in AudioFrame.iter_wav_chunked( wav_buffer, frames_per_chunk ): num_bytes_sent += len(wav_chunk) yield ( AudioSessionFrame(wav_bytes=wav_chunk), {"site_id": self.site_id, "session_id": session_id}, ) _LOGGER.debug("Sent %s byte(s) of WAV data", num_bytes_sent) yield AsrStopListening(site_id=self.site_id, session_id=session_id)
async def handle_text_captured( self, text_captured: AsrTextCaptured ) -> typing.AsyncIterable[ typing.Union[ AsrStopListening, HotwordToggleOn, NluQuery, NluIntentNotRecognized ] ]: """Handle ASR text captured for session.""" try: if not text_captured.session_id: _LOGGER.warning("Missing session id on text captured message.") return site_session = self.all_sessions.get(text_captured.session_id) if site_session is None: _LOGGER.warning( "No session for id %s. Dropping captured text from ASR.", text_captured.session_id, ) return _LOGGER.debug("Received text: %s", text_captured.text) # Record result site_session.text_captured = text_captured # Stop listening yield AsrStopListening( site_id=text_captured.site_id, session_id=site_session.session_id ) # Enable hotword yield HotwordToggleOn( site_id=text_captured.site_id, reason=HotwordToggleReason.DIALOGUE_SESSION, ) if (self.min_asr_confidence is not None) and ( text_captured.likelihood < self.min_asr_confidence ): # Transcription is below thresold. # Don't actually do an NLU query, just reject as "not recognized". _LOGGER.debug( "Transcription is below confidence threshold (%s < %s): %s", text_captured.likelihood, self.min_asr_confidence, text_captured.text, ) yield NluIntentNotRecognized( input=text_captured.text, site_id=site_session.site_id, session_id=site_session.session_id, ) else: # Perform query custom_entities: typing.Optional[typing.Dict[str, typing.Any]] = None # Copy custom entities from hotword detected if site_session.detected: custom_entities = site_session.detected.custom_entities yield NluQuery( input=text_captured.text, intent_filter=site_session.intent_filter or self.default_intent_filter, session_id=site_session.session_id, site_id=site_session.site_id, wakeword_id=text_captured.wakeword_id or site_session.wakeword_id, lang=text_captured.lang or site_session.lang, custom_data=site_session.custom_data, asr_confidence=text_captured.likelihood, custom_entities=custom_entities, ) except Exception: _LOGGER.exception("handle_text_captured")
async def handle_continue( self, continue_session: DialogueContinueSession ) -> typing.AsyncIterable[ typing.Union[AsrStartListening, AsrStopListening, SayType, DialogueError] ]: """Continue the existing session.""" site_session = self.all_sessions.get(continue_session.session_id) if site_session is None: _LOGGER.warning( "No session for id %s. Cannot continue.", continue_session.session_id ) return try: if continue_session.custom_data is not None: # Overwrite custom data site_session.custom_data = continue_session.custom_data if continue_session.lang is not None: # Overwrite language site_session.lang = continue_session.lang site_session.intent_filter = continue_session.intent_filter site_session.send_intent_not_recognized = ( continue_session.send_intent_not_recognized ) site_session.step += 1 _LOGGER.debug( "Continuing session %s (step=%s)", site_session.session_id, site_session.step, ) # Stop listening yield AsrStopListening( site_id=site_session.site_id, session_id=site_session.session_id ) # Ensure hotword is disabled for session yield HotwordToggleOff( site_id=site_session.site_id, reason=HotwordToggleReason.DIALOGUE_SESSION, ) if continue_session.text: # Forward to TTS async for tts_result in self.say( continue_session.text, site_id=site_session.site_id, session_id=continue_session.session_id, ): yield tts_result # Start ASR listening _LOGGER.debug("Listening for session %s", site_session.session_id) yield AsrStartListening( site_id=site_session.site_id, session_id=site_session.session_id, send_audio_captured=site_session.send_audio_captured, lang=site_session.lang, ) # Set up timeout asyncio.create_task( self.handle_session_timeout( site_session.site_id, site_session.session_id, site_session.step ) ) except Exception as e: _LOGGER.exception("handle_continue") yield DialogueError( error=str(e), context=str(continue_session), site_id=site_session.site_id, session_id=continue_session.session_id, )
def test_asr_stop_listening(): """Test AsrStopListening.""" assert AsrStopListening.topic() == "hermes/asr/stopListening"
async def start_listening( self, message: AsrStartListening ) -> typing.AsyncIterable[typing.Union[StopListeningType, AsrError]]: """Start recording audio data for a session.""" try: if message.session_id in self.sessions: # Stop existing session async for stop_message in self.stop_listening( AsrStopListening(site_id=message.site_id, session_id=message.session_id)): yield stop_message if self.free_transcribers: # Re-use existing transcriber info = self.free_transcribers.pop() _LOGGER.debug("Re-using existing transcriber (session_id=%s)", message.session_id) else: # Create new transcriber info = TranscriberInfo(reuse=self.reuse_transcribers) _LOGGER.debug("Creating new transcriber session %s", message.session_id) def transcribe_proc(info, transcriber_factory, sample_rate, sample_width, channels): def audio_stream(frame_queue) -> typing.Iterable[bytes]: # Pull frames from the queue frames = frame_queue.get() while frames: yield frames frames = frame_queue.get() try: info.transcriber = transcriber_factory( port_num=self.kaldi_port) assert (info.transcriber is not None), "Failed to create transcriber" while True: # Wait for session to start info.ready_event.wait() info.ready_event.clear() # Get result of transcription result = info.transcriber.transcribe_stream( audio_stream(info.frame_queue), sample_rate, sample_width, channels, ) _LOGGER.debug("Transcription result: %s", result) assert (result is not None and result.text), "Null transcription" # Signal completion info.result = result info.result_event.set() if not self.reuse_transcribers: try: info.transcriber.stop() except Exception: _LOGGER.exception("Transcriber stop") break except Exception: _LOGGER.exception("session proc") # Mark as not reusable info.reuse = False # Stop transcriber if info.transcriber is not None: try: info.transcriber.stop() except Exception: _LOGGER.exception("Transcriber stop") # Signal failure info.transcriber = None info.result = Transcription(text="", likelihood=0, transcribe_seconds=0, wav_seconds=0) info.result_event.set() # Run in separate thread info.thread = threading.Thread( target=transcribe_proc, args=( info, self.transcriber_factory, self.sample_rate, self.sample_width, self.channels, ), daemon=True, ) info.thread.start() # --------------------------------------------------------------------- # Settings for session info.start_listening = message # Signal session thread to start info.ready_event.set() if message.stop_on_silence: # Begin silence detection if info.recorder is None: info.recorder = self.recorder_factory() info.recorder.start() else: # Use internal buffer (no silence detection) info.audio_buffer = bytes() self.sessions[message.session_id] = info _LOGGER.debug("Starting listening (session_id=%s)", message.session_id) self.first_audio = True except Exception as e: _LOGGER.exception("start_listening") yield AsrError( error=str(e), context=repr(message), site_id=message.site_id, session_id=message.session_id, )
async def async_test_session(self): """Check good start/stop session.""" fake_transcription = Transcription( text="this is a test", likelihood=1, transcribe_seconds=0, wav_seconds=0 ) def fake_transcribe(stream, *args): """Return test trancription.""" for chunk in stream: if not chunk: break return fake_transcription self.transcriber.transcribe_stream = fake_transcribe # Start session start_listening = AsrStartListening( site_id=self.site_id, session_id=self.session_id, stop_on_silence=False, send_audio_captured=True, ) result = None async for response in self.hermes.on_message_blocking(start_listening): result = response # No response expected self.assertIsNone(result) # Send in "audio" fake_wav_bytes = self.hermes.to_wav_bytes(secrets.token_bytes(100)) fake_frame = AudioFrame(wav_bytes=fake_wav_bytes) async for response in self.hermes.on_message_blocking( fake_frame, site_id=self.site_id ): result = response # No response expected self.assertIsNone(result) # Stop session stop_listening = AsrStopListening( site_id=self.site_id, session_id=self.session_id ) results = [] async for response in self.hermes.on_message_blocking(stop_listening): results.append(response) # Check results self.assertEqual( results, [ AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id), AsrTextCaptured( text=fake_transcription.text, likelihood=fake_transcription.likelihood, seconds=fake_transcription.transcribe_seconds, site_id=self.site_id, session_id=self.session_id, ), ( AsrAudioCaptured(wav_bytes=fake_wav_bytes), {"site_id": self.site_id, "session_id": self.session_id}, ), ], )
def start_listening( self, message: AsrStartListening ) -> typing.Iterable[typing.Union[AsrTextCaptured, AsrError]]: """Start recording audio data for a session.""" try: if message.sessionId in self.sessions: # Stop existing session for result in self.stop_listening( AsrStopListening(sessionId=message.sessionId) ): yield result if self.free_transcribers: # Re-use existing transcriber info = self.free_transcribers.pop() _LOGGER.debug( "Re-using existing transcriber (sessionId=%s)", message.sessionId ) else: # Create new transcriber info = TranscriberInfo(recorder=self.recorder_factory()) # type: ignore _LOGGER.debug("Creating new transcriber session %s", message.sessionId) def transcribe_proc( info, transcriber_factory, sample_rate, sample_width, channels ): def audio_stream(frame_queue): # Pull frames from the queue frames = frame_queue.get() while frames: yield frames frames = frame_queue.get() try: # Create transcriber in this thread info.transcriber = transcriber_factory() while True: # Wait for session to start info.ready_event.wait() info.ready_event.clear() # Get result of transcription result = info.transcriber.transcribe_stream( audio_stream(info.frame_queue), sample_rate, sample_width, channels, ) _LOGGER.debug(result) # Signal completion info.result = result info.result_event.set() except Exception: _LOGGER.exception("session proc") # Run in separate thread info.thread = threading.Thread( target=transcribe_proc, args=( info, self.transcriber_factory, self.sample_rate, self.sample_width, self.channels, ), daemon=True, ) info.thread.start() # --------------------------------------------------------------------- # Signal session thread to start info.ready_event.set() # Begin silence detection assert info.recorder is not None info.recorder.start() self.sessions[message.sessionId] = info _LOGGER.debug("Starting listening (sessionId=%s)", message.sessionId) self.first_audio = True except Exception as e: _LOGGER.exception("start_listening") yield AsrError( error=str(e), context=repr(message), siteId=message.siteId, sessionId=message.sessionId, )
async def handle_audio_frame( self, wav_bytes: bytes, site_id: str = "default", session_id: typing.Optional[str] = None, ) -> typing.AsyncIterable[ typing.Union[ typing.Tuple[HotwordDetected, TopicArgs], AsrTextCaptured, typing.Tuple[AsrAudioCaptured, TopicArgs], AsrError, ] ]: """Add audio frame to open sessions.""" try: if self.asr_enabled: if session_id is None: # Add to every open session target_sessions = list(self.asr_sessions.items()) else: # Add to single session target_sessions = [(session_id, self.asr_sessions[session_id])] with io.BytesIO(wav_bytes) as in_io: with wave.open(in_io) as in_wav: # Get WAV details from first frame sample_rate = in_wav.getframerate() sample_width = in_wav.getsampwidth() channels = in_wav.getnchannels() audio_data = in_wav.readframes(in_wav.getnframes()) # Add to target ASR sessions for target_id, session in target_sessions: # Skip non-matching site_id if session.start_listening.site_id != site_id: continue session.sample_rate = sample_rate session.sample_width = sample_width session.channels = channels session.audio_data += audio_data if session.start_listening.stop_on_silence: # Detect silence (end of command) audio_data = self.maybe_convert_wav( wav_bytes, self.recorder_sample_rate, self.recorder_sample_width, self.recorder_channels, ) command = session.recorder.process_chunk(audio_data) if command and (command.result == VoiceCommandResult.SUCCESS): # Complete session stop_listening = AsrStopListening( site_id=site_id, session_id=target_id ) async for message in self.handle_stop_listening( stop_listening ): yield message if self.wake_enabled and (session_id is None) and self.wake_proc: # Convert and send to wake command audio_bytes = self.maybe_convert_wav( wav_bytes, self.wake_sample_rate, self.wake_sample_width, self.wake_channels, ) assert self.wake_proc.stdin self.wake_proc.stdin.write(audio_bytes) if self.wake_proc.poll(): stdout, stderr = self.wake_proc.communicate() if stderr: _LOGGER.debug(stderr.decode()) wakeword_id = stdout.decode().strip() _LOGGER.debug("Detected wake word %s", wakeword_id) yield ( HotwordDetected( model_id=wakeword_id, model_version="", model_type="personal", current_sensitivity=1.0, site_id=site_id, ), {"wakeword_id": wakeword_id}, ) # Restart wake process self.start_wake_command() except Exception: _LOGGER.exception("handle_audio_frame")