Exemple #1
0
    def get_command_recorder(self):
        """Get voice command recorder based on profile settings."""
        from rhasspysilence import WebRtcVadRecorder

        # Load settings
        vad_mode = int(pydash.get(self.profile, "voice-command.vad-mode", 3))
        min_seconds = float(
            pydash.get(self.profile, "voice-command.minimum-seconds", 1))
        max_seconds = float(
            pydash.get(self.profile, "voice-command.maximum-seconds", 30))
        speech_seconds = float(
            pydash.get(self.profile, "voice-command.speech-seconds", 0.3))
        silence_seconds = float(
            pydash.get(self.profile, "voice-command.silence-seconds", 0.5))
        before_seconds = float(
            pydash.get(self.profile, "voice-command.before-seconds", 0.5))
        skip_seconds = float(
            pydash.get(self.profile, "voice-command.skip-seconds", 0))
        chunk_size = int(
            pydash.get(self.profile, "voice-command.chunk-size", 960))
        sample_rate = int(
            pydash.get(self.profile, "audio.format.sample-rate-hertz", 16000))

        return WebRtcVadRecorder(
            vad_mode=vad_mode,
            sample_rate=sample_rate,
            chunk_size=chunk_size,
            min_seconds=min_seconds,
            max_seconds=max_seconds,
            speech_seconds=speech_seconds,
            silence_seconds=silence_seconds,
            before_seconds=before_seconds,
            skip_seconds=skip_seconds,
        )
    def __init__(
        self,
        client,
        transcriber: Transcriber,
        siteIds: typing.Optional[typing.List[str]] = None,
        enabled: bool = True,
        sample_rate: int = 16000,
        sample_width: int = 2,
        channels: int = 1,
        make_recorder: typing.Callable[[], VoiceCommandRecorder] = None,
    ):
        self.client = client
        self.transcriber = transcriber
        self.siteIds = siteIds or []
        self.enabled = enabled

        # Required audio format
        self.sample_rate = sample_rate
        self.sample_width = sample_width
        self.channels = channels

        # No timeout
        self.make_recorder = make_recorder or (
            lambda: WebRtcVadRecorder(max_seconds=None))

        # WAV buffers for each session
        self.session_recorders: typing.Dict[
            str, VoiceCommandRecorder] = defaultdict(VoiceCommandRecorder)

        # Topic to listen for WAV chunks on
        self.audioframe_topics: typing.List[str] = []
        for siteId in self.siteIds:
            self.audioframe_topics.append(AudioFrame.topic(siteId=siteId))

        self.first_audio: bool = True
class RhasspySilenceTestCase(unittest.TestCase):
    """Tests for rhasspysilence."""
    def __init__(self, *args):
        super().__init__(*args)

        self.recorder = None
        self.chunk_size = 2048

    def setUp(self):
        self.recorder = WebRtcVadRecorder()
        self.recorder.start()

    def test_command(self):
        """Verify voice command sample WAV file."""
        command = None

        # Check test WAV file
        with wave.open("etc/turn_on_living_room_lamp.wav", "r") as wav_file:
            audio_data = wav_file.readframes(wav_file.getnframes())
            while audio_data:
                chunk = audio_data[:self.chunk_size]
                audio_data = audio_data[self.chunk_size:]

                command = self.recorder.process_chunk(chunk)
                if command:
                    break

            self.assertTrue(command)
            self.assertEqual(command.result, VoiceCommandResult.SUCCESS)
            self.assertGreater(len(command.audio_data), 0)

    def test_noise(self):
        """Verify no command in noise WAV file."""
        command = None

        # Check test WAV file
        with wave.open("etc/noise.wav", "r") as wav_file:
            audio_data = wav_file.readframes(wav_file.getnframes())
            while audio_data:
                chunk = audio_data[:self.chunk_size]
                audio_data = audio_data[self.chunk_size:]

                command = self.recorder.process_chunk(chunk)
                if command:
                    break

            self.assertFalse(command)
def test_noise():
    """Verify no command in noise WAV file."""
    command = None
    recorder = WebRtcVadRecorder()
    recorder.start()

    # Check test WAV file
    with wave.open("etc/noise.wav", "r") as wav_file:
        audio_data = wav_file.readframes(wav_file.getnframes())
        while audio_data:
            chunk = audio_data[:CHUNK_SIZE]
            audio_data = audio_data[CHUNK_SIZE:]

            command = recorder.process_chunk(chunk)
            if command:
                break

        assert not command
Exemple #5
0
 def make_webrtcvad():
     return WebRtcVadRecorder(
         max_seconds=None,
         vad_mode=vad_mode,
         skip_seconds=skip_seconds,
         min_seconds=min_seconds,
         speech_seconds=speech_seconds,
         silence_seconds=silence_seconds,
         before_seconds=before_seconds,
     )
Exemple #6
0
 def default_recorder():
     return WebRtcVadRecorder(
         max_seconds=None,
         vad_mode=vad_mode,
         skip_seconds=skip_seconds,
         min_seconds=min_seconds,
         speech_seconds=speech_seconds,
         silence_seconds=silence_seconds,
         before_seconds=before_seconds,
     )
def test_command():
    """Verify voice command sample WAV file."""
    command = None
    recorder = WebRtcVadRecorder()
    recorder.start()

    # Check test WAV file
    with wave.open("etc/turn_on_living_room_lamp.wav", "r") as wav_file:
        audio_data = wav_file.readframes(wav_file.getnframes())
        while audio_data:
            chunk = audio_data[:CHUNK_SIZE]
            audio_data = audio_data[CHUNK_SIZE:]

            command = recorder.process_chunk(chunk)
            if command:
                break

        assert command
        assert command.result == VoiceCommandResult.SUCCESS
        assert command.audio_data
Exemple #8
0
def trim_silence(
    audio_bytes: bytes,
    ratio_threshold: float = 20.0,
    chunk_size: int = 960,
    skip_first_chunk=True,
) -> bytes:
    """Trim silence from start and end of audio using ratio of max/current energy."""
    first_chunk = False
    energies = []
    max_energy = None
    while len(audio_bytes) >= chunk_size:
        chunk = audio_bytes[:chunk_size]
        audio_bytes = audio_bytes[chunk_size:]

        if skip_first_chunk and (not first_chunk):
            first_chunk = True
            continue

        energy = max(1, WebRtcVadRecorder.get_debiased_energy(chunk))
        energies.append((energy, chunk))

        if (max_energy is None) or (energy > max_energy):
            max_energy = energy

    # Determine chunks below threshold
    assert max_energy is not None, "No maximum energy"
    start_index = None
    end_index = None

    for i, (energy, chunk) in enumerate(energies):
        ratio = max_energy / energy
        if ratio < ratio_threshold:
            end_index = None
            if start_index is None:
                start_index = i
        elif end_index is None:
            end_index = i

    if start_index is None:
        start_index = 0

    if end_index is None:
        end_index = len(energies) - 1

    start_index = max(0, start_index - 1)
    end_index = min(len(energies) - 1, end_index + 1)

    keep_bytes = bytes()
    for _, chunk in energies[start_index : end_index + 1]:
        keep_bytes += chunk

    return keep_bytes
 def default_recorder():
     return WebRtcVadRecorder(
         max_seconds=max_seconds,
         vad_mode=vad_mode,
         skip_seconds=skip_seconds,
         min_seconds=min_seconds,
         speech_seconds=speech_seconds,
         silence_seconds=silence_seconds,
         before_seconds=before_seconds,
         silence_method=silence_method,
         current_energy_threshold=current_energy_threshold,
         max_energy=max_energy,
         max_current_ratio_threshold=max_current_energy_ratio_threshold,
     )
Exemple #10
0
def record_templates(
    record_dir: Path,
    name_format: str,
    recorder: WebRtcVadRecorder,
    args: argparse.Namespace,
):
    """Record audio templates."""
    print("Reading 16-bit 16Khz mono audio from stdin...", file=sys.stderr)

    num_templates = 0

    try:
        print(
            f"Recording template {num_templates}. Please speak your wake word. Press CTRL+C to exit."
        )
        recorder.start()

        while True:
            # Read raw audio chunk
            chunk = sys.stdin.buffer.read(recorder.chunk_size)
            if not chunk:
                # Empty chunk
                break

            result = recorder.process_chunk(chunk)
            if result:
                audio_bytes = recorder.stop()
                audio_bytes = trim_silence(audio_bytes)

                template_path = record_dir / name_format.format(
                    n=num_templates)
                template_path.parent.mkdir(parents=True, exist_ok=True)

                wav_bytes = buffer_to_wav(audio_bytes)
                template_path.write_bytes(wav_bytes)
                _LOGGER.debug("Wrote %s byte(s) of WAV audio to %s",
                              len(wav_bytes), template_path)

                num_templates += 1
                print(
                    f"Recording template {num_templates}. Please speak your wake word. Press CTRL+C to exit."
                )
                recorder.start()
    except KeyboardInterrupt:
        print("Done")
 def make_webrtcvad():
     return WebRtcVadRecorder(max_seconds=None)
    def __init__(
        self,
        templates: typing.List[Template],
        keyword_name: str = "",
        probability_threshold: float = 0.5,
        minimum_matches: int = 0,
        distance_threshold: float = 0.22,
        template_dtw: typing.Optional[DynamicTimeWarping] = None,
        dtw_window_size: int = 5,
        dtw_step_pattern: float = 2,
        shift_sec: float = DEFAULT_SHIFT_SECONDS,
        refractory_sec: float = 2.0,
        skip_probability_threshold: float = 0.0,
        failed_matches_to_refractory: typing.Optional[int] = None,
        recorder: typing.Optional[WebRtcVadRecorder] = None,
        debug: bool = False,
        benchmark: bool = False,
    ):
        self.templates = templates
        assert self.templates, "No templates"

        self.keyword_name = keyword_name

        # Use or create silence detector
        self.recorder = recorder or WebRtcVadRecorder()
        self.vad_chunk_bytes = self.recorder.chunk_size
        self.sample_rate = self.recorder.sample_rate

        # Assume 16-bit samples
        self.sample_width = 2
        self.bytes_per_second = int(self.sample_rate * self.sample_width)

        # Match settings
        self.probability_threshold = probability_threshold
        self.minimum_matches = minimum_matches
        self.distance_threshold = distance_threshold
        self.skip_probability_threshold = skip_probability_threshold
        self.refractory_sec = refractory_sec
        self.failed_matches_to_refractory = failed_matches_to_refractory

        # Dynamic time warping calculation
        self.dtw = template_dtw or DynamicTimeWarping()
        self.dtw_window_size = dtw_window_size
        self.dtw_step_pattern = dtw_step_pattern

        # Average duration of templates
        template_duration_sec = sum([t.duration_sec
                                     for t in templates]) / len(templates)

        # Seconds to shift template window by during processing
        self.template_shift_sec = shift_sec
        self.shifts_per_template = (
            int(math.floor(template_duration_sec / shift_sec)) - 1)

        # Bytes needed for a template
        self.template_chunk_bytes = int(
            math.ceil(template_duration_sec * self.bytes_per_second))

        # Ensure divisible by sample width
        while (self.template_chunk_bytes % self.sample_width) != 0:
            self.template_chunk_bytes += 1

        # Audio
        self.vad_audio_buffer = bytes()
        self.template_audio_buffer = bytes()
        self.example_audio_buffer = bytes()
        self.template_mfcc: typing.Optional[np.ndarray] = None
        self.template_chunks_left = 0
        self.num_template_chunks = int(
            math.ceil((self.template_chunk_bytes / self.vad_chunk_bytes) / 2))

        # State machine
        self.num_refractory_chunks = int(
            math.ceil(self.sample_rate * self.sample_width *
                      (refractory_sec / self.vad_chunk_bytes)))
        self.refractory_chunks_left = 0
        self.failed_matches = 0
        self.match_seconds: typing.Optional[float] = None

        # If True, log DTW predictions
        self.debug = debug

        # Keep previously-computed distances and probabilities for debugging
        self.last_distances: typing.List[typing.Optional[float]] = [
            None for _ in self.templates
        ]

        self.last_probabilities: typing.List[typing.Optional[float]] = [
            None for _ in self.templates
        ]

        # ------------
        # Benchmarking
        # ------------
        self.benchmark = benchmark

        # Seconds to process an entire VAD chunk
        self.time_process_vad_chunk: typing.List[float] = []

        # Seconds to compute single MFCC
        self.time_mfcc: typing.List[float] = []

        # Seconds to check template-sized window for a match
        self.time_match: typing.List[float] = []

        # Seconds to compute DTW cost
        self.time_dtw: typing.List[float] = []
Exemple #13
0
    def __init__(
        self,
        client,
        raven: Raven,
        minimum_matches: int = 1,
        wakeword_id: str = "",
        site_ids: typing.Optional[typing.List[str]] = None,
        enabled: bool = True,
        sample_rate: int = 16000,
        sample_width: int = 2,
        channels: int = 1,
        chunk_size: int = 960,
        udp_audio: typing.Optional[typing.List[typing.Tuple[str, int, str]]] = None,
        udp_chunk_size: int = 2048,
        log_predictions: bool = False,
    ):
        super().__init__(
            "rhasspywake_raven_hermes",
            client,
            sample_rate=sample_rate,
            sample_width=sample_width,
            channels=channels,
            site_ids=site_ids,
        )

        self.subscribe(
            AudioFrame,
            HotwordToggleOn,
            HotwordToggleOff,
            GetHotwords,
            RecordHotwordExample,
        )

        self.raven = raven
        self.minimum_matches = minimum_matches
        self.wakeword_id = wakeword_id

        self.enabled = enabled
        self.disabled_reasons: typing.Set[str] = set()

        # Required audio format
        self.sample_rate = sample_rate
        self.sample_width = sample_width
        self.channels = channels

        self.chunk_size = chunk_size

        # Queue of WAV audio chunks to process (plus site_id)
        self.wav_queue: queue.Queue = queue.Queue()

        self.first_audio: bool = True
        self.audio_buffer = bytes()

        self.last_audio_site_id: str = "default"

        # Fields for recording examples
        self.recording_example = False
        self.example_recorder = WebRtcVadRecorder(max_seconds=10)
        self.example_future: typing.Optional[asyncio.Future] = None

        # Start threads
        self.detection_thread = threading.Thread(
            target=self.detection_thread_proc, daemon=True
        )
        self.detection_thread.start()

        # Listen for raw audio on UDP too
        self.udp_chunk_size = udp_chunk_size

        if udp_audio:
            for udp_host, udp_port, udp_site_id in udp_audio:
                threading.Thread(
                    target=self.udp_thread_proc,
                    args=(udp_host, udp_port, udp_site_id),
                    daemon=True,
                ).start()
 def setUp(self):
     self.recorder = WebRtcVadRecorder()
     self.recorder.start()
Exemple #15
0
    def __init__(
        self,
        templates: typing.List[Template],
        probability_threshold: float = 0.5,
        minimum_matches: int = 0,
        distance_threshold: float = 0.22,
        frame_dtw: typing.Optional[DynamicTimeWarping] = None,
        dtw_window_size: int = 5,
        dtw_step_pattern: float = 2,
        sample_rate: int = 16000,
        chunk_size: int = 960,
        shift_sec: float = 0.01,
        before_chunks: int = 0,
        refractory_sec: float = 2.0,
        skip_probability_threshold: float = 0.0,
        recorder: typing.Optional[WebRtcVadRecorder] = None,
        debug: bool = False,
    ):
        self.templates = templates
        assert self.templates, "No templates"

        self.probability_threshold = probability_threshold
        self.minimum_matches = minimum_matches
        self.distance_threshold = distance_threshold
        self.skip_probability_threshold = skip_probability_threshold

        self.chunk_size = chunk_size
        self.shift_sec = shift_sec
        self.sample_rate = sample_rate

        self.before_buffer: typing.Optional[typing.Deque[bytes]] = None
        if before_chunks > 0:
            self.before_buffer = deque(maxlen=before_chunks)

        # Assume 16-bit samples
        self.sample_width = 2
        self.chunk_seconds = (self.chunk_size /
                              self.sample_width) / self.sample_rate

        # Use or create silence detector
        self.recorder = recorder or WebRtcVadRecorder()

        # Dynamic time warping calculation
        self.dtw = frame_dtw or DynamicTimeWarping()
        self.dtw_window_size = dtw_window_size
        self.dtw_step_pattern = dtw_step_pattern

        # Keep previously-computed distances and probabilities for debugging
        self.last_distances: typing.List[typing.Optional[float]] = [
            None for _ in self.templates
        ]

        self.last_probabilities: typing.List[typing.Optional[float]] = [
            None for _ in self.templates
        ]

        # Average duration of templates
        self.frame_duration_sec = sum([t.duration_sec
                                       for t in templates]) / len(templates)

        # Size in bytes of a frame
        self.window_chunk_size = (
            self.seconds_to_chunks(self.frame_duration_sec) * self.chunk_size *
            self.sample_width)

        # Ensure divisible by sample width
        while (self.window_chunk_size % self.sample_width) != 0:
            self.window_chunk_size += 1

        # Size in bytes to shift each frame.
        # Should be less than the size of a frame to ensure overlap.
        self.shift_size = (self.seconds_to_chunks(self.shift_sec) *
                           self.chunk_size * self.sample_width)

        # State machine
        self.audio_buffer = bytes()
        self.state = RavenState.IN_SILENCE
        self.num_silence_chunks = self.seconds_to_chunks(
            self.frame_duration_sec)
        self.silence_chunks_left = 0
        self.num_refractory_chunks = self.seconds_to_chunks(refractory_sec)
        self.refractory_chunks_left = 0
        self.match_seconds: typing.Optional[float] = None

        self.debug = debug
    def __init__(
        self,
        client,
        ravens: typing.List[Raven],
        examples_dir: typing.Optional[Path] = None,
        examples_format: str = "{keyword}/examples/%Y%m%d-%H%M%S.wav",
        wakeword_id: str = "",
        site_ids: typing.Optional[typing.List[str]] = None,
        enabled: bool = True,
        sample_rate: int = 16000,
        sample_width: int = 2,
        channels: int = 1,
        chunk_size: int = 1920,
        udp_audio: typing.Optional[typing.List[typing.Tuple[str, int,
                                                            str]]] = None,
        udp_chunk_size: int = 2048,
        log_predictions: bool = False,
        lang: typing.Optional[str] = None,
    ):
        super().__init__(
            "rhasspywake_raven_hermes",
            client,
            sample_rate=sample_rate,
            sample_width=sample_width,
            channels=channels,
            site_ids=site_ids,
        )

        self.subscribe(
            AudioFrame,
            HotwordToggleOn,
            HotwordToggleOff,
            GetHotwords,
            RecordHotwordExample,
        )

        self.ravens = ravens
        self.wakeword_id = wakeword_id

        self.examples_dir = examples_dir
        self.examples_format = examples_format

        self.enabled = enabled
        self.disabled_reasons: typing.Set[str] = set()

        # Required audio format
        self.sample_rate = sample_rate
        self.sample_width = sample_width
        self.channels = channels

        self.chunk_size = chunk_size

        # Queue of WAV audio chunks to process (plus site_id)
        self.wav_queue: queue.Queue = queue.Queue()

        self.first_audio: bool = True

        self.last_audio_site_id: str = "default"

        self.lang = lang

        # Fields for recording examples
        self.recording_example = False
        self.example_recorder = WebRtcVadRecorder(max_seconds=10)
        self.example_future: typing.Optional[asyncio.Future] = None

        # Raw audio chunk queues for Raven threads
        self.chunk_queues: typing.List[queue.Queue] = [
            queue.Queue() for raven in ravens
        ]

        # Start main thread to convert audio from MQTT/UDP
        self.audio_thread = threading.Thread(target=self.audio_thread_proc,
                                             daemon=True)
        self.audio_thread.start()

        # Start a thread per Raven instance (per-keyword)
        self.detection_threads = [
            threading.Thread(
                target=self.detection_thread_proc,
                args=(self.chunk_queues[i], self.ravens[i]),
                daemon=True,
            ) for i in range(len(self.ravens))
        ]

        for thread in self.detection_threads:
            thread.start()

        # Listen for raw audio on UDP too
        self.udp_chunk_size = udp_chunk_size

        if udp_audio:
            for udp_host, udp_port, udp_site_id in udp_audio:
                threading.Thread(
                    target=self.udp_thread_proc,
                    args=(udp_host, udp_port, udp_site_id),
                    daemon=True,
                ).start()
Exemple #17
0
def main():
    """Main method."""
    parser = argparse.ArgumentParser(prog="rhasspy-wake-raven-hermes")
    parser.add_argument(
        "--keyword",
        action="append",
        nargs="+",
        default=[],
        help="Directory with WAV templates and settings (setting-name=value)",
    )
    parser.add_argument(
        "--probability-threshold",
        type=float,
        default=0.5,
        help="Probability above which detection occurs (default: 0.5)",
    )
    parser.add_argument(
        "--distance-threshold",
        type=float,
        default=0.22,
        help=
        "Normalized dynamic time warping distance threshold for template matching (default: 0.22)",
    )
    parser.add_argument(
        "--minimum-matches",
        type=int,
        default=1,
        help=
        "Number of templates that must match to produce output (default: 1)",
    )
    parser.add_argument(
        "--refractory-seconds",
        type=float,
        default=2.0,
        help="Seconds before wake word can be activated again (default: 2)",
    )
    parser.add_argument(
        "--window-shift-seconds",
        type=float,
        default=Raven.DEFAULT_SHIFT_SECONDS,
        help=
        f"Seconds to shift sliding time window on audio buffer (default: {Raven.DEFAULT_SHIFT_SECONDS})",
    )
    parser.add_argument(
        "--dtw-window-size",
        type=int,
        default=5,
        help=
        "Size of band around slanted diagonal during dynamic time warping calculation (default: 5)",
    )
    parser.add_argument(
        "--vad-sensitivity",
        type=int,
        choices=[1, 2, 3],
        default=3,
        help="Webrtcvad VAD sensitivity (1-3)",
    )
    parser.add_argument(
        "--current-threshold",
        type=float,
        help="Debiased energy threshold of current audio frame",
    )
    parser.add_argument(
        "--max-energy",
        type=float,
        help="Fixed maximum energy for ratio calculation (default: observed)",
    )
    parser.add_argument(
        "--max-current-ratio-threshold",
        type=float,
        help="Threshold of ratio between max energy and current audio frame",
    )
    parser.add_argument(
        "--silence-method",
        choices=[e.value for e in SilenceMethod],
        default=SilenceMethod.VAD_ONLY,
        help="Method for detecting silence",
    )
    parser.add_argument(
        "--average-templates",
        action="store_true",
        help=
        "Average wakeword templates together to reduce number of calculations",
    )
    parser.add_argument(
        "--udp-audio",
        nargs=3,
        action="append",
        help="Host/port/siteId for UDP audio input",
    )
    parser.add_argument(
        "--examples-dir",
        help="Save positive example audio to directory as WAV files")
    parser.add_argument(
        "--examples-format",
        default="{keyword}/examples/%Y%m%d-%H%M%S.wav",
        help=
        "Format of positive example WAV file names using strftime (relative to examples-dir)",
    )
    parser.add_argument(
        "--log-predictions",
        action="store_true",
        help="Log prediction probabilities for each audio chunk (very verbose)",
    )
    parser.add_argument("--lang", help="Set lang in hotword detected message")

    hermes_cli.add_hermes_args(parser)
    args = parser.parse_args()

    hermes_cli.setup_logging(args)
    _LOGGER.debug(args)
    hermes: typing.Optional[WakeHermesMqtt] = None

    # -------------------------------------------------------------------------

    if args.examples_dir:
        # Directory to save positive example WAV files
        args.examples_dir = Path(args.examples_dir)
        args.examples_dir.mkdir(parents=True, exist_ok=True)

    if args.keyword:
        missing_keywords = not any(
            list(Path(k[0]).glob("*.wav")) for k in args.keyword)
    else:
        missing_keywords = True

    if missing_keywords:
        args.keyword = [[_DIR / "templates"]]
        _LOGGER.debug(
            "No keywords provided. Use built-in 'okay rhasspy' templates.")

    # Create silence detector
    recorder = WebRtcVadRecorder(
        vad_mode=args.vad_sensitivity,
        silence_method=args.silence_method,
        current_energy_threshold=args.current_threshold,
        max_energy=args.max_energy,
        max_current_ratio_threshold=args.max_current_ratio_threshold,
    )

    # Load audio templates
    ravens: typing.List[Raven] = []

    for keyword_settings in args.keyword:
        template_dir = Path(keyword_settings[0])
        wav_paths = list(template_dir.glob("*.wav"))
        if not wav_paths:
            _LOGGER.warning("No WAV files found in %s", template_dir)
            continue

        keyword_name = template_dir.name if not missing_keywords else "okay-rhasspy"

        # Load audio templates
        keyword_templates = [
            Raven.wav_to_template(p,
                                  name=str(p),
                                  shift_sec=args.window_shift_seconds)
            for p in wav_paths
        ]

        raven_args = {
            "templates": keyword_templates,
            "keyword_name": keyword_name,
            "recorder": recorder,
            "probability_threshold": args.probability_threshold,
            "minimum_matches": args.minimum_matches,
            "distance_threshold": args.distance_threshold,
            "refractory_sec": args.refractory_seconds,
            "shift_sec": args.window_shift_seconds,
            "debug": args.log_predictions,
        }

        # Apply settings
        average_templates = args.average_templates
        for setting_str in keyword_settings[1:]:
            setting_name, setting_value = setting_str.strip().split("=",
                                                                    maxsplit=1)
            setting_name = setting_name.lower().replace("_", "-")

            if setting_name == "name":
                raven_args["keyword_name"] = setting_value
            elif setting_name == "probability-threshold":
                raven_args["probability_threshold"] = float(setting_value)
            elif setting_name == "minimum-matches":
                raven_args["minimum_matches"] = int(setting_value)
            elif setting_name == "average-templates":
                average_templates = setting_value.lower().strip() == "true"

        if average_templates:
            _LOGGER.debug("Averaging %s templates for %s",
                          len(keyword_templates), template_dir)
            raven_args["templates"] = [
                Template.average_templates(keyword_templates)
            ]

        # Create instance of Raven in a separate thread for keyword
        ravens.append(Raven(**raven_args))

    udp_audio = []
    if args.udp_audio:
        udp_audio = [(host, int(port), site_id)
                     for host, port, site_id in args.udp_audio]

    # Listen for messages
    client = mqtt.Client()
    hermes = WakeHermesMqtt(
        client,
        ravens=ravens,
        examples_dir=args.examples_dir,
        examples_format=args.examples_format,
        udp_audio=udp_audio,
        site_ids=args.site_id,
        lang=args.lang,
    )

    _LOGGER.debug("Connecting to %s:%s", args.host, args.port)
    hermes_cli.connect(client, args)
    client.loop_start()

    try:
        # Run event loop
        asyncio.run(hermes.handle_messages_async())
    except KeyboardInterrupt:
        pass
    finally:
        _LOGGER.debug("Shutting down")
        client.loop_stop()
        hermes.stop()
class WakeHermesMqtt(HermesClient):
    """Hermes MQTT server for Rhasspy wakeword with Raven."""
    def __init__(
        self,
        client,
        ravens: typing.List[Raven],
        examples_dir: typing.Optional[Path] = None,
        examples_format: str = "{keyword}/examples/%Y%m%d-%H%M%S.wav",
        wakeword_id: str = "",
        site_ids: typing.Optional[typing.List[str]] = None,
        enabled: bool = True,
        sample_rate: int = 16000,
        sample_width: int = 2,
        channels: int = 1,
        chunk_size: int = 1920,
        udp_audio: typing.Optional[typing.List[typing.Tuple[str, int,
                                                            str]]] = None,
        udp_chunk_size: int = 2048,
        log_predictions: bool = False,
        lang: typing.Optional[str] = None,
    ):
        super().__init__(
            "rhasspywake_raven_hermes",
            client,
            sample_rate=sample_rate,
            sample_width=sample_width,
            channels=channels,
            site_ids=site_ids,
        )

        self.subscribe(
            AudioFrame,
            HotwordToggleOn,
            HotwordToggleOff,
            GetHotwords,
            RecordHotwordExample,
        )

        self.ravens = ravens
        self.wakeword_id = wakeword_id

        self.examples_dir = examples_dir
        self.examples_format = examples_format

        self.enabled = enabled
        self.disabled_reasons: typing.Set[str] = set()

        # Required audio format
        self.sample_rate = sample_rate
        self.sample_width = sample_width
        self.channels = channels

        self.chunk_size = chunk_size

        # Queue of WAV audio chunks to process (plus site_id)
        self.wav_queue: queue.Queue = queue.Queue()

        self.first_audio: bool = True

        self.last_audio_site_id: str = "default"

        self.lang = lang

        # Fields for recording examples
        self.recording_example = False
        self.example_recorder = WebRtcVadRecorder(max_seconds=10)
        self.example_future: typing.Optional[asyncio.Future] = None

        # Raw audio chunk queues for Raven threads
        self.chunk_queues: typing.List[queue.Queue] = [
            queue.Queue() for raven in ravens
        ]

        # Start main thread to convert audio from MQTT/UDP
        self.audio_thread = threading.Thread(target=self.audio_thread_proc,
                                             daemon=True)
        self.audio_thread.start()

        # Start a thread per Raven instance (per-keyword)
        self.detection_threads = [
            threading.Thread(
                target=self.detection_thread_proc,
                args=(self.chunk_queues[i], self.ravens[i]),
                daemon=True,
            ) for i in range(len(self.ravens))
        ]

        for thread in self.detection_threads:
            thread.start()

        # Listen for raw audio on UDP too
        self.udp_chunk_size = udp_chunk_size

        if udp_audio:
            for udp_host, udp_port, udp_site_id in udp_audio:
                threading.Thread(
                    target=self.udp_thread_proc,
                    args=(udp_host, udp_port, udp_site_id),
                    daemon=True,
                ).start()

    # -------------------------------------------------------------------------

    async def handle_audio_frame(self,
                                 wav_bytes: bytes,
                                 site_id: str = "default") -> None:
        """Process a single audio frame"""
        self.wav_queue.put((wav_bytes, site_id))

    async def handle_detection(
        self, matching_indexes: typing.List[int], raven: Raven
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[
            HotwordDetected, TopicArgs], HotwordError]]:
        """Handle a successful hotword detection"""
        try:
            template = raven.templates[matching_indexes[0]]

            wakeword_id = raven.keyword_name or template.name
            if not wakeword_id:
                wakeword_id = "default"

            yield (
                HotwordDetected(
                    site_id=self.last_audio_site_id,
                    model_id=template.name,
                    current_sensitivity=raven.probability_threshold,
                    model_version="",
                    model_type="personal",
                    lang=self.lang,
                ),
                {
                    "wakeword_id": wakeword_id
                },
            )
        except Exception as e:
            _LOGGER.exception("handle_detection")
            yield HotwordError(
                error=str(e),
                context=f"{raven.keyword_name}: {template.name}",
                site_id=self.last_audio_site_id,
            )

    async def handle_get_hotwords(
        self, get_hotwords: GetHotwords
    ) -> typing.AsyncIterable[typing.Union[Hotwords, HotwordError]]:
        """Report available hotwords"""
        try:
            models: typing.List[Hotword] = []

            # Each keyword is in a separate Raven instance
            for raven in self.ravens:
                # Assume that the directory name is something like
                # "okay-rhasspy" for the keyword "okay rhasspy".
                models.append(
                    Hotword(
                        model_id=raven.keyword_name,
                        model_words=re.sub(r"[_-]+", " ", raven.keyword_name),
                    ))

            yield Hotwords(models=models,
                           id=get_hotwords.id,
                           site_id=get_hotwords.site_id)

        except Exception as e:
            _LOGGER.exception("handle_get_hotwords")
            yield HotwordError(error=str(e),
                               context=str(get_hotwords),
                               site_id=get_hotwords.site_id)

    async def handle_record_example(
        self, record_example: RecordHotwordExample
    ) -> typing.AsyncIterable[typing.Union[typing.Tuple[
            HotwordExampleRecorded, TopicArgs], HotwordError]]:
        """Record an example of a hotword."""
        try:
            if self.recording_example:
                _LOGGER.warning("Cancelling previous recording")
                self.example_recorder.stop()

            # Start recording
            assert self.loop, "No loop"
            self.example_future = self.loop.create_future()
            self.example_recorder.start()
            self.recording_example = True

            # Wait for result
            _LOGGER.debug("Recording example (id=%s)", record_example.id)
            example_audio = await self.example_future
            assert isinstance(example_audio, bytes)

            # Trim silence
            _LOGGER.debug("Trimming silence from example")
            example_audio = trim_silence(example_audio)

            # Convert to WAV format
            wav_data = self.to_wav_bytes(example_audio)

            yield (
                HotwordExampleRecorded(wav_bytes=wav_data),
                {
                    "site_id": record_example.site_id,
                    "request_id": record_example.id
                },
            )

        except Exception as e:
            _LOGGER.exception("handle_record_example")
            yield HotwordError(
                error=str(e),
                context=str(record_example),
                site_id=record_example.site_id,
            )

    def add_example_audio(self, audio_data: bytes):
        """Add an audio frame to the currently recording example."""
        result = self.example_recorder.process_chunk(audio_data)
        if result:
            self.recording_example = False
            assert self.example_future is not None, "No future"
            example_audio = self.example_recorder.stop()
            _LOGGER.debug("Recorded %s byte(s) for audio for example",
                          len(example_audio))

            # Signal waiting coroutine with audio
            assert self.loop, "No loop"
            self.loop.call_soon_threadsafe(self.example_future.set_result,
                                           example_audio)

    # -------------------------------------------------------------------------

    def audio_thread_proc(self):
        """Handle WAV audio chunks."""
        try:
            while True:
                wav_bytes, site_id = self.wav_queue.get()
                if wav_bytes is None:
                    # Shutdown signal
                    for chunk_queue in self.chunk_queues:
                        chunk_queue.put(None)

                    # Wait for detection threads to exit
                    for thread in self.detection_threads:
                        thread.join()

                    break

                self.last_audio_site_id = site_id

                # Handle audio frames
                if self.first_audio:
                    _LOGGER.debug("Receiving audio")
                    self.first_audio = False

                # Extract/convert audio data
                audio_data = self.maybe_convert_wav(wav_bytes)

                if self.recording_example:
                    # Add to currently recording example
                    self.add_example_audio(audio_data)

                    # Don't process audio for wake word while recording
                    continue

                # Add to queues for detection threads
                for chunk_queue in self.chunk_queues:
                    chunk_queue.put(audio_data)
        except Exception:
            _LOGGER.exception("audio_thread_proc")

    def detection_thread_proc(self, chunk_queue: queue.Queue, raven: Raven):
        """Run Raven detection on audio chunks."""
        try:
            _LOGGER.debug(
                "Listening for keyword %s (probability_threshold=%s, minimum_matches=%s, num_templates=%s)",
                raven.keyword_name,
                raven.probability_threshold,
                raven.minimum_matches,
                len(raven.templates),
            )

            while True:
                audio_data = chunk_queue.get()
                if audio_data is None:
                    # Shutdown signal
                    break

                if audio_data:
                    try:
                        keep_audio = bool(self.examples_dir)
                        matching_indexes = raven.process_chunk(
                            audio_data, keep_audio=keep_audio)
                        if len(matching_indexes) >= raven.minimum_matches:
                            # Report detection
                            assert self.loop is not None, "No loop"
                            asyncio.run_coroutine_threadsafe(
                                self.publish_all(
                                    self.handle_detection(
                                        matching_indexes, raven)),
                                self.loop,
                            )

                            if keep_audio:
                                # Save positive example
                                assert self.examples_dir is not None
                                example_path = self.examples_dir / time.strftime(
                                    self.examples_format).format(
                                        keyword=raven.keyword_name)

                                example_path.parent.mkdir(parents=True,
                                                          exist_ok=True)

                                with open(example_path, "wb") as example_file:
                                    example_wav_bytes = self.to_wav_bytes(
                                        raven.example_audio_buffer)
                                    example_file.write(example_wav_bytes)

                                _LOGGER.debug("Wrote example to %s",
                                              example_path)
                    except Exception:
                        _LOGGER.exception("process_chunk")
        except Exception:
            _LOGGER.exception("detection_thread_proc")

    # -------------------------------------------------------------------------

    def stop(self):
        """Stop audio and detection threads."""
        self.wav_queue.put((None, ""))

        _LOGGER.debug("Waiting for detection threads to stop...")
        self.audio_thread.join()

    # -------------------------------------------------------------------------

    def udp_thread_proc(self, host: str, port: int, site_id: str):
        """Handle WAV chunks from UDP socket."""
        try:
            udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            udp_socket.bind((host, port))
            _LOGGER.debug("Listening for audio on UDP %s:%s", host, port)

            while True:
                wav_bytes, _ = udp_socket.recvfrom(self.udp_chunk_size +
                                                   WAV_HEADER_BYTES)

                if self.enabled:
                    self.wav_queue.put((wav_bytes, site_id))
        except Exception:
            _LOGGER.exception("udp_thread_proc")

    # -------------------------------------------------------------------------

    async def on_message_blocking(
        self,
        message: Message,
        site_id: typing.Optional[str] = None,
        session_id: typing.Optional[str] = None,
        topic: typing.Optional[str] = None,
    ) -> GeneratorType:
        """Received message from MQTT broker."""
        # Check enable/disable messages
        if isinstance(message, HotwordToggleOn):
            if message.reason == HotwordToggleReason.UNKNOWN:
                # Always enable on unknown
                self.disabled_reasons.clear()
            else:
                self.disabled_reasons.discard(message.reason)

            if self.disabled_reasons:
                _LOGGER.debug("Still disabled: %s", self.disabled_reasons)
            else:
                self.enabled = True
                self.first_audio = True
                _LOGGER.debug("Enabled")
        elif isinstance(message, HotwordToggleOff):
            self.enabled = False
            self.disabled_reasons.add(message.reason)
            _LOGGER.debug("Disabled")
        elif isinstance(message, AudioFrame):
            if self.enabled:
                assert site_id, "Missing site_id"
                await self.handle_audio_frame(message.wav_bytes,
                                              site_id=site_id)
        elif isinstance(message, GetHotwords):
            async for hotword_result in self.handle_get_hotwords(message):
                yield hotword_result
        elif isinstance(message, RecordHotwordExample):
            # Handled in on_message
            pass
        else:
            _LOGGER.warning("Unexpected message: %s", message)

    async def on_message(
        self,
        message: Message,
        site_id: typing.Optional[str] = None,
        session_id: typing.Optional[str] = None,
        topic: typing.Optional[str] = None,
    ) -> GeneratorType:
        """Received message from MQTT broker (non-blocking)."""
        if isinstance(message, RecordHotwordExample):
            async for example_result in self.handle_record_example(message):
                yield example_result
Exemple #19
0
def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(prog="rhasspy-wake-raven")
    parser.add_argument("templates",
                        nargs="+",
                        help="Path to WAV file templates or directories")
    parser.add_argument(
        "--record",
        help=
        "Record example templates with given name format (e.g., 'okay-rhasspy-{n:02d}.wav')",
    )
    parser.add_argument(
        "--probability-threshold",
        type=float,
        default=0.5,
        help="Probability above which detection occurs (default: 0.5)",
    )
    parser.add_argument(
        "--distance-threshold",
        type=float,
        default=0.22,
        help=
        "Normalized dynamic time warping distance threshold for template matching (default: 0.22)",
    )
    parser.add_argument(
        "--minimum-matches",
        type=int,
        default=1,
        help=
        "Number of templates that must match to produce output (default: 1)",
    )
    parser.add_argument(
        "--refractory-seconds",
        type=float,
        default=2.0,
        help="Seconds before wake word can be activated again (default: 2)",
    )
    parser.add_argument(
        "--print-all-matches",
        action="store_true",
        help=
        "Print JSON for all matching templates instead of just the first one",
    )
    parser.add_argument(
        "--window-shift-seconds",
        type=float,
        default=0.01,
        help=
        "Seconds to shift sliding time window on audio buffer (default: 0.01)",
    )
    parser.add_argument(
        "--dtw-window-size",
        type=int,
        default=5,
        help=
        "Size of band around slanted diagonal during dynamic time warping calculation (default: 5)",
    )
    parser.add_argument(
        "--vad-sensitivity",
        type=int,
        choices=[1, 2, 3],
        default=3,
        help="Webrtcvad VAD sensitivity (1-3)",
    )
    parser.add_argument(
        "--current-threshold",
        type=float,
        help="Debiased energy threshold of current audio frame",
    )
    parser.add_argument(
        "--max-energy",
        type=float,
        help="Fixed maximum energy for ratio calculation (default: observed)",
    )
    parser.add_argument(
        "--max-current-ratio-threshold",
        type=float,
        help="Threshold of ratio between max energy and current audio frame",
    )
    parser.add_argument(
        "--silence-method",
        choices=[e.value for e in SilenceMethod],
        default=SilenceMethod.VAD_ONLY,
        help="Method for detecting silence",
    )
    parser.add_argument(
        "--average-templates",
        action="store_true",
        help=
        "Average wakeword templates together to reduce number of calculations",
    )
    parser.add_argument(
        "--exit-count",
        type=int,
        help="Exit after some number of detections (default: never)",
    )
    parser.add_argument(
        "--read-entire-input",
        action="store_true",
        help="Read entire audio input at start and exit after processing",
    )
    parser.add_argument(
        "--max-chunks-in-queue",
        type=int,
        help=
        "Maximum number of audio chunks waiting for processing before being dropped",
    )
    parser.add_argument(
        "--skip-probability-threshold",
        type=float,
        default=0,
        help=
        "Skip additional template calculations if probability is below this threshold",
    )
    parser.add_argument("--debug",
                        action="store_true",
                        help="Print DEBUG messages to the console")
    args = parser.parse_args()

    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    # Create silence detector
    recorder = WebRtcVadRecorder(
        vad_mode=args.vad_sensitivity,
        silence_method=args.silence_method,
        current_energy_threshold=args.current_threshold,
        max_energy=args.max_energy,
        max_current_ratio_threshold=args.max_current_ratio_threshold,
        min_seconds=0.5,
        before_seconds=1,
    )

    if args.record:
        # Do recording instead of recognizing
        return record_templates(args.record, recorder, args)

    # Load audio templates
    template_paths: typing.List[Path] = []
    for template_path_str in args.templates:
        template_path = Path(template_path_str)
        if template_path.is_dir():
            # Add all WAV files from directory
            _LOGGER.debug("Adding WAV files from directory %s", template_path)
            template_paths.extend(template_path.glob("*.wav"))
        elif template_path.is_file():
            # Add file directly
            template_paths.append(template_path)

    templates = [Raven.wav_to_template(p, name=str(p)) for p in template_paths]
    if args.average_templates:
        _LOGGER.debug("Averaging %s templates", len(templates))
        templates = [Template.average_templates(templates)]

    # Create Raven object
    raven = Raven(
        templates=templates,
        recorder=recorder,
        probability_threshold=args.probability_threshold,
        minimum_matches=args.minimum_matches,
        distance_threshold=args.distance_threshold,
        refractory_sec=args.refractory_seconds,
        shift_sec=args.window_shift_seconds,
        skip_probability_threshold=args.skip_probability_threshold,
        debug=args.debug,
    )

    print("Reading 16-bit 16Khz raw audio from stdin...", file=sys.stderr)

    if args.read_entire_input:
        audio_buffer = FakeStdin(sys.stdin.buffer.read())
    else:
        audio_buffer = sys.stdin.buffer

    chunk_queue = Queue()
    detect_thread = threading.Thread(target=detect_thread_proc,
                                     args=(chunk_queue, raven, args),
                                     daemon=True)

    detect_thread.start()

    try:
        while True:
            # Read raw audio chunk
            chunk = audio_buffer.read(raven.chunk_size)
            if not chunk:
                # Empty chunk
                break

            # Ensure chunk is the right size
            while len(chunk) < raven.chunk_size:
                chunk += audio_buffer.read(raven.chunk_size - len(chunk))
                if not chunk:
                    # Empty chunk
                    break

            chunk_queue.put(chunk)

    except KeyboardInterrupt:
        pass
    finally:
        # Exhaust queue
        while not chunk_queue.empty():
            chunk_queue.get()

        # Signal thread to quit
        chunk_queue.put(None)
        detect_thread.join()
Exemple #20
0
def main():
    """Main method."""
    parser = argparse.ArgumentParser(prog="rhasspy-wake-raven-hermes")
    parser.add_argument(
        "--template-dir",
        help=
        "Directory with Raven WAV templates (default: templates in Python module)",
    )
    parser.add_argument(
        "--probability-threshold",
        type=float,
        default=0.5,
        help="Probability above which detection occurs (default: 0.5)",
    )
    parser.add_argument(
        "--distance-threshold",
        type=float,
        default=0.22,
        help=
        "Normalized dynamic time warping distance threshold for template matching (default: 0.22)",
    )
    parser.add_argument(
        "--minimum-matches",
        type=int,
        default=1,
        help=
        "Number of templates that must match to produce output (default: 1)",
    )
    parser.add_argument(
        "--refractory-seconds",
        type=float,
        default=2.0,
        help="Seconds before wake word can be activated again (default: 2)",
    )
    parser.add_argument(
        "--window-shift-seconds",
        type=float,
        default=0.05,
        help=
        "Seconds to shift sliding time window on audio buffer (default: 0.05)",
    )
    parser.add_argument(
        "--dtw-window-size",
        type=int,
        default=5,
        help=
        "Size of band around slanted diagonal during dynamic time warping calculation (default: 5)",
    )
    parser.add_argument(
        "--vad-sensitivity",
        type=int,
        choices=[1, 2, 3],
        default=3,
        help="Webrtcvad VAD sensitivity (1-3)",
    )
    parser.add_argument(
        "--current-threshold",
        type=float,
        help="Debiased energy threshold of current audio frame",
    )
    parser.add_argument(
        "--max-energy",
        type=float,
        help="Fixed maximum energy for ratio calculation (default: observed)",
    )
    parser.add_argument(
        "--max-current-ratio-threshold",
        type=float,
        help="Threshold of ratio between max energy and current audio frame",
    )
    parser.add_argument(
        "--silence-method",
        choices=[e.value for e in SilenceMethod],
        default=SilenceMethod.VAD_ONLY,
        help="Method for detecting silence",
    )
    parser.add_argument(
        "--average-templates",
        action="store_true",
        help=
        "Average wakeword templates together to reduce number of calculations",
    )
    parser.add_argument(
        "--wakeword-id",
        default="",
        help="Wakeword ID for model (default: use file name)",
    )
    parser.add_argument(
        "--udp-audio",
        nargs=3,
        action="append",
        help="Host/port/siteId for UDP audio input",
    )
    parser.add_argument(
        "--log-predictions",
        action="store_true",
        help="Log prediction probabilities for each audio chunk (very verbose)",
    )

    hermes_cli.add_hermes_args(parser)
    args = parser.parse_args()

    hermes_cli.setup_logging(args)
    _LOGGER.debug(args)
    hermes: typing.Optional[WakeHermesMqtt] = None

    wav_paths: typing.List[Path] = []
    if args.template_dir:
        args.template_dir = Path(args.template_dir)

        if args.template_dir.is_dir():
            _LOGGER.debug("Loading WAV templates from %s", args.template_dir)
            wav_paths = list(args.template_dir.glob("*.wav"))

            if not wav_paths:
                _LOGGER.warning("No WAV templates found!")

    if not wav_paths:
        args.template_dir = _DIR / "templates"
        _LOGGER.debug("Loading WAV templates from %s", args.template_dir)
        wav_paths = list(args.template_dir.glob("*.wav"))

    # Create silence detector
    recorder = WebRtcVadRecorder(
        vad_mode=args.vad_sensitivity,
        silence_method=args.silence_method,
        current_energy_threshold=args.current_threshold,
        max_energy=args.max_energy,
        max_current_ratio_threshold=args.max_current_ratio_threshold,
    )

    # Load audio templates
    templates = [Raven.wav_to_template(p, name=p.name) for p in wav_paths]
    if args.average_templates:
        _LOGGER.debug("Averaging %s templates", len(templates))
        templates = [Template.average_templates(templates)]

    raven = Raven(
        templates=templates,
        recorder=recorder,
        probability_threshold=args.probability_threshold,
        minimum_matches=args.minimum_matches,
        distance_threshold=args.distance_threshold,
        refractory_sec=args.refractory_seconds,
        shift_sec=args.window_shift_seconds,
        debug=args.log_predictions,
    )

    udp_audio = []
    if args.udp_audio:
        udp_audio = [(host, int(port), site_id)
                     for host, port, site_id in args.udp_audio]

    # Listen for messages
    client = mqtt.Client()
    hermes = WakeHermesMqtt(
        client,
        raven=raven,
        minimum_matches=args.minimum_matches,
        wakeword_id=args.wakeword_id,
        udp_audio=udp_audio,
        site_ids=args.site_id,
    )

    _LOGGER.debug("Connecting to %s:%s", args.host, args.port)
    hermes_cli.connect(client, args)
    client.loop_start()

    try:
        # Run event loop
        asyncio.run(hermes.handle_messages_async())
    except KeyboardInterrupt:
        pass
    finally:
        _LOGGER.debug("Shutting down")
        client.loop_stop()
Exemple #21
0
class WakeHermesMqtt(HermesClient):
    """Hermes MQTT server for Rhasspy wakeword with Raven."""

    def __init__(
        self,
        client,
        raven: Raven,
        minimum_matches: int = 1,
        wakeword_id: str = "",
        site_ids: typing.Optional[typing.List[str]] = None,
        enabled: bool = True,
        sample_rate: int = 16000,
        sample_width: int = 2,
        channels: int = 1,
        chunk_size: int = 960,
        udp_audio: typing.Optional[typing.List[typing.Tuple[str, int, str]]] = None,
        udp_chunk_size: int = 2048,
        log_predictions: bool = False,
    ):
        super().__init__(
            "rhasspywake_raven_hermes",
            client,
            sample_rate=sample_rate,
            sample_width=sample_width,
            channels=channels,
            site_ids=site_ids,
        )

        self.subscribe(
            AudioFrame,
            HotwordToggleOn,
            HotwordToggleOff,
            GetHotwords,
            RecordHotwordExample,
        )

        self.raven = raven
        self.minimum_matches = minimum_matches
        self.wakeword_id = wakeword_id

        self.enabled = enabled
        self.disabled_reasons: typing.Set[str] = set()

        # Required audio format
        self.sample_rate = sample_rate
        self.sample_width = sample_width
        self.channels = channels

        self.chunk_size = chunk_size

        # Queue of WAV audio chunks to process (plus site_id)
        self.wav_queue: queue.Queue = queue.Queue()

        self.first_audio: bool = True
        self.audio_buffer = bytes()

        self.last_audio_site_id: str = "default"

        # Fields for recording examples
        self.recording_example = False
        self.example_recorder = WebRtcVadRecorder(max_seconds=10)
        self.example_future: typing.Optional[asyncio.Future] = None

        # Start threads
        self.detection_thread = threading.Thread(
            target=self.detection_thread_proc, daemon=True
        )
        self.detection_thread.start()

        # Listen for raw audio on UDP too
        self.udp_chunk_size = udp_chunk_size

        if udp_audio:
            for udp_host, udp_port, udp_site_id in udp_audio:
                threading.Thread(
                    target=self.udp_thread_proc,
                    args=(udp_host, udp_port, udp_site_id),
                    daemon=True,
                ).start()

    # -------------------------------------------------------------------------

    async def handle_audio_frame(
        self, wav_bytes: bytes, site_id: str = "default"
    ) -> None:
        """Process a single audio frame"""
        self.wav_queue.put((wav_bytes, site_id))

    async def handle_detection(
        self, matching_indexes: typing.List[int]
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[HotwordDetected, TopicArgs], HotwordError]
    ]:
        """Handle a successful hotword detection"""
        try:
            template = self.raven.templates[matching_indexes[0]]
            wakeword_id = self.wakeword_id
            if not wakeword_id:
                wakeword_id = template.name

            yield (
                HotwordDetected(
                    site_id=self.last_audio_site_id,
                    model_id=template.name,
                    current_sensitivity=self.raven.distance_threshold,
                    model_version="",
                    model_type="personal",
                ),
                {"wakeword_id": wakeword_id},
            )
        except Exception as e:
            _LOGGER.exception("handle_detection")
            yield HotwordError(
                error=str(e),
                context=str(matching_indexes),
                site_id=self.last_audio_site_id,
            )

    async def handle_get_hotwords(
        self, get_hotwords: GetHotwords
    ) -> typing.AsyncIterable[typing.Union[Hotwords, HotwordError]]:
        """Report available hotwords"""
        try:
            yield Hotwords(models=[], id=get_hotwords.id, site_id=get_hotwords.site_id)

        except Exception as e:
            _LOGGER.exception("handle_get_hotwords")
            yield HotwordError(
                error=str(e), context=str(get_hotwords), site_id=get_hotwords.site_id
            )

    async def handle_record_example(
        self, record_example: RecordHotwordExample
    ) -> typing.AsyncIterable[
        typing.Union[typing.Tuple[HotwordExampleRecorded, TopicArgs], HotwordError]
    ]:
        """Record an example of a hotword."""
        try:
            assert (
                not self.recording_example
            ), "Only one example can be recorded at a time"

            # Start recording
            assert self.loop, "No loop"
            self.example_future = self.loop.create_future()
            self.example_recorder.start()
            self.recording_example = True

            # Wait for result
            _LOGGER.debug("Recording example (id=%s)", record_example.id)
            example_audio = await self.example_future
            assert isinstance(example_audio, bytes)

            # Trim silence
            _LOGGER.debug("Trimming silence from example")
            example_audio = trim_silence(example_audio)

            # Convert to WAV format
            wav_data = self.to_wav_bytes(example_audio)

            yield (
                HotwordExampleRecorded(wav_bytes=wav_data),
                {"site_id": record_example.site_id, "request_id": record_example.id},
            )

        except Exception as e:
            _LOGGER.exception("handle_record_example")
            yield HotwordError(
                error=str(e),
                context=str(record_example),
                site_id=record_example.site_id,
            )

    def add_example_audio(self, audio_data: bytes):
        """Add an audio frame to the currently recording example."""
        result = self.example_recorder.process_chunk(audio_data)
        if result:
            self.recording_example = False
            assert self.example_future is not None, "No future"
            example_audio = self.example_recorder.stop()
            _LOGGER.debug(
                "Recorded %s byte(s) for audio for example", len(example_audio)
            )

            # Signal waiting coroutine with audio
            assert self.loop, "No loop"
            self.loop.call_soon_threadsafe(
                self.example_future.set_result, example_audio
            )

    # -------------------------------------------------------------------------

    def detection_thread_proc(self):
        """Handle WAV audio chunks."""
        try:
            while True:
                wav_bytes, site_id = self.wav_queue.get()
                if wav_bytes is None:
                    # Shutdown signal
                    break

                self.last_audio_site_id = site_id

                # Handle audio frames
                if self.first_audio:
                    _LOGGER.debug("Receiving audio")
                    self.first_audio = False

                # Extract/convert audio data
                audio_data = self.maybe_convert_wav(wav_bytes)

                if self.recording_example:
                    # Add to currently recording example
                    self.add_example_audio(audio_data)

                    # Don't process audio for wake word while recording
                    self.audio_buffer = bytes()
                    continue

                # Add to persistent buffer
                self.audio_buffer += audio_data

                # Process in chunks.
                # Any remaining audio data will be kept in buffer.
                while len(self.audio_buffer) >= self.chunk_size:
                    chunk = self.audio_buffer[: self.chunk_size]
                    self.audio_buffer = self.audio_buffer[self.chunk_size :]

                    if chunk:
                        try:
                            matching_indexes = self.raven.process_chunk(chunk)
                            if len(matching_indexes) >= self.minimum_matches:
                                asyncio.run_coroutine_threadsafe(
                                    self.publish_all(
                                        self.handle_detection(matching_indexes)
                                    ),
                                    self.loop,
                                )
                        except Exception:
                            _LOGGER.exception("process_chunk")
        except Exception:
            _LOGGER.exception("detection_thread_proc")

    # -------------------------------------------------------------------------

    def udp_thread_proc(self, host: str, port: int, site_id: str):
        """Handle WAV chunks from UDP socket."""
        try:
            udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            udp_socket.bind((host, port))
            _LOGGER.debug("Listening for audio on UDP %s:%s", host, port)

            while True:
                wav_bytes, _ = udp_socket.recvfrom(
                    self.udp_chunk_size + WAV_HEADER_BYTES
                )

                if self.enabled:
                    self.wav_queue.put((wav_bytes, site_id))
        except Exception:
            _LOGGER.exception("udp_thread_proc")

    # -------------------------------------------------------------------------

    async def on_message_blocking(
        self,
        message: Message,
        site_id: typing.Optional[str] = None,
        session_id: typing.Optional[str] = None,
        topic: typing.Optional[str] = None,
    ) -> GeneratorType:
        """Received message from MQTT broker."""
        # Check enable/disable messages
        if isinstance(message, HotwordToggleOn):
            if message.reason == HotwordToggleReason.UNKNOWN:
                # Always enable on unknown
                self.disabled_reasons.clear()
            else:
                self.disabled_reasons.discard(message.reason)

            if self.disabled_reasons:
                _LOGGER.debug("Still disabled: %s", self.disabled_reasons)
            else:
                self.enabled = True
                self.first_audio = True
                _LOGGER.debug("Enabled")
        elif isinstance(message, HotwordToggleOff):
            self.enabled = False
            self.disabled_reasons.add(message.reason)
            _LOGGER.debug("Disabled")
        elif isinstance(message, AudioFrame):
            if self.enabled:
                assert site_id, "Missing site_id"
                await self.handle_audio_frame(message.wav_bytes, site_id=site_id)
        elif isinstance(message, GetHotwords):
            async for hotword_result in self.handle_get_hotwords(message):
                yield hotword_result
        elif isinstance(message, RecordHotwordExample):
            # Handled in on_message
            pass
        else:
            _LOGGER.warning("Unexpected message: %s", message)

    async def on_message(
        self,
        message: Message,
        site_id: typing.Optional[str] = None,
        session_id: typing.Optional[str] = None,
        topic: typing.Optional[str] = None,
    ) -> GeneratorType:
        """Received message from MQTT broker (non-blocking)."""
        if isinstance(message, RecordHotwordExample):
            async for example_result in self.handle_record_example(message):
                yield example_result
Exemple #22
0
def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(prog="rhasspy-wake-raven")
    parser.add_argument(
        "--keyword",
        action="append",
        nargs="+",
        default=[],
        help="Directory with WAV templates and settings (setting-name=value)",
    )
    parser.add_argument(
        "--chunk-size",
        default=1920,
        help=
        "Number of bytes to read at a time from standard in (default: 1920)",
    )
    parser.add_argument(
        "--record",
        nargs="+",
        help=
        "Record example templates to a directory, optionally with given name format (e.g., 'my-keyword-{n:02d}.wav')",
    )
    parser.add_argument(
        "--probability-threshold",
        type=float,
        default=0.5,
        help="Probability above which detection occurs (default: 0.5)",
    )
    parser.add_argument(
        "--distance-threshold",
        type=float,
        default=0.22,
        help=
        "Normalized dynamic time warping distance threshold for template matching (default: 0.22)",
    )
    parser.add_argument(
        "--minimum-matches",
        type=int,
        default=1,
        help=
        "Number of templates that must match to produce output (default: 1)",
    )
    parser.add_argument(
        "--refractory-seconds",
        type=float,
        default=2.0,
        help="Seconds before wake word can be activated again (default: 2)",
    )
    parser.add_argument(
        "--print-all-matches",
        action="store_true",
        help=
        "Print JSON for all matching templates instead of just the first one",
    )
    parser.add_argument(
        "--window-shift-seconds",
        type=float,
        default=Raven.DEFAULT_SHIFT_SECONDS,
        help=
        f"Seconds to shift sliding time window on audio buffer (default: {Raven.DEFAULT_SHIFT_SECONDS})",
    )
    parser.add_argument(
        "--dtw-window-size",
        type=int,
        default=5,
        help=
        "Size of band around slanted diagonal during dynamic time warping calculation (default: 5)",
    )
    parser.add_argument(
        "--vad-sensitivity",
        type=int,
        choices=[1, 2, 3],
        default=1,
        help="Webrtcvad VAD sensitivity (1-3)",
    )
    parser.add_argument(
        "--current-threshold",
        type=float,
        help="Debiased energy threshold of current audio frame",
    )
    parser.add_argument(
        "--max-energy",
        type=float,
        help="Fixed maximum energy for ratio calculation (default: observed)",
    )
    parser.add_argument(
        "--max-current-ratio-threshold",
        type=float,
        help="Threshold of ratio between max energy and current audio frame",
    )
    parser.add_argument(
        "--silence-method",
        choices=[e.value for e in SilenceMethod],
        default=SilenceMethod.VAD_ONLY,
        help="Method for detecting silence",
    )
    parser.add_argument(
        "--average-templates",
        action="store_true",
        help=
        "Average wakeword templates together to reduce number of calculations",
    )
    parser.add_argument(
        "--exit-count",
        type=int,
        help="Exit after some number of detections (default: never)",
    )
    parser.add_argument(
        "--read-entire-input",
        action="store_true",
        help="Read entire audio input at start and exit after processing",
    )
    parser.add_argument(
        "--max-chunks-in-queue",
        type=int,
        help=
        "Maximum number of audio chunks waiting for processing before being dropped",
    )
    parser.add_argument(
        "--skip-probability-threshold",
        type=float,
        default=0,
        help=
        "Skip additional template calculations if probability is below this threshold",
    )
    parser.add_argument(
        "--failed-matches-to-refractory",
        type=int,
        help=
        "Number of failed template matches before entering refractory period (default: disabled)",
    )
    parser.add_argument(
        "--benchmark",
        action="store_true",
        help="Track timings and report benchmark results on exit",
    )
    parser.add_argument("--debug",
                        action="store_true",
                        help="Print DEBUG messages to the console")
    args = parser.parse_args()

    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    # Create silence detector.
    # This can be shared by Raven instances because it's not maintaining state.
    recorder = WebRtcVadRecorder(
        vad_mode=args.vad_sensitivity,
        silence_method=args.silence_method,
        current_energy_threshold=args.current_threshold,
        max_energy=args.max_energy,
        max_current_ratio_threshold=args.max_current_ratio_threshold,
        min_seconds=0.5,
        before_seconds=1,
    )

    if args.record:
        # Do recording instead of recognizing
        record_dir = Path(args.record[0])
        if len(args.record) > 1:
            record_format = args.record[1]
        else:
            record_format = "example-{n:02d}.wav"

        return record_templates(record_dir, record_format, recorder, args)

    assert args.keyword, "--keyword is required"

    # Instances of Raven that will run in separate threads
    ravens: typing.List[RavenInstance] = []

    # Queue for detections. Handled in separate thread.
    output_queue = Queue()

    # Load one or more keywords
    for keyword_settings in args.keyword:
        template_dir = Path(keyword_settings[0])
        wav_paths = list(template_dir.glob("*.wav"))
        if not wav_paths:
            _LOGGER.warning("No WAV files found in %s", template_dir)
            continue

        keyword_name = template_dir.name

        # Load audio templates
        keyword_templates = [
            Raven.wav_to_template(p,
                                  name=str(p),
                                  shift_sec=args.window_shift_seconds)
            for p in wav_paths
        ]

        raven_args = {
            "templates": keyword_templates,
            "keyword_name": keyword_name,
            "recorder": recorder,
            "probability_threshold": args.probability_threshold,
            "minimum_matches": args.minimum_matches,
            "distance_threshold": args.distance_threshold,
            "refractory_sec": args.refractory_seconds,
            "shift_sec": args.window_shift_seconds,
            "skip_probability_threshold": args.skip_probability_threshold,
            "failed_matches_to_refractory": args.failed_matches_to_refractory,
            "debug": args.debug,
            "benchmark": args.benchmark,
        }

        # Apply settings
        average_templates = args.average_templates
        for setting_str in keyword_settings[1:]:
            setting_name, setting_value = setting_str.strip().split("=",
                                                                    maxsplit=1)
            setting_name = setting_name.lower()

            if setting_name == "name":
                raven_args["keyword_name"] = setting_value
            elif setting_name == "probability-threshold":
                raven_args["probability_threshold"] = float(setting_value)
            elif setting_name == "minimum-matches":
                raven_args["minimum_matches"] = int(setting_value)
            elif setting_name == "average-templates":
                average_templates = setting_value.lower().strip() == "true"

        if average_templates:
            _LOGGER.debug("Averaging %s templates for %s",
                          len(keyword_templates), template_dir)
            raven_args["templates"] = [
                Template.average_templates(keyword_templates)
            ]

        # Create instance of Raven in a separate thread for keyword
        raven = Raven(**raven_args)
        chunk_queue: "Queue[bytes]" = Queue()

        ravens.append(
            RavenInstance(
                thread=threading.Thread(
                    target=detect_thread_proc,
                    args=(chunk_queue, raven, output_queue, args),
                    daemon=True,
                ),
                raven=raven,
                chunk_queue=chunk_queue,
            ))

    # Start all threads
    for raven_inst in ravens:
        raven_inst.thread.start()

    output_thread = threading.Thread(target=output_thread_proc,
                                     args=(output_queue, ),
                                     daemon=True)

    output_thread.start()

    # -------------------------------------------------------------------------

    print("Reading 16-bit 16Khz raw audio from stdin...", file=sys.stderr)

    if args.read_entire_input:
        audio_buffer = FakeStdin(sys.stdin.buffer.read())
    else:
        audio_buffer = sys.stdin.buffer

    try:
        while True:
            # Read raw audio chunk
            chunk = audio_buffer.read(args.chunk_size)
            if not chunk or _EXIT_NOW:
                # Empty chunk
                break

            # Add to all detector threads
            for raven_inst in ravens:
                raven_inst.chunk_queue.put(chunk)

    except KeyboardInterrupt:
        pass
    finally:
        if not args.read_entire_input:
            # Exhaust queues
            _LOGGER.debug("Emptying audio queues...")
            for raven_inst in ravens:
                while not raven_inst.chunk_queue.empty():
                    raven_inst.chunk_queue.get()

        for raven_inst in ravens:
            # Signal thread to quit
            raven_inst.chunk_queue.put(None)
            _LOGGER.debug("Waiting for %s thread...",
                          raven_inst.raven.keyword_name)
            raven_inst.thread.join()

        # Stop output thread
        output_queue.put(None)
        _LOGGER.debug("Waiting for output thread...")
        output_thread.join()

        if args.benchmark:
            print_benchmark(ravens)