示例#1
0
def G_main(
    latents_in,  # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,  # Second input: Conditioning labels [minibatch, label_size].
    truncation_psi=0.5,  # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff=None,  # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val=None,  # Value for truncation_psi to use during validation.
    truncation_cutoff_val=None,  # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta=0.995,  # Decay for tracking the moving average of W during training. None = disable.
    style_mixing_prob=0.9,  # Probability of mixing styles during training. None = disable.
    is_training=False,  # Network is under training? Enables and disables specific features.
    is_validation=False,  # Network is under validation? Chooses which value to use for truncation_psi.
    return_dlatents=False,  # Return dlatents in addition to the images?
    is_template_graph=False,  # True = template graph constructed by the Network class, False = actual evaluation.
    components=dnnlib.EasyDict(
    ),  # Container for sub-networks. Retained between calls.
    mapping_latent_func='G_mapping_latent',  # Build func name for the mapping network.
    mapping_label_func='label_mapping',  # Build func name for the mapping network.
    synthesis_func='G_synthesis_stylegan2',  # Build func name for the synthesis network.
    **kwargs):  # Arguments for sub-networks (mapping and synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, dnnlib.EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None
                       and not tflib.is_tf_expression(truncation_psi)
                       and truncation_psi == 1):
        truncation_psi = None
    if is_training:
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None
                           and not tflib.is_tf_expression(dlatent_avg_beta)
                           and dlatent_avg_beta == 1):
        dlatent_avg_beta = None
    if not is_training or (style_mixing_prob is not None
                           and not tflib.is_tf_expression(style_mixing_prob)
                           and style_mixing_prob <= 0):
        style_mixing_prob = None

    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network(
            'G_synthesis', func_name=globals()[synthesis_func], **kwargs)
    num_layers = components.synthesis.input_shape[1]
    if 'mapping_latent' not in components:
        components.mapping_latent = tflib.Network(
            'G_mapping_latent',
            func_name=globals()[mapping_latent_func],
            dlatent_broadcast=num_layers,
            **kwargs)
    if 'mapping_label' not in components:
        components.mapping_label = tflib.Network(
            'label_mapping',
            func_name=globals()[mapping_label_func],
            dlatent_broadcast=num_layers,
            **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)

    # Evaluate mapping network.
    dlatents = components.mapping_latent.get_output_for(
        latents_in, is_training=is_training, **kwargs)
    dlatents = tf.cast(dlatents, tf.float32)

    dlatents_label = components.mapping_label.get_output_for(
        labels_in, is_training=is_training, **kwargs)
    dlatents_label = tf.cast(dlatents_label, tf.float32)

    dlatents = tf.concat([dlatents, dlatents_label], axis=-1)

    # Perform style mixing regularization.
    if style_mixing_prob is not None:
        with tf.variable_scope('StyleMix'):
            latents2 = tf.random_normal(tf.shape(latents_in))
            dlatents2 = components.mapping_latent.get_output_for(
                latents2, is_training=is_training, **kwargs)
            dlatents2 = tf.cast(dlatents2, tf.float32)
            dlatents_label2 = components.mapping_label.get_output_for(
                labels_in, is_training=is_training, **kwargs)
            dlatents_label2 = tf.cast(dlatents_label2, tf.float32)
            dlatents2 = tf.concat([dlatents2, dlatents_label2], axis=-1)
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)
            dlatents = tf.where(
                tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)),
                dlatents, dlatents2)

    # Evaluate synthesis network.
    deps = []
    if 'lod' in components.synthesis.vars:
        deps.append(tf.assign(components.synthesis.vars['lod'], lod_in))
    with tf.control_dependencies(deps):
        images_out = components.synthesis.get_output_for(
            dlatents,
            is_training=is_training,
            force_clean_graph=is_template_graph,
            **kwargs)

    # Return requested outputs.
    images_out = tf.identity(images_out, name='images_out')
    if return_dlatents:
        return images_out, dlatents
    return images_out
示例#2
0
def training_schedule(
    cur_nimg,
    training_set,
    num_gpus,
    lod_initial_resolution=4,  # Image resolution used at the beginning.
    lod_training_kimg=600,  # Thousands of real images to show before doubling the resolution.
    lod_transition_kimg=600,  # Thousands of real images to show when fading in new layers.
    minibatch_base=16,  # Maximum minibatch size, divided evenly among GPUs.
    minibatch_dict={},  # Resolution-specific overrides.
    max_minibatch_per_gpu={},  # Resolution-specific maximum minibatch size per GPU.
    G_lrate_base=0.001,  # Learning rate for the generator.
    G_lrate_dict={},  # Resolution-specific overrides.
    D_lrate_base=0.001,  # Learning rate for the discriminator.
    D_lrate_dict={},  # Resolution-specific overrides.
    lrate_rampup_kimg=0,  # Duration of learning rate ramp-up.
    tick_kimg_base=160,  # Default interval of progress snapshots.
    tick_kimg_dict={
        4: 160,
        8: 140,
        16: 120,
        32: 100,
        64: 80,
        128: 60,
        256: 40,
        512: 30,
        1024: 20
    }):  # Resolution-specific overrides.

    # Initialize result dict.
    s = dnnlib.EasyDict()
    s.kimg = cur_nimg / 1000.0

    # Training phase.
    phase_dur = lod_training_kimg + lod_transition_kimg
    phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0
    phase_kimg = s.kimg - phase_idx * phase_dur

    # Level-of-detail and resolution.
    s.lod = training_set.resolution_log2
    s.lod -= np.floor(np.log2(lod_initial_resolution))
    s.lod -= phase_idx
    if lod_transition_kimg > 0:
        s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg
    s.lod = max(s.lod, 0.0)
    s.resolution = 2**(training_set.resolution_log2 - int(np.floor(s.lod)))

    # Minibatch size.
    s.minibatch = minibatch_dict.get(s.resolution, minibatch_base)
    s.minibatch -= s.minibatch % num_gpus
    if s.resolution in max_minibatch_per_gpu:
        s.minibatch = min(s.minibatch,
                          max_minibatch_per_gpu[s.resolution] * num_gpus)

    # Learning rate.
    s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base)
    s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base)
    if lrate_rampup_kimg > 0:
        rampup = min(s.kimg / lrate_rampup_kimg, 1.0)
        s.G_lrate *= rampup
        s.D_lrate *= rampup

    # Other parameters.
    s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base)
    return s
示例#3
0
def convert_tf_discriminator(tf_D):
    if tf_D.version < 4:
        raise ValueError('TensorFlow pickle version too low')

    # Collect kwargs.
    tf_kwargs = tf_D.static_kwargs
    known_kwargs = set()

    def kwarg(tf_name, default=None):
        known_kwargs.add(tf_name)
        return tf_kwargs.get(tf_name, default)

    # Convert kwargs.
    kwargs = dnnlib.EasyDict(
        c_dim=kwarg('label_size', 0),
        img_resolution=kwarg('resolution', 1024),
        img_channels=kwarg('num_channels', 3),
        architecture=kwarg('architecture', 'resnet'),
        channel_base=kwarg('fmap_base', 16384) * 2,
        channel_max=kwarg('fmap_max', 512),
        num_fp16_res=kwarg('num_fp16_res', 0),
        conv_clamp=kwarg('conv_clamp', None),
        cmap_dim=kwarg('mapping_fmaps', None),
        block_kwargs=dnnlib.EasyDict(
            activation=kwarg('nonlinearity', 'lrelu'),
            resample_filter=kwarg('resample_kernel', [1, 3, 3, 1]),
            freeze_layers=kwarg('freeze_layers', 0),
        ),
        mapping_kwargs=dnnlib.EasyDict(
            num_layers=kwarg('mapping_layers', 0),
            embed_features=kwarg('mapping_fmaps', None),
            layer_features=kwarg('mapping_fmaps', None),
            activation=kwarg('nonlinearity', 'lrelu'),
            lr_multiplier=kwarg('mapping_lrmul', 0.1),
        ),
        epilogue_kwargs=dnnlib.EasyDict(
            mbstd_group_size=kwarg('mbstd_group_size', None),
            mbstd_num_channels=kwarg('mbstd_num_features', 1),
            activation=kwarg('nonlinearity', 'lrelu'),
        ),
    )

    # Check for unknown kwargs.
    kwarg('structure')
    unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
    if len(unknown_kwargs) > 0:
        raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])

    # Collect params.
    tf_params = _collect_tf_params(tf_D)
    for name, value in list(tf_params.items()):
        match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
        if match:
            r = kwargs.img_resolution // (2**int(match.group(1)))
            tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
            kwargs.architecture = 'orig'
    #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')

    # Convert params.
    from training import networks
    D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
    # pylint: disable=unnecessary-lambda
    _populate_module_params(
        D,
        r'b(\d+)\.fromrgb\.weight',
        lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
        r'b(\d+)\.fromrgb\.bias',
        lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
        r'b(\d+)\.conv(\d+)\.weight',
        lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight']
        .transpose(3, 2, 0, 1),
        r'b(\d+)\.conv(\d+)\.bias',
        lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
        r'b(\d+)\.skip\.weight',
        lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
        r'mapping\.embed\.weight',
        lambda: tf_params[f'LabelEmbed/weight'].transpose(),
        r'mapping\.embed\.bias',
        lambda: tf_params[f'LabelEmbed/bias'],
        r'mapping\.fc(\d+)\.weight',
        lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
        r'mapping\.fc(\d+)\.bias',
        lambda i: tf_params[f'Mapping{i}/bias'],
        r'b4\.conv\.weight',
        lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
        r'b4\.conv\.bias',
        lambda: tf_params[f'4x4/Conv/bias'],
        r'b4\.fc\.weight',
        lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
        r'b4\.fc\.bias',
        lambda: tf_params[f'4x4/Dense0/bias'],
        r'b4\.out\.weight',
        lambda: tf_params[f'Output/weight'].transpose(),
        r'b4\.out\.bias',
        lambda: tf_params[f'Output/bias'],
        r'.*\.resample_filter',
        None,
    )
    return D
示例#4
0
def training_schedule(
    cur_nimg,
    training_set,
    lod_initial_resolution=None,  # Image resolution used at the beginning.
    lod_training_kimg=600,  # Thousands of real images to show before doubling the resolution.
    lod_transition_kimg=600,  # Thousands of real images to show when fading in new layers.
    minibatch_size_base=32,  # Global minibatch size.
    minibatch_size_dict={},  # Resolution-specific overrides.
    minibatch_gpu_base=4,  # Number of samples processed at a time by one GPU.
    minibatch_gpu_dict={},  # Resolution-specific overrides.
    G_lrate_base=0.002,  # Learning rate for the generator.
    G_lrate_dict={},  # Resolution-specific overrides.
    D_lrate_base=0.002,  # Learning rate for the discriminator.
    D_lrate_dict={},  # Resolution-specific overrides.
    lrate_rampup_kimg=0,  # Duration of learning rate ramp-up.
    tick_kimg_base=4,  # Default interval of progress snapshots.
    tick_kimg_dict={
        8: 28,
        16: 24,
        32: 20,
        64: 16,
        128: 12,
        256: 8,
        512: 6,
        1024: 4
    }):  # Resolution-specific overrides.

    # Initialize result dict.
    s = dnnlib.EasyDict()
    s.kimg = cur_nimg / 1000.0

    # Training phase.
    phase_dur = lod_training_kimg + lod_transition_kimg
    phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0
    phase_kimg = s.kimg - phase_idx * phase_dur

    # Level-of-detail and resolution.
    if lod_initial_resolution is None:
        s.lod = 0.0
    else:
        s.lod = training_set.resolution_log2
        s.lod -= np.floor(np.log2(lod_initial_resolution))
        s.lod -= phase_idx
        if lod_transition_kimg > 0:
            s.lod -= max(phase_kimg - lod_training_kimg,
                         0.0) / lod_transition_kimg
        s.lod = max(s.lod, 0.0)
    s.resolution = 2**(training_set.resolution_log2 - int(np.floor(s.lod)))

    # Minibatch size.
    s.minibatch_size = minibatch_size_dict.get(s.resolution,
                                               minibatch_size_base)
    s.minibatch_gpu = minibatch_gpu_dict.get(s.resolution, minibatch_gpu_base)

    # Learning rate.
    s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base)
    s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base)
    if lrate_rampup_kimg > 0:
        rampup = min(s.kimg / lrate_rampup_kimg, 1.0)
        s.G_lrate *= rampup
        s.D_lrate *= rampup

    # Other parameters.
    s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base)
    return s
示例#5
0
def setup_training_loop_kwargs(
        # General options (not included in desc).
        gpus=None,  # Number of GPUs: <int>, default = 1 gpu
        snap=None,  # Snapshot interval: <int>, default = 50 ticks
        metrics=None,  # List of metric names: [], ['fid50k_full'] (default), ...
        seed=None,  # Random seed: <int>, default = 0

        # Dataset.
    data=None,  # Training dataset (required): <path>
        cond=None,  # Train conditional model based on dataset labels: <bool>, default = False
        subset=None,  # Train with only N images: <int>, default = all
        mirror=None,  # Augment dataset with x-flips: <bool>, default = False

        # Base config.
    cfg=None,  # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
        gamma=None,  # Override R1 gamma: <float>
        kimg=None,  # Override training duration: <int>
        batch=None,  # Override batch size: <int>

        # Discriminator augmentation.
    aug=None,  # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
        p=None,  # Specify p for 'fixed' (required): <float>
        target=None,  # Override ADA target for 'ada': <float>, default = depends on aug
        augpipe=None,  # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'

        # Transfer learning.
    resume=None,  # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
        freezed=None,  # Freeze-D: <int>, default = 0 discriminator layers

        # Performance options (not included in desc).
    fp32=None,  # Disable mixed-precision training: <bool>, default = False
        nhwc=None,  # Use NHWC memory format with FP16: <bool>, default = False
        allow_tf32=None,  # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
        nobench=None,  # Disable cuDNN benchmarking: <bool>, default = False
        workers=None,  # Override number of DataLoader workers: <int>, default = 3
):
    args = dnnlib.EasyDict()

    # ------------------------------------------
    # General options: gpus, snap, metrics, seed
    # ------------------------------------------

    if gpus is None:
        gpus = 1
    assert isinstance(gpus, int)
    if not (gpus >= 1 and gpus & (gpus - 1) == 0):
        raise UserError('--gpus must be a power of two')
    args.num_gpus = gpus

    if snap is None:
        snap = 50
    assert isinstance(snap, int)
    if snap < 1:
        raise UserError('--snap must be at least 1')
    args.image_snapshot_ticks = snap
    args.network_snapshot_ticks = snap

    if metrics is None:
        metrics = ['fid50k_full']
    assert isinstance(metrics, list)
    if not all(metric_main.is_valid_metric(metric) for metric in metrics):
        raise UserError(
            '\n'.join(['--metrics can only contain the following values:'] +
                      metric_main.list_valid_metrics()))
    args.metrics = metrics

    if seed is None:
        seed = 0
    assert isinstance(seed, int)
    args.random_seed = seed

    # -----------------------------------
    # Dataset: data, cond, subset, mirror
    # -----------------------------------

    assert data is not None
    assert isinstance(data, str)
    args.training_set_kwargs = dnnlib.EasyDict(
        class_name='training.dataset.ImageFolderDataset',
        path=data,
        use_labels=True,
        max_size=None,
        xflip=False)
    args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True,
                                              num_workers=3,
                                              prefetch_factor=2)
    try:
        training_set = dnnlib.util.construct_class_by_name(
            **args.training_set_kwargs)  # subclass of training.dataset.Dataset
        args.training_set_kwargs.resolution = training_set.resolution  # be explicit about resolution
        args.training_set_kwargs.use_labels = training_set.has_labels  # be explicit about labels
        args.training_set_kwargs.max_size = len(
            training_set)  # be explicit about dataset size
        desc = training_set.name
        del training_set  # conserve memory
    except IOError as err:
        raise UserError(f'--data: {err}')

    if cond is None:
        cond = False
    assert isinstance(cond, bool)
    if cond:
        if not args.training_set_kwargs.use_labels:
            raise UserError(
                '--cond=True requires labels specified in dataset.json')
        desc += '-cond'
    else:
        args.training_set_kwargs.use_labels = False

    if subset is not None:
        assert isinstance(subset, int)
        if not 1 <= subset <= args.training_set_kwargs.max_size:
            raise UserError(
                f'--subset must be between 1 and {args.training_set_kwargs.max_size}'
            )
        desc += f'-subset{subset}'
        if subset < args.training_set_kwargs.max_size:
            args.training_set_kwargs.max_size = subset
            args.training_set_kwargs.random_seed = args.random_seed

    if mirror is None:
        mirror = False
    assert isinstance(mirror, bool)
    if mirror:
        desc += '-mirror'
        args.training_set_kwargs.xflip = True

    # ------------------------------------
    # Base config: cfg, gamma, kimg, batch
    # ------------------------------------

    if cfg is None:
        cfg = 'auto'
    assert isinstance(cfg, str)
    desc += f'-{cfg}'

    cfg_specs = {
        'auto':
        dict(
            ref_gpus=-1,
            kimg=25000,
            mb=-1,
            mbstd=-1,
            fmaps=-1,
            lrate=-1,
            gamma=-1,
            ema=-1,
            ramp=0.05,
            map=2),  # Populated dynamically based on resolution and GPU count.
        'stylegan2':
        dict(ref_gpus=8,
             kimg=25000,
             mb=32,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=8),  # Uses mixed-precision, unlike the original StyleGAN2.
        'paper256':
        dict(ref_gpus=8,
             kimg=25000,
             mb=64,
             mbstd=8,
             fmaps=0.5,
             lrate=0.0025,
             gamma=1,
             ema=20,
             ramp=None,
             map=8),
        'paper512':
        dict(ref_gpus=8,
             kimg=25000,
             mb=64,
             mbstd=8,
             fmaps=1,
             lrate=0.0025,
             gamma=0.5,
             ema=20,
             ramp=None,
             map=8),
        'paper1024':
        dict(ref_gpus=8,
             kimg=25000,
             mb=32,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=2,
             ema=10,
             ramp=None,
             map=8),
        'cifar':
        dict(ref_gpus=2,
             kimg=100000,
             mb=64,
             mbstd=32,
             fmaps=1,
             lrate=0.0025,
             gamma=0.01,
             ema=500,
             ramp=0.05,
             map=2),
    }

    assert cfg in cfg_specs
    spec = dnnlib.EasyDict(cfg_specs[cfg])
    if cfg == 'auto':
        desc += f'{gpus:d}'
        spec.ref_gpus = gpus
        res = args.training_set_kwargs.resolution
        spec.mb = max(min(gpus * min(4096 // res, 32), 64),
                      gpus)  # keep gpu memory consumption at bay
        spec.mbstd = min(
            spec.mb // gpus, 4
        )  # other hyperparams behave more predictably if mbstd group size remains fixed
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.lrate = 0.002 if res >= 1024 else 0.0025
        spec.gamma = 0.0002 * (res**2) / spec.mb  # heuristic formula
        spec.ema = spec.mb * 10 / 32

    args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator',
                                    z_dim=512,
                                    w_dim=512,
                                    mapping_kwargs=dnnlib.EasyDict(),
                                    synthesis_kwargs=dnnlib.EasyDict())
    args.D_kwargs = dnnlib.EasyDict(
        class_name='training.networks.Discriminator',
        block_kwargs=dnnlib.EasyDict(),
        mapping_kwargs=dnnlib.EasyDict(),
        epilogue_kwargs=dnnlib.EasyDict())
    args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(
        spec.fmaps * 32768)
    args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
    args.G_kwargs.mapping_kwargs.num_layers = 3  # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!Edit to 3
    args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4  # enable mixed-precision training
    args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256  # clamp activations to avoid float16 overflow
    args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd

    args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam',
                                        lr=spec.lrate,
                                        betas=[0, 0.99],
                                        eps=1e-8)
    args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam',
                                        lr=spec.lrate,
                                        betas=[0, 0.99],
                                        eps=1e-8)
    args.loss_kwargs = dnnlib.EasyDict(
        class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)

    args.total_kimg = spec.kimg
    args.batch_size = spec.mb
    args.batch_gpu = spec.mb // spec.ref_gpus
    args.ema_kimg = spec.ema
    args.ema_rampup = spec.ramp

    if cfg == 'cifar':
        args.loss_kwargs.pl_weight = 0  # disable path length regularization
        args.loss_kwargs.style_mixing_prob = 0  # disable style mixing
        args.D_kwargs.architecture = 'orig'  # disable residual skip connections

    if gamma is not None:
        assert isinstance(gamma, float)
        if not gamma >= 0:
            raise UserError('--gamma must be non-negative')
        desc += f'-gamma{gamma:g}'
        args.loss_kwargs.r1_gamma = gamma

    if kimg is not None:
        assert isinstance(kimg, int)
        if not kimg >= 1:
            raise UserError('--kimg must be at least 1')
        desc += f'-kimg{kimg:d}'
        args.total_kimg = kimg

    if batch is not None:
        assert isinstance(batch, int)
        if not (batch >= 1 and batch % gpus == 0):
            raise UserError(
                '--batch must be at least 1 and divisible by --gpus')
        desc += f'-batch{batch}'
        args.batch_size = batch
        args.batch_gpu = batch // gpus

    # ---------------------------------------------------
    # Discriminator augmentation: aug, p, target, augpipe
    # ---------------------------------------------------

    if aug is None:
        aug = 'ada'
    else:
        assert isinstance(aug, str)
        desc += f'-{aug}'

    if aug == 'ada':
        args.ada_target = 0.6

    elif aug == 'noaug':
        pass

    elif aug == 'fixed':
        if p is None:
            raise UserError(f'--aug={aug} requires specifying --p')

    else:
        raise UserError(f'--aug={aug} not supported')

    if p is not None:
        assert isinstance(p, float)
        if aug != 'fixed':
            raise UserError('--p can only be specified with --aug=fixed')
        if not 0 <= p <= 1:
            raise UserError('--p must be between 0 and 1')
        desc += f'-p{p:g}'
        args.augment_p = p

    if target is not None:
        assert isinstance(target, float)
        if aug != 'ada':
            raise UserError('--target can only be specified with --aug=ada')
        if not 0 <= target <= 1:
            raise UserError('--target must be between 0 and 1')
        desc += f'-target{target:g}'
        args.ada_target = target

    assert augpipe is None or isinstance(augpipe, str)
    if augpipe is None:
        augpipe = 'bgc'
    else:
        if aug == 'noaug':
            raise UserError('--augpipe cannot be specified with --aug=noaug')
        desc += f'-{augpipe}'

    augpipe_specs = {
        'blit':
        dict(xflip=1, rotate90=1, xint=1),
        'geom':
        dict(scale=1, rotate=1, aniso=1, xfrac=1),
        'color':
        dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'filter':
        dict(imgfilter=1),
        'noise':
        dict(noise=1),
        'cutout':
        dict(cutout=1),
        'bg':
        dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
        'bgc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1),
        'bgcf':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1),
        'bgcfn':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1),
        'bgcfnc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1,
             cutout=1),
    }

    assert augpipe in augpipe_specs
    if aug != 'noaug':
        args.augment_kwargs = dnnlib.EasyDict(
            class_name='training.augment.AugmentPipe',
            **augpipe_specs[augpipe])

    # ----------------------------------
    # Transfer learning: resume, freezed
    # ----------------------------------

    resume_specs = {
        'ffhq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
        'ffhq512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
        'ffhq1024':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
        'celebahq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
        'lsundog256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
    }

    assert resume is None or isinstance(resume, str)
    if resume is None:
        resume = 'noresume'
    elif resume == 'noresume':
        desc += '-noresume'
    elif resume in resume_specs:
        desc += f'-resume{resume}'
        args.resume_pkl = resume_specs[resume]  # predefined url
    else:
        desc += '-resumecustom'
        args.resume_pkl = resume  # custom path or url

    if resume != 'noresume':
        args.ada_kimg = 100  # make ADA react faster at the beginning
        args.ema_rampup = None  # disable EMA rampup

    if freezed is not None:
        assert isinstance(freezed, int)
        if not freezed >= 0:
            raise UserError('--freezed must be non-negative')
        desc += f'-freezed{freezed:d}'
        args.D_kwargs.block_kwargs.freeze_layers = freezed

    # -------------------------------------------------
    # Performance options: fp32, nhwc, nobench, workers
    # -------------------------------------------------

    if fp32 is None:
        fp32 = False
    assert isinstance(fp32, bool)
    if fp32:
        args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
        args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None

    if nhwc is None:
        nhwc = False
    assert isinstance(nhwc, bool)
    if nhwc:
        args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True

    if nobench is None:
        nobench = False
    assert isinstance(nobench, bool)
    if nobench:
        args.cudnn_benchmark = False

    if allow_tf32 is None:
        allow_tf32 = False
    assert isinstance(allow_tf32, bool)
    if allow_tf32:
        args.allow_tf32 = True

    if workers is not None:
        assert isinstance(workers, int)
        if not workers >= 1:
            raise UserError('--workers must be at least 1')
        args.data_loader_kwargs.num_workers = workers

    return desc, args
 def _report_result(self, value, suffix='', fmt='%-10.4f'):
     self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]
示例#7
0
import validation

# Submit config
# ------------------------------------------------------------------------------------------

submit_config = dnnlib.SubmitConfig()
submit_config.run_dir_root = 'results'
submit_config.run_dir_ignore += ['datasets', 'results']

desc = "autoencoder"

# Tensorflow config
# ------------------------------------------------------------------------------------------

tf_config = dnnlib.EasyDict()
tf_config["graph_options.place_pruned_graph"] = True

# Network config
# ------------------------------------------------------------------------------------------

net_config = dnnlib.EasyDict(func_name="network.autoencoder")

# Optimizer config
# ------------------------------------------------------------------------------------------

optimizer_config = dnnlib.EasyDict(beta1=0.9, beta2=0.99, epsilon=1e-8)

# Noise augmentation config
gaussian_noise_config = dnnlib.EasyDict(func_name='train.AugmentGaussian',
                                        train_stddev_rng_range=(0.0, 50.0),
示例#8
0
def generate_images(network_pkl,
                    seeds,
                    truncation_psi,
                    data_dir=None,
                    dataset_name=None,
                    model=None):
    G_args = EasyDict(func_name='training.' + model + '.G_main')
    dataset_args = EasyDict(tfrecord_dir=dataset_name)
    G_args.fmap_base = 8 << 10
    tflib.init_tf()
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    print('Constructing networks...')
    Gs = tflib.Network('G',
                       num_channels=training_set.shape[0],
                       resolution=training_set.shape[1],
                       label_size=training_set.label_size,
                       **G_args)
    print('Loading networks from "%s"...' % network_pkl)
    _, _, _Gs = pretrained_networks.load_networks(network_pkl)
    Gs.copy_vars_from(_Gs)
    noise_vars = [
        var for name, var in Gs.components.synthesis.vars.items()
        if name.startswith('noise')
    ]

    Gs_kwargs = dnnlib.EasyDict()
    # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation_psi is not None:
        Gs_kwargs.truncation_psi = truncation_psi

    for seed_idx, seed in enumerate(seeds):
        print('Generating image for seed %d (%d/%d) ...' %
              (seed, seed_idx, len(seeds)))
        rnd = np.random.RandomState(seed)
        z = rnd.randn(1, *Gs.input_shape[1:])  # [minibatch, component]
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]
        images, x_v, n_v, m_v = Gs.run(
            z, None, **Gs_kwargs)  # [minibatch, height, width, channel]

        print(images.shape, n_v.shape, x_v.shape, m_v.shape)
        misc.convert_to_pil_image(images[0], drange=[-1, 1]).save(
            dnnlib.make_run_dir_path('seed%04d.png' % seed))
        misc.save_image_grid(adjust_range(n_v),
                             dnnlib.make_run_dir_path('seed%04d-nv.png' %
                                                      seed),
                             drange=[-1, 1])
        print(np.linalg.norm(x_v - m_v))
        misc.save_image_grid(adjust_range(x_v).transpose([1, 0, 2, 3]),
                             dnnlib.make_run_dir_path('seed%04d-xv.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(m_v).transpose([1, 0, 2, 3]),
                             dnnlib.make_run_dir_path('seed%04d-mv.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(x_v, 'cat')),
                             dnnlib.make_run_dir_path('seed%04d-xvs.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(m_v, 'ss')),
                             dnnlib.make_run_dir_path('seed%04d-mvs.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(m_v, 'ffhq')),
                             dnnlib.make_run_dir_path('seed%04d-fmvs.png' %
                                                      seed),
                             drange=[-1, 1])
示例#9
0
def style_mixing_example(network_pkl,
                         row_seeds,
                         col_seeds,
                         truncation_psi,
                         col_styles,
                         minibatch_size=4):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    w_avg = Gs.get_var("dlatent_avg")  # [component]

    Gs_syn_kwargs = dnnlib.EasyDict()
    Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8,
                                          nchw_to_nhwc=True)
    Gs_syn_kwargs.randomize_noise = False
    Gs_syn_kwargs.minibatch_size = minibatch_size

    print("Generating W vectors...")
    all_seeds = list(set(row_seeds + col_seeds))
    all_z = np.stack([
        np.random.RandomState(seed).randn(*Gs.input_shape[1:])
        for seed in all_seeds
    ])  # [minibatch, component]
    all_w = Gs.components.mapping.run(all_z,
                                      None)  # [minibatch, layer, component]
    all_w = w_avg + (all_w -
                     w_avg) * truncation_psi  # [minibatch, layer, component]
    w_dict = {seed: w
              for seed, w in zip(all_seeds, list(all_w))}  # [layer, component]

    print("Generating images...")
    all_images = Gs.components.synthesis.run(
        all_w, **Gs_syn_kwargs)  # [minibatch, height, width, channel]
    image_dict = {(seed, seed): image
                  for seed, image in zip(all_seeds, list(all_images))}

    print("Generating style-mixed images...")
    for row_seed in row_seeds:
        for col_seed in col_seeds:
            w = w_dict[row_seed].copy()
            w[col_styles] = w_dict[col_seed][col_styles]
            image = Gs.components.synthesis.run(w[np.newaxis],
                                                **Gs_syn_kwargs)[0]
            image_dict[(row_seed, col_seed)] = image

    print("Saving images...")
    for (row_seed, col_seed), image in image_dict.items():
        PIL.Image.fromarray(image, "RGB").save(
            dnnlib.make_run_dir_path("%d-%d.png" % (row_seed, col_seed)))

    print("Saving image grid...")
    _N, _C, H, W = Gs.output_shape
    canvas = PIL.Image.new("RGB", (W * (len(col_seeds) + 1), H *
                                   (len(row_seeds) + 1)), "black")
    for row_idx, row_seed in enumerate([None] + row_seeds):
        for col_idx, col_seed in enumerate([None] + col_seeds):
            if row_seed is None and col_seed is None:
                continue
            key = (row_seed, col_seed)
            if row_seed is None:
                key = (col_seed, col_seed)
            if col_seed is None:
                key = (row_seed, row_seed)
            canvas.paste(PIL.Image.fromarray(image_dict[key], "RGB"),
                         (W * col_idx, H * row_idx))
    canvas.save(dnnlib.make_run_dir_path("grid.png"))
示例#10
0
def main():
    os.makedirs(a.out_dir, exist_ok=True)
    np.random.seed(seed=696)

    # parse filename to model parameters
    mparams = basename(a.model).split('-')
    res = int(mparams[1])
    cfg = mparams[2]

    # setup generator
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.func_name = 'training.stylegan2_custom.G_main'
    Gs_kwargs.verbose = False
    Gs_kwargs.resolution = res
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type
    Gs_kwargs.latent_size = a.latent_size
    Gs_kwargs.impl = a.ops

    if cfg.lower() == 'f':
        Gs_kwargs.synthesis_func = 'G_synthesis_stylegan2'
    elif cfg.lower() == 'e':
        Gs_kwargs.synthesis_func = 'G_synthesis_stylegan2'
        Gs_kwargs.fmap_base = 8 << 10
    else:
        print(' old modes [A-D] not implemented')
        exit()

    # check initial model resolution
    if len(mparams) > 3:
        if 'x' in mparams[3].lower():
            init_res = [int(x) for x in mparams[3].lower().split('x')]
            Gs_kwargs.init_res = list(reversed(init_res))  # [H,W]

    # load model, check channels
    sess = tflib.init_tf({'allow_soft_placement': True})
    pkl_name = osp.splitext(a.model)[0]
    with open(pkl_name + '.pkl', 'rb') as file:
        network = pickle.load(file, encoding='latin1')
    try:
        _, _, network = network
    except:
        pass
    Gs_kwargs.num_channels = network.output_shape[1]

    # reload custom network, if needed
    if '.pkl' in a.model.lower():
        print(' .. Gs from pkl ..')
        Gs = network
    else:
        print(' .. Gs custom ..')
        Gs = tflib.Network('Gs', **Gs_kwargs)
        Gs.copy_vars_from(network)
    # Gs.print_layers()
    print(' out shape', Gs.output_shape[1:])
    if a.size is None: a.size = Gs.output_shape[2:]

    z_dim = Gs.input_shape[1]
    shape = (1, z_dim)

    print(' making timeline..')
    latents = latent_anima(shape,
                           a.frames,
                           a.fstep,
                           cubic=a.cubic,
                           gauss=a.gauss,
                           verbose=True)  # [frm,1,512]
    print(' latents', latents.shape)

    # generate images from latent timeline
    frame_count = latents.shape[0]
    pbar = ProgressBar(frame_count)
    for i in range(frame_count):
        output = Gs.run(latents[i], [None],
                        truncation_psi=a.trunc,
                        randomize_noise=False,
                        output_transform=fmt)
        ext = 'png' if output.shape[3] == 4 else 'jpg'
        filename = osp.join(a.out_dir, "%05d.%s" % (i, ext))
        imsave(filename, output[0])
        pbar.upd()

    # convert latents to dlatents, save them
    latents = latents.squeeze(1)  # [frm,512]
    dlatents = Gs.components.mapping.run(latents,
                                         None,
                                         latent_size=z_dim,
                                         dtype='float16')  # [frm,18,512]
    filename = '{}-{}-{}.npy'.format(basename(a.model), a.size[1], a.size[0])
    filename = osp.join(osp.dirname(a.out_dir), filename)
    np.save(filename, dlatents)
    print('saved dlatents', dlatents.shape, 'to', filename)
示例#11
0
import time
import hashlib
import numpy as np
import tensorflow as tf
import dnnlib
import dnnlib.tflib as tflib

import config
from training import misc
from training import dataset

#----------------------------------------------------------------------------
# Standard metrics.

fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID',
                         name='fid50k',
                         num_images=50000,
                         minibatch_per_gpu=8)
ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL',
                            name='ppl_zfull',
                            num_samples=100000,
                            epsilon=1e-4,
                            space='z',
                            sampling='full',
                            minibatch_per_gpu=16)
ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL',
                            name='ppl_wfull',
                            num_samples=100000,
                            epsilon=1e-4,
                            space='w',
                            sampling='full',
                            minibatch_per_gpu=16)
示例#12
0
def convert_tf_generator(tf_G):
    if tf_G.version < 4:
        raise ValueError("TensorFlow pickle version too low")

    # Collect kwargs.
    tf_kwargs = tf_G.static_kwargs
    known_kwargs = set()

    def kwarg(tf_name, default=None, none=None):
        known_kwargs.add(tf_name)
        val = tf_kwargs.get(tf_name, default)
        return val if val is not None else none

    # Convert kwargs.
    kwargs = dnnlib.EasyDict(
        z_dim=kwarg("latent_size", 512),
        c_dim=kwarg("label_size", 0),
        w_dim=kwarg("dlatent_size", 512),
        img_resolution=kwarg("resolution", 1024),
        img_channels=kwarg("num_channels", 3),
        mapping_kwargs=dnnlib.EasyDict(
            num_layers=kwarg("mapping_layers", 8),
            embed_features=kwarg("label_fmaps", None),
            layer_features=kwarg("mapping_fmaps", None),
            activation=kwarg("mapping_nonlinearity", "lrelu"),
            lr_multiplier=kwarg("mapping_lrmul", 0.01),
            w_avg_beta=kwarg("w_avg_beta", 0.995, none=1),
        ),
        synthesis_kwargs=dnnlib.EasyDict(
            channel_base=kwarg("fmap_base", 16384) * 2,
            channel_max=kwarg("fmap_max", 512),
            num_fp16_res=kwarg("num_fp16_res", 0),
            conv_clamp=kwarg("conv_clamp", None),
            architecture=kwarg("architecture", "skip"),
            resample_filter=kwarg("resample_kernel", [1, 3, 3, 1]),
            use_noise=kwarg("use_noise", True),
            activation=kwarg("nonlinearity", "lrelu"),
        ),
    )

    # Check for unknown kwargs.
    kwarg("truncation_psi")
    kwarg("truncation_cutoff")
    kwarg("style_mixing_prob")
    kwarg("structure")
    unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
    if len(unknown_kwargs) > 0:
        raise ValueError("Unknown TensorFlow kwarg", unknown_kwargs[0])

    # Collect params.
    tf_params = _collect_tf_params(tf_G)
    for name, value in list(tf_params.items()):
        match = re.fullmatch(r"ToRGB_lod(\d+)/(.*)", name)
        if match:
            r = kwargs.img_resolution // (2**int(match.group(1)))
            tf_params[f"{r}x{r}/ToRGB/{match.group(2)}"] = value
            kwargs.synthesis.kwargs.architecture = "orig"
    # for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')

    # Convert params.
    from training import networks

    G = networks.Generator(**kwargs).eval().requires_grad_(False)
    # pylint: disable=unnecessary-lambda
    _populate_module_params(
        G,
        r"mapping\.w_avg",
        lambda: tf_params[f"dlatent_avg"],
        r"mapping\.embed\.weight",
        lambda: tf_params[f"mapping/LabelEmbed/weight"].transpose(),
        r"mapping\.embed\.bias",
        lambda: tf_params[f"mapping/LabelEmbed/bias"],
        r"mapping\.fc(\d+)\.weight",
        lambda i: tf_params[f"mapping/Dense{i}/weight"].transpose(),
        r"mapping\.fc(\d+)\.bias",
        lambda i: tf_params[f"mapping/Dense{i}/bias"],
        r"synthesis\.b4\.const",
        lambda: tf_params[f"synthesis/4x4/Const/const"][0],
        r"synthesis\.b4\.conv1\.weight",
        lambda: tf_params[f"synthesis/4x4/Conv/weight"].transpose(3, 2, 0, 1),
        r"synthesis\.b4\.conv1\.bias",
        lambda: tf_params[f"synthesis/4x4/Conv/bias"],
        r"synthesis\.b4\.conv1\.noise_const",
        lambda: tf_params[f"synthesis/noise0"][0, 0],
        r"synthesis\.b4\.conv1\.noise_strength",
        lambda: tf_params[f"synthesis/4x4/Conv/noise_strength"],
        r"synthesis\.b4\.conv1\.affine\.weight",
        lambda: tf_params[f"synthesis/4x4/Conv/mod_weight"].transpose(),
        r"synthesis\.b4\.conv1\.affine\.bias",
        lambda: tf_params[f"synthesis/4x4/Conv/mod_bias"] + 1,
        r"synthesis\.b(\d+)\.conv0\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/weight"
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r"synthesis\.b(\d+)\.conv0\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/bias"],
        r"synthesis\.b(\d+)\.conv0\.noise_const",
        lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-5}"][0, 0
                                                                          ],
        r"synthesis\.b(\d+)\.conv0\.noise_strength",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/noise_strength"],
        r"synthesis\.b(\d+)\.conv0\.affine\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_weight"].
        transpose(),
        r"synthesis\.b(\d+)\.conv0\.affine\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_bias"] + 1,
        r"synthesis\.b(\d+)\.conv1\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/weight"].transpose(
            3, 2, 0, 1),
        r"synthesis\.b(\d+)\.conv1\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/bias"],
        r"synthesis\.b(\d+)\.conv1\.noise_const",
        lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-4}"][0, 0
                                                                          ],
        r"synthesis\.b(\d+)\.conv1\.noise_strength",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/noise_strength"],
        r"synthesis\.b(\d+)\.conv1\.affine\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_weight"].transpose(),
        r"synthesis\.b(\d+)\.conv1\.affine\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_bias"] + 1,
        r"synthesis\.b(\d+)\.torgb\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/weight"].transpose(
            3, 2, 0, 1),
        r"synthesis\.b(\d+)\.torgb\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/bias"],
        r"synthesis\.b(\d+)\.torgb\.affine\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_weight"].transpose(),
        r"synthesis\.b(\d+)\.torgb\.affine\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_bias"] + 1,
        r"synthesis\.b(\d+)\.skip\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Skip/weight"
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r".*\.resample_filter",
        None,
    )
    return G
 def load(pkl_file):
     with open(pkl_file, 'rb') as f:
         s = dnnlib.EasyDict(pickle.load(f))
     obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
     obj.__dict__.update(s)
     return obj
from dnnlib.util import Logger
from ffhq_datareader import load_dataset
from experiments import compute_stylegan_realism
from experiments import compute_stylegan_truncation
from utils import init_tf

SAVE_PATH = os.path.dirname(__file__)

#----------------------------------------------------------------------------
# Configs for truncation sweep and realism score.

realism_config = dnnlib.EasyDict(minibatch_size=8,
                                 num_images=50000,
                                 num_gen_images=1000,
                                 show_n_images=64,
                                 truncation=1.0,
                                 save_images=True,
                                 save_path=SAVE_PATH,
                                 num_gpus=1,
                                 random_seed=123456)

truncation_config = dnnlib.EasyDict(minibatch_size=8,
                                    num_images=50000,
                                    truncations=[1.0, 0.7, 0.3],
                                    save_txt=True,
                                    save_path=SAVE_PATH,
                                    num_gpus=1,
                                    random_seed=1234)

#----------------------------------------------------------------------------
# Minimal CLI.
示例#15
0
 def _report_result(self, value, suffix="", fmt="%-10.4f"):
     self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]
示例#16
0
def main():
    os.makedirs(a.out_dir, exist_ok=True)
    device = torch.device('cuda')

    # setup generator
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type

    # load base or custom network
    pkl_name = osp.splitext(a.model)[0]
    if '.pkl' in a.model.lower():
        custom = False
        print(' .. Gs from pkl ..', basename(a.model))
    else:
        custom = True
        print(' .. Gs custom ..', basename(a.model))
    with dnnlib.util.open_url(pkl_name + '.pkl') as f:
        Gs = legacy.load_network_pkl(f, custom=custom, **Gs_kwargs)['G_ema'].to(device)  # type: ignore

    dlat_shape = (1, Gs.num_ws, Gs.w_dim)  # [1,18,512]

    # read saved latents
    if a.dlatents is not None and osp.isfile(a.dlatents):
        key_dlatents = load_latents(a.dlatents)
        if len(key_dlatents.shape) == 2: key_dlatents = np.expand_dims(key_dlatents, 0)
    elif a.dlatents is not None and osp.isdir(a.dlatents):
        # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1]
        key_dlatents = []
        npy_list = file_list(a.dlatents, 'npy')
        for npy in npy_list:
            key_dlatent = load_latents(npy)
            if len(key_dlatent.shape) == 2: key_dlatent = np.expand_dims(key_dlatent, 0)
            key_dlatents.append(key_dlatent)
        key_dlatents = np.concatenate(key_dlatents)  # [frm,18,512]
    else:
        print(' No input dlatents found');
        exit()
    key_dlatents = key_dlatents[:, np.newaxis]  # [frm,1,18,512]
    print(' key dlatents', key_dlatents.shape)

    # replace higher layers with single (style) latent
    if a.style_dlat is not None:
        print(' styling with dlatent', a.style_dlat)
        style_dlatent = load_latents(a.style_dlat)
        while len(style_dlatent.shape) < 4: style_dlatent = np.expand_dims(style_dlatent, 0)
        # try replacing 5 by other value, less than Gs.num_ws
        key_dlatents[:, :, range(5, Gs.num_ws), :] = style_dlatent[:, :, range(5, Gs.num_ws), :]

    frames = key_dlatents.shape[0] * a.fstep

    dlatents = latent_anima(dlat_shape, frames, a.fstep, key_latents=key_dlatents, cubic=a.cubic,
                            verbose=True)  # [frm,1,512]
    print(' dlatents', dlatents.shape)
    frame_count = dlatents.shape[0]
    dlatents = torch.from_numpy(dlatents).to(device)

    # distort image by tweaking initial const layer
    if a.digress > 0:
        try:
            init_res = Gs.init_res
        except:
            init_res = (4, 4)  # default initial layer size
        dconst = a.digress * latent_anima([1, Gs.z_dim, *init_res], frame_count, a.fstep, cubic=True, verbose=False)
    else:
        dconst = np.zeros([frame_count, 1, 1, 1, 1])
    dconst = torch.from_numpy(dconst).to(device)

    # generate images from latent timeline
    pbar = ProgressBar(frame_count)
    for i in range(frame_count):

        # generate multi-latent result
        if custom:
            output = Gs.synthesis(dlatents[i], None, dconst[i], noise_mode='const')
        else:
            output = Gs.synthesis(dlatents[i], noise_mode='const')
        output = (output.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()

        ext = 'png' if output.shape[3] == 4 else 'jpg'
        filename = osp.join(a.out_dir, "%06d.%s" % (i, ext))
        imsave(filename, output[0])
        pbar.upd()
示例#17
0
def G_style(
    latents_in,  # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,  # Second input: Conditioning labels [minibatch, label_size].
    truncation_psi=0.7,  # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff=8,  # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val=None,  # Value for truncation_psi to use during validation.
    truncation_cutoff_val=None,  # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta=0.995,  # Decay for tracking the moving average of W during training. None = disable.
    style_mixing_prob=0.9,  # Probability of mixing styles during training. None = disable.
    is_training=False,  # Network is under training? Enables and disables specific features.
    is_validation=False,  # Network is under validation? Chooses which value to use for truncation_psi.
    is_template_graph=False,  # True = template graph constructed by the Network class, False = actual evaluation.
    components=dnnlib.EasyDict(
    ),  # Container for sub-networks. Retained between calls.
    **kwargs):  # Arguments for sub-networks (G_mapping and G_synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, dnnlib.EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None
                       and not tflib.is_tf_expression(truncation_psi)
                       and truncation_psi == 1):
        truncation_psi = None
    if is_training or (truncation_cutoff is not None
                       and not tflib.is_tf_expression(truncation_cutoff)
                       and truncation_cutoff <= 0):
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None
                           and not tflib.is_tf_expression(dlatent_avg_beta)
                           and dlatent_avg_beta == 1):
        dlatent_avg_beta = None
    if not is_training or (style_mixing_prob is not None
                           and not tflib.is_tf_expression(style_mixing_prob)
                           and style_mixing_prob <= 0):
        style_mixing_prob = None

    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network('G_synthesis',
                                             func_name=G_synthesis,
                                             **kwargs)
    num_layers = components.synthesis.input_shape[1]
    dlatent_size = components.synthesis.input_shape[2]
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping',
                                           func_name=G_mapping,
                                           dlatent_broadcast=num_layers,
                                           **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
    dlatent_avg = tf.get_variable('dlatent_avg',
                                  shape=[dlatent_size],
                                  initializer=tf.initializers.zeros(),
                                  trainable=False)

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in, labels_in,
                                                 **kwargs)

    # Update moving average of W.
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(
                dlatent_avg,
                tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

    # Perform style mixing regularization.
    if style_mixing_prob is not None:
        with tf.name_scope('StyleMix'):
            latents2 = tf.random_normal(tf.shape(latents_in))
            dlatents2 = components.mapping.get_output_for(
                latents2, labels_in, **kwargs)
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)
            dlatents = tf.where(
                tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)),
                dlatents, dlatents2)

    # Apply truncation trick.
    if truncation_psi is not None and truncation_cutoff is not None:
        with tf.variable_scope('Truncation'):
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            ones = np.ones(layer_idx.shape, dtype=np.float32)
            coefs = tf.where(layer_idx < truncation_cutoff,
                             truncation_psi * ones, ones)
            dlatents = tflib.lerp(dlatent_avg, dlatents, coefs)

    # Evaluate synthesis network.
    with tf.control_dependencies(
        [tf.assign(components.synthesis.find_var('lod'), lod_in)]):
        images_out = components.synthesis.get_output_for(
            dlatents, force_clean_graph=is_template_graph, **kwargs)
    return tf.identity(images_out, name='images_out')
示例#18
0
def lerp_video(
        network_pkl,  # Path to pretrained model pkl file
        seeds,  # Random seeds
        grid_w=None,  # Number of columns
        grid_h=None,  # Number of rows
        truncation_psi=1.0,  # Truncation trick
        outdir='out',  # Output dir
        slowdown=1,  # Slowdown of the video (power of 2)
        duration_sec=30.0,  # Duration of video in seconds
        smoothing_sec=3.0,
        mp4_fps=30,
        mp4_codec="libx264",
        mp4_bitrate="16M"):
    # Sanity check regarding slowdown
    message = 'slowdown must be a power of 2 (1, 2, 4, 8, ...) and greater than 0!'
    assert slowdown & (slowdown - 1) == 0 and slowdown > 0, message
    # Initialize TensorFlow and create outdir
    tflib.init_tf()
    os.makedirs(outdir, exist_ok=True)
    # Total duration of video and number of frames to generate
    num_frames = int(np.rint(duration_sec * mp4_fps))
    total_duration = duration_sec * slowdown

    print(f'Loading network from {network_pkl}...')
    with dnnlib.util.open_url(network_pkl) as fp:
        _G, _D, Gs = pickle.load(fp)

    print("Generating latent vectors...")
    # If there's more than one seed provided and the shape isn't specified
    if grid_w == grid_h == None and len(seeds) >= 1:
        # number of images according to the seeds provided
        num = len(seeds)
        # Get the grid width and height according to num:
        grid_w = max(int(np.ceil(np.sqrt(num))), 1)
        grid_h = max((num - 1) // grid_w + 1, 1)
        grid_size = [grid_w, grid_h]
        # [frame, image, channel, component]:
        shape = [num_frames] + Gs.input_shape[1:]
        # Get the latents:
        all_latents = np.stack([
            np.random.RandomState(seed).randn(*shape).astype(np.float32)
            for seed in seeds
        ],
                               axis=1)
    # If only one seed is provided and the shape is specified
    elif None not in (grid_w, grid_h) and len(seeds) == 1:
        # Otherwise, the user gives one seed and the grid width and height:
        grid_size = [grid_w, grid_h]
        # [frame, image, channel, component]:
        shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:]
        # Get the latents with the random state:
        random_state = np.random.RandomState(seeds)
        all_latents = random_state.randn(*shape).astype(np.float32)
    else:
        print("Error: wrong combination of arguments! Please provide \
                either one seed and the grid width and height, or a \
                list of seeds to use.")
        sys.exit(1)

    all_latents = scipy.ndimage.gaussian_filter(
        all_latents, [smoothing_sec * mp4_fps] + [0] * len(Gs.input_shape),
        mode="wrap")
    all_latents /= np.sqrt(np.mean(np.square(all_latents)))
    # Name of the final mp4 video
    mp4 = f"{grid_w}x{grid_h}-lerp-{slowdown}xslowdown.mp4"

    # Aux function to slowdown the video by 2x
    def double_slowdown(latents, duration_sec, num_frames):
        # Make an empty latent vector with double the amount of frames
        z = np.empty(np.multiply(latents.shape, [2, 1, 1]), dtype=np.float32)
        # Populate it
        for i in range(len(latents)):
            z[2 * i] = latents[i]
        # Interpolate in the odd frames
        for i in range(1, len(z), 2):
            # For the last frame, we loop to the first one
            if i == len(z) - 1:
                z[i] = (z[0] + z[i - 1]) / 2
            else:
                z[i] = (z[i - 1] + z[i + 1]) / 2
        # We also need to double the duration_sec and num_frames
        duration_sec *= 2
        num_frames *= 2
        # Return the new latents, and the two previous quantities
        return z, duration_sec, num_frames

    while slowdown > 1:
        all_latents, duration_sec, num_frames = double_slowdown(
            all_latents, duration_sec, num_frames)
        slowdown //= 2

    # Define the kwargs for the Generator:
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8,
                                      nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation_psi is not None:
        Gs_kwargs.truncation_psi = truncation_psi

    # Aux function: Frame generation func for moviepy.
    def make_frame(t):
        frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
        latents = all_latents[frame_idx]
        # Get the images (with labels = None)
        images = Gs.run(latents, None, **Gs_kwargs)
        # Generate the grid for this timestamp:
        grid = create_image_grid(images, grid_size)
        # grayscale => RGB
        if grid.shape[2] == 1:
            grid = grid.repeat(3, 2)
        return grid

    # Generate video using make_frame:
    print(
        f'Generating interpolation video of length: {total_duration} seconds...'
    )
    videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
    videoclip.write_videofile(os.path.join(outdir, mp4),
                              fps=mp4_fps,
                              codec=mp4_codec,
                              bitrate=mp4_bitrate)
示例#19
0
def main():
    os.makedirs(a.out_dir, exist_ok=True)

    # setup generator
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.func_name = 'training.stylegan2_multi.G_main'
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type
    Gs_kwargs.impl = a.ops

    # load model with arguments
    sess = tflib.init_tf({'allow_soft_placement': True})
    pkl_name = osp.splitext(a.model)[0]
    with open(pkl_name + '.pkl', 'rb') as file:
        network = pickle.load(file, encoding='latin1')
    try:
        _, _, network = network
    except:
        pass
    for k in list(network.static_kwargs.keys()):
        Gs_kwargs[k] = network.static_kwargs[k]

    # reload custom network, if needed
    if '.pkl' in a.model.lower():
        print(' .. Gs from pkl ..', basename(a.model))
        Gs = network
    else:  # reconstruct network
        print(' .. Gs custom ..', basename(a.model))
        Gs = tflib.Network('Gs', **Gs_kwargs)
        Gs.copy_vars_from(network)

    z_dim = Gs.input_shape[1]
    dz_dim = 512  # dlatent_size
    try:
        dl_dim = 2 * (int(np.floor(np.log2(Gs_kwargs.resolution))) - 1)
    except:
        print(' Resave model, no resolution kwarg found!')
        exit(1)
    dlat_shape = (1, dl_dim, dz_dim)  # [1,18,512]

    # read saved latents
    if a.dlatents is not None and osp.isfile(a.dlatents):
        key_dlatents = load_latents(a.dlatents)
        if len(key_dlatents.shape) == 2:
            key_dlatents = np.expand_dims(key_dlatents, 0)
    elif a.dlatents is not None and osp.isdir(a.dlatents):
        # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1]
        key_dlatents = []
        npy_list = file_list(a.dlatents, 'npy')
        for npy in npy_list:
            key_dlatent = load_latents(npy)
            if len(key_dlatent.shape) == 2:
                key_dlatent = np.expand_dims(key_dlatent, 0)
            key_dlatents.append(key_dlatent)
        key_dlatents = np.concatenate(key_dlatents)  # [frm,18,512]
    else:
        print(' No input dlatents found')
        exit()
    key_dlatents = key_dlatents[:, np.newaxis]  # [frm,1,18,512]
    print(' key dlatents', key_dlatents.shape)

    # replace higher layers with single (style) latent
    if a.style_dlat is not None:
        print(' styling with dlatent', a.style_dlat)
        style_dlatent = load_latents(a.style_dlat)
        while len(style_dlatent.shape) < 4:
            style_dlatent = np.expand_dims(style_dlatent, 0)
        # try replacing 5 by other value, less than dl_dim
        key_dlatents[:, :,
                     range(5, dl_dim), :] = style_dlatent[:, :,
                                                          range(5, dl_dim), :]

    frames = key_dlatents.shape[0] * a.fstep

    dlatents = latent_anima(dlat_shape,
                            frames,
                            a.fstep,
                            key_latents=key_dlatents,
                            cubic=a.cubic,
                            verbose=True)  # [frm,1,512]
    print(' dlatents', dlatents.shape)
    frame_count = dlatents.shape[0]

    # truncation trick
    dlatent_avg = Gs.get_var('dlatent_avg')  # (512,)
    tr_range = range(0, 8)
    dlatents[:, :, tr_range, :] = dlatent_avg + (dlatents[:, :, tr_range, :] -
                                                 dlatent_avg) * a.trunc

    # distort image by tweaking initial const layer
    if a.digress > 0:
        try:
            latent_size = Gs.static_kwargs['latent_size']
        except:
            latent_size = 512  # default latent size
        try:
            init_res = Gs.static_kwargs['init_res']
        except:
            init_res = (4, 4)  # default initial layer size
        dconst = a.digress * latent_anima([1, latent_size, *init_res],
                                          frames,
                                          a.fstep,
                                          cubic=True,
                                          verbose=False)
    else:
        dconst = np.zeros([frame_count, 1, 1, 1, 1])

    # generate images from latent timeline
    pbar = ProgressBar(frame_count)
    for i in range(frame_count):

        # generate multi-latent result
        if Gs.num_inputs == 2:
            output = Gs.components.synthesis.run(dlatents[i],
                                                 randomize_noise=False,
                                                 output_transform=fmt,
                                                 minibatch_size=1)
        else:
            output = Gs.components.synthesis.run(dlatents[i], [None],
                                                 dconst[i],
                                                 randomize_noise=False,
                                                 output_transform=fmt,
                                                 minibatch_size=1)

        ext = 'png' if output.shape[3] == 4 else 'jpg'
        filename = osp.join(a.out_dir, "%06d.%s" % (i, ext))
        imsave(filename, output[0])
        pbar.upd()
    def run(self,
            network_pkl,
            run_dir=None,
            data_dir=None,
            dataset_args=None,
            mirror_augment=None,
            num_gpus=1,
            tf_config=None,
            log_results=True,
            num_repeats=1,
            Gs_kwargs=dict(is_validation=True),
            resume_with_new_nets=False,
            truncations=[None]):
        self._reset(network_pkl=network_pkl,
                    run_dir=run_dir,
                    data_dir=data_dir,
                    dataset_args=dataset_args,
                    mirror_augment=mirror_augment)
        with tf.Graph().as_default(), tflib.create_session(
                tf_config).as_default():  # pylint: disable=not-context-manager
            self._report_progress(0, 1)
            _G, _D, Gs = misc.load_pkl(self._network_pkl)

            if resume_with_new_nets:
                dataset = self._get_dataset_obj()
                G = dnnlib.tflib.Network(
                    'G',
                    num_channels=dataset.shape[0],
                    resolution=dataset.shape[1],
                    label_size=dataset.label_size,
                    func_name='training.co_mod_gan.G_main',
                    pix2pix=dataset.pix2pix)
                Gs_new = G.clone('Gs')
                Gs_new.copy_vars_from(Gs)
                Gs = Gs_new

            for t in truncations:
                print('truncation={}'.format(t))
                self._results = []
                time_begin = time.time()

                Gs_kwargs.update(truncation_psi_val=t)
                self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus)
                self._report_progress(1, 1)

                if num_repeats > 1:
                    records = [
                        dnnlib.EasyDict(value=[res.value],
                                        suffix=res.suffix,
                                        fmt=res.fmt) for res in self._results
                    ]
                    for i in range(1, num_repeats):
                        print(self.get_result_str().strip())
                        self._results = []
                        self._report_progress(0, 1)
                        self._evaluate(Gs,
                                       Gs_kwargs=Gs_kwargs,
                                       num_gpus=num_gpus)
                        self._report_progress(1, 1)
                        for rec, res in zip(records, self._results):
                            rec.value.append(res.value)

                    self._results = []
                    for rec in records:
                        self._report_result(np.mean(rec.value), rec.suffix,
                                            rec.fmt)
                        self._report_result(np.std(rec.value),
                                            rec.suffix + '-std', rec.fmt)

                self._eval_time = time.time() - time_begin  # pylint: disable=attribute-defined-outside-init

                if log_results:
                    if run_dir is not None:
                        log_file = os.path.join(run_dir,
                                                'metric-%s.txt' % self.name)
                        with dnnlib.util.Logger(log_file, 'a'):
                            print(self.get_result_str().strip())
                    else:
                        print(self.get_result_str().strip())
示例#21
0
def embed(batch_size, resolution, img, network, iteration, seed=6600):
    tf.reset_default_graph()
    print('Loading networks from "%s"...' % network)
    tflib.init_tf()
    _, _, G = pretrained_networks.load_networks(network)
    img_in = tf.constant(img)
    opt = tf.train.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8)
    noise_vars = [var for name, var in G.components.synthesis.vars.items() if name.startswith('noise')]

    G_kwargs = dnnlib.EasyDict()
    G_kwargs.randomize_noise = False
    G_syn = G.components.synthesis
    loss_list = []
    p_loss_list = []
    m_loss_list = []
    dl_list = []
    si_list = []

    rnd = np.random.RandomState(seed)
    dlatent_avg = [var for name, var in G.vars.items() if name.startswith('dlatent_avg')][0].eval()
    dlatent_avg = np.expand_dims(np.expand_dims(dlatent_avg, 0), 1)
    dlatent_avg = dlatent_avg.repeat(12, 1)
    dlatent = tf.get_variable('dlatent', dtype=tf.float32, initializer=tf.constant(dlatent_avg),
                              trainable=True)
    synth_img = G_syn.get_output_for(dlatent, is_training=False, **G_kwargs)
    # synth_img = (synth_img + 1.0) / 2.0

    with tf.variable_scope('mse_loss'):
        mse_loss = tf.reduce_mean(tf.square(img_in - synth_img))
    with tf.variable_scope('perceptual_loss'):
        vgg_in = tf.concat([img_in, synth_img], 0)
        tf.keras.backend.set_image_data_format('channels_first')
        vgg = tf.keras.applications.VGG16(include_top=False, input_tensor=vgg_in, input_shape=(3, 128, 128),
                                          weights='/gdata2/fengrl/metrics/vgg.h5',
                                          pooling=None)
        h1 = vgg.get_layer('block1_conv1').output
        h2 = vgg.get_layer('block1_conv2').output
        h3 = vgg.get_layer('block3_conv2').output
        h4 = vgg.get_layer('block4_conv2').output
        pcep_loss = tf.reduce_mean(tf.square(h1[0] - h1[1])) + tf.reduce_mean(tf.square(h2[0] - h2[1])) + \
                    tf.reduce_mean(tf.square(h3[0] - h3[1])) + tf.reduce_mean(tf.square(h4[0] - h4[1]))
    loss = 0.5 * mse_loss + 0.5 * pcep_loss
    with tf.control_dependencies([loss]):
        train_op = opt.minimize(loss, var_list=[dlatent])

    tflib.init_uninitialized_vars()
    # rnd = np.random.RandomState(seed)
    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars})  # [height, width]
    for i in range(iteration):
        loss_, p_loss_, m_loss_, dl_, si_, _ = tflib.run([loss, pcep_loss, mse_loss, dlatent, synth_img, train_op])
        loss_list.append(loss_)
        p_loss_list.append(p_loss_)
        m_loss_list.append(m_loss_)
        dl_loss_ = np.sum(np.square(dl_-dlatent_avg))
        dl_list.append(dl_loss_)
        if i % 500 == 0:
            si_list.append(si_)
        if i % 100 == 0:
            print('Loss %f, mse %f, ppl %f, dl %f, step %d' % (loss_, m_loss_, p_loss_,
                                                               dl_loss_, i))
    return loss_list, p_loss_list, m_loss_list, dl_list, si_list
示例#22
0
import os
import sys
import warnings
import numpy as np
import torch
import dnnlib
import traceback

from .. import custom_ops
from .. import misc

# ----------------------------------------------------------------------------

activation_funcs = {
    'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
    'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2),
                            cuda_idx=2, ref='y', has_2nd_grad=False),
    'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2,
                             def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
    'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y',
                            has_2nd_grad=True),
    'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y',
                               has_2nd_grad=True),
    'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y',
                           has_2nd_grad=True),
    'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7,
                            ref='y', has_2nd_grad=True),
    'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1,
                                cuda_idx=8, ref='y', has_2nd_grad=True),
    'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9,
示例#23
0
def training_loop(
        run_dir='.',  # Output directory.
        training_set_kwargs={},  # Options for training set.
        data_loader_kwargs={},  # Options for torch.utils.data.DataLoader.
        G_kwargs={},  # Options for generator network.
        D_kwargs={},  # Options for discriminator network.
        G_opt_kwargs={},  # Options for generator optimizer.
        D_opt_kwargs={},  # Options for discriminator optimizer.
        augment_kwargs=None,  # Options for augmentation pipeline. None = disable.
        loss_kwargs={},  # Options for loss function.
        metrics=[],  # Metrics to evaluate during training.
        random_seed=0,  # Global random seed.
        num_gpus=1,  # Number of GPUs participating in the training.
        rank=0,  # Rank of the current process in [0, num_gpus[.
        batch_size=4,  # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
        batch_gpu=4,  # Number of samples processed at a time by one GPU.
        ema_kimg=10,  # Half-life of the exponential moving average (EMA) of generator weights.
        ema_rampup=None,  # EMA ramp-up coefficient.
        G_reg_interval=4,  # How often to perform regularization for G? None = disable lazy regularization.
        D_reg_interval=16,  # How often to perform regularization for D? None = disable lazy regularization.
        augment_p=0,  # Initial value of augmentation probability.
        ada_target=None,  # ADA target value. None = fixed p.
        ada_interval=4,  # How often to perform ADA adjustment?
        ada_kimg=500,  # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
        nimg=0,  # current image count
        total_kimg=25000,  # Total length of the training, measured in thousands of real images.
        kimg_per_tick=4,  # Progress snapshot interval.
        image_snapshot_ticks=50,  # How often to save image snapshots? None = disable.
        network_snapshot_ticks=50,  # How often to save network snapshots? None = disable.
        resume_pkl=None,  # Network pickle to resume training from.
        cudnn_benchmark=True,  # Enable torch.backends.cudnn.benchmark?
        allow_tf32=False,  # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
        abort_fn=None,  # Callback function for determining whether to abort training. Must return consistent results across ranks.
        progress_fn=None,  # Callback function for updating training progress. Called for all ranks.
):
    # Initialize.
    start_time = time.time()
    device = torch.device('cuda', rank)
    np.random.seed(random_seed * num_gpus + rank)
    torch.manual_seed(random_seed * num_gpus + rank)
    torch.backends.cudnn.benchmark = cudnn_benchmark  # Improves training speed.
    torch.backends.cuda.matmul.allow_tf32 = allow_tf32  # Allow PyTorch to internally use tf32 for matmul
    torch.backends.cudnn.allow_tf32 = allow_tf32  # Allow PyTorch to internally use tf32 for convolutions
    conv2d_gradfix.enabled = True  # Improves training speed.
    grid_sample_gradfix.enabled = True  # Avoids errors with the augmentation pipe.

    # Load training set.
    if rank == 0:
        print('Loading training set...')
    training_set = dnnlib.util.construct_class_by_name(
        **training_set_kwargs)  # subclass of training.dataset.Dataset
    training_set_sampler = misc.InfiniteSampler(dataset=training_set,
                                                rank=rank,
                                                num_replicas=num_gpus,
                                                seed=random_seed)
    training_set_iterator = iter(
        torch.utils.data.DataLoader(dataset=training_set,
                                    sampler=training_set_sampler,
                                    batch_size=batch_size // num_gpus,
                                    **data_loader_kwargs))
    if rank == 0:
        print()
        print('Num images: ', len(training_set))
        print('Image shape:', training_set.image_shape)
        print('Label shape:', training_set.label_shape)
        print()

    # Construct networks.
    if rank == 0:
        print('Constructing networks...')
    common_kwargs = dict(c_dim=training_set.label_dim,
                         img_resolution=training_set.resolution,
                         img_channels=training_set.num_channels)
    G = dnnlib.util.construct_class_by_name(
        **G_kwargs, **common_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
    D = dnnlib.util.construct_class_by_name(
        **D_kwargs, **common_kwargs).train().requires_grad_(False).to(
            device)  # subclass of torch.nn.Module
    G_ema = copy.deepcopy(G).eval()

    # Resume from existing pickle.
    if (resume_pkl is not None) and (rank == 0):
        print(f'Resuming from "{resume_pkl}"')
        with dnnlib.util.open_url(resume_pkl) as f:
            resume_data = legacy.load_network_pkl(f)
        for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
            misc.copy_params_and_buffers(resume_data[name],
                                         module,
                                         require_all=False)

    # Print network summary tables.
    if rank == 0:
        z = torch.empty([batch_gpu, G.z_dim], device=device)
        c = torch.empty([batch_gpu, G.c_dim], device=device)
        img = misc.print_module_summary(G, [z, c])
        misc.print_module_summary(D, [img, c])

    # Setup augmentation.
    if rank == 0:
        print('Setting up augmentation...')
    augment_pipe = None
    ada_stats = None
    if (augment_kwargs is not None) and (augment_p > 0
                                         or ada_target is not None):
        augment_pipe = dnnlib.util.construct_class_by_name(
            **augment_kwargs).train().requires_grad_(False).to(
                device)  # subclass of torch.nn.Module
        augment_pipe.p.copy_(torch.as_tensor(augment_p))
        if ada_target is not None:
            ada_stats = training_stats.Collector(regex='Loss/signs/real')

    # Distribute across GPUs.
    if rank == 0:
        print(f'Distributing across {num_gpus} GPUs...')
    ddp_modules = dict()
    for name, module in [('G_mapping', G.mapping),
                         ('G_synthesis', G.synthesis), ('D', D), (None, G_ema),
                         ('augment_pipe', augment_pipe)]:
        if (num_gpus > 1) and (module is not None) and len(
                list(module.parameters())) != 0:
            module.requires_grad_(True)
            module = torch.nn.parallel.DistributedDataParallel(
                module, device_ids=[device], broadcast_buffers=False)
            module.requires_grad_(False)
        if name is not None:
            ddp_modules[name] = module

    # Setup training phases.
    if rank == 0:
        print('Setting up training phases...')
    loss = dnnlib.util.construct_class_by_name(
        device=device, **ddp_modules,
        **loss_kwargs)  # subclass of training.loss.Loss
    phases = []
    for name, module, opt_kwargs, reg_interval in [
        ('G', G, G_opt_kwargs, G_reg_interval),
        ('D', D, D_opt_kwargs, D_reg_interval)
    ]:
        if reg_interval is None:
            opt = dnnlib.util.construct_class_by_name(
                params=module.parameters(),
                **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [
                dnnlib.EasyDict(name=name + 'both',
                                module=module,
                                opt=opt,
                                interval=1)
            ]
        else:  # Lazy regularization.
            mb_ratio = reg_interval / (reg_interval + 1)
            opt_kwargs = dnnlib.EasyDict(opt_kwargs)
            opt_kwargs.lr = opt_kwargs.lr * mb_ratio
            opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas]
            opt = dnnlib.util.construct_class_by_name(
                module.parameters(),
                **opt_kwargs)  # subclass of torch.optim.Optimizer
            phases += [
                dnnlib.EasyDict(name=name + 'main',
                                module=module,
                                opt=opt,
                                interval=1)
            ]
            phases += [
                dnnlib.EasyDict(name=name + 'reg',
                                module=module,
                                opt=opt,
                                interval=reg_interval)
            ]
    for phase in phases:
        phase.start_event = None
        phase.end_event = None
        if rank == 0:
            phase.start_event = torch.cuda.Event(enable_timing=True)
            phase.end_event = torch.cuda.Event(enable_timing=True)

    # Export sample images.
    grid_size = None
    grid_z = None
    grid_c = None
    if rank == 0:
        print('Exporting sample images...')
        grid_size, images, labels = setup_snapshot_image_grid(
            training_set=training_set)
        save_image_grid(images,
                        os.path.join(run_dir, 'reals.jpg'),
                        drange=[0, 255],
                        grid_size=grid_size)
        grid_z = torch.randn([labels.shape[0], G.z_dim],
                             device=device).split(batch_gpu)
        grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
        images = torch.cat([
            G_ema(z=z, c=c, noise_mode='const').cpu()
            for z, c in zip(grid_z, grid_c)
        ]).numpy()
        save_image_grid(images,
                        os.path.join(run_dir, 'fakes_init.jpg'),
                        drange=[-1, 1],
                        grid_size=grid_size)

    # Initialize logs.
    if rank == 0:
        print('Initializing logs...')
    stats_collector = training_stats.Collector(regex='.*')
    stats_metrics = dict()
    stats_jsonl = None
    stats_tfevents = None
    if rank == 0:
        stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
        try:
            import torch.utils.tensorboard as tensorboard
            stats_tfevents = tensorboard.SummaryWriter(run_dir)
        except ImportError as err:
            print('Skipping tfevents export:', err)

    # Train.
    if rank == 0:
        print(f'Training for {total_kimg} kimg...')
        print()
    cur_nimg = nimg
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    batch_idx = 0
    if progress_fn is not None:
        progress_fn(0, total_kimg)
    while True:

        # Fetch training data.
        with torch.autograd.profiler.record_function('data_fetch'):
            phase_real_img, phase_real_c = next(training_set_iterator)
            phase_real_img = (
                phase_real_img.to(device).to(torch.float32) / 127.5 -
                1).split(batch_gpu)
            phase_real_c = phase_real_c.to(device).split(batch_gpu)
            all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim],
                                    device=device)
            all_gen_z = [
                phase_gen_z.split(batch_gpu)
                for phase_gen_z in all_gen_z.split(batch_size)
            ]
            all_gen_c = [
                training_set.get_label(np.random.randint(len(training_set)))
                for _ in range(len(phases) * batch_size)
            ]
            all_gen_c = torch.from_numpy(
                np.stack(all_gen_c)).pin_memory().to(device)
            all_gen_c = [
                phase_gen_c.split(batch_gpu)
                for phase_gen_c in all_gen_c.split(batch_size)
            ]

        # Execute training phases.
        for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z,
                                                   all_gen_c):
            if batch_idx % phase.interval != 0:
                continue

            # Initialize gradient accumulation.
            if phase.start_event is not None:
                phase.start_event.record(torch.cuda.current_stream(device))
            phase.opt.zero_grad(set_to_none=True)
            phase.module.requires_grad_(True)

            # Accumulate gradients over multiple rounds.
            for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(
                    zip(phase_real_img, phase_real_c, phase_gen_z,
                        phase_gen_c)):
                sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
                gain = phase.interval
                loss.accumulate_gradients(phase=phase.name,
                                          real_img=real_img,
                                          real_c=real_c,
                                          gen_z=gen_z,
                                          gen_c=gen_c,
                                          sync=sync,
                                          gain=gain)

            # Update weights.
            phase.module.requires_grad_(False)
            with torch.autograd.profiler.record_function(phase.name + '_opt'):
                for param in phase.module.parameters():
                    if param.grad is not None:
                        misc.nan_to_num(param.grad,
                                        nan=0,
                                        posinf=1e5,
                                        neginf=-1e5,
                                        out=param.grad)
                phase.opt.step()
            if phase.end_event is not None:
                phase.end_event.record(torch.cuda.current_stream(device))

        # Update G_ema.
        with torch.autograd.profiler.record_function('Gema'):
            ema_nimg = ema_kimg * 1000
            if ema_rampup is not None:
                ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
            ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8))
            for p_ema, p in zip(G_ema.parameters(), G.parameters()):
                p_ema.copy_(p.lerp(p_ema, ema_beta))
            for b_ema, b in zip(G_ema.buffers(), G.buffers()):
                b_ema.copy_(b)

        # Update state.
        cur_nimg += batch_size
        batch_idx += 1

        # Execute ADA heuristic.
        if (ada_stats is not None) and (batch_idx % ada_interval == 0):
            ada_stats.update()
            adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (
                batch_size * ada_interval) / (ada_kimg * 1000)
            augment_pipe.p.copy_(
                (augment_pipe.p + adjust).max(misc.constant(0, device=device)))

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if (not done) and (cur_tick != 0) and (
                cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
            continue

        # Print status line, accumulating the same information in stats_collector.
        tick_end_time = time.time()
        fields = []
        fields += [
            f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"
        ]
        fields += [
            f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"
        ]
        fields += [
            f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"
        ]
        fields += [
            f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"
        ]
        fields += [
            f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"
        ]
        fields += [
            f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"
        ]
        fields += [
            f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"
        ]
        fields += [
            f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"
        ]
        torch.cuda.reset_peak_memory_stats()
        fields += [
            f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"
        ]
        training_stats.report0('Timing/total_hours',
                               (tick_end_time - start_time) / (60 * 60))
        training_stats.report0('Timing/total_days',
                               (tick_end_time - start_time) / (24 * 60 * 60))
        if rank == 0:
            print(' '.join(fields))

        # Check for abort.
        if (not done) and (abort_fn is not None) and abort_fn():
            done = True
            if rank == 0:
                print()
                print('Aborting...')

        # Save image snapshot.
        if (rank == 0) and (image_snapshot_ticks is not None) and (
                done or cur_tick % image_snapshot_ticks == 0):
            images = torch.cat([
                G_ema(z=z, c=c, noise_mode='const').cpu()
                for z, c in zip(grid_z, grid_c)
            ]).numpy()
            save_image_grid(images,
                            os.path.join(run_dir,
                                         f'fakes{cur_nimg//1000:06d}.jpg'),
                            drange=[-1, 1],
                            grid_size=grid_size)

        # Save network snapshot.
        snapshot_pkl = None
        snapshot_data = None
        if (network_snapshot_ticks
                is not None) and (done
                                  or cur_tick % network_snapshot_ticks == 0):
            snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
            for name, module in [('G', G), ('D', D), ('G_ema', G_ema),
                                 ('augment_pipe', augment_pipe)]:
                if module is not None:
                    if num_gpus > 1:
                        misc.check_ddp_consistency(module,
                                                   ignore_regex=r'.*\.w_avg')
                    module = copy.deepcopy(module).eval().requires_grad_(
                        False).cpu()
                snapshot_data[name] = module
                del module  # conserve memory
            snapshot_pkl = os.path.join(
                run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
            if rank == 0:
                with open(snapshot_pkl, 'wb') as f:
                    pickle.dump(snapshot_data, f)

        # Evaluate metrics.
        if (snapshot_data is not None) and (len(metrics) > 0):
            if rank == 0:
                print('Evaluating metrics...')
            for metric in metrics:
                result_dict = metric_main.calc_metric(
                    metric=metric,
                    G=snapshot_data['G_ema'],
                    dataset_kwargs=training_set_kwargs,
                    num_gpus=num_gpus,
                    rank=rank,
                    device=device)
                if rank == 0:
                    metric_main.report_metric(result_dict,
                                              run_dir=run_dir,
                                              snapshot_pkl=snapshot_pkl)
                stats_metrics.update(result_dict.results)
        del snapshot_data  # conserve memory

        # Collect statistics.
        for phase in phases:
            value = []
            if (phase.start_event is not None) and (phase.end_event
                                                    is not None):
                phase.end_event.synchronize()
                value = phase.start_event.elapsed_time(phase.end_event)
            training_stats.report0('Timing/' + phase.name, value)
        stats_collector.update()
        stats_dict = stats_collector.as_dict()

        # Update logs.
        timestamp = time.time()
        if stats_jsonl is not None:
            fields = dict(stats_dict, timestamp=timestamp)
            stats_jsonl.write(json.dumps(fields) + '\n')
            stats_jsonl.flush()
        if stats_tfevents is not None:
            global_step = int(cur_nimg / 1e3)
            walltime = timestamp - start_time
            for name, value in stats_dict.items():
                stats_tfevents.add_scalar(name,
                                          value.mean,
                                          global_step=global_step,
                                          walltime=walltime)
            for name, value in stats_metrics.items():
                stats_tfevents.add_scalar(f'Metrics/{name}',
                                          value,
                                          global_step=global_step,
                                          walltime=walltime)
            stats_tfevents.flush()
        if progress_fn is not None:
            progress_fn(cur_nimg // 1000, total_kimg)

        # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

    # Done.
    if rank == 0:
        print()
        print('Exiting...')
示例#24
0
def embed(batch_size,
          resolution,
          imgs,
          network,
          iteration,
          result_dir,
          seed=6600):
    tf.reset_default_graph()
    print('Loading networks from "%s"...' % network)
    tflib.init_tf()
    _, _, G = pretrained_networks.load_networks(network)
    img_in = tf.placeholder(tf.float32)
    opt = tf.train.AdamOptimizer(learning_rate=0.01,
                                 beta1=0.9,
                                 beta2=0.999,
                                 epsilon=1e-8)
    noise_vars = [
        var for name, var in G.components.synthesis.vars.items()
        if name.startswith('noise')
    ]
    alpha_vars = [
        var for name, var in G.components.synthesis.vars.items()
        if name.endswith('alpha')
    ]
    alpha_eval = [alpha.eval() for alpha in alpha_vars]

    G_kwargs = dnnlib.EasyDict()
    G_kwargs.randomize_noise = False
    G_syn = G.components.synthesis

    rnd = np.random.RandomState(seed)
    dlatent_avg = [
        var for name, var in G.vars.items() if name.startswith('dlatent_avg')
    ][0].eval()
    dlatent_avg = np.expand_dims(np.expand_dims(dlatent_avg, 0), 1)
    dlatent_avg = dlatent_avg.repeat(12, 1)
    dlatent = tf.get_variable('dlatent',
                              dtype=tf.float32,
                              initializer=tf.constant(dlatent_avg),
                              trainable=True)
    synth_img = G_syn.get_output_for(dlatent, is_training=False, **G_kwargs)
    # synth_img = (synth_img + 1.0) / 2.0

    with tf.variable_scope('mse_loss'):
        mse_loss = tf.reduce_mean(tf.square(img_in - synth_img))
    with tf.variable_scope('perceptual_loss'):
        vgg_in = tf.concat([img_in, synth_img], 0)
        tf.keras.backend.set_image_data_format('channels_first')
        vgg = tf.keras.applications.VGG16(
            include_top=False,
            input_tensor=vgg_in,
            input_shape=(3, 128, 128),
            weights='/gdata2/fengrl/metrics/vgg.h5',
            pooling=None)
        h1 = vgg.get_layer('block1_conv1').output
        h2 = vgg.get_layer('block1_conv2').output
        h3 = vgg.get_layer('block3_conv2').output
        h4 = vgg.get_layer('block4_conv2').output
        pcep_loss = tf.reduce_mean(tf.square(h1[0] - h1[1])) + tf.reduce_mean(tf.square(h2[0] - h2[1])) + \
                    tf.reduce_mean(tf.square(h3[0] - h3[1])) + tf.reduce_mean(tf.square(h4[0] - h4[1]))
    loss = 0.5 * mse_loss + 0.5 * pcep_loss
    with tf.control_dependencies([loss]):
        train_op = opt.minimize(loss, var_list=[dlatent])
    reset_opt = tf.variables_initializer(opt.variables())
    reset_dl = tf.variables_initializer([dlatent])

    tflib.init_uninitialized_vars()
    # rnd = np.random.RandomState(seed)
    tflib.set_vars(
        {var: rnd.randn(*var.shape.as_list())
         for var in noise_vars})  # [height, width]
    idx = 0
    metrics_l = []
    metrics_p = []
    metrics_m = []
    metrics_d = []

    metrics_args = [metric_defaults[x] for x in ['fid50k', 'ppl_wend']]

    metrics_fun = metric_base.MetricGroup(metrics_args)
    for temperature in [0.2, 0.5, 1.0, 1.5, 2.0, 10.0]:
        tflib.set_vars({
            alpha: scale_alpha(alpha_np, temperature)
            for alpha, alpha_np in zip(alpha_vars, alpha_eval)
        })
        # misc.save_pkl((G, G, G), os.path.join(result_dir, 'temp%f.pkl' % temperature))
        # metrics_fun.run(os.path.join(result_dir, 'temp%f.pkl' % temperature), run_dir=result_dir,
        #                 data_dir='/gdata/fengrl/noise_test_dset/tfrecords',
        #                 dataset_args=dnnlib.EasyDict(tfrecord_dir='ffhq-128', shuffle_mb=0),
        #                 mirror_augment=True, num_gpus=1)
        for img in imgs:
            img = np.expand_dims(img, 0)
            loss_list = []
            p_loss_list = []
            m_loss_list = []
            dl_list = []
            si_list = []
            tflib.run([reset_opt, reset_dl])
            for i in range(iteration):
                loss_, p_loss_, m_loss_, dl_, si_, _ = tflib.run(
                    [loss, pcep_loss, mse_loss, dlatent, synth_img, train_op],
                    {img_in: img})
                loss_list.append(loss_)
                p_loss_list.append(p_loss_)
                m_loss_list.append(m_loss_)
                dl_loss_ = np.sum(np.square(dl_ - dlatent_avg))
                dl_list.append(dl_loss_)
                if i % 500 == 0:
                    si_list.append(si_)
                if i % 100 == 0:
                    print(
                        'Temperature %f, idx %d, Loss %f, mse %f, ppl %f, dl %f, step %d'
                        % (temperature, idx, loss_, m_loss_, p_loss_, dl_loss_,
                           i))
            print('Temperature %f, idx %d, loss: %f, ppl: %f, mse: %f, d: %f' %
                  (temperature, idx, loss_list[-1], p_loss_list[-1],
                   m_loss_list[-1], dl_list[-1]))
            metrics_l.append(loss_list[-1])
            metrics_p.append(p_loss_list[-1])
            metrics_m.append(m_loss_list[-1])
            metrics_d.append(dl_list[-1])
            misc.save_image_grid(np.concatenate(si_list, 0),
                                 os.path.join(
                                     result_dir,
                                     'temp%fsi%d.png' % (temperature, idx)),
                                 drange=[-1, 1])
            misc.save_image_grid(
                si_list[-1],
                os.path.join(result_dir,
                             'temp%fsifinal%d.png' % (temperature, idx)),
                drange=[-1, 1])
            with open(
                    os.path.join(result_dir,
                                 'temp%fmetric_l%d.txt' % (temperature, idx)),
                    'w') as f:
                for l_ in loss_list:
                    f.write(str(l_) + '\n')
            with open(
                    os.path.join(result_dir,
                                 'temp%fmetric_p%d.txt' % (temperature, idx)),
                    'w') as f:
                for l_ in p_loss_list:
                    f.write(str(l_) + '\n')
            with open(
                    os.path.join(result_dir,
                                 'temp%fmetric_m%d.txt' % (temperature, idx)),
                    'w') as f:
                for l_ in m_loss_list:
                    f.write(str(l_) + '\n')
            with open(
                    os.path.join(result_dir,
                                 'temp%fmetric_d%d.txt' % (temperature, idx)),
                    'w') as f:
                for l_ in dl_list:
                    f.write(str(l_) + '\n')
            idx += 1

        l_mean = np.mean(metrics_l)
        p_mean = np.mean(metrics_p)
        m_mean = np.mean(metrics_m)
        d_mean = np.mean(metrics_d)
        with open(
                os.path.join(result_dir,
                             'Temp%fmetric_lmpd.txt' % temperature), 'w') as f:
            for i in range(len(metrics_l)):
                f.write(
                    str(metrics_l[i]) + '    ' + str(metrics_m[i]) + '    ' +
                    str(metrics_p[i]) + '    ' + str(metrics_d[i]) + '\n')
        print(
            'Overall metrics: temp %f, loss_mean %f, ppl_mean %f, mse_mean %f, d_mean %f'
            % (temperature, l_mean, p_mean, m_mean, d_mean))
        with open(os.path.join(result_dir, 'mean_metrics'), 'a') as f:
            f.write('Temperature %f\n' % temperature)
            f.write('loss %f\n' % l_mean)
            f.write('mse %f\n' % m_mean)
            f.write('ppl %f\n' % p_mean)
            f.write('dl %f\n' % d_mean)
示例#25
0
    batch_size = {256: 16, 512: 9, 1024: 4}
    n_sample = batch_size.get(size, 25)

    g = g.to(device)

    z = np.random.RandomState(0).randn(n_sample, 512).astype("float32")

    with torch.no_grad():
        img_pt, _ = g(
            [torch.from_numpy(z).to(device)],
            truncation=0.5,
            truncation_latent=latent_avg.to(device),
            use_fixed_noise=True,
        )

    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.use_fixed_noise = True
    img_tf = g_ema.run(z, None, **Gs_kwargs)
    img_tf = torch.from_numpy(img_tf).to(device)

    img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
        0.0, 1.0
    )

    img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)

    print(img_diff.abs().max())

    utils.save_image(
        img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)
    )
示例#26
0
def calc_metrics(network_pkl, metric_names, metricdata, mirror, gpus):
    tflib.init_tf()

    # Initialize metrics.
    metrics = []
    for name in metric_names:
        if name not in metric_defaults.metric_defaults:
            raise UserError('\n'.join(
                ['--metrics can only contain the following values:', 'none'] +
                list(metric_defaults.metric_defaults.keys())))
        metrics.append(
            dnnlib.util.construct_class_by_name(
                **metric_defaults.metric_defaults[name]))

    # Load network.
    if not dnnlib.util.is_url(
            network_pkl,
            allow_file_urls=True) and not os.path.isfile(network_pkl):
        raise UserError('--network must point to a file or URL')
    print(f'Loading network from "{network_pkl}"...')
    with dnnlib.util.open_url(network_pkl) as f:
        _G, _D, Gs = pickle.load(f)
        Gs.print_layers()

    # Look up training options.
    run_dir = None
    training_options = None
    if os.path.isfile(network_pkl):
        potential_run_dir = os.path.dirname(network_pkl)
        potential_json_file = os.path.join(potential_run_dir,
                                           'training_options.json')
        if os.path.isfile(potential_json_file):
            print(
                f'Looking up training options from "{potential_json_file}"...')
            run_dir = potential_run_dir
            with open(potential_json_file, 'rt') as f:
                training_options = json.load(f,
                                             object_pairs_hook=dnnlib.EasyDict)
    if training_options is None:
        print(
            'Could not look up training options; will rely on --metricdata and --mirror'
        )

    # Choose dataset options.
    dataset_options = dnnlib.EasyDict()
    if training_options is not None:
        dataset_options.update(training_options.metric_dataset_args)
    dataset_options.resolution = Gs.output_shapes[0][-1]
    dataset_options.max_label_size = Gs.input_shapes[1][-1]
    if metricdata is not None:
        if not os.path.isdir(metricdata):
            raise UserError(
                '--metricdata must point to a directory containing *.tfrecords'
            )
        dataset_options.path = metricdata
    if mirror is not None:
        dataset_options.mirror_augment = mirror
    if 'path' not in dataset_options:
        raise UserError('--metricdata must be specified explicitly')

    # Print dataset options.
    print()
    print('Dataset options:')
    print(json.dumps(dataset_options, indent=2))

    # Evaluate metrics.
    for metric in metrics:
        print()
        print(f'Evaluating {metric.name}...')
        metric.configure(dataset_args=dataset_options, run_dir=run_dir)
        metric.run(network_pkl=network_pkl, num_gpus=gpus)
示例#27
0
def convert_tf_generator(tf_G):
    if tf_G.version < 4:
        raise ValueError('TensorFlow pickle version too low')

    # Collect kwargs.
    tf_kwargs = tf_G.static_kwargs
    known_kwargs = set()

    def kwarg(tf_name, default=None, none=None):
        known_kwargs.add(tf_name)
        val = tf_kwargs.get(tf_name, default)
        return val if val is not None else none

    # Convert kwargs.
    kwargs = dnnlib.EasyDict(
        z_dim=kwarg('latent_size', 512),
        c_dim=kwarg('label_size', 0),
        w_dim=kwarg('dlatent_size', 512),
        img_resolution=kwarg('resolution', 1024),
        img_channels=kwarg('num_channels', 3),
        mapping_kwargs=dnnlib.EasyDict(
            num_layers=kwarg('mapping_layers', 8),
            embed_features=kwarg('label_fmaps', None),
            layer_features=kwarg('mapping_fmaps', None),
            activation=kwarg('mapping_nonlinearity', 'lrelu'),
            lr_multiplier=kwarg('mapping_lrmul', 0.01),
            w_avg_beta=kwarg('w_avg_beta', 0.995, none=1),
        ),
        synthesis_kwargs=dnnlib.EasyDict(
            channel_base=kwarg('fmap_base', 16384) * 2,
            channel_max=kwarg('fmap_max', 512),
            num_fp16_res=kwarg('num_fp16_res', 0),
            conv_clamp=kwarg('conv_clamp', None),
            architecture=kwarg('architecture', 'skip'),
            resample_filter=kwarg('resample_kernel', [1, 3, 3, 1]),
            use_noise=kwarg('use_noise', True),
            activation=kwarg('nonlinearity', 'lrelu'),
        ),
    )

    # Check for unknown kwargs.
    kwarg('truncation_psi')
    kwarg('truncation_cutoff')
    kwarg('style_mixing_prob')
    kwarg('structure')
    unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
    if len(unknown_kwargs) > 0:
        raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])

    # Collect params.
    tf_params = _collect_tf_params(tf_G)
    for name, value in list(tf_params.items()):
        match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
        if match:
            r = kwargs.img_resolution // (2**int(match.group(1)))
            tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
            kwargs.synthesis.kwargs.architecture = 'orig'
    #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')

    # Convert params.
    from training import networks
    G = networks.Generator(**kwargs).eval().requires_grad_(False)
    # pylint: disable=unnecessary-lambda
    _populate_module_params(
        G,
        r'mapping\.w_avg',
        lambda: tf_params[f'dlatent_avg'],
        r'mapping\.embed\.weight',
        lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
        r'mapping\.embed\.bias',
        lambda: tf_params[f'mapping/LabelEmbed/bias'],
        r'mapping\.fc(\d+)\.weight',
        lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
        r'mapping\.fc(\d+)\.bias',
        lambda i: tf_params[f'mapping/Dense{i}/bias'],
        r'synthesis\.b4\.const',
        lambda: tf_params[f'synthesis/4x4/Const/const'][0],
        r'synthesis\.b4\.conv1\.weight',
        lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
        r'synthesis\.b4\.conv1\.bias',
        lambda: tf_params[f'synthesis/4x4/Conv/bias'],
        r'synthesis\.b4\.conv1\.noise_const',
        lambda: tf_params[f'synthesis/noise0'][0, 0],
        r'synthesis\.b4\.conv1\.noise_strength',
        lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
        r'synthesis\.b4\.conv1\.affine\.weight',
        lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
        r'synthesis\.b4\.conv1\.affine\.bias',
        lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.conv0\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r'synthesis\.b(\d+)\.conv0\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
        r'synthesis\.b(\d+)\.conv0\.noise_const',
        lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0
                                                                          ],
        r'synthesis\.b(\d+)\.conv0\.noise_strength',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
        r'synthesis\.b(\d+)\.conv0\.affine\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].
        transpose(),
        r'synthesis\.b(\d+)\.conv0\.affine\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.conv1\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(
            3, 2, 0, 1),
        r'synthesis\.b(\d+)\.conv1\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
        r'synthesis\.b(\d+)\.conv1\.noise_const',
        lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0
                                                                          ],
        r'synthesis\.b(\d+)\.conv1\.noise_strength',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
        r'synthesis\.b(\d+)\.conv1\.affine\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
        r'synthesis\.b(\d+)\.conv1\.affine\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.torgb\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(
            3, 2, 0, 1),
        r'synthesis\.b(\d+)\.torgb\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
        r'synthesis\.b(\d+)\.torgb\.affine\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
        r'synthesis\.b(\d+)\.torgb\.affine\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.skip\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r'.*\.resample_filter',
        None,
    )
    return G
示例#28
0
 def get_details(self, idx):
     d = dnnlib.EasyDict()
     d.raw_idx = int(self._raw_idx[idx])
     d.xflip = (int(self._xflip[idx]) != 0)
     d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
     return d
示例#29
0
def setup_training_options(data_path, out_dir, resume=None, mirror=True):

    ###########################################################################################################################
    #                                                 EDIT THESE!                                                             #
    ###########################################################################################################################
    outdir = out_dir
    gpus = None  # Number of GPUs: <int>, default = 1 gpu
    snap = 1  # Snapshot interval: <int>, default = 50 ticks
    seed = 1000
    data = data_path  # Training dataset (required): <path>
    res = None  # Override dataset resolution: <int>, default = highest available
    mirror = mirror  # Augment dataset with x-flips: <bool>, default = False
    metrics = []  # List of metric names: [], ['fid50k_full'] (default), ...
    metricdata = None  # Metric dataset (optional): <path>
    cfg = 'stylegan2'  # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline'
    gamma = None  # Override R1 gamma: <float>, default = depends on cfg
    kimg = 10000  # Override training duration: <int>, default = depends on cfg
    aug = 'ada'  # Augmentation mode: 'ada' (default), 'noaug', 'fixed', 'adarv'
    p = None  # Specify p for 'fixed' (required): <float>
    target = None  # Override ADA target for 'ada' and 'adarv': <float>, default = depends on aug
    augpipe = 'bgc'  # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'
    cmethod = None  # Comparison method: 'nocmethod' (default), 'bcr', 'zcr', 'pagan', 'wgangp', 'auxrot', 'spectralnorm', 'shallowmap', 'adropout'
    dcap = None  # Multiplier for discriminator capacity: <float>, default = 1
    augpipe = 'bgc'
    resume = resume  # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
    freezed = None  # Freeze-D: <int>, default = 0 discriminator layers

    ###########################################################################################################################
    #                                                 End of Edit Section                                                     #
    ###########################################################################################################################

    tflib.init_tf({'rnd.np_random_seed': seed})

    # Initialize dicts.
    args = dnnlib.EasyDict()
    args.G_args = dnnlib.EasyDict(func_name='training.networks.G_main')
    args.D_args = dnnlib.EasyDict(func_name='training.networks.D_main')
    args.G_opt_args = dnnlib.EasyDict(beta1=0.0, beta2=0.99)
    args.D_opt_args = dnnlib.EasyDict(beta1=0.0, beta2=0.99)
    args.loss_args = dnnlib.EasyDict(func_name='training.loss.stylegan2')
    args.augment_args = dnnlib.EasyDict(
        class_name='training.augment.AdaptiveAugment')

    # ---------------------------
    # General options: gpus, snap
    # ---------------------------

    if gpus is None:
        gpus = 1
    assert isinstance(gpus, int)
    if not (gpus >= 1 and gpus & (gpus - 1) == 0):
        raise UserError('--gpus must be a power of two')
    args.num_gpus = gpus

    if snap is None:
        snap = 50
    assert isinstance(snap, int)
    if snap < 1:
        raise UserError('--snap must be at least 1')
    args.image_snapshot_ticks = snap
    args.network_snapshot_ticks = snap

    # -----------------------------------
    # Training dataset: data, res, mirror
    # -----------------------------------

    assert data is not None
    assert isinstance(data, str)
    data_name = os.path.basename(os.path.abspath(data))
    if not os.path.isdir(data) or len(data_name) == 0:
        raise UserError(
            '--data must point to a directory containing *.tfrecords')
    desc = data_name

    with tf.Graph().as_default(), tflib.create_session().as_default():  # pylint: disable=not-context-manager
        args.train_dataset_args = dnnlib.EasyDict(path=data,
                                                  max_label_size='full')
        dataset_obj = dataset.load_dataset(
            **args.train_dataset_args
        )  # try to load the data and see what comes out
        args.train_dataset_args.resolution = dataset_obj.shape[
            -1]  # be explicit about resolution
        args.train_dataset_args.max_label_size = dataset_obj.label_size  # be explicit about label size
        validation_set_available = dataset_obj.has_validation_set
        dataset_obj.close()
        dataset_obj = None

    if res is None:
        res = args.train_dataset_args.resolution
    else:
        assert isinstance(res, int)
        if not (res >= 4 and res & (res - 1) == 0):
            raise UserError('--res must be a power of two and at least 4')
        if res > args.train_dataset_args.resolution:
            raise UserError(
                f'--res cannot exceed maximum available resolution in the dataset ({args.train_dataset_args.resolution})'
            )
        desc += f'-res{res:d}'
    args.train_dataset_args.resolution = res

    if mirror is None:
        mirror = False
    else:
        assert isinstance(mirror, bool)
        if mirror:
            desc += '-mirror'
    args.train_dataset_args.mirror_augment = mirror

    # ----------------------------
    # Metrics: metrics, metricdata
    # ----------------------------

    if metrics is None:
        metrics = ['fid50k_full']
    assert isinstance(metrics, list)
    assert all(isinstance(metric, str) for metric in metrics)

    args.metric_arg_list = []
    for metric in metrics:
        if metric not in metric_defaults.metric_defaults:
            raise UserError('\n'.join(
                ['--metrics can only contain the following values:', 'none'] +
                list(metric_defaults.metric_defaults.keys())))
        args.metric_arg_list.append(metric_defaults.metric_defaults[metric])

    args.metric_dataset_args = dnnlib.EasyDict(args.train_dataset_args)
    if metricdata is not None:
        assert isinstance(metricdata, str)
        if not os.path.isdir(metricdata):
            raise UserError(
                '--metricdata must point to a directory containing *.tfrecords'
            )
        args.metric_dataset_args.path = metricdata

    # -----------------------------
    # Base config: cfg, gamma, kimg
    # -----------------------------

    if cfg is None:
        cfg = 'auto'
    assert isinstance(cfg, str)
    desc += f'-{cfg}'

    cfg_specs = {
        'auto':
        dict(ref_gpus=-1,
             kimg=25000,
             mb=-1,
             mbstd=-1,
             fmaps=-1,
             lrate=-1,
             gamma=-1,
             ema=-1,
             ramp=0.05,
             map=2),  # populated dynamically based on 'gpus' and 'res'
        'stylegan2':
        dict(ref_gpus=8,
             kimg=25000,
             mb=32,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=10,
             ema=10,
             ramp=None,
             map=8),  # uses mixed-precision, unlike original StyleGAN2
        'paper256':
        dict(ref_gpus=8,
             kimg=25000,
             mb=64,
             mbstd=8,
             fmaps=0.5,
             lrate=0.0025,
             gamma=1,
             ema=20,
             ramp=None,
             map=8),
        'paper512':
        dict(ref_gpus=8,
             kimg=25000,
             mb=64,
             mbstd=8,
             fmaps=1,
             lrate=0.0025,
             gamma=0.5,
             ema=20,
             ramp=None,
             map=8),
        'paper1024':
        dict(ref_gpus=8,
             kimg=25000,
             mb=32,
             mbstd=4,
             fmaps=1,
             lrate=0.002,
             gamma=2,
             ema=10,
             ramp=None,
             map=8),
        'cifar':
        dict(ref_gpus=2,
             kimg=100000,
             mb=64,
             mbstd=32,
             fmaps=0.5,
             lrate=0.0025,
             gamma=0.01,
             ema=500,
             ramp=0.05,
             map=2),
        'cifarbaseline':
        dict(ref_gpus=2,
             kimg=100000,
             mb=64,
             mbstd=32,
             fmaps=0.5,
             lrate=0.0025,
             gamma=0.01,
             ema=500,
             ramp=0.05,
             map=8),
    }

    assert cfg in cfg_specs
    spec = dnnlib.EasyDict(cfg_specs[cfg])
    if cfg == 'auto':
        desc += f'{gpus:d}'
        spec.ref_gpus = gpus
        spec.mb = max(min(gpus * min(4096 // res, 32), 64),
                      gpus)  # keep gpu memory consumption at bay
        spec.mbstd = min(
            spec.mb // gpus, 4
        )  # other hyperparams behave more predictably if mbstd group size remains fixed
        spec.fmaps = 1 if res >= 512 else 0.5
        spec.lrate = 0.002 if res >= 1024 else 0.0025
        spec.gamma = 0.0002 * (res**2) / spec.mb  # heuristic formula
        spec.ema = spec.mb * 10 / 32

    args.total_kimg = spec.kimg
    args.minibatch_size = spec.mb
    args.minibatch_gpu = spec.mb // spec.ref_gpus
    args.D_args.mbstd_group_size = spec.mbstd
    args.G_args.fmap_base = args.D_args.fmap_base = int(spec.fmaps * 16384)
    args.G_args.fmap_max = args.D_args.fmap_max = 512
    args.G_opt_args.learning_rate = args.D_opt_args.learning_rate = spec.lrate
    args.loss_args.r1_gamma = spec.gamma
    args.G_smoothing_kimg = spec.ema
    args.G_smoothing_rampup = spec.ramp
    args.G_args.mapping_layers = spec.map
    args.G_args.num_fp16_res = args.D_args.num_fp16_res = 4  # enable mixed-precision training
    args.G_args.conv_clamp = args.D_args.conv_clamp = 256  # clamp activations to avoid float16 overflow

    if cfg == 'cifar':
        args.loss_args.pl_weight = 0  # disable path length regularization
        args.G_args.style_mixing_prob = None  # disable style mixing
        args.D_args.architecture = 'orig'  # disable residual skip connections

    if gamma is not None:
        assert isinstance(gamma, float)
        if not gamma >= 0:
            raise UserError('--gamma must be non-negative')
        desc += f'-gamma{gamma:g}'
        args.loss_args.r1_gamma = gamma

    if kimg is not None:
        assert isinstance(kimg, int)
        if not kimg >= 1:
            raise UserError('--kimg must be at least 1')
        desc += f'-kimg{kimg:d}'
        args.total_kimg = kimg

    # ---------------------------------------------------
    # Discriminator augmentation: aug, p, target, augpipe
    # ---------------------------------------------------

    if aug is None:
        aug = 'ada'
    else:
        assert isinstance(aug, str)
        desc += f'-{aug}'

    if aug == 'ada':
        args.augment_args.tune_heuristic = 'rt'
        args.augment_args.tune_target = 0.6

    elif aug == 'noaug':
        pass

    elif aug == 'fixed':
        if p is None:
            raise UserError(f'--aug={aug} requires specifying --p')

    elif aug == 'adarv':
        if not validation_set_available:
            raise UserError(
                f'--aug={aug} requires separate validation set; please see "python dataset_tool.py pack -h"'
            )
        args.augment_args.tune_heuristic = 'rv'
        args.augment_args.tune_target = 0.5

    else:
        raise UserError(f'--aug={aug} not supported')

    if p is not None:
        assert isinstance(p, float)
        if aug != 'fixed':
            raise UserError('--p can only be specified with --aug=fixed')
        if not 0 <= p <= 1:
            raise UserError('--p must be between 0 and 1')
        desc += f'-p{p:g}'
        args.augment_args.initial_strength = p

    if target is not None:
        assert isinstance(target, float)
        if aug not in ['ada', 'adarv']:
            raise UserError(
                '--target can only be specified with --aug=ada or --aug=adarv')
        if not 0 <= target <= 1:
            raise UserError('--target must be between 0 and 1')
        desc += f'-target{target:g}'
        args.augment_args.tune_target = target

    assert augpipe is None or isinstance(augpipe, str)
    if augpipe is None:
        augpipe = 'bgc'
    else:
        if aug == 'noaug':
            raise UserError('--augpipe cannot be specified with --aug=noaug')
        desc += f'-{augpipe}'

    augpipe_specs = {
        'blit':
        dict(xflip=1, rotate90=1, xint=1),
        'geom':
        dict(scale=1, rotate=1, aniso=1, xfrac=1),
        'color':
        dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
        'filter':
        dict(imgfilter=1),
        'noise':
        dict(noise=1),
        'cutout':
        dict(cutout=1),
        'bg':
        dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
        'bgc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1),
        'bgcf':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1),
        'bgcfn':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1),
        'bgcfnc':
        dict(xflip=1,
             rotate90=1,
             xint=1,
             scale=1,
             rotate=1,
             aniso=1,
             xfrac=1,
             brightness=1,
             contrast=1,
             lumaflip=1,
             hue=1,
             saturation=1,
             imgfilter=1,
             noise=1,
             cutout=1),
    }

    assert augpipe in augpipe_specs
    if aug != 'noaug':
        args.augment_args.apply_func = 'training.augment.augment_pipeline'
        args.augment_args.apply_args = augpipe_specs[augpipe]

    # ---------------------------------
    # Comparison methods: cmethod, dcap
    # ---------------------------------

    assert cmethod is None or isinstance(cmethod, str)
    if cmethod is None:
        cmethod = 'nocmethod'
    else:
        desc += f'-{cmethod}'

    if cmethod == 'nocmethod':
        pass

    elif cmethod == 'bcr':
        args.loss_args.func_name = 'training.loss.cmethods'
        args.loss_args.bcr_real_weight = 10
        args.loss_args.bcr_fake_weight = 10
        args.loss_args.bcr_augment = dnnlib.EasyDict(
            func_name='training.augment.augment_pipeline',
            xint=1,
            xint_max=1 / 32)

    elif cmethod == 'zcr':
        args.loss_args.func_name = 'training.loss.cmethods'
        args.loss_args.zcr_gen_weight = 0.02
        args.loss_args.zcr_dis_weight = 0.2
        args.G_args.num_fp16_res = args.D_args.num_fp16_res = 0  # disable mixed-precision training
        args.G_args.conv_clamp = args.D_args.conv_clamp = None

    elif cmethod == 'pagan':
        if aug != 'noaug':
            raise UserError(
                f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug'
            )
        args.D_args.use_pagan = True
        args.augment_args.tune_heuristic = 'rt'  # enable ada heuristic
        args.augment_args.pop('apply_func',
                              None)  # disable discriminator augmentation
        args.augment_args.pop('apply_args', None)
        args.augment_args.tune_target = 0.95

    elif cmethod == 'wgangp':
        if aug != 'noaug':
            raise UserError(
                f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug'
            )
        if gamma is not None:
            raise UserError(
                f'--cmethod={cmethod} is not compatible with --gamma')
        args.loss_args = dnnlib.EasyDict(func_name='training.loss.wgangp')
        args.G_opt_args.learning_rate = args.D_opt_args.learning_rate = 0.001
        args.G_args.num_fp16_res = args.D_args.num_fp16_res = 0  # disable mixed-precision training
        args.G_args.conv_clamp = args.D_args.conv_clamp = None
        args.lazy_regularization = False

    elif cmethod == 'auxrot':
        if args.train_dataset_args.max_label_size > 0:
            raise UserError(
                f'--cmethod={cmethod} is not compatible with label conditioning; please specify a dataset without labels'
            )
        args.loss_args.func_name = 'training.loss.cmethods'
        args.loss_args.auxrot_alpha = 10
        args.loss_args.auxrot_beta = 5
        args.D_args.score_max = 5  # prepare D to output 5 scalars per image instead of just 1

    elif cmethod == 'spectralnorm':
        args.D_args.use_spectral_norm = True

    elif cmethod == 'shallowmap':
        if args.G_args.mapping_layers == 2:
            raise UserError(f'--cmethod={cmethod} is a no-op for --cfg={cfg}')
        args.G_args.mapping_layers = 2

    elif cmethod == 'adropout':
        if aug != 'noaug':
            raise UserError(
                f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug'
            )
        args.D_args.adaptive_dropout = 1
        args.augment_args.tune_heuristic = 'rt'  # enable ada heuristic
        args.augment_args.pop('apply_func',
                              None)  # disable discriminator augmentation
        args.augment_args.pop('apply_args', None)
        args.augment_args.tune_target = 0.6

    else:
        raise UserError(f'--cmethod={cmethod} not supported')

    if dcap is not None:
        assert isinstance(dcap, float)
        if not dcap > 0:
            raise UserError('--dcap must be positive')
        desc += f'-dcap{dcap:g}'
        args.D_args.fmap_base = max(int(args.D_args.fmap_base * dcap), 1)
        args.D_args.fmap_max = max(int(args.D_args.fmap_max * dcap), 1)

    # ----------------------------------
    # Transfer learning: resume, freezed
    # ----------------------------------

    resume_specs = {
        'ffhq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
        'ffhq512':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
        'ffhq1024':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
        'celebahq256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
        'lsundog256':
        'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
    }

    assert resume is None or isinstance(resume, str)
    if resume is None:
        resume = 'noresume'
    elif resume == 'noresume':
        desc += '-noresume'
    elif resume in resume_specs:
        desc += f'-resume{resume}'
        args.resume_pkl = resume_specs[resume]  # predefined url
    else:
        desc += '-resumecustom'
        args.resume_pkl = resume  # custom path or url

    if resume != 'noresume':
        args.augment_args.tune_kimg = 100  # make ADA react faster at the beginning
        args.G_smoothing_rampup = None  # disable EMA rampup

    if freezed is not None:
        assert isinstance(freezed, int)
        if not freezed >= 0:
            raise UserError('--freezed must be non-negative')
        desc += f'-freezed{freezed:d}'
        args.D_args.freeze_layers = freezed

    return desc, args, outdir
示例#30
0
def generate(args):
    np.random.seed()
    device = torch.device('cpu')

    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.size = args['size']
    Gs_kwargs.scale_type = args['scale_type']

    nHW = [int(s) for s in args['nXY'].split('-')][::-1]
    n_mult = nHW[0] * nHW[1]
    lmask = np.tile(np.asarray([[[[1]]]]), (1, n_mult, 1, 1))
    Gs_kwargs.countHW = nHW
    Gs_kwargs.splitfine = args['splitfine']
    lmask = torch.from_numpy(lmask).to(device)

    # load base or custom network
    pkl_name = osp.splitext(args['model'])[0]
    with dnnlib.util.open_url(pkl_name + '.pkl') as f:
        Gs = legacy.load_network_pkl(f, custom=False, **Gs_kwargs)['G_ema'].to(
            device)  # type: ignore

    lats = []  # list of [frm,1,512]
    for i in range(n_mult):
        lat_tmp = latent_anima((1, Gs.z_dim),
                               args['frames'],
                               args['fstep'],
                               cubic=args['cubic'],
                               gauss=args['gauss'])  # [frm,1,512]
        lats.append(lat_tmp)  # list of [frm,1,512]
    latents = np.concatenate(lats, 1)  # [frm,X,512]
    latents = torch.from_numpy(latents).to(device)

    dconst = np.zeros([1, 1, 1, 1, 1])
    dconst = torch.from_numpy(dconst).to(device)

    # labels / conditions
    label_size = Gs.c_dim
    if label_size > 0:
        labels = torch.zeros((1, n_mult, label_size),
                             device=device)  # [frm,X,lbl]
        label_ids = []
        for i in range(n_mult):
            label_ids.append(random.randint(0, label_size - 1))
        for i, l in enumerate(label_ids):
            labels[:, i, l] = 1
    else:
        labels = [None]

    # generate images from latent timeline
    latent = latents[0]  # [X,512]
    label = labels[0 % len(labels)]
    latmask = lmask[0 % len(lmask)] if lmask is not None else [None]  # [X,h,w]
    dc = dconst[0 % len(dconst)]  # [X,512,4,4]

    # generate multi-latent result
    Gs = Gs.float()
    output = Gs(latent,
                label,
                force_fp32=True,
                truncation_psi=args['trunc'],
                noise_mode='const')
    output = (output.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(
        torch.uint8).cpu().numpy()

    # save image
    ext = 'png' if output.shape[3] == 4 else 'jpg'
    filename = osp.join(args['out_dir'], "%06d.%s" % (0, ext))
    return Image.fromarray(output[0], 'RGB')