def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
    """Calculate quality metrics for previous training run or pretrained network pickle.

    Examples:

    \b
    # Previous training run: look up options automatically, save result to JSONL file.
    python calc_metrics.py --metrics=pr50k3_full \\
        --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl

    \b
    # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
    python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl

    Available metrics:

    \b
      ADA paper:
        fid50k_full  Frechet inception distance against the full dataset.
        kid50k_full  Kernel inception distance against the full dataset.
        pr50k3_full  Precision and recall againt the full dataset.
        is50k        Inception score for CIFAR-10.

    \b
      StyleGAN and StyleGAN2 papers:
        fid50k       Frechet inception distance against 50k real images.
        kid50k       Kernel inception distance against 50k real images.
        pr50k3       Precision and recall against 50k real images.
        ppl2_wend    Perceptual path length in W at path endpoints against full image.
        ppl_zfull    Perceptual path length in Z for full paths against cropped image.
        ppl_wfull    Perceptual path length in W for full paths against cropped image.
        ppl_zend     Perceptual path length in Z at path endpoints against cropped image.
        ppl_wend     Perceptual path length in W at path endpoints against cropped image.
    """
    dnnlib.util.Logger(should_flush=True)

    # Validate arguments.
    args = dnnlib.EasyDict(metrics=metrics,
                           num_gpus=gpus,
                           network_pkl=network_pkl,
                           verbose=verbose)
    if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
        ctx.fail(
            '\n'.join(['--metrics can only contain the following values:'] +
                      metric_main.list_valid_metrics()))
    if not args.num_gpus >= 1:
        ctx.fail('--gpus must be at least 1')

    # Load network.
    if not dnnlib.util.is_url(
            network_pkl,
            allow_file_urls=True) and not os.path.isfile(network_pkl):
        ctx.fail('--network must point to a file or URL')
    if args.verbose:
        print(f'Loading network from "{network_pkl}"...')
    with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
        network_dict = legacy.load_network_pkl(f)
        args.G = network_dict['G_ema']  # subclass of torch.nn.Module

    # Initialize dataset options.
    if data is not None:
        args.dataset_kwargs = dnnlib.EasyDict(
            class_name='training.dataset.ImageFolderDataset', path=data)
    elif network_dict['training_set_kwargs'] is not None:
        args.dataset_kwargs = dnnlib.EasyDict(
            network_dict['training_set_kwargs'])
    else:
        ctx.fail('Could not look up dataset options; please specify --data')

    # Finalize dataset options.
    args.dataset_kwargs.resolution_h = args.G.img_resolution_h
    args.dataset_kwargs.resolution_w = args.G.img_resolution_w
    args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
    if mirror is not None:
        args.dataset_kwargs.xflip = mirror

    # Print dataset options.
    if args.verbose:
        print('Dataset options:')
        print(json.dumps(args.dataset_kwargs, indent=2))

    # Locate run dir.
    args.run_dir = None
    if os.path.isfile(network_pkl):
        pkl_dir = os.path.dirname(network_pkl)
        if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
            args.run_dir = pkl_dir

    # Launch processes.
    if args.verbose:
        print('Launching processes...')
    torch.multiprocessing.set_start_method('spawn')
    with tempfile.TemporaryDirectory() as temp_dir:
        if args.num_gpus == 1:
            subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
        else:
            torch.multiprocessing.spawn(fn=subprocess_fn,
                                        args=(args, temp_dir),
                                        nprocs=args.num_gpus)
def setup_training_loop_kwargs(
    # General options (not included in desc).
    gpus=None,  # Number of GPUs: <int>, default = 1 gpu
    snap=None,  # Snapshot interval: <int>, default = 50 ticks
    metrics=None,  # List of metric names: [], ['fid50k_full'] (default), ...
    seed=None,  # Random seed: <int>, default = 0
    # Dataset.
    data=None,  # Training dataset (required): <path>
    cond=None,  # Train conditional model based on dataset labels: <bool>, default = False
    subset=None,  # Train with only N images: <int>, default = all
    mirror=None,  # Augment dataset with x-flips: <bool>, default = False
    # Base config.
    cfg=None,  # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
    gamma=None,  # Override R1 gamma: <float>
    kimg=None,  # Override training duration: <int>
    batch=None,  # Override batch size: <int>
    # Discriminator augmentation.
    aug=None,  # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
    p=None,  # Specify p for 'fixed' (required): <float>
    target=None,  # Override ADA target for 'ada': <float>, default = depends on aug
    augpipe=None,  # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'
    # Transfer learning.
    resume=None,  # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
    freezed=None,  # Freeze-D: <int>, default = 0 discriminator layers
    # Performance options (not included in desc).
    fp32=None,  # Disable mixed-precision training: <bool>, default = False
    nhwc=None,  # Use NHWC memory format with FP16: <bool>, default = False
    nobench=None,  # Disable cuDNN benchmarking: <bool>, default = False
    workers=None,  # Override number of DataLoader workers: <int>, default = 3
    **kwargs,
):
    print("Unrecognized arguments:", kwargs)
    args = dnnlib.EasyDict()

    # ------------------------------------------
    # General options: gpus, snap, metrics, seed
    # ------------------------------------------

    if gpus is None:
        gpus = 1
    assert isinstance(gpus, int)
    if not (gpus >= 1 and gpus & (gpus - 1) == 0):
        raise UserError("--gpus must be a power of two")
    args.num_gpus = gpus

    if snap is None:
        snap = 50
    assert isinstance(snap, int)
    if snap < 1:
        raise UserError("--snap must be at least 1")
    args.image_snapshot_ticks = 1
    args.network_snapshot_ticks = 5

    if metrics is None:
        metrics = ["fid50k_full"]
    assert isinstance(metrics, list)
    if not all(metric_main.is_valid_metric(metric) for metric in metrics):
        raise UserError(
            "\n".join(["--metrics can only contain the following values:"] + metric_main.list_valid_metrics())
        )
    args.metrics = metrics

    if seed is None:
        seed = 0
    assert isinstance(seed, int)
    args.random_seed = seed

    # -----------------------------------
    # Dataset: data, cond, subset, mirror
    # -----------------------------------

    assert data is not None
    assert isinstance(data, str)
    args.training_set_kwargs = dnnlib.EasyDict(
        class_name="training.dataset.ImageFolderDataset", path=data, use_labels=True, max_size=None, xflip=False
    )
    args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2)
    try:
        training_set = dnnlib.util.construct_class_by_name(
            **args.training_set_kwargs
        )  # subclass of training.dataset.Dataset
        args.training_set_kwargs.resolution = training_set.resolution  # be explicit about resolution
        args.training_set_kwargs.use_labels = training_set.has_labels  # be explicit about labels
        args.training_set_kwargs.max_size = len(training_set)  # be explicit about dataset size
        desc = training_set.name
        del training_set  # conserve memory
    except IOError as err:
        raise UserError(f"--data: {err}")

    if cond is None:
        cond = False
    assert isinstance(cond, bool)
    if cond:
        if not args.training_set_kwargs.use_labels:
            raise UserError("--cond=True requires labels specified in dataset.json")
        desc += "-cond"
    else:
        args.training_set_kwargs.use_labels = False

    if subset is not None:
        assert isinstance(subset, int)
        if not 1 <= subset <= args.training_set_kwargs.max_size:
            raise UserError(f"--subset must be between 1 and {args.training_set_kwargs.max_size}")
        desc += f"-subset{subset}"
        if subset < args.training_set_kwargs.max_size:
            args.training_set_kwargs.max_size = subset
            args.training_set_kwargs.random_seed = args.random_seed

    if mirror is None:
        mirror = False
    assert isinstance(mirror, bool)
    if mirror:
        desc += "-mirror"
        args.training_set_kwargs.xflip = True

    # ------------------------------------
    # Base config: cfg, gamma, kimg, batch
    # ------------------------------------

    if cfg is None:
        cfg = "auto"
    assert isinstance(cfg, str)
    desc += f"-{cfg}"

    cfg_specs = {
        "auto": dict(
            ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2
        ),  # Populated dynamically based on resolution and GPU count.
        "stylegan2": dict(
            ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8
        ),  # Uses mixed-precision, unlike the original StyleGAN2.
        "paper256": dict(
            ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8
        ),
        "paper512": dict(
            ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8
        ),
        "paper1024": dict(
            ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8
        ),
        "cifar": dict(
            ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2
        ),
        "wav": dict(ref_gpus=2, kimg=1000, mb=8, mbstd=4, fmaps=1, gamma=10, lrate=0.002, ema=10, ramp=None, map=8),
    }

    assert cfg in cfg_specs
    spec = dnnlib.EasyDict(cfg_specs[cfg])
    if cfg == "auto":
        desc += f"{gpus:d}"
        spec.ref_gpus = gpus
        res = args.training_set_kwargs.resolution
        spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus)  # keep gpu memory consumption at bay
        spec.mbstd = min(
            spec.mb // gpus, 4
        )  # other hyperparams behave more predictably if mbstd group size remains fixed
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.lrate = 0.002 if res >= 1024 else 0.0025
        spec.gamma = 0.0002 * (res ** 2) / spec.mb  # heuristic formula
        spec.ema = spec.mb * 10 / 32

    args.G_kwargs = dnnlib.EasyDict(
        class_name="training.networks.Generator",
        z_dim=512,
        w_dim=512,
        mapping_kwargs=dnnlib.EasyDict(),
        synthesis_kwargs=dnnlib.EasyDict(),
    )
    args.D_kwargs = dnnlib.EasyDict(
        class_name="training.networks.Discriminator",
        block_kwargs=dnnlib.EasyDict(),
        mapping_kwargs=dnnlib.EasyDict(),
        epilogue_kwargs=dnnlib.EasyDict(),
    )
    args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768)
    args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
    args.G_kwargs.mapping_kwargs.num_layers = spec.map
    args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 3  # enable mixed-precision training
    args.G_kwargs.synthesis_kwargs.conv_clamp = (
        args.D_kwargs.conv_clamp
    ) = 256  # clamp activations to avoid float16 overflow
    args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd

    args.G_opt_kwargs = dnnlib.EasyDict(class_name="torch.optim.Adam", lr=spec.lrate, betas=[0, 0.99], eps=1e-8)
    args.D_opt_kwargs = dnnlib.EasyDict(class_name="torch.optim.Adam", lr=spec.lrate, betas=[0, 0.99], eps=1e-8)
    args.loss_kwargs = dnnlib.EasyDict(class_name="training.loss.StyleGAN2Loss", r1_gamma=spec.gamma)

    args.total_kimg = spec.kimg
    args.batch_size = 8
    args.batch_gpu = 4
    args.ema_kimg = spec.ema
    args.ema_rampup = spec.ramp

    if cfg == "cifar" or cfg == "wav":
        args.loss_kwargs.pl_weight = 0  # disable path length regularization
        args.loss_kwargs.style_mixing_prob = 0  # disable style mixing
        # args.D_kwargs.architecture = "orig"  # disable residual skip connections

    if gamma is not None:
        assert isinstance(gamma, float)
        if not gamma >= 0:
            raise UserError("--gamma must be non-negative")
        desc += f"-gamma{gamma:g}"
        args.loss_kwargs.r1_gamma = gamma

    if kimg is not None:
        assert isinstance(kimg, int)
        if not kimg >= 1:
            raise UserError("--kimg must be at least 1")
        desc += f"-kimg{kimg:d}"
        args.total_kimg = kimg

    if batch is not None:
        assert isinstance(batch, int)
        if not (batch >= 1 and batch % gpus == 0):
            raise UserError("--batch must be at least 1 and divisible by --gpus")
        desc += f"-batch{batch}"
        args.batch_size = batch
        args.batch_gpu = batch // gpus

    # ---------------------------------------------------
    # Discriminator augmentation: aug, p, target, augpipe
    # ---------------------------------------------------

    if aug is None:
        aug = "ada"
    else:
        assert isinstance(aug, str)
        desc += f"-{aug}"

    if aug == "ada":
        args.ada_target = 0.6

    elif aug == "noaug":
        pass

    elif aug == "fixed":
        if p is None:
            raise UserError(f"--aug={aug} requires specifying --p")

    else:
        raise UserError(f"--aug={aug} not supported")

    if p is not None:
        assert isinstance(p, float)
        if aug != "fixed":
            raise UserError("--p can only be specified with --aug=fixed")
        if not 0 <= p <= 1:
            raise UserError("--p must be between 0 and 1")
        desc += f"-p{p:g}"
        args.augment_p = p

    if target is not None:
        assert isinstance(target, float)
        if aug != "ada":
            raise UserError("--target can only be specified with --aug=ada")
        if not 0 <= target <= 1:
            raise UserError("--target must be between 0 and 1")
        desc += f"-target{target:g}"
        args.ada_target = target

    assert augpipe is None or isinstance(augpipe, str)
    if augpipe is None:
        augpipe = "bgc"
    else:
        if aug == "noaug":
            raise UserError("--augpipe cannot be specified with --aug=noaug")
        desc += f"-{augpipe}"

    augpipe_specs = {
        "blit": dict(xflip=1, rotate90=1, xint=1),
        "geom": dict(scale=1, rotate=1, aniso=1, xfrac=1),
        "color": dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        "filter": dict(imgfilter=1),
        "noise": dict(noise=1),
        "cutout": dict(cutout=1),
        "bg": dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
        "bgc": dict(
            xflip=1,
            rotate90=1,
            xint=1,
            scale=1,
            rotate=1,
            aniso=1,
            xfrac=1,
            brightness=1,
            contrast=1,
            lumaflip=1,
            hue=1,
            saturation=1,
        ),
        "bgcf": dict(
            xflip=1,
            rotate90=1,
            xint=1,
            scale=1,
            rotate=1,
            aniso=1,
            xfrac=1,
            brightness=1,
            contrast=1,
            lumaflip=1,
            hue=1,
            saturation=1,
            imgfilter=1,
        ),
        "bgcfn": dict(
            xflip=1,
            rotate90=1,
            xint=1,
            scale=1,
            rotate=1,
            aniso=1,
            xfrac=1,
            brightness=1,
            contrast=1,
            lumaflip=1,
            hue=1,
            saturation=1,
            imgfilter=1,
            noise=1,
        ),
        "bgcfnc": dict(
            xflip=1,
            rotate90=1,
            xint=1,
            scale=1,
            rotate=1,
            aniso=1,
            xfrac=1,
            brightness=1,
            contrast=1,
            lumaflip=1,
            hue=1,
            saturation=1,
            imgfilter=1,
            noise=1,
            cutout=1,
        ),
    }

    assert augpipe in augpipe_specs
    if aug != "noaug":
        args.augment_kwargs = dnnlib.EasyDict(class_name="training.augment.AugmentPipe", **augpipe_specs[augpipe])

    # ----------------------------------
    # Transfer learning: resume, freezed
    # ----------------------------------

    resume_specs = {
        "ffhq256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl",
        "ffhq512": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl",
        "ffhq1024": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl",
        "celebahq256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl",
        "lsundog256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl",
    }

    assert resume is None or isinstance(resume, str)
    if resume is None:
        resume = "noresume"
    elif resume == "noresume":
        desc += "-noresume"
    elif resume in resume_specs:
        desc += f"-resume{resume}"
        args.resume_pkl = resume_specs[resume]  # predefined url
    else:
        desc += "-resumecustom"
        args.resume_pkl = resume  # custom path or url

    if resume != "noresume":
        args.ada_kimg = 100  # make ADA react faster at the beginning
        args.ema_rampup = None  # disable EMA rampup

    if freezed is not None:
        assert isinstance(freezed, int)
        if not freezed >= 0:
            raise UserError("--freezed must be non-negative")
        desc += f"-freezed{freezed:d}"
        args.D_kwargs.block_kwargs.freeze_layers = freezed

    # -------------------------------------------------
    # Performance options: fp32, nhwc, nobench, workers
    # -------------------------------------------------

    if fp32 is None:
        fp32 = False
    assert isinstance(fp32, bool)
    if fp32:
        args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
        args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None

    if nhwc is None:
        nhwc = False
    assert isinstance(nhwc, bool)
    if nhwc:
        args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True

    if nobench is None:
        nobench = False
    assert isinstance(nobench, bool)
    if nobench:
        args.cudnn_benchmark = False

    if workers is not None:
        assert isinstance(workers, int)
        if not workers >= 1:
            raise UserError("--workers must be at least 1")
        args.data_loader_kwargs.num_workers = workers

    return desc, args
def setup_training_loop_kwargs(
        # General options (not included in desc).
        gpus=None,  # Number of GPUs: <int>, default = 1 gpu
        snap=None,  # Snapshot interval: <int>, default = 50 ticks
        metrics=None,  # List of metric names: [], ['fid50k_full'] (default), ...
        comet_api_key=None,  # API key for Comet.ml: <str>, default = '' (don't use comet)
        comet_name=None,  # Experiment name for Comet.ml: <str>, default = '' (default Comet naming)
        gradient_clipping=None,  # None or max global gradient L2 norm: <float>, default = None (no clipping)
        seed=None,  # Random seed: <int>, default = 0

        # Train ataset.
    data=None,  # Training dataset (required): <path>
        cond=None,  # Train conditional model based on dataset labels: <bool>, default = False
        subset=None,  # Train with only N images: <int>, default = all
        mirror=None,  # Augment dataset with x-flips: <bool>, default = False

        # ValidationDataset.
    testdata=None,  # Validation dataset (optional): <path>

        # Base config.
    cfg=None,  # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
        gamma=None,  # Override R1 gamma: <float>
        kimg=None,  # Override training duration: <int>
        batch=None,  # Override batch size: <int>

        # Discriminator augmentation.
    diffaugment=None,  # Comma-separated list of DiffAugment policy, default = 'color,translation,cutout'
        diffaugment_placement=None,  # Comma-separated list of DiffAugment applying placement, default = 'real,generated,backprop'
        aug=None,  # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
        p=None,  # Specify p for 'fixed' (required): <float>
        target=None,  # Override ADA target for 'ada': <float>, default = depends on aug
        augpipe=None,  # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'

        # Transfer learning.
    resume=None,  # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
        freezed=None,  # Freeze-D: <int>, default = 0 discriminator layers

        # Performance options (not included in desc).
    fp32=None,  # Disable mixed-precision training: <bool>, default = False
        nhwc=None,  # Use NHWC memory format with FP16: <bool>, default = False
        allow_tf32=None,  # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
        nobench=None,  # Disable cuDNN benchmarking: <bool>, default = False
        workers=None,  # Override number of DataLoader workers: <int>, default = 3
):
    args = dnnlib.EasyDict()

    # -----------------------------------------------------------
    # General options: gpus, snap, metrics, comet, clipping, seed
    # -----------------------------------------------------------

    if gpus is None:
        gpus = 1
    assert isinstance(gpus, int)
    if not (gpus >= 1 and gpus & (gpus - 1) == 0):
        raise UserError('--gpus must be a power of two')
    args.num_gpus = gpus

    if snap is None:
        snap = 50
    assert isinstance(snap, int)
    if snap < 1:
        raise UserError('--snap must be at least 1')
    args.image_snapshot_ticks = snap
    args.network_snapshot_ticks = snap

    if metrics is None:
        metrics = ['fid50k_full']
    assert isinstance(metrics, list)
    if not all(metric_main.is_valid_metric(metric) for metric in metrics):
        raise UserError(
            '\n'.join(['--metrics can only contain the following values:'] +
                      metric_main.list_valid_metrics()))
    args.metrics = metrics

    if comet_api_key is None:
        comet_api_key = ''
    if comet_name is None:
        comet_name = ''
    if comet_api_key:
        if comet_ml is None:
            print(
                'comet_ml is not imported! Proceeding without comet.ml logging'
            )
            args.comet_api_key = ''
            args.comet_experiment_key = ''
        else:
            args.comet_api_key = comet_api_key
            experiment = comet_ml.Experiment(api_key=comet_api_key,
                                             project_name='Sirius SOTA GANs',
                                             auto_output_logging='simple',
                                             auto_log_co2=False,
                                             auto_metric_logging=False,
                                             auto_param_logging=False)
            experiment.set_name(comet_name)
            args.comet_experiment_key = experiment.get_key()
    else:
        args.comet_api_key = ''
        args.comet_experiment_key = ''

    args.gradient_clipping = gradient_clipping

    if seed is None:
        seed = 0
    assert isinstance(seed, int)
    args.random_seed = seed

    # -----------------------------------
    # Dataset: data, cond, subset, mirror
    # -----------------------------------

    assert data is not None
    assert isinstance(data, str)
    args.training_set_kwargs = dnnlib.EasyDict(
        class_name='training.dataset.ImageFolderDataset',
        path=data,
        use_labels=True,
        max_size=None,
        xflip=False)
    args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True,
                                              num_workers=3,
                                              prefetch_factor=2)
    try:
        training_set = dnnlib.util.construct_class_by_name(
            **args.training_set_kwargs)  # subclass of training.dataset.Dataset
        args.training_set_kwargs.resolution = training_set.resolution  # be explicit about resolution
        args.training_set_kwargs.use_labels = training_set.has_labels  # be explicit about labels
        args.training_set_kwargs.max_size = len(
            training_set)  # be explicit about dataset size
        desc = training_set.name
        if testdata is not None:
            args.validation_set_kwargs = dnnlib.EasyDict(
                class_name='training.dataset.ImageFolderDataset',
                path=testdata,
                use_labels=True,
                max_size=None,
                xflip=False)
            validation_set = dnnlib.util.construct_class_by_name(
                **args.validation_set_kwargs)
            args.validation_set_kwargs.resolution = training_set.resolution  # be explicit about resolution
            args.validation_set_kwargs.use_labels = training_set.has_labels  # be explicit about labels
            args.validation_set_kwargs.max_size = len(
                validation_set)  # be explicit about dataset size
            del validation_set
        else:
            args.validation_set_kwargs = {}
        del training_set  # conserve memory
    except IOError as err:
        raise UserError(f'--data: {err}')

    if cond is None:
        cond = False
    assert isinstance(cond, bool)
    if cond:
        if not args.training_set_kwargs.use_labels:
            raise UserError(
                '--cond=True requires labels specified in dataset.json')
        desc += '-cond'
    else:
        args.training_set_kwargs.use_labels = False
        if args.validation_set_kwargs != {}:
            args.validation_set_kwargs.use_labels = False

    if subset is not None:
        assert isinstance(subset, int)
        if not 1 <= subset <= args.training_set_kwargs.max_size:
            raise UserError(
                f'--subset must be between 1 and {args.training_set_kwargs.max_size}'
            )
        desc += f'-subset{subset}'
        if subset < args.training_set_kwargs.max_size:
            args.training_set_kwargs.max_size = subset
            args.training_set_kwargs.random_seed = args.random_seed

    if mirror is None:
        mirror = False
    assert isinstance(mirror, bool)
    if mirror:
        desc += '-mirror'
        args.training_set_kwargs.xflip = True

    # ------------------------------------
    # Base config: cfg, gamma, kimg, batch
    # ------------------------------------

    if cfg is None:
        cfg = 'low_shot'
    assert isinstance(cfg, str)
    desc += f'-{cfg}'

    cfg_specs = {
        'custom_low_shot':
        dict(ref_gpus=-1,
             kimg=500,
             mb=8,
             mbstd=4,
             fmaps=0.5,
             lrate=5e-6,
             gamma=10,
             ema=10,
             ramp=None,
             map=2,
             snap=10),
        'low_shot':
        dict(ref_gpus=-1,
             kimg=300,
             mb=8,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=2,
             snap=10),
        'auto':
        dict(
            ref_gpus=-1,
            kimg=25000,
            mb=-1,
            mbstd=-1,
            fmaps=-1,
            lrate=-1,
            gamma=-1,
            ema=-1,
            ramp=0.05,
            map=2),  # Populated dynamically based on resolution and GPU count.
        'stylegan2':
        dict(ref_gpus=8,
             kimg=25000,
             mb=32,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=8),  # Uses mixed-precision, unlike the original StyleGAN2.
        'paper256':
        dict(ref_gpus=8,
             kimg=25000,
             mb=64,
             mbstd=8,
             fmaps=0.5,
             lrate=0.0025,
             gamma=1,
             ema=20,
             ramp=None,
             map=8),
        'paper512':
        dict(ref_gpus=8,
             kimg=25000,
             mb=64,
             mbstd=8,
             fmaps=1,
             lrate=0.0025,
             gamma=0.5,
             ema=20,
             ramp=None,
             map=8),
        'paper1024':
        dict(ref_gpus=8,
             kimg=25000,
             mb=32,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=2,
             ema=10,
             ramp=None,
             map=8),
        'cifar':
        dict(ref_gpus=2,
             kimg=100000,
             mb=64,
             mbstd=32,
             fmaps=1,
             lrate=0.0025,
             gamma=0.01,
             ema=500,
             ramp=0.05,
             map=2),
    }

    assert cfg in cfg_specs
    spec = dnnlib.EasyDict(cfg_specs[cfg])
    if cfg == 'auto':
        desc += f'{gpus:d}'
        spec.ref_gpus = gpus
        res = args.training_set_kwargs.resolution
        spec.mb = max(min(gpus * min(4096 // res, 32), 64),
                      gpus)  # keep gpu memory consumption at bay
        spec.mbstd = min(
            spec.mb // gpus, 4
        )  # other hyperparams behave more predictably if mbstd group size remains fixed
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.lrate = 0.002 if res >= 1024 else 0.0025
        spec.gamma = 0.0002 * (res**2) / spec.mb  # heuristic formula
        spec.ema = spec.mb * 10 / 32

    if spec.ref_gpus < 0:
        spec.ref_gpus = gpus

    if spec.get('snap', None):
        args.image_snapshot_ticks = args.network_snapshot_ticks = spec.snap

    args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator',
                                    z_dim=512,
                                    w_dim=512,
                                    mapping_kwargs=dnnlib.EasyDict(),
                                    synthesis_kwargs=dnnlib.EasyDict())
    args.D_kwargs = dnnlib.EasyDict(
        class_name='training.networks.Discriminator',
        block_kwargs=dnnlib.EasyDict(),
        mapping_kwargs=dnnlib.EasyDict(),
        epilogue_kwargs=dnnlib.EasyDict())
    args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(
        spec.fmaps * 32768)
    args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
    args.G_kwargs.mapping_kwargs.num_layers = spec.map
    args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4  # enable mixed-precision training
    args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256  # clamp activations to avoid float16 overflow
    args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd

    args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam',
                                        lr=spec.lrate,
                                        betas=[0, 0.99],
                                        eps=1e-8)
    args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam',
                                        lr=spec.lrate,
                                        betas=[0, 0.99],
                                        eps=1e-8)
    args.loss_kwargs = dnnlib.EasyDict(
        class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)

    args.total_kimg = spec.kimg
    args.batch_size = spec.mb
    args.batch_gpu = spec.mb // spec.ref_gpus
    args.ema_kimg = spec.ema
    args.ema_rampup = spec.ramp

    if cfg == 'cifar':
        args.loss_kwargs.pl_weight = 0  # disable path length regularization
        args.loss_kwargs.style_mixing_prob = 0  # disable style mixing
        args.D_kwargs.architecture = 'orig'  # disable residual skip connections

    if gamma is not None:
        assert isinstance(gamma, float)
        if not gamma >= 0:
            raise UserError('--gamma must be non-negative')
        desc += f'-gamma{gamma:g}'
        args.loss_kwargs.r1_gamma = gamma

    if kimg is not None:
        assert isinstance(kimg, int)
        if not kimg >= 1:
            raise UserError('--kimg must be at least 1')
        desc += f'-kimg{kimg:d}'
        args.total_kimg = kimg

    if batch is not None:
        assert isinstance(batch, int)
        if not (batch >= 1 and batch % gpus == 0):
            raise UserError(
                '--batch must be at least 1 and divisible by --gpus')
        desc += f'-batch{batch}'
        args.batch_size = batch
        args.batch_gpu = batch // gpus

    # ---------------------------------------------------
    # Discriminator augmentation: aug, p, target, augpipe
    # ---------------------------------------------------

    if diffaugment is None:
        diffaugment = 'color,translation,cutout'

    if diffaugment:
        args.loss_kwargs.diffaugment = diffaugment
        aug = 'noaug'
        desc += '-{}'.format(diffaugment.replace(',', '-'))
    elif aug is None:
        aug = 'ada'
    else:
        assert isinstance(aug, str)
        desc += f'-{aug}'

    if diffaugment_placement is None:
        if diffaugment:
            diffaugment_placement = 'real,generated,backprop'
        else:
            diffaugment_placement = ''

    if diffaugment_placement:
        args.loss_kwargs.diffaugment_placement = diffaugment_placement

    if aug == 'ada':
        args.ada_target = 0.6

    elif aug == 'noaug':
        pass

    elif aug == 'fixed':
        if p is None:
            raise UserError(f'--aug={aug} requires specifying --p')

    else:
        raise UserError(f'--aug={aug} not supported')

    if p is not None:
        assert isinstance(p, float)
        if aug != 'fixed':
            raise UserError('--p can only be specified with --aug=fixed')
        if not 0 <= p <= 1:
            raise UserError('--p must be between 0 and 1')
        desc += f'-p{p:g}'
        args.augment_p = p

    if target is not None:
        assert isinstance(target, float)
        if aug != 'ada':
            raise UserError('--target can only be specified with --aug=ada')
        if not 0 <= target <= 1:
            raise UserError('--target must be between 0 and 1')
        desc += f'-target{target:g}'
        args.ada_target = target

    assert augpipe is None or isinstance(augpipe, str)
    if augpipe is None:
        augpipe = 'bgc'
    else:
        if aug == 'noaug':
            raise UserError('--augpipe cannot be specified with --aug=noaug')
        desc += f'-{augpipe}'

    augpipe_specs = {
        'blit':
        dict(xflip=1, rotate90=1, xint=1),
        'geom':
        dict(scale=1, rotate=1, aniso=1, xfrac=1),
        'color':
        dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'filter':
        dict(imgfilter=1),
        'noise':
        dict(noise=1),
        'cutout':
        dict(cutout=1),
        'bg':
        dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
        'bgc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1),
        'bgcf':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1),
        'bgcfn':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1),
        'bgcfnc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1,
             cutout=1),
    }

    assert augpipe in augpipe_specs
    if aug != 'noaug':
        args.augment_kwargs = dnnlib.EasyDict(
            class_name='training.augment.AugmentPipe',
            **augpipe_specs[augpipe])

    # ----------------------------------
    # Transfer learning: resume, freezed
    # ----------------------------------

    resume_specs = {
        'ffhq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
        'ffhq512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
        'ffhq1024':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
        'celebahq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
        'lsundog256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
    }

    assert resume is None or isinstance(resume, str)
    if resume is None:
        resume = 'noresume'
    elif resume == 'noresume':
        desc += '-noresume'
    elif resume in resume_specs:
        desc += f'-resume{resume}'
        args.resume_pkl = resume_specs[resume]  # predefined url
    else:
        desc += '-resumecustom'
        args.resume_pkl = resume  # custom path or url

    if resume != 'noresume':
        args.ada_kimg = 100  # make ADA react faster at the beginning
        args.ema_rampup = None  # disable EMA rampup

    if freezed is not None:
        assert isinstance(freezed, int)
        if not freezed >= 0:
            raise UserError('--freezed must be non-negative')
        desc += f'-freezed{freezed:d}'
        args.D_kwargs.block_kwargs.freeze_layers = freezed

    # -------------------------------------------------
    # Performance options: fp32, nhwc, nobench, workers
    # -------------------------------------------------

    if fp32 is None:
        fp32 = False
    assert isinstance(fp32, bool)
    if fp32:
        args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
        args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None

    if nhwc is None:
        nhwc = False
    assert isinstance(nhwc, bool)
    if nhwc:
        args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True

    if nobench is None:
        nobench = False
    assert isinstance(nobench, bool)
    if nobench:
        args.cudnn_benchmark = False

    if allow_tf32 is None:
        allow_tf32 = False
    assert isinstance(allow_tf32, bool)
    if allow_tf32:
        args.allow_tf32 = True

    if workers is not None:
        assert isinstance(workers, int)
        if not workers >= 1:
            raise UserError('--workers must be at least 1')
        args.data_loader_kwargs.num_workers = workers

    return desc, args
Exemple #4
0
def setup_training_loop_kwargs(
    # General options (not included in desc).
    gpus       = None, # Number of GPUs: <int>, default = 1 gpu
    snap       = None, # Snapshot interval: <int>, default = 50 ticks
    metrics    = None, # List of metric names: [], ['fid50k_full'] (default), ...
    seed       = None, # Random seed: <int>, default = 0

    # Dataset.
    data       = None, # Training dataset (required): <path>
    cond       = None, # Train conditional model based on dataset labels: <bool>, default = False
    subset     = None, # Train with only N images: <int>, default = all
    mirror     = None, # Augment dataset with x-flips: <bool>, default = False

    # Base config.
    cfg        = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
    kimg       = None, # Override training duration: <int>
    batch      = None, # Override batch size: <int>

    # Discriminator augmentation.
    aug        = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
    p          = None, # Specify p for 'fixed' (required): <float>
    target     = None, # Override ADA target for 'ada': <float>, default = depends on aug
    augpipe    = None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'

    # Transfer learning.
    resume     = None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
    freezed    = None, # Freeze-D: <int>, default = 0 discriminator layers

    # Performance options (not included in desc).
    fp32       = None, # Disable mixed-precision training: <bool>, default = False
    nhwc       = None, # Use NHWC memory format with FP16: <bool>, default = False
    nobench    = None, # Disable cuDNN benchmarking: <bool>, default = False
    workers    = None, # Override number of DataLoader workers: <int>, default = 3

    # Let's add the ability to overwrite g_fmaps/d_fmaps
    g_fmaps = None,
    d_fmaps = None,

    # Specify z_dim and w_dim
    z_dim = 512,
    w_dim = 512,

    hydra_cfg_name = None,
):
    args = dnnlib.EasyDict()

    # Hydra config
    initialize(config_path="../configs", job_name="test_app")
    hydra_cfg = compose(config_name=hydra_cfg_name)

    # ------------------------------------------
    # General options: gpus, snap, metrics, seed
    # ------------------------------------------

    if gpus is None:
        gpus = 1
    assert isinstance(gpus, int)
    if not (gpus >= 1 and gpus & (gpus - 1) == 0):
        raise UserError('--gpus must be a power of two')
    args.num_gpus = gpus

    if snap is None:
        snap = 50
    assert isinstance(snap, int)
    if snap < 1:
        raise UserError('--snap must be at least 1')
    args.image_snapshot_ticks = snap
    args.network_snapshot_ticks = snap

    if metrics is None:
        metrics = ['fid50k_full']
    assert isinstance(metrics, list)
    if not all(metric_main.is_valid_metric(metric) for metric in metrics):
        raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
    args.metrics = metrics

    if seed is None:
        seed = 0
    assert isinstance(seed, int)
    args.random_seed = seed

    # -----------------------------------
    # Dataset: data, cond, subset, mirror
    # -----------------------------------

    assert data is not None
    assert isinstance(data, str)

    if hydra_cfg.dataset.use_multi_embs:
        args.training_set_kwargs = dnnlib.EasyDict(
            class_name='training.dataset.ImageFolderDatasetWithMultiEmbs',
            path=data,
            use_labels=True,
            max_size=None,
            xflip=False)
    else:
         args.training_set_kwargs = dnnlib.EasyDict(
            class_name='training.dataset.ImageFolderDataset',
            path=data,
            use_labels=True,
            max_size=None,
            xflip=False)

    args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2)

    try:
        training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset
        args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution
        args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels
        args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size
        desc = training_set.name
        del training_set # conserve memory
    except IOError as err:
        raise UserError('--data: {}'.format(err))

    if cond is None:
        cond = False

    assert isinstance(cond, bool)

    if cond:
        if not args.training_set_kwargs.use_labels:
            raise UserError('--cond=True requires labels specified in dataset.json')
        desc += '-cond'
    else:
        args.training_set_kwargs.use_labels = False

    # if emb_cond:
    #     desc += '-emb_cond'
    # else:
    #     args.training_set_kwargs.emb_cond = False

    if subset is not None:
        assert isinstance(subset, int)
        if not 1 <= subset <= args.training_set_kwargs.max_size:
            raise UserError('--subset must be between 1 and {}'.format(args.training_set_kwargs.max_size))
        desc += '-subset{}'.format(subset)
        if subset < args.training_set_kwargs.max_size:
            args.training_set_kwargs.max_size = subset
            args.training_set_kwargs.random_seed = args.random_seed

    if mirror is None:
        mirror = False
    assert isinstance(mirror, bool)
    if mirror:
        desc += '-mirror'
        args.training_set_kwargs.xflip = True

    # ------------------------------------
    # Base config: cfg, gamma, kimg, batch
    # ------------------------------------

    if cfg is None:
        cfg = 'auto'
    assert isinstance(cfg, str)
    desc += '-{}'.format(cfg)

    cfg_specs = {
        'auto':      dict(ref_gpus=-1, kimg=25000,  mb=-1, mbstd=-1, g_fmaps=-1,  d_fmaps=-1,  lrate=-1,     gamma=-1,   ema=-1,  ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count.
        'stylegan2': dict(ref_gpus=8,  kimg=25000,  mb=32, mbstd=4,  g_fmaps=1,   d_fmaps=1,   lrate=0.002,  gamma=10,   ema=10,  ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
        'paper256':  dict(ref_gpus=8,  kimg=25000,  mb=64, mbstd=8,  g_fmaps=0.5, d_fmaps=0.5, lrate=0.0025, gamma=1,    ema=20,  ramp=None, map=8),
        'paper512':  dict(ref_gpus=8,  kimg=25000,  mb=64, mbstd=8,  g_fmaps=1,   d_fmaps=1,   lrate=0.0025, gamma=0.5,  ema=20,  ramp=None, map=8),
        'paper1024': dict(ref_gpus=8,  kimg=25000,  mb=32, mbstd=4,  g_fmaps=1,   d_fmaps=1,   lrate=0.002,  gamma=2,    ema=10,  ramp=None, map=8),
        'cifar':     dict(ref_gpus=2,  kimg=100000, mb=64, mbstd=32, g_fmaps=1,   d_fmaps=1,   lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2),
    }

    assert cfg in cfg_specs
    spec = dnnlib.EasyDict(cfg_specs[cfg])
    if cfg == 'auto':
        desc += '{:d}'.format(gpus)
        spec.ref_gpus = gpus
        res = args.training_set_kwargs.resolution
        spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus)//2 # keep gpu memory consumption at bay
        spec.mbstd = min(spec.mb // gpus, 4) # other hyperparams behave more predictably if mbstd group size remains fixed
        spec.g_fmaps = 1 if res >= 512 else 0.5
        spec.d_fmaps = 1 if res >= 512 else 0.5
        spec.lrate = 0.002 if res >= 1024 else 0.0025
        spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula
        spec.ema = spec.mb * 10 / 32

    if not (g_fmaps is None):
        spec.g_fmaps = g_fmaps

    if not (d_fmaps is None):
        spec.d_fmaps = d_fmaps

    if 'gamma' in hydra_cfg.loss_kwargs:
        spec.gamma = hydra_cfg.loss_kwargs.gamma

    args.G_kwargs = dnnlib.EasyDict(
        class_name='training.networks.Generator',
        z_dim=z_dim,
        w_dim=w_dim,
        mapping_kwargs=dnnlib.EasyDict(),
        synthesis_kwargs=dnnlib.EasyDict())
    args.D_kwargs = dnnlib.EasyDict(
        class_name='training.networks.Discriminator',
        block_kwargs=dnnlib.EasyDict(),
        mapping_kwargs=dnnlib.EasyDict(),
        epilogue_kwargs=dnnlib.EasyDict())
    args.G_kwargs.synthesis_kwargs.channel_base = int(spec.g_fmaps * 32768)
    args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
    args.G_kwargs.mapping_kwargs.num_layers = spec.map
    args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training
    args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow
    args.G_kwargs.synthesis_cfg = OmegaConf.to_container(hydra_cfg.synthesis)
    args.D_kwargs.channel_base = int(spec.d_fmaps * 32768)
    args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd

    if hydra_cfg.dataset.use_multi_embs:
        args.D_kwargs.mapping_kwargs.num_layers = 2
        args.D_kwargs.block_kwargs.is_modulated = True
    else:
        args.D_kwargs.block_kwargs.is_modulated = False

    args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
    args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
    args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)

    args.total_kimg = spec.kimg
    args.batch_size = spec.mb
    args.batch_gpu = spec.mb // spec.ref_gpus
    args.ema_kimg = spec.ema
    args.ema_rampup = spec.ramp

    if cfg == 'cifar':
        args.loss_kwargs.pl_weight = 0 # disable path length regularization
        args.loss_kwargs.style_mixing_prob = 0 # disable style mixing
        args.D_kwargs.architecture = 'orig' # disable residual skip connections

    if 'style_mixing_prob' in hydra_cfg.loss_kwargs:
        args.loss_kwargs.style_mixing_prob = hydra_cfg.loss_kwargs.style_mixing_prob

    if 'pl_weight' in hydra_cfg.loss_kwargs:
        args.loss_kwargs.pl_weight = hydra_cfg.loss_kwargs.pl_weight

    args.loss_kwargs.synthesis_cfg = OmegaConf.to_container(hydra_cfg.synthesis)

    if kimg is not None:
        assert isinstance(kimg, int)
        if not kimg >= 1:
            raise UserError('--kimg must be at least 1')
        desc += '-kimg{:d}'.format(kimg)
        args.total_kimg = kimg

    if batch is not None:
        assert isinstance(batch, int)
        if not (batch >= 1 and batch % gpus == 0):
            raise UserError('--batch must be at least 1 and divisible by --gpus')
        desc += '-batch{}'.format(batch)
        args.batch_size = batch
        args.batch_gpu = batch // gpus

    # ---------------------------------------------------
    # Discriminator augmentation: aug, p, target, augpipe
    # ---------------------------------------------------

    if aug is None:
        aug = 'ada'
    else:
        assert isinstance(aug, str)
        desc += '-{}'.format(aug)

    if aug == 'ada':
        args.ada_target = 0.6

    elif aug == 'noaug':
        pass

    elif aug == 'fixed':
        if p is None:
            raise UserError('--aug={} requires specifying --p'.format(aug))

    else:
        raise UserError('--aug={} not supported'.format(aug))

    if p is not None:
        assert isinstance(p, float)
        if aug != 'fixed':
            raise UserError('--p can only be specified with --aug=fixed')
        if not 0 <= p <= 1:
            raise UserError('--p must be between 0 and 1')
        desc += '-p{:g}'.format(p)
        args.augment_p = p

    if target is not None:
        assert isinstance(target, float)
        if aug != 'ada':
            raise UserError('--target can only be specified with --aug=ada')
        if not 0 <= target <= 1:
            raise UserError('--target must be between 0 and 1')
        desc += '-target{:g}'.format(target)
        args.ada_target = target

    assert augpipe is None or isinstance(augpipe, str)
    if augpipe is None:
        augpipe = 'bgc'
    else:
        if aug == 'noaug':
            raise UserError('--augpipe cannot be specified with --aug=noaug')
        desc += '-{}'.format(augpipe)

    augpipe_specs = {
        'blit':   dict(xflip=1, rotate90=1, xint=1),
        'geom':   dict(scale=1, rotate=1, aniso=1, xfrac=1),
        'color':  dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'filter': dict(imgfilter=1),
        'noise':  dict(noise=1),
        'cutout': dict(cutout=1),
        'bg':     dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
        'bgc':    dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'bgcf':   dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1),
        'bgcfn':  dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1),
        'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1),
    }

    assert augpipe in augpipe_specs
    if aug != 'noaug':
        args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe])

    # ----------------------------------
    # Transfer learning: resume, freezed
    # ----------------------------------

    resume_specs = {
        'ffhq256':     'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
        'ffhq512':     'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
        'ffhq1024':    'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
        'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
        'lsundog256':  'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
    }

    assert resume is None or isinstance(resume, str)
    if resume is None:
        resume = 'noresume'
    elif resume == 'noresume':
        desc += '-noresume'
    elif resume in resume_specs:
        desc += '-resume{}'.format(resume)
        args.resume_pkl = resume_specs[resume] # predefined url
    else:
        desc += '-resumecustom'
        args.resume_pkl = resume # custom path or url

    if resume != 'noresume':
        args.ada_kimg = 100 # make ADA react faster at the beginning
        args.ema_rampup = None # disable EMA rampup

    if freezed is not None:
        assert isinstance(freezed, int)
        if not freezed >= 0:
            raise UserError('--freezed must be non-negative')
        desc += '-freezed{:d}'.format(freezed)
        args.D_kwargs.block_kwargs.freeze_layers = freezed

    # -------------------------------------------------
    # Performance options: fp32, nhwc, nobench, workers
    # -------------------------------------------------

    if fp32 is None:
        fp32 = False
    assert isinstance(fp32, bool)
    if fp32:
        args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
        args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None

    if nhwc is None:
        nhwc = False
    assert isinstance(nhwc, bool)
    if nhwc:
        args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True

    if nobench is None:
        nobench = False
    assert isinstance(nobench, bool)
    if nobench:
        args.cudnn_benchmark = False

    if workers is not None:
        assert isinstance(workers, int)
        if not workers >= 1:
            raise UserError('--workers must be at least 1')
        args.data_loader_kwargs.num_workers = workers

    return desc, args
Exemple #5
0
def main(**kwargs):
    # Initialize config.
    opts = dnnlib.EasyDict(kwargs) # Command line arguments.
    c = dnnlib.EasyDict() # Main config dict.
    c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=64, w_dim=128, mapping_kwargs=dnnlib.EasyDict())
    c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
    c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
    c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2)

    # Training set.
    c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data)
    if opts.cond and not c.training_set_kwargs.use_labels:
        raise click.ClickException('--cond=True requires labels specified in dataset.json')
    c.training_set_kwargs.use_labels = opts.cond
    c.training_set_kwargs.xflip = opts.mirror

    # Hyperparameters & settings.
    c.num_gpus = opts.gpus
    c.batch_size = opts.batch
    c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus
    c.G_kwargs.channel_base = opts.cbase
    c.G_kwargs.channel_max = opts.cmax
    c.G_kwargs.mapping_kwargs.num_layers = 2
    c.G_opt_kwargs.lr = (0.002 if opts.cfg == 'stylegan2' else 0.0025) if opts.glr is None else opts.glr
    c.D_opt_kwargs.lr = opts.dlr
    c.metrics = opts.metrics
    c.total_kimg = opts.kimg
    c.kimg_per_tick = opts.tick
    c.image_snapshot_ticks = c.network_snapshot_ticks = opts.snap
    c.random_seed = c.training_set_kwargs.random_seed = opts.seed
    c.data_loader_kwargs.num_workers = opts.workers

    # Sanity checks.
    if c.batch_size % c.num_gpus != 0:
        raise click.ClickException('--batch must be a multiple of --gpus')
    if c.batch_size % (c.num_gpus * c.batch_gpu) != 0:
        raise click.ClickException('--batch must be a multiple of --gpus times --batch-gpu')
    if any(not metric_main.is_valid_metric(metric) for metric in c.metrics):
        raise click.ClickException('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))

    # Base configuration.
    c.ema_kimg = c.batch_size * 10 / 32
    if opts.cfg == 'stylegan2':
        c.G_kwargs.class_name = 'pg_modules.networks_stylegan2.Generator'
        c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions.
        use_separable_discs = True

    elif opts.cfg in ['fastgan', 'fastgan_lite']:
        c.G_kwargs = dnnlib.EasyDict(class_name='pg_modules.networks_fastgan.Generator', cond=opts.cond, synthesis_kwargs=dnnlib.EasyDict())
        c.G_kwargs.synthesis_kwargs.lite = (opts.cfg == 'fastgan_lite')
        c.G_opt_kwargs.lr = c.D_opt_kwargs.lr = 0.0002
        use_separable_discs = False

    # Resume.
    if opts.resume is not None:
        c.resume_pkl = opts.resume
        c.ema_rampup = None  # Disable EMA rampup.

    # Restart.
    c.restart_every = opts.restart_every

    # Performance-related toggles.
    if opts.fp32:
        c.G_kwargs.num_fp16_res = 0
        c.G_kwargs.conv_clamp = None
    if opts.nobench:
        c.cudnn_benchmark = False

    # Description string.
    desc = f'{opts.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}'
    if opts.desc is not None:
        desc += f'-{opts.desc}'

    # Projected and Multi-Scale Discriminators
    c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.ProjectedGANLoss')
    c.D_kwargs = dnnlib.EasyDict(
        class_name='pg_modules.discriminator.ProjectedDiscriminator',
        diffaug=True,
        interp224=(c.training_set_kwargs.resolution < 224),
        backbone_kwargs=dnnlib.EasyDict(),
    )

    c.D_kwargs.backbone_kwargs.cout = 64
    c.D_kwargs.backbone_kwargs.expand = True
    c.D_kwargs.backbone_kwargs.proj_type = 2
    c.D_kwargs.backbone_kwargs.num_discs = 4
    c.D_kwargs.backbone_kwargs.separable = use_separable_discs
    c.D_kwargs.backbone_kwargs.cond = opts.cond

    # Launch.
    launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)

    # Check for restart
    last_snapshot = misc.get_ckpt_path(c.run_dir)
    if os.path.isfile(last_snapshot):
        # get current number of training images
        with dnnlib.util.open_url(last_snapshot) as f:
            cur_nimg = legacy.load_network_pkl(f)['progress']['cur_nimg'].item()
        if (cur_nimg//1000) < c.total_kimg:
            print('Restart: exit with code 3')
            exit(3)
Exemple #6
0
def setup_config(run_dir, **args):
    args = EasyDict(args)  # command-line options
    train = EasyDict(run_dir=run_dir)  # training loop options
    vis = EasyDict(run_dir=run_dir)  # visualization loop options

    if args.reload:
        config_fn = os.path.join(run_dir, "training_options.json")
        if os.path.exists(config_fn):
            # Load config form the experiment existing file (and so ignore command-line arguments)
            with open(config_fn, "rt") as f:
                config = json.load(f)
            return config
        misc.log(
            f"Warning: --reload is set for a new experiment {args.expname}," +
            f" but configuration file to reload from {config_fn} doesn't exist.",
            "red")

    # GANformer and baselines default settings
    # ----------------------------------------------------------------------------

    if args.ganformer_default:
        task = args.dataset
        nset(args, "mirror_augment", task in ["cityscapes", "ffhq"])

        nset(args, "transformer", True)
        nset(args, "components_num", {"clevr": 8}.get(task, 16))
        nset(args, "latent_size", {"clevr": 128}.get(task, 512))

        nset(args, "normalize", "layer")
        nset(args, "integration", "mul")
        nset(args, "kmeans", True)
        nset(args, "use_pos", True)
        nset(args, "mapping_ltnt2ltnt", task != "clevr")
        nset(args, "style", task != "clevr")

        nset(args, "g_arch", "resnet")
        nset(args, "mapping_resnet", True)

        gammas = {"ffhq": 10, "cityscapes": 20, "clevr": 40, "bedrooms": 100}
        nset(args, "gamma", gammas.get(task, 10))

    if args.baseline == "GAN":
        nset(args, "style", False)
        nset(args, "latent_stem", True)

    ## k-GAN and SAGAN  are not currently supported in the pytorch version.
    ## See the TF version for implementation of these baselines!
    # if args.baseline == "SAGAN":
    #     nset(args, "style", False)
    #     nset(args, "latent_stem", True)
    #     nset(args, "g_img2img", 5)

    # if args.baseline == "kGAN":
    #     nset(args, "kgan", True)
    #     nset(args, "merge_layer", 5)
    #     nset(args, "merge_type", "softmax")
    #     nset(args, "components_num", 8)

    # General setup
    # ----------------------------------------------------------------------------

    # If the flag is specified without arguments (--arg), set to True
    for arg in [
            "cuda_bench", "allow_tf32", "keep_samples", "style", "local_noise"
    ]:
        if args[arg] is None:
            args[arg] = True

    if not any([args.train, args.eval, args.vis]):
        misc.log(
            "Warning: None of --train, --eval or --vis are provided. Therefore, we only print network shapes",
            "red")
    for arg in ["train", "eval", "vis", "last_snapshots"]:
        cset(train, arg, args[arg])

    if args.gpus != "":
        num_gpus = len(args.gpus.split(","))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    if not (num_gpus >= 1 and num_gpus & (num_gpus - 1) == 0):
        misc.error("Number of GPUs must be a power of two")
    args.num_gpus = num_gpus

    # CUDA settings
    for arg in ["batch_size", "batch_gpu", "allow_tf32"]:
        cset(train, arg, args[arg])
    cset(train, "cudnn_benchmark", args.cuda_bench)

    # Data setup
    # ----------------------------------------------------------------------------

    # For bedrooms, we choose the most common ratio in the
    # dataset and crop the other images into that ratio.
    ratios = {
        "clevr": 0.75,
        "bedrooms": 188 / 256,
        "cityscapes": 0.5,
        "ffhq": 1.0
    }
    args.ratio = args.ratio or ratios.get(args.dataset, 1.0)
    args.crop_ratio = 0.5 if args.resolution > 256 and args.ratio < 0.5 else None

    args.printname = args.expname
    for arg in ["total_kimg", "printname"]:
        cset(train, arg, args[arg])

    dataset_args = EasyDict(class_name="training.dataset.ImageFolderDataset",
                            path=f"{args.data_dir}/{args.dataset}",
                            max_items=args.train_images_num,
                            resolution=args.resolution,
                            ratio=args.ratio,
                            mirror_augment=args.mirror_augment)
    dataset_args.loader_args = EasyDict(num_workers=args.num_threads,
                                        pin_memory=True,
                                        prefetch_factor=2)

    # Optimization setup
    # ----------------------------------------------------------------------------

    cG = set_net("Generator", ["mapping", "synthesis"], args.g_lr, 4)
    cD = set_net("Discriminator", ["mapping", "block", "epilogue"], args.d_lr,
                 16)
    cset([cG, cD], "crop_ratio", args.crop_ratio)

    mbstd = min(
        args.batch_gpu, 4
    )  # other hyperparams behave more predictably if mbstd group size remains fixed
    cset(cD.epilogue_kwargs, "mbstd_group_size", mbstd)

    # Automatic tuning
    if args.autotune:
        batch_size = max(
            min(args.num_gpus * min(4096 // args.resolution, 32), 64),
            args.num_gpus)  # keep gpu memory consumption at bay
        batch_gpu = args.batch_size // args.num_gpus
        nset(args, "batch_size", batch_size)
        nset(args, "batch_gpu", batch_gpu)

        fmap_decay = 1 if args.resolution >= 512 else 0.5  # other hyperparams behave more predictably if mbstd group size remains fixed
        lr = 0.002 if args.resolution >= 1024 else 0.0025
        gamma = 0.0002 * (args.resolution**
                          2) / args.batch_size  # heuristic formula

        cset([cG.synthesis_kwargs, cD], "dim_base", int(fmap_decay * 32768))
        nset(args, "g_lr", lr)
        cset(cG.opt_args, "lr", args.g_lr)
        nset(args, "d_lr", lr)
        cset(cD.opt_args, "lr", args.d_lr)
        nset(args, "gamma", gamma)

        train.ema_rampup = 0.05
        train.ema_kimg = batch_size * 10 / 32

    if args.batch_size % (args.batch_gpu * args.num_gpus) != 0:
        misc.error(
            "--batch-size should be divided by --batch-gpu * 'num_gpus'")

    # Loss and regularization settings
    loss_args = EasyDict(class_name="training.loss.StyleGAN2Loss",
                         g_loss=args.g_loss,
                         d_loss=args.d_loss,
                         r1_gamma=args.gamma,
                         pl_weight=args.pl_weight)

    # if args.fp16:
    #     cset([cG.synthesis_kwargs, cD], "num_fp16_layers", 4) # enable mixed-precision training
    #     cset([cG.synthesis_kwargs, cD], "conv_clamp", 256) # clamp activations to avoid float16 overflow

    # cset([cG.synthesis_kwargs, cD.block_args], "fp16_channels_last", args.nhwc)

    # Evaluation and visualization
    # ----------------------------------------------------------------------------

    from metrics import metric_main
    for metric in args.metrics:
        if not metric_main.is_valid_metric(metric):
            misc.error(
                f"Unknown metric: {metric}. The valid metrics are: {metric_main.list_valid_metrics()}"
            )

    for arg in ["num_gpus", "metrics", "eval_images_num", "truncation_psi"]:
        cset(train, arg, args[arg])
    for arg in ["keep_samples", "num_heads"]:
        cset(vis, arg, args[arg])

    args.vis_imgs = args.vis_images
    args.vis_ltnts = args.vis_latents
    vis_types = [
        "imgs", "ltnts", "maps", "layer_maps", "interpolations", "noise_var",
        "style_mix"
    ]
    # Set of all the set visualization types option
    vis.vis_types = list({arg for arg in vis_types if args[f"vis_{arg}"]})

    vis_args = {
        "attention": "transformer",
        "grid": "vis_grid",
        "num": "vis_num",
        "rich_num": "vis_rich_num",
        "section_size": "vis_section_size",
        "intrp_density": "interpolation_density",
        # "intrp_per_component": "interpolation_per_component",
        "alpha": "blending_alpha"
    }
    for arg, cmd_arg in vis_args.items():
        cset(vis, arg, args[cmd_arg])

    # Networks setup
    # ----------------------------------------------------------------------------

    # Networks architecture
    cset(cG.synthesis_kwargs, "architecture", args.g_arch)
    cset(cD, "architecture", args.d_arch)

    # Latent sizes
    if args.components_num > 0:
        if not args.transformer:  # or args.kgan):
            misc.error(
                "--components-num > 0 but the model is not using components. "
                +
                "Add --transformer for GANformer (which uses latent components)."
            )
        if args.latent_size % args.components_num != 0:
            misc.error(
                f"--latent-size ({args.latent_size}) should be divisible by --components-num (k={k})"
            )
        args.latent_size = int(args.latent_size / args.components_num)

    cG.z_dim = cG.w_dim = args.latent_size
    cset([cG, vis], "k", args.components_num +
         1)  # We add a component to modulate features globally

    # Mapping network
    args.mapping_layer_dim = args.mapping_dim
    for arg in ["num_layers", "layer_dim", "resnet", "shared", "ltnt2ltnt"]:
        field = f"mapping_{arg}"
        cset(cG.mapping_kwargs, arg, args[field])

    # StyleGAN settings
    for arg in ["style", "latent_stem", "local_noise"]:
        cset(cG.synthesis_kwargs, arg, args[arg])

    # GANformer
    cset([cG.synthesis_kwargs, cG.mapping_kwargs], "transformer",
         args.transformer)

    # Attention related settings
    for arg in ["use_pos", "num_heads", "ltnt_gate", "attention_dropout"]:
        cset([cG.mapping_kwargs, cG.synthesis_kwargs], arg, args[arg])

    # Attention types and layers
    for arg in ["start_res", "end_res"
                ]:  # , "local_attention" , "ltnt2ltnt", "img2img", "img2ltnt"
        cset(cG.synthesis_kwargs, arg, args[f"g_{arg}"])

    # Mixing and dropout
    for arg in ["style_mixing", "component_mixing"]:
        cset(loss_args, arg, args[arg])
    cset(cG, "component_dropout", args["component_dropout"])

    # Extra transformer options
    args.norm = args.normalize
    for arg in [
            "norm", "integration", "img_gate", "iterative", "kmeans",
            "kmeans_iters"
    ]:
        cset(cG.synthesis_kwargs, arg, args[arg])

    # Positional encoding
    # args.pos_dim = args.pos_dim or args.latent_size
    for arg in ["dim", "type", "init", "directions_num"]:
        field = f"pos_{arg}"
        cset(cG.synthesis_kwargs, field, args[field])

    # k-GAN
    # for arg in ["layer", "type", "same"]:
    #     field = "merge_{}".format(arg)
    #     cset(cG.args, field, args[field])
    # cset(cG.synthesis_kwargs, "merge", args.kgan)
    # if args.kgan and args.transformer:
    # misc.error("Either have --transformer for GANformer or --kgan for k-GAN, not both")

    config = EasyDict(train)
    config.update(cG=cG,
                  cD=cD,
                  loss_args=loss_args,
                  dataset_args=dataset_args,
                  vis_args=vis)

    # Save config file
    with open(os.path.join(run_dir, "training_options.json"), "wt") as f:
        json.dump(config, f, indent=2)

    return config
Exemple #7
0
def setup_training_loop_kwargs(
        # data
        data=None,  # Training dataset (required): <path>
        resume=None,
        # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256',
        # 'lsundog256', <file>, <url>
        mirror=None,  # Augment dataset with x-flips: <bool>, default = False
        cond=None,  # Train conditional model based on dataset labels: <bool>, default = False
        # training
    cfg=None,  # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
        batch=None,  # Override batch size: <int>
        lrate=None,  # Override learning rate size: <float>
        kimg=None,  # Override training duration: <int>
        snap=None,  # Snapshot interval: <int>, default = 5 ticks
        gamma=None,  # Override R1 gamma: <float>
        freezed=None,  # Freeze-D: <int>, default = 0 discriminator layers
        seed=None,  # Random seed: <int>, default = 0
        # d augment
    aug=None,  # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
        p=None,  # Specify p for 'fixed' (required): <float>
        target=None,  # Override ADA target for 'ada': <float>, default = depends on aug
        augpipe=None,
        # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise',
        # 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'

        # general & perf options
    gpus=None,  # Number of GPUs: <int>, default = 1 gpu
        fp32=None,  # Disable mixed-precision training: <bool>, default = False
        nhwc=None,  # Use NHWC memory format with FP16: <bool>, default = False
        workers=None,  # Override number of DataLoader workers: <int>, default = 3
        nobench=None,  # Disable cuDNN benchmarking: <bool>, default = False
        metrics=None,  # List of metric names: [], ['fid50k_full'] (default), ...
):
    args = dnnlib.EasyDict()
    args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator',
                                    z_dim=512,
                                    w_dim=512,
                                    mapping_kwargs=dnnlib.EasyDict(),
                                    synthesis_kwargs=dnnlib.EasyDict())
    args.D_kwargs = dnnlib.EasyDict(
        class_name='training.networks.Discriminator',
        block_kwargs=dnnlib.EasyDict(),
        mapping_kwargs=dnnlib.EasyDict(),
        epilogue_kwargs=dnnlib.EasyDict())

    # General options: gpus, snap, seed
    # ------------------------------------------

    assert (gpus >= 1
            and gpus & (gpus - 1) == 0), '--gpus must be a power of two'
    args.num_gpus = gpus

    assert snap > 1, '--snap must be at least 1'
    args.image_snapshot_ticks = 1
    args.network_snapshot_ticks = snap

    if metrics is None:
        metrics = ['fid50k_full']
    assert isinstance(metrics, list)
    if not all(metric_main.is_valid_metric(metric) for metric in metrics):
        raise UserError(
            '\n'.join(['--metrics can only contain the following values:'] +
                      metric_main.list_valid_metrics()))
    args.metrics = metrics

    args.random_seed = seed

    # Dataset: data, cond, subset, mirror
    # -----------------------------------

    assert data is not None
    args.training_set_kwargs = dnnlib.EasyDict(
        class_name='training.dataset.ImageFolderDataset',
        path=data,
        use_labels=True,
        max_size=None,
        xflip=False)
    args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True,
                                              num_workers=3,
                                              prefetch_factor=2)
    try:
        training_set = dnnlib.util.construct_class_by_name(
            **args.training_set_kwargs)  # subclass of training.dataset.Dataset
        args.training_set_kwargs.resolution = res = training_set.resolution  # be explicit about resolution
        args.training_set_kwargs.use_labels = training_set.has_labels  # be explicit about labels
        args.training_set_kwargs.max_size = len(
            training_set)  # be explicit about dataset size
        # !!! custom init res
        image_shape = training_set.image_shape
        init_res = training_set.init_res
        res_log2 = training_set.res_log2
        desc = dataname = training_set.name
        del training_set  # conserve memory
    except IOError as err:
        raise UserError(f'--data: {err}')

    # !!! custom init res
    if list(init_res) == [4, 4]:
        desc += '-%d' % res
    else:
        print(' custom init resolution', init_res)
        args.G_kwargs.init_res = args.D_kwargs.init_res = list(init_res)
        desc += '-%dx%d' % (image_shape[2], image_shape[1])

    args.savenames = [desc.replace(dataname, 'snapshot'), desc]

    if cond:
        if not args.training_set_kwargs.use_labels:
            # raise UserError('--cond=True requires labels specified in dataset.json')
            raise UserError(
                ' put images in flat subdirectories for conditional training')
        desc += '-cond'
    else:
        args.training_set_kwargs.use_labels = False

    if mirror:
        # desc += '-mirror'
        args.training_set_kwargs.xflip = True

    # Base config: cfg, gamma, kimg, batch
    # ------------------------------------

    desc += f'-{cfg}'
    if gpus > 1:
        desc += f'{gpus:d}'

    cfg_specs = {
        'auto':
        dict(ramp=0.05, map=8),
        'eps':
        dict(lrate=0.001, ema=10, ramp=0.05, map=8),
        'big':
        dict(mb=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None,
             map=8),  # aydao etc
    }

    assert cfg in cfg_specs
    spec = dnnlib.EasyDict(cfg_specs[cfg])
    if cfg == 'auto':
        # spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay
        spec.mb = max(min(gpus * min(3072 // res, 32), 64),
                      gpus)  # for 11gb RAM
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.lrate = 0.002 if res >= 1024 and lrate is not None else lrate or 0.0025
        spec.gamma = 0.0002 * (res**2) / spec.mb  # heuristic formula
        spec.ema = spec.mb * 10 / 32
    elif cfg == 'eps':
        spec.mb = max(min(gpus * min(3072 // res, 32), 64), gpus)
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.gamma = 0.00001 * (res**
                                2) / spec.mb  # !!! my mb 3~4 instead of 32~64
    spec.ref_gpus = gpus
    spec.mbstd = spec.mb // gpus  # min(spec.mb // gpus, 4) # other hyperparams behave more
    # predictably if mbstd group size remains fixed

    args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(
        spec.fmaps * 32768)  # TWICE MORE than sg2 on tf !!!
    args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
    args.G_kwargs.mapping_kwargs.num_layers = spec.map
    args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4  # enable mixed-precision training
    args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256  # clamp act to avoid float16 overflow
    args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd

    args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam',
                                        lr=spec.lrate,
                                        betas=[0, 0.99],
                                        eps=1e-8)
    args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam',
                                        lr=spec.lrate,
                                        betas=[0, 0.99],
                                        eps=1e-8)
    args.loss_kwargs = dnnlib.EasyDict(
        class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)

    args.total_kimg = kimg
    args.batch_size = spec.mb
    args.batch_gpu = spec.mb // spec.ref_gpus
    args.ema_kimg = spec.ema
    args.ema_rampup = spec.ramp

    if gamma is not None:
        assert gamma >= 0, '--gamma must be non-negative'
        desc += f'-gamma{gamma:g}'
        args.loss_kwargs.r1_gamma = gamma

    if batch is not None:
        assert (batch >= 1 and batch % gpus
                == 0), '--batch must be at least 1 and divisible by --gpus'
        desc += f'-batch{batch}'
        args.batch_size = batch
        args.batch_gpu = batch // gpus

    # Discriminator augmentation: aug, p, target, augpipe
    # ---------------------------------------------------

    if aug != 'ada':
        desc += f'-{aug}'

    if aug == 'ada':
        args.ada_target = 0.6
    elif aug == 'noaug':
        pass
    elif aug == 'fixed':
        if p is None:
            raise UserError(f'--aug={aug} requires specifying --p')
    else:
        raise UserError(f'--aug={aug} not supported')

    if p is not None:
        assert aug == 'fixed', '--p can only be specified with --aug=fixed'
        assert 0 <= p <= 1, '--p must be between 0 and 1'
        desc += f'-p{p:g}'
        args.augment_p = p

    if target is not None:
        assert aug == 'ada', '--target can only be specified with --aug=ada'
        assert 0 <= target <= 1, '--target must be between 0 and 1'
        desc += f'-target{target:g}'
        args.ada_target = target

    if augpipe is None and aug == 'noaug':
        raise UserError('--augpipe cannot be specified with --aug=noaug')
    desc += f'-{augpipe}'
    args.augpipe = augpipe

    augpipe_specs = {
        'blit':
        dict(xflip=1, rotate90=1, xint=1),
        'geom':
        dict(scale=1, rotate=1, aniso=1, xfrac=1),
        'color':
        dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'filter':
        dict(imgfilter=1),
        'noise':
        dict(noise=1),
        'cutout':
        dict(cutout=1),
        'bg':
        dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
        'bgc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1),
        'bgcf':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1),
        'bgcfn':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1),
        'bgcfnc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1,
             cutout=1),
        # !!!
        'bgf_cnc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             contrast=0.23,
             imgfilter=1,
             noise=0.11,
             cutout=0.11),
        'gf_bnc':
        dict(xflip=.5,
             xint=.5,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             rotate_max=.25,
             imgfilter=1,
             noise=.5,
             cutout=.5),  # aug0
    }

    assert augpipe in augpipe_specs, ' unknown augpipe specs: %s' % augpipe
    if aug != 'noaug':
        args.augment_kwargs = dnnlib.EasyDict(
            class_name='training.augment.AugmentPipe',
            **augpipe_specs[augpipe])

    # Transfer learning: resume, freezed
    # ----------------------------------

    resume_specs = {
        'ffhq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
        'transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
        'ffhq512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
        'transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
        'ffhq1024':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
        'transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
        'celebahq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
        'transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
        'lsundog256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
        'transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
    }

    if resume is None:
        pass
    elif resume in resume_specs:
        # desc += f'-resume{resume}'
        args.resume_pkl = resume_specs[resume]  # predefined url
    else:
        # desc += '-resumecustom'
        args.resume_pkl = resume  # custom path or url

    if resume is not None:
        args.ada_kimg = 100  # make ADA react faster at the beginning
        args.ema_rampup = None  # disable EMA rampup

    if freezed is not None:
        assert freezed >= 0, '--freezed must be non-negative'
        desc += f'-freezed{freezed:d}'
        args.D_kwargs.block_kwargs.freeze_layers = freezed

    # Performance options: fp32, nhwc, nobench, workers
    # -------------------------------------------------

    if fp32:
        args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
        args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None

    if nhwc:
        args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True

    if nobench:
        args.cudnn_benchmark = False

    if workers is not None:
        assert workers >= 1, '--workers must be at least 1'
        args.data_loader_kwargs.num_workers = workers

    return desc, args