def get_command_recorder(self): """Get voice command recorder based on profile settings.""" from rhasspysilence import WebRtcVadRecorder # Load settings vad_mode = int(pydash.get(self.profile, "voice-command.vad-mode", 3)) min_seconds = float( pydash.get(self.profile, "voice-command.minimum-seconds", 1)) max_seconds = float( pydash.get(self.profile, "voice-command.maximum-seconds", 30)) speech_seconds = float( pydash.get(self.profile, "voice-command.speech-seconds", 0.3)) silence_seconds = float( pydash.get(self.profile, "voice-command.silence-seconds", 0.5)) before_seconds = float( pydash.get(self.profile, "voice-command.before-seconds", 0.5)) skip_seconds = float( pydash.get(self.profile, "voice-command.skip-seconds", 0)) chunk_size = int( pydash.get(self.profile, "voice-command.chunk-size", 960)) sample_rate = int( pydash.get(self.profile, "audio.format.sample-rate-hertz", 16000)) return WebRtcVadRecorder( vad_mode=vad_mode, sample_rate=sample_rate, chunk_size=chunk_size, min_seconds=min_seconds, max_seconds=max_seconds, speech_seconds=speech_seconds, silence_seconds=silence_seconds, before_seconds=before_seconds, skip_seconds=skip_seconds, )
def __init__( self, client, transcriber: Transcriber, siteIds: typing.Optional[typing.List[str]] = None, enabled: bool = True, sample_rate: int = 16000, sample_width: int = 2, channels: int = 1, make_recorder: typing.Callable[[], VoiceCommandRecorder] = None, ): self.client = client self.transcriber = transcriber self.siteIds = siteIds or [] self.enabled = enabled # Required audio format self.sample_rate = sample_rate self.sample_width = sample_width self.channels = channels # No timeout self.make_recorder = make_recorder or ( lambda: WebRtcVadRecorder(max_seconds=None)) # WAV buffers for each session self.session_recorders: typing.Dict[ str, VoiceCommandRecorder] = defaultdict(VoiceCommandRecorder) # Topic to listen for WAV chunks on self.audioframe_topics: typing.List[str] = [] for siteId in self.siteIds: self.audioframe_topics.append(AudioFrame.topic(siteId=siteId)) self.first_audio: bool = True
class RhasspySilenceTestCase(unittest.TestCase): """Tests for rhasspysilence.""" def __init__(self, *args): super().__init__(*args) self.recorder = None self.chunk_size = 2048 def setUp(self): self.recorder = WebRtcVadRecorder() self.recorder.start() def test_command(self): """Verify voice command sample WAV file.""" command = None # Check test WAV file with wave.open("etc/turn_on_living_room_lamp.wav", "r") as wav_file: audio_data = wav_file.readframes(wav_file.getnframes()) while audio_data: chunk = audio_data[:self.chunk_size] audio_data = audio_data[self.chunk_size:] command = self.recorder.process_chunk(chunk) if command: break self.assertTrue(command) self.assertEqual(command.result, VoiceCommandResult.SUCCESS) self.assertGreater(len(command.audio_data), 0) def test_noise(self): """Verify no command in noise WAV file.""" command = None # Check test WAV file with wave.open("etc/noise.wav", "r") as wav_file: audio_data = wav_file.readframes(wav_file.getnframes()) while audio_data: chunk = audio_data[:self.chunk_size] audio_data = audio_data[self.chunk_size:] command = self.recorder.process_chunk(chunk) if command: break self.assertFalse(command)
def test_noise(): """Verify no command in noise WAV file.""" command = None recorder = WebRtcVadRecorder() recorder.start() # Check test WAV file with wave.open("etc/noise.wav", "r") as wav_file: audio_data = wav_file.readframes(wav_file.getnframes()) while audio_data: chunk = audio_data[:CHUNK_SIZE] audio_data = audio_data[CHUNK_SIZE:] command = recorder.process_chunk(chunk) if command: break assert not command
def make_webrtcvad(): return WebRtcVadRecorder( max_seconds=None, vad_mode=vad_mode, skip_seconds=skip_seconds, min_seconds=min_seconds, speech_seconds=speech_seconds, silence_seconds=silence_seconds, before_seconds=before_seconds, )
def default_recorder(): return WebRtcVadRecorder( max_seconds=None, vad_mode=vad_mode, skip_seconds=skip_seconds, min_seconds=min_seconds, speech_seconds=speech_seconds, silence_seconds=silence_seconds, before_seconds=before_seconds, )
def test_command(): """Verify voice command sample WAV file.""" command = None recorder = WebRtcVadRecorder() recorder.start() # Check test WAV file with wave.open("etc/turn_on_living_room_lamp.wav", "r") as wav_file: audio_data = wav_file.readframes(wav_file.getnframes()) while audio_data: chunk = audio_data[:CHUNK_SIZE] audio_data = audio_data[CHUNK_SIZE:] command = recorder.process_chunk(chunk) if command: break assert command assert command.result == VoiceCommandResult.SUCCESS assert command.audio_data
def trim_silence( audio_bytes: bytes, ratio_threshold: float = 20.0, chunk_size: int = 960, skip_first_chunk=True, ) -> bytes: """Trim silence from start and end of audio using ratio of max/current energy.""" first_chunk = False energies = [] max_energy = None while len(audio_bytes) >= chunk_size: chunk = audio_bytes[:chunk_size] audio_bytes = audio_bytes[chunk_size:] if skip_first_chunk and (not first_chunk): first_chunk = True continue energy = max(1, WebRtcVadRecorder.get_debiased_energy(chunk)) energies.append((energy, chunk)) if (max_energy is None) or (energy > max_energy): max_energy = energy # Determine chunks below threshold assert max_energy is not None, "No maximum energy" start_index = None end_index = None for i, (energy, chunk) in enumerate(energies): ratio = max_energy / energy if ratio < ratio_threshold: end_index = None if start_index is None: start_index = i elif end_index is None: end_index = i if start_index is None: start_index = 0 if end_index is None: end_index = len(energies) - 1 start_index = max(0, start_index - 1) end_index = min(len(energies) - 1, end_index + 1) keep_bytes = bytes() for _, chunk in energies[start_index : end_index + 1]: keep_bytes += chunk return keep_bytes
def default_recorder(): return WebRtcVadRecorder( max_seconds=max_seconds, vad_mode=vad_mode, skip_seconds=skip_seconds, min_seconds=min_seconds, speech_seconds=speech_seconds, silence_seconds=silence_seconds, before_seconds=before_seconds, silence_method=silence_method, current_energy_threshold=current_energy_threshold, max_energy=max_energy, max_current_ratio_threshold=max_current_energy_ratio_threshold, )
def record_templates( record_dir: Path, name_format: str, recorder: WebRtcVadRecorder, args: argparse.Namespace, ): """Record audio templates.""" print("Reading 16-bit 16Khz mono audio from stdin...", file=sys.stderr) num_templates = 0 try: print( f"Recording template {num_templates}. Please speak your wake word. Press CTRL+C to exit." ) recorder.start() while True: # Read raw audio chunk chunk = sys.stdin.buffer.read(recorder.chunk_size) if not chunk: # Empty chunk break result = recorder.process_chunk(chunk) if result: audio_bytes = recorder.stop() audio_bytes = trim_silence(audio_bytes) template_path = record_dir / name_format.format( n=num_templates) template_path.parent.mkdir(parents=True, exist_ok=True) wav_bytes = buffer_to_wav(audio_bytes) template_path.write_bytes(wav_bytes) _LOGGER.debug("Wrote %s byte(s) of WAV audio to %s", len(wav_bytes), template_path) num_templates += 1 print( f"Recording template {num_templates}. Please speak your wake word. Press CTRL+C to exit." ) recorder.start() except KeyboardInterrupt: print("Done")
def make_webrtcvad(): return WebRtcVadRecorder(max_seconds=None)
def __init__( self, templates: typing.List[Template], keyword_name: str = "", probability_threshold: float = 0.5, minimum_matches: int = 0, distance_threshold: float = 0.22, template_dtw: typing.Optional[DynamicTimeWarping] = None, dtw_window_size: int = 5, dtw_step_pattern: float = 2, shift_sec: float = DEFAULT_SHIFT_SECONDS, refractory_sec: float = 2.0, skip_probability_threshold: float = 0.0, failed_matches_to_refractory: typing.Optional[int] = None, recorder: typing.Optional[WebRtcVadRecorder] = None, debug: bool = False, benchmark: bool = False, ): self.templates = templates assert self.templates, "No templates" self.keyword_name = keyword_name # Use or create silence detector self.recorder = recorder or WebRtcVadRecorder() self.vad_chunk_bytes = self.recorder.chunk_size self.sample_rate = self.recorder.sample_rate # Assume 16-bit samples self.sample_width = 2 self.bytes_per_second = int(self.sample_rate * self.sample_width) # Match settings self.probability_threshold = probability_threshold self.minimum_matches = minimum_matches self.distance_threshold = distance_threshold self.skip_probability_threshold = skip_probability_threshold self.refractory_sec = refractory_sec self.failed_matches_to_refractory = failed_matches_to_refractory # Dynamic time warping calculation self.dtw = template_dtw or DynamicTimeWarping() self.dtw_window_size = dtw_window_size self.dtw_step_pattern = dtw_step_pattern # Average duration of templates template_duration_sec = sum([t.duration_sec for t in templates]) / len(templates) # Seconds to shift template window by during processing self.template_shift_sec = shift_sec self.shifts_per_template = ( int(math.floor(template_duration_sec / shift_sec)) - 1) # Bytes needed for a template self.template_chunk_bytes = int( math.ceil(template_duration_sec * self.bytes_per_second)) # Ensure divisible by sample width while (self.template_chunk_bytes % self.sample_width) != 0: self.template_chunk_bytes += 1 # Audio self.vad_audio_buffer = bytes() self.template_audio_buffer = bytes() self.example_audio_buffer = bytes() self.template_mfcc: typing.Optional[np.ndarray] = None self.template_chunks_left = 0 self.num_template_chunks = int( math.ceil((self.template_chunk_bytes / self.vad_chunk_bytes) / 2)) # State machine self.num_refractory_chunks = int( math.ceil(self.sample_rate * self.sample_width * (refractory_sec / self.vad_chunk_bytes))) self.refractory_chunks_left = 0 self.failed_matches = 0 self.match_seconds: typing.Optional[float] = None # If True, log DTW predictions self.debug = debug # Keep previously-computed distances and probabilities for debugging self.last_distances: typing.List[typing.Optional[float]] = [ None for _ in self.templates ] self.last_probabilities: typing.List[typing.Optional[float]] = [ None for _ in self.templates ] # ------------ # Benchmarking # ------------ self.benchmark = benchmark # Seconds to process an entire VAD chunk self.time_process_vad_chunk: typing.List[float] = [] # Seconds to compute single MFCC self.time_mfcc: typing.List[float] = [] # Seconds to check template-sized window for a match self.time_match: typing.List[float] = [] # Seconds to compute DTW cost self.time_dtw: typing.List[float] = []
def __init__( self, client, raven: Raven, minimum_matches: int = 1, wakeword_id: str = "", site_ids: typing.Optional[typing.List[str]] = None, enabled: bool = True, sample_rate: int = 16000, sample_width: int = 2, channels: int = 1, chunk_size: int = 960, udp_audio: typing.Optional[typing.List[typing.Tuple[str, int, str]]] = None, udp_chunk_size: int = 2048, log_predictions: bool = False, ): super().__init__( "rhasspywake_raven_hermes", client, sample_rate=sample_rate, sample_width=sample_width, channels=channels, site_ids=site_ids, ) self.subscribe( AudioFrame, HotwordToggleOn, HotwordToggleOff, GetHotwords, RecordHotwordExample, ) self.raven = raven self.minimum_matches = minimum_matches self.wakeword_id = wakeword_id self.enabled = enabled self.disabled_reasons: typing.Set[str] = set() # Required audio format self.sample_rate = sample_rate self.sample_width = sample_width self.channels = channels self.chunk_size = chunk_size # Queue of WAV audio chunks to process (plus site_id) self.wav_queue: queue.Queue = queue.Queue() self.first_audio: bool = True self.audio_buffer = bytes() self.last_audio_site_id: str = "default" # Fields for recording examples self.recording_example = False self.example_recorder = WebRtcVadRecorder(max_seconds=10) self.example_future: typing.Optional[asyncio.Future] = None # Start threads self.detection_thread = threading.Thread( target=self.detection_thread_proc, daemon=True ) self.detection_thread.start() # Listen for raw audio on UDP too self.udp_chunk_size = udp_chunk_size if udp_audio: for udp_host, udp_port, udp_site_id in udp_audio: threading.Thread( target=self.udp_thread_proc, args=(udp_host, udp_port, udp_site_id), daemon=True, ).start()
def setUp(self): self.recorder = WebRtcVadRecorder() self.recorder.start()
def __init__( self, templates: typing.List[Template], probability_threshold: float = 0.5, minimum_matches: int = 0, distance_threshold: float = 0.22, frame_dtw: typing.Optional[DynamicTimeWarping] = None, dtw_window_size: int = 5, dtw_step_pattern: float = 2, sample_rate: int = 16000, chunk_size: int = 960, shift_sec: float = 0.01, before_chunks: int = 0, refractory_sec: float = 2.0, skip_probability_threshold: float = 0.0, recorder: typing.Optional[WebRtcVadRecorder] = None, debug: bool = False, ): self.templates = templates assert self.templates, "No templates" self.probability_threshold = probability_threshold self.minimum_matches = minimum_matches self.distance_threshold = distance_threshold self.skip_probability_threshold = skip_probability_threshold self.chunk_size = chunk_size self.shift_sec = shift_sec self.sample_rate = sample_rate self.before_buffer: typing.Optional[typing.Deque[bytes]] = None if before_chunks > 0: self.before_buffer = deque(maxlen=before_chunks) # Assume 16-bit samples self.sample_width = 2 self.chunk_seconds = (self.chunk_size / self.sample_width) / self.sample_rate # Use or create silence detector self.recorder = recorder or WebRtcVadRecorder() # Dynamic time warping calculation self.dtw = frame_dtw or DynamicTimeWarping() self.dtw_window_size = dtw_window_size self.dtw_step_pattern = dtw_step_pattern # Keep previously-computed distances and probabilities for debugging self.last_distances: typing.List[typing.Optional[float]] = [ None for _ in self.templates ] self.last_probabilities: typing.List[typing.Optional[float]] = [ None for _ in self.templates ] # Average duration of templates self.frame_duration_sec = sum([t.duration_sec for t in templates]) / len(templates) # Size in bytes of a frame self.window_chunk_size = ( self.seconds_to_chunks(self.frame_duration_sec) * self.chunk_size * self.sample_width) # Ensure divisible by sample width while (self.window_chunk_size % self.sample_width) != 0: self.window_chunk_size += 1 # Size in bytes to shift each frame. # Should be less than the size of a frame to ensure overlap. self.shift_size = (self.seconds_to_chunks(self.shift_sec) * self.chunk_size * self.sample_width) # State machine self.audio_buffer = bytes() self.state = RavenState.IN_SILENCE self.num_silence_chunks = self.seconds_to_chunks( self.frame_duration_sec) self.silence_chunks_left = 0 self.num_refractory_chunks = self.seconds_to_chunks(refractory_sec) self.refractory_chunks_left = 0 self.match_seconds: typing.Optional[float] = None self.debug = debug
def __init__( self, client, ravens: typing.List[Raven], examples_dir: typing.Optional[Path] = None, examples_format: str = "{keyword}/examples/%Y%m%d-%H%M%S.wav", wakeword_id: str = "", site_ids: typing.Optional[typing.List[str]] = None, enabled: bool = True, sample_rate: int = 16000, sample_width: int = 2, channels: int = 1, chunk_size: int = 1920, udp_audio: typing.Optional[typing.List[typing.Tuple[str, int, str]]] = None, udp_chunk_size: int = 2048, log_predictions: bool = False, lang: typing.Optional[str] = None, ): super().__init__( "rhasspywake_raven_hermes", client, sample_rate=sample_rate, sample_width=sample_width, channels=channels, site_ids=site_ids, ) self.subscribe( AudioFrame, HotwordToggleOn, HotwordToggleOff, GetHotwords, RecordHotwordExample, ) self.ravens = ravens self.wakeword_id = wakeword_id self.examples_dir = examples_dir self.examples_format = examples_format self.enabled = enabled self.disabled_reasons: typing.Set[str] = set() # Required audio format self.sample_rate = sample_rate self.sample_width = sample_width self.channels = channels self.chunk_size = chunk_size # Queue of WAV audio chunks to process (plus site_id) self.wav_queue: queue.Queue = queue.Queue() self.first_audio: bool = True self.last_audio_site_id: str = "default" self.lang = lang # Fields for recording examples self.recording_example = False self.example_recorder = WebRtcVadRecorder(max_seconds=10) self.example_future: typing.Optional[asyncio.Future] = None # Raw audio chunk queues for Raven threads self.chunk_queues: typing.List[queue.Queue] = [ queue.Queue() for raven in ravens ] # Start main thread to convert audio from MQTT/UDP self.audio_thread = threading.Thread(target=self.audio_thread_proc, daemon=True) self.audio_thread.start() # Start a thread per Raven instance (per-keyword) self.detection_threads = [ threading.Thread( target=self.detection_thread_proc, args=(self.chunk_queues[i], self.ravens[i]), daemon=True, ) for i in range(len(self.ravens)) ] for thread in self.detection_threads: thread.start() # Listen for raw audio on UDP too self.udp_chunk_size = udp_chunk_size if udp_audio: for udp_host, udp_port, udp_site_id in udp_audio: threading.Thread( target=self.udp_thread_proc, args=(udp_host, udp_port, udp_site_id), daemon=True, ).start()
def main(): """Main method.""" parser = argparse.ArgumentParser(prog="rhasspy-wake-raven-hermes") parser.add_argument( "--keyword", action="append", nargs="+", default=[], help="Directory with WAV templates and settings (setting-name=value)", ) parser.add_argument( "--probability-threshold", type=float, default=0.5, help="Probability above which detection occurs (default: 0.5)", ) parser.add_argument( "--distance-threshold", type=float, default=0.22, help= "Normalized dynamic time warping distance threshold for template matching (default: 0.22)", ) parser.add_argument( "--minimum-matches", type=int, default=1, help= "Number of templates that must match to produce output (default: 1)", ) parser.add_argument( "--refractory-seconds", type=float, default=2.0, help="Seconds before wake word can be activated again (default: 2)", ) parser.add_argument( "--window-shift-seconds", type=float, default=Raven.DEFAULT_SHIFT_SECONDS, help= f"Seconds to shift sliding time window on audio buffer (default: {Raven.DEFAULT_SHIFT_SECONDS})", ) parser.add_argument( "--dtw-window-size", type=int, default=5, help= "Size of band around slanted diagonal during dynamic time warping calculation (default: 5)", ) parser.add_argument( "--vad-sensitivity", type=int, choices=[1, 2, 3], default=3, help="Webrtcvad VAD sensitivity (1-3)", ) parser.add_argument( "--current-threshold", type=float, help="Debiased energy threshold of current audio frame", ) parser.add_argument( "--max-energy", type=float, help="Fixed maximum energy for ratio calculation (default: observed)", ) parser.add_argument( "--max-current-ratio-threshold", type=float, help="Threshold of ratio between max energy and current audio frame", ) parser.add_argument( "--silence-method", choices=[e.value for e in SilenceMethod], default=SilenceMethod.VAD_ONLY, help="Method for detecting silence", ) parser.add_argument( "--average-templates", action="store_true", help= "Average wakeword templates together to reduce number of calculations", ) parser.add_argument( "--udp-audio", nargs=3, action="append", help="Host/port/siteId for UDP audio input", ) parser.add_argument( "--examples-dir", help="Save positive example audio to directory as WAV files") parser.add_argument( "--examples-format", default="{keyword}/examples/%Y%m%d-%H%M%S.wav", help= "Format of positive example WAV file names using strftime (relative to examples-dir)", ) parser.add_argument( "--log-predictions", action="store_true", help="Log prediction probabilities for each audio chunk (very verbose)", ) parser.add_argument("--lang", help="Set lang in hotword detected message") hermes_cli.add_hermes_args(parser) args = parser.parse_args() hermes_cli.setup_logging(args) _LOGGER.debug(args) hermes: typing.Optional[WakeHermesMqtt] = None # ------------------------------------------------------------------------- if args.examples_dir: # Directory to save positive example WAV files args.examples_dir = Path(args.examples_dir) args.examples_dir.mkdir(parents=True, exist_ok=True) if args.keyword: missing_keywords = not any( list(Path(k[0]).glob("*.wav")) for k in args.keyword) else: missing_keywords = True if missing_keywords: args.keyword = [[_DIR / "templates"]] _LOGGER.debug( "No keywords provided. Use built-in 'okay rhasspy' templates.") # Create silence detector recorder = WebRtcVadRecorder( vad_mode=args.vad_sensitivity, silence_method=args.silence_method, current_energy_threshold=args.current_threshold, max_energy=args.max_energy, max_current_ratio_threshold=args.max_current_ratio_threshold, ) # Load audio templates ravens: typing.List[Raven] = [] for keyword_settings in args.keyword: template_dir = Path(keyword_settings[0]) wav_paths = list(template_dir.glob("*.wav")) if not wav_paths: _LOGGER.warning("No WAV files found in %s", template_dir) continue keyword_name = template_dir.name if not missing_keywords else "okay-rhasspy" # Load audio templates keyword_templates = [ Raven.wav_to_template(p, name=str(p), shift_sec=args.window_shift_seconds) for p in wav_paths ] raven_args = { "templates": keyword_templates, "keyword_name": keyword_name, "recorder": recorder, "probability_threshold": args.probability_threshold, "minimum_matches": args.minimum_matches, "distance_threshold": args.distance_threshold, "refractory_sec": args.refractory_seconds, "shift_sec": args.window_shift_seconds, "debug": args.log_predictions, } # Apply settings average_templates = args.average_templates for setting_str in keyword_settings[1:]: setting_name, setting_value = setting_str.strip().split("=", maxsplit=1) setting_name = setting_name.lower().replace("_", "-") if setting_name == "name": raven_args["keyword_name"] = setting_value elif setting_name == "probability-threshold": raven_args["probability_threshold"] = float(setting_value) elif setting_name == "minimum-matches": raven_args["minimum_matches"] = int(setting_value) elif setting_name == "average-templates": average_templates = setting_value.lower().strip() == "true" if average_templates: _LOGGER.debug("Averaging %s templates for %s", len(keyword_templates), template_dir) raven_args["templates"] = [ Template.average_templates(keyword_templates) ] # Create instance of Raven in a separate thread for keyword ravens.append(Raven(**raven_args)) udp_audio = [] if args.udp_audio: udp_audio = [(host, int(port), site_id) for host, port, site_id in args.udp_audio] # Listen for messages client = mqtt.Client() hermes = WakeHermesMqtt( client, ravens=ravens, examples_dir=args.examples_dir, examples_format=args.examples_format, udp_audio=udp_audio, site_ids=args.site_id, lang=args.lang, ) _LOGGER.debug("Connecting to %s:%s", args.host, args.port) hermes_cli.connect(client, args) client.loop_start() try: # Run event loop asyncio.run(hermes.handle_messages_async()) except KeyboardInterrupt: pass finally: _LOGGER.debug("Shutting down") client.loop_stop() hermes.stop()
class WakeHermesMqtt(HermesClient): """Hermes MQTT server for Rhasspy wakeword with Raven.""" def __init__( self, client, ravens: typing.List[Raven], examples_dir: typing.Optional[Path] = None, examples_format: str = "{keyword}/examples/%Y%m%d-%H%M%S.wav", wakeword_id: str = "", site_ids: typing.Optional[typing.List[str]] = None, enabled: bool = True, sample_rate: int = 16000, sample_width: int = 2, channels: int = 1, chunk_size: int = 1920, udp_audio: typing.Optional[typing.List[typing.Tuple[str, int, str]]] = None, udp_chunk_size: int = 2048, log_predictions: bool = False, lang: typing.Optional[str] = None, ): super().__init__( "rhasspywake_raven_hermes", client, sample_rate=sample_rate, sample_width=sample_width, channels=channels, site_ids=site_ids, ) self.subscribe( AudioFrame, HotwordToggleOn, HotwordToggleOff, GetHotwords, RecordHotwordExample, ) self.ravens = ravens self.wakeword_id = wakeword_id self.examples_dir = examples_dir self.examples_format = examples_format self.enabled = enabled self.disabled_reasons: typing.Set[str] = set() # Required audio format self.sample_rate = sample_rate self.sample_width = sample_width self.channels = channels self.chunk_size = chunk_size # Queue of WAV audio chunks to process (plus site_id) self.wav_queue: queue.Queue = queue.Queue() self.first_audio: bool = True self.last_audio_site_id: str = "default" self.lang = lang # Fields for recording examples self.recording_example = False self.example_recorder = WebRtcVadRecorder(max_seconds=10) self.example_future: typing.Optional[asyncio.Future] = None # Raw audio chunk queues for Raven threads self.chunk_queues: typing.List[queue.Queue] = [ queue.Queue() for raven in ravens ] # Start main thread to convert audio from MQTT/UDP self.audio_thread = threading.Thread(target=self.audio_thread_proc, daemon=True) self.audio_thread.start() # Start a thread per Raven instance (per-keyword) self.detection_threads = [ threading.Thread( target=self.detection_thread_proc, args=(self.chunk_queues[i], self.ravens[i]), daemon=True, ) for i in range(len(self.ravens)) ] for thread in self.detection_threads: thread.start() # Listen for raw audio on UDP too self.udp_chunk_size = udp_chunk_size if udp_audio: for udp_host, udp_port, udp_site_id in udp_audio: threading.Thread( target=self.udp_thread_proc, args=(udp_host, udp_port, udp_site_id), daemon=True, ).start() # ------------------------------------------------------------------------- async def handle_audio_frame(self, wav_bytes: bytes, site_id: str = "default") -> None: """Process a single audio frame""" self.wav_queue.put((wav_bytes, site_id)) async def handle_detection( self, matching_indexes: typing.List[int], raven: Raven ) -> typing.AsyncIterable[typing.Union[typing.Tuple[ HotwordDetected, TopicArgs], HotwordError]]: """Handle a successful hotword detection""" try: template = raven.templates[matching_indexes[0]] wakeword_id = raven.keyword_name or template.name if not wakeword_id: wakeword_id = "default" yield ( HotwordDetected( site_id=self.last_audio_site_id, model_id=template.name, current_sensitivity=raven.probability_threshold, model_version="", model_type="personal", lang=self.lang, ), { "wakeword_id": wakeword_id }, ) except Exception as e: _LOGGER.exception("handle_detection") yield HotwordError( error=str(e), context=f"{raven.keyword_name}: {template.name}", site_id=self.last_audio_site_id, ) async def handle_get_hotwords( self, get_hotwords: GetHotwords ) -> typing.AsyncIterable[typing.Union[Hotwords, HotwordError]]: """Report available hotwords""" try: models: typing.List[Hotword] = [] # Each keyword is in a separate Raven instance for raven in self.ravens: # Assume that the directory name is something like # "okay-rhasspy" for the keyword "okay rhasspy". models.append( Hotword( model_id=raven.keyword_name, model_words=re.sub(r"[_-]+", " ", raven.keyword_name), )) yield Hotwords(models=models, id=get_hotwords.id, site_id=get_hotwords.site_id) except Exception as e: _LOGGER.exception("handle_get_hotwords") yield HotwordError(error=str(e), context=str(get_hotwords), site_id=get_hotwords.site_id) async def handle_record_example( self, record_example: RecordHotwordExample ) -> typing.AsyncIterable[typing.Union[typing.Tuple[ HotwordExampleRecorded, TopicArgs], HotwordError]]: """Record an example of a hotword.""" try: if self.recording_example: _LOGGER.warning("Cancelling previous recording") self.example_recorder.stop() # Start recording assert self.loop, "No loop" self.example_future = self.loop.create_future() self.example_recorder.start() self.recording_example = True # Wait for result _LOGGER.debug("Recording example (id=%s)", record_example.id) example_audio = await self.example_future assert isinstance(example_audio, bytes) # Trim silence _LOGGER.debug("Trimming silence from example") example_audio = trim_silence(example_audio) # Convert to WAV format wav_data = self.to_wav_bytes(example_audio) yield ( HotwordExampleRecorded(wav_bytes=wav_data), { "site_id": record_example.site_id, "request_id": record_example.id }, ) except Exception as e: _LOGGER.exception("handle_record_example") yield HotwordError( error=str(e), context=str(record_example), site_id=record_example.site_id, ) def add_example_audio(self, audio_data: bytes): """Add an audio frame to the currently recording example.""" result = self.example_recorder.process_chunk(audio_data) if result: self.recording_example = False assert self.example_future is not None, "No future" example_audio = self.example_recorder.stop() _LOGGER.debug("Recorded %s byte(s) for audio for example", len(example_audio)) # Signal waiting coroutine with audio assert self.loop, "No loop" self.loop.call_soon_threadsafe(self.example_future.set_result, example_audio) # ------------------------------------------------------------------------- def audio_thread_proc(self): """Handle WAV audio chunks.""" try: while True: wav_bytes, site_id = self.wav_queue.get() if wav_bytes is None: # Shutdown signal for chunk_queue in self.chunk_queues: chunk_queue.put(None) # Wait for detection threads to exit for thread in self.detection_threads: thread.join() break self.last_audio_site_id = site_id # Handle audio frames if self.first_audio: _LOGGER.debug("Receiving audio") self.first_audio = False # Extract/convert audio data audio_data = self.maybe_convert_wav(wav_bytes) if self.recording_example: # Add to currently recording example self.add_example_audio(audio_data) # Don't process audio for wake word while recording continue # Add to queues for detection threads for chunk_queue in self.chunk_queues: chunk_queue.put(audio_data) except Exception: _LOGGER.exception("audio_thread_proc") def detection_thread_proc(self, chunk_queue: queue.Queue, raven: Raven): """Run Raven detection on audio chunks.""" try: _LOGGER.debug( "Listening for keyword %s (probability_threshold=%s, minimum_matches=%s, num_templates=%s)", raven.keyword_name, raven.probability_threshold, raven.minimum_matches, len(raven.templates), ) while True: audio_data = chunk_queue.get() if audio_data is None: # Shutdown signal break if audio_data: try: keep_audio = bool(self.examples_dir) matching_indexes = raven.process_chunk( audio_data, keep_audio=keep_audio) if len(matching_indexes) >= raven.minimum_matches: # Report detection assert self.loop is not None, "No loop" asyncio.run_coroutine_threadsafe( self.publish_all( self.handle_detection( matching_indexes, raven)), self.loop, ) if keep_audio: # Save positive example assert self.examples_dir is not None example_path = self.examples_dir / time.strftime( self.examples_format).format( keyword=raven.keyword_name) example_path.parent.mkdir(parents=True, exist_ok=True) with open(example_path, "wb") as example_file: example_wav_bytes = self.to_wav_bytes( raven.example_audio_buffer) example_file.write(example_wav_bytes) _LOGGER.debug("Wrote example to %s", example_path) except Exception: _LOGGER.exception("process_chunk") except Exception: _LOGGER.exception("detection_thread_proc") # ------------------------------------------------------------------------- def stop(self): """Stop audio and detection threads.""" self.wav_queue.put((None, "")) _LOGGER.debug("Waiting for detection threads to stop...") self.audio_thread.join() # ------------------------------------------------------------------------- def udp_thread_proc(self, host: str, port: int, site_id: str): """Handle WAV chunks from UDP socket.""" try: udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) udp_socket.bind((host, port)) _LOGGER.debug("Listening for audio on UDP %s:%s", host, port) while True: wav_bytes, _ = udp_socket.recvfrom(self.udp_chunk_size + WAV_HEADER_BYTES) if self.enabled: self.wav_queue.put((wav_bytes, site_id)) except Exception: _LOGGER.exception("udp_thread_proc") # ------------------------------------------------------------------------- async def on_message_blocking( self, message: Message, site_id: typing.Optional[str] = None, session_id: typing.Optional[str] = None, topic: typing.Optional[str] = None, ) -> GeneratorType: """Received message from MQTT broker.""" # Check enable/disable messages if isinstance(message, HotwordToggleOn): if message.reason == HotwordToggleReason.UNKNOWN: # Always enable on unknown self.disabled_reasons.clear() else: self.disabled_reasons.discard(message.reason) if self.disabled_reasons: _LOGGER.debug("Still disabled: %s", self.disabled_reasons) else: self.enabled = True self.first_audio = True _LOGGER.debug("Enabled") elif isinstance(message, HotwordToggleOff): self.enabled = False self.disabled_reasons.add(message.reason) _LOGGER.debug("Disabled") elif isinstance(message, AudioFrame): if self.enabled: assert site_id, "Missing site_id" await self.handle_audio_frame(message.wav_bytes, site_id=site_id) elif isinstance(message, GetHotwords): async for hotword_result in self.handle_get_hotwords(message): yield hotword_result elif isinstance(message, RecordHotwordExample): # Handled in on_message pass else: _LOGGER.warning("Unexpected message: %s", message) async def on_message( self, message: Message, site_id: typing.Optional[str] = None, session_id: typing.Optional[str] = None, topic: typing.Optional[str] = None, ) -> GeneratorType: """Received message from MQTT broker (non-blocking).""" if isinstance(message, RecordHotwordExample): async for example_result in self.handle_record_example(message): yield example_result
def main(): """Main entry point.""" parser = argparse.ArgumentParser(prog="rhasspy-wake-raven") parser.add_argument("templates", nargs="+", help="Path to WAV file templates or directories") parser.add_argument( "--record", help= "Record example templates with given name format (e.g., 'okay-rhasspy-{n:02d}.wav')", ) parser.add_argument( "--probability-threshold", type=float, default=0.5, help="Probability above which detection occurs (default: 0.5)", ) parser.add_argument( "--distance-threshold", type=float, default=0.22, help= "Normalized dynamic time warping distance threshold for template matching (default: 0.22)", ) parser.add_argument( "--minimum-matches", type=int, default=1, help= "Number of templates that must match to produce output (default: 1)", ) parser.add_argument( "--refractory-seconds", type=float, default=2.0, help="Seconds before wake word can be activated again (default: 2)", ) parser.add_argument( "--print-all-matches", action="store_true", help= "Print JSON for all matching templates instead of just the first one", ) parser.add_argument( "--window-shift-seconds", type=float, default=0.01, help= "Seconds to shift sliding time window on audio buffer (default: 0.01)", ) parser.add_argument( "--dtw-window-size", type=int, default=5, help= "Size of band around slanted diagonal during dynamic time warping calculation (default: 5)", ) parser.add_argument( "--vad-sensitivity", type=int, choices=[1, 2, 3], default=3, help="Webrtcvad VAD sensitivity (1-3)", ) parser.add_argument( "--current-threshold", type=float, help="Debiased energy threshold of current audio frame", ) parser.add_argument( "--max-energy", type=float, help="Fixed maximum energy for ratio calculation (default: observed)", ) parser.add_argument( "--max-current-ratio-threshold", type=float, help="Threshold of ratio between max energy and current audio frame", ) parser.add_argument( "--silence-method", choices=[e.value for e in SilenceMethod], default=SilenceMethod.VAD_ONLY, help="Method for detecting silence", ) parser.add_argument( "--average-templates", action="store_true", help= "Average wakeword templates together to reduce number of calculations", ) parser.add_argument( "--exit-count", type=int, help="Exit after some number of detections (default: never)", ) parser.add_argument( "--read-entire-input", action="store_true", help="Read entire audio input at start and exit after processing", ) parser.add_argument( "--max-chunks-in-queue", type=int, help= "Maximum number of audio chunks waiting for processing before being dropped", ) parser.add_argument( "--skip-probability-threshold", type=float, default=0, help= "Skip additional template calculations if probability is below this threshold", ) parser.add_argument("--debug", action="store_true", help="Print DEBUG messages to the console") args = parser.parse_args() if args.debug: logging.basicConfig(level=logging.DEBUG) else: logging.basicConfig(level=logging.INFO) # Create silence detector recorder = WebRtcVadRecorder( vad_mode=args.vad_sensitivity, silence_method=args.silence_method, current_energy_threshold=args.current_threshold, max_energy=args.max_energy, max_current_ratio_threshold=args.max_current_ratio_threshold, min_seconds=0.5, before_seconds=1, ) if args.record: # Do recording instead of recognizing return record_templates(args.record, recorder, args) # Load audio templates template_paths: typing.List[Path] = [] for template_path_str in args.templates: template_path = Path(template_path_str) if template_path.is_dir(): # Add all WAV files from directory _LOGGER.debug("Adding WAV files from directory %s", template_path) template_paths.extend(template_path.glob("*.wav")) elif template_path.is_file(): # Add file directly template_paths.append(template_path) templates = [Raven.wav_to_template(p, name=str(p)) for p in template_paths] if args.average_templates: _LOGGER.debug("Averaging %s templates", len(templates)) templates = [Template.average_templates(templates)] # Create Raven object raven = Raven( templates=templates, recorder=recorder, probability_threshold=args.probability_threshold, minimum_matches=args.minimum_matches, distance_threshold=args.distance_threshold, refractory_sec=args.refractory_seconds, shift_sec=args.window_shift_seconds, skip_probability_threshold=args.skip_probability_threshold, debug=args.debug, ) print("Reading 16-bit 16Khz raw audio from stdin...", file=sys.stderr) if args.read_entire_input: audio_buffer = FakeStdin(sys.stdin.buffer.read()) else: audio_buffer = sys.stdin.buffer chunk_queue = Queue() detect_thread = threading.Thread(target=detect_thread_proc, args=(chunk_queue, raven, args), daemon=True) detect_thread.start() try: while True: # Read raw audio chunk chunk = audio_buffer.read(raven.chunk_size) if not chunk: # Empty chunk break # Ensure chunk is the right size while len(chunk) < raven.chunk_size: chunk += audio_buffer.read(raven.chunk_size - len(chunk)) if not chunk: # Empty chunk break chunk_queue.put(chunk) except KeyboardInterrupt: pass finally: # Exhaust queue while not chunk_queue.empty(): chunk_queue.get() # Signal thread to quit chunk_queue.put(None) detect_thread.join()
def main(): """Main method.""" parser = argparse.ArgumentParser(prog="rhasspy-wake-raven-hermes") parser.add_argument( "--template-dir", help= "Directory with Raven WAV templates (default: templates in Python module)", ) parser.add_argument( "--probability-threshold", type=float, default=0.5, help="Probability above which detection occurs (default: 0.5)", ) parser.add_argument( "--distance-threshold", type=float, default=0.22, help= "Normalized dynamic time warping distance threshold for template matching (default: 0.22)", ) parser.add_argument( "--minimum-matches", type=int, default=1, help= "Number of templates that must match to produce output (default: 1)", ) parser.add_argument( "--refractory-seconds", type=float, default=2.0, help="Seconds before wake word can be activated again (default: 2)", ) parser.add_argument( "--window-shift-seconds", type=float, default=0.05, help= "Seconds to shift sliding time window on audio buffer (default: 0.05)", ) parser.add_argument( "--dtw-window-size", type=int, default=5, help= "Size of band around slanted diagonal during dynamic time warping calculation (default: 5)", ) parser.add_argument( "--vad-sensitivity", type=int, choices=[1, 2, 3], default=3, help="Webrtcvad VAD sensitivity (1-3)", ) parser.add_argument( "--current-threshold", type=float, help="Debiased energy threshold of current audio frame", ) parser.add_argument( "--max-energy", type=float, help="Fixed maximum energy for ratio calculation (default: observed)", ) parser.add_argument( "--max-current-ratio-threshold", type=float, help="Threshold of ratio between max energy and current audio frame", ) parser.add_argument( "--silence-method", choices=[e.value for e in SilenceMethod], default=SilenceMethod.VAD_ONLY, help="Method for detecting silence", ) parser.add_argument( "--average-templates", action="store_true", help= "Average wakeword templates together to reduce number of calculations", ) parser.add_argument( "--wakeword-id", default="", help="Wakeword ID for model (default: use file name)", ) parser.add_argument( "--udp-audio", nargs=3, action="append", help="Host/port/siteId for UDP audio input", ) parser.add_argument( "--log-predictions", action="store_true", help="Log prediction probabilities for each audio chunk (very verbose)", ) hermes_cli.add_hermes_args(parser) args = parser.parse_args() hermes_cli.setup_logging(args) _LOGGER.debug(args) hermes: typing.Optional[WakeHermesMqtt] = None wav_paths: typing.List[Path] = [] if args.template_dir: args.template_dir = Path(args.template_dir) if args.template_dir.is_dir(): _LOGGER.debug("Loading WAV templates from %s", args.template_dir) wav_paths = list(args.template_dir.glob("*.wav")) if not wav_paths: _LOGGER.warning("No WAV templates found!") if not wav_paths: args.template_dir = _DIR / "templates" _LOGGER.debug("Loading WAV templates from %s", args.template_dir) wav_paths = list(args.template_dir.glob("*.wav")) # Create silence detector recorder = WebRtcVadRecorder( vad_mode=args.vad_sensitivity, silence_method=args.silence_method, current_energy_threshold=args.current_threshold, max_energy=args.max_energy, max_current_ratio_threshold=args.max_current_ratio_threshold, ) # Load audio templates templates = [Raven.wav_to_template(p, name=p.name) for p in wav_paths] if args.average_templates: _LOGGER.debug("Averaging %s templates", len(templates)) templates = [Template.average_templates(templates)] raven = Raven( templates=templates, recorder=recorder, probability_threshold=args.probability_threshold, minimum_matches=args.minimum_matches, distance_threshold=args.distance_threshold, refractory_sec=args.refractory_seconds, shift_sec=args.window_shift_seconds, debug=args.log_predictions, ) udp_audio = [] if args.udp_audio: udp_audio = [(host, int(port), site_id) for host, port, site_id in args.udp_audio] # Listen for messages client = mqtt.Client() hermes = WakeHermesMqtt( client, raven=raven, minimum_matches=args.minimum_matches, wakeword_id=args.wakeword_id, udp_audio=udp_audio, site_ids=args.site_id, ) _LOGGER.debug("Connecting to %s:%s", args.host, args.port) hermes_cli.connect(client, args) client.loop_start() try: # Run event loop asyncio.run(hermes.handle_messages_async()) except KeyboardInterrupt: pass finally: _LOGGER.debug("Shutting down") client.loop_stop()
class WakeHermesMqtt(HermesClient): """Hermes MQTT server for Rhasspy wakeword with Raven.""" def __init__( self, client, raven: Raven, minimum_matches: int = 1, wakeword_id: str = "", site_ids: typing.Optional[typing.List[str]] = None, enabled: bool = True, sample_rate: int = 16000, sample_width: int = 2, channels: int = 1, chunk_size: int = 960, udp_audio: typing.Optional[typing.List[typing.Tuple[str, int, str]]] = None, udp_chunk_size: int = 2048, log_predictions: bool = False, ): super().__init__( "rhasspywake_raven_hermes", client, sample_rate=sample_rate, sample_width=sample_width, channels=channels, site_ids=site_ids, ) self.subscribe( AudioFrame, HotwordToggleOn, HotwordToggleOff, GetHotwords, RecordHotwordExample, ) self.raven = raven self.minimum_matches = minimum_matches self.wakeword_id = wakeword_id self.enabled = enabled self.disabled_reasons: typing.Set[str] = set() # Required audio format self.sample_rate = sample_rate self.sample_width = sample_width self.channels = channels self.chunk_size = chunk_size # Queue of WAV audio chunks to process (plus site_id) self.wav_queue: queue.Queue = queue.Queue() self.first_audio: bool = True self.audio_buffer = bytes() self.last_audio_site_id: str = "default" # Fields for recording examples self.recording_example = False self.example_recorder = WebRtcVadRecorder(max_seconds=10) self.example_future: typing.Optional[asyncio.Future] = None # Start threads self.detection_thread = threading.Thread( target=self.detection_thread_proc, daemon=True ) self.detection_thread.start() # Listen for raw audio on UDP too self.udp_chunk_size = udp_chunk_size if udp_audio: for udp_host, udp_port, udp_site_id in udp_audio: threading.Thread( target=self.udp_thread_proc, args=(udp_host, udp_port, udp_site_id), daemon=True, ).start() # ------------------------------------------------------------------------- async def handle_audio_frame( self, wav_bytes: bytes, site_id: str = "default" ) -> None: """Process a single audio frame""" self.wav_queue.put((wav_bytes, site_id)) async def handle_detection( self, matching_indexes: typing.List[int] ) -> typing.AsyncIterable[ typing.Union[typing.Tuple[HotwordDetected, TopicArgs], HotwordError] ]: """Handle a successful hotword detection""" try: template = self.raven.templates[matching_indexes[0]] wakeword_id = self.wakeword_id if not wakeword_id: wakeword_id = template.name yield ( HotwordDetected( site_id=self.last_audio_site_id, model_id=template.name, current_sensitivity=self.raven.distance_threshold, model_version="", model_type="personal", ), {"wakeword_id": wakeword_id}, ) except Exception as e: _LOGGER.exception("handle_detection") yield HotwordError( error=str(e), context=str(matching_indexes), site_id=self.last_audio_site_id, ) async def handle_get_hotwords( self, get_hotwords: GetHotwords ) -> typing.AsyncIterable[typing.Union[Hotwords, HotwordError]]: """Report available hotwords""" try: yield Hotwords(models=[], id=get_hotwords.id, site_id=get_hotwords.site_id) except Exception as e: _LOGGER.exception("handle_get_hotwords") yield HotwordError( error=str(e), context=str(get_hotwords), site_id=get_hotwords.site_id ) async def handle_record_example( self, record_example: RecordHotwordExample ) -> typing.AsyncIterable[ typing.Union[typing.Tuple[HotwordExampleRecorded, TopicArgs], HotwordError] ]: """Record an example of a hotword.""" try: assert ( not self.recording_example ), "Only one example can be recorded at a time" # Start recording assert self.loop, "No loop" self.example_future = self.loop.create_future() self.example_recorder.start() self.recording_example = True # Wait for result _LOGGER.debug("Recording example (id=%s)", record_example.id) example_audio = await self.example_future assert isinstance(example_audio, bytes) # Trim silence _LOGGER.debug("Trimming silence from example") example_audio = trim_silence(example_audio) # Convert to WAV format wav_data = self.to_wav_bytes(example_audio) yield ( HotwordExampleRecorded(wav_bytes=wav_data), {"site_id": record_example.site_id, "request_id": record_example.id}, ) except Exception as e: _LOGGER.exception("handle_record_example") yield HotwordError( error=str(e), context=str(record_example), site_id=record_example.site_id, ) def add_example_audio(self, audio_data: bytes): """Add an audio frame to the currently recording example.""" result = self.example_recorder.process_chunk(audio_data) if result: self.recording_example = False assert self.example_future is not None, "No future" example_audio = self.example_recorder.stop() _LOGGER.debug( "Recorded %s byte(s) for audio for example", len(example_audio) ) # Signal waiting coroutine with audio assert self.loop, "No loop" self.loop.call_soon_threadsafe( self.example_future.set_result, example_audio ) # ------------------------------------------------------------------------- def detection_thread_proc(self): """Handle WAV audio chunks.""" try: while True: wav_bytes, site_id = self.wav_queue.get() if wav_bytes is None: # Shutdown signal break self.last_audio_site_id = site_id # Handle audio frames if self.first_audio: _LOGGER.debug("Receiving audio") self.first_audio = False # Extract/convert audio data audio_data = self.maybe_convert_wav(wav_bytes) if self.recording_example: # Add to currently recording example self.add_example_audio(audio_data) # Don't process audio for wake word while recording self.audio_buffer = bytes() continue # Add to persistent buffer self.audio_buffer += audio_data # Process in chunks. # Any remaining audio data will be kept in buffer. while len(self.audio_buffer) >= self.chunk_size: chunk = self.audio_buffer[: self.chunk_size] self.audio_buffer = self.audio_buffer[self.chunk_size :] if chunk: try: matching_indexes = self.raven.process_chunk(chunk) if len(matching_indexes) >= self.minimum_matches: asyncio.run_coroutine_threadsafe( self.publish_all( self.handle_detection(matching_indexes) ), self.loop, ) except Exception: _LOGGER.exception("process_chunk") except Exception: _LOGGER.exception("detection_thread_proc") # ------------------------------------------------------------------------- def udp_thread_proc(self, host: str, port: int, site_id: str): """Handle WAV chunks from UDP socket.""" try: udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) udp_socket.bind((host, port)) _LOGGER.debug("Listening for audio on UDP %s:%s", host, port) while True: wav_bytes, _ = udp_socket.recvfrom( self.udp_chunk_size + WAV_HEADER_BYTES ) if self.enabled: self.wav_queue.put((wav_bytes, site_id)) except Exception: _LOGGER.exception("udp_thread_proc") # ------------------------------------------------------------------------- async def on_message_blocking( self, message: Message, site_id: typing.Optional[str] = None, session_id: typing.Optional[str] = None, topic: typing.Optional[str] = None, ) -> GeneratorType: """Received message from MQTT broker.""" # Check enable/disable messages if isinstance(message, HotwordToggleOn): if message.reason == HotwordToggleReason.UNKNOWN: # Always enable on unknown self.disabled_reasons.clear() else: self.disabled_reasons.discard(message.reason) if self.disabled_reasons: _LOGGER.debug("Still disabled: %s", self.disabled_reasons) else: self.enabled = True self.first_audio = True _LOGGER.debug("Enabled") elif isinstance(message, HotwordToggleOff): self.enabled = False self.disabled_reasons.add(message.reason) _LOGGER.debug("Disabled") elif isinstance(message, AudioFrame): if self.enabled: assert site_id, "Missing site_id" await self.handle_audio_frame(message.wav_bytes, site_id=site_id) elif isinstance(message, GetHotwords): async for hotword_result in self.handle_get_hotwords(message): yield hotword_result elif isinstance(message, RecordHotwordExample): # Handled in on_message pass else: _LOGGER.warning("Unexpected message: %s", message) async def on_message( self, message: Message, site_id: typing.Optional[str] = None, session_id: typing.Optional[str] = None, topic: typing.Optional[str] = None, ) -> GeneratorType: """Received message from MQTT broker (non-blocking).""" if isinstance(message, RecordHotwordExample): async for example_result in self.handle_record_example(message): yield example_result
def main(): """Main entry point.""" parser = argparse.ArgumentParser(prog="rhasspy-wake-raven") parser.add_argument( "--keyword", action="append", nargs="+", default=[], help="Directory with WAV templates and settings (setting-name=value)", ) parser.add_argument( "--chunk-size", default=1920, help= "Number of bytes to read at a time from standard in (default: 1920)", ) parser.add_argument( "--record", nargs="+", help= "Record example templates to a directory, optionally with given name format (e.g., 'my-keyword-{n:02d}.wav')", ) parser.add_argument( "--probability-threshold", type=float, default=0.5, help="Probability above which detection occurs (default: 0.5)", ) parser.add_argument( "--distance-threshold", type=float, default=0.22, help= "Normalized dynamic time warping distance threshold for template matching (default: 0.22)", ) parser.add_argument( "--minimum-matches", type=int, default=1, help= "Number of templates that must match to produce output (default: 1)", ) parser.add_argument( "--refractory-seconds", type=float, default=2.0, help="Seconds before wake word can be activated again (default: 2)", ) parser.add_argument( "--print-all-matches", action="store_true", help= "Print JSON for all matching templates instead of just the first one", ) parser.add_argument( "--window-shift-seconds", type=float, default=Raven.DEFAULT_SHIFT_SECONDS, help= f"Seconds to shift sliding time window on audio buffer (default: {Raven.DEFAULT_SHIFT_SECONDS})", ) parser.add_argument( "--dtw-window-size", type=int, default=5, help= "Size of band around slanted diagonal during dynamic time warping calculation (default: 5)", ) parser.add_argument( "--vad-sensitivity", type=int, choices=[1, 2, 3], default=1, help="Webrtcvad VAD sensitivity (1-3)", ) parser.add_argument( "--current-threshold", type=float, help="Debiased energy threshold of current audio frame", ) parser.add_argument( "--max-energy", type=float, help="Fixed maximum energy for ratio calculation (default: observed)", ) parser.add_argument( "--max-current-ratio-threshold", type=float, help="Threshold of ratio between max energy and current audio frame", ) parser.add_argument( "--silence-method", choices=[e.value for e in SilenceMethod], default=SilenceMethod.VAD_ONLY, help="Method for detecting silence", ) parser.add_argument( "--average-templates", action="store_true", help= "Average wakeword templates together to reduce number of calculations", ) parser.add_argument( "--exit-count", type=int, help="Exit after some number of detections (default: never)", ) parser.add_argument( "--read-entire-input", action="store_true", help="Read entire audio input at start and exit after processing", ) parser.add_argument( "--max-chunks-in-queue", type=int, help= "Maximum number of audio chunks waiting for processing before being dropped", ) parser.add_argument( "--skip-probability-threshold", type=float, default=0, help= "Skip additional template calculations if probability is below this threshold", ) parser.add_argument( "--failed-matches-to-refractory", type=int, help= "Number of failed template matches before entering refractory period (default: disabled)", ) parser.add_argument( "--benchmark", action="store_true", help="Track timings and report benchmark results on exit", ) parser.add_argument("--debug", action="store_true", help="Print DEBUG messages to the console") args = parser.parse_args() if args.debug: logging.basicConfig(level=logging.DEBUG) else: logging.basicConfig(level=logging.INFO) # Create silence detector. # This can be shared by Raven instances because it's not maintaining state. recorder = WebRtcVadRecorder( vad_mode=args.vad_sensitivity, silence_method=args.silence_method, current_energy_threshold=args.current_threshold, max_energy=args.max_energy, max_current_ratio_threshold=args.max_current_ratio_threshold, min_seconds=0.5, before_seconds=1, ) if args.record: # Do recording instead of recognizing record_dir = Path(args.record[0]) if len(args.record) > 1: record_format = args.record[1] else: record_format = "example-{n:02d}.wav" return record_templates(record_dir, record_format, recorder, args) assert args.keyword, "--keyword is required" # Instances of Raven that will run in separate threads ravens: typing.List[RavenInstance] = [] # Queue for detections. Handled in separate thread. output_queue = Queue() # Load one or more keywords for keyword_settings in args.keyword: template_dir = Path(keyword_settings[0]) wav_paths = list(template_dir.glob("*.wav")) if not wav_paths: _LOGGER.warning("No WAV files found in %s", template_dir) continue keyword_name = template_dir.name # Load audio templates keyword_templates = [ Raven.wav_to_template(p, name=str(p), shift_sec=args.window_shift_seconds) for p in wav_paths ] raven_args = { "templates": keyword_templates, "keyword_name": keyword_name, "recorder": recorder, "probability_threshold": args.probability_threshold, "minimum_matches": args.minimum_matches, "distance_threshold": args.distance_threshold, "refractory_sec": args.refractory_seconds, "shift_sec": args.window_shift_seconds, "skip_probability_threshold": args.skip_probability_threshold, "failed_matches_to_refractory": args.failed_matches_to_refractory, "debug": args.debug, "benchmark": args.benchmark, } # Apply settings average_templates = args.average_templates for setting_str in keyword_settings[1:]: setting_name, setting_value = setting_str.strip().split("=", maxsplit=1) setting_name = setting_name.lower() if setting_name == "name": raven_args["keyword_name"] = setting_value elif setting_name == "probability-threshold": raven_args["probability_threshold"] = float(setting_value) elif setting_name == "minimum-matches": raven_args["minimum_matches"] = int(setting_value) elif setting_name == "average-templates": average_templates = setting_value.lower().strip() == "true" if average_templates: _LOGGER.debug("Averaging %s templates for %s", len(keyword_templates), template_dir) raven_args["templates"] = [ Template.average_templates(keyword_templates) ] # Create instance of Raven in a separate thread for keyword raven = Raven(**raven_args) chunk_queue: "Queue[bytes]" = Queue() ravens.append( RavenInstance( thread=threading.Thread( target=detect_thread_proc, args=(chunk_queue, raven, output_queue, args), daemon=True, ), raven=raven, chunk_queue=chunk_queue, )) # Start all threads for raven_inst in ravens: raven_inst.thread.start() output_thread = threading.Thread(target=output_thread_proc, args=(output_queue, ), daemon=True) output_thread.start() # ------------------------------------------------------------------------- print("Reading 16-bit 16Khz raw audio from stdin...", file=sys.stderr) if args.read_entire_input: audio_buffer = FakeStdin(sys.stdin.buffer.read()) else: audio_buffer = sys.stdin.buffer try: while True: # Read raw audio chunk chunk = audio_buffer.read(args.chunk_size) if not chunk or _EXIT_NOW: # Empty chunk break # Add to all detector threads for raven_inst in ravens: raven_inst.chunk_queue.put(chunk) except KeyboardInterrupt: pass finally: if not args.read_entire_input: # Exhaust queues _LOGGER.debug("Emptying audio queues...") for raven_inst in ravens: while not raven_inst.chunk_queue.empty(): raven_inst.chunk_queue.get() for raven_inst in ravens: # Signal thread to quit raven_inst.chunk_queue.put(None) _LOGGER.debug("Waiting for %s thread...", raven_inst.raven.keyword_name) raven_inst.thread.join() # Stop output thread output_queue.put(None) _LOGGER.debug("Waiting for output thread...") output_thread.join() if args.benchmark: print_benchmark(ravens)