def load_configs_initialize_training(): parser = ArgumentParser(add_help=True) parser.add_argument("--entity", type=str, default=None, help="entity for wandb logging") parser.add_argument("--project", type=str, default=None, help="project name for wandb logging") parser.add_argument("-cfg", "--cfg_file", type=str, default="./src/configs/CIFAR10/ContraGAN.yaml") parser.add_argument("-data", "--data_dir", type=str, default=None) parser.add_argument("-save", "--save_dir", type=str, default="./") parser.add_argument("-ckpt", "--ckpt_dir", type=str, default=None) parser.add_argument("-best", "--load_best", action="store_true", help="load the best performed checkpoint") parser.add_argument("--seed", type=int, default=-1, help="seed for generating random numbers") parser.add_argument("-DDP", "--distributed_data_parallel", action="store_true") parser.add_argument( "--backend", type=str, default="nccl", help="cuda backend for DDP training \in ['nccl', 'gloo']") parser.add_argument("-tn", "--total_nodes", default=1, type=int, help="total number of nodes for training") parser.add_argument("-cn", "--current_node", default=0, type=int, help="rank of the current node") parser.add_argument("--num_workers", type=int, default=8) parser.add_argument("-sync_bn", "--synchronized_bn", action="store_true", help="turn on synchronized batchnorm") parser.add_argument("-mpc", "--mixed_precision", action="store_true", help="turn on mixed precision training") parser.add_argument("--truncation_factor", type=float, default=-1.0, help="truncation factor for applying truncation trick \ (-1.0 means not applying truncation trick)") parser.add_argument("--truncation_cutoff", type=float, default=None, help="truncation cutoff for stylegan \ (apply truncation for only w[:truncation_cutoff]") parser.add_argument( "-batch_stat", "--batch_statistics", action="store_true", help="use the statistics of a batch when evaluating GAN \ (if false, use the moving average updated statistics)") parser.add_argument("-std_stat", "--standing_statistics", action="store_true", help="apply standing statistics for evaluation") parser.add_argument( "-std_max", "--standing_max_batch", type=int, default=-1, help="maximum batch_size for calculating standing statistics \ (-1.0 menas not applying standing statistics trick for evaluation)" ) parser.add_argument("-std_step", "--standing_step", type=int, default=-1, help="# of steps for standing statistics \ (-1.0 menas not applying standing statistics trick for evaluation)" ) parser.add_argument( "--freezeD", type=int, default=-1, help="# of freezed blocks in the discriminator for transfer learning") # parser arguments to apply langevin sampling for GAN evaluation # In the arguments regarding 'decay', -1 means not applying the decay trick by default parser.add_argument( "-lgv", "--langevin_sampling", action="store_true", help= "apply langevin sampling to generate images from a Energy-Based Model") parser.add_argument( "-lgv_rate", "--langevin_rate", type=float, default=-1, help="an initial update rate for langevin sampling (\epsilon)") parser.add_argument( "-lgv_std", "--langevin_noise_std", type=float, default=-1, help= "standard deviation of a gaussian noise used in langevin sampling (std of n_i)" ) parser.add_argument( "-lgv_decay", "--langevin_decay", type=float, default=-1, help="decay strength for langevin_rate and langevin_noise_std") parser.add_argument( "-lgv_decay_steps", "--langevin_decay_steps", type=int, default=-1, help= "langevin_rate and langevin_noise_std decrease every 'langevin_decay_steps'" ) parser.add_argument("-lgv_steps", "--langevin_steps", type=int, default=-1, help="total steps of langevin sampling") parser.add_argument("-t", "--train", action="store_true") parser.add_argument("-hdf5", "--load_train_hdf5", action="store_true", help="load train images from a hdf5 file for fast I/O") parser.add_argument( "-l", "--load_data_in_memory", action="store_true", help="put the whole train dataset on the main memory for fast I/O") parser.add_argument( "-metrics", "--eval_metrics", nargs='+', default=['fid'], help= "evaluation metrics to use during training, a subset list of ['fid', 'is', 'prdc'] or none" ) parser.add_argument("--num_eval", type=int, default=1, help="number of runs for final evaluation.") parser.add_argument( "--resize_fn", type=str, default="legacy", help= "which mode to use PIL.bicubic resizing for calculating clean metrics\ in ['legacy', 'clean']") parser.add_argument("-sr", "--save_real_images", action="store_true", help="save images sampled from the reference dataset") parser.add_argument("-sf", "--save_fake_images", action="store_true", help="save fake images generated by the GAN.") parser.add_argument("-sf_num", "--save_fake_images_num", type=int, default=1, help="number of fake images to save") parser.add_argument("-v", "--vis_fake_images", action="store_true", help="visualize image canvas") parser.add_argument("-knn", "--k_nearest_neighbor", action="store_true", help="conduct k-nearest neighbor analysis") parser.add_argument("-itp", "--interpolation", action="store_true", help="conduct interpolation analysis") parser.add_argument("-fa", "--frequency_analysis", action="store_true", help="conduct frequency analysis") parser.add_argument("-tsne", "--tsne_analysis", action="store_true", help="conduct tsne analysis") parser.add_argument("-ifid", "--intra_class_fid", action="store_true", help="calculate intra-class fid") parser.add_argument('--GAN_train', action='store_true', help="whether to calculate CAS (Recall)") parser.add_argument('--GAN_test', action='store_true', help="whether to calculate CAS (Precision)") parser.add_argument('-resume_ct', '--resume_classifier_train', action='store_true', help="whether to resume classifier traning for CAS") parser.add_argument("-sefa", "--semantic_factorization", action="store_true", help="perform semantic (closed-form) factorization") parser.add_argument("-sefa_axis", "--num_semantic_axis", type=int, default=-1, help="number of semantic axis for sefa") parser.add_argument( "-sefa_max", "--maximum_variations", type=float, default=-1, help="iterpolate between z and z + maximum_variations*eigen-vector") parser.add_argument( "-empty_cache", "--empty_cache", action="store_true", help= "empty cuda caches after training step of generator and discriminator, \ slightly reduces memory usage but slows training speed. (not recommended for normal use)" ) parser.add_argument("--print_every", type=int, default=100, help="logging interval") parser.add_argument("-every", "--save_every", type=int, default=2000, help="save interval") parser.add_argument('--eval_backbone', type=str, default='Inception_V3', help="[SwAV, Inception_V3]") parser.add_argument( "-ref", "--ref_dataset", type=str, default="train", help="reference dataset for evaluation[train/valid/test]") parser.add_argument( "--is_ref_dataset", action="store_true", help="whether to calculate a inception score of the ref dataset.") args = parser.parse_args() run_cfgs = vars(args) if not args.train and \ "none" in args.eval_metrics and \ not args.save_real_images and \ not args.save_fake_images and \ not args.vis_fake_images and \ not args.k_nearest_neighbor and \ not args.interpolation and \ not args.frequency_analysis and \ not args.tsne_analysis and \ not args.intra_class_fid and \ not args.GAN_train and \ not args.GAN_test and \ not args.semantic_factorization: parser.print_help(sys.stderr) sys.exit(1) gpus_per_node, rank = torch.cuda.device_count(), torch.cuda.current_device( ) cfgs = config.Configurations(args.cfg_file) cfgs.update_cfgs(run_cfgs, super="RUN") cfgs.OPTIMIZATION.world_size = gpus_per_node * cfgs.RUN.total_nodes cfgs.check_compatability() run_name = log.make_run_name( RUN_NAME_FORMAT, data_name=cfgs.DATA.name, framework=cfgs.RUN.cfg_file.split("/")[-1][:-5], phase="train") crop_long_edge = False if cfgs.DATA.name in cfgs.MISC.no_proc_data else True resize_size = None if cfgs.DATA.name in cfgs.MISC.no_proc_data else cfgs.DATA.img_size if cfgs.RUN.load_train_hdf5: hdf5_path, crop_long_edge, resize_size = hdf5.make_hdf5( name=cfgs.DATA.name, img_size=cfgs.DATA.img_size, crop_long_edge=crop_long_edge, resize_size=resize_size, data_dir=cfgs.RUN.data_dir, DATA=cfgs.DATA, RUN=cfgs.RUN) else: hdf5_path = None cfgs.PRE.crop_long_edge, cfgs.PRE.resize_size = crop_long_edge, resize_size misc.prepare_folder(names=cfgs.MISC.base_folders, save_dir=cfgs.RUN.save_dir) misc.download_data_if_possible(data_name=cfgs.DATA.name, data_dir=cfgs.RUN.data_dir) if cfgs.RUN.seed == -1: cfgs.RUN.seed = random.randint(1, 4096) cfgs.RUN.fix_seed = False else: cfgs.RUN.fix_seed = True if cfgs.OPTIMIZATION.world_size == 1: print( "You have chosen a specific GPU. This will completely disable data parallelism." ) return cfgs, gpus_per_node, run_name, hdf5_path, rank
def main(): parser = ArgumentParser(add_help=False) parser.add_argument('-c', '--config_path', type=str, default='./src/configs/CIFAR10/ContraGAN.json') parser.add_argument('--checkpoint_folder', type=str, default=None) parser.add_argument('-current', '--load_current', action='store_true', help='whether you load the current or best checkpoint') parser.add_argument('--log_output_path', type=str, default=None) parser.add_argument('-DDP', '--distributed_data_parallel', action='store_true') parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N') parser.add_argument('-nr', '--nr', default=0, type=int, help='ranking within the nodes') parser.add_argument('--seed', type=int, default=-1, help='seed for generating random numbers') parser.add_argument('--num_workers', type=int, default=8, help='') parser.add_argument('-sync_bn', '--synchronized_bn', action='store_true', help='whether turn on synchronized batchnorm') parser.add_argument('-mpc', '--mixed_precision', action='store_true', help='whether turn on mixed precision training') parser.add_argument('-LARS', '--LARS_optimizer', action='store_true', help='whether turn on LARS optimizer') parser.add_argument('-rm_API', '--disable_debugging_API', action='store_true', help='whether disable pytorch autograd debugging mode') parser.add_argument('--reduce_train_dataset', type=float, default=1.0, help='control the number of train dataset') parser.add_argument('-stat_otf', '--bn_stat_OnTheFly', action='store_true', help='when evaluating, use the statistics of a batch') parser.add_argument('-std_stat', '--standing_statistics', action='store_true') parser.add_argument('--standing_step', type=int, default=-1, help='# of steps for accumulation batchnorm') parser.add_argument('--freeze_layers', type=int, default=-1, help='# of layers for freezing discriminator') parser.add_argument('-l', '--load_all_data_in_memory', action='store_true') parser.add_argument('-t', '--train', action='store_true') parser.add_argument('-e', '--eval', action='store_true') parser.add_argument('-s', '--save_images', action='store_true') parser.add_argument('-iv', '--image_visualization', action='store_true', help='select whether conduct image visualization') parser.add_argument('-knn', '--k_nearest_neighbor', action='store_true', help='select whether conduct k-nearest neighbor analysis') parser.add_argument('-itp', '--interpolation', action='store_true', help='whether conduct interpolation analysis') parser.add_argument('-fa', '--frequency_analysis', action='store_true', help='whether conduct frequency analysis') parser.add_argument('-tsne', '--tsne_analysis', action='store_true', help='whether conduct tsne analysis') parser.add_argument('--nrow', type=int, default=10, help='number of rows to plot image canvas') parser.add_argument('--ncol', type=int, default=8, help='number of cols to plot image canvas') parser.add_argument('--print_every', type=int, default=100, help='control log interval') parser.add_argument('--save_every', type=int, default=2000, help='control evaluation and save interval') parser.add_argument('--eval_type', type=str, default='test', help='[train/valid/test]') args = parser.parse_args() if not args.train and \ not args.eval and \ not args.save_images and \ not args.image_visualization and \ not args.k_nearest_neighbor and \ not args.interpolation and \ not args.frequency_analysis and \ not args.tsne_analysis: parser.print_help(sys.stderr) sys.exit(1) if args.config_path is not None: with open(args.config_path) as f: model_configs = json.load(f) train_configs = vars(args) else: raise NotImplementedError hdf5_path_train = make_hdf5(model_configs['data_processing'], train_configs, mode="train") \ if train_configs['load_all_data_in_memory'] else None if train_configs['seed'] == -1: train_configs['seed'] = random.randint(1,4096) cudnn.benchmark, cudnn.deterministic = True, False else: cudnn.benchmark, cudnn.deterministic = False, True fix_all_seed(train_configs['seed']) gpus_per_node, rank = torch.cuda.device_count(), torch.cuda.current_device() world_size = gpus_per_node*train_configs['nodes'] if world_size == 1: warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.') run_name = make_run_name(RUN_NAME_FORMAT, framework=train_configs['config_path'].split('/')[-1][:-5], phase='train') if train_configs['disable_debugging_API']: torch.autograd.set_detect_anomaly(False) check_flags(train_configs, model_configs, world_size) if train_configs['distributed_data_parallel'] and world_size > 1: print("Train the models through DistributedDataParallel (DDP) mode.") mp.spawn(prepare_train_eval, nprocs=gpus_per_node, args=(gpus_per_node, world_size, run_name, train_configs, model_configs, hdf5_path_train)) else: prepare_train_eval(rank, gpus_per_node, world_size, run_name, train_configs, model_configs, hdf5_path_train=hdf5_path_train)
def main(): parser = ArgumentParser(add_help=False) parser.add_argument('-c', '--config_path', type=str, default='./src/configs/CIFAR10/ContraGAN.json') parser.add_argument('--checkpoint_folder', type=str, default=None) parser.add_argument('-current', '--load_current', action='store_true', help='whether you load the current or best checkpoint') parser.add_argument('--log_output_path', type=str, default=None) parser.add_argument('--seed', type=int, default=-1, help='seed for generating random numbers') parser.add_argument('-DDP', '--distributed_data_parallel', action='store_true') parser.add_argument('--num_workers', type=int, default=8, help='') parser.add_argument('-sync_bn', '--synchronized_bn', action='store_true', help='whether turn on synchronized batchnorm') parser.add_argument('-mpc', '--mixed_precision', action='store_true', help='whether turn on mixed precision training') parser.add_argument('-rm_API', '--disable_debugging_API', action='store_true', help='whether disable pytorch autograd debugging mode') parser.add_argument('--reduce_train_dataset', type=float, default=1.0, help='control the number of train dataset') parser.add_argument('-std_stat', '--standing_statistics', action='store_true') parser.add_argument('--standing_step', type=int, default=-1, help='# of steps for accumulation batchnorm') parser.add_argument('--freeze_layers', type=int, default=-1, help='# of layers for freezing discriminator') parser.add_argument('-l', '--load_all_data_in_memory', action='store_true') parser.add_argument('-t', '--train', action='store_true') parser.add_argument('-e', '--eval', action='store_true') parser.add_argument('-s', '--save_images', action='store_true') parser.add_argument('-iv', '--image_visualization', action='store_true', help='select whether conduct image visualization') parser.add_argument('-knn', '--k_nearest_neighbor', action='store_true', help='select whether conduct k-nearest neighbor analysis') parser.add_argument('-itp', '--interpolation', action='store_true', help='whether conduct interpolation analysis') parser.add_argument('-fa', '--frequency_analysis', action='store_true', help='whether conduct frequency analysis') parser.add_argument('-tsne', '--tsne_analysis', action='store_true', help='whether conduct tsne analysis') parser.add_argument('--nrow', type=int, default=10, help='number of rows to plot image canvas') parser.add_argument('--ncol', type=int, default=8, help='number of cols to plot image canvas') parser.add_argument('--print_every', type=int, default=100, help='control log interval') parser.add_argument('--save_every', type=int, default=2000, help='control evaluation and save interval') parser.add_argument('--eval_type', type=str, default='test', help='[train/valid/test]') args = parser.parse_args() if not args.train and \ not args.eval and \ not args.save_images and \ not args.image_visualization and \ not args.k_nearest_neighbor and \ not args.interpolation and \ not args.frequency_analysis and \ not args.tsne_analysis: parser.print_help(sys.stderr) sys.exit(1) if args.config_path is not None: with open(args.config_path) as f: model_config = json.load(f) train_config = vars(args) else: raise NotImplementedError if model_config['data_processing']['dataset_name'] == 'cifar10': assert train_config['eval_type'] in ['train', 'test'], "Cifar10 does not contain dataset for validation." elif model_config['data_processing']['dataset_name'] in ['imagenet', 'tiny_imagenet', 'custom']: assert train_config['eval_type'] == 'train' or train_config['eval_type'] == 'valid', \ "StudioGAN dose not support the evalutation protocol that uses the test dataset on imagenet, tiny imagenet, and custom datasets" if train_config['distributed_data_parallel']: msg = "StudioGAN does not support image visualization, k_nearest_neighbor, interpolation, and frequency_analysis with DDP. " +\ "Please change DDP with a single GPU training or DataParallel instead." assert train_config['image_visualization'] + train_config['k_nearest_neighbor'] + \ train_config['interpolation'] + train_config['frequency_analysis'] + train_config['tsne_analysis'] == 0, msg hdf5_path_train = make_hdf5(model_config['data_processing'], train_config, mode="train") \ if train_config['load_all_data_in_memory'] else None if train_config['seed'] == -1: cudnn.benchmark, cudnn.deterministic = True, False else: fix_all_seed(train_config['seed']) cudnn.benchmark, cudnn.deterministic = False, True world_size, rank = torch.cuda.device_count(), torch.cuda.current_device() if world_size == 1: warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.') if train_config['disable_debugging_API']: torch.autograd.set_detect_anomaly(False) check_flag_0(model_config['train']['optimization']['batch_size'], world_size, train_config['freeze_layers'], train_config['checkpoint_folder'], model_config['train']['model']['architecture'], model_config['data_processing']['img_size']) run_name = make_run_name(RUN_NAME_FORMAT, framework=train_config['config_path'].split('/')[-1][:-5], phase='train') if train_config['distributed_data_parallel'] and world_size > 1: print("Train the models through DistributedDataParallel (DDP) mode.") mp.spawn(prepare_train_eval, nprocs=world_size, args=(world_size, run_name, train_config, model_config, hdf5_path_train)) else: prepare_train_eval(rank, world_size, run_name, train_config, model_config, hdf5_path_train=hdf5_path_train)
def load_frameowrk( seed, disable_debugging_API, num_workers, config_path, checkpoint_folder, reduce_train_dataset, standing_statistics, standing_step, freeze_layers, load_current, eval_type, dataset_name, num_classes, img_size, data_path, architecture, conditional_strategy, hypersphere_dim, nonlinear_embed, normalize_embed, g_spectral_norm, d_spectral_norm, activation_fn, attention, attention_after_nth_gen_block, attention_after_nth_dis_block, z_dim, shared_dim, g_conv_dim, d_conv_dim, G_depth, D_depth, optimizer, batch_size, d_lr, g_lr, momentum, nesterov, alpha, beta1, beta2, total_step, adv_loss, cr, g_init, d_init, random_flip_preprocessing, prior, truncated_factor, ema, ema_decay, ema_start, synchronized_bn, mixed_precision, hdf5_path_train, train_config, model_config, **_): if seed == 0: cudnn.benchmark = True cudnn.deterministic = False else: fix_all_seed(seed) cudnn.benchmark = False cudnn.deterministic = True if disable_debugging_API: torch.autograd.set_detect_anomaly(False) n_gpus = torch.cuda.device_count() default_device = torch.cuda.current_device() check_flag_0(batch_size, n_gpus, standing_statistics, ema, freeze_layers, checkpoint_folder) assert batch_size % n_gpus == 0, "batch_size should be divided by the number of gpus " if n_gpus == 1: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None standing_step = standing_step if standing_statistics is True else batch_size run_name = make_run_name(RUN_NAME_FORMAT, framework=config_path.split('/')[-1][:-5], phase='train') logger = make_logger(run_name, None) writer = SummaryWriter(log_dir=join('./logs', run_name)) logger.info('Run name : {run_name}'.format(run_name=run_name)) logger.info(train_config) logger.info(model_config) logger.info('Loading train datasets...') train_dataset = LoadDataset(dataset_name, data_path, train=True, download=True, resize_size=img_size, hdf5_path=hdf5_path_train, random_flip=random_flip_preprocessing) if reduce_train_dataset < 1.0: num_train = int(reduce_train_dataset * len(train_dataset)) train_dataset, _ = torch.utils.data.random_split( train_dataset, [num_train, len(train_dataset) - num_train]) logger.info('Train dataset size : {dataset_size}'.format( dataset_size=len(train_dataset))) logger.info('Loading {mode} datasets...'.format(mode=eval_type)) eval_mode = True if eval_type == 'train' else False eval_dataset = LoadDataset(dataset_name, data_path, train=eval_mode, download=True, resize_size=img_size, hdf5_path=None, random_flip=False) logger.info('Eval dataset size : {dataset_size}'.format( dataset_size=len(eval_dataset))) logger.info('Building model...') if architecture == "dcgan": assert img_size == 32, "Sry, StudioGAN does not support dcgan models for generation of images larger than 32 resolution." module = __import__( 'models.{architecture}'.format(architecture=architecture), fromlist=['something']) logger.info('Modules are located on models.{architecture}'.format( architecture=architecture)) Gen = module.Generator(z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn, conditional_strategy, num_classes, g_init, G_depth, mixed_precision).to(default_device) Dis = module.Discriminator(img_size, d_conv_dim, d_spectral_norm, attention, attention_after_nth_dis_block, activation_fn, conditional_strategy, hypersphere_dim, num_classes, nonlinear_embed, normalize_embed, d_init, D_depth, mixed_precision).to(default_device) if ema: print('Preparing EMA for G with decay of {}'.format(ema_decay)) Gen_copy = module.Generator( z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn, conditional_strategy, num_classes, initialize=False, G_depth=G_depth, mixed_precision=mixed_precision).to(default_device) Gen_ema = ema_(Gen, Gen_copy, ema_decay, ema_start) else: Gen_copy, Gen_ema = None, None logger.info(count_parameters(Gen)) logger.info(Gen) logger.info(count_parameters(Dis)) logger.info(Dis) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers, drop_last=True) eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers, drop_last=False) G_loss = { 'vanilla': loss_dcgan_gen, 'least_square': loss_lsgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen } D_loss = { 'vanilla': loss_dcgan_dis, 'least_square': loss_lsgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis } if optimizer == "SGD": G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, momentum=momentum, nesterov=nesterov) D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, momentum=momentum, nesterov=nesterov) elif optimizer == "RMSprop": G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, momentum=momentum, alpha=alpha) D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, momentum=momentum, alpha=alpha) elif optimizer == "Adam": G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, [beta1, beta2], eps=1e-6) D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, [beta1, beta2], eps=1e-6) elif optimizer == "AdaBelief": G_optimizer = AdaBelief(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, [beta1, beta2], eps=1e-12, rectify=False) D_optimizer = AdaBelief(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, [beta1, beta2], eps=1e-12, rectify=False) else: raise NotImplementedError if checkpoint_folder is not None: when = "current" if load_current is True else "best" if not exists(abspath(checkpoint_folder)): raise NotADirectoryError checkpoint_dir = make_checkpoint_dir(checkpoint_folder, run_name) g_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=G-{when}-weights-step*.pth".format(when=when)))[0] d_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=D-{when}-weights-step*.pth".format(when=when)))[0] Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint( Gen, G_optimizer, g_checkpoint_dir) Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\ load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True) logger = make_logger(run_name, None) if ema: g_ema_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format( when=when)))[0] Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True) Gen_ema.source, Gen_ema.target = Gen, Gen_copy writer = SummaryWriter(log_dir=join('./logs', run_name)) if train_config['train']: assert seed == trained_seed, "seed for sampling random numbers should be same!" logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir)) logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir)) if freeze_layers > -1: prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None else: checkpoint_dir = make_checkpoint_dir(checkpoint_folder, run_name) if n_gpus > 1: Gen = DataParallel(Gen, output_device=default_device) Dis = DataParallel(Dis, output_device=default_device) if ema: Gen_copy = DataParallel(Gen_copy, output_device=default_device) if synchronized_bn: Gen = convert_model(Gen).to(default_device) Dis = convert_model(Dis).to(default_device) if ema: Gen_copy = convert_model(Gen_copy).to(default_device) if train_config['eval']: inception_model = InceptionV3().to(default_device) if n_gpus > 1: inception_model = DataParallel(inception_model, output_device=default_device) mu, sigma = prepare_inception_moments(dataloader=eval_dataloader, generator=Gen, eval_mode=eval_type, inception_model=inception_model, splits=1, run_name=run_name, logger=logger, device=default_device) else: mu, sigma, inception_model = None, None, None train_eval = Train_Eval( run_name=run_name, best_step=best_step, dataset_name=dataset_name, eval_type=eval_type, logger=logger, writer=writer, n_gpus=n_gpus, gen_model=Gen, dis_model=Dis, inception_model=inception_model, Gen_copy=Gen_copy, Gen_ema=Gen_ema, train_dataset=train_dataset, eval_dataset=eval_dataset, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, freeze_layers=freeze_layers, conditional_strategy=conditional_strategy, pos_collected_numerator=model_config['model'] ['pos_collected_numerator'], z_dim=z_dim, num_classes=num_classes, hypersphere_dim=hypersphere_dim, d_spectral_norm=d_spectral_norm, g_spectral_norm=g_spectral_norm, G_optimizer=G_optimizer, D_optimizer=D_optimizer, batch_size=batch_size, g_steps_per_iter=model_config['optimization']['g_steps_per_iter'], d_steps_per_iter=model_config['optimization']['d_steps_per_iter'], accumulation_steps=model_config['optimization']['accumulation_steps'], total_step=total_step, G_loss=G_loss[adv_loss], D_loss=D_loss[adv_loss], contrastive_lambda=model_config['loss_function']['contrastive_lambda'], margin=model_config['loss_function']['margin'], tempering_type=model_config['loss_function']['tempering_type'], tempering_step=model_config['loss_function']['tempering_step'], start_temperature=model_config['loss_function']['start_temperature'], end_temperature=model_config['loss_function']['end_temperature'], weight_clipping_for_dis=model_config['loss_function'] ['weight_clipping_for_dis'], weight_clipping_bound=model_config['loss_function'] ['weight_clipping_bound'], gradient_penalty_for_dis=model_config['loss_function'] ['gradient_penalty_for_dis'], gradient_penalty_lambda=model_config['loss_function'] ['gradient_penalty_lambda'], deep_regret_analysis_for_dis=model_config['loss_function'] ['deep_regret_analysis_for_dis'], regret_penalty_lambda=model_config['loss_function'] ['regret_penalty_lambda'], cr=cr, cr_lambda=model_config['loss_function']['cr_lambda'], bcr=model_config['loss_function']['bcr'], real_lambda=model_config['loss_function']['real_lambda'], fake_lambda=model_config['loss_function']['fake_lambda'], zcr=model_config['loss_function']['zcr'], gen_lambda=model_config['loss_function']['gen_lambda'], dis_lambda=model_config['loss_function']['dis_lambda'], sigma_noise=model_config['loss_function']['sigma_noise'], diff_aug=model_config['training_and_sampling_setting']['diff_aug'], ada=model_config['training_and_sampling_setting']['ada'], prev_ada_p=prev_ada_p, ada_target=model_config['training_and_sampling_setting']['ada_target'], ada_length=model_config['training_and_sampling_setting']['ada_length'], prior=prior, truncated_factor=truncated_factor, ema=ema, latent_op=model_config['training_and_sampling_setting']['latent_op'], latent_op_rate=model_config['training_and_sampling_setting'] ['latent_op_rate'], latent_op_step=model_config['training_and_sampling_setting'] ['latent_op_step'], latent_op_step4eval=model_config['training_and_sampling_setting'] ['latent_op_step4eval'], latent_op_alpha=model_config['training_and_sampling_setting'] ['latent_op_alpha'], latent_op_beta=model_config['training_and_sampling_setting'] ['latent_op_beta'], latent_norm_reg_weight=model_config['training_and_sampling_setting'] ['latent_norm_reg_weight'], default_device=default_device, print_every=train_config['print_every'], save_every=train_config['save_every'], checkpoint_dir=checkpoint_dir, evaluate=train_config['eval'], mu=mu, sigma=sigma, best_fid=best_fid, best_fid_checkpoint_path=best_fid_checkpoint_path, mixed_precision=mixed_precision, train_config=train_config, model_config=model_config, ) if train_config['train']: step = train_eval.train(current_step=step, total_step=total_step) if train_config['eval']: is_save = train_eval.evaluation( step=step, standing_statistics=standing_statistics, standing_step=standing_step) if train_config['save_images']: train_eval.save_images(is_generate=True, png=True, npz=True, standing_statistics=standing_statistics, standing_step=standing_step) if train_config['image_visualization']: train_eval.run_image_visualization( nrow=train_config['nrow'], ncol=train_config['ncol'], standing_statistics=standing_statistics, standing_step=standing_step) if train_config['k_nearest_neighbor']: train_eval.run_nearest_neighbor( nrow=train_config['nrow'], ncol=train_config['ncol'], standing_statistics=standing_statistics, standing_step=standing_step) if train_config['interpolation']: assert architecture in [ "big_resnet", "biggan_deep" ], "Not supported except for biggan and biggan_deep." train_eval.run_linear_interpolation( nrow=train_config['nrow'], ncol=train_config['ncol'], fix_z=True, fix_y=False, standing_statistics=standing_statistics, standing_step=standing_step) train_eval.run_linear_interpolation( nrow=train_config['nrow'], ncol=train_config['ncol'], fix_z=False, fix_y=True, standing_statistics=standing_statistics, standing_step=standing_step) if train_config['frequency_analysis']: train_eval.run_frequency_analysis( num_images=len(train_dataset) // num_classes, standing_statistics=standing_statistics, standing_step=standing_step)
def prepare_train_eval(cfgs, hdf5_path_train, **_): if cfgs.seed == -1: cudnn.benchmark, cudnn.deterministic = True, False else: fix_all_seed(cfgs.seed) cudnn.benchmark, cudnn.deterministic = False, True n_gpus, default_device = torch.cuda.device_count(), torch.cuda.current_device() if n_gpus ==1: warnings.warn('You have chosen a specific GPU. This will completely disable data parallelism.') if cfgs.disable_debugging_API: torch.autograd.set_detect_anomaly(False) check_flag_0(cfgs.batch_size, n_gpus, cfgs.freeze_layers, cfgs.checkpoint_folder, cfgs.architecture, cfgs.img_size) run_name = make_run_name(RUN_NAME_FORMAT, framework=cfgs.config_path.split('/')[3][:-5], phase='train') prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path, mu, sigma, inception_model = None, 0, 0, None, None, None, None, None logger = make_logger(run_name, None) writer = SummaryWriter(log_dir=join('./logs', run_name)) logger.info('Run name : {run_name}'.format(run_name=run_name)) logger.info(cfgs.train_configs) logger.info(cfgs.model_configs) ##### load dataset ##### logger.info('Loading train datasets...') train_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=True, download=True, resize_size=cfgs.img_size, hdf5_path=hdf5_path_train, random_flip=cfgs.random_flip_preprocessing) if cfgs.reduce_train_dataset < 1.0: num_train = int(cfgs.reduce_train_dataset*len(train_dataset)) train_dataset, _ = torch.utils.data.random_split(train_dataset, [num_train, len(train_dataset) - num_train]) logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset))) logger.info('Loading {mode} datasets...'.format(mode=cfgs.eval_type)) eval_mode = True if cfgs.eval_type == 'train' else False eval_dataset = LoadDataset(cfgs.dataset_name, cfgs.data_path, train=eval_mode, download=True, resize_size=cfgs.img_size, hdf5_path=None, random_flip=False) logger.info('Eval dataset size : {dataset_size}'.format(dataset_size=len(eval_dataset))) train_dataloader = DataLoader(train_dataset, batch_size=cfgs.batch_size, shuffle=True, pin_memory=True, num_workers=cfgs.num_workers, drop_last=True) eval_dataloader = DataLoader(eval_dataset, batch_size=cfgs.batch_size, shuffle=True, pin_memory=True, num_workers=cfgs.num_workers, drop_last=False) ##### build model ##### logger.info('Building model...') module = __import__('models.{architecture}'.format(architecture=cfgs.architecture), fromlist=['something']) logger.info('Modules are located on models.{architecture}'.format(architecture=cfgs.architecture)) Gen = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision).to(default_device) Dis = module.Discriminator(cfgs.img_size, cfgs.d_conv_dim, cfgs.d_spectral_norm, cfgs.attention, cfgs.attention_after_nth_dis_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.hypersphere_dim, cfgs.num_classes, cfgs.nonlinear_embed, cfgs.normalize_embed, cfgs.d_init, cfgs.D_depth, cfgs.mixed_precision).to(default_device) if cfgs.ema: print('Preparing EMA for G with decay of {}'.format(cfgs.ema_decay)) Gen_copy = module.Generator(cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, initialize=False, G_depth=cfgs.G_depth, mixed_precision=cfgs.mixed_precision).to(default_device) Gen_ema = ema(Gen, Gen_copy, cfgs.ema_decay, cfgs.ema_start) else: Gen_copy, Gen_ema = None, None logger.info(count_parameters(Gen)) logger.info(Gen) logger.info(count_parameters(Dis)) logger.info(Dis) ### define loss functions and optimizers G_loss = {'vanilla': loss_dcgan_gen, 'least_square': loss_lsgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen} D_loss = {'vanilla': loss_dcgan_dis, 'least_square': loss_lsgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis} if cfgs.optimizer == "SGD": G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov) D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, nesterov=cfgs.nesterov) elif cfgs.optimizer == "RMSprop": G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, momentum=cfgs.momentum, alpha=cfgs.alpha) D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, momentum=cfgs.momentum, alpha=cfgs.alpha) elif cfgs.optimizer == "Adam": G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), cfgs.g_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6) D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), cfgs.d_lr, [cfgs.beta1, cfgs.beta2], eps=1e-6) else: raise NotImplementedError ##### load checkpoints if needed ##### if cfgs.checkpoint_folder is None: checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name) else: when = "current" if cfgs.load_current is True else "best" if not exists(abspath(cfgs.checkpoint_folder)): raise NotADirectoryError checkpoint_dir = make_checkpoint_dir(cfgs.checkpoint_folder, run_name) g_checkpoint_dir = glob.glob(join(checkpoint_dir,"model=G-{when}-weights-step*.pth".format(when=when)))[0] d_checkpoint_dir = glob.glob(join(checkpoint_dir,"model=D-{when}-weights-step*.pth".format(when=when)))[0] Gen, G_optimizer, trained_seed, run_name, step, prev_ada_p = load_checkpoint(Gen, G_optimizer, g_checkpoint_dir) Dis, D_optimizer, trained_seed, run_name, step, prev_ada_p, best_step, best_fid, best_fid_checkpoint_path =\ load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True) logger = make_logger(run_name, None) if cfgs.ema: g_ema_checkpoint_dir = glob.glob(join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format(when=when)))[0] Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True) Gen_ema.source, Gen_ema.target = Gen, Gen_copy writer = SummaryWriter(log_dir=join('./logs', run_name)) if cfgs.train_configs['train']: assert cfgs.seed == trained_seed, "seed for sampling random numbers should be same!" logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir)) logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir)) if cfgs.freeze_layers > -1 : prev_ada_p, step, best_step, best_fid, best_fid_checkpoint_path = None, 0, 0, None, None ##### wrap models with DP and convert BN to Sync BN ##### if n_gpus > 1: Gen = DataParallel(Gen, output_device=default_device) Dis = DataParallel(Dis, output_device=default_device) if cfgs.ema: Gen_copy = DataParallel(Gen_copy, output_device=default_device) if cfgs.synchronized_bn: Gen = convert_model(Gen).to(default_device) Dis = convert_model(Dis).to(default_device) if cfgs.ema: Gen_copy = convert_model(Gen_copy).to(default_device) ##### load the inception network and prepare first/secend moments for calculating FID ##### if cfgs.eval: inception_model = InceptionV3().to(default_device) if n_gpus > 1: inception_model = DataParallel(inception_model, output_device=default_device) mu, sigma = prepare_inception_moments(dataloader=eval_dataloader, generator=Gen, eval_mode=cfgs.eval_type, inception_model=inception_model, splits=1, run_name=run_name, logger=logger, device=default_device) worker = make_worker( cfgs=cfgs, run_name=run_name, best_step=best_step, logger=logger, writer=writer, n_gpus=n_gpus, gen_model=Gen, dis_model=Dis, inception_model=inception_model, Gen_copy=Gen_copy, Gen_ema=Gen_ema, train_dataset=train_dataset, eval_dataset=eval_dataset, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, G_optimizer=G_optimizer, D_optimizer=D_optimizer, G_loss=G_loss[cfgs.adv_loss], D_loss=D_loss[cfgs.adv_loss], prev_ada_p=prev_ada_p, default_device=default_device, checkpoint_dir=checkpoint_dir, mu=mu, sigma=sigma, best_fid=best_fid, best_fid_checkpoint_path=best_fid_checkpoint_path, ) if cfgs.train_configs['train']: step = worker.train(current_step=step, total_step=cfgs.total_step) if cfgs.eval: is_save = worker.evaluation(step=step, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.save_images: worker.save_images(is_generate=True, png=True, npz=True, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.image_visualization: worker.run_image_visualization(nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.k_nearest_neighbor: worker.run_nearest_neighbor(nrow=cfgs.nrow, ncol=cfgs.ncol, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.interpolation: assert cfgs.architecture in ["big_resnet", "biggan_deep"], "Not supported except for biggan and biggan_deep." worker.run_linear_interpolation(nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=True, fix_y=False, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) worker.run_linear_interpolation(nrow=cfgs.nrow, ncol=cfgs.ncol, fix_z=False, fix_y=True, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step) if cfgs.frequency_analysis: worker.run_frequency_analysis(num_images=len(train_dataset)//cfgs.num_classes, standing_statistics=cfgs.standing_statistics, standing_step=cfgs.standing_step)
def train_framework( seed, num_workers, config_path, reduce_train_dataset, load_current, type4eval_dataset, dataset_name, num_classes, img_size, data_path, architecture, conditional_strategy, hypersphere_dim, nonlinear_embed, normalize_embed, g_spectral_norm, d_spectral_norm, activation_fn, attention, attention_after_nth_gen_block, attention_after_nth_dis_block, z_dim, shared_dim, g_conv_dim, d_conv_dim, G_depth, D_depth, optimizer, batch_size, d_lr, g_lr, momentum, nesterov, alpha, beta1, beta2, total_step, adv_loss, consistency_reg, g_init, d_init, random_flip_preprocessing, prior, truncated_factor, latent_op, ema, ema_decay, ema_start, synchronized_bn, hdf5_path_train, train_config, model_config, **_): fix_all_seed(seed) cudnn.benchmark = False # Not good Generator for undetermined input size cudnn.deterministic = True n_gpus = torch.cuda.device_count() default_device = torch.cuda.current_device() second_device = default_device if n_gpus == 1 else default_device + 1 assert batch_size % n_gpus == 0, "batch_size should be divided by the number of gpus " if n_gpus == 1: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') start_step, best_step, best_fid, best_fid_checkpoint_path = 0, 0, None, None run_name = make_run_name(RUN_NAME_FORMAT, framework=config_path.split('/')[3][:-5], phase='train') logger = make_logger(run_name, None) writer = SummaryWriter(log_dir=join('./logs', run_name)) logger.info('Run name : {run_name}'.format(run_name=run_name)) logger.info(train_config) logger.info(model_config) logger.info('Loading train datasets...') train_dataset = LoadDataset(dataset_name, data_path, train=True, download=True, resize_size=img_size, hdf5_path=hdf5_path_train, consistency_reg=consistency_reg, random_flip=random_flip_preprocessing) if reduce_train_dataset < 1.0: num_train = int(reduce_train_dataset * len(train_dataset)) train_dataset, _ = torch.utils.data.random_split( train_dataset, [num_train, len(train_dataset) - num_train]) logger.info('Train dataset size : {dataset_size}'.format( dataset_size=len(train_dataset))) logger.info('Loading {mode} datasets...'.format(mode=type4eval_dataset)) eval_mode = True if type4eval_dataset == 'train' else False eval_dataset = LoadDataset(dataset_name, data_path, train=eval_mode, download=True, resize_size=img_size, hdf5_path=None, random_flip=False) logger.info('Eval dataset size : {dataset_size}'.format( dataset_size=len(eval_dataset))) logger.info('Building model...') if architecture == "dcgan": assert img_size == 32, "Sry, StudioGAN does not support dcgan models for generation of images larger than 32 resolution." module = __import__( 'models.{architecture}'.format(architecture=architecture), fromlist=['something']) logger.info('Modules are located on models.{architecture}'.format( architecture=architecture)) Gen = module.Generator(z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn, conditional_strategy, num_classes, synchronized_bn, g_init, G_depth).to(default_device) Dis = module.Discriminator(img_size, d_conv_dim, d_spectral_norm, attention, attention_after_nth_dis_block, activation_fn, conditional_strategy, hypersphere_dim, num_classes, nonlinear_embed, normalize_embed, synchronized_bn, d_init, D_depth).to(default_device) if ema: print('Preparing EMA for G with decay of {}'.format(ema_decay)) Gen_copy = module.Generator(z_dim, shared_dim, img_size, g_conv_dim, g_spectral_norm, attention, attention_after_nth_gen_block, activation_fn, conditional_strategy, num_classes, synchronized_bn=False, initialize=False, G_depth=G_depth).to(default_device) Gen_ema = ema_(Gen, Gen_copy, ema_decay, ema_start) else: Gen_copy, Gen_ema = None, None if n_gpus > 1: Gen = DataParallel(Gen, output_device=second_device) Dis = DataParallel(Dis, output_device=second_device) if ema: Gen_copy = DataParallel(Gen_copy, output_device=second_device) if synchronized_bn: patch_replication_callback(Gen) patch_replication_callback(Dis) logger.info(count_parameters(Gen)) logger.info(Gen) logger.info(count_parameters(Dis)) logger.info(Dis) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers, drop_last=True) eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers, drop_last=False) G_loss = { 'vanilla': loss_dcgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen } D_loss = { 'vanilla': loss_dcgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis } if optimizer == "SGD": G_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, momentum=momentum, nesterov=nesterov) D_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, momentum=momentum, nesterov=nesterov) elif optimizer == "RMSprop": G_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, momentum=momentum, alpha=alpha) D_optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, momentum=momentum, alpha=alpha) elif optimizer == "Adam": G_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, [beta1, beta2]) D_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, [beta1, beta2]) elif optimizer == "AdamP": G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, betas=(beta1, beta2)) D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, betas=(beta1, beta2)) else: raise NotImplementedError checkpoint_dir = make_checkpoint_dir(train_config['checkpoint_folder'], run_name) if train_config['checkpoint_folder'] is not None: when = "current" if load_current is True else "best" g_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=G-{when}-weights-step*.pth".format(when=when)))[0] d_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=D-{when}-weights-step*.pth".format(when=when)))[0] Gen, G_optimizer, trained_seed, run_name, start_step, best_step = load_checkpoint( Gen, G_optimizer, g_checkpoint_dir) Dis, D_optimizer, trained_seed, run_name, start_step, best_step, best_fid, best_fid_checkpoint_path = load_checkpoint( Dis, D_optimizer, d_checkpoint_dir, metric=True) logger = make_logger(run_name, None) if ema: g_ema_checkpoint_dir = glob.glob( join(checkpoint_dir, "model=G_ema-{when}-weights-step*.pth".format( when=when)))[0] Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=True) Gen_ema.source, Gen_ema.target = Gen, Gen_copy writer = SummaryWriter(log_dir=join('./logs', run_name)) assert seed == trained_seed, "seed for sampling random numbers should be same!" logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir)) logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir)) if train_config['eval']: inception_model = InceptionV3().to(default_device) inception_model = DataParallel(inception_model, output_device=second_device) mu, sigma, is_score, is_std = prepare_inception_moments_eval_dataset( dataloader=eval_dataloader, generator=Gen, eval_mode=type4eval_dataset, inception_model=inception_model, splits=10, run_name=run_name, logger=logger, device=second_device) else: mu, sigma, inception_model = None, None, None logger.info('Start training...') trainer = Trainer( run_name=run_name, best_step=best_step, dataset_name=dataset_name, type4eval_dataset=type4eval_dataset, logger=logger, writer=writer, n_gpus=n_gpus, gen_model=Gen, dis_model=Dis, inception_model=inception_model, Gen_copy=Gen_copy, Gen_ema=Gen_ema, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, conditional_strategy=conditional_strategy, z_dim=z_dim, num_classes=num_classes, hypersphere_dim=hypersphere_dim, d_spectral_norm=d_spectral_norm, g_spectral_norm=g_spectral_norm, G_optimizer=G_optimizer, D_optimizer=D_optimizer, batch_size=batch_size, g_steps_per_iter=model_config['optimization']['g_steps_per_iter'], d_steps_per_iter=model_config['optimization']['d_steps_per_iter'], accumulation_steps=model_config['optimization']['accumulation_steps'], total_step=total_step, G_loss=G_loss[adv_loss], D_loss=D_loss[adv_loss], contrastive_lambda=model_config['loss_function']['contrastive_lambda'], tempering_type=model_config['loss_function']['tempering_type'], tempering_step=model_config['loss_function']['tempering_step'], start_temperature=model_config['loss_function']['start_temperature'], end_temperature=model_config['loss_function']['end_temperature'], gradient_penalty_for_dis=model_config['loss_function'] ['gradient_penalty_for_dis'], gradient_penelty_lambda=model_config['loss_function'] ['gradient_penelty_lambda'], weight_clipping_for_dis=model_config['loss_function'] ['weight_clipping_for_dis'], weight_clipping_bound=model_config['loss_function'] ['weight_clipping_bound'], consistency_reg=consistency_reg, consistency_lambda=model_config['loss_function']['consistency_lambda'], diff_aug=model_config['training_and_sampling_setting']['diff_aug'], prior=prior, truncated_factor=truncated_factor, ema=ema, latent_op=latent_op, latent_op_rate=model_config['training_and_sampling_setting'] ['latent_op_rate'], latent_op_step=model_config['training_and_sampling_setting'] ['latent_op_step'], latent_op_step4eval=model_config['training_and_sampling_setting'] ['latent_op_step4eval'], latent_op_alpha=model_config['training_and_sampling_setting'] ['latent_op_alpha'], latent_op_beta=model_config['training_and_sampling_setting'] ['latent_op_beta'], latent_norm_reg_weight=model_config['training_and_sampling_setting'] ['latent_norm_reg_weight'], default_device=default_device, second_device=second_device, print_every=train_config['print_every'], save_every=train_config['save_every'], checkpoint_dir=checkpoint_dir, evaluate=train_config['eval'], mu=mu, sigma=sigma, best_fid=best_fid, best_fid_checkpoint_path=best_fid_checkpoint_path, train_config=train_config, model_config=model_config, ) if conditional_strategy == 'ContraGAN' and train_config['train']: trainer.run_ours(current_step=start_step, total_step=total_step) elif train_config['train']: trainer.run(current_step=start_step, total_step=total_step) elif train_config['eval']: is_save = trainer.evaluation(step=start_step) if train_config['k_nearest_neighbor'] > 0: trainer.K_Nearest_Neighbor( train_config['criterion_4_k_nearest_neighbor'], train_config['number_of_nearest_samples'], random.randrange(num_classes))
def pretrain(data_dir, train_path, val_path, dictionary_path, dataset_limit, vocabulary_size, batch_size, max_len, epochs, clip_grads, device, layers_count, hidden_size, heads_count, d_ff, dropout_prob, log_output, checkpoint_dir, print_every, save_every, config, run_name=None, **_): random.seed(0) np.random.seed(0) torch.manual_seed(0) train_path = train_path if data_dir is None else join(data_dir, train_path) val_path = val_path if data_dir is None else join(data_dir, val_path) dictionary_path = dictionary_path if data_dir is None else join(data_dir, dictionary_path) run_name = run_name if run_name is not None else make_run_name(RUN_NAME_FORMAT, phase='pretrain', config=config) logger = make_logger(run_name, log_output) logger.info('Run name : {run_name}'.format(run_name=run_name)) logger.info(config) logger.info('Constructing dictionaries...') dictionary = IndexDictionary.load(dictionary_path=dictionary_path, vocabulary_size=vocabulary_size) vocabulary_size = len(dictionary) logger.info(f'dictionary vocabulary : {vocabulary_size} tokens') logger.info('Loading datasets...') train_dataset = PairedDataset(data_path=train_path, dictionary=dictionary, dataset_limit=dataset_limit) val_dataset = PairedDataset(data_path=val_path, dictionary=dictionary, dataset_limit=dataset_limit) logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset))) logger.info('Building model...') model = build_model(layers_count, hidden_size, heads_count, d_ff, dropout_prob, max_len, vocabulary_size) logger.info(model) logger.info('{parameters_count} parameters'.format( parameters_count=sum([p.nelement() for p in model.parameters()]))) loss_model = MLMNSPLossModel(model) if torch.cuda.device_count() > 1: loss_model = DataParallel(loss_model, output_device=1) metric_functions = [mlm_accuracy, nsp_accuracy] train_dataloader = DataLoader( train_dataset, batch_size=batch_size, collate_fn=pretraining_collate_function) val_dataloader = DataLoader( val_dataset, batch_size=batch_size, collate_fn=pretraining_collate_function) optimizer = NoamOptimizer(model.parameters(), d_model=hidden_size, factor=2, warmup_steps=10000, betas=(0.9, 0.999), weight_decay=0.01) checkpoint_dir = make_checkpoint_dir(checkpoint_dir, run_name, config) logger.info('Start training...') trainer = Trainer( loss_model=loss_model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, metric_functions=metric_functions, optimizer=optimizer, clip_grads=clip_grads, logger=logger, checkpoint_dir=checkpoint_dir, print_every=print_every, save_every=save_every, device=device ) trainer.run(epochs=epochs) return trainer
def finetune(pretrained_checkpoint, data_dir, train_path, val_path, dictionary_path, vocabulary_size, batch_size, max_len, epochs, lr, clip_grads, device, layers_count, hidden_size, heads_count, d_ff, dropout_prob, log_output, checkpoint_dir, print_every, save_every, config, run_name=None, **_): random.seed(0) np.random.seed(0) torch.manual_seed(0) train_path = train_path if data_dir is None else join(data_dir, train_path) val_path = val_path if data_dir is None else join(data_dir, val_path) dictionary_path = dictionary_path if data_dir is None else join(data_dir, dictionary_path) run_name = run_name if run_name is not None else make_run_name(RUN_NAME_FORMAT, phase='finetune', config=config) logger = make_logger(run_name, log_output) logger.info('Run name : {run_name}'.format(run_name=run_name)) logger.info(config) logger.info('Constructing dictionaries...') dictionary = IndexDictionary.load(dictionary_path=dictionary_path, vocabulary_size=vocabulary_size) vocabulary_size = len(dictionary) logger.info(f'dictionary vocabulary : {vocabulary_size} tokens') logger.info('Loading datasets...') train_dataset = SST2IndexedDataset(data_path=train_path, dictionary=dictionary) val_dataset = SST2IndexedDataset(data_path=val_path, dictionary=dictionary) logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset))) logger.info('Building model...') pretrained_model = build_model(layers_count, hidden_size, heads_count, d_ff, dropout_prob, max_len, vocabulary_size) pretrained_model.load_state_dict(torch.load(pretrained_checkpoint, map_location='cpu')['state_dict']) model = FineTuneModel(pretrained_model, hidden_size, num_classes=2) logger.info(model) logger.info('{parameters_count} parameters'.format( parameters_count=sum([p.nelement() for p in model.parameters()]))) loss_model = ClassificationLossModel(model) metric_functions = [classification_accuracy] train_dataloader = DataLoader( train_dataset, batch_size=batch_size, collate_fn=classification_collate_function) val_dataloader = DataLoader( val_dataset, batch_size=batch_size, collate_fn=classification_collate_function) optimizer = Adam(model.parameters(), lr=lr) checkpoint_dir = make_checkpoint_dir(checkpoint_dir, run_name, config) logger.info('Start training...') trainer = Trainer( loss_model=loss_model, train_dataloader=train_dataloader, val_dataloader=val_dataloader, metric_functions=metric_functions, optimizer=optimizer, clip_grads=clip_grads, logger=logger, checkpoint_dir=checkpoint_dir, print_every=print_every, save_every=save_every, device=device ) trainer.run(epochs=epochs) return trainer
def train_framework(dataset_name, architecture, num_classes, img_size, data_path, eval_dataset, hdf5_path_train, hdf5_path_valid, train_rate, auxiliary_classifier, projection_discriminator, contrastive_training, hyper_dim, nonlinear_embed, normalize_embed, g_spectral_norm, d_spectral_norm, attention, reduce_class, at_after_th_gen_block, at_after_th_dis_block, leaky_relu, g_init, d_init, latent_op, consistency_reg, make_positive_aug, synchronized_bn, ema, ema_decay, ema_start, adv_loss, z_dim, shared_dim, g_conv_dim, d_conv_dim, batch_size, total_step, truncated_factor, prior, d_lr, g_lr, beta1, beta2, batch4metrics, config, **_): fix_all_seed(config['seed']) cudnn.benchmark = True # Not good Generator for undetermined input size cudnn.deterministic = False n_gpus = torch.cuda.device_count() default_device = torch.cuda.current_device() second_device = default_device if n_gpus == 1 else default_device+1 assert batch_size % n_gpus == 0, "batch_size should be divided by the number of gpus " if n_gpus == 1: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') start_step = 0 best_val_fid, best_checkpoint_fid_path, best_val_is, best_checkpoint_is_path = None, None, None, None run_name = make_run_name(RUN_NAME_FORMAT, framework=config['config_path'].split('/')[3][:-5], phase='train', config=config) logger = make_logger(run_name, None) writer = SummaryWriter(log_dir=join('./logs', run_name)) logger.info('Run name : {run_name}'.format(run_name=run_name)) logger.info(config) logger.info('Loading train datasets...') train_dataset = LoadDataset(dataset_name, data_path, train=True, download=True, resize_size=img_size, hdf5_path=hdf5_path_train, consistency_reg=consistency_reg, make_positive_aug=make_positive_aug) if train_rate < 1.0: num_train = int(train_rate*len(train_dataset)) train_dataset, _ = torch.utils.data.random_split(train_dataset, [num_train, len(train_dataset) - num_train]) logger.info('Train dataset size : {dataset_size}'.format(dataset_size=len(train_dataset))) logger.info('Loading valid datasets...') valid_dataset = LoadDataset(dataset_name, data_path, train=False, download=True, resize_size=img_size, hdf5_path=hdf5_path_valid) logger.info('Valid dataset size : {dataset_size}'.format(dataset_size=len(valid_dataset))) logger.info('Building model...') module = __import__('models.{architecture}'.format(architecture=architecture),fromlist=['something']) logger.info('Modules are located on models.{architecture}'.format(architecture=architecture)) num_classes = int(reduce_class*num_classes) Gen = module.Generator(z_dim, shared_dim, g_conv_dim, g_spectral_norm, attention, at_after_th_gen_block, leaky_relu, auxiliary_classifier, projection_discriminator, num_classes, contrastive_training, synchronized_bn, g_init).to(default_device) Dis = module.Discriminator(d_conv_dim, d_spectral_norm, attention, at_after_th_dis_block, leaky_relu, auxiliary_classifier, projection_discriminator, hyper_dim, num_classes, contrastive_training, nonlinear_embed, normalize_embed, synchronized_bn, d_init).to(default_device) if ema: print('Preparing EMA for G with decay of {}'.format(ema_decay)) Gen_copy = module.Generator(z_dim, shared_dim, g_conv_dim, g_spectral_norm, attention, at_after_th_gen_block, leaky_relu, auxiliary_classifier, projection_discriminator, num_classes, contrastive_training, synchronized_bn=False, initialize=False).to(default_device) Gen_ema = ema_(Gen, Gen_copy, ema_decay, ema_start) else: Gen_copy, Gen_ema = None, None if n_gpus > 1: Gen = DataParallel(Gen, output_device=second_device) Dis = DataParallel(Dis, output_device=second_device) if ema: Gen_copy = DataParallel(Gen_copy, output_device=second_device) if config['synchronized_bn']: patch_replication_callback(Gen) patch_replication_callback(Dis) logger.info(count_parameters(Gen)) logger.info(Gen) logger.info(count_parameters(Dis)) logger.info(Dis) if reduce_class != 1.0: assert dataset_name == "TINY_ILSVRC2012" or "ILSVRC2012", "reduce_class mode can not be applied on the CIFAR10 dataset" n_train = int(reduce_class*len(train_dataset)) n_valid = int(reduce_class*len(valid_dataset)) train_weights = [1.0]*n_train + [0.0]*(len(train_dataset) - n_train) valid_weights = [1.0]*n_valid + [0.0]*(len(valid_dataset) - n_valid) train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_weights, len(train_weights)) valid_sampler = torch.utils.data.sampler.WeightedRandomSampler(valid_weights, len(valid_weights)) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, shuffle=False, pin_memory=True, num_workers=config['num_workers'], drop_last=True) evaluation_dataloader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=batch4metrics, shuffle=False, pin_memory=True, num_workers=config['num_workers'], drop_last=False) else: train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=config['num_workers'], drop_last=True) evaluation_dataloader = DataLoader(valid_dataset, batch_size=batch4metrics, shuffle=True, pin_memory=True, num_workers=config['num_workers'], drop_last=False) G_loss = {'vanilla': loss_dcgan_gen, 'hinge': loss_hinge_gen, 'wasserstein': loss_wgan_gen} D_loss = {'vanilla': loss_dcgan_dis, 'hinge': loss_hinge_dis, 'wasserstein': loss_wgan_dis} G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Gen.parameters()), g_lr, [beta1, beta2]) D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, Dis.parameters()), d_lr, [beta1, beta2]) checkpoint_dir = make_checkpoint_dir(config['checkpoint_folder'], run_name, config) if config['checkpoint_folder'] is not None: logger = make_logger(run_name, config['log_output_path']) g_checkpoint_dir = glob.glob(os.path.join(checkpoint_dir,"model=G-step=" + str(config['step']) + "*.pth"))[0] d_checkpoint_dir = glob.glob(os.path.join(checkpoint_dir,"model=D-step=" + str(config['step']) + "*.pth"))[0] Gen, G_optimizer, seed, run_name, start_step = load_checkpoint(Gen, G_optimizer, g_checkpoint_dir) Dis, D_optimizer, seed, run_name, start_step, best_val_fid, best_checkpoint_fid_path,\ best_val_is, best_checkpoint_is_path = load_checkpoint(Dis, D_optimizer, d_checkpoint_dir, metric=True) if ema: g_ema_checkpoint_dir = glob.glob(os.path.join(checkpoint_dir, "model=G_ema-step=" + str(config['step']) + "*.pth"))[0] Gen_copy = load_checkpoint(Gen_copy, None, g_ema_checkpoint_dir, ema=ema) Gen_ema.source, Gen_ema.target = Gen, Gen_copy writer = SummaryWriter(log_dir=join('./logs', run_name)) assert config['seed'] == seed, "seed for sampling random numbers should be same!" logger.info('Generator checkpoint is {}'.format(g_checkpoint_dir)) logger.info('Discriminator checkpoint is {}'.format(d_checkpoint_dir)) if config['eval']: inception_model = InceptionV3().to(default_device) inception_model = DataParallel(inception_model, output_device=second_device) mu, sigma, is_score, is_std = prepare_inception_moments_eval_dataset(dataloader=evaluation_dataloader, inception_model=inception_model, reduce_class=reduce_class, splits=10, logger=logger, device=second_device, eval_dataset=eval_dataset) else: mu, sigma, inception_model = None, None, None logger.info('Start training...') trainer = Trainer( run_name=run_name, logger=logger, writer=writer, n_gpus=n_gpus, gen_model=Gen, dis_model=Dis, inception_model=inception_model, Gen_copy=Gen_copy, Gen_ema=Gen_ema, train_dataloader=train_dataloader, evaluation_dataloader=evaluation_dataloader, G_loss=G_loss[adv_loss], D_loss=D_loss[adv_loss], auxiliary_classifier=auxiliary_classifier, contrastive_training=contrastive_training, contrastive_lambda=config['contrastive_lambda'], softmax_posterior=config['softmax_posterior'], contrastive_softmax=config['contrastive_softmax'], hyper_dim=config['hyper_dim'], tempering=config['tempering'], discrete_tempering=config['discrete_tempering'], tempering_times=config['tempering_times'], start_temperature=config['start_temperature'], end_temperature=config['end_temperature'], gradient_penalty_for_dis=config['gradient_penalty_for_dis'], lambda4lp=config['lambda4lp'], lambda4gp=config['lambda4gp'], weight_clipping_for_dis=config['weight_clipping_for_dis'], weight_clipping_bound=config['weight_clipping_bound'], latent_op=latent_op, latent_op_rate=config['latent_op_rate'], latent_op_step=config['latent_op_step'], latent_op_step4eval=config['latent_op_step4eval'], latent_op_alpha=config['latent_op_alpha'], latent_op_beta=config['latent_op_beta'], latent_norm_reg_weight=config['latent_norm_reg_weight'], consistency_reg=consistency_reg, consistency_lambda=config['consistency_lambda'], make_positive_aug=make_positive_aug, G_optimizer=G_optimizer, D_optimizer=D_optimizer, default_device=default_device, second_device=second_device, batch_size=batch_size, z_dim=z_dim, num_classes=num_classes, truncated_factor=truncated_factor, prior=prior, g_steps_per_iter=config['g_steps_per_iter'], d_steps_per_iter=config['d_steps_per_iter'], accumulation_steps=config['accumulation_steps'], lambda4ortho=config['lambda4ortho'], print_every=config['print_every'], save_every=config['save_every'], checkpoint_dir=checkpoint_dir, evaluate=config['eval'], mu=mu, sigma=sigma, best_val_fid=best_val_fid, best_checkpoint_fid_path=best_checkpoint_fid_path, best_val_is=best_val_is, best_checkpoint_is_path=best_checkpoint_is_path, config=config, ) if contrastive_training: trainer.run_ours(current_step=start_step, total_step=total_step) else: trainer.run(current_step=start_step, total_step=total_step)