def dl_pretrained(base_dir: str, train_name: str = DEFAULT_WAVEGLOW, prep_name: Optional[str] = None, version: int = 3): train_dir = get_train_dir(base_dir, train_name, create=True) assert os.path.isdir(train_dir) checkpoints_dir = get_checkpoints_dir(train_dir) dest_path = get_checkpoint_pretrained(checkpoints_dir) print("Downloading pretrained waveglow model from Nvida...") dl_wg( destination=dest_path, version=version ) print("Pretrained model is now beeing converted to be able to use it...") convert_glow( origin=dest_path, destination=dest_path, keep_orig=False ) if prep_name is not None: prep_dir = get_prepared_dir(base_dir, prep_name) wholeset = load_filelist(prep_dir) save_testset(train_dir, wholeset) save_valset(train_dir, wholeset) save_prep_name(train_dir, prep_name=prep_name)
def infer(base_dir: str, train_name: str, wav_path: str, custom_checkpoint: Optional[int] = None, sigma: float = 0.666, denoiser_strength: float = 0.00, custom_hparams: Optional[Dict[str, str]] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) checkpoint_path, iteration = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), custom_checkpoint) infer_dir = get_infer_dir(train_dir, wav_path, iteration) logger = prepare_logger(get_infer_log(infer_dir)) logger.info(f"Inferring {wav_path}...") checkpoint = CheckpointWaveglow.load(checkpoint_path, logger) wav, wav_sr, wav_mel, orig_mel = infer_core( wav_path=wav_path, denoiser_strength=denoiser_strength, sigma=sigma, checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger ) save_infer_wav(infer_dir, wav_sr, wav) save_infer_plot(infer_dir, wav_mel) save_infer_orig_wav(infer_dir, wav_path) save_infer_orig_plot(infer_dir, orig_mel) score = save_diff_plot(infer_dir) save_v(infer_dir) logger.info(f"Imagescore: {score*100}%") logger.info(f"Saved output to: {infer_dir}")
def try_load_checkpoint(base_dir: str, train_name: Optional[str], checkpoint: Optional[int], logger: Logger) -> Optional[CheckpointWaveglow]: result = None if train_name: train_dir = get_train_dir(base_dir, train_name, False) checkpoint_path, _ = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), checkpoint) result = CheckpointWaveglow.load(checkpoint_path, logger) return result
def validate(base_dir: str, train_name: str, entry_id: Optional[int] = None, speaker: Optional[str] = None, ds: str = "val", custom_checkpoint: Optional[int] = None, sigma: float = 0.666, denoiser_strength: float = 0.00, custom_hparams: Optional[Dict[str, str]] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) if ds == "val": data = load_valset(train_dir) elif ds == "test": data = load_testset(train_dir) else: raise Exception() speaker_id: Optional[int] = None if speaker is not None: prep_name = load_prep_name(train_dir) prep_dir = get_prepared_dir(base_dir, prep_name, create=False) speakers = load_prep_speakers_json(prep_dir) speaker_id = speakers.get_id(speaker) entry = data.get_for_validation(entry_id, speaker_id) checkpoint_path, iteration = get_custom_or_last_checkpoint( get_checkpoints_dir(train_dir), custom_checkpoint) val_dir = get_val_dir(train_dir, entry, iteration) logger = prepare_logger(get_val_log(val_dir)) logger.info(f"Validating {entry.wav_path}...") checkpoint = CheckpointWaveglow.load(checkpoint_path, logger) wav, wav_sr, wav_mel, orig_mel = infer(wav_path=entry.wav_path, denoiser_strength=denoiser_strength, sigma=sigma, checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger) save_val_wav(val_dir, wav_sr, wav) save_val_plot(val_dir, wav_mel) save_val_orig_wav(val_dir, entry.wav_path) save_val_orig_plot(val_dir, orig_mel) score = save_diff_plot(val_dir) save_v(val_dir) logger.info(f"Imagescore: {score*100}%") logger.info(f"Saved output to: {val_dir}")
def continue_training(base_dir: str, train_name: str, custom_hparams: Optional[Dict[str, str]] = None): train_dir = get_train_dir(base_dir, train_name, create=False) assert os.path.isdir(train_dir) logs_dir = get_train_logs_dir(train_dir) logger = prepare_logger(get_train_log_file(logs_dir)) continue_train(custom_hparams=custom_hparams, logdir=logs_dir, trainset=load_trainset(train_dir), valset=load_valset(train_dir), save_checkpoint_dir=get_checkpoints_dir(train_dir), debug_logger=logger)
def start_new_training(base_dir: str, train_name: str, prep_name: str, test_size: float = 0.01, validation_size: float = 0.01, custom_hparams: Optional[Dict[str, str]] = None, split_seed: int = 1234, warm_start_train_name: Optional[str] = None, warm_start_checkpoint: Optional[int] = None): prep_dir = get_prepared_dir(base_dir, prep_name) wholeset = load_filelist(prep_dir) trainset, testset, valset = split_prepared_data_train_test_val( wholeset, test_size=test_size, validation_size=validation_size, seed=split_seed, shuffle=True) train_dir = get_train_dir(base_dir, train_name, create=True) save_trainset(train_dir, trainset) save_testset(train_dir, testset) save_valset(train_dir, valset) logs_dir = get_train_logs_dir(train_dir) logger = prepare_logger(get_train_log_file(logs_dir), reset=True) warm_model = try_load_checkpoint(base_dir=base_dir, train_name=warm_start_train_name, checkpoint=warm_start_checkpoint, logger=logger) save_prep_name(train_dir, prep_name) train( custom_hparams=custom_hparams, logdir=logs_dir, trainset=trainset, valset=valset, save_checkpoint_dir=get_checkpoints_dir(train_dir), debug_logger=logger, warm_model=warm_model, )