示例#1
0
def calc_metric(
    metric, **kwargs
):  # See metric_utils.MetricOptions for the full list of arguments.
    assert is_valid_metric(metric)
    opts = metric_utils.MetricOptions(**kwargs)

    # Calculate.
    start_time = time.time()
    results = _metric_dict[metric](opts)
    total_time = time.time() - start_time

    # Broadcast results.
    for key, value in list(results.items()):
        if opts.num_gpus > 1:
            value = torch.as_tensor(value,
                                    dtype=torch.float64,
                                    device=opts.device)
            torch.distributed.broadcast(tensor=value, src=0)
            value = float(value.cpu())
        results[key] = value

    # Decorate with metadata.
    return dnnlib.EasyDict(
        results=dnnlib.EasyDict(results),
        metric=metric,
        total_time=total_time,
        total_time_str=dnnlib.util.format_time(total_time),
        num_gpus=opts.num_gpus,
    )
示例#2
0
def instantiate_G(cfg: DictConfig) -> nn.Module:
    hydra_cfg = compose(config_name=f'../../configs/{cfg.model}.yml')

    if cfg.model in ['inr-gan', 'cips']:
        hydra_cfg.generator.coords.use_full_cache = True

    G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator',
                               z_dim=512,
                               w_dim=512,
                               mapping_kwargs=dnnlib.EasyDict(),
                               synthesis_kwargs=dnnlib.EasyDict())
    G_kwargs.synthesis_kwargs.channel_base = int(
        hydra_cfg.generator.get('fmaps', 0.5) * 32768)
    G_kwargs.synthesis_kwargs.channel_max = 512
    G_kwargs.mapping_kwargs.num_layers = hydra_cfg.generator.get(
        'mapping_net_n_layers', 2)
    if cfg.get('num_fp16_res', 0) > 0:
        G_kwargs.synthesis_kwargs.num_fp16_res = cfg.num_fp16_res
        G_kwargs.synthesis_kwargs.conv_clamp = 256
    G_kwargs.cfg = OmegaConf.to_container(hydra_cfg.generator)
    G_kwargs.c_dim = 0
    G_kwargs.img_resolution = cfg.get('resolution', 256)
    G_kwargs.img_channels = 3

    G = dnnlib.util.construct_class_by_name(
        **G_kwargs).eval().requires_grad_(False).to(cfg.device)

    return G
示例#3
0
 def _get_cache_file_for_reals(self, extension='pkl', **kwargs):
     all_args = dnnlib.EasyDict(metric_name=self.name, mirror_augment=self._mirror_augment)
     all_args.update(self._dataset_args)
     all_args.update(kwargs)
     md5 = hashlib.md5(repr(sorted(all_args.items())).encode('utf-8'))
     dataset_name = self._dataset_args['tfrecord_dir'].replace('\\', '/').split('/')[-1]
     return os.path.join(config.cache_dir, '%s-%s-%s.%s' % (md5.hexdigest(), self.name, dataset_name, extension))
示例#4
0
 def _report_result(self, value, suffix='', fmt='%-10.4f'):
     self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]
示例#5
0
import os
import time
import hashlib
import numpy as np
import tensorflow as tf
from src import dnnlib, dnnlib as tflib

import config
from training import misc
from training import dataset

#----------------------------------------------------------------------------
# Standard metrics.

fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8)
ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16)
ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16)
ppl_zend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zend', num_samples=100000, epsilon=1e-4, space='z', sampling='end', minibatch_per_gpu=16)
ppl_wend = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wend', num_samples=100000, epsilon=1e-4, space='w', sampling='end', minibatch_per_gpu=16)
ls = dnnlib.EasyDict(func_name='metrics.linear_separability.LS', name='ls', num_samples=200000, num_keep=100000, attrib_indices=range(40), minibatch_per_gpu=4)
dummy = dnnlib.EasyDict(func_name='metrics.metric_base.DummyMetric', name='dummy') # for debugging

#----------------------------------------------------------------------------
# Base class for metrics.

class MetricBase:
    def __init__(self, name):
        self.name = name
        self._network_pkl = None
        self._dataset_args = None
示例#6
0
def G_style(
    latents_in,  # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,  # Second input: Conditioning labels [minibatch, label_size].
    truncation_psi=0.7,  # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff=8,  # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val=None,  # Value for truncation_psi to use during validation.
    truncation_cutoff_val=None,  # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta=0.995,  # Decay for tracking the moving average of W during training. None = disable.
    style_mixing_prob=0.9,  # Probability of mixing styles during training. None = disable.
    is_training=False,  # Network is under training? Enables and disables specific features.
    is_validation=False,  # Network is under validation? Chooses which value to use for truncation_psi.
    is_template_graph=False,  # True = template graph constructed by the Network class, False = actual evaluation.
    components=dnnlib.EasyDict(
    ),  # Container for sub-networks. Retained between calls.
    **kwargs):  # Arguments for sub-networks (G_mapping and G_synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, dnnlib.EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None
                       and not tflib.is_tf_expression(truncation_psi)
                       and truncation_psi == 1):
        truncation_psi = None
    if is_training or (truncation_cutoff is not None
                       and not tflib.is_tf_expression(truncation_cutoff)
                       and truncation_cutoff <= 0):
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None
                           and not tflib.is_tf_expression(dlatent_avg_beta)
                           and dlatent_avg_beta == 1):
        dlatent_avg_beta = None
    if not is_training or (style_mixing_prob is not None
                           and not tflib.is_tf_expression(style_mixing_prob)
                           and style_mixing_prob <= 0):
        style_mixing_prob = None

    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network('G_synthesis',
                                             func_name=G_synthesis,
                                             **kwargs)
    num_layers = components.synthesis.input_shape[1]
    dlatent_size = components.synthesis.input_shape[2]
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping',
                                           func_name=G_mapping,
                                           dlatent_broadcast=num_layers,
                                           **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
    dlatent_avg = tf.get_variable('dlatent_avg',
                                  shape=[dlatent_size],
                                  initializer=tf.initializers.zeros(),
                                  trainable=False)

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in, labels_in,
                                                 **kwargs)

    # Update moving average of W.
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(
                dlatent_avg,
                tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

    # Perform style mixing regularization.
    if style_mixing_prob is not None:
        with tf.name_scope('StyleMix'):
            latents2 = tf.random_normal(tf.shape(latents_in))
            dlatents2 = components.mapping.get_output_for(
                latents2, labels_in, **kwargs)
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)
            dlatents = tf.where(
                tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)),
                dlatents, dlatents2)

    # Apply truncation trick.
    if truncation_psi is not None and truncation_cutoff is not None:
        with tf.variable_scope('Truncation'):
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            ones = np.ones(layer_idx.shape, dtype=np.float32)
            coefs = tf.where(layer_idx < truncation_cutoff,
                             truncation_psi * ones, ones)
            dlatents = tflib.lerp(dlatent_avg, dlatents, coefs)

    # Evaluate synthesis network.
    with tf.control_dependencies(
        [tf.assign(components.synthesis.find_var('lod'), lod_in)]):
        images_out = components.synthesis.get_output_for(
            dlatents, force_clean_graph=is_template_graph, **kwargs)
    return tf.identity(images_out, name='images_out')
示例#7
0
def training_schedule(
    cur_nimg,
    training_set,
    num_gpus,
    lod_initial_resolution=4,  # Image resolution used at the beginning.
    lod_training_kimg=600,  # Thousands of real images to show before doubling the resolution.
    lod_transition_kimg=600,  # Thousands of real images to show when fading in new layers.
    minibatch_base=16,  # Maximum minibatch size, divided evenly among GPUs.
    minibatch_dict={},  # Resolution-specific overrides.
    max_minibatch_per_gpu={},  # Resolution-specific maximum minibatch size per GPU.
    G_lrate_base=0.001,  # Learning rate for the generator.
    G_lrate_dict={},  # Resolution-specific overrides.
    D_lrate_base=0.001,  # Learning rate for the discriminator.
    D_lrate_dict={},  # Resolution-specific overrides.
    lrate_rampup_kimg=0,  # Duration of learning rate ramp-up.
    tick_kimg_base=160,  # Default interval of progress snapshots.
    tick_kimg_dict={
        4: 160,
        8: 140,
        16: 120,
        32: 100,
        64: 80,
        128: 60,
        256: 40,
        512: 30,
        1024: 20
    }):  # Resolution-specific overrides.

    # Initialize result dict.
    s = dnnlib.EasyDict()
    s.kimg = cur_nimg / 1000.0

    # Training phase.
    phase_dur = lod_training_kimg + lod_transition_kimg
    phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0
    phase_kimg = s.kimg - phase_idx * phase_dur

    # Level-of-detail and resolution.
    s.lod = training_set.resolution_log2
    s.lod -= np.floor(np.log2(lod_initial_resolution))
    s.lod -= phase_idx
    if lod_transition_kimg > 0:
        s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg
    s.lod = max(s.lod, 0.0)
    s.resolution = 2**(training_set.resolution_log2 - int(np.floor(s.lod)))

    # Minibatch size.
    s.minibatch = minibatch_dict.get(s.resolution, minibatch_base)
    s.minibatch -= s.minibatch % num_gpus
    if s.resolution in max_minibatch_per_gpu:
        s.minibatch = min(s.minibatch,
                          max_minibatch_per_gpu[s.resolution] * num_gpus)

    # Learning rate.
    s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base)
    s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base)
    if lrate_rampup_kimg > 0:
        rampup = min(s.kimg / lrate_rampup_kimg, 1.0)
        s.G_lrate *= rampup
        s.D_lrate *= rampup

    # Other parameters.
    s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base)
    return s