def infer_main(base_dir: str, train_name: str, text_name: str, speaker: str, waveglow: str = DEFAULT_WAVEGLOW, custom_checkpoint: Optional[int] = None, sentence_pause_s: float = DEFAULT_SENTENCE_PAUSE_S, sigma: float = DEFAULT_SIGMA, denoiser_strength: float = DEFAULT_DENOISER_STRENGTH, analysis: bool = True, custom_tacotron_hparams: Optional[Dict[str, str]] = None, custom_waveglow_hparams: Optional[Dict[str, str]] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) logger = get_default_logger() init_logger(logger) add_console_out_to_logger(logger) logger.info("Inferring...") checkpoint_path, iteration = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), custom_checkpoint) taco_checkpoint = CheckpointTacotron.load(checkpoint_path, logger) prep_name = load_prep_name(train_dir) infer_sents = get_infer_sentences(base_dir, prep_name, text_name) infer_dir = get_infer_dir(train_dir, text_name, iteration, speaker) add_file_out_to_logger(logger, get_infer_log(infer_dir)) train_dir_wg = get_wg_train_dir(base_dir, waveglow, create=False) wg_checkpoint_path, _ = get_last_checkpoint(get_checkpoints_dir(train_dir_wg)) wg_checkpoint = CheckpointWaveglow.load(wg_checkpoint_path, logger) wav, inference_results = infer( tacotron_checkpoint=taco_checkpoint, waveglow_checkpoint=wg_checkpoint, speaker=speaker, sentence_pause_s=sentence_pause_s, sigma=sigma, denoiser_strength=denoiser_strength, sentences=infer_sents, custom_taco_hparams=custom_tacotron_hparams, custom_wg_hparams=custom_waveglow_hparams, logger=logger ) logger.info("Saving wav...") sampling_rate = inference_results[0].sampling_rate save_infer_wav(infer_dir, sampling_rate, wav) if analysis: logger.info("Analysing...") infer_res: InferenceResult for infer_res in tqdm(inference_results): save_infer_wav_sentence(infer_dir, infer_res) save_infer_sentence_plot(infer_dir, infer_res) save_infer_pre_postnet_sentence_plot(infer_dir, infer_res) save_infer_alignments_sentence_plot(infer_dir, infer_res) sent_ids = [x.sentence.sent_id for x in inference_results] save_infer_v_plot(infer_dir, sent_ids) save_infer_h_plot(infer_dir, sent_ids) save_infer_v_pre_post(infer_dir, sent_ids) save_infer_v_alignments(infer_dir, sent_ids) logger.info(f"Saved output to {infer_dir}")
def validate_main(base_dir: str, train_name: str, waveglow: str = DEFAULT_WAVEGLOW, entry_id: Optional[int] = None, speaker: Optional[str] = None, ds: str = "val", custom_checkpoint: Optional[int] = None, sigma: float = DEFAULT_SIGMA, denoiser_strength: float = DEFAULT_DENOISER_STRENGTH, custom_tacotron_hparams: Optional[Dict[str, str]] = None, custom_waveglow_hparams: Optional[Dict[str, str]] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) if ds == "val": data = load_valset(train_dir) elif ds == "test": data = load_testset(train_dir) else: assert False speaker_id: Optional[int] = None if speaker is not None: prep_name = load_prep_name(train_dir) prep_dir = get_prepared_dir(base_dir, prep_name, create=False) speakers = load_prep_speakers_json(prep_dir) speaker_id = speakers.get_id(speaker) entry = data.get_for_validation(entry_id, speaker_id) checkpoint_path, iteration = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), custom_checkpoint) val_dir = get_val_dir(train_dir, entry, iteration) logger = prepare_logger(get_val_log(val_dir)) logger.info("Validating...") taco_checkpoint = CheckpointTacotron.load(checkpoint_path, logger) train_dir_wg = get_wg_train_dir(base_dir, waveglow, create=False) wg_checkpoint_path, _ = get_last_checkpoint(get_checkpoints_dir(train_dir_wg)) wg_checkpoint = CheckpointWaveglow.load(wg_checkpoint_path, logger) result = validate( tacotron_checkpoint=taco_checkpoint, waveglow_checkpoint=wg_checkpoint, sigma=sigma, denoiser_strength=denoiser_strength, entry=entry, logger=logger, custom_taco_hparams=custom_tacotron_hparams, custom_wg_hparams=custom_waveglow_hparams ) orig_mel = get_mel(entry.wav_path, custom_hparams=custom_waveglow_hparams) save_val_orig_wav(val_dir, entry.wav_path) save_val_orig_plot(val_dir, orig_mel) save_val_wav(val_dir, result.sampling_rate, result.wav) save_val_plot(val_dir, result.mel_outputs) save_val_pre_postnet_plot(val_dir, result.mel_outputs_postnet) save_val_alignments_sentence_plot(val_dir, result.alignments) save_val_comparison(val_dir) logger.info(f"Saved output to: {val_dir}")
def dl_pretrained(base_dir: str, train_name: str = DEFAULT_WAVEGLOW, prep_name: Optional[str] = None, version: int = 3): train_dir = get_train_dir(base_dir, train_name, create=True) assert os.path.isdir(train_dir) checkpoints_dir = get_checkpoints_dir(train_dir) dest_path = get_checkpoint_pretrained(checkpoints_dir) print("Downloading pretrained waveglow model from Nvida...") dl_wg( destination=dest_path, version=version ) print("Pretrained model is now beeing converted to be able to use it...") convert_glow( origin=dest_path, destination=dest_path, keep_orig=False ) if prep_name is not None: prep_dir = get_prepared_dir(base_dir, prep_name) wholeset = load_filelist(prep_dir) save_testset(train_dir, wholeset) save_valset(train_dir, wholeset) save_prep_name(train_dir, prep_name=prep_name)
def continue_train_main(base_dir: str, train_name: str, custom_hparams: Optional[Dict[str, str]] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) logs_dir = get_train_logs_dir(train_dir) taco_logger = Tacotron2Logger(logs_dir) logger = prepare_logger(get_train_log_file(logs_dir)) checkpoint_logger = prepare_logger( log_file_path=get_train_checkpoints_log_file(logs_dir), logger=logging.getLogger("checkpoint-logger")) checkpoints_dir = get_checkpoints_dir(train_dir) last_checkpoint_path, _ = get_last_checkpoint(checkpoints_dir) last_checkpoint = CheckpointTacotron.load(last_checkpoint_path, logger) save_callback = partial( save_checkpoint, save_checkpoint_dir=checkpoints_dir, logger=logger, ) continue_train(checkpoint=last_checkpoint, custom_hparams=custom_hparams, taco_logger=taco_logger, trainset=load_trainset(train_dir), valset=load_valset(train_dir), logger=logger, checkpoint_logger=checkpoint_logger, save_callback=save_callback)
def infer(base_dir: str, train_name: str, wav_path: str, custom_checkpoint: Optional[int] = None, sigma: float = 0.666, denoiser_strength: float = 0.00, custom_hparams: Optional[Dict[str, str]] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) checkpoint_path, iteration = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), custom_checkpoint) infer_dir = get_infer_dir(train_dir, wav_path, iteration) logger = prepare_logger(get_infer_log(infer_dir)) logger.info(f"Inferring {wav_path}...") checkpoint = CheckpointWaveglow.load(checkpoint_path, logger) wav, wav_sr, wav_mel, orig_mel = infer_core( wav_path=wav_path, denoiser_strength=denoiser_strength, sigma=sigma, checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger ) save_infer_wav(infer_dir, wav_sr, wav) save_infer_plot(infer_dir, wav_mel) save_infer_orig_wav(infer_dir, wav_path) save_infer_orig_plot(infer_dir, orig_mel) score = save_diff_plot(infer_dir) save_v(infer_dir) logger.info(f"Imagescore: {score*100}%") logger.info(f"Saved output to: {infer_dir}")
def try_load_checkpoint(base_dir: str, train_name: Optional[str], checkpoint: Optional[int], logger: Logger) -> Optional[CheckpointWaveglow]: result = None if train_name: train_dir = get_train_dir(base_dir, train_name, False) checkpoint_path, _ = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), checkpoint) result = CheckpointWaveglow.load(checkpoint_path, logger) return result
def try_load_checkpoint(base_dir: str, train_name: Optional[str], checkpoint: Optional[int], logger: Logger) -> Optional[CheckpointTacotron]: result = None if train_name: train_dir = get_train_dir(base_dir, train_name, False) checkpoint_path, _ = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), checkpoint) result = CheckpointTacotron.load(checkpoint_path, logger) logger.info(f"Using warm start model: {checkpoint_path}") return result
def validate(base_dir: str, train_name: str, entry_id: Optional[int] = None, speaker: Optional[str] = None, ds: str = "val", custom_checkpoint: Optional[int] = None, sigma: float = 0.666, denoiser_strength: float = 0.00, custom_hparams: Optional[Dict[str, str]] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) if ds == "val": data = load_valset(train_dir) elif ds == "test": data = load_testset(train_dir) else: raise Exception() speaker_id: Optional[int] = None if speaker is not None: prep_name = load_prep_name(train_dir) prep_dir = get_prepared_dir(base_dir, prep_name, create=False) speakers = load_prep_speakers_json(prep_dir) speaker_id = speakers.get_id(speaker) entry = data.get_for_validation(entry_id, speaker_id) checkpoint_path, iteration = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), custom_checkpoint) val_dir = get_val_dir(train_dir, entry, iteration) logger = prepare_logger(get_val_log(val_dir)) logger.info(f"Validating {entry.wav_path}...") checkpoint = CheckpointWaveglow.load(checkpoint_path, logger) wav, wav_sr, wav_mel, orig_mel = infer(wav_path=entry.wav_path, denoiser_strength=denoiser_strength, sigma=sigma, checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger) save_val_wav(val_dir, wav_sr, wav) save_val_plot(val_dir, wav_mel) save_val_orig_wav(val_dir, entry.wav_path) save_val_orig_plot(val_dir, orig_mel) score = save_diff_plot(val_dir) save_v(val_dir) logger.info(f"Imagescore: {score*100}%") logger.info(f"Saved output to: {val_dir}")
def continue_training(base_dir: str, train_name: str, custom_hparams: Optional[Dict[str, str]] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) logs_dir = get_train_logs_dir(train_dir) logger = prepare_logger(get_train_log_file(logs_dir)) continue_train(custom_hparams=custom_hparams, logdir=logs_dir, trainset=load_trainset(train_dir), valset=load_valset(train_dir), save_checkpoint_dir=get_checkpoints_dir(train_dir), debug_logger=logger)
def start_new_training(base_dir: str, train_name: str, prep_name: str, test_size: float = 0.01, validation_size: float = 0.01, custom_hparams: Optional[Dict[str, str]] = None, split_seed: int = 1234, warm_start_train_name: Optional[str] = None, warm_start_checkpoint: Optional[int] = None): prep_dir = get_prepared_dir(base_dir, prep_name) wholeset = load_filelist(prep_dir) trainset, testset, valset = split_prepared_data_train_test_val( wholeset, test_size=test_size, validation_size=validation_size, seed=split_seed, shuffle=True) train_dir = get_train_dir(base_dir, train_name, create=True) save_trainset(train_dir, trainset) save_testset(train_dir, testset) save_valset(train_dir, valset) logs_dir = get_train_logs_dir(train_dir) logger = prepare_logger(get_train_log_file(logs_dir), reset=True) warm_model = try_load_checkpoint(base_dir=base_dir, train_name=warm_start_train_name, checkpoint=warm_start_checkpoint, logger=logger) save_prep_name(train_dir, prep_name) train( custom_hparams=custom_hparams, logdir=logs_dir, trainset=trainset, valset=valset, save_checkpoint_dir=get_checkpoints_dir(train_dir), debug_logger=logger, warm_model=warm_model, )
def plot_embeddings(base_dir: str, train_name: str, custom_checkpoint: Optional[int] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) analysis_dir = get_analysis_root_dir(train_dir) logger = prepare_logger() checkpoint_path, checkpoint_it = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), custom_checkpoint) checkpoint = CheckpointTacotron.load(checkpoint_path, logger) # pylint: disable=no-member text, fig_2d, fig_3d = plot_embeddings_core( symbols=checkpoint.get_symbols(), emb=checkpoint.get_symbol_embedding_weights(), logger=logger ) _save_similarities_csv(analysis_dir, checkpoint_it, text) _save_2d_plot(analysis_dir, checkpoint_it, fig_2d) _save_3d_plot(analysis_dir, checkpoint_it, fig_3d) logger.info(f"Saved analysis to: {analysis_dir}")
def eval_checkpoints_main(base_dir: str, train_name: str, select: int, min_it: int, max_it: int): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) prep_name = load_prep_name(train_dir) prep_dir = get_prepared_dir(base_dir, prep_name) symbols_conv = load_prep_symbol_converter(prep_dir) speakers = load_prep_speakers_json(prep_dir) accents = load_prep_accents_ids(prep_dir) logger = prepare_logger() eval_checkpoints(custom_hparams=None, checkpoint_dir=get_checkpoints_dir(train_dir), select=select, min_it=min_it, max_it=max_it, n_symbols=len(symbols_conv), n_speakers=len(speakers), n_accents=len(accents), valset=load_valset(train_dir), logger=logger)
def train_main(base_dir: str, train_name: str, prep_name: str, warm_start_train_name: Optional[str] = None, warm_start_checkpoint: Optional[int] = None, test_size: float = 0.01, validation_size: float = 0.05, custom_hparams: Optional[Dict[str, str]] = None, split_seed: int = 1234, weights_train_name: Optional[str] = None, weights_checkpoint: Optional[int] = None, use_weights_map: Optional[bool] = None, map_from_speaker: Optional[str] = None): prep_dir = get_prepared_dir(base_dir, prep_name) train_dir = get_train_dir(base_dir, train_name, create=True) logs_dir = get_train_logs_dir(train_dir) taco_logger = Tacotron2Logger(logs_dir) logger = prepare_logger(get_train_log_file(logs_dir), reset=True) checkpoint_logger = prepare_logger( log_file_path=get_train_checkpoints_log_file(logs_dir), logger=logging.getLogger("checkpoint-logger"), reset=True) save_prep_name(train_dir, prep_name) trainset, valset = split_dataset(prep_dir=prep_dir, train_dir=train_dir, test_size=test_size, validation_size=validation_size, split_seed=split_seed) weights_model = try_load_checkpoint(base_dir=base_dir, train_name=weights_train_name, checkpoint=weights_checkpoint, logger=logger) weights_map = None if use_weights_map is not None and use_weights_map: weights_train_dir = get_train_dir(base_dir, weights_train_name, False) weights_prep_name = load_prep_name(weights_train_dir) weights_map = load_weights_map(prep_dir, weights_prep_name) warm_model = try_load_checkpoint(base_dir=base_dir, train_name=warm_start_train_name, checkpoint=warm_start_checkpoint, logger=logger) save_callback = partial( save_checkpoint, save_checkpoint_dir=get_checkpoints_dir(train_dir), logger=logger, ) train( custom_hparams=custom_hparams, taco_logger=taco_logger, symbols=load_prep_symbol_converter(prep_dir), speakers=load_prep_speakers_json(prep_dir), accents=load_prep_accents_ids(prep_dir), trainset=trainset, valset=valset, save_callback=save_callback, weights_map=weights_map, weights_checkpoint=weights_model, warm_model=warm_model, map_from_speaker_name=map_from_speaker, logger=logger, checkpoint_logger=checkpoint_logger, )