def finish_session( self, info: TranscriberInfo, siteId: str, sessionId: str ) -> typing.Iterable[AsrTextCaptured]: """Publish transcription result for a session if not already published""" # Stop silence detection assert info.recorder is not None info.recorder.stop() if not info.result_sent: # Last chunk info.frame_queue.put(None) # Wait for result info.result_event.wait(timeout=self.session_result_timeout) transcription = info.result if transcription: # Successful transcription yield ( AsrTextCaptured( text=transcription.text, likelihood=transcription.likelihood, seconds=transcription.transcribe_seconds, siteId=siteId, sessionId=sessionId, ) ) else: # Empty transcription yield AsrTextCaptured( text="", likelihood=0, seconds=0, siteId=siteId, sessionId=sessionId ) # Avoid re-sending transcription info.result_sent = True
async def async_test_ws_text(self): """Test api/events/text endpoint""" # Start listening event_queue = asyncio.Queue() connected = asyncio.Event() receive_task = asyncio.ensure_future( self.async_ws_receive("events/text", event_queue, connected)) await asyncio.wait_for(connected.wait(), timeout=5) # Send in a message text_captured = AsrTextCaptured( text="this is a test", likelihood=1, seconds=0, site_id=self.site_id, session_id=self.session_id, wakeword_id=str(uuid4()), ) self.client.publish(text_captured.topic(), text_captured.payload()) # Wait for response event = json.loads(await asyncio.wait_for(event_queue.get(), timeout=5)) self.assertEqual(text_captured.text, event.get("text", "")) self.assertEqual(text_captured.site_id, event.get("siteId", "")) self.assertEqual(text_captured.wakeword_id, event.get("wakewordId", "")) # Stop listening receive_task.cancel()
def on_message(self, client, userdata, msg): """Received message from MQTT broker.""" try: _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload), msg.topic) if NluIntent.is_topic(msg.topic): # Intent from query json_payload = json.loads(msg.payload) if not self._check_siteId(json_payload): return intent = NluIntent(**json_payload) # Report to websockets for queue in self.message_queues: self.loop.call_soon_threadsafe(queue.put_nowait, ("intent", intent)) self.handle_message(msg.topic, intent) elif msg.topic == NluIntentNotRecognized.topic(): # Intent not recognized json_payload = json.loads(msg.payload) if not self._check_siteId(json_payload): return not_recognized = NluIntentNotRecognized(**json_payload) # Report to websockets for queue in self.message_queues: self.loop.call_soon_threadsafe(queue.put_nowait, ("intent", not_recognized)) self.handle_message(msg.topic, not_recognized) elif msg.topic == TtsSayFinished.topic(): # Text to speech finished json_payload = json.loads(msg.payload) self.handle_message(msg.topic, TtsSayFinished(**json_payload)) elif msg.topic == AsrTextCaptured.topic(): # Speech to text result json_payload = json.loads(msg.payload) if not self._check_siteId(json_payload): return self.handle_message(msg.topic, AsrTextCaptured(**json_payload)) # Forward to external message queues for queue in self.message_queues: self.loop.call_soon_threadsafe( queue.put_nowait, ("mqtt", msg.topic, msg.payload)) except Exception: _LOGGER.exception("on_message")
async def transcribe( self, wav_bytes: bytes, site_id: str, session_id: typing.Optional[str] = None, wakeword_id: typing.Optional[str] = None, ) -> AsrTextCaptured: """Transcribe audio data and publish captured text.""" _LOGGER.debug("Transcribing %s byte(s) of audio data", len(wav_bytes)) transcription = self.transcriber.transcribe_wav(wav_bytes) if transcription: _LOGGER.debug(transcription) asr_tokens: typing.Optional[typing.List[ typing.List[AsrToken]]] = None if transcription.tokens: # Only one level of ASR tokens asr_inner_tokens: typing.List[AsrToken] = [] asr_tokens = [asr_inner_tokens] range_start = 0 for ps_token in transcription.tokens: range_end = range_start + len(ps_token.token) + 1 asr_inner_tokens.append( AsrToken( value=ps_token.token, confidence=ps_token.likelihood, range_start=range_start, range_end=range_start + len(ps_token.token) + 1, time=AsrTokenTime(start=ps_token.start_time, end=ps_token.end_time), )) range_start = range_end # Actual transcription return AsrTextCaptured(text=transcription.text, likelihood=transcription.likelihood, seconds=transcription.transcribe_seconds, site_id=site_id, session_id=session_id, asr_tokens=asr_tokens, wakeword_id=wakeword_id) _LOGGER.warning("Received empty transcription") return AsrTextCaptured(text="", likelihood=0, seconds=0, site_id=site_id, session_id=session_id, wakeword_id=wakeword_id)
async def async_test_silence(self): """Check start/stop session with silence detection.""" fake_transcription = Transcription( text="turn on the living room lamp", likelihood=1, transcribe_seconds=0, wav_seconds=0, ) def fake_transcribe(stream, *args): """Return test trancription.""" for chunk in stream: if not chunk: break return fake_transcription self.transcriber.transcribe_stream = fake_transcribe # Start session start_listening = AsrStartListening( site_id=self.site_id, session_id=self.session_id, stop_on_silence=True, send_audio_captured=False, ) result = None async for response in self.hermes.on_message_blocking(start_listening): result = response # No response expected self.assertIsNone(result) # Send in "audio" wav_path = Path("etc/turn_on_the_living_room_lamp.wav") results = [] with open(wav_path, "rb") as wav_file: for wav_bytes in AudioFrame.iter_wav_chunked(wav_file, 4096): frame = AudioFrame(wav_bytes=wav_bytes) async for response in self.hermes.on_message_blocking( frame, site_id=self.site_id ): results.append(response) # Except transcription self.assertEqual( results, [ AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id), AsrTextCaptured( text=fake_transcription.text, likelihood=fake_transcription.likelihood, seconds=fake_transcription.transcribe_seconds, site_id=self.site_id, session_id=self.session_id, ), ], )
def transcribe( self, audio_data: bytes, siteId: str = "default", sessionId: str = "") -> typing.Union[AsrTextCaptured, AsrError]: """Transcribe audio data and publish captured text.""" try: with io.BytesIO() as wav_buffer: wav_file: wave.Wave_write = wave.open(wav_buffer, mode="wb") with wav_file: wav_file.setframerate(self.sample_rate) wav_file.setsampwidth(self.sample_width) wav_file.setnchannels(self.channels) wav_file.writeframesraw(audio_data) transcription = self.transcriber.transcribe_wav( wav_buffer.getvalue()) if transcription: # Actual transcription return AsrTextCaptured( text=transcription.text, likelihood=transcription.likelihood, seconds=transcription.transcribe_seconds, siteId=siteId, sessionId=sessionId, ) _LOGGER.warning("Received empty transcription") # Empty transcription return AsrTextCaptured(text="", likelihood=0, seconds=0, siteId=siteId, sessionId=sessionId) except Exception as e: _LOGGER.exception("transcribe") return AsrError( error=str(e), context=repr(self.transcriber), siteId=siteId, sessionId=sessionId, )
async def async_test_transcriber_error(self): """Check start/stop session with error in transcriber.""" def fake_transcribe(stream, *args): """Raise an exception.""" raise FakeException() self.transcriber.transcribe_stream = fake_transcribe # Start session start_listening = AsrStartListening( site_id=self.site_id, session_id=self.session_id, stop_on_silence=False ) result = None async for response in self.hermes.on_message_blocking(start_listening): result = response # No response expected self.assertIsNone(result) # Send in "audio" fake_wav_bytes = self.hermes.to_wav_bytes(secrets.token_bytes(100)) fake_frame = AudioFrame(wav_bytes=fake_wav_bytes) async for response in self.hermes.on_message_blocking( fake_frame, site_id=self.site_id ): result = response # No response expected self.assertIsNone(result) # Stop session stop_listening = AsrStopListening( site_id=self.site_id, session_id=self.session_id ) results = [] async for response in self.hermes.on_message_blocking(stop_listening): results.append(response) # Check results for empty transcription self.assertEqual( results, [ AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id), AsrTextCaptured( text="", likelihood=0, seconds=0, site_id=self.site_id, session_id=self.session_id, ), ], )
async def transcribe_wav(self, wav_bytes: bytes, frames_per_chunk: int = 1024) -> AsrTextCaptured: """Transcribe WAV data""" sessionId = str(uuid4()) def handle_captured(): while True: _, message = yield if isinstance(message, AsrTextCaptured) and (message.sessionId == sessionId): return True, message def messages(): yield AsrStartListening(siteId=self.siteId, sessionId=sessionId) # Break WAV into chunks with io.BytesIO(wav_bytes) as wav_buffer: with wave.open(wav_buffer, "rb") as wav_file: frames_left = wav_file.getnframes() while frames_left > 0: with io.BytesIO() as chunk_buffer: chunk_file: wave.Wave_write = wave.open( chunk_buffer, "wb") with chunk_file: chunk_file.setframerate( wav_file.getframerate()) chunk_file.setsampwidth( wav_file.getsampwidth()) chunk_file.setnchannels( wav_file.getnchannels()) chunk_file.writeframes( wav_file.readframes(frames_per_chunk)) yield ( AudioFrame(wav_data=chunk_buffer.getvalue()), { "siteId": self.siteId }, ) frames_left -= frames_per_chunk yield AsrStopListening(siteId=self.siteId, sessionId=sessionId) topics = [AsrTextCaptured.topic()] # Expecting only a single result async for result in self.publish_wait(handle_captured(), messages(), topics): return result
def test_http_speech_to_text_hermes(self): """Text speech-to-text HTTP endpoint (Hermes format)""" response = requests.post( self.api_url("speech-to-text"), data=self.wav_bytes, params={"outputFormat": "hermes"}, ) self.check_status(response) result = response.json() self.assertEqual(result["type"], "textCaptured") text_captured = AsrTextCaptured.from_dict(result["value"]) self.assertEqual(text_captured.text, "turn on the living room lamp")
def on_connect(self, client, userdata, flags, rc): """Connected to MQTT broker.""" try: topics = [ DialogueStartSession.topic(), DialogueContinueSession.topic(), DialogueEndSession.topic(), TtsSayFinished.topic(), NluIntent.topic(intent_name="#"), NluIntentNotRecognized.topic(), AsrTextCaptured.topic(), ] + list(self.wakeword_topics.keys()) for topic in topics: self.client.subscribe(topic) _LOGGER.debug("Subscribed to %s", topic) except Exception: _LOGGER.exception("on_connect")
def test_asr_text_captured(): """Test AsrTextCaptured.""" assert AsrTextCaptured.topic() == "hermes/asr/textCaptured"
async def finish_session( self, info: TranscriberInfo, site_id: str, session_id: typing.Optional[str] ) -> typing.AsyncIterable[typing.Union[AsrTextCaptured, AudioCapturedType]]: """Publish transcription result for a session if not already published""" if info.recorder is not None: # Stop silence detection and get trimmed audio audio_data = info.recorder.stop() else: # Use complete audio buffer assert info.audio_buffer is not None audio_data = info.audio_buffer if not info.result_sent: # Avoid re-sending transcription info.result_sent = True # Last chunk info.frame_queue.put(None) # Wait for result result_success = info.result_event.wait( timeout=self.session_result_timeout) if not result_success: # Mark transcription as non-reusable info.reuse = False transcription = info.result assert info.start_listening is not None if transcription: # Successful transcription yield (AsrTextCaptured( text=transcription.text, likelihood=transcription.likelihood, seconds=transcription.transcribe_seconds, site_id=site_id, session_id=session_id, lang=info.start_listening.lang, )) else: # Empty transcription yield AsrTextCaptured( text="", likelihood=0, seconds=0, site_id=site_id, session_id=session_id, lang=info.start_listening.lang, ) if info.start_listening.send_audio_captured: wav_bytes = self.to_wav_bytes(audio_data) # Send audio data yield ( # pylint: disable=E1121 AsrAudioCaptured(wav_bytes), { "site_id": site_id, "session_id": session_id }, )
async def finish_session( self, info: TranscriberInfo, site_id: str, session_id: typing.Optional[str] ) -> typing.AsyncIterable[typing.Union[ AsrRecordingFinished, AsrTextCaptured, AudioCapturedType]]: """Publish transcription result for a session if not already published""" if info.recorder is not None: # Stop silence detection and get trimmed audio audio_data = info.recorder.stop() else: # Use complete audio buffer assert info.audio_buffer is not None audio_data = info.audio_buffer if not info.result_sent: # Send recording finished message yield AsrRecordingFinished(site_id=site_id, session_id=session_id) # Avoid re-sending transcription info.result_sent = True # Last chunk info.frame_queue.put(None) # Wait for result result_success = info.result_event.wait( timeout=self.session_result_timeout) if not result_success: # Mark transcription as non-reusable info.reuse = False transcription = info.result assert info.start_listening is not None if transcription: # Successful transcription asr_tokens: typing.Optional[typing.List[ typing.List[AsrToken]]] = None if transcription.tokens: # Only one level of ASR tokens asr_inner_tokens: typing.List[AsrToken] = [] asr_tokens = [asr_inner_tokens] range_start = 0 for ps_token in transcription.tokens: range_end = range_start + len(ps_token.token) + 1 asr_inner_tokens.append( AsrToken( value=ps_token.token, confidence=ps_token.likelihood, range_start=range_start, range_end=range_start + len(ps_token.token) + 1, time=AsrTokenTime(start=ps_token.start_time, end=ps_token.end_time), )) range_start = range_end yield (AsrTextCaptured( text=transcription.text, likelihood=transcription.likelihood, seconds=transcription.transcribe_seconds, site_id=site_id, session_id=session_id, asr_tokens=asr_tokens, lang=(info.start_listening.lang or self.lang), )) else: # Empty transcription yield AsrTextCaptured( text="", likelihood=0, seconds=0, site_id=site_id, session_id=session_id, lang=(info.start_listening.lang or self.lang), ) if info.start_listening.send_audio_captured: wav_bytes = self.to_wav_bytes(audio_data) # Send audio data yield ( # pylint: disable=E1121 AsrAudioCaptured(wav_bytes), { "site_id": site_id, "session_id": session_id }, )
def on_message(self, client, userdata, msg): """Received message from MQTT broker.""" try: _LOGGER.debug("Received %s byte(s) on %s", len(msg.payload), msg.topic) if msg.topic == DialogueStartSession.topic(): # Start session json_payload = json.loads(msg.payload) if not self._check_siteId(json_payload): return # Run in event loop (for TTS) asyncio.run_coroutine_threadsafe( self.handle_start(DialogueStartSession(**json_payload)), self.loop) elif msg.topic == DialogueContinueSession.topic(): # Continue session json_payload = json.loads(msg.payload) if not self._check_siteId(json_payload): return # Run in event loop (for TTS) asyncio.run_coroutine_threadsafe( self.handle_continue( DialogueContinueSession(**json_payload)), self.loop, ) elif msg.topic == DialogueEndSession.topic(): # End session json_payload = json.loads(msg.payload) if not self._check_siteId(json_payload): return # Run outside event loop self.handle_end(DialogueEndSession(**json_payload)) elif msg.topic == TtsSayFinished.topic(): # TTS finished json_payload = json.loads(msg.payload) if not self._check_sessionId(json_payload): return # Signal event loop self.loop.call_soon_threadsafe(self.say_finished_event.set) elif msg.topic == AsrTextCaptured.topic(): # Text captured json_payload = json.loads(msg.payload) if not self._check_sessionId(json_payload): return # Run outside event loop self.handle_text_captured(AsrTextCaptured(**json_payload)) elif msg.topic.startswith(NluIntent.topic(intent_name="")): # Intent recognized json_payload = json.loads(msg.payload) if not self._check_sessionId(json_payload): return self.handle_recognized(NluIntent(**json_payload)) elif msg.topic.startswith(NluIntentNotRecognized.topic()): # Intent recognized json_payload = json.loads(msg.payload) if not self._check_sessionId(json_payload): return # Run in event loop (for TTS) asyncio.run_coroutine_threadsafe( self.handle_not_recognized( NluIntentNotRecognized(**json_payload)), self.loop, ) elif msg.topic in self.wakeword_topics: json_payload = json.loads(msg.payload) if not self._check_siteId(json_payload): return wakeword_id = self.wakeword_topics[msg.topic] asyncio.run_coroutine_threadsafe( self.handle_wake(wakeword_id, HotwordDetected(**json_payload)), self.loop, ) except Exception: _LOGGER.exception("on_message")
async def async_test_session(self): """Check good start/stop session.""" fake_transcription = Transcription( text="this is a test", likelihood=1, transcribe_seconds=0, wav_seconds=0 ) def fake_transcribe(stream, *args): """Return test trancription.""" for chunk in stream: if not chunk: break return fake_transcription self.transcriber.transcribe_stream = fake_transcribe # Start session start_listening = AsrStartListening( site_id=self.site_id, session_id=self.session_id, stop_on_silence=False, send_audio_captured=True, ) result = None async for response in self.hermes.on_message_blocking(start_listening): result = response # No response expected self.assertIsNone(result) # Send in "audio" fake_wav_bytes = self.hermes.to_wav_bytes(secrets.token_bytes(100)) fake_frame = AudioFrame(wav_bytes=fake_wav_bytes) async for response in self.hermes.on_message_blocking( fake_frame, site_id=self.site_id ): result = response # No response expected self.assertIsNone(result) # Stop session stop_listening = AsrStopListening( site_id=self.site_id, session_id=self.session_id ) results = [] async for response in self.hermes.on_message_blocking(stop_listening): results.append(response) # Check results self.assertEqual( results, [ AsrRecordingFinished(site_id=self.site_id, session_id=self.session_id), AsrTextCaptured( text=fake_transcription.text, likelihood=fake_transcription.likelihood, seconds=fake_transcription.transcribe_seconds, site_id=self.site_id, session_id=self.session_id, ), ( AsrAudioCaptured(wav_bytes=fake_wav_bytes), {"site_id": self.site_id, "session_id": self.session_id}, ), ], )
async def handle_stop_listening( self, stop_listening: AsrStopListening ) -> typing.AsyncIterable[ typing.Union[ AsrTextCaptured, typing.Tuple[AsrAudioCaptured, TopicArgs], AsrError ] ]: """Stop ASR session.""" _LOGGER.debug("<- %s", stop_listening) try: session = self.asr_sessions.pop(stop_listening.session_id, None) if session is None: _LOGGER.warning("Session not found for %s", stop_listening.session_id) return assert session.sample_rate is not None, "No sample rate" assert session.sample_width is not None, "No sample width" assert session.channels is not None, "No channels" if session.start_listening.stop_on_silence: # Use recorded voice command audio_data = session.recorder.stop() else: # Use entire audio audio_data = session.audio_data # Process entire WAV file wav_bytes = self.to_wav_bytes( audio_data, session.sample_rate, session.sample_width, session.channels ) _LOGGER.debug("Received %s byte(s) of WAV data", len(wav_bytes)) if self.asr_url: _LOGGER.debug(self.asr_url) # Remote ASR server async with self.http_session.post( self.asr_url, data=wav_bytes, headers={"Content-Type": "audio/wav", "Accept": "application/json"}, ssl=self.ssl_context, ) as response: response.raise_for_status() transcription_dict = await response.json() elif self.asr_command: # Local ASR command _LOGGER.debug(self.asr_command) start_time = time.perf_counter() proc = await asyncio.create_subprocess_exec( *self.asr_command, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) output, error = await proc.communicate(wav_bytes) if error: _LOGGER.debug(error.decode()) text = output.decode() end_time = time.perf_counter() transcription_dict = { "text": text, "transcribe_seconds": (end_time - start_time), } else: # Empty transcription _LOGGER.warning( "No ASR URL or command. Only empty transcriptions will be returned." ) transcription_dict = {} # Publish transcription yield AsrTextCaptured( text=transcription_dict.get("text", ""), likelihood=float(transcription_dict.get("likelihood", 0)), seconds=float(transcription_dict.get("transcribe_seconds", 0)), site_id=stop_listening.site_id, session_id=stop_listening.session_id, lang=session.start_listening.lang, ) if session.start_listening.send_audio_captured: # Send audio data yield ( AsrAudioCaptured(wav_bytes=wav_bytes), { "site_id": stop_listening.site_id, "session_id": stop_listening.session_id, }, ) except Exception as e: _LOGGER.exception("handle_stop_listening") yield AsrError( error=str(e), context=f"url='{self.asr_url}', command='{self.asr_command}'", site_id=stop_listening.site_id, session_id=stop_listening.session_id, )