def train_loop( config, logger, device, netG_A2B, netG_B2A, optG, schedulerG, netD_A, netD_B, optD, schedulerD, grad_scaler, data_loaders, writer, out_scaler, use_mlflow, ): out_dir = Path(to_absolute_path(config.train.out_dir)) best_dev_loss = torch.finfo(torch.float32).max last_dev_loss = torch.finfo(torch.float32).max adv_streams = config.train.adv_streams if len(adv_streams) != len(config.model.stream_sizes): raise ValueError("adv_streams must be specified for all streams") for epoch in tqdm(range(1, config.train.nepochs + 1)): for phase in data_loaders.keys(): train = phase.startswith("train") running_loss = 0 running_metrics = {} evaluated = False for in_feats, out_feats, lengths in data_loaders[phase]: # NOTE: This is needed for pytorch's PackedSequence lengths, indices = torch.sort(lengths, dim=0, descending=True) in_feats, out_feats = ( in_feats[indices].to(device), out_feats[indices].to(device), ) if (not train) and (not evaluated): eval_spss_model( epoch, netG_A2B, in_feats, out_feats, lengths, config.model, out_scaler, writer, sr=config.data.sample_rate, ) evaluated = True if ( config.train.id_loss_until > 0 and epoch > config.train.id_loss_until ): use_id_loss = False else: use_id_loss = True loss, log_metrics = train_step( model_config=config.model, optim_config=config.train.optim, netG_A2B=netG_A2B, netG_B2A=netG_B2A, optG=optG, netD_A=netD_A, netD_B=netD_B, optD=optD, grad_scaler=grad_scaler, train=train, in_feats=in_feats, out_feats=out_feats, lengths=lengths, out_scaler=out_scaler, adv_weight=config.train.adv_weight, adv_streams=adv_streams, fm_weight=config.train.fm_weight, mask_nth_mgc_for_adv_loss=config.train.mask_nth_mgc_for_adv_loss, gan_type=config.train.gan_type, vuv_mask=config.train.vuv_mask, cycle_weight=config.train.cycle_weight, id_weight=config.train.id_weight, use_id_loss=use_id_loss, ) running_loss += loss.item() for k, v in log_metrics.items(): try: running_metrics[k] += float(v) except KeyError: running_metrics[k] = float(v) ave_loss = running_loss / len(data_loaders[phase]) logger.info("[%s] [Epoch %s]: loss %s", phase, epoch, ave_loss) for k, v in running_metrics.items(): ave_v = v / len(data_loaders[phase]) if writer is not None: writer.add_scalar(f"{k}/{phase}", ave_v, epoch) if use_mlflow: mlflow.log_metric(f"{phase}_{k}", ave_v, step=epoch) if not train: last_dev_loss = ave_loss if not train and ave_loss < best_dev_loss: best_dev_loss = ave_loss for model, opt, scheduler, postfix in [ (netG_A2B, optG, schedulerG, "_A2B"), (netG_B2A, optG, schedulerG, "_B2A"), (netD_A, optD, schedulerD, "_D_A"), (netD_B, optD, schedulerD, "_D_B"), ]: save_checkpoint( logger, out_dir, model, opt, scheduler, epoch, is_best=True, postfix=postfix, ) schedulerG.step() schedulerD.step() if epoch % config.train.checkpoint_epoch_interval == 0: for model, opt, scheduler, postfix in [ (netG_A2B, optG, schedulerG, "_A2B"), (netG_B2A, optG, schedulerG, "_B2A"), (netD_A, optD, schedulerD, "_D_A"), (netD_B, optD, schedulerD, "_D_B"), ]: save_checkpoint( logger, out_dir, model, opt, scheduler, epoch, is_best=False, postfix=postfix, ) for model, opt, scheduler, postfix in [ (netG_A2B, optG, schedulerG, "_A2B"), (netG_B2A, optG, schedulerG, "_B2A"), (netD_A, optD, schedulerD, "_D_A"), (netD_B, optD, schedulerD, "_D_B"), ]: save_checkpoint( logger, out_dir, model, opt, scheduler, config.train.nepochs, postfix=postfix, ) logger.info("The best loss was %s", best_dev_loss) if use_mlflow: mlflow.log_metric("best_dev_loss", best_dev_loss, step=epoch) mlflow.log_artifacts(out_dir) return last_dev_loss
def train_loop( config, logger, device, netG, optG, schedulerG, netD, optD, schedulerD, grad_scaler, data_loaders, writer, in_scaler, out_scaler, use_mlflow, ): out_dir = Path(to_absolute_path(config.train.out_dir)) best_dev_loss = torch.finfo(torch.float32).max last_dev_loss = torch.finfo(torch.float32).max in_lf0_idx = config.data.in_lf0_idx in_rest_idx = config.data.in_rest_idx if in_lf0_idx is None or in_rest_idx is None: raise ValueError("in_lf0_idx and in_rest_idx must be specified") pitch_reg_weight = config.train.pitch_reg_weight fm_weight = config.train.fm_weight adv_streams = config.train.adv_streams if len(adv_streams) != len(config.model.stream_sizes): raise ValueError("adv_streams must be specified for all streams") if "sample_rate" not in config.data: logger.warning( "sample_rate is not found in the data config. Fallback to 48000." ) sr = 48000 else: sr = config.data.sample_rate if "feats_criterion" not in config.train: logger.warning( "feats_criterion is not found in the data config. Fallback to MSE." ) feats_criterion = "mse" else: feats_criterion = config.train.feats_criterion for epoch in tqdm(range(1, config.train.nepochs + 1)): for phase in data_loaders.keys(): train = phase.startswith("train") running_loss = 0 running_metrics = {} evaluated = False for in_feats, out_feats, lengths in data_loaders[phase]: # NOTE: This is needed for pytorch's PackedSequence lengths, indices = torch.sort(lengths, dim=0, descending=True) in_feats, out_feats = ( in_feats[indices].to(device), out_feats[indices].to(device), ) if (not train) and (not evaluated): eval_spss_model( epoch, netG, in_feats, out_feats, lengths, config.model, out_scaler, writer, sr=sr, ) evaluated = True # Compute denormalized log-F0 in the musical scores lf0_score_denorm = ( in_feats[:, :, in_lf0_idx] - in_scaler.min_[in_lf0_idx] ) / in_scaler.scale_[in_lf0_idx] # Fill zeros for rest and padded frames lf0_score_denorm *= (in_feats[:, :, in_rest_idx] <= 0).float() for idx, length in enumerate(lengths): lf0_score_denorm[idx, length:] = 0 # Compute time-variant pitch regularization weight vector pitch_reg_dyn_ws = compute_batch_pitch_regularization_weight( lf0_score_denorm ) loss, log_metrics = train_step( model_config=config.model, optim_config=config.train.optim, netG=netG, optG=optG, netD=netD, optD=optD, grad_scaler=grad_scaler, train=train, in_feats=in_feats, out_feats=out_feats, lengths=lengths, out_scaler=out_scaler, feats_criterion=feats_criterion, pitch_reg_dyn_ws=pitch_reg_dyn_ws, pitch_reg_weight=pitch_reg_weight, adv_weight=config.train.adv_weight, adv_streams=adv_streams, fm_weight=fm_weight, adv_use_static_feats_only=config.train.adv_use_static_feats_only, mask_nth_mgc_for_adv_loss=config.train.mask_nth_mgc_for_adv_loss, gan_type=config.train.gan_type, ) running_loss += loss.item() for k, v in log_metrics.items(): try: running_metrics[k] += float(v) except KeyError: running_metrics[k] = float(v) ave_loss = running_loss / len(data_loaders[phase]) logger.info("[%s] [Epoch %s]: loss %s", phase, epoch, ave_loss) for k, v in running_metrics.items(): ave_v = v / len(data_loaders[phase]) if writer is not None: writer.add_scalar(f"{k}/{phase}", ave_v, epoch) if use_mlflow: mlflow.log_metric(f"{phase}_{k}", ave_v, step=epoch) if not train: last_dev_loss = ave_loss if not train and ave_loss < best_dev_loss: best_dev_loss = ave_loss for model, opt, scheduler, postfix in [ (netG, optG, schedulerG, ""), (netD, optD, schedulerD, "_D"), ]: save_checkpoint( logger, out_dir, model, opt, scheduler, epoch, is_best=True, postfix=postfix, ) schedulerG.step() schedulerD.step() if epoch % config.train.checkpoint_epoch_interval == 0: for model, opt, scheduler, postfix in [ (netG, optG, schedulerG, ""), (netD, optD, schedulerD, "_D"), ]: save_checkpoint( logger, out_dir, model, opt, scheduler, epoch, is_best=False, postfix=postfix, ) for model, opt, scheduler, postfix in [ (netG, optG, schedulerG, ""), (netD, optD, schedulerD, "_D"), ]: save_checkpoint( logger, out_dir, model, opt, scheduler, config.train.nepochs, postfix=postfix, ) logger.info("The best loss was %s", best_dev_loss) if use_mlflow: mlflow.log_metric("best_dev_loss", best_dev_loss, step=epoch) mlflow.log_artifacts(out_dir) return last_dev_loss
def train_loop( config, logger, device, model, optimizer, lr_scheduler, grad_scaler, data_loaders, writer, out_scaler, use_mlflow, ): out_dir = Path(to_absolute_path(config.train.out_dir)) best_dev_loss = torch.finfo(torch.float32).max last_dev_loss = torch.finfo(torch.float32).max if "feats_criterion" not in config.train: logger.warning( "feats_criterion is not found in the train config. Fallback to MSE." ) feats_criterion = "mse" else: feats_criterion = config.train.feats_criterion for epoch in tqdm(range(1, config.train.nepochs + 1)): for phase in data_loaders.keys(): train = phase.startswith("train") running_loss = 0 running_metrics = {} for in_feats, out_feats, lengths in data_loaders[phase]: # NOTE: This is needed for pytorch's PackedSequence lengths, indices = torch.sort(lengths, dim=0, descending=True) in_feats, out_feats = ( in_feats[indices].to(device), out_feats[indices].to(device), ) loss, distortions = train_step( model=model, optimizer=optimizer, grad_scaler=grad_scaler, train=train, in_feats=in_feats, out_feats=out_feats, lengths=lengths, out_scaler=out_scaler, feats_criterion=feats_criterion, stream_wise_loss=config.train.stream_wise_loss, stream_weights=config.model.stream_weights, stream_sizes=config.model.stream_sizes, ) running_loss += loss.item() for k, v in distortions.items(): try: running_metrics[k] += float(v) except KeyError: running_metrics[k] = float(v) ave_loss = running_loss / len(data_loaders[phase]) if writer is not None: writer.add_scalar(f"Loss/{phase}", ave_loss, epoch) if use_mlflow: mlflow.log_metric(f"{phase}_loss", ave_loss, step=epoch) ave_loss = running_loss / len(data_loaders[phase]) logger.info("[%s] [Epoch %s]: loss %s", phase, epoch, ave_loss) for k, v in running_metrics.items(): ave_v = v / len(data_loaders[phase]) if writer is not None: writer.add_scalar(f"{k}/{phase}", ave_v, epoch) if use_mlflow: mlflow.log_metric(f"{phase}_{k}", ave_v, step=epoch) if not train: last_dev_loss = ave_loss if not train and ave_loss < best_dev_loss: best_dev_loss = ave_loss save_checkpoint(logger, out_dir, model, optimizer, lr_scheduler, epoch, is_best=True) lr_scheduler.step() if epoch % config.train.checkpoint_epoch_interval == 0: save_checkpoint(logger, out_dir, model, optimizer, lr_scheduler, epoch, is_best=False) save_checkpoint(logger, out_dir, model, optimizer, lr_scheduler, config.train.nepochs) logger.info("The best loss was %s", best_dev_loss) if use_mlflow: mlflow.log_metric("best_dev_loss", best_dev_loss, step=epoch) mlflow.log_artifacts(out_dir) return last_dev_loss