def my_app(config: DictConfig) -> None: # NOTE: set discriminator's in_dim automatically if config.model.netD.in_dim is None: if config.train.adv_use_static_feats_only: stream_sizes = get_static_stream_sizes( config.model.stream_sizes, config.model.has_dynamic_features, config.model.num_windows, ) else: stream_sizes = np.asarray(config.model.stream_sizes) D_in_dim = int((stream_sizes * np.asarray(config.train.adv_streams)).sum()) if config.train.mask_nth_mgc_for_adv_loss > 0: D_in_dim -= config.train.mask_nth_mgc_for_adv_loss config.model.netD.in_dim = D_in_dim if "max_time_frames" in config.data and config.data.max_time_frames > 0: collate_fn = partial( collate_fn_random_segments, max_time_frames=config.data.max_time_frames ) else: collate_fn = collate_fn_default device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ( (netG, optG, schedulerG), (netD, optD, schedulerD), grad_scaler, data_loaders, writer, logger, in_scaler, out_scaler, ) = setup_gan(config, device, collate_fn) check_resf0_config(logger, netG, config, in_scaler, out_scaler) out_scaler = PyTorchStandardScaler( torch.from_numpy(out_scaler.mean_), torch.from_numpy(out_scaler.scale_) ).to(device) use_mlflow = config.mlflow.enabled if use_mlflow: with mlflow.start_run() as run: # NOTE: modify out_dir when running with mlflow config.train.out_dir = f"{config.train.out_dir}/{run.info.run_id}" save_configs(config) log_params_from_omegaconf_dict(config) last_dev_loss = train_loop( config, logger, device, netG, optG, schedulerG, netD, optD, schedulerD, grad_scaler, data_loaders, writer, in_scaler, out_scaler, use_mlflow, ) else: save_configs(config) last_dev_loss = train_loop( config, logger, device, netG, optG, schedulerG, netD, optD, schedulerD, grad_scaler, data_loaders, writer, in_scaler, out_scaler, use_mlflow, ) return last_dev_loss
def my_app(config: DictConfig) -> None: if "max_time_frames" in config.data and config.data.max_time_frames > 0: collate_fn = partial( collate_fn_random_segments, max_time_frames=config.data.max_time_frames ) else: if "reduction_factor" in config.model.netG: collate_fn = partial( collate_fn_default, reduction_factor=config.model.netG.reduction_factor, ) else: collate_fn = collate_fn_default device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ( model, optimizer, lr_scheduler, grad_scaler, data_loaders, writer, logger, in_scaler, out_scaler, ) = setup(config, device, collate_fn) check_resf0_config(logger, model, config, in_scaler, out_scaler) out_scaler = PyTorchStandardScaler( torch.from_numpy(out_scaler.mean_), torch.from_numpy(out_scaler.scale_) ).to(device) use_mlflow = config.mlflow.enabled if use_mlflow: with mlflow.start_run() as run: # NOTE: modify out_dir when running with mlflow config.train.out_dir = f"{config.train.out_dir}/{run.info.run_id}" save_configs(config) log_params_from_omegaconf_dict(config) last_dev_loss = train_loop( config, logger, device, model, optimizer, lr_scheduler, grad_scaler, data_loaders, writer, in_scaler, out_scaler, use_mlflow, ) else: save_configs(config) last_dev_loss = train_loop( config, logger, device, model, optimizer, lr_scheduler, grad_scaler, data_loaders, writer, in_scaler, out_scaler, use_mlflow, ) return last_dev_loss