def __init__(
     self,
     corpus_directory: str,
     speaker_characters: typing.Union[int, str] = 0,
     ignore_speakers: bool = False,
     **kwargs,
 ):
     if not os.path.exists(corpus_directory):
         raise CorpusError(
             f"The directory '{corpus_directory}' does not exist.")
     if not os.path.isdir(corpus_directory):
         raise CorpusError(
             f"The specified path for the corpus ({corpus_directory}) is not a directory."
         )
     self._speaker_ids = {}
     self.corpus_directory = corpus_directory
     self.speaker_characters = speaker_characters
     self.ignore_speakers = ignore_speakers
     self.word_counts = Counter()
     self.stopped = Stopped()
     self.decode_error_files = []
     self.textgrid_read_errors = []
     self.jobs: typing.List[Job] = []
     self._num_speakers = None
     self._num_utterances = None
     self._num_files = None
     super().__init__(**kwargs)
     os.makedirs(self.corpus_output_directory, exist_ok=True)
     self.imported = False
     self._current_speaker_index = 1
     self._current_file_index = 1
     self._speaker_ids = {}
    def convert_alignments(self) -> None:
        """
        Multiprocessing function that converts alignments from previous training

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.triphone.ConvertAlignmentsFunction`
            Multiprocessing helper function for each job
        :meth:`.TriphoneTrainer.convert_alignments_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`train_deltas`
            Reference Kaldi script
        :kaldi_steps:`train_lda_mllt`
            Reference Kaldi script
        :kaldi_steps:`train_sat`
            Reference Kaldi script

        """
        self.log_info("Converting alignments...")
        arguments = self.convert_alignments_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ConvertAlignmentsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    num_utterances, errors = result
                    pbar.update(num_utterances + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ConvertAlignmentsFunction(args)
                    for num_utterances, errors in function.run():
                        pbar.update(num_utterances + errors)
Exemple #3
0
    def compute_vad(self) -> None:
        """
        Compute Voice Activity Detection features over the corpus

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.ComputeVadFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.compute_vad_arguments`
            Job method for generating arguments for helper function
        """
        if os.path.exists(os.path.join(self.split_directory, "vad.0.scp")):
            self.log_info("VAD already computed, skipping!")
            return
        begin = time.time()
        self.log_info("Computing VAD...")

        arguments = self.compute_vad_arguments()
        with tqdm.tqdm(total=self.num_speakers,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ComputeVadFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if isinstance(result, KaldiProcessingError):
                        error_dict[result.job_name] = result
                        continue
                    done, no_feats, unvoiced = result
                    pbar.update(done + no_feats + unvoiced)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ComputeVadFunction(args)
                    for done, no_feats, unvoiced in function.run():
                        pbar.update(done + no_feats + unvoiced)
        self.log_debug(f"VAD computation took {time.time() - begin}")
    def gmm_gselect(self) -> None:
        """
        Multiprocessing function that stores Gaussian selection indices on disk

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.GmmGselectFunction`
            Multiprocessing helper function for each job
        :meth:`.DubmTrainer.gmm_gselect_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`train_diag_ubm`
            Reference Kaldi script

        """
        begin = time.time()
        self.log_info("Selecting gaussians...")
        arguments = self.gmm_gselect_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = GmmGselectFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if isinstance(result, Exception):
                        error_dict[getattr(result, "job_name", 0)] = result
                        continue
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = GmmGselectFunction(args)
                for _ in function.run():
                    pass

            self.log_debug(f"Gaussian selection took {time.time() - begin}")
Exemple #5
0
    def extract_ivectors(self) -> None:
        """
        Multiprocessing function that extracts job_name-vectors.

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.ExtractIvectorsFunction`
            Multiprocessing helper function for each job
        :meth:`.IvectorCorpusMixin.extract_ivectors_arguments`
            Job method for generating arguments for helper function
        :kaldi_steps_sid:`extract_ivectors`
            Reference Kaldi script
        """
        begin = time.time()

        log_dir = self.working_log_directory
        os.makedirs(log_dir, exist_ok=True)

        arguments = self.extract_ivectors_arguments()
        with tqdm.tqdm(total=self.num_speakers, disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = ExtractIvectorsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = ExtractIvectorsFunction(args)
                    for _ in function.run():
                        pbar.update(1)
        self.log_debug(f"Ivector extraction took {time.time() - begin}")
    def gauss_to_post(self) -> None:
        """
        Multiprocessing function that does Gaussian selection and posterior extraction

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.GaussToPostFunction`
            Multiprocessing helper function for each job
        :meth:`.IvectorTrainer.gauss_to_post_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps_sid:`train_ivector_extractor`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Extracting posteriors...")
        arguments = self.gauss_to_post_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = GaussToPostFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
                if isinstance(result, KaldiProcessingError):
                    error_dict[result.job_name] = result
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = GaussToPostFunction(args)
                for _ in function.run():
                    pass

        self.log_debug(f"Extracting posteriors took {time.time() - begin}")
Exemple #7
0
 def __init__(self, audio_directory: Optional[str] = None, **kwargs):
     super().__init__(**kwargs)
     self.audio_directory = audio_directory
     self.sound_file_errors = []
     self.transcriptions_without_wavs = []
     self.no_transcription_files = []
     self.stopped = Stopped()
     self.features_generated = False
     self.alignment_done = False
     self.transcription_done = False
     self.has_reference_alignments = False
     self.alignment_evaluation_done = False
Exemple #8
0
 def __init__(
     self,
     job_queue: mp.Queue,
     return_queue: mp.Queue,
     rewriter: Rewriter,
     stopped: Stopped,
 ):
     mp.Process.__init__(self)
     self.job_queue = job_queue
     self.return_queue = return_queue
     self.rewriter = rewriter
     self.stopped = stopped
     self.finished = Stopped()
 def __init__(
     self,
     job_name: int,
     job_q: mp.Queue,
     return_queue: mp.Queue,
     log_file: str,
     stopped: Stopped,
 ):
     mp.Process.__init__(self)
     self.job_name = job_name
     self.job_q = job_q
     self.return_queue = return_queue
     self.log_file = log_file
     self.stopped = stopped
     self.finished = Stopped()
 def __init__(
     self,
     name: int,
     job_q: mp.Queue,
     return_q: mp.Queue,
     stopped: Stopped,
     finished_adding: Stopped,
     speaker_characters: Union[int, str],
     sanitize_function: Optional[MultispeakerSanitizationFunction],
     sample_rate: Optional[int],
 ):
     mp.Process.__init__(self)
     self.name = str(name)
     self.job_q = job_q
     self.return_q = return_q
     self.stopped = stopped
     self.finished_adding = finished_adding
     self.finished_processing = Stopped()
     self.sanitize_function = sanitize_function
     self.speaker_characters = speaker_characters
     self.sample_rate = sample_rate
    def __init__(
        self,
        db_path: str,
        for_write_queue: mp.Queue,
        return_queue: mp.Queue,
        stopped: Stopped,
        finished_adding: Stopped,
        arguments: ExportTextGridArguments,
        exported_file_count: Counter,
    ):
        mp.Process.__init__(self)
        self.db_path = db_path
        self.for_write_queue = for_write_queue
        self.return_queue = return_queue
        self.stopped = stopped
        self.finished_adding = finished_adding
        self.finished_processing = Stopped()

        self.output_directory = arguments.output_directory
        self.output_format = arguments.output_format
        self.frame_shift = arguments.frame_shift
        self.log_path = arguments.log_path
        self.include_original_text = arguments.include_original_text
        self.exported_file_count = exported_file_count
Exemple #12
0
    def calc_fmllr(self, iteration: Optional[int] = None) -> None:
        """
        Multiprocessing function that computes speaker adaptation transforms via
        feature-space Maximum Likelihood Linear Regression (fMLLR).

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.CalcFmllrFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.calc_fmllr_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        :kaldi_steps:`train_sat`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Calculating fMLLR for speaker adaptation...")

        arguments = self.calc_fmllr_arguments(iteration=iteration)
        with tqdm.tqdm(total=self.num_speakers,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CalcFmllrFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = CalcFmllrFunction(args)
                    for _ in function.run():
                        pbar.update(1)

        self.speaker_independent = False
        self.log_debug(f"Fmllr calculation took {time.time() - begin}")
    def segment_vad(self) -> None:
        """
        Run segmentation based off of VAD.

        See Also
        --------
        :class:`~montreal_forced_aligner.segmenter.SegmentVadFunction`
            Multiprocessing helper function for each job
        segment_vad_arguments
            Job method for generating arguments for helper function
        """

        arguments = self.segment_vad_arguments()
        old_utts = set()
        new_utts = []

        with tqdm.tqdm(total=self.num_utterances,
                       disable=getattr(
                           self, "quiet",
                           False)) as pbar, self.session() as session:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = SegmentVadFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                    while True:
                        try:
                            result = return_queue.get(timeout=1)
                            if isinstance(result, Exception):
                                error_dict[getattr(result, "job_name",
                                                   0)] = result
                                continue
                            if stopped.stop_check():
                                continue
                        except Empty:
                            for proc in procs:
                                if not proc.finished.stop_check():
                                    break
                            else:
                                break
                            continue
                        utt, begin, end = result
                        old_utts.add(utt)
                        channel, speaker_id, file_id = (session.query(
                            Utterance.channel, Utterance.speaker_id,
                            Utterance.file_id).filter(
                                Utterance.id == utt).first())
                        new_utts.append({
                            "begin": begin,
                            "end": end,
                            "text": "speech",
                            "speaker_id": speaker_id,
                            "file_id": file_id,
                            "oovs": "",
                            "normalized_text": "",
                            "normalized_text_int": "",
                            "features": "",
                            "in_subset": False,
                            "ignored": False,
                            "channel": channel,
                            "duration": end - begin,
                        })

                        pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = SegmentVadFunction(args)
                    for utt, begin, end in function.run():
                        old_utts.add(utt)
                        channel, speaker_id, file_id = (session.query(
                            Utterance.channel, Utterance.speaker_id,
                            Utterance.file_id).filter(
                                Utterance.id == utt).first())
                        new_utts.append({
                            "begin": begin,
                            "end": end,
                            "text": "speech",
                            "speaker_id": speaker_id,
                            "file_id": file_id,
                            "oovs": "",
                            "normalized_text": "",
                            "normalized_text_int": "",
                            "features": "",
                            "in_subset": False,
                            "ignored": False,
                            "channel": channel,
                            "duration": end - begin,
                        })
                        pbar.update(1)
            session.query(Utterance).filter(
                Utterance.id.in_(old_utts)).delete()
            session.bulk_insert_mappings(Utterance,
                                         new_utts,
                                         return_defaults=False,
                                         render_nulls=True)
            session.commit()
    def align_utterances(self, training=False) -> None:
        """
        Multiprocessing function that aligns based on the current model.

        See Also
        --------
        :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction`
            Multiprocessing helper function for each job
        :meth:`.AlignMixin.align_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_si`
            Reference Kaldi script
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        """
        begin = time.time()
        self.log_info("Generating alignments...")
        with tqdm.tqdm(
                total=self.num_current_utterances,
                disable=getattr(self, "quiet",
                                False)) as pbar, self.session() as session:
            if not training:
                utterances = session.query(Utterance)
                if hasattr(self, "subset"):
                    utterances = utterances.filter(
                        Utterance.in_subset == True)  # noqa
                utterances.update({"alignment_log_likelihood": None})
                session.commit()
            update_mappings = []
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(self.align_arguments()):
                    function = AlignFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if not training:
                        utterance, log_likelihood = result
                        update_mappings.append({
                            "id":
                            utterance,
                            "alignment_log_likelihood":
                            log_likelihood
                        })
                    pbar.update(1)
                for p in procs:
                    p.join()

                if not training and len(update_mappings) == 0:
                    raise NoAlignmentsError(self.num_current_utterances,
                                            self.beam, self.retry_beam)
                if error_dict:
                    for v in error_dict.values():
                        raise v

            else:
                self.log_debug("Not using multiprocessing...")
                for args in self.align_arguments():
                    function = AlignFunction(args)
                    for utterance, log_likelihood in function.run():
                        if not training:
                            update_mappings.append({
                                "id":
                                utterance,
                                "alignment_log_likelihood":
                                log_likelihood
                            })
                        pbar.update(1)
                if not training and len(update_mappings) == 0:
                    raise NoAlignmentsError(self.num_current_utterances,
                                            self.beam, self.retry_beam)
            if not training:
                session.bulk_update_mappings(Utterance, update_mappings)
                session.query(Utterance).filter(
                    Utterance.alignment_log_likelihood != None  # noqa
                ).update(
                    {
                        Utterance.alignment_log_likelihood:
                        Utterance.alignment_log_likelihood /
                        Utterance.num_frames
                    },
                    synchronize_session="fetch",
                )
                session.commit()
            self.log_debug(f"Alignment round took {time.time() - begin}")
    def _load_corpus_from_source_mp(self) -> None:
        """
        Load a corpus using multiprocessing
        """
        if self.stopped is None:
            self.stopped = Stopped()
        sanitize_function = getattr(self, "sanitize_function", None)
        begin_time = time.time()
        job_queue = mp.Queue()
        return_queue = mp.Queue()
        error_dict = {}
        finished_adding = Stopped()
        procs = []
        for i in range(self.num_jobs):
            p = CorpusProcessWorker(
                i,
                job_queue,
                return_queue,
                self.stopped,
                finished_adding,
                self.speaker_characters,
                sanitize_function,
                sample_rate=0,
            )
            procs.append(p)
            p.start()
        import_data = DatabaseImportData()
        try:
            file_count = 0
            with tqdm.tqdm(total=1, disable=getattr(
                    self, "quiet", False)) as pbar, self.session() as session:
                for root, _, files in os.walk(self.corpus_directory,
                                              followlinks=True):
                    exts = find_exts(files)
                    relative_path = (root.replace(self.corpus_directory,
                                                  "").lstrip("/").lstrip("\\"))

                    if self.stopped.stop_check():
                        break
                    for file_name in exts.identifiers:
                        if self.stopped.stop_check():
                            break
                        wav_path = None
                        if file_name in exts.lab_files:
                            lab_name = exts.lab_files[file_name]
                            transcription_path = os.path.join(root, lab_name)

                        elif file_name in exts.textgrid_files:
                            tg_name = exts.textgrid_files[file_name]
                            transcription_path = os.path.join(root, tg_name)
                        else:
                            continue
                        job_queue.put((file_name, wav_path, transcription_path,
                                       relative_path))
                        file_count += 1
                        pbar.total = file_count

                finished_adding.stop()

                while True:
                    try:
                        file = return_queue.get(timeout=1)
                        if isinstance(file, tuple):
                            error_type = file[0]
                            error = file[1]
                            if error_type == "error":
                                error_dict[error_type] = error
                            else:
                                if error_type not in error_dict:
                                    error_dict[error_type] = []
                                error_dict[error_type].append(error)
                            continue
                        if self.stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished_processing.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                    import_data.add_objects(self.generate_import_objects(file))

                self.log_debug("Waiting for workers to finish...")
                for p in procs:
                    p.join()

                if "error" in error_dict:
                    session.rollback()
                    raise error_dict["error"][1]

                self._finalize_load(session, import_data)

                for k in ["decode_error_files", "textgrid_read_errors"]:
                    if hasattr(self, k):
                        if k in error_dict:
                            self.log_info(
                                "There were some issues with files in the corpus. "
                                "Please look at the log file or run the validator for more information."
                            )
                            self.log_debug(
                                f"{k} showed {len(error_dict[k])} errors:")
                            if k == "textgrid_read_errors":
                                getattr(self, k).update(error_dict[k])
                                for e in error_dict[k]:
                                    self.log_debug(f"{e.file_name}: {e.error}")
                            else:
                                self.log_debug(", ".join(error_dict[k]))
                                setattr(self, k, error_dict[k])

        except KeyboardInterrupt:
            self.log_info(
                "Detected ctrl-c, please wait a moment while we clean everything up..."
            )
            self.stopped.stop()
            finished_adding.stop()
            job_queue.join()
            self.stopped.set_sigint_source()
            while True:
                try:
                    _ = return_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
        finally:

            finished_adding.stop()
            for p in procs:
                p.join()
            if self.stopped.stop_check():
                self.log_info(
                    f"Stopped parsing early ({time.time() - begin_time} seconds)"
                )
                if self.stopped.source():
                    sys.exit(0)
            else:
                self.log_debug(
                    f"Parsed corpus directory with {self.num_jobs} jobs in {time.time() - begin_time} seconds"
                )
Exemple #16
0
    def acc_stats(self, alignment: bool = False) -> None:
        """
        Accumulate stats for the mapped model

        Parameters
        ----------
        alignment: bool
            Flag for whether to accumulate stats for the mapped alignment model
        """
        arguments = self.map_acc_stats_arguments(alignment)
        if alignment:
            initial_mdl_path = os.path.join(self.working_directory,
                                            "unadapted.alimdl")
            final_mdl_path = os.path.join(self.working_directory,
                                          "final.alimdl")
        else:
            initial_mdl_path = os.path.join(self.working_directory,
                                            "unadapted.mdl")
            final_mdl_path = os.path.join(self.working_directory, "final.mdl")
        if not os.path.exists(initial_mdl_path):
            return
        self.log_info("Accumulating statistics...")
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = AccStatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    num_utterances, errors = result
                    pbar.update(num_utterances + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = AccStatsFunction(args)
                    for num_utterances, errors in function.run():
                        pbar.update(num_utterances + errors)
        log_path = os.path.join(self.working_log_directory,
                                "map_model_est.log")
        occs_path = os.path.join(self.working_directory, "final.occs")
        with open(log_path, "w", encoding="utf8") as log_file:
            acc_files = []
            for j in arguments:
                acc_files.extend(j.acc_paths.values())
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files,
                stderr=log_file,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            ismooth_proc = subprocess.Popen(
                [
                    thirdparty_binary("gmm-ismooth-stats"),
                    "--smooth-from-model",
                    f"--tau={self.mapping_tau}",
                    initial_mdl_path,
                    "-",
                    "-",
                ],
                stderr=log_file,
                stdin=sum_proc.stdout,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            est_proc = subprocess.Popen(
                [
                    thirdparty_binary("gmm-est"),
                    "--update-flags=m",
                    f"--write-occs={occs_path}",
                    "--remove-low-count-gaussians=false",
                    initial_mdl_path,
                    "-",
                    final_mdl_path,
                ],
                stdin=ismooth_proc.stdout,
                stderr=log_file,
                env=os.environ,
            )
            est_proc.communicate()
    def compile_train_graphs(self) -> None:
        """
        Multiprocessing function that compiles training graphs for utterances.

        See Also
        --------
        :class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsFunction`
            Multiprocessing helper function for each job
        :meth:`.AlignMixin.compile_train_graphs_arguments`
            Job method for generating arguments for the helper function
        :kaldi_steps:`align_si`
            Reference Kaldi script
        :kaldi_steps:`align_fmllr`
            Reference Kaldi script
        """
        begin = time.time()
        log_directory = self.working_log_directory
        os.makedirs(log_directory, exist_ok=True)
        self.log_info("Compiling training graphs...")
        error_sum = 0
        arguments = self.compile_train_graphs_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CompileTrainGraphsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    done, errors = result
                    pbar.update(done + errors)
                    error_sum += errors
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                self.log_debug("Not using multiprocessing...")
                for args in arguments:
                    function = CompileTrainGraphsFunction(args)
                    for done, errors in function.run():
                        pbar.update(done + errors)
                        error_sum += errors
        if error_sum:
            self.log_warning(
                f"Compilation of training graphs failed for {error_sum} utterances."
            )
        self.log_debug(f"Compiling training graphs took {time.time() - begin}")
Exemple #18
0
    def calc_lda_mllt(self) -> None:
        """
        Multiprocessing function that calculates LDA+MLLT transformations.

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction`
            Multiprocessing helper function for each job
        :meth:`.LdaTrainer.calc_lda_mllt_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`est-mllt`
            Relevant Kaldi binary
        :kaldi_src:`gmm-transform-means`
            Relevant Kaldi binary
        :kaldi_src:`compose-transforms`
            Relevant Kaldi binary
        :kaldi_steps:`train_lda_mllt`
            Reference Kaldi script

        """
        self.log_info("Re-calculating LDA...")
        arguments = self.calc_lda_mllt_arguments()
        with tqdm.tqdm(
            total=self.num_current_utterances, disable=getattr(self, "quiet", False)
        ) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = CalcLdaMlltFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = CalcLdaMlltFunction(args)
                    for _ in function.run():
                        pbar.update(1)

        log_path = os.path.join(
            self.working_log_directory, f"transform_means.{self.iteration}.log"
        )
        previous_mat_path = os.path.join(self.working_directory, "lda.mat")
        new_mat_path = os.path.join(self.working_directory, "lda_new.mat")
        composed_path = os.path.join(self.working_directory, "lda_composed.mat")
        with open(log_path, "a", encoding="utf8") as log_file:
            macc_list = []
            for x in arguments:
                macc_list.extend(x.macc_paths.values())
            subprocess.call(
                [thirdparty_binary("est-mllt"), new_mat_path] + macc_list,
                stderr=log_file,
                env=os.environ,
            )
            subprocess.call(
                [
                    thirdparty_binary("gmm-transform-means"),
                    new_mat_path,
                    self.model_path,
                    self.model_path,
                ],
                stderr=log_file,
                env=os.environ,
            )

            if os.path.exists(previous_mat_path):
                subprocess.call(
                    [
                        thirdparty_binary("compose-transforms"),
                        new_mat_path,
                        previous_mat_path,
                        composed_path,
                    ],
                    stderr=log_file,
                    env=os.environ,
                )
                os.remove(previous_mat_path)
                os.rename(composed_path, previous_mat_path)
            else:
                os.rename(new_mat_path, previous_mat_path)
Exemple #19
0
    def lda_acc_stats(self) -> None:
        """
        Multiprocessing function that accumulates LDA statistics.

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.LdaTrainer.lda_acc_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`est-lda`
            Relevant Kaldi binary
        :kaldi_steps:`train_lda_mllt`
            Reference Kaldi script

        """
        worker_lda_path = os.path.join(self.worker.working_directory, "lda.mat")
        lda_path = os.path.join(self.working_directory, "lda.mat")
        if os.path.exists(worker_lda_path):
            os.remove(worker_lda_path)
        arguments = self.lda_acc_stats_arguments()
        with tqdm.tqdm(
            total=self.num_current_utterances, disable=getattr(self, "quiet", False)
        ) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = LdaAccStatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    done, errors = result
                    pbar.update(done + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = LdaAccStatsFunction(args)
                    for done, errors in function.run():
                        pbar.update(done + errors)

        log_path = os.path.join(self.working_log_directory, "lda_est.log")
        acc_list = []
        for x in arguments:
            acc_list.extend(x.acc_paths.values())
        with open(log_path, "w", encoding="utf8") as log_file:
            est_lda_proc = subprocess.Popen(
                [
                    thirdparty_binary("est-lda"),
                    f"--dim={self.lda_dimension}",
                    lda_path,
                ]
                + acc_list,
                stderr=log_file,
                env=os.environ,
            )
            est_lda_proc.communicate()
        shutil.copyfile(
            lda_path,
            worker_lda_path,
        )
    def train_g2p_lexicon(self) -> None:
        """Generate a G2P lexicon based on aligned transcripts"""
        arguments = self.worker.generate_pronunciations_arguments()
        working_dir = super(PronunciationProbabilityTrainer,
                            self).working_directory
        texts = {}
        with self.worker.session() as session:
            query = session.query(Utterance.id,
                                  Utterance.normalized_character_text)
            query = query.filter(Utterance.ignored == False)  # noqa
            initial_brackets = "".join(x[0] for x in self.worker.brackets)
            query = query.filter(
                ~Utterance.oovs.regexp_match(f"(^| )[^{initial_brackets}]"))
            if self.subset:
                query = query.filter_by(in_subset=True)
            for utt_id, text in query:
                texts[utt_id] = text
            input_files = {
                x: open(
                    os.path.join(
                        working_dir,
                        f"input_{self.worker.dictionary_base_names[x]}.txt"),
                    "w",
                    encoding="utf8",
                )
                for x in self.worker.dictionary_lookup.values()
            }
            output_files = {
                x: open(
                    os.path.join(
                        working_dir,
                        f"output_{self.worker.dictionary_base_names[x]}.txt"),
                    "w",
                    encoding="utf8",
                )
                for x in self.worker.dictionary_lookup.values()
            }
            with tqdm.tqdm(total=self.num_current_utterances,
                           disable=getattr(self, "quiet", False)) as pbar:
                if self.use_mp:
                    error_dict = {}
                    return_queue = mp.Queue()
                    stopped = Stopped()
                    procs = []
                    for i, args in enumerate(arguments):
                        args.for_g2p = True
                        function = GeneratePronunciationsFunction(args)
                        p = KaldiProcessWorker(i, return_queue, function,
                                               stopped)
                        procs.append(p)
                        p.start()
                    while True:
                        try:
                            result = return_queue.get(timeout=1)
                            if isinstance(result, Exception):
                                error_dict[getattr(result, "job_name",
                                                   0)] = result
                                continue
                            if stopped.stop_check():
                                continue
                        except Empty:
                            for proc in procs:
                                if not proc.finished.stop_check():
                                    break
                            else:
                                break
                            continue
                        dict_id, utt_id, phones = result
                        utt_id = int(utt_id.split("-")[-1])
                        pbar.update(1)
                        if utt_id not in texts or not texts[utt_id]:
                            continue
                        print(phones, file=output_files[dict_id])
                        print(f"<s> {texts[utt_id]} </s>",
                              file=input_files[dict_id])

                    for p in procs:
                        p.join()
                    if error_dict:
                        for v in error_dict.values():
                            raise v
                else:
                    self.log_debug("Not using multiprocessing...")
                    for args in arguments:
                        function = GeneratePronunciationsFunction(args)
                        for dict_id, utt_id, phones in function.run():
                            print(phones, file=output_files[dict_id])
                            print(f"<s> {texts[utt_id]} </s>",
                                  file=input_files[dict_id])
                            pbar.update(1)
            for f in input_files.values():
                f.close()
            for f in output_files.values():
                f.close()
            self.pronunciations_complete = True
            os.makedirs(self.working_log_directory, exist_ok=True)
            dictionaries = session.query(Dictionary)
            shutil.copyfile(self.phone_symbol_table_path,
                            os.path.join(self.working_directory, "phones.txt"))
            shutil.copyfile(
                self.grapheme_symbol_table_path,
                os.path.join(self.working_directory, "graphemes.txt"),
            )
            self.input_token_type = self.grapheme_symbol_table_path
            self.output_token_type = self.phone_symbol_table_path
            for d in dictionaries:
                self.log_info(f"Training G2P for {d.name}...")
                self._data_source = self.worker.dictionary_base_names[d.id]
                begin = time.time()
                if os.path.exists(self.far_path) and os.path.exists(
                        self.encoder_path):
                    self.log_info("Alignment already done, skipping!")
                else:
                    self.align_g2p()
                    self.log_debug(
                        f"Aligning utterances for {d.name} took {time.time() - begin} seconds"
                    )
                begin = time.time()
                self.generate_model()
                self.log_debug(
                    f"Generating model for {d.name} took {time.time() - begin} seconds"
                )
                os.rename(d.lexicon_fst_path, d.lexicon_fst_path + ".backup")
                shutil.copy(self.fst_path, d.lexicon_fst_path)
                d.use_g2p = True
            session.commit()
            self.worker.use_g2p = True
Exemple #21
0
    def generate_pronunciations(self) -> Dict[str, List[str]]:
        """
        Generate pronunciations

        Returns
        -------
        dict[str, list[str]]
            Mappings of keys to their generated pronunciations
        """

        fst = pynini.Fst.read(self.g2p_model.fst_path)
        if self.g2p_model.meta["architecture"] == "phonetisaurus":
            output_token_type = pynini.SymbolTable.read_text(
                self.g2p_model.sym_path)
            input_token_type = pynini.SymbolTable.read_text(
                self.g2p_model.grapheme_sym_path)
            fst.set_input_symbols(input_token_type)
            fst.set_output_symbols(output_token_type)
            rewriter = PhonetisaurusRewriter(
                fst,
                input_token_type,
                output_token_type,
                num_pronunciations=self.num_pronunciations,
                threshold=self.g2p_threshold,
                grapheme_order=self.g2p_model.meta["grapheme_order"],
            )
        else:
            output_token_type = "utf8"
            input_token_type = "utf8"
            if self.g2p_model.sym_path is not None and os.path.exists(
                    self.g2p_model.sym_path):
                output_token_type = pynini.SymbolTable.read_text(
                    self.g2p_model.sym_path)
            rewriter = Rewriter(
                fst,
                input_token_type,
                output_token_type,
                num_pronunciations=self.num_pronunciations,
                threshold=self.g2p_threshold,
            )

        num_words = len(self.words_to_g2p)
        begin = time.time()
        missing_graphemes = set()
        self.log_info("Generating pronunciations...")
        to_return = {}
        skipped_words = 0
        if num_words < 30 or self.num_jobs == 1:
            with tqdm.tqdm(total=num_words,
                           disable=getattr(self, "quiet", False)) as pbar:
                for word in self.words_to_g2p:
                    w, m = clean_up_word(word,
                                         self.g2p_model.meta["graphemes"])
                    pbar.update(1)
                    missing_graphemes = missing_graphemes | m
                    if self.strict_graphemes and m:
                        skipped_words += 1
                        continue
                    if not w:
                        skipped_words += 1
                        continue
                    try:
                        prons = rewriter(w)
                    except rewrite.Error:
                        continue
                    to_return[word] = prons
                self.log_debug(
                    f"Skipping {skipped_words} words for containing the following graphemes: "
                    f"{comma_join(sorted(missing_graphemes))}")
        else:
            stopped = Stopped()
            job_queue = mp.Queue()
            for word in self.words_to_g2p:
                w, m = clean_up_word(word, self.g2p_model.meta["graphemes"])
                missing_graphemes = missing_graphemes | m
                if self.strict_graphemes and m:
                    skipped_words += 1
                    continue
                if not w:
                    skipped_words += 1
                    continue
                job_queue.put(w)
            self.log_debug(
                f"Skipping {skipped_words} words for containing the following graphemes: "
                f"{comma_join(sorted(missing_graphemes))}")
            error_dict = {}
            return_queue = mp.Queue()
            procs = []
            for _ in range(self.num_jobs):
                p = RewriterWorker(
                    job_queue,
                    return_queue,
                    rewriter,
                    stopped,
                )
                procs.append(p)
                p.start()
            num_words -= skipped_words
            with tqdm.tqdm(total=num_words,
                           disable=getattr(self, "quiet", False)) as pbar:
                while True:
                    try:
                        word, result = return_queue.get(timeout=1)
                        if stopped.stop_check():
                            continue
                    except queue.Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                    if isinstance(result, Exception):
                        error_dict[word] = result
                        continue
                    to_return[word] = result

            for p in procs:
                p.join()
            if error_dict:
                raise PyniniGenerationError(error_dict)
        self.log_debug(
            f"Processed {num_words} in {time.time() - begin} seconds")
        return to_return
    def _alignments(self) -> None:
        """Trains the aligner and constructs the alignments FAR."""
        if not os.path.exists(self.align_path):
            self.log_info("Training aligner")
            train_opts = []
            if self.batch_size:
                train_opts.append(f"--batch_size={self.batch_size}")
            if self.delta:
                train_opts.append(f"--delta={self.delta}")
            if self.fst_default_cache_gc:
                train_opts.append(
                    f"--fst_default_cache_gc={self.fst_default_cache_gc}")
            if self.fst_default_cache_gc_limit:
                train_opts.append(
                    f"--fst_default_cache_gc_limit={self.fst_default_cache_gc_limit}"
                )
            if self.alpha:
                train_opts.append(f"--alpha={self.alpha}")
            if self.num_iterations:
                train_opts.append(f"--max_iters={self.num_iterations}")
            # Constructs the actual command vectors (plus an index for logging
            # purposes).
            random.seed(self.seed)
            starts = [(RandomStart(
                idx,
                seed,
                self.input_far_path,
                self.output_far_path,
                self.cg_path,
                self.working_directory,
                train_opts,
            )) for (idx, seed) in enumerate(
                random.sample(range(1, RAND_MAX), self.random_starts), 1)]
            stopped = Stopped()
            num_commands = len(starts)
            job_queue = mp.JoinableQueue()
            fst_likelihoods = {}
            # Actually runs starts.
            self.log_info("Calculating alignments...")
            begin = time.time()
            with tqdm.tqdm(total=num_commands * self.num_iterations,
                           disable=getattr(self, "quiet", False)) as pbar:
                for start in starts:
                    job_queue.put(start)
                error_dict = {}
                return_queue = mp.Queue()
                procs = []
                for i in range(self.num_jobs):
                    log_path = os.path.join(self.working_log_directory,
                                            f"baumwelch.{i}.log")
                    p = RandomStartWorker(
                        i,
                        job_queue,
                        return_queue,
                        log_path,
                        stopped,
                    )
                    procs.append(p)
                    p.start()

                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except queue.Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    if isinstance(result, int):
                        pbar.update(result)
                    else:
                        fst_likelihoods[result[0]] = result[1]
                for p in procs:
                    p.join()
            if error_dict:
                raise PyniniAlignmentError(error_dict)
            (best_fst, best_likelihood) = min(fst_likelihoods.items(),
                                              key=operator.itemgetter(1))
            self.log_info(f"Best likelihood: {best_likelihood}")
            self.log_debug(
                f"Ran {self.random_starts} random starts in {time.time() - begin} seconds"
            )
            # Moves best likelihood solution to the requested location.
            shutil.move(best_fst, self.align_path)
        cmd = [thirdparty_binary("baumwelchdecode")]
        if self.fst_default_cache_gc:
            cmd.append(f"--fst_default_cache_gc={self.fst_default_cache_gc}")
        if self.fst_default_cache_gc_limit:
            cmd.append(
                f"--fst_default_cache_gc_limit={self.fst_default_cache_gc_limit}"
            )
        cmd.append(self.input_far_path)
        cmd.append(self.output_far_path)
        cmd.append(self.align_path)
        cmd.append(self.afst_path)
        self.log_debug(f"Subprocess call: {cmd}")
        subprocess.check_call(cmd, env=os.environ)
        self.log_info("Completed computing alignments!")
Exemple #23
0
    def mfcc(self) -> None:
        """
        Multiprocessing function that converts sound files into MFCCs.

        See :kaldi_docs:`feat` for an overview on feature generation in Kaldi.

        See Also
        --------
        :class:`~montreal_forced_aligner.corpus.features.MfccFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticCorpusMixin.mfcc_arguments`
            Job method for generating arguments for helper function
        :kaldi_steps:`make_mfcc`
            Reference Kaldi script
        """
        self.log_info("Generating MFCCs...")
        log_directory = os.path.join(self.split_directory, "log")
        os.makedirs(log_directory, exist_ok=True)
        arguments = self.mfcc_arguments()
        with tqdm.tqdm(total=self.num_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = MfccFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(result)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        print(v)
                    self.dirty = True
                    sys.exit(1)
            else:
                for args in arguments:
                    function = MfccFunction(args)
                    for num_utterances in function.run():
                        pbar.update(num_utterances)
        with self.session() as session:
            update_mapping = []
            session.query(Utterance).update({"ignored": True})
            for j in arguments:
                with open(j.feats_scp_path, "r", encoding="utf8") as f:
                    for line in f:
                        line = line.strip()
                        if line == "":
                            continue
                        f = line.split(maxsplit=1)
                        utt_id = int(f[0].split("-")[-1])
                        feats = f[1]
                        update_mapping.append({
                            "id": utt_id,
                            "features": feats,
                            "ignored": False
                        })
            session.bulk_update_mappings(Utterance, update_mapping)
            session.commit()
    def acc_stats(self) -> None:
        """
        Multiprocessing function that accumulates stats for GMM training.

        See Also
        --------
        :class:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.AcousticModelTrainingMixin.acc_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`gmm-sum-accs`
            Relevant Kaldi binary
        :kaldi_src:`gmm-est`
            Relevant Kaldi binary
        :kaldi_steps:`train_mono`
            Reference Kaldi script
        :kaldi_steps:`train_deltas`
            Reference Kaldi script
        """
        self.log_info("Accumulating statistics...")
        arguments = self.acc_stats_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = AccStatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    num_utterances, errors = result
                    pbar.update(num_utterances + errors)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = AccStatsFunction(args)
                    for num_utterances, errors in function.run():
                        pbar.update(num_utterances + errors)

        log_path = os.path.join(self.working_log_directory,
                                f"update.{self.iteration}.log")
        with open(log_path, "w") as log_file:
            acc_files = []
            for a in arguments:
                acc_files.extend(a.acc_paths.values())
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files,
                stdout=subprocess.PIPE,
                stderr=log_file,
                env=os.environ,
            )
            est_command = [
                thirdparty_binary("gmm-est"),
                f"--write-occs={self.next_occs_path}",
                f"--mix-up={self.current_gaussians}",
            ]
            if self.power > 0:
                est_command.append(f"--power={self.power}")
            est_command.extend([
                self.model_path,
                "-",
                self.next_model_path,
            ])
            est_proc = subprocess.Popen(
                est_command,
                stdin=sum_proc.stdout,
                stderr=log_file,
                env=os.environ,
            )
            est_proc.communicate()
        avg_like_pattern = re.compile(
            r"Overall avg like per frame.* = (?P<like>[-.,\d]+) over (?P<frames>[.\d+e]+) frames"
        )
        average_logdet_pattern = re.compile(
            r"Overall average logdet is (?P<logdet>[-.,\d]+) over (?P<frames>[.\d+e]+) frames"
        )
        avg_like_sum = 0
        avg_like_frames = 0
        average_logdet_sum = 0
        average_logdet_frames = 0
        for a in arguments:
            with open(a.log_path, "r", encoding="utf8") as f:
                for line in f:
                    m = avg_like_pattern.search(line)
                    if m:
                        like = float(m.group("like"))
                        frames = float(m.group("frames"))
                        avg_like_sum += like * frames
                        avg_like_frames += frames
                    m = average_logdet_pattern.search(line)
                    if m:
                        logdet = float(m.group("logdet"))
                        frames = float(m.group("frames"))
                        average_logdet_sum += logdet * frames
                        average_logdet_frames += frames
        if avg_like_frames:
            log_like = avg_like_sum / avg_like_frames
            if average_logdet_frames:
                log_like += average_logdet_sum / average_logdet_frames
            self.log_debug(
                f"Likelihood for iteration {self.iteration}: {log_like}")

        if not self.debug:
            for f in acc_files:
                os.remove(f)
    def acc_global_stats(self) -> None:
        """
        Multiprocessing function that accumulates global GMM stats

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.AccGlobalStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.DubmTrainer.acc_global_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`gmm-global-sum-accs`
            Relevant Kaldi binary
        :kaldi_steps:`train_diag_ubm`
            Reference Kaldi script

        """
        begin = time.time()
        self.log_info("Accumulating global stats...")
        arguments = self.acc_global_stats_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = AccGlobalStatsFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if isinstance(result, Exception):
                        error_dict[getattr(result, "job_name", 0)] = result
                        continue
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = AccGlobalStatsFunction(args)
                for _ in function.run():
                    pass

        self.log_debug(f"Accumulating stats took {time.time() - begin}")

        # Don't remove low-count Gaussians till the last tier,
        # or gselect info won't be valid anymore
        if self.iteration < self.num_iterations:
            opt = "--remove-low-count-gaussians=false"
        else:
            opt = f"--remove-low-count-gaussians={self.remove_low_count_gaussians}"
        log_path = os.path.join(self.working_log_directory,
                                f"update.{self.iteration}.log")
        with open(log_path, "w") as log_file:
            acc_files = []
            for j in arguments:
                acc_files.append(j.acc_path)
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-global-sum-accs"), "-"] + acc_files,
                stderr=log_file,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            gmm_global_est_proc = subprocess.Popen(
                [
                    thirdparty_binary("gmm-global-est"),
                    opt,
                    f"--min-gaussian-weight={self.min_gaussian_weight}",
                    self.model_path,
                    "-",
                    self.next_model_path,
                ],
                stderr=log_file,
                stdin=sum_proc.stdout,
                env=os.environ,
            )
            gmm_global_est_proc.communicate()
            # Clean up
            if not self.debug:
                for p in acc_files:
                    os.remove(p)
Exemple #26
0
    def _load_corpus_from_source_mp(self) -> None:
        """
        Load a corpus using multiprocessing
        """
        begin_time = time.process_time()
        job_queue = mp.Queue()
        return_queue = mp.Queue()
        finished_adding = Stopped()
        stopped = Stopped()
        file_counts = Counter()
        sanitize_function = getattr(self, "sanitize_function", None)
        error_dict = {}
        procs = []
        self.db_engine.dispose()
        parser = AcousticDirectoryParser(
            self.corpus_directory,
            job_queue,
            self.audio_directory,
            stopped,
            finished_adding,
            file_counts,
        )
        parser.start()
        for i in range(self.num_jobs):
            p = CorpusProcessWorker(
                i,
                job_queue,
                return_queue,
                stopped,
                finished_adding,
                self.speaker_characters,
                sanitize_function,
                self.sample_frequency,
            )
            procs.append(p)
            p.start()
        last_poll = time.time() - 30
        import_data = DatabaseImportData()
        try:
            with self.session() as session:
                with tqdm.tqdm(total=100,
                               disable=getattr(self, "quiet", False)) as pbar:
                    while True:
                        try:
                            file = return_queue.get(timeout=1)
                            if isinstance(file, tuple):
                                error_type = file[0]
                                error = file[1]
                                if error_type == "error":
                                    error_dict[error_type] = error
                                else:
                                    if error_type not in error_dict:
                                        error_dict[error_type] = []
                                    error_dict[error_type].append(error)
                                continue
                            if self.stopped.stop_check():
                                continue
                        except Empty:
                            for proc in procs:
                                if not proc.finished_processing.stop_check():
                                    break
                            else:
                                break
                            continue
                        if time.time() - last_poll > 15:
                            pbar.total = file_counts.value()
                            last_poll = time.time()
                        pbar.update(1)
                        import_data.add_objects(
                            self.generate_import_objects(file))

                    self.log_debug(
                        f"Processing queue: {time.process_time() - begin_time}"
                    )

                    if "error" in error_dict:
                        session.rollback()
                        raise error_dict["error"][1]
                    self._finalize_load(session, import_data)
            for k in [
                    "sound_file_errors", "decode_error_files",
                    "textgrid_read_errors"
            ]:
                if hasattr(self, k):
                    if k in error_dict:
                        self.log_info(
                            "There were some issues with files in the corpus. "
                            "Please look at the log file or run the validator for more information."
                        )
                        self.log_debug(
                            f"{k} showed {len(error_dict[k])} errors:")
                        if k in {"textgrid_read_errors", "sound_file_errors"}:
                            getattr(self, k).update(error_dict[k])
                            for e in error_dict[k]:
                                self.log_debug(f"{e.file_name}: {e.error}")
                        else:
                            self.log_debug(", ".join(error_dict[k]))
                            setattr(self, k, error_dict[k])

        except Exception as e:
            if isinstance(e, KeyboardInterrupt):
                self.log_info(
                    "Detected ctrl-c, please wait a moment while we clean everything up..."
                )
                self.stopped.set_sigint_source()
            self.stopped.stop()
            finished_adding.stop()
            while True:
                try:
                    _ = job_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
                try:
                    _ = return_queue.get(timeout=1)
                    _ = job_queue.get(timeout=1)
                    if self.stopped.stop_check():
                        continue
                except Empty:
                    for proc in procs:
                        if not proc.finished_processing.stop_check():
                            break
                    else:
                        break
        finally:
            parser.join()
            for p in procs:
                p.join()
            if self.stopped.stop_check():
                self.log_info(
                    f"Stopped parsing early ({time.process_time() - begin_time} seconds)"
                )
                if self.stopped.source():
                    sys.exit(0)
            else:
                self.log_debug(
                    f"Parsed corpus directory with {self.num_jobs} jobs in {time.process_time() - begin_time} seconds"
                )
    def acc_ivector_stats(self) -> None:
        """
        Multiprocessing function that accumulates ivector extraction stats.

        See Also
        --------
        :func:`~montreal_forced_aligner.ivector.trainer.AccIvectorStatsFunction`
            Multiprocessing helper function for each job
        :meth:`.IvectorTrainer.acc_ivector_stats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`ivector-extractor-sum-accs`
            Relevant Kaldi binary
        :kaldi_src:`ivector-extractor-est`
            Relevant Kaldi binary
        :kaldi_steps_sid:`train_ivector_extractor`
            Reference Kaldi script
        """

        begin = time.time()
        self.log_info("Accumulating ivector stats...")
        arguments = self.acc_ivector_stats_arguments()

        if self.use_mp:
            error_dict = {}
            return_queue = mp.Queue()
            stopped = Stopped()
            procs = []
            for i, args in enumerate(arguments):
                function = AccIvectorStatsFunction(args)
                p = KaldiProcessWorker(i, return_queue, function, stopped)
                procs.append(p)
                p.start()
            while True:
                try:
                    result = return_queue.get(timeout=1)
                    if stopped.stop_check():
                        continue
                except queue.Empty:
                    for proc in procs:
                        if not proc.finished.stop_check():
                            break
                    else:
                        break
                    continue
                if isinstance(result, KaldiProcessingError):
                    error_dict[result.job_name] = result
                    continue
            for p in procs:
                p.join()
            if error_dict:
                for v in error_dict.values():
                    raise v

        else:
            self.log_debug("Not using multiprocessing...")
            for args in arguments:
                function = AccIvectorStatsFunction(args)
                for _ in function.run():
                    pass

        self.log_debug(f"Accumulating stats took {time.time() - begin}")

        log_path = os.path.join(self.working_log_directory,
                                f"sum_acc.{self.iteration}.log")
        acc_path = os.path.join(self.working_directory,
                                f"acc.{self.iteration}")
        with open(log_path, "w", encoding="utf8") as log_file:
            accinits = []
            for j in arguments:
                accinits.append(j.acc_init_path)
            sum_accs_proc = subprocess.Popen(
                [
                    thirdparty_binary("ivector-extractor-sum-accs"),
                    "--parallel=true"
                ] + accinits + [acc_path],
                stderr=log_file,
                env=os.environ,
            )

            sum_accs_proc.communicate()
        # clean up
        for p in accinits:
            os.remove(p)
        # Est extractor
        log_path = os.path.join(self.working_log_directory,
                                f"update.{self.iteration}.log")
        with open(log_path, "w") as log_file:
            extractor_est_proc = subprocess.Popen(
                [
                    thirdparty_binary("ivector-extractor-est"),
                    f"--num-threads={len(self.jobs)}",
                    f"--gaussian-min-count={self.gaussian_min_count}",
                    self.ie_path,
                    os.path.join(self.working_directory,
                                 f"acc.{self.iteration}"),
                    self.next_ie_path,
                ],
                stderr=log_file,
                env=os.environ,
            )
            extractor_est_proc.communicate()
    def create_align_model(self) -> None:
        """
        Create alignment model for speaker-adapted training that will use speaker-independent
        features in later aligning.

        See Also
        --------
        :func:`~montreal_forced_aligner.acoustic_modeling.sat.AccStatsTwoFeatsFunction`
            Multiprocessing helper function for each job
        :meth:`.SatTrainer.acc_stats_two_feats_arguments`
            Job method for generating arguments for the helper function
        :kaldi_src:`gmm-est`
            Relevant Kaldi binary
        :kaldi_src:`gmm-sum-accs`
            Relevant Kaldi binary
        :kaldi_steps:`train_sat`
            Reference Kaldi script
        """
        self.log_info(
            "Creating alignment model for speaker-independent features...")
        begin = time.time()

        arguments = self.acc_stats_two_feats_arguments()
        with tqdm.tqdm(total=self.num_current_utterances,
                       disable=getattr(self, "quiet", False)) as pbar:
            if self.use_mp:
                error_dict = {}
                return_queue = mp.Queue()
                stopped = Stopped()
                procs = []
                for i, args in enumerate(arguments):
                    function = AccStatsTwoFeatsFunction(args)
                    p = KaldiProcessWorker(i, return_queue, function, stopped)
                    procs.append(p)
                    p.start()
                while True:
                    try:
                        result = return_queue.get(timeout=1)
                        if isinstance(result, Exception):
                            error_dict[getattr(result, "job_name", 0)] = result
                            continue
                        if stopped.stop_check():
                            continue
                    except Empty:
                        for proc in procs:
                            if not proc.finished.stop_check():
                                break
                        else:
                            break
                        continue
                    pbar.update(1)
                for p in procs:
                    p.join()
                if error_dict:
                    for v in error_dict.values():
                        raise v
            else:
                for args in arguments:
                    function = AccStatsTwoFeatsFunction(args)
                    for _ in function.run():
                        pbar.update(1)

        log_path = os.path.join(self.working_log_directory,
                                "align_model_est.log")
        with open(log_path, "w", encoding="utf8") as log_file:

            acc_files = []
            for x in arguments:
                acc_files.extend(x.acc_paths.values())
            sum_proc = subprocess.Popen(
                [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files,
                stderr=log_file,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            est_command = [
                thirdparty_binary("gmm-est"),
                "--remove-low-count-gaussians=false",
            ]
            if not self.quick:
                est_command.append(f"--power={self.power}")
            else:
                est_command.append(
                    f"--write-occs={os.path.join(self.working_directory, 'final.occs')}"
                )
            est_command.extend([
                self.model_path,
                "-",
                self.model_path.replace(".mdl", ".alimdl"),
            ])
            est_proc = subprocess.Popen(
                est_command,
                stdin=sum_proc.stdout,
                stderr=log_file,
                env=os.environ,
            )
            est_proc.communicate()
        parse_logs(self.working_log_directory)
        if not self.debug:
            for f in acc_files:
                os.remove(f)
        self.log_debug(f"Alignment model creation took {time.time() - begin}")