def convert_alignments(self) -> None:
        """
        Multiprocessing function that converts alignments from previous training

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.triphone.ConvertAlignmentsFunction`
            Multiprocessing helper function for each job
        :meth:`.TriphoneTrainer.convert_alignments_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`train_deltas`
            Reference Kaldi script
        :kaldi_steps:`train_lda_mllt`
            Reference Kaldi script
        :kaldi_steps:`train_sat`
            Reference Kaldi script

        """
        self.log_info("Converting alignments...")
        arguments = self.convert_alignments_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ConvertAlignmentsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    num_utterances, errors = result
                    pbar.update(num_utterances + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ConvertAlignmentsFunction(args)
                    for num_utterances, errors in function.run():
                        pbar.update(num_utterances + errors)
示例#2
0
    def compute_vad(self) -> None:
        """
        Compute Voice Activity Detection features over the corpus

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.ComputeVadFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.compute_vad_arguments`
            Job method for generating arguments for helper function
        """
        if os.path.exists(os.path.join(self.split_directory, "vad.0.scp")):
            self.log_info("VAD already computed, skipping!")
            return
        begin = time.time()
        self.log_info("Computing VAD...")

        arguments = self.compute_vad_arguments()
        with tqdm.tqdm(total=self.num_speakers,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ComputeVadFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if isinstance(result, KaldiProcessingError):
                        error_dict[result.job_name] = result
                        continue
                    done, no_feats, unvoiced = result
                    pbar.update(done + no_feats + unvoiced)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ComputeVadFunction(args)
                    for done, no_feats, unvoiced in function.run():
                        pbar.update(done + no_feats + unvoiced)
        self.log_debug(f"VAD computation took {time.time() - begin}")
    def gmm_gselect(self) -> None:
        """
        Multiprocessing function that stores Gaussian selection indices on disk

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.GmmGselectFunction`
            Multiprocessing helper function for each job
        :meth:`.DubmTrainer.gmm_gselect_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`train_diag_ubm`
            Reference Kaldi script

        """
        begin = time.time()
        self.log_info("Selecting gaussians...")
        arguments = self.gmm_gselect_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = GmmGselectFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if isinstance(result, Exception):
                        error_dict[getattr(result, "job_name", 0)] = result
                        continue
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = GmmGselectFunction(args)
                for _ in function.run():
                    pass

            self.log_debug(f"Gaussian selection took {time.time() - begin}")
示例#4
0
    def extract_ivectors(self) -> None:
        """
        Multiprocessing function that extracts job_name-vectors.

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.ExtractIvectorsFunction`
            Multiprocessing helper function for each job
        :meth:`.IvectorCorpusMixin.extract_ivectors_arguments`
            Job method for generating arguments for helper function
        :kaldi_steps_sid:`extract_ivectors`
            Reference Kaldi script
        """
        begin = time.time()

        log_dir = self.working_log_directory
        os.makedirs(log_dir, exist_ok=True)

        arguments = self.extract_ivectors_arguments()
        with tqdm.tqdm(total=self.num_speakers, disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ExtractIvectorsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ExtractIvectorsFunction(args)
                    for _ in function.run():
                        pbar.update(1)
        self.log_debug(f"Ivector extraction took {time.time() - begin}")
    def gauss_to_post(self) -> None:
        """
        Multiprocessing function that does Gaussian selection and posterior extraction

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.GaussToPostFunction`
            Multiprocessing helper function for each job
        :meth:`.IvectorTrainer.gauss_to_post_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps_sid:`train_ivector_extractor`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Extracting posteriors...")
        arguments = self.gauss_to_post_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = GaussToPostFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
                if isinstance(result, KaldiProcessingError):
                    error_dict[result.job_name] = result
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = GaussToPostFunction(args)
                for _ in function.run():
                    pass

        self.log_debug(f"Extracting posteriors took {time.time() - begin}")
    def align_utterances(self, training=False) -> None:
        """
        Multiprocessing function that aligns based on the current model.

        See Also
        --------
        :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction`
            Multiprocessing helper function for each job
        :meth:`.AlignMixin.align_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_si`
            Reference Kaldi script
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Generating alignments...")
        with tqdm.tqdm(
                total=self.num_current_utterances,
                disable=getattr(self, "quiet",
                                False)) as pbar, self.session() as session:
            if not training:
                utterances = session.query(Utterance)
                if hasattr(self, "subset"):
                    utterances = utterances.filter(
                        Utterance.in_subset == True)  # noqa
                utterances.update({"alignment_log_likelihood": None})
                session.commit()
            update_mappings = []
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(self.align_arguments()):
                    function = AlignFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if not training:
                        utterance, log_likelihood = result
                        update_mappings.append({
                            "id":
                            utterance,
                            "alignment_log_likelihood":
                            log_likelihood
                        })
                    pbar.update(1)
                for p in procs:
                    p.join()

                if not training and len(update_mappings) == 0:
                    raise NoAlignmentsError(self.num_current_utterances,
                                            self.beam, self.retry_beam)
                if error_dict:
                    for v in error_dict.values():
                        raise v

            else:
                self.log_debug("Not using multiprocessing...")
                for args in self.align_arguments():
                    function = AlignFunction(args)
                    for utterance, log_likelihood in function.run():
                        if not training:
                            update_mappings.append({
                                "id":
                                utterance,
                                "alignment_log_likelihood":
                                log_likelihood
                            })
                        pbar.update(1)
                if not training and len(update_mappings) == 0:
                    raise NoAlignmentsError(self.num_current_utterances,
                                            self.beam, self.retry_beam)
            if not training:
                session.bulk_update_mappings(Utterance, update_mappings)
                session.query(Utterance).filter(
                    Utterance.alignment_log_likelihood != None  # noqa
                ).update(
                    {
                        Utterance.alignment_log_likelihood:
                        Utterance.alignment_log_likelihood /
                        Utterance.num_frames
                    },
                    synchronize_session="fetch",
                )
                session.commit()
            self.log_debug(f"Alignment round took {time.time() - begin}")
    def compile_train_graphs(self) -> None:
        """
        Multiprocessing function that compiles training graphs for utterances.

        See Also
        --------
        :class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsFunction`
            Multiprocessing helper function for each job
        :meth:`.AlignMixin.compile_train_graphs_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_si`
            Reference Kaldi script
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        """
        begin = time.time()
        log_directory = self.working_log_directory
        os.makedirs(log_directory, exist_ok=True)
        self.log_info("Compiling training graphs...")
        error_sum = 0
        arguments = self.compile_train_graphs_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CompileTrainGraphsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    done, errors = result
                    pbar.update(done + errors)
                    error_sum += errors
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                self.log_debug("Not using multiprocessing...")
                for args in arguments:
                    function = CompileTrainGraphsFunction(args)
                    for done, errors in function.run():
                        pbar.update(done + errors)
                        error_sum += errors
        if error_sum:
            self.log_warning(
                f"Compilation of training graphs failed for {error_sum} utterances."
            )
        self.log_debug(f"Compiling training graphs took {time.time() - begin}")
示例#8
0
    def acc_stats(self, alignment: bool = False) -> None:
        """
        Accumulate stats for the mapped model

        Parameters
        ----------
        alignment: bool
            Flag for whether to accumulate stats for the mapped alignment model
        """
        arguments = self.map_acc_stats_arguments(alignment)
        if alignment:
            initial_mdl_path = os.path.join(self.working_directory,
                                            "unadapted.alimdl")
            final_mdl_path = os.path.join(self.working_directory,
                                          "final.alimdl")
        else:
            initial_mdl_path = os.path.join(self.working_directory,
                                            "unadapted.mdl")
            final_mdl_path = os.path.join(self.working_directory, "final.mdl")
        if not os.path.exists(initial_mdl_path):
            return
        self.log_info("Accumulating statistics...")
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = AccStatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    num_utterances, errors = result
                    pbar.update(num_utterances + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = AccStatsFunction(args)
                    for num_utterances, errors in function.run():
                        pbar.update(num_utterances + errors)
        log_path = os.path.join(self.working_log_directory,
                                "map_model_est.log")
        occs_path = os.path.join(self.working_directory, "final.occs")
        with open(log_path, "w", encoding="utf8") as log_file:
            acc_files = []
            for j in arguments:
                acc_files.extend(j.acc_paths.values())
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files,
                stderr=log_file,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            ismooth_proc = subprocess.Popen(
                [
                    thirdparty_binary("gmm-ismooth-stats"),
                    "--smooth-from-model",
                    f"--tau={self.mapping_tau}",
                    initial_mdl_path,
                    "-",
                    "-",
                ],
                stderr=log_file,
                stdin=sum_proc.stdout,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            est_proc = subprocess.Popen(
                [
                    thirdparty_binary("gmm-est"),
                    "--update-flags=m",
                    f"--write-occs={occs_path}",
                    "--remove-low-count-gaussians=false",
                    initial_mdl_path,
                    "-",
                    final_mdl_path,
                ],
                stdin=ismooth_proc.stdout,
                stderr=log_file,
                env=os.environ,
            )
            est_proc.communicate()
    def train_g2p_lexicon(self) -> None:
        """Generate a G2P lexicon based on aligned transcripts"""
        arguments = self.worker.generate_pronunciations_arguments()
        working_dir = super(PronunciationProbabilityTrainer,
                            self).working_directory
        texts = {}
        with self.worker.session() as session:
            query = session.query(Utterance.id,
                                  Utterance.normalized_character_text)
            query = query.filter(Utterance.ignored == False)  # noqa
            initial_brackets = "".join(x[0] for x in self.worker.brackets)
            query = query.filter(
                ~Utterance.oovs.regexp_match(f"(^| )[^{initial_brackets}]"))
            if self.subset:
                query = query.filter_by(in_subset=True)
            for utt_id, text in query:
                texts[utt_id] = text
            input_files = {
                x: open(
                    os.path.join(
                        working_dir,
                        f"input_{self.worker.dictionary_base_names[x]}.txt"),
                    "w",
                    encoding="utf8",
                )
                for x in self.worker.dictionary_lookup.values()
            }
            output_files = {
                x: open(
                    os.path.join(
                        working_dir,
                        f"output_{self.worker.dictionary_base_names[x]}.txt"),
                    "w",
                    encoding="utf8",
                )
                for x in self.worker.dictionary_lookup.values()
            }
            with tqdm.tqdm(total=self.num_current_utterances,
                           disable=getattr(self, "quiet", False)) as pbar:
                if self.use_mp:
                    error_dict = {}
                    return_queue = mp.Queue()
                    stopped = Stopped()
                    procs = []
                    for i, args in enumerate(arguments):
                        args.for_g2p = True
                        function = GeneratePronunciationsFunction(args)
                        p = KaldiProcessWorker(i, return_queue, function,
                                               stopped)
                        procs.append(p)
                        p.start()
                    while True:
                        try:
                            result = return_queue.get(timeout=1)
                            if isinstance(result, Exception):
                                error_dict[getattr(result, "job_name",
                                                   0)] = result
                                continue
                            if stopped.stop_check():
                                continue
                        except Empty:
                            for proc in procs:
                                if not proc.finished.stop_check():
                                    break
                            else:
                                break
                            continue
                        dict_id, utt_id, phones = result
                        utt_id = int(utt_id.split("-")[-1])
                        pbar.update(1)
                        if utt_id not in texts or not texts[utt_id]:
                            continue
                        print(phones, file=output_files[dict_id])
                        print(f"<s> {texts[utt_id]} </s>",
                              file=input_files[dict_id])

                    for p in procs:
                        p.join()
                    if error_dict:
                        for v in error_dict.values():
                            raise v
                else:
                    self.log_debug("Not using multiprocessing...")
                    for args in arguments:
                        function = GeneratePronunciationsFunction(args)
                        for dict_id, utt_id, phones in function.run():
                            print(phones, file=output_files[dict_id])
                            print(f"<s> {texts[utt_id]} </s>",
                                  file=input_files[dict_id])
                            pbar.update(1)
            for f in input_files.values():
                f.close()
            for f in output_files.values():
                f.close()
            self.pronunciations_complete = True
            os.makedirs(self.working_log_directory, exist_ok=True)
            dictionaries = session.query(Dictionary)
            shutil.copyfile(self.phone_symbol_table_path,
                            os.path.join(self.working_directory, "phones.txt"))
            shutil.copyfile(
                self.grapheme_symbol_table_path,
                os.path.join(self.working_directory, "graphemes.txt"),
            )
            self.input_token_type = self.grapheme_symbol_table_path
            self.output_token_type = self.phone_symbol_table_path
            for d in dictionaries:
                self.log_info(f"Training G2P for {d.name}...")
                self._data_source = self.worker.dictionary_base_names[d.id]
                begin = time.time()
                if os.path.exists(self.far_path) and os.path.exists(
                        self.encoder_path):
                    self.log_info("Alignment already done, skipping!")
                else:
                    self.align_g2p()
                    self.log_debug(
                        f"Aligning utterances for {d.name} took {time.time() - begin} seconds"
                    )
                begin = time.time()
                self.generate_model()
                self.log_debug(
                    f"Generating model for {d.name} took {time.time() - begin} seconds"
                )
                os.rename(d.lexicon_fst_path, d.lexicon_fst_path + ".backup")
                shutil.copy(self.fst_path, d.lexicon_fst_path)
                d.use_g2p = True
            session.commit()
            self.worker.use_g2p = True
    def acc_stats(self) -> None:
        """
        Multiprocessing function that accumulates stats for GMM training.

        See Also
        --------
        :class:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticModelTrainingMixin.acc_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`gmm-sum-accs`
            Relevant Kaldi binary
        :kaldi_src:`gmm-est`
            Relevant Kaldi binary
        :kaldi_steps:`train_mono`
            Reference Kaldi script
        :kaldi_steps:`train_deltas`
            Reference Kaldi script
        """
        self.log_info("Accumulating statistics...")
        arguments = self.acc_stats_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = AccStatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    num_utterances, errors = result
                    pbar.update(num_utterances + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = AccStatsFunction(args)
                    for num_utterances, errors in function.run():
                        pbar.update(num_utterances + errors)

        log_path = os.path.join(self.working_log_directory,
                                f"update.{self.iteration}.log")
        with open(log_path, "w") as log_file:
            acc_files = []
            for a in arguments:
                acc_files.extend(a.acc_paths.values())
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files,
                stdout=subprocess.PIPE,
                stderr=log_file,
                env=os.environ,
            )
            est_command = [
                thirdparty_binary("gmm-est"),
                f"--write-occs={self.next_occs_path}",
                f"--mix-up={self.current_gaussians}",
            ]
            if self.power > 0:
                est_command.append(f"--power={self.power}")
            est_command.extend([
                self.model_path,
                "-",
                self.next_model_path,
            ])
            est_proc = subprocess.Popen(
                est_command,
                stdin=sum_proc.stdout,
                stderr=log_file,
                env=os.environ,
            )
            est_proc.communicate()
        avg_like_pattern = re.compile(
            r"Overall avg like per frame.* = (?P<like>[-.,\d]+) over (?P<frames>[.\d+e]+) frames"
        )
        average_logdet_pattern = re.compile(
            r"Overall average logdet is (?P<logdet>[-.,\d]+) over (?P<frames>[.\d+e]+) frames"
        )
        avg_like_sum = 0
        avg_like_frames = 0
        average_logdet_sum = 0
        average_logdet_frames = 0
        for a in arguments:
            with open(a.log_path, "r", encoding="utf8") as f:
                for line in f:
                    m = avg_like_pattern.search(line)
                    if m:
                        like = float(m.group("like"))
                        frames = float(m.group("frames"))
                        avg_like_sum += like * frames
                        avg_like_frames += frames
                    m = average_logdet_pattern.search(line)
                    if m:
                        logdet = float(m.group("logdet"))
                        frames = float(m.group("frames"))
                        average_logdet_sum += logdet * frames
                        average_logdet_frames += frames
        if avg_like_frames:
            log_like = avg_like_sum / avg_like_frames
            if average_logdet_frames:
                log_like += average_logdet_sum / average_logdet_frames
            self.log_debug(
                f"Likelihood for iteration {self.iteration}: {log_like}")

        if not self.debug:
            for f in acc_files:
                os.remove(f)
    def _alignments(self) -> None:
        """Trains the aligner and constructs the alignments FAR."""
        if not os.path.exists(self.align_path):
            self.log_info("Training aligner")
            train_opts = []
            if self.batch_size:
                train_opts.append(f"--batch_size={self.batch_size}")
            if self.delta:
                train_opts.append(f"--delta={self.delta}")
            if self.fst_default_cache_gc:
                train_opts.append(
                    f"--fst_default_cache_gc={self.fst_default_cache_gc}")
            if self.fst_default_cache_gc_limit:
                train_opts.append(
                    f"--fst_default_cache_gc_limit={self.fst_default_cache_gc_limit}"
                )
            if self.alpha:
                train_opts.append(f"--alpha={self.alpha}")
            if self.num_iterations:
                train_opts.append(f"--max_iters={self.num_iterations}")
            # Constructs the actual command vectors (plus an index for logging
            # purposes).
            random.seed(self.seed)
            starts = [(RandomStart(
                idx,
                seed,
                self.input_far_path,
                self.output_far_path,
                self.cg_path,
                self.working_directory,
                train_opts,
            )) for (idx, seed) in enumerate(
                random.sample(range(1, RAND_MAX), self.random_starts), 1)]
            stopped = Stopped()
            num_commands = len(starts)
            job_queue = mp.JoinableQueue()
            fst_likelihoods = {}
            # Actually runs starts.
            self.log_info("Calculating alignments...")
            begin = time.time()
            with tqdm.tqdm(total=num_commands * self.num_iterations,
                           disable=getattr(self, "quiet", False)) as pbar:
                for start in starts:
                    job_queue.put(start)
                error_dict = {}
                return_queue = mp.Queue()
                procs = []
                for i in range(self.num_jobs):
                    log_path = os.path.join(self.working_log_directory,
                                            f"baumwelch.{i}.log")
                    p = RandomStartWorker(
                        i,
                        job_queue,
                        return_queue,
                        log_path,
                        stopped,
                    )
                    procs.append(p)
                    p.start()

                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except queue.Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if isinstance(result, int):
                        pbar.update(result)
                    else:
                        fst_likelihoods[result[0]] = result[1]
                for p in procs:
                    p.join()
            if error_dict:
                raise PyniniAlignmentError(error_dict)
            (best_fst, best_likelihood) = min(fst_likelihoods.items(),
                                              key=operator.itemgetter(1))
            self.log_info(f"Best likelihood: {best_likelihood}")
            self.log_debug(
                f"Ran {self.random_starts} random starts in {time.time() - begin} seconds"
            )
            # Moves best likelihood solution to the requested location.
            shutil.move(best_fst, self.align_path)
        cmd = [thirdparty_binary("baumwelchdecode")]
        if self.fst_default_cache_gc:
            cmd.append(f"--fst_default_cache_gc={self.fst_default_cache_gc}")
        if self.fst_default_cache_gc_limit:
            cmd.append(
                f"--fst_default_cache_gc_limit={self.fst_default_cache_gc_limit}"
            )
        cmd.append(self.input_far_path)
        cmd.append(self.output_far_path)
        cmd.append(self.align_path)
        cmd.append(self.afst_path)
        self.log_debug(f"Subprocess call: {cmd}")
        subprocess.check_call(cmd, env=os.environ)
        self.log_info("Completed computing alignments!")
示例#12
0
    def generate_pronunciations(self) -> Dict[str, List[str]]:
        """
        Generate pronunciations

        Returns
        -------
        dict[str, list[str]]
            Mappings of keys to their generated pronunciations
        """

        fst = pynini.Fst.read(self.g2p_model.fst_path)
        if self.g2p_model.meta["architecture"] == "phonetisaurus":
            output_token_type = pynini.SymbolTable.read_text(
                self.g2p_model.sym_path)
            input_token_type = pynini.SymbolTable.read_text(
                self.g2p_model.grapheme_sym_path)
            fst.set_input_symbols(input_token_type)
            fst.set_output_symbols(output_token_type)
            rewriter = PhonetisaurusRewriter(
                fst,
                input_token_type,
                output_token_type,
                num_pronunciations=self.num_pronunciations,
                threshold=self.g2p_threshold,
                grapheme_order=self.g2p_model.meta["grapheme_order"],
            )
        else:
            output_token_type = "utf8"
            input_token_type = "utf8"
            if self.g2p_model.sym_path is not None and os.path.exists(
                    self.g2p_model.sym_path):
                output_token_type = pynini.SymbolTable.read_text(
                    self.g2p_model.sym_path)
            rewriter = Rewriter(
                fst,
                input_token_type,
                output_token_type,
                num_pronunciations=self.num_pronunciations,
                threshold=self.g2p_threshold,
            )

        num_words = len(self.words_to_g2p)
        begin = time.time()
        missing_graphemes = set()
        self.log_info("Generating pronunciations...")
        to_return = {}
        skipped_words = 0
        if num_words < 30 or self.num_jobs == 1:
            with tqdm.tqdm(total=num_words,
                           disable=getattr(self, "quiet", False)) as pbar:
                for word in self.words_to_g2p:
                    w, m = clean_up_word(word,
                                         self.g2p_model.meta["graphemes"])
                    pbar.update(1)
                    missing_graphemes = missing_graphemes | m
                    if self.strict_graphemes and m:
                        skipped_words += 1
                        continue
                    if not w:
                        skipped_words += 1
                        continue
                    try:
                        prons = rewriter(w)
                    except rewrite.Error:
                        continue
                    to_return[word] = prons
                self.log_debug(
                    f"Skipping {skipped_words} words for containing the following graphemes: "
                    f"{comma_join(sorted(missing_graphemes))}")
        else:
            stopped = Stopped()
            job_queue = mp.Queue()
            for word in self.words_to_g2p:
                w, m = clean_up_word(word, self.g2p_model.meta["graphemes"])
                missing_graphemes = missing_graphemes | m
                if self.strict_graphemes and m:
                    skipped_words += 1
                    continue
                if not w:
                    skipped_words += 1
                    continue
                job_queue.put(w)
            self.log_debug(
                f"Skipping {skipped_words} words for containing the following graphemes: "
                f"{comma_join(sorted(missing_graphemes))}")
            error_dict = {}
            return_queue = mp.Queue()
            procs = []
            for _ in range(self.num_jobs):
                p = RewriterWorker(
                    job_queue,
                    return_queue,
                    rewriter,
                    stopped,
                )
                procs.append(p)
                p.start()
            num_words -= skipped_words
            with tqdm.tqdm(total=num_words,
                           disable=getattr(self, "quiet", False)) as pbar:
                while True:
                    try:
                        word, result = return_queue.get(timeout=1)
                        if stopped.stop_check():
                            continue
                    except queue.Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                    if isinstance(result, Exception):
                        error_dict[word] = result
                        continue
                    to_return[word] = result

            for p in procs:
                p.join()
            if error_dict:
                raise PyniniGenerationError(error_dict)
        self.log_debug(
            f"Processed {num_words} in {time.time() - begin} seconds")
        return to_return
示例#13
0
    def calc_lda_mllt(self) -> None:
        """
        Multiprocessing function that calculates LDA+MLLT transformations.

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction`
            Multiprocessing helper function for each job
        :meth:`.LdaTrainer.calc_lda_mllt_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`est-mllt`
            Relevant Kaldi binary
        :kaldi_src:`gmm-transform-means`
            Relevant Kaldi binary
        :kaldi_src:`compose-transforms`
            Relevant Kaldi binary
        :kaldi_steps:`train_lda_mllt`
            Reference Kaldi script

        """
        self.log_info("Re-calculating LDA...")
        arguments = self.calc_lda_mllt_arguments()
        with tqdm.tqdm(
            total=self.num_current_utterances, disable=getattr(self, "quiet", False)
        ) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CalcLdaMlltFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = CalcLdaMlltFunction(args)
                    for _ in function.run():
                        pbar.update(1)

        log_path = os.path.join(
            self.working_log_directory, f"transform_means.{self.iteration}.log"
        )
        previous_mat_path = os.path.join(self.working_directory, "lda.mat")
        new_mat_path = os.path.join(self.working_directory, "lda_new.mat")
        composed_path = os.path.join(self.working_directory, "lda_composed.mat")
        with open(log_path, "a", encoding="utf8") as log_file:
            macc_list = []
            for x in arguments:
                macc_list.extend(x.macc_paths.values())
            subprocess.call(
                [thirdparty_binary("est-mllt"), new_mat_path] + macc_list,
                stderr=log_file,
                env=os.environ,
            )
            subprocess.call(
                [
                    thirdparty_binary("gmm-transform-means"),
                    new_mat_path,
                    self.model_path,
                    self.model_path,
                ],
                stderr=log_file,
                env=os.environ,
            )

            if os.path.exists(previous_mat_path):
                subprocess.call(
                    [
                        thirdparty_binary("compose-transforms"),
                        new_mat_path,
                        previous_mat_path,
                        composed_path,
                    ],
                    stderr=log_file,
                    env=os.environ,
                )
                os.remove(previous_mat_path)
                os.rename(composed_path, previous_mat_path)
            else:
                os.rename(new_mat_path, previous_mat_path)
示例#14
0
    def lda_acc_stats(self) -> None:
        """
        Multiprocessing function that accumulates LDA statistics.

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.LdaTrainer.lda_acc_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`est-lda`
            Relevant Kaldi binary
        :kaldi_steps:`train_lda_mllt`
            Reference Kaldi script

        """
        worker_lda_path = os.path.join(self.worker.working_directory, "lda.mat")
        lda_path = os.path.join(self.working_directory, "lda.mat")
        if os.path.exists(worker_lda_path):
            os.remove(worker_lda_path)
        arguments = self.lda_acc_stats_arguments()
        with tqdm.tqdm(
            total=self.num_current_utterances, disable=getattr(self, "quiet", False)
        ) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = LdaAccStatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    done, errors = result
                    pbar.update(done + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = LdaAccStatsFunction(args)
                    for done, errors in function.run():
                        pbar.update(done + errors)

        log_path = os.path.join(self.working_log_directory, "lda_est.log")
        acc_list = []
        for x in arguments:
            acc_list.extend(x.acc_paths.values())
        with open(log_path, "w", encoding="utf8") as log_file:
            est_lda_proc = subprocess.Popen(
                [
                    thirdparty_binary("est-lda"),
                    f"--dim={self.lda_dimension}",
                    lda_path,
                ]
                + acc_list,
                stderr=log_file,
                env=os.environ,
            )
            est_lda_proc.communicate()
        shutil.copyfile(
            lda_path,
            worker_lda_path,
        )
    def segment_vad(self) -> None:
        """
        Run segmentation based off of VAD.

        See Also
        --------
        :class:`~montreal_forced_aligner.segmenter.SegmentVadFunction`
            Multiprocessing helper function for each job
        segment_vad_arguments
            Job method for generating arguments for helper function
        """

        arguments = self.segment_vad_arguments()
        old_utts = set()
        new_utts = []

        with tqdm.tqdm(total=self.num_utterances,
                       disable=getattr(
                           self, "quiet",
                           False)) as pbar, self.session() as session:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = SegmentVadFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                    while True:
                        try:
                            result = return_queue.get(timeout=1)
                            if isinstance(result, Exception):
                                error_dict[getattr(result, "job_name",
                                                   0)] = result
                                continue
                            if stopped.stop_check():
                                continue
                        except Empty:
                            for proc in procs:
                                if not proc.finished.stop_check():
                                    break
                            else:
                                break
                            continue
                        utt, begin, end = result
                        old_utts.add(utt)
                        channel, speaker_id, file_id = (session.query(
                            Utterance.channel, Utterance.speaker_id,
                            Utterance.file_id).filter(
                                Utterance.id == utt).first())
                        new_utts.append({
                            "begin": begin,
                            "end": end,
                            "text": "speech",
                            "speaker_id": speaker_id,
                            "file_id": file_id,
                            "oovs": "",
                            "normalized_text": "",
                            "normalized_text_int": "",
                            "features": "",
                            "in_subset": False,
                            "ignored": False,
                            "channel": channel,
                            "duration": end - begin,
                        })

                        pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = SegmentVadFunction(args)
                    for utt, begin, end in function.run():
                        old_utts.add(utt)
                        channel, speaker_id, file_id = (session.query(
                            Utterance.channel, Utterance.speaker_id,
                            Utterance.file_id).filter(
                                Utterance.id == utt).first())
                        new_utts.append({
                            "begin": begin,
                            "end": end,
                            "text": "speech",
                            "speaker_id": speaker_id,
                            "file_id": file_id,
                            "oovs": "",
                            "normalized_text": "",
                            "normalized_text_int": "",
                            "features": "",
                            "in_subset": False,
                            "ignored": False,
                            "channel": channel,
                            "duration": end - begin,
                        })
                        pbar.update(1)
            session.query(Utterance).filter(
                Utterance.id.in_(old_utts)).delete()
            session.bulk_insert_mappings(Utterance,
                                         new_utts,
                                         return_defaults=False,
                                         render_nulls=True)
            session.commit()
class TextCorpusMixin(CorpusMixin):
    """
    Abstract mixin class for processing text corpora

    See Also
    --------
    :class:`~montreal_forced_aligner.corpus.base.CorpusMixin`
        For corpus parsing parameters
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _load_corpus_from_source_mp(self) -> None:
        """
        Load a corpus using multiprocessing
        """
        if self.stopped is None:
            self.stopped = Stopped()
        sanitize_function = getattr(self, "sanitize_function", None)
        begin_time = time.time()
        job_queue = mp.Queue()
        return_queue = mp.Queue()
        error_dict = {}
        finished_adding = Stopped()
        procs = []
        for i in range(self.num_jobs):
            p = CorpusProcessWorker(
                i,
                job_queue,
                return_queue,
                self.stopped,
                finished_adding,
                self.speaker_characters,
                sanitize_function,
                sample_rate=0,
            )
            procs.append(p)
            p.start()
        import_data = DatabaseImportData()
        try:
            file_count = 0
            with tqdm.tqdm(total=1, disable=getattr(
                    self, "quiet", False)) as pbar, self.session() as session:
                for root, _, files in os.walk(self.corpus_directory,
                                              followlinks=True):
                    exts = find_exts(files)
                    relative_path = (root.replace(self.corpus_directory,
                                                  "").lstrip("/").lstrip("\\"))

                    if self.stopped.stop_check():
                        break
                    for file_name in exts.identifiers:
                        if self.stopped.stop_check():
                            break
                        wav_path = None
                        if file_name in exts.lab_files:
                            lab_name = exts.lab_files[file_name]
                            transcription_path = os.path.join(root, lab_name)

                        elif file_name in exts.textgrid_files:
                            tg_name = exts.textgrid_files[file_name]
                            transcription_path = os.path.join(root, tg_name)
                        else:
                            continue
                        job_queue.put((file_name, wav_path, transcription_path,
                                       relative_path))
                        file_count += 1
                        pbar.total = file_count

                finished_adding.stop()

                while True:
                    try:
                        file = return_queue.get(timeout=1)
                        if isinstance(file, tuple):
                            error_type = file[0]
                            error = file[1]
                            if error_type == "error":
                                error_dict[error_type] = error
                            else:
                                if error_type not in error_dict:
                                    error_dict[error_type] = []
                                error_dict[error_type].append(error)
                            continue
                        if self.stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished_processing.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                    import_data.add_objects(self.generate_import_objects(file))

                self.log_debug("Waiting for workers to finish...")
                for p in procs:
                    p.join()

                if "error" in error_dict:
                    session.rollback()
                    raise error_dict["error"][1]

                self._finalize_load(session, import_data)

                for k in ["decode_error_files", "textgrid_read_errors"]:
                    if hasattr(self, k):
                        if k in error_dict:
                            self.log_info(
                                "There were some issues with files in the corpus. "
                                "Please look at the log file or run the validator for more information."
                            )
                            self.log_debug(
                                f"{k} showed {len(error_dict[k])} errors:")
                            if k == "textgrid_read_errors":
                                getattr(self, k).update(error_dict[k])
                                for e in error_dict[k]:
                                    self.log_debug(f"{e.file_name}: {e.error}")
                            else:
                                self.log_debug(", ".join(error_dict[k]))
                                setattr(self, k, error_dict[k])

        except KeyboardInterrupt:
            self.log_info(
                "Detected ctrl-c, please wait a moment while we clean everything up..."
            )
            self.stopped.stop()
            finished_adding.stop()
            job_queue.join()
            self.stopped.set_sigint_source()
            while True:
                try:
                    _ = return_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
        finally:

            finished_adding.stop()
            for p in procs:
                p.join()
            if self.stopped.stop_check():
                self.log_info(
                    f"Stopped parsing early ({time.time() - begin_time} seconds)"
                )
                if self.stopped.source():
                    sys.exit(0)
            else:
                self.log_debug(
                    f"Parsed corpus directory with {self.num_jobs} jobs in {time.time() - begin_time} seconds"
                )

    def _load_corpus_from_source(self) -> None:
        """
        Load a corpus without using multiprocessing
        """
        begin_time = time.time()
        self.stopped = False

        import_data = DatabaseImportData()
        sanitize_function = getattr(self, "sanitize_function", None)
        with self.session() as session:
            for root, _, files in os.walk(self.corpus_directory,
                                          followlinks=True):
                exts = find_exts(files)
                relative_path = root.replace(self.corpus_directory,
                                             "").lstrip("/").lstrip("\\")
                if self.stopped:
                    return
                for file_name in exts.identifiers:

                    wav_path = None
                    if file_name in exts.lab_files:
                        lab_name = exts.lab_files[file_name]
                        transcription_path = os.path.join(root, lab_name)
                    elif file_name in exts.textgrid_files:
                        tg_name = exts.textgrid_files[file_name]
                        transcription_path = os.path.join(root, tg_name)
                    else:
                        continue
                    try:
                        file = FileData.parse_file(
                            file_name,
                            wav_path,
                            transcription_path,
                            relative_path,
                            self.speaker_characters,
                            sanitize_function,
                        )
                        import_data.add_objects(
                            self.generate_import_objects(file))
                    except TextParseError as e:
                        self.decode_error_files.append(e)
                    except TextGridParseError as e:
                        self.textgrid_read_errors.append(e)
            self._finalize_load(session, import_data)
        if self.decode_error_files or self.textgrid_read_errors:
            self.log_info(
                "There were some issues with files in the corpus. "
                "Please look at the log file or run the validator for more information."
            )
            if self.decode_error_files:
                self.log_debug(
                    f"There were {len(self.decode_error_files)} errors decoding text files:"
                )
                self.log_debug(", ".join(self.decode_error_files))
            if self.textgrid_read_errors:
                self.log_debug(
                    f"There were {len(self.textgrid_read_errors)} errors decoding reading TextGrid files:"
                )
                for e in self.textgrid_read_errors:
                    self.log_debug(f"{e.file_name}: {e.error}")

        self.log_debug(
            f"Parsed corpus directory in {time.time()-begin_time} seconds")
示例#17
0
    def mfcc(self) -> None:
        """
        Multiprocessing function that converts sound files into MFCCs.

        See :kaldi_docs:`feat` for an overview on feature generation in Kaldi.

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.MfccFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.mfcc_arguments`
            Job method for generating arguments for helper function
        :kaldi_steps:`make_mfcc`
            Reference Kaldi script
        """
        self.log_info("Generating MFCCs...")
        log_directory = os.path.join(self.split_directory, "log")
        os.makedirs(log_directory, exist_ok=True)
        arguments = self.mfcc_arguments()
        with tqdm.tqdm(total=self.num_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = MfccFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(result)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        print(v)
                    self.dirty = True
                    sys.exit(1)
            else:
                for args in arguments:
                    function = MfccFunction(args)
                    for num_utterances in function.run():
                        pbar.update(num_utterances)
        with self.session() as session:
            update_mapping = []
            session.query(Utterance).update({"ignored": True})
            for j in arguments:
                with open(j.feats_scp_path, "r", encoding="utf8") as f:
                    for line in f:
                        line = line.strip()
                        if line == "":
                            continue
                        f = line.split(maxsplit=1)
                        utt_id = int(f[0].split("-")[-1])
                        feats = f[1]
                        update_mapping.append({
                            "id": utt_id,
                            "features": feats,
                            "ignored": False
                        })
            session.bulk_update_mappings(Utterance, update_mapping)
            session.commit()
示例#18
0
    def calc_fmllr(self, iteration: Optional[int] = None) -> None:
        """
        Multiprocessing function that computes speaker adaptation transforms via
        feature-space Maximum Likelihood Linear Regression (fMLLR).

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.CalcFmllrFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.calc_fmllr_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        :kaldi_steps:`train_sat`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Calculating fMLLR for speaker adaptation...")

        arguments = self.calc_fmllr_arguments(iteration=iteration)
        with tqdm.tqdm(total=self.num_speakers,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CalcFmllrFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = CalcFmllrFunction(args)
                    for _ in function.run():
                        pbar.update(1)

        self.speaker_independent = False
        self.log_debug(f"Fmllr calculation took {time.time() - begin}")
    def acc_global_stats(self) -> None:
        """
        Multiprocessing function that accumulates global GMM stats

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.AccGlobalStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.DubmTrainer.acc_global_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`gmm-global-sum-accs`
            Relevant Kaldi binary
        :kaldi_steps:`train_diag_ubm`
            Reference Kaldi script

        """
        begin = time.time()
        self.log_info("Accumulating global stats...")
        arguments = self.acc_global_stats_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = AccGlobalStatsFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if isinstance(result, Exception):
                        error_dict[getattr(result, "job_name", 0)] = result
                        continue
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = AccGlobalStatsFunction(args)
                for _ in function.run():
                    pass

        self.log_debug(f"Accumulating stats took {time.time() - begin}")

        # Don't remove low-count Gaussians till the last tier,
        # or gselect info won't be valid anymore
        if self.iteration < self.num_iterations:
            opt = "--remove-low-count-gaussians=false"
        else:
            opt = f"--remove-low-count-gaussians={self.remove_low_count_gaussians}"
        log_path = os.path.join(self.working_log_directory,
                                f"update.{self.iteration}.log")
        with open(log_path, "w") as log_file:
            acc_files = []
            for j in arguments:
                acc_files.append(j.acc_path)
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-global-sum-accs"), "-"] + acc_files,
                stderr=log_file,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            gmm_global_est_proc = subprocess.Popen(
                [
                    thirdparty_binary("gmm-global-est"),
                    opt,
                    f"--min-gaussian-weight={self.min_gaussian_weight}",
                    self.model_path,
                    "-",
                    self.next_model_path,
                ],
                stderr=log_file,
                stdin=sum_proc.stdout,
                env=os.environ,
            )
            gmm_global_est_proc.communicate()
            # Clean up
            if not self.debug:
                for p in acc_files:
                    os.remove(p)
示例#20
0
class AcousticCorpusMixin(CorpusMixin, FeatureConfigMixin, metaclass=ABCMeta):
    """
    Mixin class for acoustic corpora

    Parameters
    ----------
    audio_directory: str
        Extra directory to look for audio files

    See Also
    --------
    :class:`~montreal_forced_aligner.corpus.base.CorpusMixin`
        For corpus parsing parameters
    :class:`~montreal_forced_aligner.corpus.features.FeatureConfigMixin`
        For feature generation parameters

    Attributes
    ----------
    sound_file_errors: list[str]
        List of sound files with errors in loading
    transcriptions_without_wavs: list[str]
        List of text files without sound files
    no_transcription_files: list[str]
        List of sound files without transcription files
    stopped: Stopped
        Stop check for loading the corpus
    """
    def __init__(self, audio_directory: Optional[str] = None, **kwargs):
        super().__init__(**kwargs)
        self.audio_directory = audio_directory
        self.sound_file_errors = []
        self.transcriptions_without_wavs = []
        self.no_transcription_files = []
        self.stopped = Stopped()
        self.features_generated = False
        self.alignment_done = False
        self.transcription_done = False
        self.has_reference_alignments = False
        self.alignment_evaluation_done = False

    def inspect_database(self) -> None:
        """Check if a database file exists and create the necessary metadata"""
        exist_check = os.path.exists(self.db_path)
        if not exist_check:
            self.initialize_database()
        with self.session() as session:
            corpus = session.query(Corpus).first()
            if corpus:
                self.imported = corpus.imported
                self.features_generated = corpus.features_generated
                self.alignment_done = corpus.alignment_done
                self.transcription_done = corpus.transcription_done
                self.has_reference_alignments = corpus.has_reference_alignments
                self.alignment_evaluation_done = corpus.alignment_evaluation_done
            else:
                session.add(Corpus(name=self.data_source_identifier))
                session.commit()

    def load_reference_alignments(self, reference_directory: str) -> None:
        """
        Load reference alignments to use in alignment evaluation from a directory

        Parameters
        ----------
        reference_directory: str
            Directory containing reference alignments

        """
        if self.has_reference_alignments:
            self.log_info("Reference alignments already loaded!")
            return
        self.log_info("Loading reference files...")
        indices = []
        jobs = []
        reference_intervals = []
        with tqdm.tqdm(total=self.num_files,
                       disable=getattr(self, "quiet", False)) as pbar, Session(
                           self.db_engine, autoflush=False) as session:
            for root, _, files in os.walk(reference_directory,
                                          followlinks=True):
                root_speaker = os.path.basename(root)
                for f in files:
                    if f.endswith(".TextGrid"):
                        file_name = f.replace(".TextGrid", "")
                        file_id = session.query(
                            File.id).filter_by(name=file_name).scalar()
                        if not file_id:
                            continue
                        if self.use_mp:
                            indices.append(file_id)
                            jobs.append((os.path.join(root, f), root_speaker))
                        else:
                            intervals = parse_aligned_textgrid(
                                os.path.join(root, f), root_speaker)
                            utterances = (session.query(
                                Utterance.id, Speaker.name,
                                Utterance.end).join(Utterance.speaker).join(
                                    Utterance.file).filter(File.id == file_id))
                            for u_id, speaker_name, end in utterances:
                                if speaker_name not in intervals:
                                    continue
                                while (intervals[speaker_name] and
                                       intervals[speaker_name][0].end <= end):
                                    interval = intervals[speaker_name].pop(0)
                                    reference_intervals.append({
                                        "begin":
                                        interval.begin,
                                        "end":
                                        interval.end,
                                        "label":
                                        interval.label,
                                        "utterance_id":
                                        u_id,
                                    })

                            pbar.update(1)
            if self.use_mp:
                with mp.Pool(self.num_jobs) as pool:
                    gen = pool.starmap(parse_aligned_textgrid, jobs)
                    for i, intervals in enumerate(gen):
                        pbar.update(1)
                        file_id = indices[i]
                        utterances = (session.query(
                            Utterance.id, Speaker.name,
                            Utterance.end).join(Utterance.speaker).filter(
                                Utterance.file_id == file_id))
                        for u_id, speaker_name, end in utterances:
                            if speaker_name not in intervals:
                                continue
                            while (intervals[speaker_name]
                                   and intervals[speaker_name][0].end <= end):
                                interval = intervals[speaker_name].pop(0)
                                reference_intervals.append({
                                    "begin":
                                    interval.begin,
                                    "end":
                                    interval.end,
                                    "label":
                                    interval.label,
                                    "utterance_id":
                                    u_id,
                                })
            with session.bind.begin() as conn:
                conn.execute(
                    sqlalchemy.insert(ReferencePhoneInterval.__table__),
                    reference_intervals)
                session.commit()
            session.query(Corpus).update({"has_reference_alignments": True})
            session.commit()

    def load_corpus(self) -> None:
        """
        Load the corpus
        """
        self.initialize_database()
        self._load_corpus()

        self.initialize_jobs()
        self.create_corpus_split()
        self.generate_features()

    def generate_features(self, compute_cmvn: bool = True) -> None:
        """
        Generate features for the corpus

        Parameters
        ----------
        compute_cmvn: bool
            Flag for whether to compute CMVN, defaults to True
        """
        if self.features_generated:
            return
        self.log_info(f"Generating base features ({self.feature_type})...")
        if self.feature_type == "mfcc":
            self.mfcc()
        self.combine_feats()
        if compute_cmvn:
            self.log_info("Calculating CMVN...")
            self.calc_cmvn()
        self.features_generated = True
        with self.session() as session:
            session.query(Corpus).update({"features_generated": True})
            session.commit()
        self.create_corpus_split()

    def create_corpus_split(self) -> None:
        """Create the split directory for the corpus"""
        if self.features_generated:
            self.log_info("Creating corpus split with features...")
            super().create_corpus_split()
        else:
            self.log_info("Creating corpus split for feature generation...")
            split_dir = self.split_directory
            os.makedirs(os.path.join(split_dir, "log"), exist_ok=True)
            with self.session() as session:
                for job in self.jobs:
                    job.output_for_features(split_dir, session)

    def construct_base_feature_string(self, all_feats: bool = False) -> str:
        """
        Construct the base feature string independent of job name

        Used in initialization of MonophoneTrainer (to get dimension size) and IvectorTrainer (uses all feats)

        Parameters
        ----------
        all_feats: bool
            Flag for whether all features across all jobs should be taken into account

        Returns
        -------
        str
            Base feature string
        """
        j = self.jobs[0]
        if all_feats:
            feat_path = os.path.join(self.base_data_directory, "feats.scp")
            utt2spk_path = os.path.join(self.base_data_directory,
                                        "utt2spk.scp")
            cmvn_path = os.path.join(self.base_data_directory, "cmvn.scp")
            feats = f'ark,s,cs:apply-cmvn --utt2spk=ark:"{utt2spk_path}" scp:"{cmvn_path}" scp:"{feat_path}" ark:- |'
            feats += " add-deltas ark:- ark:- |"
            return feats
        utt2spks = j.construct_path_dictionary(self.data_directory, "utt2spk",
                                               "scp")
        cmvns = j.construct_path_dictionary(self.data_directory, "cmvn", "scp")
        features = j.construct_path_dictionary(self.data_directory, "feats",
                                               "scp")
        for dict_id in j.dictionary_ids:
            feat_path = features[dict_id]
            cmvn_path = cmvns[dict_id]
            utt2spk_path = utt2spks[dict_id]
            feats = f'ark,s,cs:apply-cmvn --utt2spk=ark:"{utt2spk_path}" scp:"{cmvn_path}" scp:"{feat_path}" ark:- |'
            if self.uses_deltas:
                feats += " add-deltas ark:- ark:- |"

            return feats
        else:
            utt2spk_path = j.construct_path(self.data_directory, "utt2spk",
                                            "scp")
            cmvn_path = j.construct_path(self.data_directory, "cmvn", "scp")
            feat_path = j.construct_path(self.data_directory, "feats", "scp")
            feats = f'ark,s,cs:apply-cmvn --utt2spk=ark:"{utt2spk_path}" scp:"{cmvn_path}" scp:"{feat_path}" ark:- |'
            if self.uses_deltas:
                feats += " add-deltas ark:- ark:- |"
            return feats

    def construct_feature_proc_strings(
        self,
        speaker_independent: bool = False,
    ) -> typing.Union[List[Dict[str, str]], List[str]]:
        """
        Constructs a feature processing string to supply to Kaldi binaries, taking into account corpus features and the
        current working directory of the aligner (whether fMLLR or LDA transforms should be used, etc).

        Parameters
        ----------
        speaker_independent: bool
            Flag for whether features should be speaker-independent regardless of the presence of fMLLR transforms

        Returns
        -------
        list[dict[str, str]]
            Feature strings per job
        """
        strings = []
        for j in self.jobs:
            lda_mat_path = None
            fmllrs = {}
            if self.working_directory is not None:
                lda_mat_path = os.path.join(self.working_directory, "lda.mat")
                if not os.path.exists(lda_mat_path):
                    lda_mat_path = None

                fmllrs = j.construct_path_dictionary(self.working_directory,
                                                     "trans", "ark")
            utt2spks = j.construct_path_dictionary(self.data_directory,
                                                   "utt2spk", "scp")
            cmvns = j.construct_path_dictionary(self.data_directory, "cmvn",
                                                "scp")
            features = j.construct_path_dictionary(self.data_directory,
                                                   "feats", "scp")
            vads = j.construct_path_dictionary(self.data_directory, "vad",
                                               "scp")
            feat_strings = {}
            if not j.dictionary_ids:
                utt2spk_path = j.construct_path(self.data_directory, "utt2spk",
                                                "scp")
                cmvn_path = j.construct_path(self.data_directory, "cmvn",
                                             "scp")
                feat_path = j.construct_path(self.data_directory, "feats",
                                             "scp")
                feats = f'ark,s,cs:apply-cmvn --utt2spk=ark:"{utt2spk_path}" scp:"{cmvn_path}" scp:"{feat_path}" ark:- |'
                if self.uses_deltas:
                    feats += " add-deltas ark:- ark:- |"

                strings.append(feats)
                continue

            for dict_id in j.dictionary_ids:
                feat_path = features[dict_id]
                cmvn_path = cmvns[dict_id]
                utt2spk_path = utt2spks[dict_id]
                fmllr_trans_path = None
                try:
                    fmllr_trans_path = fmllrs[dict_id]
                    if not os.path.exists(fmllr_trans_path):
                        fmllr_trans_path = None
                except KeyError:
                    pass
                vad_path = vads[dict_id]
                if self.uses_voiced:
                    feats = f'ark,s,cs:add-deltas scp:"{feat_path}" ark:- |'
                    if self.uses_cmvn:
                        feats += " apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=300 ark:- ark:- |"
                    feats += f' select-voiced-frames ark:- scp,s,cs:"{vad_path}" ark:- |'
                elif not os.path.exists(cmvn_path) and self.uses_cmvn:
                    feats = f'ark,s,cs:add-deltas scp:"{feat_path}" ark:- |'
                    if self.uses_cmvn:
                        feats += " apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=300 ark:- ark:- |"
                else:
                    feats = f'ark,s,cs:apply-cmvn --utt2spk=ark:"{utt2spk_path}" scp:"{cmvn_path}" scp:"{feat_path}" ark:- |'
                    if lda_mat_path is not None:
                        feats += f" splice-feats --left-context={self.splice_left_context} --right-context={self.splice_right_context} ark:- ark:- |"
                        feats += f' transform-feats "{lda_mat_path}" ark:- ark:- |'
                    elif self.uses_splices:
                        feats += f" splice-feats --left-context={self.splice_left_context} --right-context={self.splice_right_context} ark:- ark:- |"
                    elif self.uses_deltas:
                        feats += " add-deltas ark:- ark:- |"
                    if fmllr_trans_path is not None and not (
                            self.speaker_independent or speaker_independent):
                        if not os.path.exists(fmllr_trans_path):
                            raise Exception(
                                f"Could not find {fmllr_trans_path}")
                        feats += f' transform-feats --utt2spk=ark:"{utt2spk_path}" ark:"{fmllr_trans_path}" ark:- ark:- |'
                feat_strings[dict_id] = feats
            strings.append(feat_strings)
        return strings

    def compute_vad_arguments(self) -> List[VadArguments]:
        """
        Generate Job arguments for :class:`~montreal_forced_aligner.corpus.features.ComputeVadFunction`

        Returns
        -------
        list[:class:`~montreal_forced_aligner.corpus.features.VadArguments`]
            Arguments for processing
        """
        return [
            VadArguments(
                j.name,
                getattr(self, "db_engine", ""),
                os.path.join(self.split_directory, "log",
                             f"compute_vad.{j.name}.log"),
                j.construct_path(self.split_directory, "feats", "scp"),
                j.construct_path(self.split_directory, "vad", "scp"),
                self.vad_options,
            ) for j in self.jobs if j.has_data
        ]

    def calc_fmllr_arguments(self,
                             iteration: Optional[int] = None
                             ) -> List[CalcFmllrArguments]:
        """
        Generate Job arguments for :class:`~montreal_forced_aligner.corpus.features.CalcFmllrFunction`

        Returns
        -------
        list[:class:`~montreal_forced_aligner.corpus.features.CalcFmllrArguments`]
            Arguments for processing
        """
        feature_strings = self.construct_feature_proc_strings()
        base_log = "calc_fmllr"
        if iteration is not None:
            base_log += f".{iteration}"
        return [
            CalcFmllrArguments(
                j.name,
                getattr(self, "db_path", ""),
                os.path.join(self.working_log_directory,
                             f"{base_log}.{j.name}.log"),
                j.dictionary_ids,
                feature_strings[j.name],
                j.construct_path_dictionary(self.working_directory, "ali",
                                            "ark"),
                self.alignment_model_path,
                self.model_path,
                j.construct_path_dictionary(self.data_directory, "spk2utt",
                                            "scp"),
                j.construct_path_dictionary(self.working_directory, "trans",
                                            "ark"),
                self.fmllr_options,
            ) for j in self.jobs if j.has_data
        ]

    def mfcc_arguments(self) -> List[MfccArguments]:
        """
        Generate Job arguments for :class:`~montreal_forced_aligner.corpus.features.MfccFunction`

        Returns
        -------
        list[:class:`~montreal_forced_aligner.corpus.features.MfccArguments`]
            Arguments for processing
        """
        return [
            MfccArguments(
                j.name,
                self.db_path,
                os.path.join(self.split_directory, "log",
                             f"make_mfcc.{j.name}.log"),
                j.construct_path(self.split_directory, "wav", "scp"),
                j.construct_path(self.split_directory, "segments", "scp"),
                j.construct_path(self.split_directory, "feats", "scp"),
                self.mfcc_options,
                self.pitch_options,
            ) for j in self.jobs if j.has_data
        ]

    def mfcc(self) -> None:
        """
        Multiprocessing function that converts sound files into MFCCs.

        See :kaldi_docs:`feat` for an overview on feature generation in Kaldi.

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.MfccFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.mfcc_arguments`
            Job method for generating arguments for helper function
        :kaldi_steps:`make_mfcc`
            Reference Kaldi script
        """
        self.log_info("Generating MFCCs...")
        log_directory = os.path.join(self.split_directory, "log")
        os.makedirs(log_directory, exist_ok=True)
        arguments = self.mfcc_arguments()
        with tqdm.tqdm(total=self.num_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = MfccFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(result)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        print(v)
                    self.dirty = True
                    sys.exit(1)
            else:
                for args in arguments:
                    function = MfccFunction(args)
                    for num_utterances in function.run():
                        pbar.update(num_utterances)
        with self.session() as session:
            update_mapping = []
            session.query(Utterance).update({"ignored": True})
            for j in arguments:
                with open(j.feats_scp_path, "r", encoding="utf8") as f:
                    for line in f:
                        line = line.strip()
                        if line == "":
                            continue
                        f = line.split(maxsplit=1)
                        utt_id = int(f[0].split("-")[-1])
                        feats = f[1]
                        update_mapping.append({
                            "id": utt_id,
                            "features": feats,
                            "ignored": False
                        })
            session.bulk_update_mappings(Utterance, update_mapping)
            session.commit()

    def calc_cmvn(self) -> None:
        """
        Calculate CMVN statistics for speakers

        See Also
        --------
        :kaldi_src:`compute-cmvn-stats`
            Relevant Kaldi binary
        """
        self._write_feats()
        self._write_spk2utt()
        spk2utt = os.path.join(self.corpus_output_directory, "spk2utt.scp")
        feats = os.path.join(self.corpus_output_directory, "feats.scp")
        cmvn_ark = os.path.join(self.corpus_output_directory, "cmvn.ark")
        cmvn_scp = os.path.join(self.corpus_output_directory, "cmvn.scp")
        log_path = os.path.join(self.features_log_directory, "cmvn.log")
        with open(log_path, "w") as logf:
            subprocess.call(
                [
                    thirdparty_binary("compute-cmvn-stats"),
                    f"--spk2utt=ark:{spk2utt}",
                    f"scp:{feats}",
                    f"ark,scp:{cmvn_ark},{cmvn_scp}",
                ],
                stderr=logf,
                env=os.environ,
            )
        update_mapping = []
        with self.session() as session:
            for s, cmvn in load_scp(cmvn_scp).items():
                if isinstance(cmvn, list):
                    cmvn = " ".join(cmvn)
                update_mapping.append({"id": int(s), "cmvn": cmvn})
            session.bulk_update_mappings(Speaker, update_mapping)
            session.commit()

    def calc_fmllr(self, iteration: Optional[int] = None) -> None:
        """
        Multiprocessing function that computes speaker adaptation transforms via
        feature-space Maximum Likelihood Linear Regression (fMLLR).

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.CalcFmllrFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.calc_fmllr_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        :kaldi_steps:`train_sat`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Calculating fMLLR for speaker adaptation...")

        arguments = self.calc_fmllr_arguments(iteration=iteration)
        with tqdm.tqdm(total=self.num_speakers,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CalcFmllrFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = CalcFmllrFunction(args)
                    for _ in function.run():
                        pbar.update(1)

        self.speaker_independent = False
        self.log_debug(f"Fmllr calculation took {time.time() - begin}")

    def compute_vad(self) -> None:
        """
        Compute Voice Activity Detection features over the corpus

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.ComputeVadFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.compute_vad_arguments`
            Job method for generating arguments for helper function
        """
        if os.path.exists(os.path.join(self.split_directory, "vad.0.scp")):
            self.log_info("VAD already computed, skipping!")
            return
        begin = time.time()
        self.log_info("Computing VAD...")

        arguments = self.compute_vad_arguments()
        with tqdm.tqdm(total=self.num_speakers,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ComputeVadFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if isinstance(result, KaldiProcessingError):
                        error_dict[result.job_name] = result
                        continue
                    done, no_feats, unvoiced = result
                    pbar.update(done + no_feats + unvoiced)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ComputeVadFunction(args)
                    for done, no_feats, unvoiced in function.run():
                        pbar.update(done + no_feats + unvoiced)
        self.log_debug(f"VAD computation took {time.time() - begin}")

    def combine_feats(self) -> None:
        """
        Combine feature generation results and store relevant information
        """

        with self.session() as session:
            ignored_utterances = (
                session.query(
                    SoundFile.sound_file_path,
                    Speaker.name,
                    Utterance.begin,
                    Utterance.end,
                    Utterance.text,
                ).join(Utterance.speaker).join(Utterance.file).join(
                    File.sound_file).filter(Utterance.ignored == True)  # noqa
            )
            ignored_count = 0
            for sound_file_path, speaker_name, begin, end, text in ignored_utterances:
                self.log_debug(f"  - Ignored File: {sound_file_path}")
                self.log_debug(f"    - Speaker: {speaker_name}")
                self.log_debug(f"    - Begin: {begin}")
                self.log_debug(f"    - End: {end}")
                self.log_debug(f"    - Text: {text}")
                ignored_count += 1
            if ignored_count:
                self.log_warning(
                    f"There were {ignored_count} utterances ignored due to an issue in feature generation, see the log file for full "
                    "details or run `mfa validate` on the corpus.")

    def _write_feats(self) -> None:
        """Write feats scp file for Kaldi"""
        feats_path = os.path.join(self.corpus_output_directory, "feats.scp")
        with self.session() as session, open(feats_path, "w",
                                             encoding="utf8") as f:
            utterances = (session.query(
                Utterance.kaldi_id,
                Utterance.features).filter_by(ignored=False).order_by(
                    Utterance.kaldi_id))
            for u_id, features in utterances:
                f.write(f"{u_id} {features}\n")

    def get_feat_dim(self) -> int:
        """
        Calculate the feature dimension for the corpus

        Returns
        -------
        int
            Dimension of feature vectors
        """
        feature_string = self.construct_base_feature_string()
        with open(os.path.join(self.features_log_directory, "feat-to-dim.log"),
                  "w") as log_file:
            subset_proc = subprocess.Popen(
                [
                    thirdparty_binary("subset-feats"),
                    "--n=1",
                    feature_string,
                    "ark:-",
                ],
                stderr=log_file,
                stdout=subprocess.PIPE,
            )
            dim_proc = subprocess.Popen(
                [thirdparty_binary("feat-to-dim"), "ark:-", "-"],
                stdin=subset_proc.stdout,
                stdout=subprocess.PIPE,
                stderr=log_file,
            )
            stdout, stderr = dim_proc.communicate()
            feats = stdout.decode("utf8").strip()
        return int(feats)

    def _load_corpus_from_source_mp(self) -> None:
        """
        Load a corpus using multiprocessing
        """
        begin_time = time.process_time()
        job_queue = mp.Queue()
        return_queue = mp.Queue()
        finished_adding = Stopped()
        stopped = Stopped()
        file_counts = Counter()
        sanitize_function = getattr(self, "sanitize_function", None)
        error_dict = {}
        procs = []
        self.db_engine.dispose()
        parser = AcousticDirectoryParser(
            self.corpus_directory,
            job_queue,
            self.audio_directory,
            stopped,
            finished_adding,
            file_counts,
        )
        parser.start()
        for i in range(self.num_jobs):
            p = CorpusProcessWorker(
                i,
                job_queue,
                return_queue,
                stopped,
                finished_adding,
                self.speaker_characters,
                sanitize_function,
                self.sample_frequency,
            )
            procs.append(p)
            p.start()
        last_poll = time.time() - 30
        import_data = DatabaseImportData()
        try:
            with self.session() as session:
                with tqdm.tqdm(total=100,
                               disable=getattr(self, "quiet", False)) as pbar:
                    while True:
                        try:
                            file = return_queue.get(timeout=1)
                            if isinstance(file, tuple):
                                error_type = file[0]
                                error = file[1]
                                if error_type == "error":
                                    error_dict[error_type] = error
                                else:
                                    if error_type not in error_dict:
                                        error_dict[error_type] = []
                                    error_dict[error_type].append(error)
                                continue
                            if self.stopped.stop_check():
                                continue
                        except Empty:
                            for proc in procs:
                                if not proc.finished_processing.stop_check():
                                    break
                            else:
                                break
                            continue
                        if time.time() - last_poll > 15:
                            pbar.total = file_counts.value()
                            last_poll = time.time()
                        pbar.update(1)
                        import_data.add_objects(
                            self.generate_import_objects(file))

                    self.log_debug(
                        f"Processing queue: {time.process_time() - begin_time}"
                    )

                    if "error" in error_dict:
                        session.rollback()
                        raise error_dict["error"][1]
                    self._finalize_load(session, import_data)
            for k in [
                    "sound_file_errors", "decode_error_files",
                    "textgrid_read_errors"
            ]:
                if hasattr(self, k):
                    if k in error_dict:
                        self.log_info(
                            "There were some issues with files in the corpus. "
                            "Please look at the log file or run the validator for more information."
                        )
                        self.log_debug(
                            f"{k} showed {len(error_dict[k])} errors:")
                        if k in {"textgrid_read_errors", "sound_file_errors"}:
                            getattr(self, k).update(error_dict[k])
                            for e in error_dict[k]:
                                self.log_debug(f"{e.file_name}: {e.error}")
                        else:
                            self.log_debug(", ".join(error_dict[k]))
                            setattr(self, k, error_dict[k])

        except Exception as e:
            if isinstance(e, KeyboardInterrupt):
                self.log_info(
                    "Detected ctrl-c, please wait a moment while we clean everything up..."
                )
                self.stopped.set_sigint_source()
            self.stopped.stop()
            finished_adding.stop()
            while True:
                try:
                    _ = job_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
                try:
                    _ = return_queue.get(timeout=1)
                    _ = job_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
        finally:
            parser.join()
            for p in procs:
                p.join()
            if self.stopped.stop_check():
                self.log_info(
                    f"Stopped parsing early ({time.process_time() - begin_time} seconds)"
                )
                if self.stopped.source():
                    sys.exit(0)
            else:
                self.log_debug(
                    f"Parsed corpus directory with {self.num_jobs} jobs in {time.process_time() - begin_time} seconds"
                )

    def _load_corpus_from_source(self) -> None:
        """
        Load a corpus without using multiprocessing
        """
        begin_time = time.time()
        sanitize_function = None
        if hasattr(self, "sanitize_function"):
            sanitize_function = self.sanitize_function

        all_sound_files = {}
        use_audio_directory = False
        if self.audio_directory and os.path.exists(self.audio_directory):
            use_audio_directory = True
            for root, _, files in os.walk(self.audio_directory,
                                          followlinks=True):
                if self.stopped.stop_check():
                    return
                exts = find_exts(files)
                exts.wav_files = {
                    k: os.path.join(root, v)
                    for k, v in exts.wav_files.items()
                }
                exts.other_audio_files = {
                    k: os.path.join(root, v)
                    for k, v in exts.other_audio_files.items()
                }
                all_sound_files.update(exts.other_audio_files)
                all_sound_files.update(exts.wav_files)
        self.log_debug(f"Walking through {self.corpus_directory}...")
        import_data = DatabaseImportData()
        with self.session() as session:
            for root, _, files in os.walk(self.corpus_directory,
                                          followlinks=True):
                exts = find_exts(files)
                relative_path = root.replace(self.corpus_directory,
                                             "").lstrip("/").lstrip("\\")
                if self.stopped.stop_check():
                    return
                if not use_audio_directory:
                    all_sound_files = {}
                    wav_files = {
                        k: os.path.join(root, v)
                        for k, v in exts.wav_files.items()
                    }
                    other_audio_files = {
                        k: os.path.join(root, v)
                        for k, v in exts.other_audio_files.items()
                    }
                    all_sound_files.update(other_audio_files)
                    all_sound_files.update(wav_files)
                for file_name in exts.identifiers:

                    wav_path = None
                    transcription_path = None
                    if file_name in all_sound_files:
                        wav_path = all_sound_files[file_name]
                    if file_name in exts.lab_files:
                        lab_name = exts.lab_files[file_name]
                        transcription_path = os.path.join(root, lab_name)
                    elif file_name in exts.textgrid_files:
                        tg_name = exts.textgrid_files[file_name]
                        transcription_path = os.path.join(root, tg_name)
                    if wav_path is None and transcription_path is None:  # Not a file for MFA
                        continue
                    if wav_path is None:
                        self.transcriptions_without_wavs.append(
                            transcription_path)
                        continue
                    if transcription_path is None:
                        self.no_transcription_files.append(wav_path)
                    try:
                        file = FileData.parse_file(
                            file_name,
                            wav_path,
                            transcription_path,
                            relative_path,
                            self.speaker_characters,
                            sanitize_function,
                            self.sample_frequency,
                        )
                        import_data.add_objects(
                            self.generate_import_objects(file))
                    except TextParseError as e:
                        self.decode_error_files.append(e)
                    except TextGridParseError as e:
                        self.textgrid_read_errors.append(e)
                    except SoundFileError as e:
                        self.sound_file_errors.append(e)
            self._finalize_load(session, import_data)
        if self.decode_error_files or self.textgrid_read_errors:
            self.log_info(
                "There were some issues with files in the corpus. "
                "Please look at the log file or run the validator for more information."
            )
            if self.decode_error_files:
                self.log_debug(
                    f"There were {len(self.decode_error_files)} errors decoding text files:"
                )
                self.log_debug(", ".join(self.decode_error_files))
            if self.textgrid_read_errors:
                self.log_debug(
                    f"There were {len(self.textgrid_read_errors)} errors decoding reading TextGrid files:"
                )
                for e in self.textgrid_read_errors:
                    self.log_debug(f"{e.file_name}: {e.error}")

        self.log_debug(
            f"Parsed corpus directory in {time.time() - begin_time} seconds")
    def acc_ivector_stats(self) -> None:
        """
        Multiprocessing function that accumulates ivector extraction stats.

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.AccIvectorStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.IvectorTrainer.acc_ivector_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`ivector-extractor-sum-accs`
            Relevant Kaldi binary
        :kaldi_src:`ivector-extractor-est`
            Relevant Kaldi binary
        :kaldi_steps_sid:`train_ivector_extractor`
            Reference Kaldi script
        """

        begin = time.time()
        self.log_info("Accumulating ivector stats...")
        arguments = self.acc_ivector_stats_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = AccIvectorStatsFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
                if isinstance(result, KaldiProcessingError):
                    error_dict[result.job_name] = result
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = AccIvectorStatsFunction(args)
                for _ in function.run():
                    pass

        self.log_debug(f"Accumulating stats took {time.time() - begin}")

        log_path = os.path.join(self.working_log_directory,
                                f"sum_acc.{self.iteration}.log")
        acc_path = os.path.join(self.working_directory,
                                f"acc.{self.iteration}")
        with open(log_path, "w", encoding="utf8") as log_file:
            accinits = []
            for j in arguments:
                accinits.append(j.acc_init_path)
            sum_accs_proc = subprocess.Popen(
                [
                    thirdparty_binary("ivector-extractor-sum-accs"),
                    "--parallel=true"
                ] + accinits + [acc_path],
                stderr=log_file,
                env=os.environ,
            )

            sum_accs_proc.communicate()
        # clean up
        for p in accinits:
            os.remove(p)
        # Est extractor
        log_path = os.path.join(self.working_log_directory,
                                f"update.{self.iteration}.log")
        with open(log_path, "w") as log_file:
            extractor_est_proc = subprocess.Popen(
                [
                    thirdparty_binary("ivector-extractor-est"),
                    f"--num-threads={len(self.jobs)}",
                    f"--gaussian-min-count={self.gaussian_min_count}",
                    self.ie_path,
                    os.path.join(self.working_directory,
                                 f"acc.{self.iteration}"),
                    self.next_ie_path,
                ],
                stderr=log_file,
                env=os.environ,
            )
            extractor_est_proc.communicate()
    def create_align_model(self) -> None:
        """
        Create alignment model for speaker-adapted training that will use speaker-independent
        features in later aligning.

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.sat.AccStatsTwoFeatsFunction`
            Multiprocessing helper function for each job
        :meth:`.SatTrainer.acc_stats_two_feats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`gmm-est`
            Relevant Kaldi binary
        :kaldi_src:`gmm-sum-accs`
            Relevant Kaldi binary
        :kaldi_steps:`train_sat`
            Reference Kaldi script
        """
        self.log_info(
            "Creating alignment model for speaker-independent features...")
        begin = time.time()

        arguments = self.acc_stats_two_feats_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = AccStatsTwoFeatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = AccStatsTwoFeatsFunction(args)
                    for _ in function.run():
                        pbar.update(1)

        log_path = os.path.join(self.working_log_directory,
                                "align_model_est.log")
        with open(log_path, "w", encoding="utf8") as log_file:

            acc_files = []
            for x in arguments:
                acc_files.extend(x.acc_paths.values())
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files,
                stderr=log_file,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            est_command = [
                thirdparty_binary("gmm-est"),
                "--remove-low-count-gaussians=false",
            ]
            if not self.quick:
                est_command.append(f"--power={self.power}")
            else:
                est_command.append(
                    f"--write-occs={os.path.join(self.working_directory, 'final.occs')}"
                )
            est_command.extend([
                self.model_path,
                "-",
                self.model_path.replace(".mdl", ".alimdl"),
            ])
            est_proc = subprocess.Popen(
                est_command,
                stdin=sum_proc.stdout,
                stderr=log_file,
                env=os.environ,
            )
            est_proc.communicate()
        parse_logs(self.working_log_directory)
        if not self.debug:
            for f in acc_files:
                os.remove(f)
        self.log_debug(f"Alignment model creation took {time.time() - begin}")