def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): """Calculate quality metrics for previous training run or pretrained network pickle. Examples: \b # Previous training run: look up options automatically, save result to JSONL file. python calc_metrics.py --metrics=pr50k3_full \\ --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl \b # Pre-trained network pickle: specify dataset explicitly, print result to stdout. python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl Available metrics: \b ADA paper: fid50k_full Frechet inception distance against the full dataset. kid50k_full Kernel inception distance against the full dataset. pr50k3_full Precision and recall againt the full dataset. is50k Inception score for CIFAR-10. \b StyleGAN and StyleGAN2 papers: fid50k Frechet inception distance against 50k real images. kid50k Kernel inception distance against 50k real images. pr50k3 Precision and recall against 50k real images. ppl2_wend Perceptual path length in W at path endpoints against full image. ppl_zfull Perceptual path length in Z for full paths against cropped image. ppl_wfull Perceptual path length in W for full paths against cropped image. ppl_zend Perceptual path length in Z at path endpoints against cropped image. ppl_wend Perceptual path length in W at path endpoints against cropped image. """ dnnlib.util.Logger(should_flush=True) # Validate arguments. args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): ctx.fail( '\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) if not args.num_gpus >= 1: ctx.fail('--gpus must be at least 1') # Load network. if not dnnlib.util.is_url( network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): ctx.fail('--network must point to a file or URL') if args.verbose: print(f'Loading network from "{network_pkl}"...') with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: network_dict = legacy.load_network_pkl(f) args.G = network_dict['G_ema'] # subclass of torch.nn.Module # Initialize dataset options. if data is not None: args.dataset_kwargs = dnnlib.EasyDict( class_name='training.dataset.ImageFolderDataset', path=data) elif network_dict['training_set_kwargs'] is not None: args.dataset_kwargs = dnnlib.EasyDict( network_dict['training_set_kwargs']) else: ctx.fail('Could not look up dataset options; please specify --data') # Finalize dataset options. args.dataset_kwargs.resolution_h = args.G.img_resolution_h args.dataset_kwargs.resolution_w = args.G.img_resolution_w args.dataset_kwargs.use_labels = (args.G.c_dim != 0) if mirror is not None: args.dataset_kwargs.xflip = mirror # Print dataset options. if args.verbose: print('Dataset options:') print(json.dumps(args.dataset_kwargs, indent=2)) # Locate run dir. args.run_dir = None if os.path.isfile(network_pkl): pkl_dir = os.path.dirname(network_pkl) if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): args.run_dir = pkl_dir # Launch processes. if args.verbose: print('Launching processes...') torch.multiprocessing.set_start_method('spawn') with tempfile.TemporaryDirectory() as temp_dir: if args.num_gpus == 1: subprocess_fn(rank=0, args=args, temp_dir=temp_dir) else: torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
def setup_training_loop_kwargs( # General options (not included in desc). gpus=None, # Number of GPUs: <int>, default = 1 gpu snap=None, # Snapshot interval: <int>, default = 50 ticks metrics=None, # List of metric names: [], ['fid50k_full'] (default), ... seed=None, # Random seed: <int>, default = 0 # Dataset. data=None, # Training dataset (required): <path> cond=None, # Train conditional model based on dataset labels: <bool>, default = False subset=None, # Train with only N images: <int>, default = all mirror=None, # Augment dataset with x-flips: <bool>, default = False # Base config. cfg=None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar' gamma=None, # Override R1 gamma: <float> kimg=None, # Override training duration: <int> batch=None, # Override batch size: <int> # Discriminator augmentation. aug=None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed' p=None, # Specify p for 'fixed' (required): <float> target=None, # Override ADA target for 'ada': <float>, default = depends on aug augpipe=None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc' # Transfer learning. resume=None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url> freezed=None, # Freeze-D: <int>, default = 0 discriminator layers # Performance options (not included in desc). fp32=None, # Disable mixed-precision training: <bool>, default = False nhwc=None, # Use NHWC memory format with FP16: <bool>, default = False nobench=None, # Disable cuDNN benchmarking: <bool>, default = False workers=None, # Override number of DataLoader workers: <int>, default = 3 **kwargs, ): print("Unrecognized arguments:", kwargs) args = dnnlib.EasyDict() # ------------------------------------------ # General options: gpus, snap, metrics, seed # ------------------------------------------ if gpus is None: gpus = 1 assert isinstance(gpus, int) if not (gpus >= 1 and gpus & (gpus - 1) == 0): raise UserError("--gpus must be a power of two") args.num_gpus = gpus if snap is None: snap = 50 assert isinstance(snap, int) if snap < 1: raise UserError("--snap must be at least 1") args.image_snapshot_ticks = 1 args.network_snapshot_ticks = 5 if metrics is None: metrics = ["fid50k_full"] assert isinstance(metrics, list) if not all(metric_main.is_valid_metric(metric) for metric in metrics): raise UserError( "\n".join(["--metrics can only contain the following values:"] + metric_main.list_valid_metrics()) ) args.metrics = metrics if seed is None: seed = 0 assert isinstance(seed, int) args.random_seed = seed # ----------------------------------- # Dataset: data, cond, subset, mirror # ----------------------------------- assert data is not None assert isinstance(data, str) args.training_set_kwargs = dnnlib.EasyDict( class_name="training.dataset.ImageFolderDataset", path=data, use_labels=True, max_size=None, xflip=False ) args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2) try: training_set = dnnlib.util.construct_class_by_name( **args.training_set_kwargs ) # subclass of training.dataset.Dataset args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size desc = training_set.name del training_set # conserve memory except IOError as err: raise UserError(f"--data: {err}") if cond is None: cond = False assert isinstance(cond, bool) if cond: if not args.training_set_kwargs.use_labels: raise UserError("--cond=True requires labels specified in dataset.json") desc += "-cond" else: args.training_set_kwargs.use_labels = False if subset is not None: assert isinstance(subset, int) if not 1 <= subset <= args.training_set_kwargs.max_size: raise UserError(f"--subset must be between 1 and {args.training_set_kwargs.max_size}") desc += f"-subset{subset}" if subset < args.training_set_kwargs.max_size: args.training_set_kwargs.max_size = subset args.training_set_kwargs.random_seed = args.random_seed if mirror is None: mirror = False assert isinstance(mirror, bool) if mirror: desc += "-mirror" args.training_set_kwargs.xflip = True # ------------------------------------ # Base config: cfg, gamma, kimg, batch # ------------------------------------ if cfg is None: cfg = "auto" assert isinstance(cfg, str) desc += f"-{cfg}" cfg_specs = { "auto": dict( ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2 ), # Populated dynamically based on resolution and GPU count. "stylegan2": dict( ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8 ), # Uses mixed-precision, unlike the original StyleGAN2. "paper256": dict( ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8 ), "paper512": dict( ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8 ), "paper1024": dict( ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8 ), "cifar": dict( ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2 ), "wav": dict(ref_gpus=2, kimg=1000, mb=8, mbstd=4, fmaps=1, gamma=10, lrate=0.002, ema=10, ramp=None, map=8), } assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) if cfg == "auto": desc += f"{gpus:d}" spec.ref_gpus = gpus res = args.training_set_kwargs.resolution spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay spec.mbstd = min( spec.mb // gpus, 4 ) # other hyperparams behave more predictably if mbstd group size remains fixed spec.fmaps = 1 if res >= 512 else 0.5 spec.lrate = 0.002 if res >= 1024 else 0.0025 spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula spec.ema = spec.mb * 10 / 32 args.G_kwargs = dnnlib.EasyDict( class_name="training.networks.Generator", z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict(), ) args.D_kwargs = dnnlib.EasyDict( class_name="training.networks.Discriminator", block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict(), ) args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768) args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 args.G_kwargs.mapping_kwargs.num_layers = spec.map args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 3 # enable mixed-precision training args.G_kwargs.synthesis_kwargs.conv_clamp = ( args.D_kwargs.conv_clamp ) = 256 # clamp activations to avoid float16 overflow args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd args.G_opt_kwargs = dnnlib.EasyDict(class_name="torch.optim.Adam", lr=spec.lrate, betas=[0, 0.99], eps=1e-8) args.D_opt_kwargs = dnnlib.EasyDict(class_name="torch.optim.Adam", lr=spec.lrate, betas=[0, 0.99], eps=1e-8) args.loss_kwargs = dnnlib.EasyDict(class_name="training.loss.StyleGAN2Loss", r1_gamma=spec.gamma) args.total_kimg = spec.kimg args.batch_size = 8 args.batch_gpu = 4 args.ema_kimg = spec.ema args.ema_rampup = spec.ramp if cfg == "cifar" or cfg == "wav": args.loss_kwargs.pl_weight = 0 # disable path length regularization args.loss_kwargs.style_mixing_prob = 0 # disable style mixing # args.D_kwargs.architecture = "orig" # disable residual skip connections if gamma is not None: assert isinstance(gamma, float) if not gamma >= 0: raise UserError("--gamma must be non-negative") desc += f"-gamma{gamma:g}" args.loss_kwargs.r1_gamma = gamma if kimg is not None: assert isinstance(kimg, int) if not kimg >= 1: raise UserError("--kimg must be at least 1") desc += f"-kimg{kimg:d}" args.total_kimg = kimg if batch is not None: assert isinstance(batch, int) if not (batch >= 1 and batch % gpus == 0): raise UserError("--batch must be at least 1 and divisible by --gpus") desc += f"-batch{batch}" args.batch_size = batch args.batch_gpu = batch // gpus # --------------------------------------------------- # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- if aug is None: aug = "ada" else: assert isinstance(aug, str) desc += f"-{aug}" if aug == "ada": args.ada_target = 0.6 elif aug == "noaug": pass elif aug == "fixed": if p is None: raise UserError(f"--aug={aug} requires specifying --p") else: raise UserError(f"--aug={aug} not supported") if p is not None: assert isinstance(p, float) if aug != "fixed": raise UserError("--p can only be specified with --aug=fixed") if not 0 <= p <= 1: raise UserError("--p must be between 0 and 1") desc += f"-p{p:g}" args.augment_p = p if target is not None: assert isinstance(target, float) if aug != "ada": raise UserError("--target can only be specified with --aug=ada") if not 0 <= target <= 1: raise UserError("--target must be between 0 and 1") desc += f"-target{target:g}" args.ada_target = target assert augpipe is None or isinstance(augpipe, str) if augpipe is None: augpipe = "bgc" else: if aug == "noaug": raise UserError("--augpipe cannot be specified with --aug=noaug") desc += f"-{augpipe}" augpipe_specs = { "blit": dict(xflip=1, rotate90=1, xint=1), "geom": dict(scale=1, rotate=1, aniso=1, xfrac=1), "color": dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), "filter": dict(imgfilter=1), "noise": dict(noise=1), "cutout": dict(cutout=1), "bg": dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), "bgc": dict( xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, ), "bgcf": dict( xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, ), "bgcfn": dict( xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, ), "bgcfnc": dict( xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1, ), } assert augpipe in augpipe_specs if aug != "noaug": args.augment_kwargs = dnnlib.EasyDict(class_name="training.augment.AugmentPipe", **augpipe_specs[augpipe]) # ---------------------------------- # Transfer learning: resume, freezed # ---------------------------------- resume_specs = { "ffhq256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl", "ffhq512": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl", "ffhq1024": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl", "celebahq256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl", "lsundog256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl", } assert resume is None or isinstance(resume, str) if resume is None: resume = "noresume" elif resume == "noresume": desc += "-noresume" elif resume in resume_specs: desc += f"-resume{resume}" args.resume_pkl = resume_specs[resume] # predefined url else: desc += "-resumecustom" args.resume_pkl = resume # custom path or url if resume != "noresume": args.ada_kimg = 100 # make ADA react faster at the beginning args.ema_rampup = None # disable EMA rampup if freezed is not None: assert isinstance(freezed, int) if not freezed >= 0: raise UserError("--freezed must be non-negative") desc += f"-freezed{freezed:d}" args.D_kwargs.block_kwargs.freeze_layers = freezed # ------------------------------------------------- # Performance options: fp32, nhwc, nobench, workers # ------------------------------------------------- if fp32 is None: fp32 = False assert isinstance(fp32, bool) if fp32: args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None if nhwc is None: nhwc = False assert isinstance(nhwc, bool) if nhwc: args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True if nobench is None: nobench = False assert isinstance(nobench, bool) if nobench: args.cudnn_benchmark = False if workers is not None: assert isinstance(workers, int) if not workers >= 1: raise UserError("--workers must be at least 1") args.data_loader_kwargs.num_workers = workers return desc, args
def setup_training_loop_kwargs( # General options (not included in desc). gpus=None, # Number of GPUs: <int>, default = 1 gpu snap=None, # Snapshot interval: <int>, default = 50 ticks metrics=None, # List of metric names: [], ['fid50k_full'] (default), ... comet_api_key=None, # API key for Comet.ml: <str>, default = '' (don't use comet) comet_name=None, # Experiment name for Comet.ml: <str>, default = '' (default Comet naming) gradient_clipping=None, # None or max global gradient L2 norm: <float>, default = None (no clipping) seed=None, # Random seed: <int>, default = 0 # Train ataset. data=None, # Training dataset (required): <path> cond=None, # Train conditional model based on dataset labels: <bool>, default = False subset=None, # Train with only N images: <int>, default = all mirror=None, # Augment dataset with x-flips: <bool>, default = False # ValidationDataset. testdata=None, # Validation dataset (optional): <path> # Base config. cfg=None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar' gamma=None, # Override R1 gamma: <float> kimg=None, # Override training duration: <int> batch=None, # Override batch size: <int> # Discriminator augmentation. diffaugment=None, # Comma-separated list of DiffAugment policy, default = 'color,translation,cutout' diffaugment_placement=None, # Comma-separated list of DiffAugment applying placement, default = 'real,generated,backprop' aug=None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed' p=None, # Specify p for 'fixed' (required): <float> target=None, # Override ADA target for 'ada': <float>, default = depends on aug augpipe=None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc' # Transfer learning. resume=None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url> freezed=None, # Freeze-D: <int>, default = 0 discriminator layers # Performance options (not included in desc). fp32=None, # Disable mixed-precision training: <bool>, default = False nhwc=None, # Use NHWC memory format with FP16: <bool>, default = False allow_tf32=None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False nobench=None, # Disable cuDNN benchmarking: <bool>, default = False workers=None, # Override number of DataLoader workers: <int>, default = 3 ): args = dnnlib.EasyDict() # ----------------------------------------------------------- # General options: gpus, snap, metrics, comet, clipping, seed # ----------------------------------------------------------- if gpus is None: gpus = 1 assert isinstance(gpus, int) if not (gpus >= 1 and gpus & (gpus - 1) == 0): raise UserError('--gpus must be a power of two') args.num_gpus = gpus if snap is None: snap = 50 assert isinstance(snap, int) if snap < 1: raise UserError('--snap must be at least 1') args.image_snapshot_ticks = snap args.network_snapshot_ticks = snap if metrics is None: metrics = ['fid50k_full'] assert isinstance(metrics, list) if not all(metric_main.is_valid_metric(metric) for metric in metrics): raise UserError( '\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) args.metrics = metrics if comet_api_key is None: comet_api_key = '' if comet_name is None: comet_name = '' if comet_api_key: if comet_ml is None: print( 'comet_ml is not imported! Proceeding without comet.ml logging' ) args.comet_api_key = '' args.comet_experiment_key = '' else: args.comet_api_key = comet_api_key experiment = comet_ml.Experiment(api_key=comet_api_key, project_name='Sirius SOTA GANs', auto_output_logging='simple', auto_log_co2=False, auto_metric_logging=False, auto_param_logging=False) experiment.set_name(comet_name) args.comet_experiment_key = experiment.get_key() else: args.comet_api_key = '' args.comet_experiment_key = '' args.gradient_clipping = gradient_clipping if seed is None: seed = 0 assert isinstance(seed, int) args.random_seed = seed # ----------------------------------- # Dataset: data, cond, subset, mirror # ----------------------------------- assert data is not None assert isinstance(data, str) args.training_set_kwargs = dnnlib.EasyDict( class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False) args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2) try: training_set = dnnlib.util.construct_class_by_name( **args.training_set_kwargs) # subclass of training.dataset.Dataset args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels args.training_set_kwargs.max_size = len( training_set) # be explicit about dataset size desc = training_set.name if testdata is not None: args.validation_set_kwargs = dnnlib.EasyDict( class_name='training.dataset.ImageFolderDataset', path=testdata, use_labels=True, max_size=None, xflip=False) validation_set = dnnlib.util.construct_class_by_name( **args.validation_set_kwargs) args.validation_set_kwargs.resolution = training_set.resolution # be explicit about resolution args.validation_set_kwargs.use_labels = training_set.has_labels # be explicit about labels args.validation_set_kwargs.max_size = len( validation_set) # be explicit about dataset size del validation_set else: args.validation_set_kwargs = {} del training_set # conserve memory except IOError as err: raise UserError(f'--data: {err}') if cond is None: cond = False assert isinstance(cond, bool) if cond: if not args.training_set_kwargs.use_labels: raise UserError( '--cond=True requires labels specified in dataset.json') desc += '-cond' else: args.training_set_kwargs.use_labels = False if args.validation_set_kwargs != {}: args.validation_set_kwargs.use_labels = False if subset is not None: assert isinstance(subset, int) if not 1 <= subset <= args.training_set_kwargs.max_size: raise UserError( f'--subset must be between 1 and {args.training_set_kwargs.max_size}' ) desc += f'-subset{subset}' if subset < args.training_set_kwargs.max_size: args.training_set_kwargs.max_size = subset args.training_set_kwargs.random_seed = args.random_seed if mirror is None: mirror = False assert isinstance(mirror, bool) if mirror: desc += '-mirror' args.training_set_kwargs.xflip = True # ------------------------------------ # Base config: cfg, gamma, kimg, batch # ------------------------------------ if cfg is None: cfg = 'low_shot' assert isinstance(cfg, str) desc += f'-{cfg}' cfg_specs = { 'custom_low_shot': dict(ref_gpus=-1, kimg=500, mb=8, mbstd=4, fmaps=0.5, lrate=5e-6, gamma=10, ema=10, ramp=None, map=2, snap=10), 'low_shot': dict(ref_gpus=-1, kimg=300, mb=8, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=2, snap=10), 'auto': dict( ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count. 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2. 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8), 'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2), } assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) if cfg == 'auto': desc += f'{gpus:d}' spec.ref_gpus = gpus res = args.training_set_kwargs.resolution spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay spec.mbstd = min( spec.mb // gpus, 4 ) # other hyperparams behave more predictably if mbstd group size remains fixed spec.fmaps = 1 if res >= 512 else 0.5 spec.lrate = 0.002 if res >= 1024 else 0.0025 spec.gamma = 0.0002 * (res**2) / spec.mb # heuristic formula spec.ema = spec.mb * 10 / 32 if spec.ref_gpus < 0: spec.ref_gpus = gpus if spec.get('snap', None): args.image_snapshot_ticks = args.network_snapshot_ticks = spec.snap args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict()) args.D_kwargs = dnnlib.EasyDict( class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict()) args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int( spec.fmaps * 32768) args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 args.G_kwargs.mapping_kwargs.num_layers = spec.map args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8) args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8) args.loss_kwargs = dnnlib.EasyDict( class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma) args.total_kimg = spec.kimg args.batch_size = spec.mb args.batch_gpu = spec.mb // spec.ref_gpus args.ema_kimg = spec.ema args.ema_rampup = spec.ramp if cfg == 'cifar': args.loss_kwargs.pl_weight = 0 # disable path length regularization args.loss_kwargs.style_mixing_prob = 0 # disable style mixing args.D_kwargs.architecture = 'orig' # disable residual skip connections if gamma is not None: assert isinstance(gamma, float) if not gamma >= 0: raise UserError('--gamma must be non-negative') desc += f'-gamma{gamma:g}' args.loss_kwargs.r1_gamma = gamma if kimg is not None: assert isinstance(kimg, int) if not kimg >= 1: raise UserError('--kimg must be at least 1') desc += f'-kimg{kimg:d}' args.total_kimg = kimg if batch is not None: assert isinstance(batch, int) if not (batch >= 1 and batch % gpus == 0): raise UserError( '--batch must be at least 1 and divisible by --gpus') desc += f'-batch{batch}' args.batch_size = batch args.batch_gpu = batch // gpus # --------------------------------------------------- # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- if diffaugment is None: diffaugment = 'color,translation,cutout' if diffaugment: args.loss_kwargs.diffaugment = diffaugment aug = 'noaug' desc += '-{}'.format(diffaugment.replace(',', '-')) elif aug is None: aug = 'ada' else: assert isinstance(aug, str) desc += f'-{aug}' if diffaugment_placement is None: if diffaugment: diffaugment_placement = 'real,generated,backprop' else: diffaugment_placement = '' if diffaugment_placement: args.loss_kwargs.diffaugment_placement = diffaugment_placement if aug == 'ada': args.ada_target = 0.6 elif aug == 'noaug': pass elif aug == 'fixed': if p is None: raise UserError(f'--aug={aug} requires specifying --p') else: raise UserError(f'--aug={aug} not supported') if p is not None: assert isinstance(p, float) if aug != 'fixed': raise UserError('--p can only be specified with --aug=fixed') if not 0 <= p <= 1: raise UserError('--p must be between 0 and 1') desc += f'-p{p:g}' args.augment_p = p if target is not None: assert isinstance(target, float) if aug != 'ada': raise UserError('--target can only be specified with --aug=ada') if not 0 <= target <= 1: raise UserError('--target must be between 0 and 1') desc += f'-target{target:g}' args.ada_target = target assert augpipe is None or isinstance(augpipe, str) if augpipe is None: augpipe = 'bgc' else: if aug == 'noaug': raise UserError('--augpipe cannot be specified with --aug=noaug') desc += f'-{augpipe}' augpipe_specs = { 'blit': dict(xflip=1, rotate90=1, xint=1), 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1), 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'filter': dict(imgfilter=1), 'noise': dict(noise=1), 'cutout': dict(cutout=1), 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1), 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1), 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1), } assert augpipe in augpipe_specs if aug != 'noaug': args.augment_kwargs = dnnlib.EasyDict( class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe]) # ---------------------------------- # Transfer learning: resume, freezed # ---------------------------------- resume_specs = { 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', } assert resume is None or isinstance(resume, str) if resume is None: resume = 'noresume' elif resume == 'noresume': desc += '-noresume' elif resume in resume_specs: desc += f'-resume{resume}' args.resume_pkl = resume_specs[resume] # predefined url else: desc += '-resumecustom' args.resume_pkl = resume # custom path or url if resume != 'noresume': args.ada_kimg = 100 # make ADA react faster at the beginning args.ema_rampup = None # disable EMA rampup if freezed is not None: assert isinstance(freezed, int) if not freezed >= 0: raise UserError('--freezed must be non-negative') desc += f'-freezed{freezed:d}' args.D_kwargs.block_kwargs.freeze_layers = freezed # ------------------------------------------------- # Performance options: fp32, nhwc, nobench, workers # ------------------------------------------------- if fp32 is None: fp32 = False assert isinstance(fp32, bool) if fp32: args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None if nhwc is None: nhwc = False assert isinstance(nhwc, bool) if nhwc: args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True if nobench is None: nobench = False assert isinstance(nobench, bool) if nobench: args.cudnn_benchmark = False if allow_tf32 is None: allow_tf32 = False assert isinstance(allow_tf32, bool) if allow_tf32: args.allow_tf32 = True if workers is not None: assert isinstance(workers, int) if not workers >= 1: raise UserError('--workers must be at least 1') args.data_loader_kwargs.num_workers = workers return desc, args
def setup_training_loop_kwargs( # General options (not included in desc). gpus = None, # Number of GPUs: <int>, default = 1 gpu snap = None, # Snapshot interval: <int>, default = 50 ticks metrics = None, # List of metric names: [], ['fid50k_full'] (default), ... seed = None, # Random seed: <int>, default = 0 # Dataset. data = None, # Training dataset (required): <path> cond = None, # Train conditional model based on dataset labels: <bool>, default = False subset = None, # Train with only N images: <int>, default = all mirror = None, # Augment dataset with x-flips: <bool>, default = False # Base config. cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar' kimg = None, # Override training duration: <int> batch = None, # Override batch size: <int> # Discriminator augmentation. aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed' p = None, # Specify p for 'fixed' (required): <float> target = None, # Override ADA target for 'ada': <float>, default = depends on aug augpipe = None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc' # Transfer learning. resume = None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url> freezed = None, # Freeze-D: <int>, default = 0 discriminator layers # Performance options (not included in desc). fp32 = None, # Disable mixed-precision training: <bool>, default = False nhwc = None, # Use NHWC memory format with FP16: <bool>, default = False nobench = None, # Disable cuDNN benchmarking: <bool>, default = False workers = None, # Override number of DataLoader workers: <int>, default = 3 # Let's add the ability to overwrite g_fmaps/d_fmaps g_fmaps = None, d_fmaps = None, # Specify z_dim and w_dim z_dim = 512, w_dim = 512, hydra_cfg_name = None, ): args = dnnlib.EasyDict() # Hydra config initialize(config_path="../configs", job_name="test_app") hydra_cfg = compose(config_name=hydra_cfg_name) # ------------------------------------------ # General options: gpus, snap, metrics, seed # ------------------------------------------ if gpus is None: gpus = 1 assert isinstance(gpus, int) if not (gpus >= 1 and gpus & (gpus - 1) == 0): raise UserError('--gpus must be a power of two') args.num_gpus = gpus if snap is None: snap = 50 assert isinstance(snap, int) if snap < 1: raise UserError('--snap must be at least 1') args.image_snapshot_ticks = snap args.network_snapshot_ticks = snap if metrics is None: metrics = ['fid50k_full'] assert isinstance(metrics, list) if not all(metric_main.is_valid_metric(metric) for metric in metrics): raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) args.metrics = metrics if seed is None: seed = 0 assert isinstance(seed, int) args.random_seed = seed # ----------------------------------- # Dataset: data, cond, subset, mirror # ----------------------------------- assert data is not None assert isinstance(data, str) if hydra_cfg.dataset.use_multi_embs: args.training_set_kwargs = dnnlib.EasyDict( class_name='training.dataset.ImageFolderDatasetWithMultiEmbs', path=data, use_labels=True, max_size=None, xflip=False) else: args.training_set_kwargs = dnnlib.EasyDict( class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False) args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2) try: training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size desc = training_set.name del training_set # conserve memory except IOError as err: raise UserError('--data: {}'.format(err)) if cond is None: cond = False assert isinstance(cond, bool) if cond: if not args.training_set_kwargs.use_labels: raise UserError('--cond=True requires labels specified in dataset.json') desc += '-cond' else: args.training_set_kwargs.use_labels = False # if emb_cond: # desc += '-emb_cond' # else: # args.training_set_kwargs.emb_cond = False if subset is not None: assert isinstance(subset, int) if not 1 <= subset <= args.training_set_kwargs.max_size: raise UserError('--subset must be between 1 and {}'.format(args.training_set_kwargs.max_size)) desc += '-subset{}'.format(subset) if subset < args.training_set_kwargs.max_size: args.training_set_kwargs.max_size = subset args.training_set_kwargs.random_seed = args.random_seed if mirror is None: mirror = False assert isinstance(mirror, bool) if mirror: desc += '-mirror' args.training_set_kwargs.xflip = True # ------------------------------------ # Base config: cfg, gamma, kimg, batch # ------------------------------------ if cfg is None: cfg = 'auto' assert isinstance(cfg, str) desc += '-{}'.format(cfg) cfg_specs = { 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, g_fmaps=-1, d_fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count. 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, g_fmaps=1, d_fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2. 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, g_fmaps=0.5, d_fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, g_fmaps=1, d_fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, g_fmaps=1, d_fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8), 'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, g_fmaps=1, d_fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2), } assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) if cfg == 'auto': desc += '{:d}'.format(gpus) spec.ref_gpus = gpus res = args.training_set_kwargs.resolution spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus)//2 # keep gpu memory consumption at bay spec.mbstd = min(spec.mb // gpus, 4) # other hyperparams behave more predictably if mbstd group size remains fixed spec.g_fmaps = 1 if res >= 512 else 0.5 spec.d_fmaps = 1 if res >= 512 else 0.5 spec.lrate = 0.002 if res >= 1024 else 0.0025 spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula spec.ema = spec.mb * 10 / 32 if not (g_fmaps is None): spec.g_fmaps = g_fmaps if not (d_fmaps is None): spec.d_fmaps = d_fmaps if 'gamma' in hydra_cfg.loss_kwargs: spec.gamma = hydra_cfg.loss_kwargs.gamma args.G_kwargs = dnnlib.EasyDict( class_name='training.networks.Generator', z_dim=z_dim, w_dim=w_dim, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict()) args.D_kwargs = dnnlib.EasyDict( class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict()) args.G_kwargs.synthesis_kwargs.channel_base = int(spec.g_fmaps * 32768) args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 args.G_kwargs.mapping_kwargs.num_layers = spec.map args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow args.G_kwargs.synthesis_cfg = OmegaConf.to_container(hydra_cfg.synthesis) args.D_kwargs.channel_base = int(spec.d_fmaps * 32768) args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd if hydra_cfg.dataset.use_multi_embs: args.D_kwargs.mapping_kwargs.num_layers = 2 args.D_kwargs.block_kwargs.is_modulated = True else: args.D_kwargs.block_kwargs.is_modulated = False args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8) args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8) args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma) args.total_kimg = spec.kimg args.batch_size = spec.mb args.batch_gpu = spec.mb // spec.ref_gpus args.ema_kimg = spec.ema args.ema_rampup = spec.ramp if cfg == 'cifar': args.loss_kwargs.pl_weight = 0 # disable path length regularization args.loss_kwargs.style_mixing_prob = 0 # disable style mixing args.D_kwargs.architecture = 'orig' # disable residual skip connections if 'style_mixing_prob' in hydra_cfg.loss_kwargs: args.loss_kwargs.style_mixing_prob = hydra_cfg.loss_kwargs.style_mixing_prob if 'pl_weight' in hydra_cfg.loss_kwargs: args.loss_kwargs.pl_weight = hydra_cfg.loss_kwargs.pl_weight args.loss_kwargs.synthesis_cfg = OmegaConf.to_container(hydra_cfg.synthesis) if kimg is not None: assert isinstance(kimg, int) if not kimg >= 1: raise UserError('--kimg must be at least 1') desc += '-kimg{:d}'.format(kimg) args.total_kimg = kimg if batch is not None: assert isinstance(batch, int) if not (batch >= 1 and batch % gpus == 0): raise UserError('--batch must be at least 1 and divisible by --gpus') desc += '-batch{}'.format(batch) args.batch_size = batch args.batch_gpu = batch // gpus # --------------------------------------------------- # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- if aug is None: aug = 'ada' else: assert isinstance(aug, str) desc += '-{}'.format(aug) if aug == 'ada': args.ada_target = 0.6 elif aug == 'noaug': pass elif aug == 'fixed': if p is None: raise UserError('--aug={} requires specifying --p'.format(aug)) else: raise UserError('--aug={} not supported'.format(aug)) if p is not None: assert isinstance(p, float) if aug != 'fixed': raise UserError('--p can only be specified with --aug=fixed') if not 0 <= p <= 1: raise UserError('--p must be between 0 and 1') desc += '-p{:g}'.format(p) args.augment_p = p if target is not None: assert isinstance(target, float) if aug != 'ada': raise UserError('--target can only be specified with --aug=ada') if not 0 <= target <= 1: raise UserError('--target must be between 0 and 1') desc += '-target{:g}'.format(target) args.ada_target = target assert augpipe is None or isinstance(augpipe, str) if augpipe is None: augpipe = 'bgc' else: if aug == 'noaug': raise UserError('--augpipe cannot be specified with --aug=noaug') desc += '-{}'.format(augpipe) augpipe_specs = { 'blit': dict(xflip=1, rotate90=1, xint=1), 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1), 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'filter': dict(imgfilter=1), 'noise': dict(noise=1), 'cutout': dict(cutout=1), 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1), 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1), 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1), } assert augpipe in augpipe_specs if aug != 'noaug': args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe]) # ---------------------------------- # Transfer learning: resume, freezed # ---------------------------------- resume_specs = { 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', } assert resume is None or isinstance(resume, str) if resume is None: resume = 'noresume' elif resume == 'noresume': desc += '-noresume' elif resume in resume_specs: desc += '-resume{}'.format(resume) args.resume_pkl = resume_specs[resume] # predefined url else: desc += '-resumecustom' args.resume_pkl = resume # custom path or url if resume != 'noresume': args.ada_kimg = 100 # make ADA react faster at the beginning args.ema_rampup = None # disable EMA rampup if freezed is not None: assert isinstance(freezed, int) if not freezed >= 0: raise UserError('--freezed must be non-negative') desc += '-freezed{:d}'.format(freezed) args.D_kwargs.block_kwargs.freeze_layers = freezed # ------------------------------------------------- # Performance options: fp32, nhwc, nobench, workers # ------------------------------------------------- if fp32 is None: fp32 = False assert isinstance(fp32, bool) if fp32: args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None if nhwc is None: nhwc = False assert isinstance(nhwc, bool) if nhwc: args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True if nobench is None: nobench = False assert isinstance(nobench, bool) if nobench: args.cudnn_benchmark = False if workers is not None: assert isinstance(workers, int) if not workers >= 1: raise UserError('--workers must be at least 1') args.data_loader_kwargs.num_workers = workers return desc, args
def main(**kwargs): # Initialize config. opts = dnnlib.EasyDict(kwargs) # Command line arguments. c = dnnlib.EasyDict() # Main config dict. c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=64, w_dim=128, mapping_kwargs=dnnlib.EasyDict()) c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8) c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2) # Training set. c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data) if opts.cond and not c.training_set_kwargs.use_labels: raise click.ClickException('--cond=True requires labels specified in dataset.json') c.training_set_kwargs.use_labels = opts.cond c.training_set_kwargs.xflip = opts.mirror # Hyperparameters & settings. c.num_gpus = opts.gpus c.batch_size = opts.batch c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus c.G_kwargs.channel_base = opts.cbase c.G_kwargs.channel_max = opts.cmax c.G_kwargs.mapping_kwargs.num_layers = 2 c.G_opt_kwargs.lr = (0.002 if opts.cfg == 'stylegan2' else 0.0025) if opts.glr is None else opts.glr c.D_opt_kwargs.lr = opts.dlr c.metrics = opts.metrics c.total_kimg = opts.kimg c.kimg_per_tick = opts.tick c.image_snapshot_ticks = c.network_snapshot_ticks = opts.snap c.random_seed = c.training_set_kwargs.random_seed = opts.seed c.data_loader_kwargs.num_workers = opts.workers # Sanity checks. if c.batch_size % c.num_gpus != 0: raise click.ClickException('--batch must be a multiple of --gpus') if c.batch_size % (c.num_gpus * c.batch_gpu) != 0: raise click.ClickException('--batch must be a multiple of --gpus times --batch-gpu') if any(not metric_main.is_valid_metric(metric) for metric in c.metrics): raise click.ClickException('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) # Base configuration. c.ema_kimg = c.batch_size * 10 / 32 if opts.cfg == 'stylegan2': c.G_kwargs.class_name = 'pg_modules.networks_stylegan2.Generator' c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions. use_separable_discs = True elif opts.cfg in ['fastgan', 'fastgan_lite']: c.G_kwargs = dnnlib.EasyDict(class_name='pg_modules.networks_fastgan.Generator', cond=opts.cond, synthesis_kwargs=dnnlib.EasyDict()) c.G_kwargs.synthesis_kwargs.lite = (opts.cfg == 'fastgan_lite') c.G_opt_kwargs.lr = c.D_opt_kwargs.lr = 0.0002 use_separable_discs = False # Resume. if opts.resume is not None: c.resume_pkl = opts.resume c.ema_rampup = None # Disable EMA rampup. # Restart. c.restart_every = opts.restart_every # Performance-related toggles. if opts.fp32: c.G_kwargs.num_fp16_res = 0 c.G_kwargs.conv_clamp = None if opts.nobench: c.cudnn_benchmark = False # Description string. desc = f'{opts.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}' if opts.desc is not None: desc += f'-{opts.desc}' # Projected and Multi-Scale Discriminators c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.ProjectedGANLoss') c.D_kwargs = dnnlib.EasyDict( class_name='pg_modules.discriminator.ProjectedDiscriminator', diffaug=True, interp224=(c.training_set_kwargs.resolution < 224), backbone_kwargs=dnnlib.EasyDict(), ) c.D_kwargs.backbone_kwargs.cout = 64 c.D_kwargs.backbone_kwargs.expand = True c.D_kwargs.backbone_kwargs.proj_type = 2 c.D_kwargs.backbone_kwargs.num_discs = 4 c.D_kwargs.backbone_kwargs.separable = use_separable_discs c.D_kwargs.backbone_kwargs.cond = opts.cond # Launch. launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run) # Check for restart last_snapshot = misc.get_ckpt_path(c.run_dir) if os.path.isfile(last_snapshot): # get current number of training images with dnnlib.util.open_url(last_snapshot) as f: cur_nimg = legacy.load_network_pkl(f)['progress']['cur_nimg'].item() if (cur_nimg//1000) < c.total_kimg: print('Restart: exit with code 3') exit(3)
def setup_config(run_dir, **args): args = EasyDict(args) # command-line options train = EasyDict(run_dir=run_dir) # training loop options vis = EasyDict(run_dir=run_dir) # visualization loop options if args.reload: config_fn = os.path.join(run_dir, "training_options.json") if os.path.exists(config_fn): # Load config form the experiment existing file (and so ignore command-line arguments) with open(config_fn, "rt") as f: config = json.load(f) return config misc.log( f"Warning: --reload is set for a new experiment {args.expname}," + f" but configuration file to reload from {config_fn} doesn't exist.", "red") # GANformer and baselines default settings # ---------------------------------------------------------------------------- if args.ganformer_default: task = args.dataset nset(args, "mirror_augment", task in ["cityscapes", "ffhq"]) nset(args, "transformer", True) nset(args, "components_num", {"clevr": 8}.get(task, 16)) nset(args, "latent_size", {"clevr": 128}.get(task, 512)) nset(args, "normalize", "layer") nset(args, "integration", "mul") nset(args, "kmeans", True) nset(args, "use_pos", True) nset(args, "mapping_ltnt2ltnt", task != "clevr") nset(args, "style", task != "clevr") nset(args, "g_arch", "resnet") nset(args, "mapping_resnet", True) gammas = {"ffhq": 10, "cityscapes": 20, "clevr": 40, "bedrooms": 100} nset(args, "gamma", gammas.get(task, 10)) if args.baseline == "GAN": nset(args, "style", False) nset(args, "latent_stem", True) ## k-GAN and SAGAN are not currently supported in the pytorch version. ## See the TF version for implementation of these baselines! # if args.baseline == "SAGAN": # nset(args, "style", False) # nset(args, "latent_stem", True) # nset(args, "g_img2img", 5) # if args.baseline == "kGAN": # nset(args, "kgan", True) # nset(args, "merge_layer", 5) # nset(args, "merge_type", "softmax") # nset(args, "components_num", 8) # General setup # ---------------------------------------------------------------------------- # If the flag is specified without arguments (--arg), set to True for arg in [ "cuda_bench", "allow_tf32", "keep_samples", "style", "local_noise" ]: if args[arg] is None: args[arg] = True if not any([args.train, args.eval, args.vis]): misc.log( "Warning: None of --train, --eval or --vis are provided. Therefore, we only print network shapes", "red") for arg in ["train", "eval", "vis", "last_snapshots"]: cset(train, arg, args[arg]) if args.gpus != "": num_gpus = len(args.gpus.split(",")) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus if not (num_gpus >= 1 and num_gpus & (num_gpus - 1) == 0): misc.error("Number of GPUs must be a power of two") args.num_gpus = num_gpus # CUDA settings for arg in ["batch_size", "batch_gpu", "allow_tf32"]: cset(train, arg, args[arg]) cset(train, "cudnn_benchmark", args.cuda_bench) # Data setup # ---------------------------------------------------------------------------- # For bedrooms, we choose the most common ratio in the # dataset and crop the other images into that ratio. ratios = { "clevr": 0.75, "bedrooms": 188 / 256, "cityscapes": 0.5, "ffhq": 1.0 } args.ratio = args.ratio or ratios.get(args.dataset, 1.0) args.crop_ratio = 0.5 if args.resolution > 256 and args.ratio < 0.5 else None args.printname = args.expname for arg in ["total_kimg", "printname"]: cset(train, arg, args[arg]) dataset_args = EasyDict(class_name="training.dataset.ImageFolderDataset", path=f"{args.data_dir}/{args.dataset}", max_items=args.train_images_num, resolution=args.resolution, ratio=args.ratio, mirror_augment=args.mirror_augment) dataset_args.loader_args = EasyDict(num_workers=args.num_threads, pin_memory=True, prefetch_factor=2) # Optimization setup # ---------------------------------------------------------------------------- cG = set_net("Generator", ["mapping", "synthesis"], args.g_lr, 4) cD = set_net("Discriminator", ["mapping", "block", "epilogue"], args.d_lr, 16) cset([cG, cD], "crop_ratio", args.crop_ratio) mbstd = min( args.batch_gpu, 4 ) # other hyperparams behave more predictably if mbstd group size remains fixed cset(cD.epilogue_kwargs, "mbstd_group_size", mbstd) # Automatic tuning if args.autotune: batch_size = max( min(args.num_gpus * min(4096 // args.resolution, 32), 64), args.num_gpus) # keep gpu memory consumption at bay batch_gpu = args.batch_size // args.num_gpus nset(args, "batch_size", batch_size) nset(args, "batch_gpu", batch_gpu) fmap_decay = 1 if args.resolution >= 512 else 0.5 # other hyperparams behave more predictably if mbstd group size remains fixed lr = 0.002 if args.resolution >= 1024 else 0.0025 gamma = 0.0002 * (args.resolution** 2) / args.batch_size # heuristic formula cset([cG.synthesis_kwargs, cD], "dim_base", int(fmap_decay * 32768)) nset(args, "g_lr", lr) cset(cG.opt_args, "lr", args.g_lr) nset(args, "d_lr", lr) cset(cD.opt_args, "lr", args.d_lr) nset(args, "gamma", gamma) train.ema_rampup = 0.05 train.ema_kimg = batch_size * 10 / 32 if args.batch_size % (args.batch_gpu * args.num_gpus) != 0: misc.error( "--batch-size should be divided by --batch-gpu * 'num_gpus'") # Loss and regularization settings loss_args = EasyDict(class_name="training.loss.StyleGAN2Loss", g_loss=args.g_loss, d_loss=args.d_loss, r1_gamma=args.gamma, pl_weight=args.pl_weight) # if args.fp16: # cset([cG.synthesis_kwargs, cD], "num_fp16_layers", 4) # enable mixed-precision training # cset([cG.synthesis_kwargs, cD], "conv_clamp", 256) # clamp activations to avoid float16 overflow # cset([cG.synthesis_kwargs, cD.block_args], "fp16_channels_last", args.nhwc) # Evaluation and visualization # ---------------------------------------------------------------------------- from metrics import metric_main for metric in args.metrics: if not metric_main.is_valid_metric(metric): misc.error( f"Unknown metric: {metric}. The valid metrics are: {metric_main.list_valid_metrics()}" ) for arg in ["num_gpus", "metrics", "eval_images_num", "truncation_psi"]: cset(train, arg, args[arg]) for arg in ["keep_samples", "num_heads"]: cset(vis, arg, args[arg]) args.vis_imgs = args.vis_images args.vis_ltnts = args.vis_latents vis_types = [ "imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var", "style_mix" ] # Set of all the set visualization types option vis.vis_types = list({arg for arg in vis_types if args[f"vis_{arg}"]}) vis_args = { "attention": "transformer", "grid": "vis_grid", "num": "vis_num", "rich_num": "vis_rich_num", "section_size": "vis_section_size", "intrp_density": "interpolation_density", # "intrp_per_component": "interpolation_per_component", "alpha": "blending_alpha" } for arg, cmd_arg in vis_args.items(): cset(vis, arg, args[cmd_arg]) # Networks setup # ---------------------------------------------------------------------------- # Networks architecture cset(cG.synthesis_kwargs, "architecture", args.g_arch) cset(cD, "architecture", args.d_arch) # Latent sizes if args.components_num > 0: if not args.transformer: # or args.kgan): misc.error( "--components-num > 0 but the model is not using components. " + "Add --transformer for GANformer (which uses latent components)." ) if args.latent_size % args.components_num != 0: misc.error( f"--latent-size ({args.latent_size}) should be divisible by --components-num (k={k})" ) args.latent_size = int(args.latent_size / args.components_num) cG.z_dim = cG.w_dim = args.latent_size cset([cG, vis], "k", args.components_num + 1) # We add a component to modulate features globally # Mapping network args.mapping_layer_dim = args.mapping_dim for arg in ["num_layers", "layer_dim", "resnet", "shared", "ltnt2ltnt"]: field = f"mapping_{arg}" cset(cG.mapping_kwargs, arg, args[field]) # StyleGAN settings for arg in ["style", "latent_stem", "local_noise"]: cset(cG.synthesis_kwargs, arg, args[arg]) # GANformer cset([cG.synthesis_kwargs, cG.mapping_kwargs], "transformer", args.transformer) # Attention related settings for arg in ["use_pos", "num_heads", "ltnt_gate", "attention_dropout"]: cset([cG.mapping_kwargs, cG.synthesis_kwargs], arg, args[arg]) # Attention types and layers for arg in ["start_res", "end_res" ]: # , "local_attention" , "ltnt2ltnt", "img2img", "img2ltnt" cset(cG.synthesis_kwargs, arg, args[f"g_{arg}"]) # Mixing and dropout for arg in ["style_mixing", "component_mixing"]: cset(loss_args, arg, args[arg]) cset(cG, "component_dropout", args["component_dropout"]) # Extra transformer options args.norm = args.normalize for arg in [ "norm", "integration", "img_gate", "iterative", "kmeans", "kmeans_iters" ]: cset(cG.synthesis_kwargs, arg, args[arg]) # Positional encoding # args.pos_dim = args.pos_dim or args.latent_size for arg in ["dim", "type", "init", "directions_num"]: field = f"pos_{arg}" cset(cG.synthesis_kwargs, field, args[field]) # k-GAN # for arg in ["layer", "type", "same"]: # field = "merge_{}".format(arg) # cset(cG.args, field, args[field]) # cset(cG.synthesis_kwargs, "merge", args.kgan) # if args.kgan and args.transformer: # misc.error("Either have --transformer for GANformer or --kgan for k-GAN, not both") config = EasyDict(train) config.update(cG=cG, cD=cD, loss_args=loss_args, dataset_args=dataset_args, vis_args=vis) # Save config file with open(os.path.join(run_dir, "training_options.json"), "wt") as f: json.dump(config, f, indent=2) return config
def setup_training_loop_kwargs( # data data=None, # Training dataset (required): <path> resume=None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', # 'lsundog256', <file>, <url> mirror=None, # Augment dataset with x-flips: <bool>, default = False cond=None, # Train conditional model based on dataset labels: <bool>, default = False # training cfg=None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar' batch=None, # Override batch size: <int> lrate=None, # Override learning rate size: <float> kimg=None, # Override training duration: <int> snap=None, # Snapshot interval: <int>, default = 5 ticks gamma=None, # Override R1 gamma: <float> freezed=None, # Freeze-D: <int>, default = 0 discriminator layers seed=None, # Random seed: <int>, default = 0 # d augment aug=None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed' p=None, # Specify p for 'fixed' (required): <float> target=None, # Override ADA target for 'ada': <float>, default = depends on aug augpipe=None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', # 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc' # general & perf options gpus=None, # Number of GPUs: <int>, default = 1 gpu fp32=None, # Disable mixed-precision training: <bool>, default = False nhwc=None, # Use NHWC memory format with FP16: <bool>, default = False workers=None, # Override number of DataLoader workers: <int>, default = 3 nobench=None, # Disable cuDNN benchmarking: <bool>, default = False metrics=None, # List of metric names: [], ['fid50k_full'] (default), ... ): args = dnnlib.EasyDict() args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict()) args.D_kwargs = dnnlib.EasyDict( class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict()) # General options: gpus, snap, seed # ------------------------------------------ assert (gpus >= 1 and gpus & (gpus - 1) == 0), '--gpus must be a power of two' args.num_gpus = gpus assert snap > 1, '--snap must be at least 1' args.image_snapshot_ticks = 1 args.network_snapshot_ticks = snap if metrics is None: metrics = ['fid50k_full'] assert isinstance(metrics, list) if not all(metric_main.is_valid_metric(metric) for metric in metrics): raise UserError( '\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) args.metrics = metrics args.random_seed = seed # Dataset: data, cond, subset, mirror # ----------------------------------- assert data is not None args.training_set_kwargs = dnnlib.EasyDict( class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False) args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2) try: training_set = dnnlib.util.construct_class_by_name( **args.training_set_kwargs) # subclass of training.dataset.Dataset args.training_set_kwargs.resolution = res = training_set.resolution # be explicit about resolution args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels args.training_set_kwargs.max_size = len( training_set) # be explicit about dataset size # !!! custom init res image_shape = training_set.image_shape init_res = training_set.init_res res_log2 = training_set.res_log2 desc = dataname = training_set.name del training_set # conserve memory except IOError as err: raise UserError(f'--data: {err}') # !!! custom init res if list(init_res) == [4, 4]: desc += '-%d' % res else: print(' custom init resolution', init_res) args.G_kwargs.init_res = args.D_kwargs.init_res = list(init_res) desc += '-%dx%d' % (image_shape[2], image_shape[1]) args.savenames = [desc.replace(dataname, 'snapshot'), desc] if cond: if not args.training_set_kwargs.use_labels: # raise UserError('--cond=True requires labels specified in dataset.json') raise UserError( ' put images in flat subdirectories for conditional training') desc += '-cond' else: args.training_set_kwargs.use_labels = False if mirror: # desc += '-mirror' args.training_set_kwargs.xflip = True # Base config: cfg, gamma, kimg, batch # ------------------------------------ desc += f'-{cfg}' if gpus > 1: desc += f'{gpus:d}' cfg_specs = { 'auto': dict(ramp=0.05, map=8), 'eps': dict(lrate=0.001, ema=10, ramp=0.05, map=8), 'big': dict(mb=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # aydao etc } assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) if cfg == 'auto': # spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay spec.mb = max(min(gpus * min(3072 // res, 32), 64), gpus) # for 11gb RAM spec.fmaps = 1 if res >= 512 else 0.5 spec.lrate = 0.002 if res >= 1024 and lrate is not None else lrate or 0.0025 spec.gamma = 0.0002 * (res**2) / spec.mb # heuristic formula spec.ema = spec.mb * 10 / 32 elif cfg == 'eps': spec.mb = max(min(gpus * min(3072 // res, 32), 64), gpus) spec.fmaps = 1 if res >= 512 else 0.5 spec.gamma = 0.00001 * (res** 2) / spec.mb # !!! my mb 3~4 instead of 32~64 spec.ref_gpus = gpus spec.mbstd = spec.mb // gpus # min(spec.mb // gpus, 4) # other hyperparams behave more # predictably if mbstd group size remains fixed args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int( spec.fmaps * 32768) # TWICE MORE than sg2 on tf !!! args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 args.G_kwargs.mapping_kwargs.num_layers = spec.map args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp act to avoid float16 overflow args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8) args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8) args.loss_kwargs = dnnlib.EasyDict( class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma) args.total_kimg = kimg args.batch_size = spec.mb args.batch_gpu = spec.mb // spec.ref_gpus args.ema_kimg = spec.ema args.ema_rampup = spec.ramp if gamma is not None: assert gamma >= 0, '--gamma must be non-negative' desc += f'-gamma{gamma:g}' args.loss_kwargs.r1_gamma = gamma if batch is not None: assert (batch >= 1 and batch % gpus == 0), '--batch must be at least 1 and divisible by --gpus' desc += f'-batch{batch}' args.batch_size = batch args.batch_gpu = batch // gpus # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- if aug != 'ada': desc += f'-{aug}' if aug == 'ada': args.ada_target = 0.6 elif aug == 'noaug': pass elif aug == 'fixed': if p is None: raise UserError(f'--aug={aug} requires specifying --p') else: raise UserError(f'--aug={aug} not supported') if p is not None: assert aug == 'fixed', '--p can only be specified with --aug=fixed' assert 0 <= p <= 1, '--p must be between 0 and 1' desc += f'-p{p:g}' args.augment_p = p if target is not None: assert aug == 'ada', '--target can only be specified with --aug=ada' assert 0 <= target <= 1, '--target must be between 0 and 1' desc += f'-target{target:g}' args.ada_target = target if augpipe is None and aug == 'noaug': raise UserError('--augpipe cannot be specified with --aug=noaug') desc += f'-{augpipe}' args.augpipe = augpipe augpipe_specs = { 'blit': dict(xflip=1, rotate90=1, xint=1), 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1), 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'filter': dict(imgfilter=1), 'noise': dict(noise=1), 'cutout': dict(cutout=1), 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1), 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1), 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1), # !!! 'bgf_cnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, contrast=0.23, imgfilter=1, noise=0.11, cutout=0.11), 'gf_bnc': dict(xflip=.5, xint=.5, scale=1, rotate=1, aniso=1, xfrac=1, rotate_max=.25, imgfilter=1, noise=.5, cutout=.5), # aug0 } assert augpipe in augpipe_specs, ' unknown augpipe specs: %s' % augpipe if aug != 'noaug': args.augment_kwargs = dnnlib.EasyDict( class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe]) # Transfer learning: resume, freezed # ---------------------------------- resume_specs = { 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/' 'transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/' 'transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/' 'transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/' 'transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/' 'transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', } if resume is None: pass elif resume in resume_specs: # desc += f'-resume{resume}' args.resume_pkl = resume_specs[resume] # predefined url else: # desc += '-resumecustom' args.resume_pkl = resume # custom path or url if resume is not None: args.ada_kimg = 100 # make ADA react faster at the beginning args.ema_rampup = None # disable EMA rampup if freezed is not None: assert freezed >= 0, '--freezed must be non-negative' desc += f'-freezed{freezed:d}' args.D_kwargs.block_kwargs.freeze_layers = freezed # Performance options: fp32, nhwc, nobench, workers # ------------------------------------------------- if fp32: args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None if nhwc: args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True if nobench: args.cudnn_benchmark = False if workers is not None: assert workers >= 1, '--workers must be at least 1' args.data_loader_kwargs.num_workers = workers return desc, args