Exemplo n.º 1
0
def setup_training_loop_kwargs(
        # General options (not included in desc).
        gpus=None,  # Number of GPUs: <int>, default = 1 gpu
        snap=None,  # Snapshot interval: <int>, default = 50 ticks
        metrics=None,  # List of metric names: [], ['fid50k_full'] (default), ...
        seed=None,  # Random seed: <int>, default = 0

        # Dataset.
    data=None,  # Training dataset (required): <path>
        cond=None,  # Train conditional model based on dataset labels: <bool>, default = False
        subset=None,  # Train with only N images: <int>, default = all
        mirror=None,  # Augment dataset with x-flips: <bool>, default = False

        # Base config.
    cfg=None,  # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
        gamma=None,  # Override R1 gamma: <float>
        kimg=None,  # Override training duration: <int>
        batch=None,  # Override batch size: <int>

        # Discriminator augmentation.
    aug=None,  # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
        p=None,  # Specify p for 'fixed' (required): <float>
        target=None,  # Override ADA target for 'ada': <float>, default = depends on aug
        augpipe=None,  # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'

        # Transfer learning.
    resume=None,  # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
        freezed=None,  # Freeze-D: <int>, default = 0 discriminator layers

        # Performance options (not included in desc).
    fp32=None,  # Disable mixed-precision training: <bool>, default = False
        nhwc=None,  # Use NHWC memory format with FP16: <bool>, default = False
        allow_tf32=None,  # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
        nobench=None,  # Disable cuDNN benchmarking: <bool>, default = False
        workers=None,  # Override number of DataLoader workers: <int>, default = 3
):
    args = dnnlib.EasyDict()

    # ------------------------------------------
    # General options: gpus, snap, metrics, seed
    # ------------------------------------------

    if gpus is None:
        gpus = 1
    assert isinstance(gpus, int)
    if not (gpus >= 1 and gpus & (gpus - 1) == 0):
        raise UserError('--gpus must be a power of two')
    args.num_gpus = gpus

    if snap is None:
        snap = 50
    assert isinstance(snap, int)
    if snap < 1:
        raise UserError('--snap must be at least 1')
    args.image_snapshot_ticks = snap
    args.network_snapshot_ticks = snap

    if metrics is None:
        metrics = ['fid50k_full']
    assert isinstance(metrics, list)
    if not all(metric_main.is_valid_metric(metric) for metric in metrics):
        raise UserError(
            '\n'.join(['--metrics can only contain the following values:'] +
                      metric_main.list_valid_metrics()))
    args.metrics = metrics

    if seed is None:
        seed = 0
    assert isinstance(seed, int)
    args.random_seed = seed

    # -----------------------------------
    # Dataset: data, cond, subset, mirror
    # -----------------------------------

    assert data is not None
    assert isinstance(data, str)
    args.training_set_kwargs = dnnlib.EasyDict(
        class_name='training.dataset.ImageFolderDataset',
        path=data,
        use_labels=True,
        max_size=None,
        xflip=False)
    args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True,
                                              num_workers=3,
                                              prefetch_factor=2)
    try:
        training_set = dnnlib.util.construct_class_by_name(
            **args.training_set_kwargs)  # subclass of training.dataset.Dataset
        args.training_set_kwargs.resolution = training_set.resolution  # be explicit about resolution
        args.training_set_kwargs.use_labels = training_set.has_labels  # be explicit about labels
        args.training_set_kwargs.max_size = len(
            training_set)  # be explicit about dataset size
        desc = training_set.name
        del training_set  # conserve memory
    except IOError as err:
        raise UserError(f'--data: {err}')

    if cond is None:
        cond = False
    assert isinstance(cond, bool)
    if cond:
        if not args.training_set_kwargs.use_labels:
            raise UserError(
                '--cond=True requires labels specified in dataset.json')
        desc += '-cond'
    else:
        args.training_set_kwargs.use_labels = False

    if subset is not None:
        assert isinstance(subset, int)
        if not 1 <= subset <= args.training_set_kwargs.max_size:
            raise UserError(
                f'--subset must be between 1 and {args.training_set_kwargs.max_size}'
            )
        desc += f'-subset{subset}'
        if subset < args.training_set_kwargs.max_size:
            args.training_set_kwargs.max_size = subset
            args.training_set_kwargs.random_seed = args.random_seed

    if mirror is None:
        mirror = False
    assert isinstance(mirror, bool)
    if mirror:
        desc += '-mirror'
        args.training_set_kwargs.xflip = True

    # ------------------------------------
    # Base config: cfg, gamma, kimg, batch
    # ------------------------------------

    if cfg is None:
        cfg = 'auto'
    assert isinstance(cfg, str)
    desc += f'-{cfg}'

    cfg_specs = {
        'auto':
        dict(
            ref_gpus=-1,
            kimg=25000,
            mb=-1,
            mbstd=-1,
            fmaps=-1,
            lrate=-1,
            gamma=-1,
            ema=-1,
            ramp=0.05,
            map=2),  # Populated dynamically based on resolution and GPU count.
        'stylegan2':
        dict(ref_gpus=8,
             kimg=25000,
             mb=32,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=8),  # Uses mixed-precision, unlike the original StyleGAN2.
        'paper256':
        dict(ref_gpus=8,
             kimg=25000,
             mb=64,
             mbstd=8,
             fmaps=0.5,
             lrate=0.0025,
             gamma=1,
             ema=20,
             ramp=None,
             map=8),
        'paper512':
        dict(ref_gpus=8,
             kimg=25000,
             mb=64,
             mbstd=8,
             fmaps=1,
             lrate=0.0025,
             gamma=0.5,
             ema=20,
             ramp=None,
             map=8),
        'paper1024':
        dict(ref_gpus=8,
             kimg=25000,
             mb=32,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=2,
             ema=10,
             ramp=None,
             map=8),
        'cifar':
        dict(ref_gpus=2,
             kimg=100000,
             mb=64,
             mbstd=32,
             fmaps=1,
             lrate=0.0025,
             gamma=0.01,
             ema=500,
             ramp=0.05,
             map=2),
    }

    assert cfg in cfg_specs
    spec = dnnlib.EasyDict(cfg_specs[cfg])
    if cfg == 'auto':
        desc += f'{gpus:d}'
        spec.ref_gpus = gpus
        res = args.training_set_kwargs.resolution
        spec.mb = max(min(gpus * min(4096 // res, 32), 64),
                      gpus)  # keep gpu memory consumption at bay
        spec.mbstd = min(
            spec.mb // gpus, 4
        )  # other hyperparams behave more predictably if mbstd group size remains fixed
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.lrate = 0.002 if res >= 1024 else 0.0025
        spec.gamma = 0.0002 * (res**2) / spec.mb  # heuristic formula
        spec.ema = spec.mb * 10 / 32

    args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator',
                                    z_dim=512,
                                    w_dim=512,
                                    mapping_kwargs=dnnlib.EasyDict(),
                                    synthesis_kwargs=dnnlib.EasyDict())
    args.D_kwargs = dnnlib.EasyDict(
        class_name='training.networks.Discriminator',
        block_kwargs=dnnlib.EasyDict(),
        mapping_kwargs=dnnlib.EasyDict(),
        epilogue_kwargs=dnnlib.EasyDict())
    args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(
        spec.fmaps * 32768)
    args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
    args.G_kwargs.mapping_kwargs.num_layers = 3  # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!Edit to 3
    args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4  # enable mixed-precision training
    args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256  # clamp activations to avoid float16 overflow
    args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd

    args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam',
                                        lr=spec.lrate,
                                        betas=[0, 0.99],
                                        eps=1e-8)
    args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam',
                                        lr=spec.lrate,
                                        betas=[0, 0.99],
                                        eps=1e-8)
    args.loss_kwargs = dnnlib.EasyDict(
        class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)

    args.total_kimg = spec.kimg
    args.batch_size = spec.mb
    args.batch_gpu = spec.mb // spec.ref_gpus
    args.ema_kimg = spec.ema
    args.ema_rampup = spec.ramp

    if cfg == 'cifar':
        args.loss_kwargs.pl_weight = 0  # disable path length regularization
        args.loss_kwargs.style_mixing_prob = 0  # disable style mixing
        args.D_kwargs.architecture = 'orig'  # disable residual skip connections

    if gamma is not None:
        assert isinstance(gamma, float)
        if not gamma >= 0:
            raise UserError('--gamma must be non-negative')
        desc += f'-gamma{gamma:g}'
        args.loss_kwargs.r1_gamma = gamma

    if kimg is not None:
        assert isinstance(kimg, int)
        if not kimg >= 1:
            raise UserError('--kimg must be at least 1')
        desc += f'-kimg{kimg:d}'
        args.total_kimg = kimg

    if batch is not None:
        assert isinstance(batch, int)
        if not (batch >= 1 and batch % gpus == 0):
            raise UserError(
                '--batch must be at least 1 and divisible by --gpus')
        desc += f'-batch{batch}'
        args.batch_size = batch
        args.batch_gpu = batch // gpus

    # ---------------------------------------------------
    # Discriminator augmentation: aug, p, target, augpipe
    # ---------------------------------------------------

    if aug is None:
        aug = 'ada'
    else:
        assert isinstance(aug, str)
        desc += f'-{aug}'

    if aug == 'ada':
        args.ada_target = 0.6

    elif aug == 'noaug':
        pass

    elif aug == 'fixed':
        if p is None:
            raise UserError(f'--aug={aug} requires specifying --p')

    else:
        raise UserError(f'--aug={aug} not supported')

    if p is not None:
        assert isinstance(p, float)
        if aug != 'fixed':
            raise UserError('--p can only be specified with --aug=fixed')
        if not 0 <= p <= 1:
            raise UserError('--p must be between 0 and 1')
        desc += f'-p{p:g}'
        args.augment_p = p

    if target is not None:
        assert isinstance(target, float)
        if aug != 'ada':
            raise UserError('--target can only be specified with --aug=ada')
        if not 0 <= target <= 1:
            raise UserError('--target must be between 0 and 1')
        desc += f'-target{target:g}'
        args.ada_target = target

    assert augpipe is None or isinstance(augpipe, str)
    if augpipe is None:
        augpipe = 'bgc'
    else:
        if aug == 'noaug':
            raise UserError('--augpipe cannot be specified with --aug=noaug')
        desc += f'-{augpipe}'

    augpipe_specs = {
        'blit':
        dict(xflip=1, rotate90=1, xint=1),
        'geom':
        dict(scale=1, rotate=1, aniso=1, xfrac=1),
        'color':
        dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'filter':
        dict(imgfilter=1),
        'noise':
        dict(noise=1),
        'cutout':
        dict(cutout=1),
        'bg':
        dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
        'bgc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1),
        'bgcf':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1),
        'bgcfn':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1),
        'bgcfnc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1,
             cutout=1),
    }

    assert augpipe in augpipe_specs
    if aug != 'noaug':
        args.augment_kwargs = dnnlib.EasyDict(
            class_name='training.augment.AugmentPipe',
            **augpipe_specs[augpipe])

    # ----------------------------------
    # Transfer learning: resume, freezed
    # ----------------------------------

    resume_specs = {
        'ffhq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
        'ffhq512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
        'ffhq1024':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
        'celebahq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
        'lsundog256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
    }

    assert resume is None or isinstance(resume, str)
    if resume is None:
        resume = 'noresume'
    elif resume == 'noresume':
        desc += '-noresume'
    elif resume in resume_specs:
        desc += f'-resume{resume}'
        args.resume_pkl = resume_specs[resume]  # predefined url
    else:
        desc += '-resumecustom'
        args.resume_pkl = resume  # custom path or url

    if resume != 'noresume':
        args.ada_kimg = 100  # make ADA react faster at the beginning
        args.ema_rampup = None  # disable EMA rampup

    if freezed is not None:
        assert isinstance(freezed, int)
        if not freezed >= 0:
            raise UserError('--freezed must be non-negative')
        desc += f'-freezed{freezed:d}'
        args.D_kwargs.block_kwargs.freeze_layers = freezed

    # -------------------------------------------------
    # Performance options: fp32, nhwc, nobench, workers
    # -------------------------------------------------

    if fp32 is None:
        fp32 = False
    assert isinstance(fp32, bool)
    if fp32:
        args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
        args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None

    if nhwc is None:
        nhwc = False
    assert isinstance(nhwc, bool)
    if nhwc:
        args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True

    if nobench is None:
        nobench = False
    assert isinstance(nobench, bool)
    if nobench:
        args.cudnn_benchmark = False

    if allow_tf32 is None:
        allow_tf32 = False
    assert isinstance(allow_tf32, bool)
    if allow_tf32:
        args.allow_tf32 = True

    if workers is not None:
        assert isinstance(workers, int)
        if not workers >= 1:
            raise UserError('--workers must be at least 1')
        args.data_loader_kwargs.num_workers = workers

    return desc, args
Exemplo n.º 2
0
def setup_training_loop_kwargs(
        # data
        data=None,  # Training dataset (required): <path>
        resume=None,
        # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256',
        # 'lsundog256', <file>, <url>
        mirror=None,  # Augment dataset with x-flips: <bool>, default = False
        cond=None,  # Train conditional model based on dataset labels: <bool>, default = False
        # training
    cfg=None,  # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
        batch=None,  # Override batch size: <int>
        lrate=None,  # Override learning rate size: <float>
        kimg=None,  # Override training duration: <int>
        snap=None,  # Snapshot interval: <int>, default = 5 ticks
        gamma=None,  # Override R1 gamma: <float>
        freezed=None,  # Freeze-D: <int>, default = 0 discriminator layers
        seed=None,  # Random seed: <int>, default = 0
        # d augment
    aug=None,  # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
        p=None,  # Specify p for 'fixed' (required): <float>
        target=None,  # Override ADA target for 'ada': <float>, default = depends on aug
        augpipe=None,
        # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise',
        # 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'

        # general & perf options
    gpus=None,  # Number of GPUs: <int>, default = 1 gpu
        fp32=None,  # Disable mixed-precision training: <bool>, default = False
        nhwc=None,  # Use NHWC memory format with FP16: <bool>, default = False
        workers=None,  # Override number of DataLoader workers: <int>, default = 3
        nobench=None,  # Disable cuDNN benchmarking: <bool>, default = False
        metrics=None,  # List of metric names: [], ['fid50k_full'] (default), ...
):
    args = dnnlib.EasyDict()
    args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator',
                                    z_dim=512,
                                    w_dim=512,
                                    mapping_kwargs=dnnlib.EasyDict(),
                                    synthesis_kwargs=dnnlib.EasyDict())
    args.D_kwargs = dnnlib.EasyDict(
        class_name='training.networks.Discriminator',
        block_kwargs=dnnlib.EasyDict(),
        mapping_kwargs=dnnlib.EasyDict(),
        epilogue_kwargs=dnnlib.EasyDict())

    # General options: gpus, snap, seed
    # ------------------------------------------

    assert (gpus >= 1
            and gpus & (gpus - 1) == 0), '--gpus must be a power of two'
    args.num_gpus = gpus

    assert snap > 1, '--snap must be at least 1'
    args.image_snapshot_ticks = 1
    args.network_snapshot_ticks = snap

    if metrics is None:
        metrics = ['fid50k_full']
    assert isinstance(metrics, list)
    if not all(metric_main.is_valid_metric(metric) for metric in metrics):
        raise UserError(
            '\n'.join(['--metrics can only contain the following values:'] +
                      metric_main.list_valid_metrics()))
    args.metrics = metrics

    args.random_seed = seed

    # Dataset: data, cond, subset, mirror
    # -----------------------------------

    assert data is not None
    args.training_set_kwargs = dnnlib.EasyDict(
        class_name='training.dataset.ImageFolderDataset',
        path=data,
        use_labels=True,
        max_size=None,
        xflip=False)
    args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True,
                                              num_workers=3,
                                              prefetch_factor=2)
    try:
        training_set = dnnlib.util.construct_class_by_name(
            **args.training_set_kwargs)  # subclass of training.dataset.Dataset
        args.training_set_kwargs.resolution = res = training_set.resolution  # be explicit about resolution
        args.training_set_kwargs.use_labels = training_set.has_labels  # be explicit about labels
        args.training_set_kwargs.max_size = len(
            training_set)  # be explicit about dataset size
        # !!! custom init res
        image_shape = training_set.image_shape
        init_res = training_set.init_res
        res_log2 = training_set.res_log2
        desc = dataname = training_set.name
        del training_set  # conserve memory
    except IOError as err:
        raise UserError(f'--data: {err}')

    # !!! custom init res
    if list(init_res) == [4, 4]:
        desc += '-%d' % res
    else:
        print(' custom init resolution', init_res)
        args.G_kwargs.init_res = args.D_kwargs.init_res = list(init_res)
        desc += '-%dx%d' % (image_shape[2], image_shape[1])

    args.savenames = [desc.replace(dataname, 'snapshot'), desc]

    if cond:
        if not args.training_set_kwargs.use_labels:
            # raise UserError('--cond=True requires labels specified in dataset.json')
            raise UserError(
                ' put images in flat subdirectories for conditional training')
        desc += '-cond'
    else:
        args.training_set_kwargs.use_labels = False

    if mirror:
        # desc += '-mirror'
        args.training_set_kwargs.xflip = True

    # Base config: cfg, gamma, kimg, batch
    # ------------------------------------

    desc += f'-{cfg}'
    if gpus > 1:
        desc += f'{gpus:d}'

    cfg_specs = {
        'auto':
        dict(ramp=0.05, map=8),
        'eps':
        dict(lrate=0.001, ema=10, ramp=0.05, map=8),
        'big':
        dict(mb=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None,
             map=8),  # aydao etc
    }

    assert cfg in cfg_specs
    spec = dnnlib.EasyDict(cfg_specs[cfg])
    if cfg == 'auto':
        # spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay
        spec.mb = max(min(gpus * min(3072 // res, 32), 64),
                      gpus)  # for 11gb RAM
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.lrate = 0.002 if res >= 1024 and lrate is not None else lrate or 0.0025
        spec.gamma = 0.0002 * (res**2) / spec.mb  # heuristic formula
        spec.ema = spec.mb * 10 / 32
    elif cfg == 'eps':
        spec.mb = max(min(gpus * min(3072 // res, 32), 64), gpus)
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.gamma = 0.00001 * (res**
                                2) / spec.mb  # !!! my mb 3~4 instead of 32~64
    spec.ref_gpus = gpus
    spec.mbstd = spec.mb // gpus  # min(spec.mb // gpus, 4) # other hyperparams behave more
    # predictably if mbstd group size remains fixed

    args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(
        spec.fmaps * 32768)  # TWICE MORE than sg2 on tf !!!
    args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
    args.G_kwargs.mapping_kwargs.num_layers = spec.map
    args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4  # enable mixed-precision training
    args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256  # clamp act to avoid float16 overflow
    args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd

    args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam',
                                        lr=spec.lrate,
                                        betas=[0, 0.99],
                                        eps=1e-8)
    args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam',
                                        lr=spec.lrate,
                                        betas=[0, 0.99],
                                        eps=1e-8)
    args.loss_kwargs = dnnlib.EasyDict(
        class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)

    args.total_kimg = kimg
    args.batch_size = spec.mb
    args.batch_gpu = spec.mb // spec.ref_gpus
    args.ema_kimg = spec.ema
    args.ema_rampup = spec.ramp

    if gamma is not None:
        assert gamma >= 0, '--gamma must be non-negative'
        desc += f'-gamma{gamma:g}'
        args.loss_kwargs.r1_gamma = gamma

    if batch is not None:
        assert (batch >= 1 and batch % gpus
                == 0), '--batch must be at least 1 and divisible by --gpus'
        desc += f'-batch{batch}'
        args.batch_size = batch
        args.batch_gpu = batch // gpus

    # Discriminator augmentation: aug, p, target, augpipe
    # ---------------------------------------------------

    if aug != 'ada':
        desc += f'-{aug}'

    if aug == 'ada':
        args.ada_target = 0.6
    elif aug == 'noaug':
        pass
    elif aug == 'fixed':
        if p is None:
            raise UserError(f'--aug={aug} requires specifying --p')
    else:
        raise UserError(f'--aug={aug} not supported')

    if p is not None:
        assert aug == 'fixed', '--p can only be specified with --aug=fixed'
        assert 0 <= p <= 1, '--p must be between 0 and 1'
        desc += f'-p{p:g}'
        args.augment_p = p

    if target is not None:
        assert aug == 'ada', '--target can only be specified with --aug=ada'
        assert 0 <= target <= 1, '--target must be between 0 and 1'
        desc += f'-target{target:g}'
        args.ada_target = target

    if augpipe is None and aug == 'noaug':
        raise UserError('--augpipe cannot be specified with --aug=noaug')
    desc += f'-{augpipe}'
    args.augpipe = augpipe

    augpipe_specs = {
        'blit':
        dict(xflip=1, rotate90=1, xint=1),
        'geom':
        dict(scale=1, rotate=1, aniso=1, xfrac=1),
        'color':
        dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'filter':
        dict(imgfilter=1),
        'noise':
        dict(noise=1),
        'cutout':
        dict(cutout=1),
        'bg':
        dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
        'bgc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1),
        'bgcf':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1),
        'bgcfn':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1),
        'bgcfnc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1,
             cutout=1),
        # !!!
        'bgf_cnc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             contrast=0.23,
             imgfilter=1,
             noise=0.11,
             cutout=0.11),
        'gf_bnc':
        dict(xflip=.5,
             xint=.5,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             rotate_max=.25,
             imgfilter=1,
             noise=.5,
             cutout=.5),  # aug0
    }

    assert augpipe in augpipe_specs, ' unknown augpipe specs: %s' % augpipe
    if aug != 'noaug':
        args.augment_kwargs = dnnlib.EasyDict(
            class_name='training.augment.AugmentPipe',
            **augpipe_specs[augpipe])

    # Transfer learning: resume, freezed
    # ----------------------------------

    resume_specs = {
        'ffhq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
        'transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
        'ffhq512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
        'transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
        'ffhq1024':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
        'transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
        'celebahq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
        'transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
        'lsundog256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
        'transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
    }

    if resume is None:
        pass
    elif resume in resume_specs:
        # desc += f'-resume{resume}'
        args.resume_pkl = resume_specs[resume]  # predefined url
    else:
        # desc += '-resumecustom'
        args.resume_pkl = resume  # custom path or url

    if resume is not None:
        args.ada_kimg = 100  # make ADA react faster at the beginning
        args.ema_rampup = None  # disable EMA rampup

    if freezed is not None:
        assert freezed >= 0, '--freezed must be non-negative'
        desc += f'-freezed{freezed:d}'
        args.D_kwargs.block_kwargs.freeze_layers = freezed

    # Performance options: fp32, nhwc, nobench, workers
    # -------------------------------------------------

    if fp32:
        args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
        args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None

    if nhwc:
        args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True

    if nobench:
        args.cudnn_benchmark = False

    if workers is not None:
        assert workers >= 1, '--workers must be at least 1'
        args.data_loader_kwargs.num_workers = workers

    return desc, args