def __init__( self, corpus_directory: str, speaker_characters: typing.Union[int, str] = 0, ignore_speakers: bool = False, **kwargs, ): if not os.path.exists(corpus_directory): raise CorpusError( f"The directory '{corpus_directory}' does not exist.") if not os.path.isdir(corpus_directory): raise CorpusError( f"The specified path for the corpus ({corpus_directory}) is not a directory." ) self._speaker_ids = {} self.corpus_directory = corpus_directory self.speaker_characters = speaker_characters self.ignore_speakers = ignore_speakers self.word_counts = Counter() self.stopped = Stopped() self.decode_error_files = [] self.textgrid_read_errors = [] self.jobs: typing.List[Job] = [] self._num_speakers = None self._num_utterances = None self._num_files = None super().__init__(**kwargs) os.makedirs(self.corpus_output_directory, exist_ok=True) self.imported = False self._current_speaker_index = 1 self._current_file_index = 1 self._speaker_ids = {}
def convert_alignments(self) -> None: """ Multiprocessing function that converts alignments from previous training See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.triphone.ConvertAlignmentsFunction` Multiprocessing helper function for each job :meth:`.TriphoneTrainer.convert_alignments_arguments` Job method for generating arguments for the helper function :kaldi_steps:`train_deltas` Reference Kaldi script :kaldi_steps:`train_lda_mllt` Reference Kaldi script :kaldi_steps:`train_sat` Reference Kaldi script """ self.log_info("Converting alignments...") arguments = self.convert_alignments_arguments() with tqdm.tqdm(total=self.num_current_utterances, disable=getattr(self, "quiet", False)) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = ConvertAlignmentsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue num_utterances, errors = result pbar.update(num_utterances + errors) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = ConvertAlignmentsFunction(args) for num_utterances, errors in function.run(): pbar.update(num_utterances + errors)
def compute_vad(self) -> None: """ Compute Voice Activity Detection features over the corpus See Also -------- :class:`~montreal_forced_aligner.corpus.features.ComputeVadFunction` Multiprocessing helper function for each job :meth:`.AcousticCorpusMixin.compute_vad_arguments` Job method for generating arguments for helper function """ if os.path.exists(os.path.join(self.split_directory, "vad.0.scp")): self.log_info("VAD already computed, skipping!") return begin = time.time() self.log_info("Computing VAD...") arguments = self.compute_vad_arguments() with tqdm.tqdm(total=self.num_speakers, disable=getattr(self, "quiet", False)) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = ComputeVadFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue if isinstance(result, KaldiProcessingError): error_dict[result.job_name] = result continue done, no_feats, unvoiced = result pbar.update(done + no_feats + unvoiced) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = ComputeVadFunction(args) for done, no_feats, unvoiced in function.run(): pbar.update(done + no_feats + unvoiced) self.log_debug(f"VAD computation took {time.time() - begin}")
def gmm_gselect(self) -> None: """ Multiprocessing function that stores Gaussian selection indices on disk See Also -------- :func:`~montreal_forced_aligner.ivector.trainer.GmmGselectFunction` Multiprocessing helper function for each job :meth:`.DubmTrainer.gmm_gselect_arguments` Job method for generating arguments for the helper function :kaldi_steps:`train_diag_ubm` Reference Kaldi script """ begin = time.time() self.log_info("Selecting gaussians...") arguments = self.gmm_gselect_arguments() if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = GmmGselectFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except queue.Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: self.log_debug("Not using multiprocessing...") for args in arguments: function = GmmGselectFunction(args) for _ in function.run(): pass self.log_debug(f"Gaussian selection took {time.time() - begin}")
def extract_ivectors(self) -> None: """ Multiprocessing function that extracts job_name-vectors. See Also -------- :class:`~montreal_forced_aligner.corpus.features.ExtractIvectorsFunction` Multiprocessing helper function for each job :meth:`.IvectorCorpusMixin.extract_ivectors_arguments` Job method for generating arguments for helper function :kaldi_steps_sid:`extract_ivectors` Reference Kaldi script """ begin = time.time() log_dir = self.working_log_directory os.makedirs(log_dir, exist_ok=True) arguments = self.extract_ivectors_arguments() with tqdm.tqdm(total=self.num_speakers, disable=getattr(self, "quiet", False)) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = ExtractIvectorsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue pbar.update(1) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = ExtractIvectorsFunction(args) for _ in function.run(): pbar.update(1) self.log_debug(f"Ivector extraction took {time.time() - begin}")
def gauss_to_post(self) -> None: """ Multiprocessing function that does Gaussian selection and posterior extraction See Also -------- :func:`~montreal_forced_aligner.ivector.trainer.GaussToPostFunction` Multiprocessing helper function for each job :meth:`.IvectorTrainer.gauss_to_post_arguments` Job method for generating arguments for the helper function :kaldi_steps_sid:`train_ivector_extractor` Reference Kaldi script """ begin = time.time() self.log_info("Extracting posteriors...") arguments = self.gauss_to_post_arguments() if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = GaussToPostFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if stopped.stop_check(): continue except queue.Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue if isinstance(result, KaldiProcessingError): error_dict[result.job_name] = result continue for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: self.log_debug("Not using multiprocessing...") for args in arguments: function = GaussToPostFunction(args) for _ in function.run(): pass self.log_debug(f"Extracting posteriors took {time.time() - begin}")
def __init__(self, audio_directory: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.audio_directory = audio_directory self.sound_file_errors = [] self.transcriptions_without_wavs = [] self.no_transcription_files = [] self.stopped = Stopped() self.features_generated = False self.alignment_done = False self.transcription_done = False self.has_reference_alignments = False self.alignment_evaluation_done = False
def __init__( self, job_queue: mp.Queue, return_queue: mp.Queue, rewriter: Rewriter, stopped: Stopped, ): mp.Process.__init__(self) self.job_queue = job_queue self.return_queue = return_queue self.rewriter = rewriter self.stopped = stopped self.finished = Stopped()
def __init__( self, job_name: int, job_q: mp.Queue, return_queue: mp.Queue, log_file: str, stopped: Stopped, ): mp.Process.__init__(self) self.job_name = job_name self.job_q = job_q self.return_queue = return_queue self.log_file = log_file self.stopped = stopped self.finished = Stopped()
def __init__( self, name: int, job_q: mp.Queue, return_q: mp.Queue, stopped: Stopped, finished_adding: Stopped, speaker_characters: Union[int, str], sanitize_function: Optional[MultispeakerSanitizationFunction], sample_rate: Optional[int], ): mp.Process.__init__(self) self.name = str(name) self.job_q = job_q self.return_q = return_q self.stopped = stopped self.finished_adding = finished_adding self.finished_processing = Stopped() self.sanitize_function = sanitize_function self.speaker_characters = speaker_characters self.sample_rate = sample_rate
def __init__( self, db_path: str, for_write_queue: mp.Queue, return_queue: mp.Queue, stopped: Stopped, finished_adding: Stopped, arguments: ExportTextGridArguments, exported_file_count: Counter, ): mp.Process.__init__(self) self.db_path = db_path self.for_write_queue = for_write_queue self.return_queue = return_queue self.stopped = stopped self.finished_adding = finished_adding self.finished_processing = Stopped() self.output_directory = arguments.output_directory self.output_format = arguments.output_format self.frame_shift = arguments.frame_shift self.log_path = arguments.log_path self.include_original_text = arguments.include_original_text self.exported_file_count = exported_file_count
def calc_fmllr(self, iteration: Optional[int] = None) -> None: """ Multiprocessing function that computes speaker adaptation transforms via feature-space Maximum Likelihood Linear Regression (fMLLR). See Also -------- :class:`~montreal_forced_aligner.corpus.features.CalcFmllrFunction` Multiprocessing helper function for each job :meth:`.AcousticCorpusMixin.calc_fmllr_arguments` Job method for generating arguments for the helper function :kaldi_steps:`align_fmllr` Reference Kaldi script :kaldi_steps:`train_sat` Reference Kaldi script """ begin = time.time() self.log_info("Calculating fMLLR for speaker adaptation...") arguments = self.calc_fmllr_arguments(iteration=iteration) with tqdm.tqdm(total=self.num_speakers, disable=getattr(self, "quiet", False)) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = CalcFmllrFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue pbar.update(1) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = CalcFmllrFunction(args) for _ in function.run(): pbar.update(1) self.speaker_independent = False self.log_debug(f"Fmllr calculation took {time.time() - begin}")
def segment_vad(self) -> None: """ Run segmentation based off of VAD. See Also -------- :class:`~montreal_forced_aligner.segmenter.SegmentVadFunction` Multiprocessing helper function for each job segment_vad_arguments Job method for generating arguments for helper function """ arguments = self.segment_vad_arguments() old_utts = set() new_utts = [] with tqdm.tqdm(total=self.num_utterances, disable=getattr( self, "quiet", False)) as pbar, self.session() as session: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = SegmentVadFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue utt, begin, end = result old_utts.add(utt) channel, speaker_id, file_id = (session.query( Utterance.channel, Utterance.speaker_id, Utterance.file_id).filter( Utterance.id == utt).first()) new_utts.append({ "begin": begin, "end": end, "text": "speech", "speaker_id": speaker_id, "file_id": file_id, "oovs": "", "normalized_text": "", "normalized_text_int": "", "features": "", "in_subset": False, "ignored": False, "channel": channel, "duration": end - begin, }) pbar.update(1) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = SegmentVadFunction(args) for utt, begin, end in function.run(): old_utts.add(utt) channel, speaker_id, file_id = (session.query( Utterance.channel, Utterance.speaker_id, Utterance.file_id).filter( Utterance.id == utt).first()) new_utts.append({ "begin": begin, "end": end, "text": "speech", "speaker_id": speaker_id, "file_id": file_id, "oovs": "", "normalized_text": "", "normalized_text_int": "", "features": "", "in_subset": False, "ignored": False, "channel": channel, "duration": end - begin, }) pbar.update(1) session.query(Utterance).filter( Utterance.id.in_(old_utts)).delete() session.bulk_insert_mappings(Utterance, new_utts, return_defaults=False, render_nulls=True) session.commit()
def align_utterances(self, training=False) -> None: """ Multiprocessing function that aligns based on the current model. See Also -------- :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction` Multiprocessing helper function for each job :meth:`.AlignMixin.align_arguments` Job method for generating arguments for the helper function :kaldi_steps:`align_si` Reference Kaldi script :kaldi_steps:`align_fmllr` Reference Kaldi script """ begin = time.time() self.log_info("Generating alignments...") with tqdm.tqdm( total=self.num_current_utterances, disable=getattr(self, "quiet", False)) as pbar, self.session() as session: if not training: utterances = session.query(Utterance) if hasattr(self, "subset"): utterances = utterances.filter( Utterance.in_subset == True) # noqa utterances.update({"alignment_log_likelihood": None}) session.commit() update_mappings = [] if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(self.align_arguments()): function = AlignFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue if not training: utterance, log_likelihood = result update_mappings.append({ "id": utterance, "alignment_log_likelihood": log_likelihood }) pbar.update(1) for p in procs: p.join() if not training and len(update_mappings) == 0: raise NoAlignmentsError(self.num_current_utterances, self.beam, self.retry_beam) if error_dict: for v in error_dict.values(): raise v else: self.log_debug("Not using multiprocessing...") for args in self.align_arguments(): function = AlignFunction(args) for utterance, log_likelihood in function.run(): if not training: update_mappings.append({ "id": utterance, "alignment_log_likelihood": log_likelihood }) pbar.update(1) if not training and len(update_mappings) == 0: raise NoAlignmentsError(self.num_current_utterances, self.beam, self.retry_beam) if not training: session.bulk_update_mappings(Utterance, update_mappings) session.query(Utterance).filter( Utterance.alignment_log_likelihood != None # noqa ).update( { Utterance.alignment_log_likelihood: Utterance.alignment_log_likelihood / Utterance.num_frames }, synchronize_session="fetch", ) session.commit() self.log_debug(f"Alignment round took {time.time() - begin}")
def _load_corpus_from_source_mp(self) -> None: """ Load a corpus using multiprocessing """ if self.stopped is None: self.stopped = Stopped() sanitize_function = getattr(self, "sanitize_function", None) begin_time = time.time() job_queue = mp.Queue() return_queue = mp.Queue() error_dict = {} finished_adding = Stopped() procs = [] for i in range(self.num_jobs): p = CorpusProcessWorker( i, job_queue, return_queue, self.stopped, finished_adding, self.speaker_characters, sanitize_function, sample_rate=0, ) procs.append(p) p.start() import_data = DatabaseImportData() try: file_count = 0 with tqdm.tqdm(total=1, disable=getattr( self, "quiet", False)) as pbar, self.session() as session: for root, _, files in os.walk(self.corpus_directory, followlinks=True): exts = find_exts(files) relative_path = (root.replace(self.corpus_directory, "").lstrip("/").lstrip("\\")) if self.stopped.stop_check(): break for file_name in exts.identifiers: if self.stopped.stop_check(): break wav_path = None if file_name in exts.lab_files: lab_name = exts.lab_files[file_name] transcription_path = os.path.join(root, lab_name) elif file_name in exts.textgrid_files: tg_name = exts.textgrid_files[file_name] transcription_path = os.path.join(root, tg_name) else: continue job_queue.put((file_name, wav_path, transcription_path, relative_path)) file_count += 1 pbar.total = file_count finished_adding.stop() while True: try: file = return_queue.get(timeout=1) if isinstance(file, tuple): error_type = file[0] error = file[1] if error_type == "error": error_dict[error_type] = error else: if error_type not in error_dict: error_dict[error_type] = [] error_dict[error_type].append(error) continue if self.stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished_processing.stop_check(): break else: break continue pbar.update(1) import_data.add_objects(self.generate_import_objects(file)) self.log_debug("Waiting for workers to finish...") for p in procs: p.join() if "error" in error_dict: session.rollback() raise error_dict["error"][1] self._finalize_load(session, import_data) for k in ["decode_error_files", "textgrid_read_errors"]: if hasattr(self, k): if k in error_dict: self.log_info( "There were some issues with files in the corpus. " "Please look at the log file or run the validator for more information." ) self.log_debug( f"{k} showed {len(error_dict[k])} errors:") if k == "textgrid_read_errors": getattr(self, k).update(error_dict[k]) for e in error_dict[k]: self.log_debug(f"{e.file_name}: {e.error}") else: self.log_debug(", ".join(error_dict[k])) setattr(self, k, error_dict[k]) except KeyboardInterrupt: self.log_info( "Detected ctrl-c, please wait a moment while we clean everything up..." ) self.stopped.stop() finished_adding.stop() job_queue.join() self.stopped.set_sigint_source() while True: try: _ = return_queue.get(timeout=1) if self.stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished_processing.stop_check(): break else: break finally: finished_adding.stop() for p in procs: p.join() if self.stopped.stop_check(): self.log_info( f"Stopped parsing early ({time.time() - begin_time} seconds)" ) if self.stopped.source(): sys.exit(0) else: self.log_debug( f"Parsed corpus directory with {self.num_jobs} jobs in {time.time() - begin_time} seconds" )
def acc_stats(self, alignment: bool = False) -> None: """ Accumulate stats for the mapped model Parameters ---------- alignment: bool Flag for whether to accumulate stats for the mapped alignment model """ arguments = self.map_acc_stats_arguments(alignment) if alignment: initial_mdl_path = os.path.join(self.working_directory, "unadapted.alimdl") final_mdl_path = os.path.join(self.working_directory, "final.alimdl") else: initial_mdl_path = os.path.join(self.working_directory, "unadapted.mdl") final_mdl_path = os.path.join(self.working_directory, "final.mdl") if not os.path.exists(initial_mdl_path): return self.log_info("Accumulating statistics...") with tqdm.tqdm(total=self.num_current_utterances, disable=getattr(self, "quiet", False)) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = AccStatsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue num_utterances, errors = result pbar.update(num_utterances + errors) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = AccStatsFunction(args) for num_utterances, errors in function.run(): pbar.update(num_utterances + errors) log_path = os.path.join(self.working_log_directory, "map_model_est.log") occs_path = os.path.join(self.working_directory, "final.occs") with open(log_path, "w", encoding="utf8") as log_file: acc_files = [] for j in arguments: acc_files.extend(j.acc_paths.values()) sum_proc = subprocess.Popen( [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) ismooth_proc = subprocess.Popen( [ thirdparty_binary("gmm-ismooth-stats"), "--smooth-from-model", f"--tau={self.mapping_tau}", initial_mdl_path, "-", "-", ], stderr=log_file, stdin=sum_proc.stdout, stdout=subprocess.PIPE, env=os.environ, ) est_proc = subprocess.Popen( [ thirdparty_binary("gmm-est"), "--update-flags=m", f"--write-occs={occs_path}", "--remove-low-count-gaussians=false", initial_mdl_path, "-", final_mdl_path, ], stdin=ismooth_proc.stdout, stderr=log_file, env=os.environ, ) est_proc.communicate()
def compile_train_graphs(self) -> None: """ Multiprocessing function that compiles training graphs for utterances. See Also -------- :class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsFunction` Multiprocessing helper function for each job :meth:`.AlignMixin.compile_train_graphs_arguments` Job method for generating arguments for the helper function :kaldi_steps:`align_si` Reference Kaldi script :kaldi_steps:`align_fmllr` Reference Kaldi script """ begin = time.time() log_directory = self.working_log_directory os.makedirs(log_directory, exist_ok=True) self.log_info("Compiling training graphs...") error_sum = 0 arguments = self.compile_train_graphs_arguments() with tqdm.tqdm(total=self.num_current_utterances, disable=getattr(self, "quiet", False)) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = CompileTrainGraphsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue done, errors = result pbar.update(done + errors) error_sum += errors for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: self.log_debug("Not using multiprocessing...") for args in arguments: function = CompileTrainGraphsFunction(args) for done, errors in function.run(): pbar.update(done + errors) error_sum += errors if error_sum: self.log_warning( f"Compilation of training graphs failed for {error_sum} utterances." ) self.log_debug(f"Compiling training graphs took {time.time() - begin}")
def calc_lda_mllt(self) -> None: """ Multiprocessing function that calculates LDA+MLLT transformations. See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction` Multiprocessing helper function for each job :meth:`.LdaTrainer.calc_lda_mllt_arguments` Job method for generating arguments for the helper function :kaldi_src:`est-mllt` Relevant Kaldi binary :kaldi_src:`gmm-transform-means` Relevant Kaldi binary :kaldi_src:`compose-transforms` Relevant Kaldi binary :kaldi_steps:`train_lda_mllt` Reference Kaldi script """ self.log_info("Re-calculating LDA...") arguments = self.calc_lda_mllt_arguments() with tqdm.tqdm( total=self.num_current_utterances, disable=getattr(self, "quiet", False) ) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = CalcLdaMlltFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue pbar.update(1) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = CalcLdaMlltFunction(args) for _ in function.run(): pbar.update(1) log_path = os.path.join( self.working_log_directory, f"transform_means.{self.iteration}.log" ) previous_mat_path = os.path.join(self.working_directory, "lda.mat") new_mat_path = os.path.join(self.working_directory, "lda_new.mat") composed_path = os.path.join(self.working_directory, "lda_composed.mat") with open(log_path, "a", encoding="utf8") as log_file: macc_list = [] for x in arguments: macc_list.extend(x.macc_paths.values()) subprocess.call( [thirdparty_binary("est-mllt"), new_mat_path] + macc_list, stderr=log_file, env=os.environ, ) subprocess.call( [ thirdparty_binary("gmm-transform-means"), new_mat_path, self.model_path, self.model_path, ], stderr=log_file, env=os.environ, ) if os.path.exists(previous_mat_path): subprocess.call( [ thirdparty_binary("compose-transforms"), new_mat_path, previous_mat_path, composed_path, ], stderr=log_file, env=os.environ, ) os.remove(previous_mat_path) os.rename(composed_path, previous_mat_path) else: os.rename(new_mat_path, previous_mat_path)
def lda_acc_stats(self) -> None: """ Multiprocessing function that accumulates LDA statistics. See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction` Multiprocessing helper function for each job :meth:`.LdaTrainer.lda_acc_stats_arguments` Job method for generating arguments for the helper function :kaldi_src:`est-lda` Relevant Kaldi binary :kaldi_steps:`train_lda_mllt` Reference Kaldi script """ worker_lda_path = os.path.join(self.worker.working_directory, "lda.mat") lda_path = os.path.join(self.working_directory, "lda.mat") if os.path.exists(worker_lda_path): os.remove(worker_lda_path) arguments = self.lda_acc_stats_arguments() with tqdm.tqdm( total=self.num_current_utterances, disable=getattr(self, "quiet", False) ) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = LdaAccStatsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue done, errors = result pbar.update(done + errors) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = LdaAccStatsFunction(args) for done, errors in function.run(): pbar.update(done + errors) log_path = os.path.join(self.working_log_directory, "lda_est.log") acc_list = [] for x in arguments: acc_list.extend(x.acc_paths.values()) with open(log_path, "w", encoding="utf8") as log_file: est_lda_proc = subprocess.Popen( [ thirdparty_binary("est-lda"), f"--dim={self.lda_dimension}", lda_path, ] + acc_list, stderr=log_file, env=os.environ, ) est_lda_proc.communicate() shutil.copyfile( lda_path, worker_lda_path, )
def train_g2p_lexicon(self) -> None: """Generate a G2P lexicon based on aligned transcripts""" arguments = self.worker.generate_pronunciations_arguments() working_dir = super(PronunciationProbabilityTrainer, self).working_directory texts = {} with self.worker.session() as session: query = session.query(Utterance.id, Utterance.normalized_character_text) query = query.filter(Utterance.ignored == False) # noqa initial_brackets = "".join(x[0] for x in self.worker.brackets) query = query.filter( ~Utterance.oovs.regexp_match(f"(^| )[^{initial_brackets}]")) if self.subset: query = query.filter_by(in_subset=True) for utt_id, text in query: texts[utt_id] = text input_files = { x: open( os.path.join( working_dir, f"input_{self.worker.dictionary_base_names[x]}.txt"), "w", encoding="utf8", ) for x in self.worker.dictionary_lookup.values() } output_files = { x: open( os.path.join( working_dir, f"output_{self.worker.dictionary_base_names[x]}.txt"), "w", encoding="utf8", ) for x in self.worker.dictionary_lookup.values() } with tqdm.tqdm(total=self.num_current_utterances, disable=getattr(self, "quiet", False)) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): args.for_g2p = True function = GeneratePronunciationsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue dict_id, utt_id, phones = result utt_id = int(utt_id.split("-")[-1]) pbar.update(1) if utt_id not in texts or not texts[utt_id]: continue print(phones, file=output_files[dict_id]) print(f"<s> {texts[utt_id]} </s>", file=input_files[dict_id]) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: self.log_debug("Not using multiprocessing...") for args in arguments: function = GeneratePronunciationsFunction(args) for dict_id, utt_id, phones in function.run(): print(phones, file=output_files[dict_id]) print(f"<s> {texts[utt_id]} </s>", file=input_files[dict_id]) pbar.update(1) for f in input_files.values(): f.close() for f in output_files.values(): f.close() self.pronunciations_complete = True os.makedirs(self.working_log_directory, exist_ok=True) dictionaries = session.query(Dictionary) shutil.copyfile(self.phone_symbol_table_path, os.path.join(self.working_directory, "phones.txt")) shutil.copyfile( self.grapheme_symbol_table_path, os.path.join(self.working_directory, "graphemes.txt"), ) self.input_token_type = self.grapheme_symbol_table_path self.output_token_type = self.phone_symbol_table_path for d in dictionaries: self.log_info(f"Training G2P for {d.name}...") self._data_source = self.worker.dictionary_base_names[d.id] begin = time.time() if os.path.exists(self.far_path) and os.path.exists( self.encoder_path): self.log_info("Alignment already done, skipping!") else: self.align_g2p() self.log_debug( f"Aligning utterances for {d.name} took {time.time() - begin} seconds" ) begin = time.time() self.generate_model() self.log_debug( f"Generating model for {d.name} took {time.time() - begin} seconds" ) os.rename(d.lexicon_fst_path, d.lexicon_fst_path + ".backup") shutil.copy(self.fst_path, d.lexicon_fst_path) d.use_g2p = True session.commit() self.worker.use_g2p = True
def generate_pronunciations(self) -> Dict[str, List[str]]: """ Generate pronunciations Returns ------- dict[str, list[str]] Mappings of keys to their generated pronunciations """ fst = pynini.Fst.read(self.g2p_model.fst_path) if self.g2p_model.meta["architecture"] == "phonetisaurus": output_token_type = pynini.SymbolTable.read_text( self.g2p_model.sym_path) input_token_type = pynini.SymbolTable.read_text( self.g2p_model.grapheme_sym_path) fst.set_input_symbols(input_token_type) fst.set_output_symbols(output_token_type) rewriter = PhonetisaurusRewriter( fst, input_token_type, output_token_type, num_pronunciations=self.num_pronunciations, threshold=self.g2p_threshold, grapheme_order=self.g2p_model.meta["grapheme_order"], ) else: output_token_type = "utf8" input_token_type = "utf8" if self.g2p_model.sym_path is not None and os.path.exists( self.g2p_model.sym_path): output_token_type = pynini.SymbolTable.read_text( self.g2p_model.sym_path) rewriter = Rewriter( fst, input_token_type, output_token_type, num_pronunciations=self.num_pronunciations, threshold=self.g2p_threshold, ) num_words = len(self.words_to_g2p) begin = time.time() missing_graphemes = set() self.log_info("Generating pronunciations...") to_return = {} skipped_words = 0 if num_words < 30 or self.num_jobs == 1: with tqdm.tqdm(total=num_words, disable=getattr(self, "quiet", False)) as pbar: for word in self.words_to_g2p: w, m = clean_up_word(word, self.g2p_model.meta["graphemes"]) pbar.update(1) missing_graphemes = missing_graphemes | m if self.strict_graphemes and m: skipped_words += 1 continue if not w: skipped_words += 1 continue try: prons = rewriter(w) except rewrite.Error: continue to_return[word] = prons self.log_debug( f"Skipping {skipped_words} words for containing the following graphemes: " f"{comma_join(sorted(missing_graphemes))}") else: stopped = Stopped() job_queue = mp.Queue() for word in self.words_to_g2p: w, m = clean_up_word(word, self.g2p_model.meta["graphemes"]) missing_graphemes = missing_graphemes | m if self.strict_graphemes and m: skipped_words += 1 continue if not w: skipped_words += 1 continue job_queue.put(w) self.log_debug( f"Skipping {skipped_words} words for containing the following graphemes: " f"{comma_join(sorted(missing_graphemes))}") error_dict = {} return_queue = mp.Queue() procs = [] for _ in range(self.num_jobs): p = RewriterWorker( job_queue, return_queue, rewriter, stopped, ) procs.append(p) p.start() num_words -= skipped_words with tqdm.tqdm(total=num_words, disable=getattr(self, "quiet", False)) as pbar: while True: try: word, result = return_queue.get(timeout=1) if stopped.stop_check(): continue except queue.Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue pbar.update(1) if isinstance(result, Exception): error_dict[word] = result continue to_return[word] = result for p in procs: p.join() if error_dict: raise PyniniGenerationError(error_dict) self.log_debug( f"Processed {num_words} in {time.time() - begin} seconds") return to_return
def _alignments(self) -> None: """Trains the aligner and constructs the alignments FAR.""" if not os.path.exists(self.align_path): self.log_info("Training aligner") train_opts = [] if self.batch_size: train_opts.append(f"--batch_size={self.batch_size}") if self.delta: train_opts.append(f"--delta={self.delta}") if self.fst_default_cache_gc: train_opts.append( f"--fst_default_cache_gc={self.fst_default_cache_gc}") if self.fst_default_cache_gc_limit: train_opts.append( f"--fst_default_cache_gc_limit={self.fst_default_cache_gc_limit}" ) if self.alpha: train_opts.append(f"--alpha={self.alpha}") if self.num_iterations: train_opts.append(f"--max_iters={self.num_iterations}") # Constructs the actual command vectors (plus an index for logging # purposes). random.seed(self.seed) starts = [(RandomStart( idx, seed, self.input_far_path, self.output_far_path, self.cg_path, self.working_directory, train_opts, )) for (idx, seed) in enumerate( random.sample(range(1, RAND_MAX), self.random_starts), 1)] stopped = Stopped() num_commands = len(starts) job_queue = mp.JoinableQueue() fst_likelihoods = {} # Actually runs starts. self.log_info("Calculating alignments...") begin = time.time() with tqdm.tqdm(total=num_commands * self.num_iterations, disable=getattr(self, "quiet", False)) as pbar: for start in starts: job_queue.put(start) error_dict = {} return_queue = mp.Queue() procs = [] for i in range(self.num_jobs): log_path = os.path.join(self.working_log_directory, f"baumwelch.{i}.log") p = RandomStartWorker( i, job_queue, return_queue, log_path, stopped, ) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except queue.Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue if isinstance(result, int): pbar.update(result) else: fst_likelihoods[result[0]] = result[1] for p in procs: p.join() if error_dict: raise PyniniAlignmentError(error_dict) (best_fst, best_likelihood) = min(fst_likelihoods.items(), key=operator.itemgetter(1)) self.log_info(f"Best likelihood: {best_likelihood}") self.log_debug( f"Ran {self.random_starts} random starts in {time.time() - begin} seconds" ) # Moves best likelihood solution to the requested location. shutil.move(best_fst, self.align_path) cmd = [thirdparty_binary("baumwelchdecode")] if self.fst_default_cache_gc: cmd.append(f"--fst_default_cache_gc={self.fst_default_cache_gc}") if self.fst_default_cache_gc_limit: cmd.append( f"--fst_default_cache_gc_limit={self.fst_default_cache_gc_limit}" ) cmd.append(self.input_far_path) cmd.append(self.output_far_path) cmd.append(self.align_path) cmd.append(self.afst_path) self.log_debug(f"Subprocess call: {cmd}") subprocess.check_call(cmd, env=os.environ) self.log_info("Completed computing alignments!")
def mfcc(self) -> None: """ Multiprocessing function that converts sound files into MFCCs. See :kaldi_docs:`feat` for an overview on feature generation in Kaldi. See Also -------- :class:`~montreal_forced_aligner.corpus.features.MfccFunction` Multiprocessing helper function for each job :meth:`.AcousticCorpusMixin.mfcc_arguments` Job method for generating arguments for helper function :kaldi_steps:`make_mfcc` Reference Kaldi script """ self.log_info("Generating MFCCs...") log_directory = os.path.join(self.split_directory, "log") os.makedirs(log_directory, exist_ok=True) arguments = self.mfcc_arguments() with tqdm.tqdm(total=self.num_utterances, disable=getattr(self, "quiet", False)) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = MfccFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue pbar.update(result) for p in procs: p.join() if error_dict: for v in error_dict.values(): print(v) self.dirty = True sys.exit(1) else: for args in arguments: function = MfccFunction(args) for num_utterances in function.run(): pbar.update(num_utterances) with self.session() as session: update_mapping = [] session.query(Utterance).update({"ignored": True}) for j in arguments: with open(j.feats_scp_path, "r", encoding="utf8") as f: for line in f: line = line.strip() if line == "": continue f = line.split(maxsplit=1) utt_id = int(f[0].split("-")[-1]) feats = f[1] update_mapping.append({ "id": utt_id, "features": feats, "ignored": False }) session.bulk_update_mappings(Utterance, update_mapping) session.commit()
def acc_stats(self) -> None: """ Multiprocessing function that accumulates stats for GMM training. See Also -------- :class:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsFunction` Multiprocessing helper function for each job :meth:`.AcousticModelTrainingMixin.acc_stats_arguments` Job method for generating arguments for the helper function :kaldi_src:`gmm-sum-accs` Relevant Kaldi binary :kaldi_src:`gmm-est` Relevant Kaldi binary :kaldi_steps:`train_mono` Reference Kaldi script :kaldi_steps:`train_deltas` Reference Kaldi script """ self.log_info("Accumulating statistics...") arguments = self.acc_stats_arguments() with tqdm.tqdm(total=self.num_current_utterances, disable=getattr(self, "quiet", False)) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = AccStatsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue num_utterances, errors = result pbar.update(num_utterances + errors) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = AccStatsFunction(args) for num_utterances, errors in function.run(): pbar.update(num_utterances + errors) log_path = os.path.join(self.working_log_directory, f"update.{self.iteration}.log") with open(log_path, "w") as log_file: acc_files = [] for a in arguments: acc_files.extend(a.acc_paths.values()) sum_proc = subprocess.Popen( [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) est_command = [ thirdparty_binary("gmm-est"), f"--write-occs={self.next_occs_path}", f"--mix-up={self.current_gaussians}", ] if self.power > 0: est_command.append(f"--power={self.power}") est_command.extend([ self.model_path, "-", self.next_model_path, ]) est_proc = subprocess.Popen( est_command, stdin=sum_proc.stdout, stderr=log_file, env=os.environ, ) est_proc.communicate() avg_like_pattern = re.compile( r"Overall avg like per frame.* = (?P<like>[-.,\d]+) over (?P<frames>[.\d+e]+) frames" ) average_logdet_pattern = re.compile( r"Overall average logdet is (?P<logdet>[-.,\d]+) over (?P<frames>[.\d+e]+) frames" ) avg_like_sum = 0 avg_like_frames = 0 average_logdet_sum = 0 average_logdet_frames = 0 for a in arguments: with open(a.log_path, "r", encoding="utf8") as f: for line in f: m = avg_like_pattern.search(line) if m: like = float(m.group("like")) frames = float(m.group("frames")) avg_like_sum += like * frames avg_like_frames += frames m = average_logdet_pattern.search(line) if m: logdet = float(m.group("logdet")) frames = float(m.group("frames")) average_logdet_sum += logdet * frames average_logdet_frames += frames if avg_like_frames: log_like = avg_like_sum / avg_like_frames if average_logdet_frames: log_like += average_logdet_sum / average_logdet_frames self.log_debug( f"Likelihood for iteration {self.iteration}: {log_like}") if not self.debug: for f in acc_files: os.remove(f)
def acc_global_stats(self) -> None: """ Multiprocessing function that accumulates global GMM stats See Also -------- :func:`~montreal_forced_aligner.ivector.trainer.AccGlobalStatsFunction` Multiprocessing helper function for each job :meth:`.DubmTrainer.acc_global_stats_arguments` Job method for generating arguments for the helper function :kaldi_src:`gmm-global-sum-accs` Relevant Kaldi binary :kaldi_steps:`train_diag_ubm` Reference Kaldi script """ begin = time.time() self.log_info("Accumulating global stats...") arguments = self.acc_global_stats_arguments() if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = AccGlobalStatsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except queue.Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: self.log_debug("Not using multiprocessing...") for args in arguments: function = AccGlobalStatsFunction(args) for _ in function.run(): pass self.log_debug(f"Accumulating stats took {time.time() - begin}") # Don't remove low-count Gaussians till the last tier, # or gselect info won't be valid anymore if self.iteration < self.num_iterations: opt = "--remove-low-count-gaussians=false" else: opt = f"--remove-low-count-gaussians={self.remove_low_count_gaussians}" log_path = os.path.join(self.working_log_directory, f"update.{self.iteration}.log") with open(log_path, "w") as log_file: acc_files = [] for j in arguments: acc_files.append(j.acc_path) sum_proc = subprocess.Popen( [thirdparty_binary("gmm-global-sum-accs"), "-"] + acc_files, stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) gmm_global_est_proc = subprocess.Popen( [ thirdparty_binary("gmm-global-est"), opt, f"--min-gaussian-weight={self.min_gaussian_weight}", self.model_path, "-", self.next_model_path, ], stderr=log_file, stdin=sum_proc.stdout, env=os.environ, ) gmm_global_est_proc.communicate() # Clean up if not self.debug: for p in acc_files: os.remove(p)
def _load_corpus_from_source_mp(self) -> None: """ Load a corpus using multiprocessing """ begin_time = time.process_time() job_queue = mp.Queue() return_queue = mp.Queue() finished_adding = Stopped() stopped = Stopped() file_counts = Counter() sanitize_function = getattr(self, "sanitize_function", None) error_dict = {} procs = [] self.db_engine.dispose() parser = AcousticDirectoryParser( self.corpus_directory, job_queue, self.audio_directory, stopped, finished_adding, file_counts, ) parser.start() for i in range(self.num_jobs): p = CorpusProcessWorker( i, job_queue, return_queue, stopped, finished_adding, self.speaker_characters, sanitize_function, self.sample_frequency, ) procs.append(p) p.start() last_poll = time.time() - 30 import_data = DatabaseImportData() try: with self.session() as session: with tqdm.tqdm(total=100, disable=getattr(self, "quiet", False)) as pbar: while True: try: file = return_queue.get(timeout=1) if isinstance(file, tuple): error_type = file[0] error = file[1] if error_type == "error": error_dict[error_type] = error else: if error_type not in error_dict: error_dict[error_type] = [] error_dict[error_type].append(error) continue if self.stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished_processing.stop_check(): break else: break continue if time.time() - last_poll > 15: pbar.total = file_counts.value() last_poll = time.time() pbar.update(1) import_data.add_objects( self.generate_import_objects(file)) self.log_debug( f"Processing queue: {time.process_time() - begin_time}" ) if "error" in error_dict: session.rollback() raise error_dict["error"][1] self._finalize_load(session, import_data) for k in [ "sound_file_errors", "decode_error_files", "textgrid_read_errors" ]: if hasattr(self, k): if k in error_dict: self.log_info( "There were some issues with files in the corpus. " "Please look at the log file or run the validator for more information." ) self.log_debug( f"{k} showed {len(error_dict[k])} errors:") if k in {"textgrid_read_errors", "sound_file_errors"}: getattr(self, k).update(error_dict[k]) for e in error_dict[k]: self.log_debug(f"{e.file_name}: {e.error}") else: self.log_debug(", ".join(error_dict[k])) setattr(self, k, error_dict[k]) except Exception as e: if isinstance(e, KeyboardInterrupt): self.log_info( "Detected ctrl-c, please wait a moment while we clean everything up..." ) self.stopped.set_sigint_source() self.stopped.stop() finished_adding.stop() while True: try: _ = job_queue.get(timeout=1) if self.stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished_processing.stop_check(): break else: break try: _ = return_queue.get(timeout=1) _ = job_queue.get(timeout=1) if self.stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished_processing.stop_check(): break else: break finally: parser.join() for p in procs: p.join() if self.stopped.stop_check(): self.log_info( f"Stopped parsing early ({time.process_time() - begin_time} seconds)" ) if self.stopped.source(): sys.exit(0) else: self.log_debug( f"Parsed corpus directory with {self.num_jobs} jobs in {time.process_time() - begin_time} seconds" )
def acc_ivector_stats(self) -> None: """ Multiprocessing function that accumulates ivector extraction stats. See Also -------- :func:`~montreal_forced_aligner.ivector.trainer.AccIvectorStatsFunction` Multiprocessing helper function for each job :meth:`.IvectorTrainer.acc_ivector_stats_arguments` Job method for generating arguments for the helper function :kaldi_src:`ivector-extractor-sum-accs` Relevant Kaldi binary :kaldi_src:`ivector-extractor-est` Relevant Kaldi binary :kaldi_steps_sid:`train_ivector_extractor` Reference Kaldi script """ begin = time.time() self.log_info("Accumulating ivector stats...") arguments = self.acc_ivector_stats_arguments() if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = AccIvectorStatsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if stopped.stop_check(): continue except queue.Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue if isinstance(result, KaldiProcessingError): error_dict[result.job_name] = result continue for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: self.log_debug("Not using multiprocessing...") for args in arguments: function = AccIvectorStatsFunction(args) for _ in function.run(): pass self.log_debug(f"Accumulating stats took {time.time() - begin}") log_path = os.path.join(self.working_log_directory, f"sum_acc.{self.iteration}.log") acc_path = os.path.join(self.working_directory, f"acc.{self.iteration}") with open(log_path, "w", encoding="utf8") as log_file: accinits = [] for j in arguments: accinits.append(j.acc_init_path) sum_accs_proc = subprocess.Popen( [ thirdparty_binary("ivector-extractor-sum-accs"), "--parallel=true" ] + accinits + [acc_path], stderr=log_file, env=os.environ, ) sum_accs_proc.communicate() # clean up for p in accinits: os.remove(p) # Est extractor log_path = os.path.join(self.working_log_directory, f"update.{self.iteration}.log") with open(log_path, "w") as log_file: extractor_est_proc = subprocess.Popen( [ thirdparty_binary("ivector-extractor-est"), f"--num-threads={len(self.jobs)}", f"--gaussian-min-count={self.gaussian_min_count}", self.ie_path, os.path.join(self.working_directory, f"acc.{self.iteration}"), self.next_ie_path, ], stderr=log_file, env=os.environ, ) extractor_est_proc.communicate()
def create_align_model(self) -> None: """ Create alignment model for speaker-adapted training that will use speaker-independent features in later aligning. See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.sat.AccStatsTwoFeatsFunction` Multiprocessing helper function for each job :meth:`.SatTrainer.acc_stats_two_feats_arguments` Job method for generating arguments for the helper function :kaldi_src:`gmm-est` Relevant Kaldi binary :kaldi_src:`gmm-sum-accs` Relevant Kaldi binary :kaldi_steps:`train_sat` Reference Kaldi script """ self.log_info( "Creating alignment model for speaker-independent features...") begin = time.time() arguments = self.acc_stats_two_feats_arguments() with tqdm.tqdm(total=self.num_current_utterances, disable=getattr(self, "quiet", False)) as pbar: if self.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = AccStatsTwoFeatsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue pbar.update(1) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = AccStatsTwoFeatsFunction(args) for _ in function.run(): pbar.update(1) log_path = os.path.join(self.working_log_directory, "align_model_est.log") with open(log_path, "w", encoding="utf8") as log_file: acc_files = [] for x in arguments: acc_files.extend(x.acc_paths.values()) sum_proc = subprocess.Popen( [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) est_command = [ thirdparty_binary("gmm-est"), "--remove-low-count-gaussians=false", ] if not self.quick: est_command.append(f"--power={self.power}") else: est_command.append( f"--write-occs={os.path.join(self.working_directory, 'final.occs')}" ) est_command.extend([ self.model_path, "-", self.model_path.replace(".mdl", ".alimdl"), ]) est_proc = subprocess.Popen( est_command, stdin=sum_proc.stdout, stderr=log_file, env=os.environ, ) est_proc.communicate() parse_logs(self.working_log_directory) if not self.debug: for f in acc_files: os.remove(f) self.log_debug(f"Alignment model creation took {time.time() - begin}")