def restore_model(base_dir: str, train_name: str, checkpoint_dir: str) -> None: train_dir = get_train_dir(base_dir, train_name, create=True) logs_dir = get_train_logs_dir(train_dir) logger = prepare_logger(get_train_log_file(logs_dir), reset=True) save_checkpoint_dir = get_checkpoints_dir(train_dir) last_checkpoint, iteration = get_last_checkpoint(checkpoint_dir) logger.info(f"Restoring checkpoint {iteration} from {checkpoint_dir}...") shutil.copy2(last_checkpoint, save_checkpoint_dir) logger.info("Restoring done.")
def try_load_checkpoint(base_dir: str, train_name: Optional[str], checkpoint: Optional[int], logger: Logger) -> Optional[CheckpointTacotron]: result = None if train_name: train_dir = get_train_dir(base_dir, train_name, False) checkpoint_path, _ = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), checkpoint) result = CheckpointTacotron.load(checkpoint_path, logger) logger.info(f"Using warm start model: {checkpoint_path}") return result
def plot_embeddings(base_dir: str, train_name: str, custom_checkpoint: Optional[int] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) analysis_dir = get_analysis_root_dir(train_dir) logger = prepare_logger() checkpoint_path, checkpoint_it = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), custom_checkpoint) checkpoint = CheckpointTacotron.load(checkpoint_path, logger) # pylint: disable=no-member text, fig_2d, fig_3d = plot_embeddings_core( symbols=checkpoint.get_symbols(), emb=checkpoint.get_symbol_embedding_weights(), logger=logger) _save_similarities_csv(analysis_dir, checkpoint_it, text) _save_2d_plot(analysis_dir, checkpoint_it, fig_2d) _save_3d_plot(analysis_dir, checkpoint_it, fig_3d) logger.info(f"Saved analysis to: {analysis_dir}")
def continue_train(base_dir: str, train_name: str, custom_hparams: Optional[Dict[str, str]] = None) -> None: train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) logs_dir = get_train_logs_dir(train_dir) taco_logger = Tacotron2Logger(logs_dir) logger = prepare_logger(get_train_log_file(logs_dir)) checkpoint_logger = prepare_logger( log_file_path=get_train_checkpoints_log_file(logs_dir), logger=logging.getLogger("checkpoint-logger")) checkpoints_dir = get_checkpoints_dir(train_dir) last_checkpoint_path, _ = get_last_checkpoint(checkpoints_dir) last_checkpoint = CheckpointTacotron.load(last_checkpoint_path, logger) save_callback = partial( save_checkpoint, save_checkpoint_dir=checkpoints_dir, logger=logger, ) ttsp_dir, merge_name, prep_name = load_prep_settings(train_dir) merge_dir = get_merged_dir(ttsp_dir, merge_name, create=False) prep_dir = get_prep_dir(merge_dir, prep_name, create=False) trainset = load_trainset(prep_dir) valset = load_valset(prep_dir) continue_train_core(checkpoint=last_checkpoint, custom_hparams=custom_hparams, taco_logger=taco_logger, trainset=trainset, valset=valset, logger=logger, checkpoint_logger=checkpoint_logger, save_callback=save_callback)
def train(base_dir: str, ttsp_dir: str, train_name: str, merge_name: str, prep_name: str, warm_start_train_name: Optional[str] = None, warm_start_checkpoint: Optional[int] = None, custom_hparams: Optional[Dict[str, str]] = None, weights_train_name: Optional[str] = None, weights_checkpoint: Optional[int] = None, use_weights_map: Optional[bool] = None, map_from_speaker: Optional[str] = None) -> None: merge_dir = get_merged_dir(ttsp_dir, merge_name, create=False) prep_dir = get_prep_dir(merge_dir, prep_name, create=False) train_dir = get_train_dir(base_dir, train_name, create=True) logs_dir = get_train_logs_dir(train_dir) taco_logger = Tacotron2Logger(logs_dir) logger = prepare_logger(get_train_log_file(logs_dir), reset=True) checkpoint_logger = prepare_logger( log_file_path=get_train_checkpoints_log_file(logs_dir), logger=logging.getLogger("checkpoint-logger"), reset=True) save_prep_settings(train_dir, ttsp_dir, merge_name, prep_name) trainset = load_trainset(prep_dir) valset = load_valset(prep_dir) weights_model = try_load_checkpoint(base_dir=base_dir, train_name=weights_train_name, checkpoint=weights_checkpoint, logger=logger) weights_map = None if use_weights_map is not None and use_weights_map: weights_train_dir = get_train_dir(base_dir, weights_train_name, False) _, weights_merge_name, _ = load_prep_settings(weights_train_dir) weights_map = load_weights_map(merge_dir, weights_merge_name) warm_model = try_load_checkpoint(base_dir=base_dir, train_name=warm_start_train_name, checkpoint=warm_start_checkpoint, logger=logger) save_callback = partial( save_checkpoint, save_checkpoint_dir=get_checkpoints_dir(train_dir), logger=logger, ) train_core( custom_hparams=custom_hparams, taco_logger=taco_logger, symbols=load_merged_symbol_converter(merge_dir), speakers=load_merged_speakers_json(merge_dir), accents=load_merged_accents_ids(merge_dir), trainset=trainset, valset=valset, save_callback=save_callback, weights_map=weights_map, weights_checkpoint=weights_model, warm_model=warm_model, map_from_speaker_name=map_from_speaker, logger=logger, checkpoint_logger=checkpoint_logger, )
def infer(base_dir: str, train_name: str, text_name: str, speaker: str, sentence_ids: Optional[Set[int]] = None, custom_checkpoint: Optional[int] = None, full_run: bool = True, custom_hparams: Optional[Dict[str, str]] = None, max_decoder_steps: int = DEFAULT_MAX_DECODER_STEPS, seed: int = DEFAULT_SEED, copy_mel_info_to: Optional[str] = DEFAULT_SAVE_MEL_INFO_COPY_PATH): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) logger = get_default_logger() init_logger(logger) add_console_out_to_logger(logger) logger.info("Inferring...") checkpoint_path, iteration = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), custom_checkpoint) taco_checkpoint = CheckpointTacotron.load(checkpoint_path, logger) ttsp_dir, merge_name, _ = load_prep_settings(train_dir) # merge_dir = get_merged_dir(ttsp_dir, merge_name, create=False) infer_sents = get_infer_sentences(ttsp_dir, merge_name, text_name) run_name = get_run_name( input_name=text_name, full_run=full_run, iteration=iteration, speaker_name=speaker, ) infer_dir = get_infer_dir( train_dir=train_dir, run_name=run_name, ) add_file_out_to_logger(logger, get_infer_log_new(infer_dir)) mel_postnet_npy_paths: List[Dict[str, Any]] = [] save_callback = partial(save_results, infer_dir=infer_dir, mel_postnet_npy_paths=mel_postnet_npy_paths) inference_results = infer_core( checkpoint=taco_checkpoint, sentences=infer_sents, custom_hparams=custom_hparams, full_run=full_run, save_callback=save_callback, sentence_ids=sentence_ids, speaker_name=speaker, train_name=train_name, max_decoder_steps=max_decoder_steps, seed=seed, logger=logger, ) logger.info("Creating mel_postnet_v.png") save_mel_postnet_v_plot(infer_dir, inference_results) logger.info("Creating mel_postnet_h.png") save_mel_postnet_h_plot(infer_dir, inference_results) logger.info("Creating mel_v.png") save_mel_v_plot(infer_dir, inference_results) logger.info("Creating alignments_v.png") save_alignments_v_plot(infer_dir, inference_results) logger.info("Creating total.csv") save_stats(infer_dir, inference_results) npy_path = save_mel_postnet_npy_paths( infer_dir=infer_dir, name=run_name, mel_postnet_npy_paths=mel_postnet_npy_paths ) logger.info("Wrote all inferred mel paths including sampling rate into these file(s):") logger.info(npy_path) if copy_mel_info_to is not None: create_parent_folder(copy_mel_info_to) copyfile(npy_path, copy_mel_info_to) logger.info(copy_mel_info_to) logger.info(f"Saved output to: {infer_dir}")
def validate(base_dir: str, train_name: str, entry_ids: Optional[Set[int]] = None, speaker: Optional[str] = None, ds: str = "val", custom_checkpoints: Optional[Set[int]] = None, custom_hparams: Optional[Dict[str, str]] = None, full_run: bool = False, max_decoder_steps: int = DEFAULT_MAX_DECODER_STEPS, mcd_no_of_coeffs_per_frame: int = DEFAULT_MCD_NO_OF_COEFFS_PER_FRAME, copy_mel_info_to: Optional[str] = DEFAULT_SAVE_MEL_INFO_COPY_PATH, fast: bool = False, repetitions: int = DEFAULT_REPETITIONS, select_best_from: Optional[str] = None, seed: Optional[int] = DEFAULT_SEED) -> None: """Param: custom checkpoints: empty => all; None => random; ids""" assert repetitions > 0 train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) ttsp_dir, merge_name, prep_name = load_prep_settings(train_dir) merge_dir = get_merged_dir(ttsp_dir, merge_name, create=False) prep_dir = get_prep_dir(merge_dir, prep_name, create=False) if ds == "val": data = load_valset(prep_dir) elif ds == "test": data = load_testset(prep_dir) else: assert False iterations: Set[int] = set() checkpoint_dir = get_checkpoints_dir(train_dir) if custom_checkpoints is None: _, last_it = get_last_checkpoint(checkpoint_dir) iterations.add(last_it) else: if len(custom_checkpoints) == 0: iterations = set(get_all_checkpoint_iterations(checkpoint_dir)) else: iterations = custom_checkpoints run_name = get_run_name( ds=ds, entry_ids=entry_ids, full_run=full_run, iterations=iterations, speaker=speaker, ) val_dir = get_val_dir( train_dir=train_dir, run_name=run_name, ) val_log_path = os.path.join(val_dir, "log.txt") logger = prepare_logger(val_log_path) logger.info("Validating...") logger.info(f"Checkpoints: {','.join(str(x) for x in sorted(iterations))}") result = ValidationEntries() save_callback = None trainset = load_trainset(prep_dir) select_best_from_df = None if select_best_from is not None: select_best_from_df = pd.read_csv(select_best_from, sep="\t") for iteration in tqdm(sorted(iterations)): mel_postnet_npy_paths: List[str] = [] logger.info(f"Current checkpoint: {iteration}") checkpoint_path = get_checkpoint(checkpoint_dir, iteration) taco_checkpoint = CheckpointTacotron.load(checkpoint_path, logger) if not fast: save_callback = partial(save_results, val_dir=val_dir, iteration=iteration, mel_postnet_npy_paths=mel_postnet_npy_paths) validation_entries = validate_core( checkpoint=taco_checkpoint, data=data, trainset=trainset, custom_hparams=custom_hparams, entry_ids=entry_ids, full_run=full_run, speaker_name=speaker, train_name=train_name, logger=logger, max_decoder_steps=max_decoder_steps, fast=fast, save_callback=save_callback, mcd_no_of_coeffs_per_frame=mcd_no_of_coeffs_per_frame, repetitions=repetitions, seed=seed, select_best_from=select_best_from_df, ) result.extend(validation_entries) if len(result) == 0: return save_stats(val_dir, result) if not fast: logger.info("Wrote all inferred mel paths including sampling rate into these file(s):") npy_path = save_mel_postnet_npy_paths( val_dir=val_dir, name=run_name, mel_postnet_npy_paths=mel_postnet_npy_paths ) logger.info(npy_path) if copy_mel_info_to is not None: create_parent_folder(copy_mel_info_to) copyfile(npy_path, copy_mel_info_to) logger.info(copy_mel_info_to) logger.info(f"Saved output to: {val_dir}")