def convert_alignments(self) -> None:
        """
        Multiprocessing function that converts alignments from previous training

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.triphone.ConvertAlignmentsFunction`
            Multiprocessing helper function for each job
        :meth:`.TriphoneTrainer.convert_alignments_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`train_deltas`
            Reference Kaldi script
        :kaldi_steps:`train_lda_mllt`
            Reference Kaldi script
        :kaldi_steps:`train_sat`
            Reference Kaldi script

        """
        self.log_info("Converting alignments...")
        arguments = self.convert_alignments_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ConvertAlignmentsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    num_utterances, errors = result
                    pbar.update(num_utterances + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ConvertAlignmentsFunction(args)
                    for num_utterances, errors in function.run():
                        pbar.update(num_utterances + errors)
Example #2
0
    def compute_vad(self) -> None:
        """
        Compute Voice Activity Detection features over the corpus

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.ComputeVadFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.compute_vad_arguments`
            Job method for generating arguments for helper function
        """
        if os.path.exists(os.path.join(self.split_directory, "vad.0.scp")):
            self.log_info("VAD already computed, skipping!")
            return
        begin = time.time()
        self.log_info("Computing VAD...")

        arguments = self.compute_vad_arguments()
        with tqdm.tqdm(total=self.num_speakers,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ComputeVadFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if isinstance(result, KaldiProcessingError):
                        error_dict[result.job_name] = result
                        continue
                    done, no_feats, unvoiced = result
                    pbar.update(done + no_feats + unvoiced)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ComputeVadFunction(args)
                    for done, no_feats, unvoiced in function.run():
                        pbar.update(done + no_feats + unvoiced)
        self.log_debug(f"VAD computation took {time.time() - begin}")
    def gmm_gselect(self) -> None:
        """
        Multiprocessing function that stores Gaussian selection indices on disk

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.GmmGselectFunction`
            Multiprocessing helper function for each job
        :meth:`.DubmTrainer.gmm_gselect_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`train_diag_ubm`
            Reference Kaldi script

        """
        begin = time.time()
        self.log_info("Selecting gaussians...")
        arguments = self.gmm_gselect_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = GmmGselectFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if isinstance(result, Exception):
                        error_dict[getattr(result, "job_name", 0)] = result
                        continue
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = GmmGselectFunction(args)
                for _ in function.run():
                    pass

            self.log_debug(f"Gaussian selection took {time.time() - begin}")
Example #4
0
    def extract_ivectors(self) -> None:
        """
        Multiprocessing function that extracts job_name-vectors.

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.ExtractIvectorsFunction`
            Multiprocessing helper function for each job
        :meth:`.IvectorCorpusMixin.extract_ivectors_arguments`
            Job method for generating arguments for helper function
        :kaldi_steps_sid:`extract_ivectors`
            Reference Kaldi script
        """
        begin = time.time()

        log_dir = self.working_log_directory
        os.makedirs(log_dir, exist_ok=True)

        arguments = self.extract_ivectors_arguments()
        with tqdm.tqdm(total=self.num_speakers, disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ExtractIvectorsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ExtractIvectorsFunction(args)
                    for _ in function.run():
                        pbar.update(1)
        self.log_debug(f"Ivector extraction took {time.time() - begin}")
    def gauss_to_post(self) -> None:
        """
        Multiprocessing function that does Gaussian selection and posterior extraction

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.GaussToPostFunction`
            Multiprocessing helper function for each job
        :meth:`.IvectorTrainer.gauss_to_post_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps_sid:`train_ivector_extractor`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Extracting posteriors...")
        arguments = self.gauss_to_post_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = GaussToPostFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
                if isinstance(result, KaldiProcessingError):
                    error_dict[result.job_name] = result
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = GaussToPostFunction(args)
                for _ in function.run():
                    pass

        self.log_debug(f"Extracting posteriors took {time.time() - begin}")
Example #6
0
 def __init__(self, audio_directory: Optional[str] = None, **kwargs):
     super().__init__(**kwargs)
     self.audio_directory = audio_directory
     self.sound_file_errors = []
     self.transcriptions_without_wavs = []
     self.no_transcription_files = []
     self.stopped = Stopped()
     self.features_generated = False
     self.alignment_done = False
     self.transcription_done = False
     self.has_reference_alignments = False
     self.alignment_evaluation_done = False
Example #7
0
 def __init__(
     self,
     job_queue: mp.Queue,
     return_queue: mp.Queue,
     rewriter: Rewriter,
     stopped: Stopped,
 ):
     mp.Process.__init__(self)
     self.job_queue = job_queue
     self.return_queue = return_queue
     self.rewriter = rewriter
     self.stopped = stopped
     self.finished = Stopped()
Example #8
0
class RewriterWorker(mp.Process):
    """
    Rewriter process

    Parameters
    ----------
    job_queue: :class:`~multiprocessing.Queue`
        Queue to pull words from
    return_queue: :class:`~multiprocessing.Queue`
        Queue to put pronunciations
    rewriter: :class:`~montreal_forced_aligner.g2p.generator.Rewriter`
        Function to generate pronunciations of words
    stopped: :class:`~montreal_forced_aligner.utils.Stopped`
        Stop check
    """
    def __init__(
        self,
        job_queue: mp.Queue,
        return_queue: mp.Queue,
        rewriter: Rewriter,
        stopped: Stopped,
    ):
        mp.Process.__init__(self)
        self.job_queue = job_queue
        self.return_queue = return_queue
        self.rewriter = rewriter
        self.stopped = stopped
        self.finished = Stopped()

    def run(self) -> None:
        """Run the rewriting function"""
        while True:
            try:
                word = self.job_queue.get(timeout=1)
            except queue.Empty:
                break
            if self.stopped.stop_check():
                continue
            try:
                rep = self.rewriter(word)
                self.return_queue.put((word, rep))
            except rewrite.Error:
                pass
            except Exception as e:  # noqa
                self.stopped.stop()
                self.return_queue.put(e)
                raise
        self.finished.stop()
        return
 def __init__(
     self,
     job_name: int,
     job_q: mp.Queue,
     return_queue: mp.Queue,
     log_file: str,
     stopped: Stopped,
 ):
     mp.Process.__init__(self)
     self.job_name = job_name
     self.job_q = job_q
     self.return_queue = return_queue
     self.log_file = log_file
     self.stopped = stopped
     self.finished = Stopped()
 def __init__(
     self,
     corpus_directory: str,
     speaker_characters: typing.Union[int, str] = 0,
     ignore_speakers: bool = False,
     **kwargs,
 ):
     if not os.path.exists(corpus_directory):
         raise CorpusError(
             f"The directory '{corpus_directory}' does not exist.")
     if not os.path.isdir(corpus_directory):
         raise CorpusError(
             f"The specified path for the corpus ({corpus_directory}) is not a directory."
         )
     self._speaker_ids = {}
     self.corpus_directory = corpus_directory
     self.speaker_characters = speaker_characters
     self.ignore_speakers = ignore_speakers
     self.word_counts = Counter()
     self.stopped = Stopped()
     self.decode_error_files = []
     self.textgrid_read_errors = []
     self.jobs: typing.List[Job] = []
     self._num_speakers = None
     self._num_utterances = None
     self._num_files = None
     super().__init__(**kwargs)
     os.makedirs(self.corpus_output_directory, exist_ok=True)
     self.imported = False
     self._current_speaker_index = 1
     self._current_file_index = 1
     self._speaker_ids = {}
 def __init__(
     self,
     name: int,
     job_q: mp.Queue,
     return_q: mp.Queue,
     stopped: Stopped,
     finished_adding: Stopped,
     speaker_characters: Union[int, str],
     sanitize_function: Optional[MultispeakerSanitizationFunction],
     sample_rate: Optional[int],
 ):
     mp.Process.__init__(self)
     self.name = str(name)
     self.job_q = job_q
     self.return_q = return_q
     self.stopped = stopped
     self.finished_adding = finished_adding
     self.finished_processing = Stopped()
     self.sanitize_function = sanitize_function
     self.speaker_characters = speaker_characters
     self.sample_rate = sample_rate
    def __init__(
        self,
        db_path: str,
        for_write_queue: mp.Queue,
        return_queue: mp.Queue,
        stopped: Stopped,
        finished_adding: Stopped,
        arguments: ExportTextGridArguments,
        exported_file_count: Counter,
    ):
        mp.Process.__init__(self)
        self.db_path = db_path
        self.for_write_queue = for_write_queue
        self.return_queue = return_queue
        self.stopped = stopped
        self.finished_adding = finished_adding
        self.finished_processing = Stopped()

        self.output_directory = arguments.output_directory
        self.output_format = arguments.output_format
        self.frame_shift = arguments.frame_shift
        self.log_path = arguments.log_path
        self.include_original_text = arguments.include_original_text
        self.exported_file_count = exported_file_count
class RandomStartWorker(mp.Process):
    """
    Random start worker
    """
    def __init__(
        self,
        job_name: int,
        job_q: mp.Queue,
        return_queue: mp.Queue,
        log_file: str,
        stopped: Stopped,
    ):
        mp.Process.__init__(self)
        self.job_name = job_name
        self.job_q = job_q
        self.return_queue = return_queue
        self.log_file = log_file
        self.stopped = stopped
        self.finished = Stopped()

    def run(self) -> None:
        """Run the random start worker"""
        with open(self.log_file, "w", encoding="utf8") as log_file:
            while True:
                try:
                    args = self.job_q.get(timeout=1)
                except queue.Empty:
                    break
                if self.stopped.stop_check():
                    continue
                try:
                    start = time.time()
                    # Randomize channel model.
                    rfst_path = os.path.join(args.tempdir,
                                             f"random-{args.seed:05d}.fst")
                    afst_path = os.path.join(args.tempdir,
                                             f"aligner-{args.seed:05d}.fst")
                    likelihood_path = afst_path.replace(".fst", ".like")
                    if not os.path.exists(afst_path):
                        cmd = [
                            thirdparty_binary("baumwelchrandomize"),
                            f"--seed={args.seed}",
                            args.cg_path,
                            rfst_path,
                        ]
                        subprocess.check_call(cmd,
                                              stderr=log_file,
                                              env=os.environ)
                        random_end = time.time()
                        log_file.write(
                            f"{args.seed} randomization took {random_end - start} seconds\n"
                        )
                        # Train on randomized channel model.

                        likelihood = INF
                        cmd = [
                            thirdparty_binary("baumwelchtrain"),
                            *args.train_opts,
                            args.input_far_path,
                            args.output_far_path,
                            rfst_path,
                            afst_path,
                        ]
                        log_file.write(
                            f"{args.seed} train command: {' '.join(cmd)}\n")
                        log_file.flush()
                        with subprocess.Popen(cmd,
                                              stderr=subprocess.PIPE,
                                              text=True,
                                              env=os.environ) as proc:
                            # Parses STDERR to capture the likelihood.
                            for line in proc.stderr:  # type: ignore
                                log_file.write(line)
                                log_file.flush()
                                line = line.rstrip()
                                match = re.match(
                                    r"INFO: Iteration \d+: (-?\d*(\.\d*)?)",
                                    line)
                                assert match, line
                                likelihood = float(match.group(1))
                                self.return_queue.put(1)
                            with open(likelihood_path, "w") as f:
                                f.write(str(likelihood))
                        log_file.write(
                            f"{args.seed} training took {time.time() - random_end} seconds\n"
                        )
                    else:
                        with open(likelihood_path, "r") as f:
                            likelihood = f.read().strip()
                    self.return_queue.put((afst_path, likelihood))
                except Exception:
                    self.stopped.stop()
                    e = KaldiProcessingError([self.log_file])
                    e.job_name = self.job_name
                    self.return_queue.put(e)
        self.finished.stop()
        return
class CorpusProcessWorker(mp.Process):
    """
    Multiprocessing corpus loading worker

    Attributes
    ----------
    job_q: :class:`~multiprocessing.Queue`
        Job queue for files to process
    return_dict: dict
        Dictionary to catch errors
    return_q: :class:`~multiprocessing.Queue`
        Return queue for processed Files
    stopped: :class:`~montreal_forced_aligner.utils.Stopped`
        Stop check for whether corpus loading should exit
    finished_adding: :class:`~montreal_forced_aligner.utils.Stopped`
        Signal that the main thread has stopped adding new files to be processed
    """

    def __init__(
        self,
        name: int,
        job_q: mp.Queue,
        return_q: mp.Queue,
        stopped: Stopped,
        finished_adding: Stopped,
        speaker_characters: Union[int, str],
        sanitize_function: Optional[MultispeakerSanitizationFunction],
        sample_rate: Optional[int],
    ):
        mp.Process.__init__(self)
        self.name = str(name)
        self.job_q = job_q
        self.return_q = return_q
        self.stopped = stopped
        self.finished_adding = finished_adding
        self.finished_processing = Stopped()
        self.sanitize_function = sanitize_function
        self.speaker_characters = speaker_characters
        self.sample_rate = sample_rate

    def run(self) -> None:
        """
        Run the corpus loading job
        """
        while True:
            try:
                file_name, wav_path, text_path, relative_path = self.job_q.get(timeout=1)
            except Empty:
                if self.finished_adding.stop_check():
                    break
                continue
            if self.stopped.stop_check():
                continue
            try:
                file = FileData.parse_file(
                    file_name,
                    wav_path,
                    text_path,
                    relative_path,
                    self.speaker_characters,
                    self.sanitize_function,
                    self.sample_rate,
                )
                self.return_q.put(file)
            except TextParseError as e:
                self.return_q.put(("decode_error_files", e))
            except TextGridParseError as e:
                self.return_q.put(("textgrid_read_errors", e))
            except SoundFileError as e:
                self.return_q.put(("sound_file_errors", e))
            except Exception as e:
                self.stopped.stop()
                self.return_q.put(("error", e))
        self.finished_processing.stop()
        return
    def acc_stats(self) -> None:
        """
        Multiprocessing function that accumulates stats for GMM training.

        See Also
        --------
        :class:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticModelTrainingMixin.acc_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`gmm-sum-accs`
            Relevant Kaldi binary
        :kaldi_src:`gmm-est`
            Relevant Kaldi binary
        :kaldi_steps:`train_mono`
            Reference Kaldi script
        :kaldi_steps:`train_deltas`
            Reference Kaldi script
        """
        self.log_info("Accumulating statistics...")
        arguments = self.acc_stats_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = AccStatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    num_utterances, errors = result
                    pbar.update(num_utterances + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = AccStatsFunction(args)
                    for num_utterances, errors in function.run():
                        pbar.update(num_utterances + errors)

        log_path = os.path.join(self.working_log_directory,
                                f"update.{self.iteration}.log")
        with open(log_path, "w") as log_file:
            acc_files = []
            for a in arguments:
                acc_files.extend(a.acc_paths.values())
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files,
                stdout=subprocess.PIPE,
                stderr=log_file,
                env=os.environ,
            )
            est_command = [
                thirdparty_binary("gmm-est"),
                f"--write-occs={self.next_occs_path}",
                f"--mix-up={self.current_gaussians}",
            ]
            if self.power > 0:
                est_command.append(f"--power={self.power}")
            est_command.extend([
                self.model_path,
                "-",
                self.next_model_path,
            ])
            est_proc = subprocess.Popen(
                est_command,
                stdin=sum_proc.stdout,
                stderr=log_file,
                env=os.environ,
            )
            est_proc.communicate()
        avg_like_pattern = re.compile(
            r"Overall avg like per frame.* = (?P<like>[-.,\d]+) over (?P<frames>[.\d+e]+) frames"
        )
        average_logdet_pattern = re.compile(
            r"Overall average logdet is (?P<logdet>[-.,\d]+) over (?P<frames>[.\d+e]+) frames"
        )
        avg_like_sum = 0
        avg_like_frames = 0
        average_logdet_sum = 0
        average_logdet_frames = 0
        for a in arguments:
            with open(a.log_path, "r", encoding="utf8") as f:
                for line in f:
                    m = avg_like_pattern.search(line)
                    if m:
                        like = float(m.group("like"))
                        frames = float(m.group("frames"))
                        avg_like_sum += like * frames
                        avg_like_frames += frames
                    m = average_logdet_pattern.search(line)
                    if m:
                        logdet = float(m.group("logdet"))
                        frames = float(m.group("frames"))
                        average_logdet_sum += logdet * frames
                        average_logdet_frames += frames
        if avg_like_frames:
            log_like = avg_like_sum / avg_like_frames
            if average_logdet_frames:
                log_like += average_logdet_sum / average_logdet_frames
            self.log_debug(
                f"Likelihood for iteration {self.iteration}: {log_like}")

        if not self.debug:
            for f in acc_files:
                os.remove(f)
    def _load_corpus_from_source_mp(self) -> None:
        """
        Load a corpus using multiprocessing
        """
        if self.stopped is None:
            self.stopped = Stopped()
        sanitize_function = getattr(self, "sanitize_function", None)
        begin_time = time.time()
        job_queue = mp.Queue()
        return_queue = mp.Queue()
        error_dict = {}
        finished_adding = Stopped()
        procs = []
        for i in range(self.num_jobs):
            p = CorpusProcessWorker(
                i,
                job_queue,
                return_queue,
                self.stopped,
                finished_adding,
                self.speaker_characters,
                sanitize_function,
                sample_rate=0,
            )
            procs.append(p)
            p.start()
        import_data = DatabaseImportData()
        try:
            file_count = 0
            with tqdm.tqdm(total=1, disable=getattr(
                    self, "quiet", False)) as pbar, self.session() as session:
                for root, _, files in os.walk(self.corpus_directory,
                                              followlinks=True):
                    exts = find_exts(files)
                    relative_path = (root.replace(self.corpus_directory,
                                                  "").lstrip("/").lstrip("\\"))

                    if self.stopped.stop_check():
                        break
                    for file_name in exts.identifiers:
                        if self.stopped.stop_check():
                            break
                        wav_path = None
                        if file_name in exts.lab_files:
                            lab_name = exts.lab_files[file_name]
                            transcription_path = os.path.join(root, lab_name)

                        elif file_name in exts.textgrid_files:
                            tg_name = exts.textgrid_files[file_name]
                            transcription_path = os.path.join(root, tg_name)
                        else:
                            continue
                        job_queue.put((file_name, wav_path, transcription_path,
                                       relative_path))
                        file_count += 1
                        pbar.total = file_count

                finished_adding.stop()

                while True:
                    try:
                        file = return_queue.get(timeout=1)
                        if isinstance(file, tuple):
                            error_type = file[0]
                            error = file[1]
                            if error_type == "error":
                                error_dict[error_type] = error
                            else:
                                if error_type not in error_dict:
                                    error_dict[error_type] = []
                                error_dict[error_type].append(error)
                            continue
                        if self.stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished_processing.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                    import_data.add_objects(self.generate_import_objects(file))

                self.log_debug("Waiting for workers to finish...")
                for p in procs:
                    p.join()

                if "error" in error_dict:
                    session.rollback()
                    raise error_dict["error"][1]

                self._finalize_load(session, import_data)

                for k in ["decode_error_files", "textgrid_read_errors"]:
                    if hasattr(self, k):
                        if k in error_dict:
                            self.log_info(
                                "There were some issues with files in the corpus. "
                                "Please look at the log file or run the validator for more information."
                            )
                            self.log_debug(
                                f"{k} showed {len(error_dict[k])} errors:")
                            if k == "textgrid_read_errors":
                                getattr(self, k).update(error_dict[k])
                                for e in error_dict[k]:
                                    self.log_debug(f"{e.file_name}: {e.error}")
                            else:
                                self.log_debug(", ".join(error_dict[k]))
                                setattr(self, k, error_dict[k])

        except KeyboardInterrupt:
            self.log_info(
                "Detected ctrl-c, please wait a moment while we clean everything up..."
            )
            self.stopped.stop()
            finished_adding.stop()
            job_queue.join()
            self.stopped.set_sigint_source()
            while True:
                try:
                    _ = return_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
        finally:

            finished_adding.stop()
            for p in procs:
                p.join()
            if self.stopped.stop_check():
                self.log_info(
                    f"Stopped parsing early ({time.time() - begin_time} seconds)"
                )
                if self.stopped.source():
                    sys.exit(0)
            else:
                self.log_debug(
                    f"Parsed corpus directory with {self.num_jobs} jobs in {time.time() - begin_time} seconds"
                )
class ExportTextGridProcessWorker(mp.Process):
    """
    Multiprocessing worker for exporting TextGrids

    See Also
    --------
    :meth:`.CorpusAligner.collect_alignments`
        Main function that runs this worker in parallel

    Parameters
    ----------
    for_write_queue: :class:`~multiprocessing.Queue`
        Input queue of files to export
    stopped: :class:`~montreal_forced_aligner.utils.Stopped`
        Stop check for processing
    finished_processing: :class:`~montreal_forced_aligner.utils.Stopped`
        Input signal that all jobs have been added and no more new ones will come in
    textgrid_errors: dict[str, str]
        Dictionary for storing errors encountered
    arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridArguments`
        Arguments to pass to the TextGrid export function
    exported_file_count: :class:`~montreal_forced_aligner.utils.Counter`
        Counter for exported files
    """
    def __init__(
        self,
        db_path: str,
        for_write_queue: mp.Queue,
        return_queue: mp.Queue,
        stopped: Stopped,
        finished_adding: Stopped,
        arguments: ExportTextGridArguments,
        exported_file_count: Counter,
    ):
        mp.Process.__init__(self)
        self.db_path = db_path
        self.for_write_queue = for_write_queue
        self.return_queue = return_queue
        self.stopped = stopped
        self.finished_adding = finished_adding
        self.finished_processing = Stopped()

        self.output_directory = arguments.output_directory
        self.output_format = arguments.output_format
        self.frame_shift = arguments.frame_shift
        self.log_path = arguments.log_path
        self.include_original_text = arguments.include_original_text
        self.exported_file_count = exported_file_count

    def run(self) -> None:
        """Run the exporter function"""
        db_engine = sqlalchemy.create_engine(
            f"sqlite:///{self.db_path}?mode=ro&nolock=1")
        with open(self.log_path, "w",
                  encoding="utf8") as log_file, Session(db_engine) as session:

            while True:
                try:
                    (
                        file_id,
                        name,
                        relative_path,
                        duration,
                        text_file_path,
                    ) = self.for_write_queue.get(timeout=1)
                except Empty:
                    if self.finished_adding.stop_check():
                        self.finished_processing.stop()
                        break
                    continue

                if self.stopped.stop_check():
                    continue
                try:
                    output_path = construct_output_path(
                        name,
                        relative_path,
                        self.output_directory,
                        text_file_path,
                        self.output_format,
                    )
                    utterances = (session.query(Utterance).options(
                        joinedload(Utterance.speaker,
                                   innerjoin=True).load_only(Speaker.name),
                        selectinload(Utterance.phone_intervals),
                        selectinload(Utterance.word_intervals),
                    ).filter(Utterance.file_id == file_id))
                    data = {}
                    for utt in utterances:
                        if utt.speaker.name not in data:
                            data[utt.speaker.name] = {
                                "words": [],
                                "phones": []
                            }
                            if self.include_original_text:
                                data[utt.speaker.name]["utterances"] = []

                        if self.include_original_text:
                            data[utt.speaker.name]["utterances"].append(
                                CtmInterval(utt.begin, utt.end, utt.text,
                                            utt.id))
                        for wi in utt.word_intervals:
                            data[utt.speaker.name]["words"].append(
                                CtmInterval(wi.begin, wi.end, wi.label,
                                            utt.id))

                        for pi in utt.phone_intervals:
                            data[utt.speaker.name]["phones"].append(
                                CtmInterval(pi.begin, pi.end, pi.label,
                                            utt.id))
                    export_textgrid(data, output_path, duration,
                                    self.frame_shift, self.output_format)
                    self.return_queue.put(1)
                except Exception:
                    exc_type, exc_value, exc_traceback = sys.exc_info()
                    self.return_queue.put(
                        AlignmentExportError(
                            output_path,
                            traceback.format_exception(exc_type, exc_value,
                                                       exc_traceback),
                        ))
                    self.stopped.stop()
                    raise
            log_file.write("Done!\n")
    def compile_train_graphs(self) -> None:
        """
        Multiprocessing function that compiles training graphs for utterances.

        See Also
        --------
        :class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsFunction`
            Multiprocessing helper function for each job
        :meth:`.AlignMixin.compile_train_graphs_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_si`
            Reference Kaldi script
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        """
        begin = time.time()
        log_directory = self.working_log_directory
        os.makedirs(log_directory, exist_ok=True)
        self.log_info("Compiling training graphs...")
        error_sum = 0
        arguments = self.compile_train_graphs_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CompileTrainGraphsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    done, errors = result
                    pbar.update(done + errors)
                    error_sum += errors
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                self.log_debug("Not using multiprocessing...")
                for args in arguments:
                    function = CompileTrainGraphsFunction(args)
                    for done, errors in function.run():
                        pbar.update(done + errors)
                        error_sum += errors
        if error_sum:
            self.log_warning(
                f"Compilation of training graphs failed for {error_sum} utterances."
            )
        self.log_debug(f"Compiling training graphs took {time.time() - begin}")
    def align_utterances(self, training=False) -> None:
        """
        Multiprocessing function that aligns based on the current model.

        See Also
        --------
        :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction`
            Multiprocessing helper function for each job
        :meth:`.AlignMixin.align_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_si`
            Reference Kaldi script
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Generating alignments...")
        with tqdm.tqdm(
                total=self.num_current_utterances,
                disable=getattr(self, "quiet",
                                False)) as pbar, self.session() as session:
            if not training:
                utterances = session.query(Utterance)
                if hasattr(self, "subset"):
                    utterances = utterances.filter(
                        Utterance.in_subset == True)  # noqa
                utterances.update({"alignment_log_likelihood": None})
                session.commit()
            update_mappings = []
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(self.align_arguments()):
                    function = AlignFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if not training:
                        utterance, log_likelihood = result
                        update_mappings.append({
                            "id":
                            utterance,
                            "alignment_log_likelihood":
                            log_likelihood
                        })
                    pbar.update(1)
                for p in procs:
                    p.join()

                if not training and len(update_mappings) == 0:
                    raise NoAlignmentsError(self.num_current_utterances,
                                            self.beam, self.retry_beam)
                if error_dict:
                    for v in error_dict.values():
                        raise v

            else:
                self.log_debug("Not using multiprocessing...")
                for args in self.align_arguments():
                    function = AlignFunction(args)
                    for utterance, log_likelihood in function.run():
                        if not training:
                            update_mappings.append({
                                "id":
                                utterance,
                                "alignment_log_likelihood":
                                log_likelihood
                            })
                        pbar.update(1)
                if not training and len(update_mappings) == 0:
                    raise NoAlignmentsError(self.num_current_utterances,
                                            self.beam, self.retry_beam)
            if not training:
                session.bulk_update_mappings(Utterance, update_mappings)
                session.query(Utterance).filter(
                    Utterance.alignment_log_likelihood != None  # noqa
                ).update(
                    {
                        Utterance.alignment_log_likelihood:
                        Utterance.alignment_log_likelihood /
                        Utterance.num_frames
                    },
                    synchronize_session="fetch",
                )
                session.commit()
            self.log_debug(f"Alignment round took {time.time() - begin}")
    def _alignments(self) -> None:
        """Trains the aligner and constructs the alignments FAR."""
        if not os.path.exists(self.align_path):
            self.log_info("Training aligner")
            train_opts = []
            if self.batch_size:
                train_opts.append(f"--batch_size={self.batch_size}")
            if self.delta:
                train_opts.append(f"--delta={self.delta}")
            if self.fst_default_cache_gc:
                train_opts.append(
                    f"--fst_default_cache_gc={self.fst_default_cache_gc}")
            if self.fst_default_cache_gc_limit:
                train_opts.append(
                    f"--fst_default_cache_gc_limit={self.fst_default_cache_gc_limit}"
                )
            if self.alpha:
                train_opts.append(f"--alpha={self.alpha}")
            if self.num_iterations:
                train_opts.append(f"--max_iters={self.num_iterations}")
            # Constructs the actual command vectors (plus an index for logging
            # purposes).
            random.seed(self.seed)
            starts = [(RandomStart(
                idx,
                seed,
                self.input_far_path,
                self.output_far_path,
                self.cg_path,
                self.working_directory,
                train_opts,
            )) for (idx, seed) in enumerate(
                random.sample(range(1, RAND_MAX), self.random_starts), 1)]
            stopped = Stopped()
            num_commands = len(starts)
            job_queue = mp.JoinableQueue()
            fst_likelihoods = {}
            # Actually runs starts.
            self.log_info("Calculating alignments...")
            begin = time.time()
            with tqdm.tqdm(total=num_commands * self.num_iterations,
                           disable=getattr(self, "quiet", False)) as pbar:
                for start in starts:
                    job_queue.put(start)
                error_dict = {}
                return_queue = mp.Queue()
                procs = []
                for i in range(self.num_jobs):
                    log_path = os.path.join(self.working_log_directory,
                                            f"baumwelch.{i}.log")
                    p = RandomStartWorker(
                        i,
                        job_queue,
                        return_queue,
                        log_path,
                        stopped,
                    )
                    procs.append(p)
                    p.start()

                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except queue.Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if isinstance(result, int):
                        pbar.update(result)
                    else:
                        fst_likelihoods[result[0]] = result[1]
                for p in procs:
                    p.join()
            if error_dict:
                raise PyniniAlignmentError(error_dict)
            (best_fst, best_likelihood) = min(fst_likelihoods.items(),
                                              key=operator.itemgetter(1))
            self.log_info(f"Best likelihood: {best_likelihood}")
            self.log_debug(
                f"Ran {self.random_starts} random starts in {time.time() - begin} seconds"
            )
            # Moves best likelihood solution to the requested location.
            shutil.move(best_fst, self.align_path)
        cmd = [thirdparty_binary("baumwelchdecode")]
        if self.fst_default_cache_gc:
            cmd.append(f"--fst_default_cache_gc={self.fst_default_cache_gc}")
        if self.fst_default_cache_gc_limit:
            cmd.append(
                f"--fst_default_cache_gc_limit={self.fst_default_cache_gc_limit}"
            )
        cmd.append(self.input_far_path)
        cmd.append(self.output_far_path)
        cmd.append(self.align_path)
        cmd.append(self.afst_path)
        self.log_debug(f"Subprocess call: {cmd}")
        subprocess.check_call(cmd, env=os.environ)
        self.log_info("Completed computing alignments!")
Example #21
0
    def acc_stats(self, alignment: bool = False) -> None:
        """
        Accumulate stats for the mapped model

        Parameters
        ----------
        alignment: bool
            Flag for whether to accumulate stats for the mapped alignment model
        """
        arguments = self.map_acc_stats_arguments(alignment)
        if alignment:
            initial_mdl_path = os.path.join(self.working_directory,
                                            "unadapted.alimdl")
            final_mdl_path = os.path.join(self.working_directory,
                                          "final.alimdl")
        else:
            initial_mdl_path = os.path.join(self.working_directory,
                                            "unadapted.mdl")
            final_mdl_path = os.path.join(self.working_directory, "final.mdl")
        if not os.path.exists(initial_mdl_path):
            return
        self.log_info("Accumulating statistics...")
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = AccStatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    num_utterances, errors = result
                    pbar.update(num_utterances + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = AccStatsFunction(args)
                    for num_utterances, errors in function.run():
                        pbar.update(num_utterances + errors)
        log_path = os.path.join(self.working_log_directory,
                                "map_model_est.log")
        occs_path = os.path.join(self.working_directory, "final.occs")
        with open(log_path, "w", encoding="utf8") as log_file:
            acc_files = []
            for j in arguments:
                acc_files.extend(j.acc_paths.values())
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files,
                stderr=log_file,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            ismooth_proc = subprocess.Popen(
                [
                    thirdparty_binary("gmm-ismooth-stats"),
                    "--smooth-from-model",
                    f"--tau={self.mapping_tau}",
                    initial_mdl_path,
                    "-",
                    "-",
                ],
                stderr=log_file,
                stdin=sum_proc.stdout,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            est_proc = subprocess.Popen(
                [
                    thirdparty_binary("gmm-est"),
                    "--update-flags=m",
                    f"--write-occs={occs_path}",
                    "--remove-low-count-gaussians=false",
                    initial_mdl_path,
                    "-",
                    final_mdl_path,
                ],
                stdin=ismooth_proc.stdout,
                stderr=log_file,
                env=os.environ,
            )
            est_proc.communicate()
Example #22
0
    def calc_lda_mllt(self) -> None:
        """
        Multiprocessing function that calculates LDA+MLLT transformations.

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction`
            Multiprocessing helper function for each job
        :meth:`.LdaTrainer.calc_lda_mllt_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`est-mllt`
            Relevant Kaldi binary
        :kaldi_src:`gmm-transform-means`
            Relevant Kaldi binary
        :kaldi_src:`compose-transforms`
            Relevant Kaldi binary
        :kaldi_steps:`train_lda_mllt`
            Reference Kaldi script

        """
        self.log_info("Re-calculating LDA...")
        arguments = self.calc_lda_mllt_arguments()
        with tqdm.tqdm(
            total=self.num_current_utterances, disable=getattr(self, "quiet", False)
        ) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CalcLdaMlltFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = CalcLdaMlltFunction(args)
                    for _ in function.run():
                        pbar.update(1)

        log_path = os.path.join(
            self.working_log_directory, f"transform_means.{self.iteration}.log"
        )
        previous_mat_path = os.path.join(self.working_directory, "lda.mat")
        new_mat_path = os.path.join(self.working_directory, "lda_new.mat")
        composed_path = os.path.join(self.working_directory, "lda_composed.mat")
        with open(log_path, "a", encoding="utf8") as log_file:
            macc_list = []
            for x in arguments:
                macc_list.extend(x.macc_paths.values())
            subprocess.call(
                [thirdparty_binary("est-mllt"), new_mat_path] + macc_list,
                stderr=log_file,
                env=os.environ,
            )
            subprocess.call(
                [
                    thirdparty_binary("gmm-transform-means"),
                    new_mat_path,
                    self.model_path,
                    self.model_path,
                ],
                stderr=log_file,
                env=os.environ,
            )

            if os.path.exists(previous_mat_path):
                subprocess.call(
                    [
                        thirdparty_binary("compose-transforms"),
                        new_mat_path,
                        previous_mat_path,
                        composed_path,
                    ],
                    stderr=log_file,
                    env=os.environ,
                )
                os.remove(previous_mat_path)
                os.rename(composed_path, previous_mat_path)
            else:
                os.rename(new_mat_path, previous_mat_path)
Example #23
0
    def mfcc(self) -> None:
        """
        Multiprocessing function that converts sound files into MFCCs.

        See :kaldi_docs:`feat` for an overview on feature generation in Kaldi.

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.MfccFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.mfcc_arguments`
            Job method for generating arguments for helper function
        :kaldi_steps:`make_mfcc`
            Reference Kaldi script
        """
        self.log_info("Generating MFCCs...")
        log_directory = os.path.join(self.split_directory, "log")
        os.makedirs(log_directory, exist_ok=True)
        arguments = self.mfcc_arguments()
        with tqdm.tqdm(total=self.num_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = MfccFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(result)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        print(v)
                    self.dirty = True
                    sys.exit(1)
            else:
                for args in arguments:
                    function = MfccFunction(args)
                    for num_utterances in function.run():
                        pbar.update(num_utterances)
        with self.session() as session:
            update_mapping = []
            session.query(Utterance).update({"ignored": True})
            for j in arguments:
                with open(j.feats_scp_path, "r", encoding="utf8") as f:
                    for line in f:
                        line = line.strip()
                        if line == "":
                            continue
                        f = line.split(maxsplit=1)
                        utt_id = int(f[0].split("-")[-1])
                        feats = f[1]
                        update_mapping.append({
                            "id": utt_id,
                            "features": feats,
                            "ignored": False
                        })
            session.bulk_update_mappings(Utterance, update_mapping)
            session.commit()
Example #24
0
    def generate_pronunciations(self) -> Dict[str, List[str]]:
        """
        Generate pronunciations

        Returns
        -------
        dict[str, list[str]]
            Mappings of keys to their generated pronunciations
        """

        fst = pynini.Fst.read(self.g2p_model.fst_path)
        if self.g2p_model.meta["architecture"] == "phonetisaurus":
            output_token_type = pynini.SymbolTable.read_text(
                self.g2p_model.sym_path)
            input_token_type = pynini.SymbolTable.read_text(
                self.g2p_model.grapheme_sym_path)
            fst.set_input_symbols(input_token_type)
            fst.set_output_symbols(output_token_type)
            rewriter = PhonetisaurusRewriter(
                fst,
                input_token_type,
                output_token_type,
                num_pronunciations=self.num_pronunciations,
                threshold=self.g2p_threshold,
                grapheme_order=self.g2p_model.meta["grapheme_order"],
            )
        else:
            output_token_type = "utf8"
            input_token_type = "utf8"
            if self.g2p_model.sym_path is not None and os.path.exists(
                    self.g2p_model.sym_path):
                output_token_type = pynini.SymbolTable.read_text(
                    self.g2p_model.sym_path)
            rewriter = Rewriter(
                fst,
                input_token_type,
                output_token_type,
                num_pronunciations=self.num_pronunciations,
                threshold=self.g2p_threshold,
            )

        num_words = len(self.words_to_g2p)
        begin = time.time()
        missing_graphemes = set()
        self.log_info("Generating pronunciations...")
        to_return = {}
        skipped_words = 0
        if num_words < 30 or self.num_jobs == 1:
            with tqdm.tqdm(total=num_words,
                           disable=getattr(self, "quiet", False)) as pbar:
                for word in self.words_to_g2p:
                    w, m = clean_up_word(word,
                                         self.g2p_model.meta["graphemes"])
                    pbar.update(1)
                    missing_graphemes = missing_graphemes | m
                    if self.strict_graphemes and m:
                        skipped_words += 1
                        continue
                    if not w:
                        skipped_words += 1
                        continue
                    try:
                        prons = rewriter(w)
                    except rewrite.Error:
                        continue
                    to_return[word] = prons
                self.log_debug(
                    f"Skipping {skipped_words} words for containing the following graphemes: "
                    f"{comma_join(sorted(missing_graphemes))}")
        else:
            stopped = Stopped()
            job_queue = mp.Queue()
            for word in self.words_to_g2p:
                w, m = clean_up_word(word, self.g2p_model.meta["graphemes"])
                missing_graphemes = missing_graphemes | m
                if self.strict_graphemes and m:
                    skipped_words += 1
                    continue
                if not w:
                    skipped_words += 1
                    continue
                job_queue.put(w)
            self.log_debug(
                f"Skipping {skipped_words} words for containing the following graphemes: "
                f"{comma_join(sorted(missing_graphemes))}")
            error_dict = {}
            return_queue = mp.Queue()
            procs = []
            for _ in range(self.num_jobs):
                p = RewriterWorker(
                    job_queue,
                    return_queue,
                    rewriter,
                    stopped,
                )
                procs.append(p)
                p.start()
            num_words -= skipped_words
            with tqdm.tqdm(total=num_words,
                           disable=getattr(self, "quiet", False)) as pbar:
                while True:
                    try:
                        word, result = return_queue.get(timeout=1)
                        if stopped.stop_check():
                            continue
                    except queue.Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                    if isinstance(result, Exception):
                        error_dict[word] = result
                        continue
                    to_return[word] = result

            for p in procs:
                p.join()
            if error_dict:
                raise PyniniGenerationError(error_dict)
        self.log_debug(
            f"Processed {num_words} in {time.time() - begin} seconds")
        return to_return
Example #25
0
class AcousticCorpusMixin(CorpusMixin, FeatureConfigMixin, metaclass=ABCMeta):
    """
    Mixin class for acoustic corpora

    Parameters
    ----------
    audio_directory: str
        Extra directory to look for audio files

    See Also
    --------
    :class:`~montreal_forced_aligner.corpus.base.CorpusMixin`
        For corpus parsing parameters
    :class:`~montreal_forced_aligner.corpus.features.FeatureConfigMixin`
        For feature generation parameters

    Attributes
    ----------
    sound_file_errors: list[str]
        List of sound files with errors in loading
    transcriptions_without_wavs: list[str]
        List of text files without sound files
    no_transcription_files: list[str]
        List of sound files without transcription files
    stopped: Stopped
        Stop check for loading the corpus
    """
    def __init__(self, audio_directory: Optional[str] = None, **kwargs):
        super().__init__(**kwargs)
        self.audio_directory = audio_directory
        self.sound_file_errors = []
        self.transcriptions_without_wavs = []
        self.no_transcription_files = []
        self.stopped = Stopped()
        self.features_generated = False
        self.alignment_done = False
        self.transcription_done = False
        self.has_reference_alignments = False
        self.alignment_evaluation_done = False

    def inspect_database(self) -> None:
        """Check if a database file exists and create the necessary metadata"""
        exist_check = os.path.exists(self.db_path)
        if not exist_check:
            self.initialize_database()
        with self.session() as session:
            corpus = session.query(Corpus).first()
            if corpus:
                self.imported = corpus.imported
                self.features_generated = corpus.features_generated
                self.alignment_done = corpus.alignment_done
                self.transcription_done = corpus.transcription_done
                self.has_reference_alignments = corpus.has_reference_alignments
                self.alignment_evaluation_done = corpus.alignment_evaluation_done
            else:
                session.add(Corpus(name=self.data_source_identifier))
                session.commit()

    def load_reference_alignments(self, reference_directory: str) -> None:
        """
        Load reference alignments to use in alignment evaluation from a directory

        Parameters
        ----------
        reference_directory: str
            Directory containing reference alignments

        """
        if self.has_reference_alignments:
            self.log_info("Reference alignments already loaded!")
            return
        self.log_info("Loading reference files...")
        indices = []
        jobs = []
        reference_intervals = []
        with tqdm.tqdm(total=self.num_files,
                       disable=getattr(self, "quiet", False)) as pbar, Session(
                           self.db_engine, autoflush=False) as session:
            for root, _, files in os.walk(reference_directory,
                                          followlinks=True):
                root_speaker = os.path.basename(root)
                for f in files:
                    if f.endswith(".TextGrid"):
                        file_name = f.replace(".TextGrid", "")
                        file_id = session.query(
                            File.id).filter_by(name=file_name).scalar()
                        if not file_id:
                            continue
                        if self.use_mp:
                            indices.append(file_id)
                            jobs.append((os.path.join(root, f), root_speaker))
                        else:
                            intervals = parse_aligned_textgrid(
                                os.path.join(root, f), root_speaker)
                            utterances = (session.query(
                                Utterance.id, Speaker.name,
                                Utterance.end).join(Utterance.speaker).join(
                                    Utterance.file).filter(File.id == file_id))
                            for u_id, speaker_name, end in utterances:
                                if speaker_name not in intervals:
                                    continue
                                while (intervals[speaker_name] and
                                       intervals[speaker_name][0].end <= end):
                                    interval = intervals[speaker_name].pop(0)
                                    reference_intervals.append({
                                        "begin":
                                        interval.begin,
                                        "end":
                                        interval.end,
                                        "label":
                                        interval.label,
                                        "utterance_id":
                                        u_id,
                                    })

                            pbar.update(1)
            if self.use_mp:
                with mp.Pool(self.num_jobs) as pool:
                    gen = pool.starmap(parse_aligned_textgrid, jobs)
                    for i, intervals in enumerate(gen):
                        pbar.update(1)
                        file_id = indices[i]
                        utterances = (session.query(
                            Utterance.id, Speaker.name,
                            Utterance.end).join(Utterance.speaker).filter(
                                Utterance.file_id == file_id))
                        for u_id, speaker_name, end in utterances:
                            if speaker_name not in intervals:
                                continue
                            while (intervals[speaker_name]
                                   and intervals[speaker_name][0].end <= end):
                                interval = intervals[speaker_name].pop(0)
                                reference_intervals.append({
                                    "begin":
                                    interval.begin,
                                    "end":
                                    interval.end,
                                    "label":
                                    interval.label,
                                    "utterance_id":
                                    u_id,
                                })
            with session.bind.begin() as conn:
                conn.execute(
                    sqlalchemy.insert(ReferencePhoneInterval.__table__),
                    reference_intervals)
                session.commit()
            session.query(Corpus).update({"has_reference_alignments": True})
            session.commit()

    def load_corpus(self) -> None:
        """
        Load the corpus
        """
        self.initialize_database()
        self._load_corpus()

        self.initialize_jobs()
        self.create_corpus_split()
        self.generate_features()

    def generate_features(self, compute_cmvn: bool = True) -> None:
        """
        Generate features for the corpus

        Parameters
        ----------
        compute_cmvn: bool
            Flag for whether to compute CMVN, defaults to True
        """
        if self.features_generated:
            return
        self.log_info(f"Generating base features ({self.feature_type})...")
        if self.feature_type == "mfcc":
            self.mfcc()
        self.combine_feats()
        if compute_cmvn:
            self.log_info("Calculating CMVN...")
            self.calc_cmvn()
        self.features_generated = True
        with self.session() as session:
            session.query(Corpus).update({"features_generated": True})
            session.commit()
        self.create_corpus_split()

    def create_corpus_split(self) -> None:
        """Create the split directory for the corpus"""
        if self.features_generated:
            self.log_info("Creating corpus split with features...")
            super().create_corpus_split()
        else:
            self.log_info("Creating corpus split for feature generation...")
            split_dir = self.split_directory
            os.makedirs(os.path.join(split_dir, "log"), exist_ok=True)
            with self.session() as session:
                for job in self.jobs:
                    job.output_for_features(split_dir, session)

    def construct_base_feature_string(self, all_feats: bool = False) -> str:
        """
        Construct the base feature string independent of job name

        Used in initialization of MonophoneTrainer (to get dimension size) and IvectorTrainer (uses all feats)

        Parameters
        ----------
        all_feats: bool
            Flag for whether all features across all jobs should be taken into account

        Returns
        -------
        str
            Base feature string
        """
        j = self.jobs[0]
        if all_feats:
            feat_path = os.path.join(self.base_data_directory, "feats.scp")
            utt2spk_path = os.path.join(self.base_data_directory,
                                        "utt2spk.scp")
            cmvn_path = os.path.join(self.base_data_directory, "cmvn.scp")
            feats = f'ark,s,cs:apply-cmvn --utt2spk=ark:"{utt2spk_path}" scp:"{cmvn_path}" scp:"{feat_path}" ark:- |'
            feats += " add-deltas ark:- ark:- |"
            return feats
        utt2spks = j.construct_path_dictionary(self.data_directory, "utt2spk",
                                               "scp")
        cmvns = j.construct_path_dictionary(self.data_directory, "cmvn", "scp")
        features = j.construct_path_dictionary(self.data_directory, "feats",
                                               "scp")
        for dict_id in j.dictionary_ids:
            feat_path = features[dict_id]
            cmvn_path = cmvns[dict_id]
            utt2spk_path = utt2spks[dict_id]
            feats = f'ark,s,cs:apply-cmvn --utt2spk=ark:"{utt2spk_path}" scp:"{cmvn_path}" scp:"{feat_path}" ark:- |'
            if self.uses_deltas:
                feats += " add-deltas ark:- ark:- |"

            return feats
        else:
            utt2spk_path = j.construct_path(self.data_directory, "utt2spk",
                                            "scp")
            cmvn_path = j.construct_path(self.data_directory, "cmvn", "scp")
            feat_path = j.construct_path(self.data_directory, "feats", "scp")
            feats = f'ark,s,cs:apply-cmvn --utt2spk=ark:"{utt2spk_path}" scp:"{cmvn_path}" scp:"{feat_path}" ark:- |'
            if self.uses_deltas:
                feats += " add-deltas ark:- ark:- |"
            return feats

    def construct_feature_proc_strings(
        self,
        speaker_independent: bool = False,
    ) -> typing.Union[List[Dict[str, str]], List[str]]:
        """
        Constructs a feature processing string to supply to Kaldi binaries, taking into account corpus features and the
        current working directory of the aligner (whether fMLLR or LDA transforms should be used, etc).

        Parameters
        ----------
        speaker_independent: bool
            Flag for whether features should be speaker-independent regardless of the presence of fMLLR transforms

        Returns
        -------
        list[dict[str, str]]
            Feature strings per job
        """
        strings = []
        for j in self.jobs:
            lda_mat_path = None
            fmllrs = {}
            if self.working_directory is not None:
                lda_mat_path = os.path.join(self.working_directory, "lda.mat")
                if not os.path.exists(lda_mat_path):
                    lda_mat_path = None

                fmllrs = j.construct_path_dictionary(self.working_directory,
                                                     "trans", "ark")
            utt2spks = j.construct_path_dictionary(self.data_directory,
                                                   "utt2spk", "scp")
            cmvns = j.construct_path_dictionary(self.data_directory, "cmvn",
                                                "scp")
            features = j.construct_path_dictionary(self.data_directory,
                                                   "feats", "scp")
            vads = j.construct_path_dictionary(self.data_directory, "vad",
                                               "scp")
            feat_strings = {}
            if not j.dictionary_ids:
                utt2spk_path = j.construct_path(self.data_directory, "utt2spk",
                                                "scp")
                cmvn_path = j.construct_path(self.data_directory, "cmvn",
                                             "scp")
                feat_path = j.construct_path(self.data_directory, "feats",
                                             "scp")
                feats = f'ark,s,cs:apply-cmvn --utt2spk=ark:"{utt2spk_path}" scp:"{cmvn_path}" scp:"{feat_path}" ark:- |'
                if self.uses_deltas:
                    feats += " add-deltas ark:- ark:- |"

                strings.append(feats)
                continue

            for dict_id in j.dictionary_ids:
                feat_path = features[dict_id]
                cmvn_path = cmvns[dict_id]
                utt2spk_path = utt2spks[dict_id]
                fmllr_trans_path = None
                try:
                    fmllr_trans_path = fmllrs[dict_id]
                    if not os.path.exists(fmllr_trans_path):
                        fmllr_trans_path = None
                except KeyError:
                    pass
                vad_path = vads[dict_id]
                if self.uses_voiced:
                    feats = f'ark,s,cs:add-deltas scp:"{feat_path}" ark:- |'
                    if self.uses_cmvn:
                        feats += " apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=300 ark:- ark:- |"
                    feats += f' select-voiced-frames ark:- scp,s,cs:"{vad_path}" ark:- |'
                elif not os.path.exists(cmvn_path) and self.uses_cmvn:
                    feats = f'ark,s,cs:add-deltas scp:"{feat_path}" ark:- |'
                    if self.uses_cmvn:
                        feats += " apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=300 ark:- ark:- |"
                else:
                    feats = f'ark,s,cs:apply-cmvn --utt2spk=ark:"{utt2spk_path}" scp:"{cmvn_path}" scp:"{feat_path}" ark:- |'
                    if lda_mat_path is not None:
                        feats += f" splice-feats --left-context={self.splice_left_context} --right-context={self.splice_right_context} ark:- ark:- |"
                        feats += f' transform-feats "{lda_mat_path}" ark:- ark:- |'
                    elif self.uses_splices:
                        feats += f" splice-feats --left-context={self.splice_left_context} --right-context={self.splice_right_context} ark:- ark:- |"
                    elif self.uses_deltas:
                        feats += " add-deltas ark:- ark:- |"
                    if fmllr_trans_path is not None and not (
                            self.speaker_independent or speaker_independent):
                        if not os.path.exists(fmllr_trans_path):
                            raise Exception(
                                f"Could not find {fmllr_trans_path}")
                        feats += f' transform-feats --utt2spk=ark:"{utt2spk_path}" ark:"{fmllr_trans_path}" ark:- ark:- |'
                feat_strings[dict_id] = feats
            strings.append(feat_strings)
        return strings

    def compute_vad_arguments(self) -> List[VadArguments]:
        """
        Generate Job arguments for :class:`~montreal_forced_aligner.corpus.features.ComputeVadFunction`

        Returns
        -------
        list[:class:`~montreal_forced_aligner.corpus.features.VadArguments`]
            Arguments for processing
        """
        return [
            VadArguments(
                j.name,
                getattr(self, "db_engine", ""),
                os.path.join(self.split_directory, "log",
                             f"compute_vad.{j.name}.log"),
                j.construct_path(self.split_directory, "feats", "scp"),
                j.construct_path(self.split_directory, "vad", "scp"),
                self.vad_options,
            ) for j in self.jobs if j.has_data
        ]

    def calc_fmllr_arguments(self,
                             iteration: Optional[int] = None
                             ) -> List[CalcFmllrArguments]:
        """
        Generate Job arguments for :class:`~montreal_forced_aligner.corpus.features.CalcFmllrFunction`

        Returns
        -------
        list[:class:`~montreal_forced_aligner.corpus.features.CalcFmllrArguments`]
            Arguments for processing
        """
        feature_strings = self.construct_feature_proc_strings()
        base_log = "calc_fmllr"
        if iteration is not None:
            base_log += f".{iteration}"
        return [
            CalcFmllrArguments(
                j.name,
                getattr(self, "db_path", ""),
                os.path.join(self.working_log_directory,
                             f"{base_log}.{j.name}.log"),
                j.dictionary_ids,
                feature_strings[j.name],
                j.construct_path_dictionary(self.working_directory, "ali",
                                            "ark"),
                self.alignment_model_path,
                self.model_path,
                j.construct_path_dictionary(self.data_directory, "spk2utt",
                                            "scp"),
                j.construct_path_dictionary(self.working_directory, "trans",
                                            "ark"),
                self.fmllr_options,
            ) for j in self.jobs if j.has_data
        ]

    def mfcc_arguments(self) -> List[MfccArguments]:
        """
        Generate Job arguments for :class:`~montreal_forced_aligner.corpus.features.MfccFunction`

        Returns
        -------
        list[:class:`~montreal_forced_aligner.corpus.features.MfccArguments`]
            Arguments for processing
        """
        return [
            MfccArguments(
                j.name,
                self.db_path,
                os.path.join(self.split_directory, "log",
                             f"make_mfcc.{j.name}.log"),
                j.construct_path(self.split_directory, "wav", "scp"),
                j.construct_path(self.split_directory, "segments", "scp"),
                j.construct_path(self.split_directory, "feats", "scp"),
                self.mfcc_options,
                self.pitch_options,
            ) for j in self.jobs if j.has_data
        ]

    def mfcc(self) -> None:
        """
        Multiprocessing function that converts sound files into MFCCs.

        See :kaldi_docs:`feat` for an overview on feature generation in Kaldi.

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.MfccFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.mfcc_arguments`
            Job method for generating arguments for helper function
        :kaldi_steps:`make_mfcc`
            Reference Kaldi script
        """
        self.log_info("Generating MFCCs...")
        log_directory = os.path.join(self.split_directory, "log")
        os.makedirs(log_directory, exist_ok=True)
        arguments = self.mfcc_arguments()
        with tqdm.tqdm(total=self.num_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = MfccFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(result)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        print(v)
                    self.dirty = True
                    sys.exit(1)
            else:
                for args in arguments:
                    function = MfccFunction(args)
                    for num_utterances in function.run():
                        pbar.update(num_utterances)
        with self.session() as session:
            update_mapping = []
            session.query(Utterance).update({"ignored": True})
            for j in arguments:
                with open(j.feats_scp_path, "r", encoding="utf8") as f:
                    for line in f:
                        line = line.strip()
                        if line == "":
                            continue
                        f = line.split(maxsplit=1)
                        utt_id = int(f[0].split("-")[-1])
                        feats = f[1]
                        update_mapping.append({
                            "id": utt_id,
                            "features": feats,
                            "ignored": False
                        })
            session.bulk_update_mappings(Utterance, update_mapping)
            session.commit()

    def calc_cmvn(self) -> None:
        """
        Calculate CMVN statistics for speakers

        See Also
        --------
        :kaldi_src:`compute-cmvn-stats`
            Relevant Kaldi binary
        """
        self._write_feats()
        self._write_spk2utt()
        spk2utt = os.path.join(self.corpus_output_directory, "spk2utt.scp")
        feats = os.path.join(self.corpus_output_directory, "feats.scp")
        cmvn_ark = os.path.join(self.corpus_output_directory, "cmvn.ark")
        cmvn_scp = os.path.join(self.corpus_output_directory, "cmvn.scp")
        log_path = os.path.join(self.features_log_directory, "cmvn.log")
        with open(log_path, "w") as logf:
            subprocess.call(
                [
                    thirdparty_binary("compute-cmvn-stats"),
                    f"--spk2utt=ark:{spk2utt}",
                    f"scp:{feats}",
                    f"ark,scp:{cmvn_ark},{cmvn_scp}",
                ],
                stderr=logf,
                env=os.environ,
            )
        update_mapping = []
        with self.session() as session:
            for s, cmvn in load_scp(cmvn_scp).items():
                if isinstance(cmvn, list):
                    cmvn = " ".join(cmvn)
                update_mapping.append({"id": int(s), "cmvn": cmvn})
            session.bulk_update_mappings(Speaker, update_mapping)
            session.commit()

    def calc_fmllr(self, iteration: Optional[int] = None) -> None:
        """
        Multiprocessing function that computes speaker adaptation transforms via
        feature-space Maximum Likelihood Linear Regression (fMLLR).

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.CalcFmllrFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.calc_fmllr_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        :kaldi_steps:`train_sat`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Calculating fMLLR for speaker adaptation...")

        arguments = self.calc_fmllr_arguments(iteration=iteration)
        with tqdm.tqdm(total=self.num_speakers,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CalcFmllrFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = CalcFmllrFunction(args)
                    for _ in function.run():
                        pbar.update(1)

        self.speaker_independent = False
        self.log_debug(f"Fmllr calculation took {time.time() - begin}")

    def compute_vad(self) -> None:
        """
        Compute Voice Activity Detection features over the corpus

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.ComputeVadFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.compute_vad_arguments`
            Job method for generating arguments for helper function
        """
        if os.path.exists(os.path.join(self.split_directory, "vad.0.scp")):
            self.log_info("VAD already computed, skipping!")
            return
        begin = time.time()
        self.log_info("Computing VAD...")

        arguments = self.compute_vad_arguments()
        with tqdm.tqdm(total=self.num_speakers,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ComputeVadFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if isinstance(result, KaldiProcessingError):
                        error_dict[result.job_name] = result
                        continue
                    done, no_feats, unvoiced = result
                    pbar.update(done + no_feats + unvoiced)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ComputeVadFunction(args)
                    for done, no_feats, unvoiced in function.run():
                        pbar.update(done + no_feats + unvoiced)
        self.log_debug(f"VAD computation took {time.time() - begin}")

    def combine_feats(self) -> None:
        """
        Combine feature generation results and store relevant information
        """

        with self.session() as session:
            ignored_utterances = (
                session.query(
                    SoundFile.sound_file_path,
                    Speaker.name,
                    Utterance.begin,
                    Utterance.end,
                    Utterance.text,
                ).join(Utterance.speaker).join(Utterance.file).join(
                    File.sound_file).filter(Utterance.ignored == True)  # noqa
            )
            ignored_count = 0
            for sound_file_path, speaker_name, begin, end, text in ignored_utterances:
                self.log_debug(f"  - Ignored File: {sound_file_path}")
                self.log_debug(f"    - Speaker: {speaker_name}")
                self.log_debug(f"    - Begin: {begin}")
                self.log_debug(f"    - End: {end}")
                self.log_debug(f"    - Text: {text}")
                ignored_count += 1
            if ignored_count:
                self.log_warning(
                    f"There were {ignored_count} utterances ignored due to an issue in feature generation, see the log file for full "
                    "details or run `mfa validate` on the corpus.")

    def _write_feats(self) -> None:
        """Write feats scp file for Kaldi"""
        feats_path = os.path.join(self.corpus_output_directory, "feats.scp")
        with self.session() as session, open(feats_path, "w",
                                             encoding="utf8") as f:
            utterances = (session.query(
                Utterance.kaldi_id,
                Utterance.features).filter_by(ignored=False).order_by(
                    Utterance.kaldi_id))
            for u_id, features in utterances:
                f.write(f"{u_id} {features}\n")

    def get_feat_dim(self) -> int:
        """
        Calculate the feature dimension for the corpus

        Returns
        -------
        int
            Dimension of feature vectors
        """
        feature_string = self.construct_base_feature_string()
        with open(os.path.join(self.features_log_directory, "feat-to-dim.log"),
                  "w") as log_file:
            subset_proc = subprocess.Popen(
                [
                    thirdparty_binary("subset-feats"),
                    "--n=1",
                    feature_string,
                    "ark:-",
                ],
                stderr=log_file,
                stdout=subprocess.PIPE,
            )
            dim_proc = subprocess.Popen(
                [thirdparty_binary("feat-to-dim"), "ark:-", "-"],
                stdin=subset_proc.stdout,
                stdout=subprocess.PIPE,
                stderr=log_file,
            )
            stdout, stderr = dim_proc.communicate()
            feats = stdout.decode("utf8").strip()
        return int(feats)

    def _load_corpus_from_source_mp(self) -> None:
        """
        Load a corpus using multiprocessing
        """
        begin_time = time.process_time()
        job_queue = mp.Queue()
        return_queue = mp.Queue()
        finished_adding = Stopped()
        stopped = Stopped()
        file_counts = Counter()
        sanitize_function = getattr(self, "sanitize_function", None)
        error_dict = {}
        procs = []
        self.db_engine.dispose()
        parser = AcousticDirectoryParser(
            self.corpus_directory,
            job_queue,
            self.audio_directory,
            stopped,
            finished_adding,
            file_counts,
        )
        parser.start()
        for i in range(self.num_jobs):
            p = CorpusProcessWorker(
                i,
                job_queue,
                return_queue,
                stopped,
                finished_adding,
                self.speaker_characters,
                sanitize_function,
                self.sample_frequency,
            )
            procs.append(p)
            p.start()
        last_poll = time.time() - 30
        import_data = DatabaseImportData()
        try:
            with self.session() as session:
                with tqdm.tqdm(total=100,
                               disable=getattr(self, "quiet", False)) as pbar:
                    while True:
                        try:
                            file = return_queue.get(timeout=1)
                            if isinstance(file, tuple):
                                error_type = file[0]
                                error = file[1]
                                if error_type == "error":
                                    error_dict[error_type] = error
                                else:
                                    if error_type not in error_dict:
                                        error_dict[error_type] = []
                                    error_dict[error_type].append(error)
                                continue
                            if self.stopped.stop_check():
                                continue
                        except Empty:
                            for proc in procs:
                                if not proc.finished_processing.stop_check():
                                    break
                            else:
                                break
                            continue
                        if time.time() - last_poll > 15:
                            pbar.total = file_counts.value()
                            last_poll = time.time()
                        pbar.update(1)
                        import_data.add_objects(
                            self.generate_import_objects(file))

                    self.log_debug(
                        f"Processing queue: {time.process_time() - begin_time}"
                    )

                    if "error" in error_dict:
                        session.rollback()
                        raise error_dict["error"][1]
                    self._finalize_load(session, import_data)
            for k in [
                    "sound_file_errors", "decode_error_files",
                    "textgrid_read_errors"
            ]:
                if hasattr(self, k):
                    if k in error_dict:
                        self.log_info(
                            "There were some issues with files in the corpus. "
                            "Please look at the log file or run the validator for more information."
                        )
                        self.log_debug(
                            f"{k} showed {len(error_dict[k])} errors:")
                        if k in {"textgrid_read_errors", "sound_file_errors"}:
                            getattr(self, k).update(error_dict[k])
                            for e in error_dict[k]:
                                self.log_debug(f"{e.file_name}: {e.error}")
                        else:
                            self.log_debug(", ".join(error_dict[k]))
                            setattr(self, k, error_dict[k])

        except Exception as e:
            if isinstance(e, KeyboardInterrupt):
                self.log_info(
                    "Detected ctrl-c, please wait a moment while we clean everything up..."
                )
                self.stopped.set_sigint_source()
            self.stopped.stop()
            finished_adding.stop()
            while True:
                try:
                    _ = job_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
                try:
                    _ = return_queue.get(timeout=1)
                    _ = job_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
        finally:
            parser.join()
            for p in procs:
                p.join()
            if self.stopped.stop_check():
                self.log_info(
                    f"Stopped parsing early ({time.process_time() - begin_time} seconds)"
                )
                if self.stopped.source():
                    sys.exit(0)
            else:
                self.log_debug(
                    f"Parsed corpus directory with {self.num_jobs} jobs in {time.process_time() - begin_time} seconds"
                )

    def _load_corpus_from_source(self) -> None:
        """
        Load a corpus without using multiprocessing
        """
        begin_time = time.time()
        sanitize_function = None
        if hasattr(self, "sanitize_function"):
            sanitize_function = self.sanitize_function

        all_sound_files = {}
        use_audio_directory = False
        if self.audio_directory and os.path.exists(self.audio_directory):
            use_audio_directory = True
            for root, _, files in os.walk(self.audio_directory,
                                          followlinks=True):
                if self.stopped.stop_check():
                    return
                exts = find_exts(files)
                exts.wav_files = {
                    k: os.path.join(root, v)
                    for k, v in exts.wav_files.items()
                }
                exts.other_audio_files = {
                    k: os.path.join(root, v)
                    for k, v in exts.other_audio_files.items()
                }
                all_sound_files.update(exts.other_audio_files)
                all_sound_files.update(exts.wav_files)
        self.log_debug(f"Walking through {self.corpus_directory}...")
        import_data = DatabaseImportData()
        with self.session() as session:
            for root, _, files in os.walk(self.corpus_directory,
                                          followlinks=True):
                exts = find_exts(files)
                relative_path = root.replace(self.corpus_directory,
                                             "").lstrip("/").lstrip("\\")
                if self.stopped.stop_check():
                    return
                if not use_audio_directory:
                    all_sound_files = {}
                    wav_files = {
                        k: os.path.join(root, v)
                        for k, v in exts.wav_files.items()
                    }
                    other_audio_files = {
                        k: os.path.join(root, v)
                        for k, v in exts.other_audio_files.items()
                    }
                    all_sound_files.update(other_audio_files)
                    all_sound_files.update(wav_files)
                for file_name in exts.identifiers:

                    wav_path = None
                    transcription_path = None
                    if file_name in all_sound_files:
                        wav_path = all_sound_files[file_name]
                    if file_name in exts.lab_files:
                        lab_name = exts.lab_files[file_name]
                        transcription_path = os.path.join(root, lab_name)
                    elif file_name in exts.textgrid_files:
                        tg_name = exts.textgrid_files[file_name]
                        transcription_path = os.path.join(root, tg_name)
                    if wav_path is None and transcription_path is None:  # Not a file for MFA
                        continue
                    if wav_path is None:
                        self.transcriptions_without_wavs.append(
                            transcription_path)
                        continue
                    if transcription_path is None:
                        self.no_transcription_files.append(wav_path)
                    try:
                        file = FileData.parse_file(
                            file_name,
                            wav_path,
                            transcription_path,
                            relative_path,
                            self.speaker_characters,
                            sanitize_function,
                            self.sample_frequency,
                        )
                        import_data.add_objects(
                            self.generate_import_objects(file))
                    except TextParseError as e:
                        self.decode_error_files.append(e)
                    except TextGridParseError as e:
                        self.textgrid_read_errors.append(e)
                    except SoundFileError as e:
                        self.sound_file_errors.append(e)
            self._finalize_load(session, import_data)
        if self.decode_error_files or self.textgrid_read_errors:
            self.log_info(
                "There were some issues with files in the corpus. "
                "Please look at the log file or run the validator for more information."
            )
            if self.decode_error_files:
                self.log_debug(
                    f"There were {len(self.decode_error_files)} errors decoding text files:"
                )
                self.log_debug(", ".join(self.decode_error_files))
            if self.textgrid_read_errors:
                self.log_debug(
                    f"There were {len(self.textgrid_read_errors)} errors decoding reading TextGrid files:"
                )
                for e in self.textgrid_read_errors:
                    self.log_debug(f"{e.file_name}: {e.error}")

        self.log_debug(
            f"Parsed corpus directory in {time.time() - begin_time} seconds")
Example #26
0
    def calc_fmllr(self, iteration: Optional[int] = None) -> None:
        """
        Multiprocessing function that computes speaker adaptation transforms via
        feature-space Maximum Likelihood Linear Regression (fMLLR).

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.CalcFmllrFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.calc_fmllr_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        :kaldi_steps:`train_sat`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Calculating fMLLR for speaker adaptation...")

        arguments = self.calc_fmllr_arguments(iteration=iteration)
        with tqdm.tqdm(total=self.num_speakers,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CalcFmllrFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = CalcFmllrFunction(args)
                    for _ in function.run():
                        pbar.update(1)

        self.speaker_independent = False
        self.log_debug(f"Fmllr calculation took {time.time() - begin}")
    def acc_global_stats(self) -> None:
        """
        Multiprocessing function that accumulates global GMM stats

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.AccGlobalStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.DubmTrainer.acc_global_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`gmm-global-sum-accs`
            Relevant Kaldi binary
        :kaldi_steps:`train_diag_ubm`
            Reference Kaldi script

        """
        begin = time.time()
        self.log_info("Accumulating global stats...")
        arguments = self.acc_global_stats_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = AccGlobalStatsFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if isinstance(result, Exception):
                        error_dict[getattr(result, "job_name", 0)] = result
                        continue
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = AccGlobalStatsFunction(args)
                for _ in function.run():
                    pass

        self.log_debug(f"Accumulating stats took {time.time() - begin}")

        # Don't remove low-count Gaussians till the last tier,
        # or gselect info won't be valid anymore
        if self.iteration < self.num_iterations:
            opt = "--remove-low-count-gaussians=false"
        else:
            opt = f"--remove-low-count-gaussians={self.remove_low_count_gaussians}"
        log_path = os.path.join(self.working_log_directory,
                                f"update.{self.iteration}.log")
        with open(log_path, "w") as log_file:
            acc_files = []
            for j in arguments:
                acc_files.append(j.acc_path)
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-global-sum-accs"), "-"] + acc_files,
                stderr=log_file,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            gmm_global_est_proc = subprocess.Popen(
                [
                    thirdparty_binary("gmm-global-est"),
                    opt,
                    f"--min-gaussian-weight={self.min_gaussian_weight}",
                    self.model_path,
                    "-",
                    self.next_model_path,
                ],
                stderr=log_file,
                stdin=sum_proc.stdout,
                env=os.environ,
            )
            gmm_global_est_proc.communicate()
            # Clean up
            if not self.debug:
                for p in acc_files:
                    os.remove(p)
Example #28
0
    def _load_corpus_from_source_mp(self) -> None:
        """
        Load a corpus using multiprocessing
        """
        begin_time = time.process_time()
        job_queue = mp.Queue()
        return_queue = mp.Queue()
        finished_adding = Stopped()
        stopped = Stopped()
        file_counts = Counter()
        sanitize_function = getattr(self, "sanitize_function", None)
        error_dict = {}
        procs = []
        self.db_engine.dispose()
        parser = AcousticDirectoryParser(
            self.corpus_directory,
            job_queue,
            self.audio_directory,
            stopped,
            finished_adding,
            file_counts,
        )
        parser.start()
        for i in range(self.num_jobs):
            p = CorpusProcessWorker(
                i,
                job_queue,
                return_queue,
                stopped,
                finished_adding,
                self.speaker_characters,
                sanitize_function,
                self.sample_frequency,
            )
            procs.append(p)
            p.start()
        last_poll = time.time() - 30
        import_data = DatabaseImportData()
        try:
            with self.session() as session:
                with tqdm.tqdm(total=100,
                               disable=getattr(self, "quiet", False)) as pbar:
                    while True:
                        try:
                            file = return_queue.get(timeout=1)
                            if isinstance(file, tuple):
                                error_type = file[0]
                                error = file[1]
                                if error_type == "error":
                                    error_dict[error_type] = error
                                else:
                                    if error_type not in error_dict:
                                        error_dict[error_type] = []
                                    error_dict[error_type].append(error)
                                continue
                            if self.stopped.stop_check():
                                continue
                        except Empty:
                            for proc in procs:
                                if not proc.finished_processing.stop_check():
                                    break
                            else:
                                break
                            continue
                        if time.time() - last_poll > 15:
                            pbar.total = file_counts.value()
                            last_poll = time.time()
                        pbar.update(1)
                        import_data.add_objects(
                            self.generate_import_objects(file))

                    self.log_debug(
                        f"Processing queue: {time.process_time() - begin_time}"
                    )

                    if "error" in error_dict:
                        session.rollback()
                        raise error_dict["error"][1]
                    self._finalize_load(session, import_data)
            for k in [
                    "sound_file_errors", "decode_error_files",
                    "textgrid_read_errors"
            ]:
                if hasattr(self, k):
                    if k in error_dict:
                        self.log_info(
                            "There were some issues with files in the corpus. "
                            "Please look at the log file or run the validator for more information."
                        )
                        self.log_debug(
                            f"{k} showed {len(error_dict[k])} errors:")
                        if k in {"textgrid_read_errors", "sound_file_errors"}:
                            getattr(self, k).update(error_dict[k])
                            for e in error_dict[k]:
                                self.log_debug(f"{e.file_name}: {e.error}")
                        else:
                            self.log_debug(", ".join(error_dict[k]))
                            setattr(self, k, error_dict[k])

        except Exception as e:
            if isinstance(e, KeyboardInterrupt):
                self.log_info(
                    "Detected ctrl-c, please wait a moment while we clean everything up..."
                )
                self.stopped.set_sigint_source()
            self.stopped.stop()
            finished_adding.stop()
            while True:
                try:
                    _ = job_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
                try:
                    _ = return_queue.get(timeout=1)
                    _ = job_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
        finally:
            parser.join()
            for p in procs:
                p.join()
            if self.stopped.stop_check():
                self.log_info(
                    f"Stopped parsing early ({time.process_time() - begin_time} seconds)"
                )
                if self.stopped.source():
                    sys.exit(0)
            else:
                self.log_debug(
                    f"Parsed corpus directory with {self.num_jobs} jobs in {time.process_time() - begin_time} seconds"
                )
    def acc_ivector_stats(self) -> None:
        """
        Multiprocessing function that accumulates ivector extraction stats.

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.AccIvectorStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.IvectorTrainer.acc_ivector_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`ivector-extractor-sum-accs`
            Relevant Kaldi binary
        :kaldi_src:`ivector-extractor-est`
            Relevant Kaldi binary
        :kaldi_steps_sid:`train_ivector_extractor`
            Reference Kaldi script
        """

        begin = time.time()
        self.log_info("Accumulating ivector stats...")
        arguments = self.acc_ivector_stats_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = AccIvectorStatsFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
                if isinstance(result, KaldiProcessingError):
                    error_dict[result.job_name] = result
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = AccIvectorStatsFunction(args)
                for _ in function.run():
                    pass

        self.log_debug(f"Accumulating stats took {time.time() - begin}")

        log_path = os.path.join(self.working_log_directory,
                                f"sum_acc.{self.iteration}.log")
        acc_path = os.path.join(self.working_directory,
                                f"acc.{self.iteration}")
        with open(log_path, "w", encoding="utf8") as log_file:
            accinits = []
            for j in arguments:
                accinits.append(j.acc_init_path)
            sum_accs_proc = subprocess.Popen(
                [
                    thirdparty_binary("ivector-extractor-sum-accs"),
                    "--parallel=true"
                ] + accinits + [acc_path],
                stderr=log_file,
                env=os.environ,
            )

            sum_accs_proc.communicate()
        # clean up
        for p in accinits:
            os.remove(p)
        # Est extractor
        log_path = os.path.join(self.working_log_directory,
                                f"update.{self.iteration}.log")
        with open(log_path, "w") as log_file:
            extractor_est_proc = subprocess.Popen(
                [
                    thirdparty_binary("ivector-extractor-est"),
                    f"--num-threads={len(self.jobs)}",
                    f"--gaussian-min-count={self.gaussian_min_count}",
                    self.ie_path,
                    os.path.join(self.working_directory,
                                 f"acc.{self.iteration}"),
                    self.next_ie_path,
                ],
                stderr=log_file,
                env=os.environ,
            )
            extractor_est_proc.communicate()
Example #30
0
    def lda_acc_stats(self) -> None:
        """
        Multiprocessing function that accumulates LDA statistics.

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.LdaTrainer.lda_acc_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`est-lda`
            Relevant Kaldi binary
        :kaldi_steps:`train_lda_mllt`
            Reference Kaldi script

        """
        worker_lda_path = os.path.join(self.worker.working_directory, "lda.mat")
        lda_path = os.path.join(self.working_directory, "lda.mat")
        if os.path.exists(worker_lda_path):
            os.remove(worker_lda_path)
        arguments = self.lda_acc_stats_arguments()
        with tqdm.tqdm(
            total=self.num_current_utterances, disable=getattr(self, "quiet", False)
        ) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = LdaAccStatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    done, errors = result
                    pbar.update(done + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = LdaAccStatsFunction(args)
                    for done, errors in function.run():
                        pbar.update(done + errors)

        log_path = os.path.join(self.working_log_directory, "lda_est.log")
        acc_list = []
        for x in arguments:
            acc_list.extend(x.acc_paths.values())
        with open(log_path, "w", encoding="utf8") as log_file:
            est_lda_proc = subprocess.Popen(
                [
                    thirdparty_binary("est-lda"),
                    f"--dim={self.lda_dimension}",
                    lda_path,
                ]
                + acc_list,
                stderr=log_file,
                env=os.environ,
            )
            est_lda_proc.communicate()
        shutil.copyfile(
            lda_path,
            worker_lda_path,
        )