def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args train_args = TrainVocoderArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args args, config_overrides = parser.parse_known_args() train_args.parse_args(args) # load config.json and register if args.config_path or args.continue_path: if args.config_path: # init from a file config = load_config(args.config_path) if len(config_overrides) > 0: config.parse_known_args(config_overrides, relaxed_parser=True) elif args.continue_path: # continue from a prev experiment config = load_config( os.path.join(args.continue_path, "config.json")) if len(config_overrides) > 0: config.parse_known_args(config_overrides, relaxed_parser=True) else: # init from console args from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel config_base = BaseTrainingConfig() config_base.parse_known_args(config_overrides) config = register_config(config_base.model)() # load training samples if "feature_path" in config and config.feature_path: # load pre-computed features print(f" > Loading features from: {config.feature_path}") eval_samples, train_samples = load_wav_feat_data( config.data_path, config.feature_path, config.eval_split_size) else: # load data raw wav files eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) # setup audio processor ap = AudioProcessor(**config.audio) # init the model from config model = setup_model(config) # init the trainer and 🚀 trainer = Trainer( train_args, config, config.output_path, model=model, train_samples=train_samples, eval_samples=eval_samples, training_assets={"audio_processor": ap}, parse_command_line_args=False, ) trainer.fit()
def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global train_data, eval_data print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) # setup audio processor ap = AudioProcessor(**c.audio) # DISTRUBUTED if num_gpus > 1: init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # setup models model = setup_generator(c) # scaler for mixed_precision scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None # setup optimizers optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0) # schedulers scheduler = None if 'lr_scheduler' in c: scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) scheduler = scheduler(optimizer, **c.lr_scheduler_params) # setup criterion criterion = torch.nn.L1Loss().cuda() if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') try: print(" > Restoring Model...") model.load_state_dict(checkpoint['model']) print(" > Restoring Optimizer...") optimizer.load_state_dict(checkpoint['optimizer']) if 'scheduler' in checkpoint: print(" > Restoring LR Scheduler...") scheduler.load_state_dict(checkpoint['scheduler']) # NOTE: Not sure if necessary scheduler.optimizer = optimizer if "scaler" in checkpoint and c.mixed_precision: print(" > Restoring AMP Scaler...") scaler.load_state_dict(checkpoint["scaler"]) except RuntimeError: # retore only matching layers. print(" > Partial model initialization...") model_dict = model.state_dict() model_dict = set_init_dict(model_dict, checkpoint['model'], c) model.load_state_dict(model_dict) del model_dict # reset lr if not countinuining training. for group in optimizer.param_groups: group['lr'] = c.lr print(" > Model restored from step %d" % checkpoint['step'], flush=True) args.restore_step = checkpoint['step'] else: args.restore_step = 0 if use_cuda: model.cuda() criterion.cuda() # DISTRUBUTED if num_gpus > 1: model = DDP_th(model, device_ids=[args.rank]) num_params = count_parameters(model) print(" > WaveGrad has {} parameters".format(num_params), flush=True) if 'best_loss' not in locals(): best_loss = float('inf') global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) _, global_step = train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict[c.target_loss] best_loss = save_best_model(target_loss, best_loss, model, optimizer, scheduler, None, None, None, global_step, epoch, OUT_PATH, model_losses=eval_avg_loss_dict, scaler=scaler.state_dict() if c.mixed_precision else None)
def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global train_data, eval_data print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) # setup audio processor ap = AudioProcessor(**c.audio) # DISTRUBUTED if num_gpus > 1: init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # setup models model_gen = setup_generator(c) model_disc = setup_discriminator(c) # setup optimizers optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0) optimizer_disc = RAdam(model_disc.parameters(), lr=c.lr_disc, weight_decay=0) # schedulers scheduler_gen = None scheduler_disc = None if 'lr_scheduler_gen' in c: scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) if 'lr_scheduler_disc' in c: scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) # setup criterion criterion_gen = GeneratorLoss(c) criterion_disc = DiscriminatorLoss(c) if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)}...") checkpoint = torch.load(args.restore_path, map_location='cpu') try: print(" > Restoring Generator Model...") model_gen.load_state_dict(checkpoint['model']) print(" > Restoring Generator Optimizer...") optimizer_gen.load_state_dict(checkpoint['optimizer']) print(" > Restoring Discriminator Model...") model_disc.load_state_dict(checkpoint['model_disc']) print(" > Restoring Discriminator Optimizer...") optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) if 'scheduler' in checkpoint: print(" > Restoring Generator LR Scheduler...") scheduler_gen.load_state_dict(checkpoint['scheduler']) # NOTE: Not sure if necessary scheduler_gen.optimizer = optimizer_gen if 'scheduler_disc' in checkpoint: print(" > Restoring Discriminator LR Scheduler...") scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) scheduler_disc.optimizer = optimizer_disc except RuntimeError: # restore only matching layers. print(" > Partial model initialization...") model_dict = model_gen.state_dict() model_dict = set_init_dict(model_dict, checkpoint['model'], c) model_gen.load_state_dict(model_dict) model_dict = model_disc.state_dict() model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c) model_disc.load_state_dict(model_dict) del model_dict # reset lr if not countinuining training. for group in optimizer_gen.param_groups: group['lr'] = c.lr_gen for group in optimizer_disc.param_groups: group['lr'] = c.lr_disc print(f" > Model restored from step {checkpoint['step']:d}", flush=True) args.restore_step = checkpoint['step'] else: args.restore_step = 0 if use_cuda: model_gen.cuda() criterion_gen.cuda() model_disc.cuda() criterion_disc.cuda() # DISTRUBUTED if num_gpus > 1: model_gen = DDP_th(model_gen, device_ids=[args.rank]) model_disc = DDP_th(model_disc, device_ids=[args.rank]) num_params = count_parameters(model_gen) print(" > Generator has {} parameters".format(num_params), flush=True) num_params = count_parameters(model_disc) print(" > Discriminator has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: best_loss = float('inf') print(" > Starting with inf best loss.") else: print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") best_loss = torch.load(args.best_path, map_location='cpu')['model_loss'] print(f" > Starting with best loss of {best_loss}.") keep_all_best = c.get('keep_all_best', False) keep_after = c.get('keep_after', 10000) # void if keep_all_best False global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) _, global_step = train(model_gen, criterion_gen, optimizer_gen, model_disc, criterion_disc, optimizer_disc, scheduler_gen, scheduler_disc, ap, global_step, epoch) eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict[c.target_loss] best_loss = save_best_model( target_loss, best_loss, model_gen, optimizer_gen, scheduler_gen, model_disc, optimizer_disc, scheduler_disc, global_step, epoch, OUT_PATH, keep_all_best=keep_all_best, keep_after=keep_after, model_losses=eval_avg_loss_dict, )
use_noise_augment=True, eval_split_size=10, print_step=25, print_eval=False, mixed_precision=False, lr_gen=1e-4, lr_disc=1e-4, data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"), output_path=output_path, ) # init audio processor ap = AudioProcessor(**config.audio.to_dict()) # load training samples eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) # init model model = GAN(config) # init the trainer and 🚀 trainer = Trainer( TrainingArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples, training_assets={"audio_processor": ap}, ) trainer.fit()
def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global train_data, eval_data # setup audio processor ap = AudioProcessor(**c.audio) # print(f" > Loading wavs from: {c.data_path}") # if c.feature_path is not None: # print(f" > Loading features from: {c.feature_path}") # eval_data, train_data = load_wav_feat_data( # c.data_path, c.feature_path, c.eval_split_size # ) # else: # mel_feat_path = os.path.join(OUT_PATH, "mel") # feat_data = find_feat_files(mel_feat_path) # if feat_data: # print(f" > Loading features from: {mel_feat_path}") # eval_data, train_data = load_wav_feat_data( # c.data_path, mel_feat_path, c.eval_split_size # ) # else: # print(" > No feature data found. Preprocessing...") # # preprocessing feature data from given wav files # preprocess_wav_files(OUT_PATH, CONFIG, ap) # eval_data, train_data = load_wav_feat_data( # c.data_path, mel_feat_path, c.eval_split_size # ) print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) # setup model model_wavernn = setup_wavernn(c) # setup amp scaler scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None # define train functions if c.mode == "mold": criterion = discretized_mix_logistic_loss elif c.mode == "gauss": criterion = gaussian_loss elif isinstance(c.mode, int): criterion = torch.nn.CrossEntropyLoss() if use_cuda: model_wavernn.cuda() if isinstance(c.mode, int): criterion.cuda() optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0) scheduler = None if "lr_scheduler" in c: scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) scheduler = scheduler(optimizer, **c.lr_scheduler_params) # slow start for the first 5 epochs # lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1) # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # restore any checkpoint if args.restore_path: checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Model...") model_wavernn.load_state_dict(checkpoint["model"]) print(" > Restoring Optimizer...") optimizer.load_state_dict(checkpoint["optimizer"]) if "scheduler" in checkpoint: print(" > Restoring Generator LR Scheduler...") scheduler.load_state_dict(checkpoint["scheduler"]) scheduler.optimizer = optimizer if "scaler" in checkpoint and c.mixed_precision: print(" > Restoring AMP Scaler...") scaler.load_state_dict(checkpoint["scaler"]) except RuntimeError: # retore only matching layers. print(" > Partial model initialization...") model_dict = model_wavernn.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model"], c) model_wavernn.load_state_dict(model_dict) print(" > Model restored from step %d" % checkpoint["step"], flush=True) args.restore_step = checkpoint["step"] else: args.restore_step = 0 # DISTRIBUTED # if num_gpus > 1: # model = apply_gradient_allreduce(model) num_parameters = count_parameters(model_wavernn) print(" > Model has {} parameters".format(num_parameters), flush=True) if "best_loss" not in locals(): best_loss = float("inf") global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) _, global_step = train(model_wavernn, optimizer, criterion, scheduler, scaler, ap, global_step, epoch) eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict["avg_model_loss"] best_loss = save_best_model( target_loss, best_loss, model_wavernn, optimizer, scheduler, None, None, None, global_step, epoch, OUT_PATH, model_losses=eval_avg_loss_dict, scaler=scaler.state_dict() if c.mixed_precision else None)
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers): ''' run dataloader with given parameters and check conditions ''' ap = AudioProcessor(**C.audio) _, train_items = load_wav_data(test_data_path, 10) dataset = GANDataset(ap, train_items, seq_len=seq_len, hop_len=hop_len, pad_short=2000, conv_pad=conv_pad, return_segments=return_segments, use_noise_augment=use_noise_augment, use_cache=use_cache) loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) max_iter = 10 count_iter = 0 # return random segments or return the whole audio if return_segments: for item1, _ in loader: feat1, wav1 = item1 # feat2, wav2 = item2 expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2) # check shapes assert np.all(feat1.shape == expected_feat_shape ), f" [!] {feat1.shape} vs {expected_feat_shape}" assert (feat1.shape[2] - conv_pad * 2) * hop_len == wav1.shape[2] # check feature vs audio match if not use_noise_augment: for idx in range(batch_size): audio = wav1[idx].squeeze() feat = feat1[idx] mel = ap.melspectrogram(audio) # the first 2 and the last frame is skipped due to the padding # applied in spec. computation. assert (feat - mel[:, :feat1.shape[-1]])[:, 2:-1].sum( ) == 0, f' [!] {(feat - mel[:, :feat1.shape[-1]])[:, 2:-1].sum()}' count_iter += 1 # if count_iter == max_iter: # break else: for item in loader: feat, wav = item expected_feat_shape = (batch_size, ap.num_mels, (wav.shape[-1] // hop_len) + (conv_pad * 2)) assert np.all(feat.shape == expected_feat_shape ), f" [!] {feat.shape} vs {expected_feat_shape}" assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2] count_iter += 1 if count_iter == max_iter: break
'--search_depth', type=int, default=3, help= 'Search granularity. Increasing this increases the run-time exponentially.' ) # load config args = parser.parse_args() config = load_config(args.config_path) # setup audio processor ap = AudioProcessor(**config.audio) # load dataset _, train_data = load_wav_data(args.data_path, 0) train_data = train_data[:args.num_samples] dataset = WaveGradDataset(ap=ap, items=train_data, seq_len=-1, hop_len=ap.hop_length, pad_short=config.pad_short, conv_pad=config.conv_pad, is_training=True, return_segments=False, use_noise_augment=False, use_cache=False, verbose=True) loader = DataLoader(dataset, batch_size=1, shuffle=False,
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_pairs, return_segments, use_noise_augment, use_cache, num_workers): """Run dataloader with given parameters and check conditions""" ap = AudioProcessor(**C.audio) _, train_items = load_wav_data(test_data_path, 10) dataset = GANDataset( ap, train_items, seq_len=seq_len, hop_len=hop_len, pad_short=2000, conv_pad=conv_pad, return_pairs=return_pairs, return_segments=return_segments, use_noise_augment=use_noise_augment, use_cache=use_cache, ) loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) max_iter = 10 count_iter = 0 def check_item(feat, wav): """Pass a single pair of features and waveform""" feat = feat.numpy() wav = wav.numpy() expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2) # check shapes assert np.all(feat.shape == expected_feat_shape ), f" [!] {feat.shape} vs {expected_feat_shape}" assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2] # check feature vs audio match if not use_noise_augment: for idx in range(batch_size): audio = wav[idx].squeeze() feat = feat[idx] mel = ap.melspectrogram(audio) # the first 2 and the last 2 frames are skipped due to the padding # differences in stft max_diff = abs((feat - mel[:, :feat.shape[-1]])[:, 2:-2]).max() assert max_diff <= 1e-6, f" [!] {max_diff}" # return random segments or return the whole audio if return_segments: if return_pairs: for item1, item2 in loader: feat1, wav1 = item1 feat2, wav2 = item2 check_item(feat1, wav1) check_item(feat2, wav2) count_iter += 1 else: for item1 in loader: feat1, wav1 = item1 check_item(feat1, wav1) count_iter += 1 else: for item in loader: feat, wav = item expected_feat_shape = (batch_size, ap.num_mels, (wav.shape[-1] // hop_len) + (conv_pad * 2)) assert np.all(feat.shape == expected_feat_shape ), f" [!] {feat.shape} vs {expected_feat_shape}" assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2] count_iter += 1 if count_iter == max_iter: break
def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global train_data, eval_data print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: #eval_data, train_data = load_file_data(c.data_path, c.eval_split_size) eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) # setup audio processor ap = AudioProcessor(**c.audio) # DISTRUBUTED # if num_gpus > 1: # init_distributed(args.rank, num_gpus, args.group_id, # c.distributed["backend"], c.distributed["url"]) # setup models model_gen = setup_generator(c) model_disc = setup_discriminator(c) # setup optimizers optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0) optimizer_disc = RAdam(model_disc.parameters(), lr=c.lr_disc, weight_decay=0) scaler_G = GradScaler() scaler_D = GradScaler() # schedulers scheduler_gen = None scheduler_disc = None if 'lr_scheduler_gen' in c: scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) if 'lr_scheduler_disc' in c: scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) # setup criterion criterion_gen = GeneratorLoss(c) criterion_disc = DiscriminatorLoss(c) if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') try: print(" > Restoring Generator Model...") model_gen.load_state_dict(checkpoint['model']) print(" > Restoring Generator Optimizer...") optimizer_gen.load_state_dict(checkpoint['optimizer']) print(" > Restoring Discriminator Model...") model_disc.load_state_dict(checkpoint['model_disc']) print(" > Restoring Discriminator Optimizer...") optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) if 'scheduler' in checkpoint: print(" > Restoring Generator LR Scheduler...") scheduler_gen.load_state_dict(checkpoint['scheduler']) # NOTE: Not sure if necessary scheduler_gen.optimizer = optimizer_gen if 'scheduler_disc' in checkpoint: print(" > Restoring Discriminator LR Scheduler...") scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) scheduler_disc.optimizer = optimizer_disc except RuntimeError: # retore only matching layers. print(" > Partial model initialization...") model_dict = model_gen.state_dict() model_dict = set_init_dict(model_dict, checkpoint['model'], c) model_gen.load_state_dict(model_dict) model_dict = model_disc.state_dict() model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c) model_disc.load_state_dict(model_dict) del model_dict # reset lr if not countinuining training. for group in optimizer_gen.param_groups: group['lr'] = c.lr_gen for group in optimizer_disc.param_groups: group['lr'] = c.lr_disc print(" > Model restored from step %d" % checkpoint['step'], flush=True) args.restore_step = checkpoint['step'] else: args.restore_step = 0 if use_cuda: model_gen.cuda() criterion_gen.cuda() model_disc.cuda() criterion_disc.cuda() # DISTRUBUTED # if num_gpus > 1: # model = apply_gradient_allreduce(model) num_params = count_parameters(model_gen) print(" > Generator has {} parameters".format(num_params), flush=True) num_params = count_parameters(model_disc) print(" > Discriminator has {} parameters".format(num_params), flush=True) if 'best_loss' not in locals(): best_loss = float('inf') global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) _, global_step = train(model_gen, criterion_gen, optimizer_gen, model_disc, criterion_disc, optimizer_disc, scaler_G, scaler_D, scheduler_gen, scheduler_disc, ap, global_step, epoch) eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict[c.target_loss] best_loss = save_best_model(target_loss, best_loss, model_gen, optimizer_gen, scheduler_gen, model_disc, optimizer_disc, scheduler_disc, global_step, epoch, OUT_PATH, model_losses=eval_avg_loss_dict)
def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global train_data, eval_data print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) # setup audio processor ap = AudioProcessor(**c.audio.to_dict()) # DISTRUBUTED if num_gpus > 1: init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # setup models model_gen = setup_generator(c) model_disc = setup_discriminator(c) # setup criterion criterion_gen = GeneratorLoss(c) criterion_disc = DiscriminatorLoss(c) if use_cuda: model_gen.cuda() criterion_gen.cuda() model_disc.cuda() criterion_disc.cuda() # setup optimizers # TODO: allow loading custom optimizers optimizer_gen = None optimizer_disc = None optimizer_gen = getattr(torch.optim, c.optimizer) optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params) optimizer_disc = getattr(torch.optim, c.optimizer) if c.discriminator_model == "hifigan_discriminator": optimizer_disc = optimizer_disc( itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params, ) else: optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) # schedulers scheduler_gen = None scheduler_disc = None if "lr_scheduler_gen" in c: scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) if "lr_scheduler_disc" in c: scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)}...") checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Generator Model...") model_gen.load_state_dict(checkpoint["model"]) print(" > Restoring Generator Optimizer...") optimizer_gen.load_state_dict(checkpoint["optimizer"]) print(" > Restoring Discriminator Model...") model_disc.load_state_dict(checkpoint["model_disc"]) print(" > Restoring Discriminator Optimizer...") optimizer_disc.load_state_dict(checkpoint["optimizer_disc"]) # restore schedulers if it is a continuing training. if args.continue_path != "": if "scheduler" in checkpoint and scheduler_gen is not None: print(" > Restoring Generator LR Scheduler...") scheduler_gen.load_state_dict(checkpoint["scheduler"]) # NOTE: Not sure if necessary scheduler_gen.optimizer = optimizer_gen if "scheduler_disc" in checkpoint and scheduler_disc is not None: print(" > Restoring Discriminator LR Scheduler...") scheduler_disc.load_state_dict( checkpoint["scheduler_disc"]) scheduler_disc.optimizer = optimizer_disc if c.lr_scheduler_disc == "ExponentialLR": scheduler_disc.last_epoch = checkpoint["epoch"] except RuntimeError: # restore only matching layers. print(" > Partial model initialization...") model_dict = model_gen.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model"], c) model_gen.load_state_dict(model_dict) model_dict = model_disc.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c) model_disc.load_state_dict(model_dict) del model_dict # reset lr if not countinuining training. if args.continue_path == "": for group in optimizer_gen.param_groups: group["lr"] = c.lr_gen for group in optimizer_disc.param_groups: group["lr"] = c.lr_disc print(f" > Model restored from step {checkpoint['step']:d}", flush=True) args.restore_step = checkpoint["step"] else: args.restore_step = 0 # DISTRUBUTED if num_gpus > 1: model_gen = DDP_th(model_gen, device_ids=[args.rank]) model_disc = DDP_th(model_disc, device_ids=[args.rank]) num_params = count_parameters(model_gen) print(" > Generator has {} parameters".format(num_params), flush=True) num_params = count_parameters(model_disc) print(" > Discriminator has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: best_loss = float("inf") print(" > Starting with inf best loss.") else: print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with best loss of {best_loss}.") keep_all_best = c.get("keep_all_best", False) keep_after = c.get("keep_after", 10000) # void if keep_all_best False global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) _, global_step = train( model_gen, criterion_gen, optimizer_gen, model_disc, criterion_disc, optimizer_disc, scheduler_gen, scheduler_disc, ap, global_step, epoch, ) eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict[c.target_loss] best_loss = save_best_model( target_loss, best_loss, model_gen, optimizer_gen, scheduler_gen, model_disc, optimizer_disc, scheduler_disc, global_step, epoch, OUT_PATH, keep_all_best=keep_all_best, keep_after=keep_after, model_losses=eval_avg_loss_dict, )