def load_frameowrk( seed, disable_debugging_API, num_workers, config_path, checkpoint_folder, reduce_train_dataset, standing_statistics, standing_step, freeze_layers, load_current, eval_type, dataset_name, num_classes, img_size, data_path, architecture, conditional_strategy, hypersphere_dim, nonlinear_embed, normalize_embed, g_spectral_norm, d_spectral_norm, activation_fn, attention, attention_after_nth_gen_block, attention_after_nth_dis_block, z_dim, shared_dim, g_conv_dim, d_conv_dim, G_depth, D_depth, optimizer, batch_size, d_lr, g_lr, momentum, nesterov, alpha, beta1, beta2, total_step, adv_loss, cr, g_init, d_init, random_flip_preprocessing, prior, truncated_factor, ema, ema_decay, ema_start, synchronized_bn, mixed_precision, hdf5_path_train, train_config, model_config, **_): if seed == 0: cudnn.benchmark = True cudnn.deterministic = False else: fix_all_seed(seed) cudnn.benchmark = False cudnn.deterministic = True if disable_debugging_API: torch.autograd.set_detect_anomaly(False) n_gpus = torch.cuda.device_count() default_device = torch.cuda.current_device() check_flag_0(batch_size, n_gpus, standing_statistics, ema, freeze_layers, checkpoint_folder) assert batch_size % n_gpus == 0, "batch_size should be divided by the number of gpus " if n_gpus == 1: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None standing_step = standing_step if standing_statistics is True else batch_size run_name = make_run_name(RUN_NAME_FORMAT, framework=config_path.split('/')[-1][:-5], phase='train') logger = make_logger(run_name, None) writer = SummaryWriter(log_dir=join('./logs', run_name)) logger.info('Run name : {run_name}'.format(run_name=run_name)) logger.info(train_config) logger.info(model_config) logger.info('Loading train datasets...') train_dataset = LoadDataset(dataset_name, data_path, train=True, download=True, resize_size=img_size, hdf5_path=hdf5_path_train, random_flip=random_flip_preprocessing) if reduce_train_dataset < 1.0: num_train = int(reduce_train_dataset * len(train_dataset)) train_dataset, _ = torch.utils.data.random_split( train_dataset, [num_train, len(train_dataset) - num_train]) logger.info('Train dataset size : {dataset_size}'.format( dataset_size=len(train_dataset))) logger.info('Loading {mode} datasets...'.format(mode=eval_type)) eval_mode = True if eval_type == 'train' else False eval_dataset = LoadDataset(dataset_name, data_path, train=eval_mode, download=True, resize_size=img_size, hdf5_path=None, random_flip=False) logger.info('Eval dataset size : {dataset_size}'.format( dataset_size=len(eval_dataset))) logger.info('Building model...') if architecture == "dcgan": assert img_size == 32, "Sry, StudioGAN does not support dcgan models for generation of images larger than 32 resolution." module = __import__( 'models.{architecture}'.format(architecture=architecture), fromlist=['something']) logger.info('Modules are located on models.{architecture}'.format( architecture=architecture)) Gen = module.Generator(z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn, conditional_strategy, num_classes, g_init, G_depth, mixed_precision).to(default_device) Dis = module.Discriminator(img_size, d_conv_dim, d_spectral_norm, attention, attention_after_nth_dis_block, activation_fn, conditional_strategy, hypersphere_dim, num_classes, nonlinear_embed, normalize_embed, d_init, D_depth, mixed_precision).to(default_device) if ema: print('Preparing EMA for G with decay of {}'.format(ema_decay)) Gen_copy = module.Generator( z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn, conditional_strategy, num_classes, initialize=False, G_depth=G_depth, mixed_precision=mixed_precision).to(default_device) Gen_ema = ema_(Gen, Gen_copy, ema_decay, ema_start) else: Gen_copy, Gen_ema = None, None logger.info(count_parameters(Gen)) logger.info(Gen) logger.info(count_parameters(Dis)) logger.info(Dis) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers, drop_last=True) eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers, drop_last=False) G_loss = { 'vanilla': loss_dcgan_gen, 'least_square': loss_lsgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen } D_loss = { 'vanilla': loss_dcgan_dis, 'least_square': loss_lsgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis } if optimizer == "SGD": G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, momentum=momentum, nesterov=nesterov) D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, momentum=momentum, nesterov=nesterov) elif optimizer == "RMSprop": G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, momentum=momentum, alpha=alpha) D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, momentum=momentum, alpha=alpha) elif optimizer == "Adam": G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, [beta1, beta2], eps=1e-6) D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, [beta1, beta2], eps=1e-6) elif optimizer == "AdaBelief": G_optimizer = AdaBelief(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, [beta1, beta2], eps=1e-12, rectify=False) D_optimizer = AdaBelief(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, [beta1, beta2], eps=1e-12, rectify=False) else: raise NotImplementedError if checkpoint_folder is not None: when = "current" if load_current is True else "best" if not exists(abspath(checkpoint_folder)): raise NotADirectoryError checkpoint_dir = make_checkpoint_dir(checkpoint_folder, run_name) g_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=G-{when}-weights-step*.pth".format(when=when)))[0] d_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=D-{when}-weights-step*.pth".format(when=when)))[0] Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint( Gen, G_optimizer, g_checkpoint_dir) Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\ load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True) logger = make_logger(run_name, None) if ema: g_ema_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format( when=when)))[0] Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True) Gen_ema.source, Gen_ema.target = Gen, Gen_copy writer = SummaryWriter(log_dir=join('./logs', run_name)) if train_config['train']: assert seed == trained_seed, "seed for sampling random numbers should be same!" logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir)) logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir)) if freeze_layers > -1: prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None else: checkpoint_dir = make_checkpoint_dir(checkpoint_folder, run_name) if n_gpus > 1: Gen = DataParallel(Gen, output_device=default_device) Dis = DataParallel(Dis, output_device=default_device) if ema: Gen_copy = DataParallel(Gen_copy, output_device=default_device) if synchronized_bn: Gen = convert_model(Gen).to(default_device) Dis = convert_model(Dis).to(default_device) if ema: Gen_copy = convert_model(Gen_copy).to(default_device) if train_config['eval']: inception_model = InceptionV3().to(default_device) if n_gpus > 1: inception_model = DataParallel(inception_model, output_device=default_device) mu, sigma = prepare_inception_moments(dataloader=eval_dataloader, generator=Gen, eval_mode=eval_type, inception_model=inception_model, splits=1, run_name=run_name, logger=logger, device=default_device) else: mu, sigma, inception_model = None, None, None train_eval = Train_Eval( run_name=run_name, best_step=best_step, dataset_name=dataset_name, eval_type=eval_type, logger=logger, writer=writer, n_gpus=n_gpus, gen_model=Gen, dis_model=Dis, inception_model=inception_model, Gen_copy=Gen_copy, Gen_ema=Gen_ema, train_dataset=train_dataset, eval_dataset=eval_dataset, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, freeze_layers=freeze_layers, conditional_strategy=conditional_strategy, pos_collected_numerator=model_config['model'] ['pos_collected_numerator'], z_dim=z_dim, num_classes=num_classes, hypersphere_dim=hypersphere_dim, d_spectral_norm=d_spectral_norm, g_spectral_norm=g_spectral_norm, G_optimizer=G_optimizer, D_optimizer=D_optimizer, batch_size=batch_size, g_steps_per_iter=model_config['optimization']['g_steps_per_iter'], d_steps_per_iter=model_config['optimization']['d_steps_per_iter'], accumulation_steps=model_config['optimization']['accumulation_steps'], total_step=total_step, G_loss=G_loss[adv_loss], D_loss=D_loss[adv_loss], contrastive_lambda=model_config['loss_function']['contrastive_lambda'], margin=model_config['loss_function']['margin'], tempering_type=model_config['loss_function']['tempering_type'], tempering_step=model_config['loss_function']['tempering_step'], start_temperature=model_config['loss_function']['start_temperature'], end_temperature=model_config['loss_function']['end_temperature'], weight_clipping_for_dis=model_config['loss_function'] ['weight_clipping_for_dis'], weight_clipping_bound=model_config['loss_function'] ['weight_clipping_bound'], gradient_penalty_for_dis=model_config['loss_function'] ['gradient_penalty_for_dis'], gradient_penalty_lambda=model_config['loss_function'] ['gradient_penalty_lambda'], deep_regret_analysis_for_dis=model_config['loss_function'] ['deep_regret_analysis_for_dis'], regret_penalty_lambda=model_config['loss_function'] ['regret_penalty_lambda'], cr=cr, cr_lambda=model_config['loss_function']['cr_lambda'], bcr=model_config['loss_function']['bcr'], real_lambda=model_config['loss_function']['real_lambda'], fake_lambda=model_config['loss_function']['fake_lambda'], zcr=model_config['loss_function']['zcr'], gen_lambda=model_config['loss_function']['gen_lambda'], dis_lambda=model_config['loss_function']['dis_lambda'], sigma_noise=model_config['loss_function']['sigma_noise'], diff_aug=model_config['training_and_sampling_setting']['diff_aug'], ada=model_config['training_and_sampling_setting']['ada'], prev_ada_p=prev_ada_p, ada_target=model_config['training_and_sampling_setting']['ada_target'], ada_length=model_config['training_and_sampling_setting']['ada_length'], prior=prior, truncated_factor=truncated_factor, ema=ema, latent_op=model_config['training_and_sampling_setting']['latent_op'], latent_op_rate=model_config['training_and_sampling_setting'] ['latent_op_rate'], latent_op_step=model_config['training_and_sampling_setting'] ['latent_op_step'], latent_op_step4eval=model_config['training_and_sampling_setting'] ['latent_op_step4eval'], latent_op_alpha=model_config['training_and_sampling_setting'] ['latent_op_alpha'], latent_op_beta=model_config['training_and_sampling_setting'] ['latent_op_beta'], latent_norm_reg_weight=model_config['training_and_sampling_setting'] ['latent_norm_reg_weight'], default_device=default_device, print_every=train_config['print_every'], save_every=train_config['save_every'], checkpoint_dir=checkpoint_dir, evaluate=train_config['eval'], mu=mu, sigma=sigma, best_fid=best_fid, best_fid_checkpoint_path=best_fid_checkpoint_path, mixed_precision=mixed_precision, train_config=train_config, model_config=model_config, ) if train_config['train']: step = train_eval.train(current_step=step, total_step=total_step) if train_config['eval']: is_save = train_eval.evaluation( step=step, standing_statistics=standing_statistics, standing_step=standing_step) if train_config['save_images']: train_eval.save_images(is_generate=True, png=True, npz=True, standing_statistics=standing_statistics, standing_step=standing_step) if train_config['image_visualization']: train_eval.run_image_visualization( nrow=train_config['nrow'], ncol=train_config['ncol'], standing_statistics=standing_statistics, standing_step=standing_step) if train_config['k_nearest_neighbor']: train_eval.run_nearest_neighbor( nrow=train_config['nrow'], ncol=train_config['ncol'], standing_statistics=standing_statistics, standing_step=standing_step) if train_config['interpolation']: assert architecture in [ "big_resnet", "biggan_deep" ], "Not supported except for biggan and biggan_deep." train_eval.run_linear_interpolation( nrow=train_config['nrow'], ncol=train_config['ncol'], fix_z=True, fix_y=False, standing_statistics=standing_statistics, standing_step=standing_step) train_eval.run_linear_interpolation( nrow=train_config['nrow'], ncol=train_config['ncol'], fix_z=False, fix_y=True, standing_statistics=standing_statistics, standing_step=standing_step) if train_config['frequency_analysis']: train_eval.run_frequency_analysis( num_images=len(train_dataset) // num_classes, standing_statistics=standing_statistics, standing_step=standing_step)
def prepare_train_eval(rank, world_size, run_name, train_config, model_config, hdf5_path_train): cfgs = dict2clsattr(train_config, model_config) prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path, mu, sigma, inception_model = None, 0, 0, None, None, None, None, None if cfgs.distributed_data_parallel: print("Use GPU: {} for training.".format(rank)) setup(rank, world_size) torch.cuda.set_device(rank) writer = SummaryWriter( log_dir=join('./logs', run_name)) if rank == 0 else None if rank == 0: logger = make_logger(run_name, None) logger.info('Run name : {run_name}'.format(run_name=run_name)) logger.info(train_config) logger.info(model_config) else: logger = None ##### load dataset ##### if rank == 0: logger.info('Load train datasets...') train_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=True, download=True, resize_size=cfgs.img_size, hdf5_path=hdf5_path_train, random_flip=cfgs.random_flip_preprocessing) if cfgs.reduce_train_dataset < 1.0: num_train = int(cfgs.reduce_train_dataset * len(train_dataset)) train_dataset, _ = torch.utils.data.random_split( train_dataset, [num_train, len(train_dataset) - num_train]) if rank == 0: logger.info('Train dataset size : {dataset_size}'.format( dataset_size=len(train_dataset))) if rank == 0: logger.info('Load {mode} datasets...'.format(mode=cfgs.eval_type)) eval_mode = True if cfgs.eval_type == 'train' else False eval_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=eval_mode, download=True, resize_size=cfgs.img_size, hdf5_path=None, random_flip=False) if rank == 0: logger.info('Eval dataset size : {dataset_size}'.format( dataset_size=len(eval_dataset))) if cfgs.distributed_data_parallel: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) cfgs.batch_size = cfgs.batch_size // world_size else: train_sampler = None train_dataloader = DataLoader(train_dataset, batch_size=cfgs.batch_size, shuffle=(train_sampler is None), pin_memory=True, num_workers=cfgs.num_workers, sampler=train_sampler, drop_last=True) eval_dataloader = DataLoader(eval_dataset, batch_size=cfgs.batch_size, shuffle=False, pin_memory=True, num_workers=cfgs.num_workers, drop_last=False) ##### build model ##### if rank == 0: logger.info('Build model...') module = __import__( 'models.{architecture}'.format(architecture=cfgs.architecture), fromlist=['something']) if rank == 0: logger.info('Modules are located on models.{architecture}.'.format( architecture=cfgs.architecture)) Gen = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision).to(rank) Dis = module.Discriminator( cfgs.img_size, cfgs.d_conv_dim, cfgs.d_spectral_norm, cfgs.attention, cfgs.attention_after_nth_dis_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.hypersphere_dim, cfgs.num_classes, cfgs.nonlinear_embed, cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth, cfgs.mixed_precision).to(rank) if cfgs.ema: if rank == 0: logger.info('Prepare EMA for G with decay of {}.'.format( cfgs.ema_decay)) Gen_copy = module.Generator( cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, initialize=False, G_depth=cfgs.G_depth, mixed_precision=cfgs.mixed_precision).to(rank) Gen_ema = ema(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start) else: Gen_copy, Gen_ema = None, None if rank == 0: logger.info(count_parameters(Gen)) if rank == 0: logger.info(Gen) if rank == 0: logger.info(count_parameters(Dis)) if rank == 0: logger.info(Dis) ### define loss functions and optimizers G_loss = { 'vanilla': loss_dcgan_gen, 'least_square': loss_lsgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen } D_loss = { 'vanilla': loss_dcgan_dis, 'least_square': loss_lsgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis } if cfgs.optimizer == "SGD": G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov) D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov) elif cfgs.optimizer == "RMSprop": G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, alpha=cfgs.alpha) D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, alpha=cfgs.alpha) elif cfgs.optimizer == "Adam": G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6) D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6) else: raise NotImplementedError if cfgs.LARS_optimizer: G_optimizer = LARS(optimizer=G_optimizer, eps=1e-8, trust_coef=0.001) D_optimizer = LARS(optimizer=D_optimizer, eps=1e-8, trust_coef=0.001) ##### load checkpoints if needed ##### if cfgs.checkpoint_folder is None: checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name) else: when = "current" if cfgs.load_current is True else "best" if not exists(abspath(cfgs.checkpoint_folder)): raise NotADirectoryError checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name) g_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=G-{when}-weights-step*.pth".format(when=when)))[0] d_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=D-{when}-weights-step*.pth".format(when=when)))[0] Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint( Gen, G_optimizer, g_checkpoint_dir) Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\ load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True) if rank == 0: logger = make_logger(run_name, None) if cfgs.ema: g_ema_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format( when=when)))[0] Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True) Gen_ema.source, Gen_ema.target = Gen, Gen_copy writer = SummaryWriter( log_dir=join('./logs', run_name)) if rank == 0 else None if cfgs.train_configs['train']: assert cfgs.seed == trained_seed, "Seed for sampling random numbers should be same!" if rank == 0: logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir)) if rank == 0: logger.info( 'Discriminator checkpoint is {}'.format(d_checkpoint_dir)) if cfgs.freeze_layers > -1: prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None ##### wrap models with DP and convert BN to Sync BN ##### if world_size > 1: if cfgs.distributed_data_parallel: if cfgs.synchronized_bn: process_group = torch.distributed.new_group( [w for w in range(world_size)]) Gen = torch.nn.SyncBatchNorm.convert_sync_batchnorm( Gen, process_group) Dis = torch.nn.SyncBatchNorm.convert_sync_batchnorm( Dis, process_group) if cfgs.ema: Gen_copy = torch.nn.SyncBatchNorm.convert_sync_batchnorm( Gen_copy, process_group) Gen = DDP(Gen, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True) Dis = DDP(Dis, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True) if cfgs.ema: Gen_copy = DDP(Gen_copy, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True) else: Gen = DataParallel(Gen, output_device=rank) Dis = DataParallel(Dis, output_device=rank) if cfgs.ema: Gen_copy = DataParallel(Gen_copy, output_device=rank) if cfgs.synchronized_bn: Gen = convert_model(Gen).to(rank) Dis = convert_model(Dis).to(rank) if cfgs.ema: Gen_copy = convert_model(Gen_copy).to(rank) ##### load the inception network and prepare first/secend moments for calculating FID ##### if cfgs.eval: inception_model = InceptionV3().to(rank) if world_size > 1 and cfgs.distributed_data_parallel: toggle_grad(inception_model, on=True) inception_model = DDP(inception_model, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True) elif world_size > 1 and cfgs.distributed_data_parallel is False: inception_model = DataParallel(inception_model, output_device=rank) else: pass mu, sigma = prepare_inception_moments(dataloader=eval_dataloader, generator=Gen, eval_mode=cfgs.eval_type, inception_model=inception_model, splits=1, run_name=run_name, logger=logger, device=rank) worker = make_worker( cfgs=cfgs, run_name=run_name, best_step=best_step, logger=logger, writer=writer, n_gpus=world_size, gen_model=Gen, dis_model=Dis, inception_model=inception_model, Gen_copy=Gen_copy, Gen_ema=Gen_ema, train_dataset=train_dataset, eval_dataset=eval_dataset, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, G_optimizer=G_optimizer, D_optimizer=D_optimizer, G_loss=G_loss[cfgs.adv_loss], D_loss=D_loss[cfgs.adv_loss], prev_ada_p=prev_ada_p, rank=rank, checkpoint_dir=checkpoint_dir, mu=mu, sigma=sigma, best_fid=best_fid, best_fid_checkpoint_path=best_fid_checkpoint_path, ) if cfgs.train_configs['train']: step = worker.train(current_step=step, total_step=cfgs.total_step) if cfgs.eval: is_save = worker.evaluation( step=step, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.save_images: worker.save_images(is_generate=True, png=True, npz=True, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.image_visualization: worker.run_image_visualization( nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.k_nearest_neighbor: worker.run_nearest_neighbor( nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.interpolation: assert cfgs.architecture in [ "big_resnet", "biggan_deep" ], "StudioGAN does not support interpolation analysis except for biggan and biggan_deep." worker.run_linear_interpolation( nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=True, fix_y=False, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) worker.run_linear_interpolation( nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=False, fix_y=True, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.frequency_analysis: worker.run_frequency_analysis( num_images=len(train_dataset) // cfgs.num_classes, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.tsne_analysis: worker.run_tsne(dataloader=eval_dataloader, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)
def prepare_train_eval(cfgs, hdf5_path_train, **_): if cfgs.seed == -1: cudnn.benchmark, cudnn.deterministic = True, False else: fix_all_seed(cfgs.seed) cudnn.benchmark, cudnn.deterministic = False, True n_gpus, default_device = torch.cuda.device_count(), torch.cuda.current_device() if n_gpus ==1: warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.') if cfgs.disable_debugging_API: torch.autograd.set_detect_anomaly(False) check_flag_0(cfgs.batch_size, n_gpus, cfgs.freeze_layers, cfgs.checkpoint_folder, cfgs.architecture, cfgs.img_size) run_name = make_run_name(RUN_NAME_FORMAT, framework=cfgs.config_path.split('/')[3][:-5], phase='train') prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path, mu, sigma, inception_model = None, 0, 0, None, None, None, None, None logger = make_logger(run_name, None) writer = SummaryWriter(log_dir=join('./logs', run_name)) logger.info('Run name : {run_name}'.format(run_name=run_name)) logger.info(cfgs.train_configs) logger.info(cfgs.model_configs) ##### load dataset ##### logger.info('Loading train datasets...') train_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=True, download=True, resize_size=cfgs.img_size, hdf5_path=hdf5_path_train, random_flip=cfgs.random_flip_preprocessing) if cfgs.reduce_train_dataset < 1.0: num_train = int(cfgs.reduce_train_dataset*len(train_dataset)) train_dataset, _ = torch.utils.data.random_split(train_dataset, [num_train, len(train_dataset) - num_train]) logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset))) logger.info('Loading {mode} datasets...'.format(mode=cfgs.eval_type)) eval_mode = True if cfgs.eval_type == 'train' else False eval_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=eval_mode, download=True, resize_size=cfgs.img_size, hdf5_path=None, random_flip=False) logger.info('Eval dataset size : {dataset_size}'.format(dataset_size=len(eval_dataset))) train_dataloader = DataLoader(train_dataset, batch_size=cfgs.batch_size, shuffle=True, pin_memory=True, num_workers=cfgs.num_workers, drop_last=True) eval_dataloader = DataLoader(eval_dataset, batch_size=cfgs.batch_size, shuffle=True, pin_memory=True, num_workers=cfgs.num_workers, drop_last=False) ##### build model ##### logger.info('Building model...') module = __import__('models.{architecture}'.format(architecture=cfgs.architecture), fromlist=['something']) logger.info('Modules are located on models.{architecture}'.format(architecture=cfgs.architecture)) Gen = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision).to(default_device) Dis = module.Discriminator(cfgs.img_size, cfgs.d_conv_dim, cfgs.d_spectral_norm, cfgs.attention, cfgs.attention_after_nth_dis_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.hypersphere_dim, cfgs.num_classes, cfgs.nonlinear_embed, cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth, cfgs.mixed_precision).to(default_device) if cfgs.ema: print('Preparing EMA for G with decay of {}'.format(cfgs.ema_decay)) Gen_copy = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, initialize=False, G_depth=cfgs.G_depth, mixed_precision=cfgs.mixed_precision).to(default_device) Gen_ema = ema(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start) else: Gen_copy, Gen_ema = None, None logger.info(count_parameters(Gen)) logger.info(Gen) logger.info(count_parameters(Dis)) logger.info(Dis) ### define loss functions and optimizers G_loss = {'vanilla': loss_dcgan_gen, 'least_square': loss_lsgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen} D_loss = {'vanilla': loss_dcgan_dis, 'least_square': loss_lsgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis} if cfgs.optimizer == "SGD": G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov) D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov) elif cfgs.optimizer == "RMSprop": G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, alpha=cfgs.alpha) D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, alpha=cfgs.alpha) elif cfgs.optimizer == "Adam": G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6) D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6) else: raise NotImplementedError ##### load checkpoints if needed ##### if cfgs.checkpoint_folder is None: checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name) else: when = "current" if cfgs.load_current is True else "best" if not exists(abspath(cfgs.checkpoint_folder)): raise NotADirectoryError checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name) g_checkpoint_dir = glob.glob(join(checkpoint_dir,"model=G-{when}-weights-step*.pth".format(when=when)))[0] d_checkpoint_dir = glob.glob(join(checkpoint_dir,"model=D-{when}-weights-step*.pth".format(when=when)))[0] Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint(Gen, G_optimizer, g_checkpoint_dir) Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\ load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True) logger = make_logger(run_name, None) if cfgs.ema: g_ema_checkpoint_dir = glob.glob(join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format(when=when)))[0] Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True) Gen_ema.source, Gen_ema.target = Gen, Gen_copy writer = SummaryWriter(log_dir=join('./logs', run_name)) if cfgs.train_configs['train']: assert cfgs.seed == trained_seed, "seed for sampling random numbers should be same!" logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir)) logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir)) if cfgs.freeze_layers > -1 : prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None ##### wrap models with DP and convert BN to Sync BN ##### if n_gpus > 1: Gen = DataParallel(Gen, output_device=default_device) Dis = DataParallel(Dis, output_device=default_device) if cfgs.ema: Gen_copy = DataParallel(Gen_copy, output_device=default_device) if cfgs.synchronized_bn: Gen = convert_model(Gen).to(default_device) Dis = convert_model(Dis).to(default_device) if cfgs.ema: Gen_copy = convert_model(Gen_copy).to(default_device) ##### load the inception network and prepare first/secend moments for calculating FID ##### if cfgs.eval: inception_model = InceptionV3().to(default_device) if n_gpus > 1: inception_model = DataParallel(inception_model, output_device=default_device) mu, sigma = prepare_inception_moments(dataloader=eval_dataloader, generator=Gen, eval_mode=cfgs.eval_type, inception_model=inception_model, splits=1, run_name=run_name, logger=logger, device=default_device) worker = make_worker( cfgs=cfgs, run_name=run_name, best_step=best_step, logger=logger, writer=writer, n_gpus=n_gpus, gen_model=Gen, dis_model=Dis, inception_model=inception_model, Gen_copy=Gen_copy, Gen_ema=Gen_ema, train_dataset=train_dataset, eval_dataset=eval_dataset, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, G_optimizer=G_optimizer, D_optimizer=D_optimizer, G_loss=G_loss[cfgs.adv_loss], D_loss=D_loss[cfgs.adv_loss], prev_ada_p=prev_ada_p, default_device=default_device, checkpoint_dir=checkpoint_dir, mu=mu, sigma=sigma, best_fid=best_fid, best_fid_checkpoint_path=best_fid_checkpoint_path, ) if cfgs.train_configs['train']: step = worker.train(current_step=step, total_step=cfgs.total_step) if cfgs.eval: is_save = worker.evaluation(step=step, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.save_images: worker.save_images(is_generate=True, png=True, npz=True, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.image_visualization: worker.run_image_visualization(nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.k_nearest_neighbor: worker.run_nearest_neighbor(nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.interpolation: assert cfgs.architecture in ["big_resnet", "biggan_deep"], "Not supported except for biggan and biggan_deep." worker.run_linear_interpolation(nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=True, fix_y=False, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) worker.run_linear_interpolation(nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=False, fix_y=True, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.frequency_analysis: worker.run_frequency_analysis(num_images=len(train_dataset)//cfgs.num_classes, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)