def G_main( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. truncation_psi=0.5, # Style strength multiplier for the truncation trick. None = disable. truncation_cutoff=None, # Number of layers for which to apply the truncation trick. None = disable. truncation_psi_val=None, # Value for truncation_psi to use during validation. truncation_cutoff_val=None, # Value for truncation_cutoff to use during validation. dlatent_avg_beta=0.995, # Decay for tracking the moving average of W during training. None = disable. style_mixing_prob=0.9, # Probability of mixing styles during training. None = disable. is_training=False, # Network is under training? Enables and disables specific features. is_validation=False, # Network is under validation? Chooses which value to use for truncation_psi. return_dlatents=False, # Return dlatents in addition to the images? is_template_graph=False, # True = template graph constructed by the Network class, False = actual evaluation. components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. mapping_latent_func='G_mapping_latent', # Build func name for the mapping network. mapping_label_func='label_mapping', # Build func name for the mapping network. synthesis_func='G_synthesis_stylegan2', # Build func name for the synthesis network. **kwargs): # Arguments for sub-networks (mapping and synthesis). # Validate arguments. assert not is_training or not is_validation assert isinstance(components, dnnlib.EasyDict) if is_validation: truncation_psi = truncation_psi_val truncation_cutoff = truncation_cutoff_val if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): truncation_psi = None if is_training: truncation_cutoff = None if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): dlatent_avg_beta = None if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): style_mixing_prob = None # Setup components. if 'synthesis' not in components: components.synthesis = tflib.Network( 'G_synthesis', func_name=globals()[synthesis_func], **kwargs) num_layers = components.synthesis.input_shape[1] if 'mapping_latent' not in components: components.mapping_latent = tflib.Network( 'G_mapping_latent', func_name=globals()[mapping_latent_func], dlatent_broadcast=num_layers, **kwargs) if 'mapping_label' not in components: components.mapping_label = tflib.Network( 'label_mapping', func_name=globals()[mapping_label_func], dlatent_broadcast=num_layers, **kwargs) # Setup variables. lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) # Evaluate mapping network. dlatents = components.mapping_latent.get_output_for( latents_in, is_training=is_training, **kwargs) dlatents = tf.cast(dlatents, tf.float32) dlatents_label = components.mapping_label.get_output_for( labels_in, is_training=is_training, **kwargs) dlatents_label = tf.cast(dlatents_label, tf.float32) dlatents = tf.concat([dlatents, dlatents_label], axis=-1) # Perform style mixing regularization. if style_mixing_prob is not None: with tf.variable_scope('StyleMix'): latents2 = tf.random_normal(tf.shape(latents_in)) dlatents2 = components.mapping_latent.get_output_for( latents2, is_training=is_training, **kwargs) dlatents2 = tf.cast(dlatents2, tf.float32) dlatents_label2 = components.mapping_label.get_output_for( labels_in, is_training=is_training, **kwargs) dlatents_label2 = tf.cast(dlatents_label2, tf.float32) dlatents2 = tf.concat([dlatents2, dlatents_label2], axis=-1) layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 mixing_cutoff = tf.cond( tf.random_uniform([], 0.0, 1.0) < style_mixing_prob, lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32), lambda: cur_layers) dlatents = tf.where( tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) # Evaluate synthesis network. deps = [] if 'lod' in components.synthesis.vars: deps.append(tf.assign(components.synthesis.vars['lod'], lod_in)) with tf.control_dependencies(deps): images_out = components.synthesis.get_output_for( dlatents, is_training=is_training, force_clean_graph=is_template_graph, **kwargs) # Return requested outputs. images_out = tf.identity(images_out, name='images_out') if return_dlatents: return images_out, dlatents return images_out
def training_schedule( cur_nimg, training_set, num_gpus, lod_initial_resolution=4, # Image resolution used at the beginning. lod_training_kimg=600, # Thousands of real images to show before doubling the resolution. lod_transition_kimg=600, # Thousands of real images to show when fading in new layers. minibatch_base=16, # Maximum minibatch size, divided evenly among GPUs. minibatch_dict={}, # Resolution-specific overrides. max_minibatch_per_gpu={}, # Resolution-specific maximum minibatch size per GPU. G_lrate_base=0.001, # Learning rate for the generator. G_lrate_dict={}, # Resolution-specific overrides. D_lrate_base=0.001, # Learning rate for the discriminator. D_lrate_dict={}, # Resolution-specific overrides. lrate_rampup_kimg=0, # Duration of learning rate ramp-up. tick_kimg_base=160, # Default interval of progress snapshots. tick_kimg_dict={ 4: 160, 8: 140, 16: 120, 32: 100, 64: 80, 128: 60, 256: 40, 512: 30, 1024: 20 }): # Resolution-specific overrides. # Initialize result dict. s = dnnlib.EasyDict() s.kimg = cur_nimg / 1000.0 # Training phase. phase_dur = lod_training_kimg + lod_transition_kimg phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0 phase_kimg = s.kimg - phase_idx * phase_dur # Level-of-detail and resolution. s.lod = training_set.resolution_log2 s.lod -= np.floor(np.log2(lod_initial_resolution)) s.lod -= phase_idx if lod_transition_kimg > 0: s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg s.lod = max(s.lod, 0.0) s.resolution = 2**(training_set.resolution_log2 - int(np.floor(s.lod))) # Minibatch size. s.minibatch = minibatch_dict.get(s.resolution, minibatch_base) s.minibatch -= s.minibatch % num_gpus if s.resolution in max_minibatch_per_gpu: s.minibatch = min(s.minibatch, max_minibatch_per_gpu[s.resolution] * num_gpus) # Learning rate. s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base) s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base) if lrate_rampup_kimg > 0: rampup = min(s.kimg / lrate_rampup_kimg, 1.0) s.G_lrate *= rampup s.D_lrate *= rampup # Other parameters. s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base) return s
def convert_tf_discriminator(tf_D): if tf_D.version < 4: raise ValueError('TensorFlow pickle version too low') # Collect kwargs. tf_kwargs = tf_D.static_kwargs known_kwargs = set() def kwarg(tf_name, default=None): known_kwargs.add(tf_name) return tf_kwargs.get(tf_name, default) # Convert kwargs. kwargs = dnnlib.EasyDict( c_dim=kwarg('label_size', 0), img_resolution=kwarg('resolution', 1024), img_channels=kwarg('num_channels', 3), architecture=kwarg('architecture', 'resnet'), channel_base=kwarg('fmap_base', 16384) * 2, channel_max=kwarg('fmap_max', 512), num_fp16_res=kwarg('num_fp16_res', 0), conv_clamp=kwarg('conv_clamp', None), cmap_dim=kwarg('mapping_fmaps', None), block_kwargs=dnnlib.EasyDict( activation=kwarg('nonlinearity', 'lrelu'), resample_filter=kwarg('resample_kernel', [1, 3, 3, 1]), freeze_layers=kwarg('freeze_layers', 0), ), mapping_kwargs=dnnlib.EasyDict( num_layers=kwarg('mapping_layers', 0), embed_features=kwarg('mapping_fmaps', None), layer_features=kwarg('mapping_fmaps', None), activation=kwarg('nonlinearity', 'lrelu'), lr_multiplier=kwarg('mapping_lrmul', 0.1), ), epilogue_kwargs=dnnlib.EasyDict( mbstd_group_size=kwarg('mbstd_group_size', None), mbstd_num_channels=kwarg('mbstd_num_features', 1), activation=kwarg('nonlinearity', 'lrelu'), ), ) # Check for unknown kwargs. kwarg('structure') unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) if len(unknown_kwargs) > 0: raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) # Collect params. tf_params = _collect_tf_params(tf_D) for name, value in list(tf_params.items()): match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name) if match: r = kwargs.img_resolution // (2**int(match.group(1))) tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value kwargs.architecture = 'orig' #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') # Convert params. from training import networks D = networks.Discriminator(**kwargs).eval().requires_grad_(False) # pylint: disable=unnecessary-lambda _populate_module_params( D, r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1), r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'], r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'] .transpose(3, 2, 0, 1), r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'], r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1), r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(), r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'], r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(), r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'], r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1), r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'], r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(), r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'], r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(), r'b4\.out\.bias', lambda: tf_params[f'Output/bias'], r'.*\.resample_filter', None, ) return D
def training_schedule( cur_nimg, training_set, lod_initial_resolution=None, # Image resolution used at the beginning. lod_training_kimg=600, # Thousands of real images to show before doubling the resolution. lod_transition_kimg=600, # Thousands of real images to show when fading in new layers. minibatch_size_base=32, # Global minibatch size. minibatch_size_dict={}, # Resolution-specific overrides. minibatch_gpu_base=4, # Number of samples processed at a time by one GPU. minibatch_gpu_dict={}, # Resolution-specific overrides. G_lrate_base=0.002, # Learning rate for the generator. G_lrate_dict={}, # Resolution-specific overrides. D_lrate_base=0.002, # Learning rate for the discriminator. D_lrate_dict={}, # Resolution-specific overrides. lrate_rampup_kimg=0, # Duration of learning rate ramp-up. tick_kimg_base=4, # Default interval of progress snapshots. tick_kimg_dict={ 8: 28, 16: 24, 32: 20, 64: 16, 128: 12, 256: 8, 512: 6, 1024: 4 }): # Resolution-specific overrides. # Initialize result dict. s = dnnlib.EasyDict() s.kimg = cur_nimg / 1000.0 # Training phase. phase_dur = lod_training_kimg + lod_transition_kimg phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0 phase_kimg = s.kimg - phase_idx * phase_dur # Level-of-detail and resolution. if lod_initial_resolution is None: s.lod = 0.0 else: s.lod = training_set.resolution_log2 s.lod -= np.floor(np.log2(lod_initial_resolution)) s.lod -= phase_idx if lod_transition_kimg > 0: s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg s.lod = max(s.lod, 0.0) s.resolution = 2**(training_set.resolution_log2 - int(np.floor(s.lod))) # Minibatch size. s.minibatch_size = minibatch_size_dict.get(s.resolution, minibatch_size_base) s.minibatch_gpu = minibatch_gpu_dict.get(s.resolution, minibatch_gpu_base) # Learning rate. s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base) s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base) if lrate_rampup_kimg > 0: rampup = min(s.kimg / lrate_rampup_kimg, 1.0) s.G_lrate *= rampup s.D_lrate *= rampup # Other parameters. s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base) return s
def setup_training_loop_kwargs( # General options (not included in desc). gpus=None, # Number of GPUs: <int>, default = 1 gpu snap=None, # Snapshot interval: <int>, default = 50 ticks metrics=None, # List of metric names: [], ['fid50k_full'] (default), ... seed=None, # Random seed: <int>, default = 0 # Dataset. data=None, # Training dataset (required): <path> cond=None, # Train conditional model based on dataset labels: <bool>, default = False subset=None, # Train with only N images: <int>, default = all mirror=None, # Augment dataset with x-flips: <bool>, default = False # Base config. cfg=None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar' gamma=None, # Override R1 gamma: <float> kimg=None, # Override training duration: <int> batch=None, # Override batch size: <int> # Discriminator augmentation. aug=None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed' p=None, # Specify p for 'fixed' (required): <float> target=None, # Override ADA target for 'ada': <float>, default = depends on aug augpipe=None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc' # Transfer learning. resume=None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url> freezed=None, # Freeze-D: <int>, default = 0 discriminator layers # Performance options (not included in desc). fp32=None, # Disable mixed-precision training: <bool>, default = False nhwc=None, # Use NHWC memory format with FP16: <bool>, default = False allow_tf32=None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False nobench=None, # Disable cuDNN benchmarking: <bool>, default = False workers=None, # Override number of DataLoader workers: <int>, default = 3 ): args = dnnlib.EasyDict() # ------------------------------------------ # General options: gpus, snap, metrics, seed # ------------------------------------------ if gpus is None: gpus = 1 assert isinstance(gpus, int) if not (gpus >= 1 and gpus & (gpus - 1) == 0): raise UserError('--gpus must be a power of two') args.num_gpus = gpus if snap is None: snap = 50 assert isinstance(snap, int) if snap < 1: raise UserError('--snap must be at least 1') args.image_snapshot_ticks = snap args.network_snapshot_ticks = snap if metrics is None: metrics = ['fid50k_full'] assert isinstance(metrics, list) if not all(metric_main.is_valid_metric(metric) for metric in metrics): raise UserError( '\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) args.metrics = metrics if seed is None: seed = 0 assert isinstance(seed, int) args.random_seed = seed # ----------------------------------- # Dataset: data, cond, subset, mirror # ----------------------------------- assert data is not None assert isinstance(data, str) args.training_set_kwargs = dnnlib.EasyDict( class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False) args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2) try: training_set = dnnlib.util.construct_class_by_name( **args.training_set_kwargs) # subclass of training.dataset.Dataset args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels args.training_set_kwargs.max_size = len( training_set) # be explicit about dataset size desc = training_set.name del training_set # conserve memory except IOError as err: raise UserError(f'--data: {err}') if cond is None: cond = False assert isinstance(cond, bool) if cond: if not args.training_set_kwargs.use_labels: raise UserError( '--cond=True requires labels specified in dataset.json') desc += '-cond' else: args.training_set_kwargs.use_labels = False if subset is not None: assert isinstance(subset, int) if not 1 <= subset <= args.training_set_kwargs.max_size: raise UserError( f'--subset must be between 1 and {args.training_set_kwargs.max_size}' ) desc += f'-subset{subset}' if subset < args.training_set_kwargs.max_size: args.training_set_kwargs.max_size = subset args.training_set_kwargs.random_seed = args.random_seed if mirror is None: mirror = False assert isinstance(mirror, bool) if mirror: desc += '-mirror' args.training_set_kwargs.xflip = True # ------------------------------------ # Base config: cfg, gamma, kimg, batch # ------------------------------------ if cfg is None: cfg = 'auto' assert isinstance(cfg, str) desc += f'-{cfg}' cfg_specs = { 'auto': dict( ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count. 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2. 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8), 'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2), } assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) if cfg == 'auto': desc += f'{gpus:d}' spec.ref_gpus = gpus res = args.training_set_kwargs.resolution spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay spec.mbstd = min( spec.mb // gpus, 4 ) # other hyperparams behave more predictably if mbstd group size remains fixed spec.fmaps = 1 if res >= 512 else 0.5 spec.lrate = 0.002 if res >= 1024 else 0.0025 spec.gamma = 0.0002 * (res**2) / spec.mb # heuristic formula spec.ema = spec.mb * 10 / 32 args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict()) args.D_kwargs = dnnlib.EasyDict( class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict()) args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int( spec.fmaps * 32768) args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 args.G_kwargs.mapping_kwargs.num_layers = 3 # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!Edit to 3 args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8) args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8) args.loss_kwargs = dnnlib.EasyDict( class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma) args.total_kimg = spec.kimg args.batch_size = spec.mb args.batch_gpu = spec.mb // spec.ref_gpus args.ema_kimg = spec.ema args.ema_rampup = spec.ramp if cfg == 'cifar': args.loss_kwargs.pl_weight = 0 # disable path length regularization args.loss_kwargs.style_mixing_prob = 0 # disable style mixing args.D_kwargs.architecture = 'orig' # disable residual skip connections if gamma is not None: assert isinstance(gamma, float) if not gamma >= 0: raise UserError('--gamma must be non-negative') desc += f'-gamma{gamma:g}' args.loss_kwargs.r1_gamma = gamma if kimg is not None: assert isinstance(kimg, int) if not kimg >= 1: raise UserError('--kimg must be at least 1') desc += f'-kimg{kimg:d}' args.total_kimg = kimg if batch is not None: assert isinstance(batch, int) if not (batch >= 1 and batch % gpus == 0): raise UserError( '--batch must be at least 1 and divisible by --gpus') desc += f'-batch{batch}' args.batch_size = batch args.batch_gpu = batch // gpus # --------------------------------------------------- # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- if aug is None: aug = 'ada' else: assert isinstance(aug, str) desc += f'-{aug}' if aug == 'ada': args.ada_target = 0.6 elif aug == 'noaug': pass elif aug == 'fixed': if p is None: raise UserError(f'--aug={aug} requires specifying --p') else: raise UserError(f'--aug={aug} not supported') if p is not None: assert isinstance(p, float) if aug != 'fixed': raise UserError('--p can only be specified with --aug=fixed') if not 0 <= p <= 1: raise UserError('--p must be between 0 and 1') desc += f'-p{p:g}' args.augment_p = p if target is not None: assert isinstance(target, float) if aug != 'ada': raise UserError('--target can only be specified with --aug=ada') if not 0 <= target <= 1: raise UserError('--target must be between 0 and 1') desc += f'-target{target:g}' args.ada_target = target assert augpipe is None or isinstance(augpipe, str) if augpipe is None: augpipe = 'bgc' else: if aug == 'noaug': raise UserError('--augpipe cannot be specified with --aug=noaug') desc += f'-{augpipe}' augpipe_specs = { 'blit': dict(xflip=1, rotate90=1, xint=1), 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1), 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'filter': dict(imgfilter=1), 'noise': dict(noise=1), 'cutout': dict(cutout=1), 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1), 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1), 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1), } assert augpipe in augpipe_specs if aug != 'noaug': args.augment_kwargs = dnnlib.EasyDict( class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe]) # ---------------------------------- # Transfer learning: resume, freezed # ---------------------------------- resume_specs = { 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', } assert resume is None or isinstance(resume, str) if resume is None: resume = 'noresume' elif resume == 'noresume': desc += '-noresume' elif resume in resume_specs: desc += f'-resume{resume}' args.resume_pkl = resume_specs[resume] # predefined url else: desc += '-resumecustom' args.resume_pkl = resume # custom path or url if resume != 'noresume': args.ada_kimg = 100 # make ADA react faster at the beginning args.ema_rampup = None # disable EMA rampup if freezed is not None: assert isinstance(freezed, int) if not freezed >= 0: raise UserError('--freezed must be non-negative') desc += f'-freezed{freezed:d}' args.D_kwargs.block_kwargs.freeze_layers = freezed # ------------------------------------------------- # Performance options: fp32, nhwc, nobench, workers # ------------------------------------------------- if fp32 is None: fp32 = False assert isinstance(fp32, bool) if fp32: args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None if nhwc is None: nhwc = False assert isinstance(nhwc, bool) if nhwc: args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True if nobench is None: nobench = False assert isinstance(nobench, bool) if nobench: args.cudnn_benchmark = False if allow_tf32 is None: allow_tf32 = False assert isinstance(allow_tf32, bool) if allow_tf32: args.allow_tf32 = True if workers is not None: assert isinstance(workers, int) if not workers >= 1: raise UserError('--workers must be at least 1') args.data_loader_kwargs.num_workers = workers return desc, args
def _report_result(self, value, suffix='', fmt='%-10.4f'): self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]
import validation # Submit config # ------------------------------------------------------------------------------------------ submit_config = dnnlib.SubmitConfig() submit_config.run_dir_root = 'results' submit_config.run_dir_ignore += ['datasets', 'results'] desc = "autoencoder" # Tensorflow config # ------------------------------------------------------------------------------------------ tf_config = dnnlib.EasyDict() tf_config["graph_options.place_pruned_graph"] = True # Network config # ------------------------------------------------------------------------------------------ net_config = dnnlib.EasyDict(func_name="network.autoencoder") # Optimizer config # ------------------------------------------------------------------------------------------ optimizer_config = dnnlib.EasyDict(beta1=0.9, beta2=0.99, epsilon=1e-8) # Noise augmentation config gaussian_noise_config = dnnlib.EasyDict(func_name='train.AugmentGaussian', train_stddev_rng_range=(0.0, 50.0),
def generate_images(network_pkl, seeds, truncation_psi, data_dir=None, dataset_name=None, model=None): G_args = EasyDict(func_name='training.' + model + '.G_main') dataset_args = EasyDict(tfrecord_dir=dataset_name) G_args.fmap_base = 8 << 10 tflib.init_tf() training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) print('Constructing networks...') Gs = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) print('Loading networks from "%s"...' % network_pkl) _, _, _Gs = pretrained_networks.load_networks(network_pkl) Gs.copy_vars_from(_Gs) noise_vars = [ var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise') ] Gs_kwargs = dnnlib.EasyDict() # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) Gs_kwargs.randomize_noise = False if truncation_psi is not None: Gs_kwargs.truncation_psi = truncation_psi for seed_idx, seed in enumerate(seeds): print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) rnd = np.random.RandomState(seed) z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component] tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] images, x_v, n_v, m_v = Gs.run( z, None, **Gs_kwargs) # [minibatch, height, width, channel] print(images.shape, n_v.shape, x_v.shape, m_v.shape) misc.convert_to_pil_image(images[0], drange=[-1, 1]).save( dnnlib.make_run_dir_path('seed%04d.png' % seed)) misc.save_image_grid(adjust_range(n_v), dnnlib.make_run_dir_path('seed%04d-nv.png' % seed), drange=[-1, 1]) print(np.linalg.norm(x_v - m_v)) misc.save_image_grid(adjust_range(x_v).transpose([1, 0, 2, 3]), dnnlib.make_run_dir_path('seed%04d-xv.png' % seed), drange=[-1, 1]) misc.save_image_grid(adjust_range(m_v).transpose([1, 0, 2, 3]), dnnlib.make_run_dir_path('seed%04d-mv.png' % seed), drange=[-1, 1]) misc.save_image_grid(adjust_range(clip(x_v, 'cat')), dnnlib.make_run_dir_path('seed%04d-xvs.png' % seed), drange=[-1, 1]) misc.save_image_grid(adjust_range(clip(m_v, 'ss')), dnnlib.make_run_dir_path('seed%04d-mvs.png' % seed), drange=[-1, 1]) misc.save_image_grid(adjust_range(clip(m_v, 'ffhq')), dnnlib.make_run_dir_path('seed%04d-fmvs.png' % seed), drange=[-1, 1])
def style_mixing_example(network_pkl, row_seeds, col_seeds, truncation_psi, col_styles, minibatch_size=4): print('Loading networks from "%s"...' % network_pkl) _G, _D, Gs = pretrained_networks.load_networks(network_pkl) w_avg = Gs.get_var("dlatent_avg") # [component] Gs_syn_kwargs = dnnlib.EasyDict() Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) Gs_syn_kwargs.randomize_noise = False Gs_syn_kwargs.minibatch_size = minibatch_size print("Generating W vectors...") all_seeds = list(set(row_seeds + col_seeds)) all_z = np.stack([ np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds ]) # [minibatch, component] all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component] all_w = w_avg + (all_w - w_avg) * truncation_psi # [minibatch, layer, component] w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} # [layer, component] print("Generating images...") all_images = Gs.components.synthesis.run( all_w, **Gs_syn_kwargs) # [minibatch, height, width, channel] image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))} print("Generating style-mixed images...") for row_seed in row_seeds: for col_seed in col_seeds: w = w_dict[row_seed].copy() w[col_styles] = w_dict[col_seed][col_styles] image = Gs.components.synthesis.run(w[np.newaxis], **Gs_syn_kwargs)[0] image_dict[(row_seed, col_seed)] = image print("Saving images...") for (row_seed, col_seed), image in image_dict.items(): PIL.Image.fromarray(image, "RGB").save( dnnlib.make_run_dir_path("%d-%d.png" % (row_seed, col_seed))) print("Saving image grid...") _N, _C, H, W = Gs.output_shape canvas = PIL.Image.new("RGB", (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), "black") for row_idx, row_seed in enumerate([None] + row_seeds): for col_idx, col_seed in enumerate([None] + col_seeds): if row_seed is None and col_seed is None: continue key = (row_seed, col_seed) if row_seed is None: key = (col_seed, col_seed) if col_seed is None: key = (row_seed, row_seed) canvas.paste(PIL.Image.fromarray(image_dict[key], "RGB"), (W * col_idx, H * row_idx)) canvas.save(dnnlib.make_run_dir_path("grid.png"))
def main(): os.makedirs(a.out_dir, exist_ok=True) np.random.seed(seed=696) # parse filename to model parameters mparams = basename(a.model).split('-') res = int(mparams[1]) cfg = mparams[2] # setup generator fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.func_name = 'training.stylegan2_custom.G_main' Gs_kwargs.verbose = False Gs_kwargs.resolution = res Gs_kwargs.size = a.size Gs_kwargs.scale_type = a.scale_type Gs_kwargs.latent_size = a.latent_size Gs_kwargs.impl = a.ops if cfg.lower() == 'f': Gs_kwargs.synthesis_func = 'G_synthesis_stylegan2' elif cfg.lower() == 'e': Gs_kwargs.synthesis_func = 'G_synthesis_stylegan2' Gs_kwargs.fmap_base = 8 << 10 else: print(' old modes [A-D] not implemented') exit() # check initial model resolution if len(mparams) > 3: if 'x' in mparams[3].lower(): init_res = [int(x) for x in mparams[3].lower().split('x')] Gs_kwargs.init_res = list(reversed(init_res)) # [H,W] # load model, check channels sess = tflib.init_tf({'allow_soft_placement': True}) pkl_name = osp.splitext(a.model)[0] with open(pkl_name + '.pkl', 'rb') as file: network = pickle.load(file, encoding='latin1') try: _, _, network = network except: pass Gs_kwargs.num_channels = network.output_shape[1] # reload custom network, if needed if '.pkl' in a.model.lower(): print(' .. Gs from pkl ..') Gs = network else: print(' .. Gs custom ..') Gs = tflib.Network('Gs', **Gs_kwargs) Gs.copy_vars_from(network) # Gs.print_layers() print(' out shape', Gs.output_shape[1:]) if a.size is None: a.size = Gs.output_shape[2:] z_dim = Gs.input_shape[1] shape = (1, z_dim) print(' making timeline..') latents = latent_anima(shape, a.frames, a.fstep, cubic=a.cubic, gauss=a.gauss, verbose=True) # [frm,1,512] print(' latents', latents.shape) # generate images from latent timeline frame_count = latents.shape[0] pbar = ProgressBar(frame_count) for i in range(frame_count): output = Gs.run(latents[i], [None], truncation_psi=a.trunc, randomize_noise=False, output_transform=fmt) ext = 'png' if output.shape[3] == 4 else 'jpg' filename = osp.join(a.out_dir, "%05d.%s" % (i, ext)) imsave(filename, output[0]) pbar.upd() # convert latents to dlatents, save them latents = latents.squeeze(1) # [frm,512] dlatents = Gs.components.mapping.run(latents, None, latent_size=z_dim, dtype='float16') # [frm,18,512] filename = '{}-{}-{}.npy'.format(basename(a.model), a.size[1], a.size[0]) filename = osp.join(osp.dirname(a.out_dir), filename) np.save(filename, dlatents) print('saved dlatents', dlatents.shape, 'to', filename)
import time import hashlib import numpy as np import tensorflow as tf import dnnlib import dnnlib.tflib as tflib import config from training import misc from training import dataset #---------------------------------------------------------------------------- # Standard metrics. fid50k = dnnlib.EasyDict(func_name='metrics.frechet_inception_distance.FID', name='fid50k', num_images=50000, minibatch_per_gpu=8) ppl_zfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_zfull', num_samples=100000, epsilon=1e-4, space='z', sampling='full', minibatch_per_gpu=16) ppl_wfull = dnnlib.EasyDict(func_name='metrics.perceptual_path_length.PPL', name='ppl_wfull', num_samples=100000, epsilon=1e-4, space='w', sampling='full', minibatch_per_gpu=16)
def convert_tf_generator(tf_G): if tf_G.version < 4: raise ValueError("TensorFlow pickle version too low") # Collect kwargs. tf_kwargs = tf_G.static_kwargs known_kwargs = set() def kwarg(tf_name, default=None, none=None): known_kwargs.add(tf_name) val = tf_kwargs.get(tf_name, default) return val if val is not None else none # Convert kwargs. kwargs = dnnlib.EasyDict( z_dim=kwarg("latent_size", 512), c_dim=kwarg("label_size", 0), w_dim=kwarg("dlatent_size", 512), img_resolution=kwarg("resolution", 1024), img_channels=kwarg("num_channels", 3), mapping_kwargs=dnnlib.EasyDict( num_layers=kwarg("mapping_layers", 8), embed_features=kwarg("label_fmaps", None), layer_features=kwarg("mapping_fmaps", None), activation=kwarg("mapping_nonlinearity", "lrelu"), lr_multiplier=kwarg("mapping_lrmul", 0.01), w_avg_beta=kwarg("w_avg_beta", 0.995, none=1), ), synthesis_kwargs=dnnlib.EasyDict( channel_base=kwarg("fmap_base", 16384) * 2, channel_max=kwarg("fmap_max", 512), num_fp16_res=kwarg("num_fp16_res", 0), conv_clamp=kwarg("conv_clamp", None), architecture=kwarg("architecture", "skip"), resample_filter=kwarg("resample_kernel", [1, 3, 3, 1]), use_noise=kwarg("use_noise", True), activation=kwarg("nonlinearity", "lrelu"), ), ) # Check for unknown kwargs. kwarg("truncation_psi") kwarg("truncation_cutoff") kwarg("style_mixing_prob") kwarg("structure") unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) if len(unknown_kwargs) > 0: raise ValueError("Unknown TensorFlow kwarg", unknown_kwargs[0]) # Collect params. tf_params = _collect_tf_params(tf_G) for name, value in list(tf_params.items()): match = re.fullmatch(r"ToRGB_lod(\d+)/(.*)", name) if match: r = kwargs.img_resolution // (2**int(match.group(1))) tf_params[f"{r}x{r}/ToRGB/{match.group(2)}"] = value kwargs.synthesis.kwargs.architecture = "orig" # for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') # Convert params. from training import networks G = networks.Generator(**kwargs).eval().requires_grad_(False) # pylint: disable=unnecessary-lambda _populate_module_params( G, r"mapping\.w_avg", lambda: tf_params[f"dlatent_avg"], r"mapping\.embed\.weight", lambda: tf_params[f"mapping/LabelEmbed/weight"].transpose(), r"mapping\.embed\.bias", lambda: tf_params[f"mapping/LabelEmbed/bias"], r"mapping\.fc(\d+)\.weight", lambda i: tf_params[f"mapping/Dense{i}/weight"].transpose(), r"mapping\.fc(\d+)\.bias", lambda i: tf_params[f"mapping/Dense{i}/bias"], r"synthesis\.b4\.const", lambda: tf_params[f"synthesis/4x4/Const/const"][0], r"synthesis\.b4\.conv1\.weight", lambda: tf_params[f"synthesis/4x4/Conv/weight"].transpose(3, 2, 0, 1), r"synthesis\.b4\.conv1\.bias", lambda: tf_params[f"synthesis/4x4/Conv/bias"], r"synthesis\.b4\.conv1\.noise_const", lambda: tf_params[f"synthesis/noise0"][0, 0], r"synthesis\.b4\.conv1\.noise_strength", lambda: tf_params[f"synthesis/4x4/Conv/noise_strength"], r"synthesis\.b4\.conv1\.affine\.weight", lambda: tf_params[f"synthesis/4x4/Conv/mod_weight"].transpose(), r"synthesis\.b4\.conv1\.affine\.bias", lambda: tf_params[f"synthesis/4x4/Conv/mod_bias"] + 1, r"synthesis\.b(\d+)\.conv0\.weight", lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/weight" ][::-1, ::-1].transpose(3, 2, 0, 1), r"synthesis\.b(\d+)\.conv0\.bias", lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/bias"], r"synthesis\.b(\d+)\.conv0\.noise_const", lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-5}"][0, 0 ], r"synthesis\.b(\d+)\.conv0\.noise_strength", lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/noise_strength"], r"synthesis\.b(\d+)\.conv0\.affine\.weight", lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_weight"]. transpose(), r"synthesis\.b(\d+)\.conv0\.affine\.bias", lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_bias"] + 1, r"synthesis\.b(\d+)\.conv1\.weight", lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/weight"].transpose( 3, 2, 0, 1), r"synthesis\.b(\d+)\.conv1\.bias", lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/bias"], r"synthesis\.b(\d+)\.conv1\.noise_const", lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-4}"][0, 0 ], r"synthesis\.b(\d+)\.conv1\.noise_strength", lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/noise_strength"], r"synthesis\.b(\d+)\.conv1\.affine\.weight", lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_weight"].transpose(), r"synthesis\.b(\d+)\.conv1\.affine\.bias", lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_bias"] + 1, r"synthesis\.b(\d+)\.torgb\.weight", lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/weight"].transpose( 3, 2, 0, 1), r"synthesis\.b(\d+)\.torgb\.bias", lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/bias"], r"synthesis\.b(\d+)\.torgb\.affine\.weight", lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_weight"].transpose(), r"synthesis\.b(\d+)\.torgb\.affine\.bias", lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_bias"] + 1, r"synthesis\.b(\d+)\.skip\.weight", lambda r: tf_params[f"synthesis/{r}x{r}/Skip/weight" ][::-1, ::-1].transpose(3, 2, 0, 1), r".*\.resample_filter", None, ) return G
def load(pkl_file): with open(pkl_file, 'rb') as f: s = dnnlib.EasyDict(pickle.load(f)) obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) obj.__dict__.update(s) return obj
from dnnlib.util import Logger from ffhq_datareader import load_dataset from experiments import compute_stylegan_realism from experiments import compute_stylegan_truncation from utils import init_tf SAVE_PATH = os.path.dirname(__file__) #---------------------------------------------------------------------------- # Configs for truncation sweep and realism score. realism_config = dnnlib.EasyDict(minibatch_size=8, num_images=50000, num_gen_images=1000, show_n_images=64, truncation=1.0, save_images=True, save_path=SAVE_PATH, num_gpus=1, random_seed=123456) truncation_config = dnnlib.EasyDict(minibatch_size=8, num_images=50000, truncations=[1.0, 0.7, 0.3], save_txt=True, save_path=SAVE_PATH, num_gpus=1, random_seed=1234) #---------------------------------------------------------------------------- # Minimal CLI.
def _report_result(self, value, suffix="", fmt="%-10.4f"): self._results += [dnnlib.EasyDict(value=value, suffix=suffix, fmt=fmt)]
def main(): os.makedirs(a.out_dir, exist_ok=True) device = torch.device('cuda') # setup generator Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.verbose = a.verbose Gs_kwargs.size = a.size Gs_kwargs.scale_type = a.scale_type # load base or custom network pkl_name = osp.splitext(a.model)[0] if '.pkl' in a.model.lower(): custom = False print(' .. Gs from pkl ..', basename(a.model)) else: custom = True print(' .. Gs custom ..', basename(a.model)) with dnnlib.util.open_url(pkl_name + '.pkl') as f: Gs = legacy.load_network_pkl(f, custom=custom, **Gs_kwargs)['G_ema'].to(device) # type: ignore dlat_shape = (1, Gs.num_ws, Gs.w_dim) # [1,18,512] # read saved latents if a.dlatents is not None and osp.isfile(a.dlatents): key_dlatents = load_latents(a.dlatents) if len(key_dlatents.shape) == 2: key_dlatents = np.expand_dims(key_dlatents, 0) elif a.dlatents is not None and osp.isdir(a.dlatents): # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1] key_dlatents = [] npy_list = file_list(a.dlatents, 'npy') for npy in npy_list: key_dlatent = load_latents(npy) if len(key_dlatent.shape) == 2: key_dlatent = np.expand_dims(key_dlatent, 0) key_dlatents.append(key_dlatent) key_dlatents = np.concatenate(key_dlatents) # [frm,18,512] else: print(' No input dlatents found'); exit() key_dlatents = key_dlatents[:, np.newaxis] # [frm,1,18,512] print(' key dlatents', key_dlatents.shape) # replace higher layers with single (style) latent if a.style_dlat is not None: print(' styling with dlatent', a.style_dlat) style_dlatent = load_latents(a.style_dlat) while len(style_dlatent.shape) < 4: style_dlatent = np.expand_dims(style_dlatent, 0) # try replacing 5 by other value, less than Gs.num_ws key_dlatents[:, :, range(5, Gs.num_ws), :] = style_dlatent[:, :, range(5, Gs.num_ws), :] frames = key_dlatents.shape[0] * a.fstep dlatents = latent_anima(dlat_shape, frames, a.fstep, key_latents=key_dlatents, cubic=a.cubic, verbose=True) # [frm,1,512] print(' dlatents', dlatents.shape) frame_count = dlatents.shape[0] dlatents = torch.from_numpy(dlatents).to(device) # distort image by tweaking initial const layer if a.digress > 0: try: init_res = Gs.init_res except: init_res = (4, 4) # default initial layer size dconst = a.digress * latent_anima([1, Gs.z_dim, *init_res], frame_count, a.fstep, cubic=True, verbose=False) else: dconst = np.zeros([frame_count, 1, 1, 1, 1]) dconst = torch.from_numpy(dconst).to(device) # generate images from latent timeline pbar = ProgressBar(frame_count) for i in range(frame_count): # generate multi-latent result if custom: output = Gs.synthesis(dlatents[i], None, dconst[i], noise_mode='const') else: output = Gs.synthesis(dlatents[i], noise_mode='const') output = (output.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy() ext = 'png' if output.shape[3] == 4 else 'jpg' filename = osp.join(a.out_dir, "%06d.%s" % (i, ext)) imsave(filename, output[0]) pbar.upd()
def G_style( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. truncation_psi=0.7, # Style strength multiplier for the truncation trick. None = disable. truncation_cutoff=8, # Number of layers for which to apply the truncation trick. None = disable. truncation_psi_val=None, # Value for truncation_psi to use during validation. truncation_cutoff_val=None, # Value for truncation_cutoff to use during validation. dlatent_avg_beta=0.995, # Decay for tracking the moving average of W during training. None = disable. style_mixing_prob=0.9, # Probability of mixing styles during training. None = disable. is_training=False, # Network is under training? Enables and disables specific features. is_validation=False, # Network is under validation? Chooses which value to use for truncation_psi. is_template_graph=False, # True = template graph constructed by the Network class, False = actual evaluation. components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. **kwargs): # Arguments for sub-networks (G_mapping and G_synthesis). # Validate arguments. assert not is_training or not is_validation assert isinstance(components, dnnlib.EasyDict) if is_validation: truncation_psi = truncation_psi_val truncation_cutoff = truncation_cutoff_val if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): truncation_psi = None if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0): truncation_cutoff = None if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): dlatent_avg_beta = None if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): style_mixing_prob = None # Setup components. if 'synthesis' not in components: components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs) num_layers = components.synthesis.input_shape[1] dlatent_size = components.synthesis.input_shape[2] if 'mapping' not in components: components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs) # Setup variables. lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False) # Evaluate mapping network. dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs) # Update moving average of W. if dlatent_avg_beta is not None: with tf.variable_scope('DlatentAvg'): batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) update_op = tf.assign( dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) with tf.control_dependencies([update_op]): dlatents = tf.identity(dlatents) # Perform style mixing regularization. if style_mixing_prob is not None: with tf.name_scope('StyleMix'): latents2 = tf.random_normal(tf.shape(latents_in)) dlatents2 = components.mapping.get_output_for( latents2, labels_in, **kwargs) layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 mixing_cutoff = tf.cond( tf.random_uniform([], 0.0, 1.0) < style_mixing_prob, lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32), lambda: cur_layers) dlatents = tf.where( tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) # Apply truncation trick. if truncation_psi is not None and truncation_cutoff is not None: with tf.variable_scope('Truncation'): layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] ones = np.ones(layer_idx.shape, dtype=np.float32) coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) dlatents = tflib.lerp(dlatent_avg, dlatents, coefs) # Evaluate synthesis network. with tf.control_dependencies( [tf.assign(components.synthesis.find_var('lod'), lod_in)]): images_out = components.synthesis.get_output_for( dlatents, force_clean_graph=is_template_graph, **kwargs) return tf.identity(images_out, name='images_out')
def lerp_video( network_pkl, # Path to pretrained model pkl file seeds, # Random seeds grid_w=None, # Number of columns grid_h=None, # Number of rows truncation_psi=1.0, # Truncation trick outdir='out', # Output dir slowdown=1, # Slowdown of the video (power of 2) duration_sec=30.0, # Duration of video in seconds smoothing_sec=3.0, mp4_fps=30, mp4_codec="libx264", mp4_bitrate="16M"): # Sanity check regarding slowdown message = 'slowdown must be a power of 2 (1, 2, 4, 8, ...) and greater than 0!' assert slowdown & (slowdown - 1) == 0 and slowdown > 0, message # Initialize TensorFlow and create outdir tflib.init_tf() os.makedirs(outdir, exist_ok=True) # Total duration of video and number of frames to generate num_frames = int(np.rint(duration_sec * mp4_fps)) total_duration = duration_sec * slowdown print(f'Loading network from {network_pkl}...') with dnnlib.util.open_url(network_pkl) as fp: _G, _D, Gs = pickle.load(fp) print("Generating latent vectors...") # If there's more than one seed provided and the shape isn't specified if grid_w == grid_h == None and len(seeds) >= 1: # number of images according to the seeds provided num = len(seeds) # Get the grid width and height according to num: grid_w = max(int(np.ceil(np.sqrt(num))), 1) grid_h = max((num - 1) // grid_w + 1, 1) grid_size = [grid_w, grid_h] # [frame, image, channel, component]: shape = [num_frames] + Gs.input_shape[1:] # Get the latents: all_latents = np.stack([ np.random.RandomState(seed).randn(*shape).astype(np.float32) for seed in seeds ], axis=1) # If only one seed is provided and the shape is specified elif None not in (grid_w, grid_h) and len(seeds) == 1: # Otherwise, the user gives one seed and the grid width and height: grid_size = [grid_w, grid_h] # [frame, image, channel, component]: shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:] # Get the latents with the random state: random_state = np.random.RandomState(seeds) all_latents = random_state.randn(*shape).astype(np.float32) else: print("Error: wrong combination of arguments! Please provide \ either one seed and the grid width and height, or a \ list of seeds to use.") sys.exit(1) all_latents = scipy.ndimage.gaussian_filter( all_latents, [smoothing_sec * mp4_fps] + [0] * len(Gs.input_shape), mode="wrap") all_latents /= np.sqrt(np.mean(np.square(all_latents))) # Name of the final mp4 video mp4 = f"{grid_w}x{grid_h}-lerp-{slowdown}xslowdown.mp4" # Aux function to slowdown the video by 2x def double_slowdown(latents, duration_sec, num_frames): # Make an empty latent vector with double the amount of frames z = np.empty(np.multiply(latents.shape, [2, 1, 1]), dtype=np.float32) # Populate it for i in range(len(latents)): z[2 * i] = latents[i] # Interpolate in the odd frames for i in range(1, len(z), 2): # For the last frame, we loop to the first one if i == len(z) - 1: z[i] = (z[0] + z[i - 1]) / 2 else: z[i] = (z[i - 1] + z[i + 1]) / 2 # We also need to double the duration_sec and num_frames duration_sec *= 2 num_frames *= 2 # Return the new latents, and the two previous quantities return z, duration_sec, num_frames while slowdown > 1: all_latents, duration_sec, num_frames = double_slowdown( all_latents, duration_sec, num_frames) slowdown //= 2 # Define the kwargs for the Generator: Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) Gs_kwargs.randomize_noise = False if truncation_psi is not None: Gs_kwargs.truncation_psi = truncation_psi # Aux function: Frame generation func for moviepy. def make_frame(t): frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1)) latents = all_latents[frame_idx] # Get the images (with labels = None) images = Gs.run(latents, None, **Gs_kwargs) # Generate the grid for this timestamp: grid = create_image_grid(images, grid_size) # grayscale => RGB if grid.shape[2] == 1: grid = grid.repeat(3, 2) return grid # Generate video using make_frame: print( f'Generating interpolation video of length: {total_duration} seconds...' ) videoclip = moviepy.editor.VideoClip(make_frame, duration=duration_sec) videoclip.write_videofile(os.path.join(outdir, mp4), fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)
def main(): os.makedirs(a.out_dir, exist_ok=True) # setup generator fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.func_name = 'training.stylegan2_multi.G_main' Gs_kwargs.verbose = a.verbose Gs_kwargs.size = a.size Gs_kwargs.scale_type = a.scale_type Gs_kwargs.impl = a.ops # load model with arguments sess = tflib.init_tf({'allow_soft_placement': True}) pkl_name = osp.splitext(a.model)[0] with open(pkl_name + '.pkl', 'rb') as file: network = pickle.load(file, encoding='latin1') try: _, _, network = network except: pass for k in list(network.static_kwargs.keys()): Gs_kwargs[k] = network.static_kwargs[k] # reload custom network, if needed if '.pkl' in a.model.lower(): print(' .. Gs from pkl ..', basename(a.model)) Gs = network else: # reconstruct network print(' .. Gs custom ..', basename(a.model)) Gs = tflib.Network('Gs', **Gs_kwargs) Gs.copy_vars_from(network) z_dim = Gs.input_shape[1] dz_dim = 512 # dlatent_size try: dl_dim = 2 * (int(np.floor(np.log2(Gs_kwargs.resolution))) - 1) except: print(' Resave model, no resolution kwarg found!') exit(1) dlat_shape = (1, dl_dim, dz_dim) # [1,18,512] # read saved latents if a.dlatents is not None and osp.isfile(a.dlatents): key_dlatents = load_latents(a.dlatents) if len(key_dlatents.shape) == 2: key_dlatents = np.expand_dims(key_dlatents, 0) elif a.dlatents is not None and osp.isdir(a.dlatents): # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1] key_dlatents = [] npy_list = file_list(a.dlatents, 'npy') for npy in npy_list: key_dlatent = load_latents(npy) if len(key_dlatent.shape) == 2: key_dlatent = np.expand_dims(key_dlatent, 0) key_dlatents.append(key_dlatent) key_dlatents = np.concatenate(key_dlatents) # [frm,18,512] else: print(' No input dlatents found') exit() key_dlatents = key_dlatents[:, np.newaxis] # [frm,1,18,512] print(' key dlatents', key_dlatents.shape) # replace higher layers with single (style) latent if a.style_dlat is not None: print(' styling with dlatent', a.style_dlat) style_dlatent = load_latents(a.style_dlat) while len(style_dlatent.shape) < 4: style_dlatent = np.expand_dims(style_dlatent, 0) # try replacing 5 by other value, less than dl_dim key_dlatents[:, :, range(5, dl_dim), :] = style_dlatent[:, :, range(5, dl_dim), :] frames = key_dlatents.shape[0] * a.fstep dlatents = latent_anima(dlat_shape, frames, a.fstep, key_latents=key_dlatents, cubic=a.cubic, verbose=True) # [frm,1,512] print(' dlatents', dlatents.shape) frame_count = dlatents.shape[0] # truncation trick dlatent_avg = Gs.get_var('dlatent_avg') # (512,) tr_range = range(0, 8) dlatents[:, :, tr_range, :] = dlatent_avg + (dlatents[:, :, tr_range, :] - dlatent_avg) * a.trunc # distort image by tweaking initial const layer if a.digress > 0: try: latent_size = Gs.static_kwargs['latent_size'] except: latent_size = 512 # default latent size try: init_res = Gs.static_kwargs['init_res'] except: init_res = (4, 4) # default initial layer size dconst = a.digress * latent_anima([1, latent_size, *init_res], frames, a.fstep, cubic=True, verbose=False) else: dconst = np.zeros([frame_count, 1, 1, 1, 1]) # generate images from latent timeline pbar = ProgressBar(frame_count) for i in range(frame_count): # generate multi-latent result if Gs.num_inputs == 2: output = Gs.components.synthesis.run(dlatents[i], randomize_noise=False, output_transform=fmt, minibatch_size=1) else: output = Gs.components.synthesis.run(dlatents[i], [None], dconst[i], randomize_noise=False, output_transform=fmt, minibatch_size=1) ext = 'png' if output.shape[3] == 4 else 'jpg' filename = osp.join(a.out_dir, "%06d.%s" % (i, ext)) imsave(filename, output[0]) pbar.upd()
def run(self, network_pkl, run_dir=None, data_dir=None, dataset_args=None, mirror_augment=None, num_gpus=1, tf_config=None, log_results=True, num_repeats=1, Gs_kwargs=dict(is_validation=True), resume_with_new_nets=False, truncations=[None]): self._reset(network_pkl=network_pkl, run_dir=run_dir, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment) with tf.Graph().as_default(), tflib.create_session( tf_config).as_default(): # pylint: disable=not-context-manager self._report_progress(0, 1) _G, _D, Gs = misc.load_pkl(self._network_pkl) if resume_with_new_nets: dataset = self._get_dataset_obj() G = dnnlib.tflib.Network( 'G', num_channels=dataset.shape[0], resolution=dataset.shape[1], label_size=dataset.label_size, func_name='training.co_mod_gan.G_main', pix2pix=dataset.pix2pix) Gs_new = G.clone('Gs') Gs_new.copy_vars_from(Gs) Gs = Gs_new for t in truncations: print('truncation={}'.format(t)) self._results = [] time_begin = time.time() Gs_kwargs.update(truncation_psi_val=t) self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus) self._report_progress(1, 1) if num_repeats > 1: records = [ dnnlib.EasyDict(value=[res.value], suffix=res.suffix, fmt=res.fmt) for res in self._results ] for i in range(1, num_repeats): print(self.get_result_str().strip()) self._results = [] self._report_progress(0, 1) self._evaluate(Gs, Gs_kwargs=Gs_kwargs, num_gpus=num_gpus) self._report_progress(1, 1) for rec, res in zip(records, self._results): rec.value.append(res.value) self._results = [] for rec in records: self._report_result(np.mean(rec.value), rec.suffix, rec.fmt) self._report_result(np.std(rec.value), rec.suffix + '-std', rec.fmt) self._eval_time = time.time() - time_begin # pylint: disable=attribute-defined-outside-init if log_results: if run_dir is not None: log_file = os.path.join(run_dir, 'metric-%s.txt' % self.name) with dnnlib.util.Logger(log_file, 'a'): print(self.get_result_str().strip()) else: print(self.get_result_str().strip())
def embed(batch_size, resolution, img, network, iteration, seed=6600): tf.reset_default_graph() print('Loading networks from "%s"...' % network) tflib.init_tf() _, _, G = pretrained_networks.load_networks(network) img_in = tf.constant(img) opt = tf.train.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8) noise_vars = [var for name, var in G.components.synthesis.vars.items() if name.startswith('noise')] G_kwargs = dnnlib.EasyDict() G_kwargs.randomize_noise = False G_syn = G.components.synthesis loss_list = [] p_loss_list = [] m_loss_list = [] dl_list = [] si_list = [] rnd = np.random.RandomState(seed) dlatent_avg = [var for name, var in G.vars.items() if name.startswith('dlatent_avg')][0].eval() dlatent_avg = np.expand_dims(np.expand_dims(dlatent_avg, 0), 1) dlatent_avg = dlatent_avg.repeat(12, 1) dlatent = tf.get_variable('dlatent', dtype=tf.float32, initializer=tf.constant(dlatent_avg), trainable=True) synth_img = G_syn.get_output_for(dlatent, is_training=False, **G_kwargs) # synth_img = (synth_img + 1.0) / 2.0 with tf.variable_scope('mse_loss'): mse_loss = tf.reduce_mean(tf.square(img_in - synth_img)) with tf.variable_scope('perceptual_loss'): vgg_in = tf.concat([img_in, synth_img], 0) tf.keras.backend.set_image_data_format('channels_first') vgg = tf.keras.applications.VGG16(include_top=False, input_tensor=vgg_in, input_shape=(3, 128, 128), weights='/gdata2/fengrl/metrics/vgg.h5', pooling=None) h1 = vgg.get_layer('block1_conv1').output h2 = vgg.get_layer('block1_conv2').output h3 = vgg.get_layer('block3_conv2').output h4 = vgg.get_layer('block4_conv2').output pcep_loss = tf.reduce_mean(tf.square(h1[0] - h1[1])) + tf.reduce_mean(tf.square(h2[0] - h2[1])) + \ tf.reduce_mean(tf.square(h3[0] - h3[1])) + tf.reduce_mean(tf.square(h4[0] - h4[1])) loss = 0.5 * mse_loss + 0.5 * pcep_loss with tf.control_dependencies([loss]): train_op = opt.minimize(loss, var_list=[dlatent]) tflib.init_uninitialized_vars() # rnd = np.random.RandomState(seed) tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] for i in range(iteration): loss_, p_loss_, m_loss_, dl_, si_, _ = tflib.run([loss, pcep_loss, mse_loss, dlatent, synth_img, train_op]) loss_list.append(loss_) p_loss_list.append(p_loss_) m_loss_list.append(m_loss_) dl_loss_ = np.sum(np.square(dl_-dlatent_avg)) dl_list.append(dl_loss_) if i % 500 == 0: si_list.append(si_) if i % 100 == 0: print('Loss %f, mse %f, ppl %f, dl %f, step %d' % (loss_, m_loss_, p_loss_, dl_loss_, i)) return loss_list, p_loss_list, m_loss_list, dl_list, si_list
import os import sys import warnings import numpy as np import torch import dnnlib import traceback from .. import custom_ops from .. import misc # ---------------------------------------------------------------------------- activation_funcs = { 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9,
def training_loop( run_dir='.', # Output directory. training_set_kwargs={}, # Options for training set. data_loader_kwargs={}, # Options for torch.utils.data.DataLoader. G_kwargs={}, # Options for generator network. D_kwargs={}, # Options for discriminator network. G_opt_kwargs={}, # Options for generator optimizer. D_opt_kwargs={}, # Options for discriminator optimizer. augment_kwargs=None, # Options for augmentation pipeline. None = disable. loss_kwargs={}, # Options for loss function. metrics=[], # Metrics to evaluate during training. random_seed=0, # Global random seed. num_gpus=1, # Number of GPUs participating in the training. rank=0, # Rank of the current process in [0, num_gpus[. batch_size=4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu=4, # Number of samples processed at a time by one GPU. ema_kimg=10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup=None, # EMA ramp-up coefficient. G_reg_interval=4, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval=16, # How often to perform regularization for D? None = disable lazy regularization. augment_p=0, # Initial value of augmentation probability. ada_target=None, # ADA target value. None = fixed p. ada_interval=4, # How often to perform ADA adjustment? ada_kimg=500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. nimg=0, # current image count total_kimg=25000, # Total length of the training, measured in thousands of real images. kimg_per_tick=4, # Progress snapshot interval. image_snapshot_ticks=50, # How often to save image snapshots? None = disable. network_snapshot_ticks=50, # How often to save network snapshots? None = disable. resume_pkl=None, # Network pickle to resume training from. cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? allow_tf32=False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32? abort_fn=None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn=None, # Callback function for updating training progress. Called for all ranks. ): # Initialize. start_time = time.time() device = torch.device('cuda', rank) np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. # Load training set. if rank == 0: print('Loading training set...') training_set = dnnlib.util.construct_class_by_name( **training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter( torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size // num_gpus, **data_loader_kwargs)) if rank == 0: print() print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) print() # Construct networks. if rank == 0: print('Constructing networks...') common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name( **G_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name( **D_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module G_ema = copy.deepcopy(G).eval() # Resume from existing pickle. if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: resume_data = legacy.load_network_pkl(f) for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) # Print network summary tables. if rank == 0: z = torch.empty([batch_gpu, G.z_dim], device=device) c = torch.empty([batch_gpu, G.c_dim], device=device) img = misc.print_module_summary(G, [z, c]) misc.print_module_summary(D, [img, c]) # Setup augmentation. if rank == 0: print('Setting up augmentation...') augment_pipe = None ada_stats = None if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): augment_pipe = dnnlib.util.construct_class_by_name( **augment_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module augment_pipe.p.copy_(torch.as_tensor(augment_p)) if ada_target is not None: ada_stats = training_stats.Collector(regex='Loss/signs/real') # Distribute across GPUs. if rank == 0: print(f'Distributing across {num_gpus} GPUs...') ddp_modules = dict() for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]: if (num_gpus > 1) and (module is not None) and len( list(module.parameters())) != 0: module.requires_grad_(True) module = torch.nn.parallel.DistributedDataParallel( module, device_ids=[device], broadcast_buffers=False) module.requires_grad_(False) if name is not None: ddp_modules[name] = module # Setup training phases. if rank == 0: print('Setting up training phases...') loss = dnnlib.util.construct_class_by_name( device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss phases = [] for name, module, opt_kwargs, reg_interval in [ ('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval) ]: if reg_interval is None: opt = dnnlib.util.construct_class_by_name( params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'both', module=module, opt=opt, interval=1) ] else: # Lazy regularization. mb_ratio = reg_interval / (reg_interval + 1) opt_kwargs = dnnlib.EasyDict(opt_kwargs) opt_kwargs.lr = opt_kwargs.lr * mb_ratio opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas] opt = dnnlib.util.construct_class_by_name( module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'main', module=module, opt=opt, interval=1) ] phases += [ dnnlib.EasyDict(name=name + 'reg', module=module, opt=opt, interval=reg_interval) ] for phase in phases: phase.start_event = None phase.end_event = None if rank == 0: phase.start_event = torch.cuda.Event(enable_timing=True) phase.end_event = torch.cuda.Event(enable_timing=True) # Export sample images. grid_size = None grid_z = None grid_c = None if rank == 0: print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid( training_set=training_set) save_image_grid(images, os.path.join(run_dir, 'reals.jpg'), drange=[0, 255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1, 1], grid_size=grid_size) # Initialize logs. if rank == 0: print('Initializing logs...') stats_collector = training_stats.Collector(regex='.*') stats_metrics = dict() stats_jsonl = None stats_tfevents = None if rank == 0: stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') try: import torch.utils.tensorboard as tensorboard stats_tfevents = tensorboard.SummaryWriter(run_dir) except ImportError as err: print('Skipping tfevents export:', err) # Train. if rank == 0: print(f'Training for {total_kimg} kimg...') print() cur_nimg = nimg cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: progress_fn(0, total_kimg) while True: # Fetch training data. with torch.autograd.profiler.record_function('data_fetch'): phase_real_img, phase_real_c = next(training_set_iterator) phase_real_img = ( phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) phase_real_c = phase_real_c.to(device).split(batch_gpu) all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) all_gen_z = [ phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size) ] all_gen_c = [ training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size) ] all_gen_c = torch.from_numpy( np.stack(all_gen_c)).pin_memory().to(device) all_gen_c = [ phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size) ] # Execute training phases. for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): if batch_idx % phase.interval != 0: continue # Initialize gradient accumulation. if phase.start_event is not None: phase.start_event.record(torch.cuda.current_stream(device)) phase.opt.zero_grad(set_to_none=True) phase.module.requires_grad_(True) # Accumulate gradients over multiple rounds. for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate( zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)): sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1) gain = phase.interval loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain) # Update weights. phase.module.requires_grad_(False) with torch.autograd.profiler.record_function(phase.name + '_opt'): for param in phase.module.parameters(): if param.grad is not None: misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) phase.opt.step() if phase.end_event is not None: phase.end_event.record(torch.cuda.current_stream(device)) # Update G_ema. with torch.autograd.profiler.record_function('Gema'): ema_nimg = ema_kimg * 1000 if ema_rampup is not None: ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8)) for p_ema, p in zip(G_ema.parameters(), G.parameters()): p_ema.copy_(p.lerp(p_ema, ema_beta)) for b_ema, b in zip(G_ema.buffers(), G.buffers()): b_ema.copy_(b) # Update state. cur_nimg += batch_size batch_idx += 1 # Execute ADA heuristic. if (ada_stats is not None) and (batch_idx % ada_interval == 0): ada_stats.update() adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * ( batch_size * ada_interval) / (ada_kimg * 1000) augment_pipe.p.copy_( (augment_pipe.p + adjust).max(misc.constant(0, device=device))) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and ( cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue # Print status line, accumulating the same information in stats_collector. tick_end_time = time.time() fields = [] fields += [ f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}" ] fields += [ f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}" ] fields += [ f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}" ] fields += [ f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}" ] fields += [ f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}" ] fields += [ f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}" ] fields += [ f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" ] fields += [ f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}" ] torch.cuda.reset_peak_memory_stats() fields += [ f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}" ] training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) if rank == 0: print(' '.join(fields)) # Check for abort. if (not done) and (abort_fn is not None) and abort_fn(): done = True if rank == 0: print() print('Aborting...') # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and ( done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.jpg'), drange=[-1, 1], grid_size=grid_size) # Save network snapshot. snapshot_pkl = None snapshot_data = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]: if module is not None: if num_gpus > 1: misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg') module = copy.deepcopy(module).eval().requires_grad_( False).cpu() snapshot_data[name] = module del module # conserve memory snapshot_pkl = os.path.join( run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') if rank == 0: with open(snapshot_pkl, 'wb') as f: pickle.dump(snapshot_data, f) # Evaluate metrics. if (snapshot_data is not None) and (len(metrics) > 0): if rank == 0: print('Evaluating metrics...') for metric in metrics: result_dict = metric_main.calc_metric( metric=metric, G=snapshot_data['G_ema'], dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) if rank == 0: metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) stats_metrics.update(result_dict.results) del snapshot_data # conserve memory # Collect statistics. for phase in phases: value = [] if (phase.start_event is not None) and (phase.end_event is not None): phase.end_event.synchronize() value = phase.start_event.elapsed_time(phase.end_event) training_stats.report0('Timing/' + phase.name, value) stats_collector.update() stats_dict = stats_collector.as_dict() # Update logs. timestamp = time.time() if stats_jsonl is not None: fields = dict(stats_dict, timestamp=timestamp) stats_jsonl.write(json.dumps(fields) + '\n') stats_jsonl.flush() if stats_tfevents is not None: global_step = int(cur_nimg / 1e3) walltime = timestamp - start_time for name, value in stats_dict.items(): stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) for name, value in stats_metrics.items(): stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) stats_tfevents.flush() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Update state. cur_tick += 1 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time if done: break # Done. if rank == 0: print() print('Exiting...')
def embed(batch_size, resolution, imgs, network, iteration, result_dir, seed=6600): tf.reset_default_graph() print('Loading networks from "%s"...' % network) tflib.init_tf() _, _, G = pretrained_networks.load_networks(network) img_in = tf.placeholder(tf.float32) opt = tf.train.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8) noise_vars = [ var for name, var in G.components.synthesis.vars.items() if name.startswith('noise') ] alpha_vars = [ var for name, var in G.components.synthesis.vars.items() if name.endswith('alpha') ] alpha_eval = [alpha.eval() for alpha in alpha_vars] G_kwargs = dnnlib.EasyDict() G_kwargs.randomize_noise = False G_syn = G.components.synthesis rnd = np.random.RandomState(seed) dlatent_avg = [ var for name, var in G.vars.items() if name.startswith('dlatent_avg') ][0].eval() dlatent_avg = np.expand_dims(np.expand_dims(dlatent_avg, 0), 1) dlatent_avg = dlatent_avg.repeat(12, 1) dlatent = tf.get_variable('dlatent', dtype=tf.float32, initializer=tf.constant(dlatent_avg), trainable=True) synth_img = G_syn.get_output_for(dlatent, is_training=False, **G_kwargs) # synth_img = (synth_img + 1.0) / 2.0 with tf.variable_scope('mse_loss'): mse_loss = tf.reduce_mean(tf.square(img_in - synth_img)) with tf.variable_scope('perceptual_loss'): vgg_in = tf.concat([img_in, synth_img], 0) tf.keras.backend.set_image_data_format('channels_first') vgg = tf.keras.applications.VGG16( include_top=False, input_tensor=vgg_in, input_shape=(3, 128, 128), weights='/gdata2/fengrl/metrics/vgg.h5', pooling=None) h1 = vgg.get_layer('block1_conv1').output h2 = vgg.get_layer('block1_conv2').output h3 = vgg.get_layer('block3_conv2').output h4 = vgg.get_layer('block4_conv2').output pcep_loss = tf.reduce_mean(tf.square(h1[0] - h1[1])) + tf.reduce_mean(tf.square(h2[0] - h2[1])) + \ tf.reduce_mean(tf.square(h3[0] - h3[1])) + tf.reduce_mean(tf.square(h4[0] - h4[1])) loss = 0.5 * mse_loss + 0.5 * pcep_loss with tf.control_dependencies([loss]): train_op = opt.minimize(loss, var_list=[dlatent]) reset_opt = tf.variables_initializer(opt.variables()) reset_dl = tf.variables_initializer([dlatent]) tflib.init_uninitialized_vars() # rnd = np.random.RandomState(seed) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] idx = 0 metrics_l = [] metrics_p = [] metrics_m = [] metrics_d = [] metrics_args = [metric_defaults[x] for x in ['fid50k', 'ppl_wend']] metrics_fun = metric_base.MetricGroup(metrics_args) for temperature in [0.2, 0.5, 1.0, 1.5, 2.0, 10.0]: tflib.set_vars({ alpha: scale_alpha(alpha_np, temperature) for alpha, alpha_np in zip(alpha_vars, alpha_eval) }) # misc.save_pkl((G, G, G), os.path.join(result_dir, 'temp%f.pkl' % temperature)) # metrics_fun.run(os.path.join(result_dir, 'temp%f.pkl' % temperature), run_dir=result_dir, # data_dir='/gdata/fengrl/noise_test_dset/tfrecords', # dataset_args=dnnlib.EasyDict(tfrecord_dir='ffhq-128', shuffle_mb=0), # mirror_augment=True, num_gpus=1) for img in imgs: img = np.expand_dims(img, 0) loss_list = [] p_loss_list = [] m_loss_list = [] dl_list = [] si_list = [] tflib.run([reset_opt, reset_dl]) for i in range(iteration): loss_, p_loss_, m_loss_, dl_, si_, _ = tflib.run( [loss, pcep_loss, mse_loss, dlatent, synth_img, train_op], {img_in: img}) loss_list.append(loss_) p_loss_list.append(p_loss_) m_loss_list.append(m_loss_) dl_loss_ = np.sum(np.square(dl_ - dlatent_avg)) dl_list.append(dl_loss_) if i % 500 == 0: si_list.append(si_) if i % 100 == 0: print( 'Temperature %f, idx %d, Loss %f, mse %f, ppl %f, dl %f, step %d' % (temperature, idx, loss_, m_loss_, p_loss_, dl_loss_, i)) print('Temperature %f, idx %d, loss: %f, ppl: %f, mse: %f, d: %f' % (temperature, idx, loss_list[-1], p_loss_list[-1], m_loss_list[-1], dl_list[-1])) metrics_l.append(loss_list[-1]) metrics_p.append(p_loss_list[-1]) metrics_m.append(m_loss_list[-1]) metrics_d.append(dl_list[-1]) misc.save_image_grid(np.concatenate(si_list, 0), os.path.join( result_dir, 'temp%fsi%d.png' % (temperature, idx)), drange=[-1, 1]) misc.save_image_grid( si_list[-1], os.path.join(result_dir, 'temp%fsifinal%d.png' % (temperature, idx)), drange=[-1, 1]) with open( os.path.join(result_dir, 'temp%fmetric_l%d.txt' % (temperature, idx)), 'w') as f: for l_ in loss_list: f.write(str(l_) + '\n') with open( os.path.join(result_dir, 'temp%fmetric_p%d.txt' % (temperature, idx)), 'w') as f: for l_ in p_loss_list: f.write(str(l_) + '\n') with open( os.path.join(result_dir, 'temp%fmetric_m%d.txt' % (temperature, idx)), 'w') as f: for l_ in m_loss_list: f.write(str(l_) + '\n') with open( os.path.join(result_dir, 'temp%fmetric_d%d.txt' % (temperature, idx)), 'w') as f: for l_ in dl_list: f.write(str(l_) + '\n') idx += 1 l_mean = np.mean(metrics_l) p_mean = np.mean(metrics_p) m_mean = np.mean(metrics_m) d_mean = np.mean(metrics_d) with open( os.path.join(result_dir, 'Temp%fmetric_lmpd.txt' % temperature), 'w') as f: for i in range(len(metrics_l)): f.write( str(metrics_l[i]) + ' ' + str(metrics_m[i]) + ' ' + str(metrics_p[i]) + ' ' + str(metrics_d[i]) + '\n') print( 'Overall metrics: temp %f, loss_mean %f, ppl_mean %f, mse_mean %f, d_mean %f' % (temperature, l_mean, p_mean, m_mean, d_mean)) with open(os.path.join(result_dir, 'mean_metrics'), 'a') as f: f.write('Temperature %f\n' % temperature) f.write('loss %f\n' % l_mean) f.write('mse %f\n' % m_mean) f.write('ppl %f\n' % p_mean) f.write('dl %f\n' % d_mean)
batch_size = {256: 16, 512: 9, 1024: 4} n_sample = batch_size.get(size, 25) g = g.to(device) z = np.random.RandomState(0).randn(n_sample, 512).astype("float32") with torch.no_grad(): img_pt, _ = g( [torch.from_numpy(z).to(device)], truncation=0.5, truncation_latent=latent_avg.to(device), use_fixed_noise=True, ) Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.use_fixed_noise = True img_tf = g_ema.run(z, None, **Gs_kwargs) img_tf = torch.from_numpy(img_tf).to(device) img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp( 0.0, 1.0 ) img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) print(img_diff.abs().max()) utils.save_image( img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1) )
def calc_metrics(network_pkl, metric_names, metricdata, mirror, gpus): tflib.init_tf() # Initialize metrics. metrics = [] for name in metric_names: if name not in metric_defaults.metric_defaults: raise UserError('\n'.join( ['--metrics can only contain the following values:', 'none'] + list(metric_defaults.metric_defaults.keys()))) metrics.append( dnnlib.util.construct_class_by_name( **metric_defaults.metric_defaults[name])) # Load network. if not dnnlib.util.is_url( network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): raise UserError('--network must point to a file or URL') print(f'Loading network from "{network_pkl}"...') with dnnlib.util.open_url(network_pkl) as f: _G, _D, Gs = pickle.load(f) Gs.print_layers() # Look up training options. run_dir = None training_options = None if os.path.isfile(network_pkl): potential_run_dir = os.path.dirname(network_pkl) potential_json_file = os.path.join(potential_run_dir, 'training_options.json') if os.path.isfile(potential_json_file): print( f'Looking up training options from "{potential_json_file}"...') run_dir = potential_run_dir with open(potential_json_file, 'rt') as f: training_options = json.load(f, object_pairs_hook=dnnlib.EasyDict) if training_options is None: print( 'Could not look up training options; will rely on --metricdata and --mirror' ) # Choose dataset options. dataset_options = dnnlib.EasyDict() if training_options is not None: dataset_options.update(training_options.metric_dataset_args) dataset_options.resolution = Gs.output_shapes[0][-1] dataset_options.max_label_size = Gs.input_shapes[1][-1] if metricdata is not None: if not os.path.isdir(metricdata): raise UserError( '--metricdata must point to a directory containing *.tfrecords' ) dataset_options.path = metricdata if mirror is not None: dataset_options.mirror_augment = mirror if 'path' not in dataset_options: raise UserError('--metricdata must be specified explicitly') # Print dataset options. print() print('Dataset options:') print(json.dumps(dataset_options, indent=2)) # Evaluate metrics. for metric in metrics: print() print(f'Evaluating {metric.name}...') metric.configure(dataset_args=dataset_options, run_dir=run_dir) metric.run(network_pkl=network_pkl, num_gpus=gpus)
def convert_tf_generator(tf_G): if tf_G.version < 4: raise ValueError('TensorFlow pickle version too low') # Collect kwargs. tf_kwargs = tf_G.static_kwargs known_kwargs = set() def kwarg(tf_name, default=None, none=None): known_kwargs.add(tf_name) val = tf_kwargs.get(tf_name, default) return val if val is not None else none # Convert kwargs. kwargs = dnnlib.EasyDict( z_dim=kwarg('latent_size', 512), c_dim=kwarg('label_size', 0), w_dim=kwarg('dlatent_size', 512), img_resolution=kwarg('resolution', 1024), img_channels=kwarg('num_channels', 3), mapping_kwargs=dnnlib.EasyDict( num_layers=kwarg('mapping_layers', 8), embed_features=kwarg('label_fmaps', None), layer_features=kwarg('mapping_fmaps', None), activation=kwarg('mapping_nonlinearity', 'lrelu'), lr_multiplier=kwarg('mapping_lrmul', 0.01), w_avg_beta=kwarg('w_avg_beta', 0.995, none=1), ), synthesis_kwargs=dnnlib.EasyDict( channel_base=kwarg('fmap_base', 16384) * 2, channel_max=kwarg('fmap_max', 512), num_fp16_res=kwarg('num_fp16_res', 0), conv_clamp=kwarg('conv_clamp', None), architecture=kwarg('architecture', 'skip'), resample_filter=kwarg('resample_kernel', [1, 3, 3, 1]), use_noise=kwarg('use_noise', True), activation=kwarg('nonlinearity', 'lrelu'), ), ) # Check for unknown kwargs. kwarg('truncation_psi') kwarg('truncation_cutoff') kwarg('style_mixing_prob') kwarg('structure') unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) if len(unknown_kwargs) > 0: raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) # Collect params. tf_params = _collect_tf_params(tf_G) for name, value in list(tf_params.items()): match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) if match: r = kwargs.img_resolution // (2**int(match.group(1))) tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value kwargs.synthesis.kwargs.architecture = 'orig' #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') # Convert params. from training import networks G = networks.Generator(**kwargs).eval().requires_grad_(False) # pylint: disable=unnecessary-lambda _populate_module_params( G, r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight' ][::-1, ::-1].transpose(3, 2, 0, 1), r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0 ], r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight']. transpose(), r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose( 3, 2, 0, 1), r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0 ], r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose( 3, 2, 0, 1), r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight' ][::-1, ::-1].transpose(3, 2, 0, 1), r'.*\.resample_filter', None, ) return G
def get_details(self, idx): d = dnnlib.EasyDict() d.raw_idx = int(self._raw_idx[idx]) d.xflip = (int(self._xflip[idx]) != 0) d.raw_label = self._get_raw_labels()[d.raw_idx].copy() return d
def setup_training_options(data_path, out_dir, resume=None, mirror=True): ########################################################################################################################### # EDIT THESE! # ########################################################################################################################### outdir = out_dir gpus = None # Number of GPUs: <int>, default = 1 gpu snap = 1 # Snapshot interval: <int>, default = 50 ticks seed = 1000 data = data_path # Training dataset (required): <path> res = None # Override dataset resolution: <int>, default = highest available mirror = mirror # Augment dataset with x-flips: <bool>, default = False metrics = [] # List of metric names: [], ['fid50k_full'] (default), ... metricdata = None # Metric dataset (optional): <path> cfg = 'stylegan2' # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'cifarbaseline' gamma = None # Override R1 gamma: <float>, default = depends on cfg kimg = 10000 # Override training duration: <int>, default = depends on cfg aug = 'ada' # Augmentation mode: 'ada' (default), 'noaug', 'fixed', 'adarv' p = None # Specify p for 'fixed' (required): <float> target = None # Override ADA target for 'ada' and 'adarv': <float>, default = depends on aug augpipe = 'bgc' # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc' cmethod = None # Comparison method: 'nocmethod' (default), 'bcr', 'zcr', 'pagan', 'wgangp', 'auxrot', 'spectralnorm', 'shallowmap', 'adropout' dcap = None # Multiplier for discriminator capacity: <float>, default = 1 augpipe = 'bgc' resume = resume # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url> freezed = None # Freeze-D: <int>, default = 0 discriminator layers ########################################################################################################################### # End of Edit Section # ########################################################################################################################### tflib.init_tf({'rnd.np_random_seed': seed}) # Initialize dicts. args = dnnlib.EasyDict() args.G_args = dnnlib.EasyDict(func_name='training.networks.G_main') args.D_args = dnnlib.EasyDict(func_name='training.networks.D_main') args.G_opt_args = dnnlib.EasyDict(beta1=0.0, beta2=0.99) args.D_opt_args = dnnlib.EasyDict(beta1=0.0, beta2=0.99) args.loss_args = dnnlib.EasyDict(func_name='training.loss.stylegan2') args.augment_args = dnnlib.EasyDict( class_name='training.augment.AdaptiveAugment') # --------------------------- # General options: gpus, snap # --------------------------- if gpus is None: gpus = 1 assert isinstance(gpus, int) if not (gpus >= 1 and gpus & (gpus - 1) == 0): raise UserError('--gpus must be a power of two') args.num_gpus = gpus if snap is None: snap = 50 assert isinstance(snap, int) if snap < 1: raise UserError('--snap must be at least 1') args.image_snapshot_ticks = snap args.network_snapshot_ticks = snap # ----------------------------------- # Training dataset: data, res, mirror # ----------------------------------- assert data is not None assert isinstance(data, str) data_name = os.path.basename(os.path.abspath(data)) if not os.path.isdir(data) or len(data_name) == 0: raise UserError( '--data must point to a directory containing *.tfrecords') desc = data_name with tf.Graph().as_default(), tflib.create_session().as_default(): # pylint: disable=not-context-manager args.train_dataset_args = dnnlib.EasyDict(path=data, max_label_size='full') dataset_obj = dataset.load_dataset( **args.train_dataset_args ) # try to load the data and see what comes out args.train_dataset_args.resolution = dataset_obj.shape[ -1] # be explicit about resolution args.train_dataset_args.max_label_size = dataset_obj.label_size # be explicit about label size validation_set_available = dataset_obj.has_validation_set dataset_obj.close() dataset_obj = None if res is None: res = args.train_dataset_args.resolution else: assert isinstance(res, int) if not (res >= 4 and res & (res - 1) == 0): raise UserError('--res must be a power of two and at least 4') if res > args.train_dataset_args.resolution: raise UserError( f'--res cannot exceed maximum available resolution in the dataset ({args.train_dataset_args.resolution})' ) desc += f'-res{res:d}' args.train_dataset_args.resolution = res if mirror is None: mirror = False else: assert isinstance(mirror, bool) if mirror: desc += '-mirror' args.train_dataset_args.mirror_augment = mirror # ---------------------------- # Metrics: metrics, metricdata # ---------------------------- if metrics is None: metrics = ['fid50k_full'] assert isinstance(metrics, list) assert all(isinstance(metric, str) for metric in metrics) args.metric_arg_list = [] for metric in metrics: if metric not in metric_defaults.metric_defaults: raise UserError('\n'.join( ['--metrics can only contain the following values:', 'none'] + list(metric_defaults.metric_defaults.keys()))) args.metric_arg_list.append(metric_defaults.metric_defaults[metric]) args.metric_dataset_args = dnnlib.EasyDict(args.train_dataset_args) if metricdata is not None: assert isinstance(metricdata, str) if not os.path.isdir(metricdata): raise UserError( '--metricdata must point to a directory containing *.tfrecords' ) args.metric_dataset_args.path = metricdata # ----------------------------- # Base config: cfg, gamma, kimg # ----------------------------- if cfg is None: cfg = 'auto' assert isinstance(cfg, str) desc += f'-{cfg}' cfg_specs = { 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # populated dynamically based on 'gpus' and 'res' 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # uses mixed-precision, unlike original StyleGAN2 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8), 'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=0.5, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2), 'cifarbaseline': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=0.5, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=8), } assert cfg in cfg_specs spec = dnnlib.EasyDict(cfg_specs[cfg]) if cfg == 'auto': desc += f'{gpus:d}' spec.ref_gpus = gpus spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay spec.mbstd = min( spec.mb // gpus, 4 ) # other hyperparams behave more predictably if mbstd group size remains fixed spec.fmaps = 1 if res >= 512 else 0.5 spec.lrate = 0.002 if res >= 1024 else 0.0025 spec.gamma = 0.0002 * (res**2) / spec.mb # heuristic formula spec.ema = spec.mb * 10 / 32 args.total_kimg = spec.kimg args.minibatch_size = spec.mb args.minibatch_gpu = spec.mb // spec.ref_gpus args.D_args.mbstd_group_size = spec.mbstd args.G_args.fmap_base = args.D_args.fmap_base = int(spec.fmaps * 16384) args.G_args.fmap_max = args.D_args.fmap_max = 512 args.G_opt_args.learning_rate = args.D_opt_args.learning_rate = spec.lrate args.loss_args.r1_gamma = spec.gamma args.G_smoothing_kimg = spec.ema args.G_smoothing_rampup = spec.ramp args.G_args.mapping_layers = spec.map args.G_args.num_fp16_res = args.D_args.num_fp16_res = 4 # enable mixed-precision training args.G_args.conv_clamp = args.D_args.conv_clamp = 256 # clamp activations to avoid float16 overflow if cfg == 'cifar': args.loss_args.pl_weight = 0 # disable path length regularization args.G_args.style_mixing_prob = None # disable style mixing args.D_args.architecture = 'orig' # disable residual skip connections if gamma is not None: assert isinstance(gamma, float) if not gamma >= 0: raise UserError('--gamma must be non-negative') desc += f'-gamma{gamma:g}' args.loss_args.r1_gamma = gamma if kimg is not None: assert isinstance(kimg, int) if not kimg >= 1: raise UserError('--kimg must be at least 1') desc += f'-kimg{kimg:d}' args.total_kimg = kimg # --------------------------------------------------- # Discriminator augmentation: aug, p, target, augpipe # --------------------------------------------------- if aug is None: aug = 'ada' else: assert isinstance(aug, str) desc += f'-{aug}' if aug == 'ada': args.augment_args.tune_heuristic = 'rt' args.augment_args.tune_target = 0.6 elif aug == 'noaug': pass elif aug == 'fixed': if p is None: raise UserError(f'--aug={aug} requires specifying --p') elif aug == 'adarv': if not validation_set_available: raise UserError( f'--aug={aug} requires separate validation set; please see "python dataset_tool.py pack -h"' ) args.augment_args.tune_heuristic = 'rv' args.augment_args.tune_target = 0.5 else: raise UserError(f'--aug={aug} not supported') if p is not None: assert isinstance(p, float) if aug != 'fixed': raise UserError('--p can only be specified with --aug=fixed') if not 0 <= p <= 1: raise UserError('--p must be between 0 and 1') desc += f'-p{p:g}' args.augment_args.initial_strength = p if target is not None: assert isinstance(target, float) if aug not in ['ada', 'adarv']: raise UserError( '--target can only be specified with --aug=ada or --aug=adarv') if not 0 <= target <= 1: raise UserError('--target must be between 0 and 1') desc += f'-target{target:g}' args.augment_args.tune_target = target assert augpipe is None or isinstance(augpipe, str) if augpipe is None: augpipe = 'bgc' else: if aug == 'noaug': raise UserError('--augpipe cannot be specified with --aug=noaug') desc += f'-{augpipe}' augpipe_specs = { 'blit': dict(xflip=1, rotate90=1, xint=1), 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1), 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'filter': dict(imgfilter=1), 'noise': dict(noise=1), 'cutout': dict(cutout=1), 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1), 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1), 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1), } assert augpipe in augpipe_specs if aug != 'noaug': args.augment_args.apply_func = 'training.augment.augment_pipeline' args.augment_args.apply_args = augpipe_specs[augpipe] # --------------------------------- # Comparison methods: cmethod, dcap # --------------------------------- assert cmethod is None or isinstance(cmethod, str) if cmethod is None: cmethod = 'nocmethod' else: desc += f'-{cmethod}' if cmethod == 'nocmethod': pass elif cmethod == 'bcr': args.loss_args.func_name = 'training.loss.cmethods' args.loss_args.bcr_real_weight = 10 args.loss_args.bcr_fake_weight = 10 args.loss_args.bcr_augment = dnnlib.EasyDict( func_name='training.augment.augment_pipeline', xint=1, xint_max=1 / 32) elif cmethod == 'zcr': args.loss_args.func_name = 'training.loss.cmethods' args.loss_args.zcr_gen_weight = 0.02 args.loss_args.zcr_dis_weight = 0.2 args.G_args.num_fp16_res = args.D_args.num_fp16_res = 0 # disable mixed-precision training args.G_args.conv_clamp = args.D_args.conv_clamp = None elif cmethod == 'pagan': if aug != 'noaug': raise UserError( f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug' ) args.D_args.use_pagan = True args.augment_args.tune_heuristic = 'rt' # enable ada heuristic args.augment_args.pop('apply_func', None) # disable discriminator augmentation args.augment_args.pop('apply_args', None) args.augment_args.tune_target = 0.95 elif cmethod == 'wgangp': if aug != 'noaug': raise UserError( f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug' ) if gamma is not None: raise UserError( f'--cmethod={cmethod} is not compatible with --gamma') args.loss_args = dnnlib.EasyDict(func_name='training.loss.wgangp') args.G_opt_args.learning_rate = args.D_opt_args.learning_rate = 0.001 args.G_args.num_fp16_res = args.D_args.num_fp16_res = 0 # disable mixed-precision training args.G_args.conv_clamp = args.D_args.conv_clamp = None args.lazy_regularization = False elif cmethod == 'auxrot': if args.train_dataset_args.max_label_size > 0: raise UserError( f'--cmethod={cmethod} is not compatible with label conditioning; please specify a dataset without labels' ) args.loss_args.func_name = 'training.loss.cmethods' args.loss_args.auxrot_alpha = 10 args.loss_args.auxrot_beta = 5 args.D_args.score_max = 5 # prepare D to output 5 scalars per image instead of just 1 elif cmethod == 'spectralnorm': args.D_args.use_spectral_norm = True elif cmethod == 'shallowmap': if args.G_args.mapping_layers == 2: raise UserError(f'--cmethod={cmethod} is a no-op for --cfg={cfg}') args.G_args.mapping_layers = 2 elif cmethod == 'adropout': if aug != 'noaug': raise UserError( f'--cmethod={cmethod} is not compatible with discriminator augmentation; please specify --aug=noaug' ) args.D_args.adaptive_dropout = 1 args.augment_args.tune_heuristic = 'rt' # enable ada heuristic args.augment_args.pop('apply_func', None) # disable discriminator augmentation args.augment_args.pop('apply_args', None) args.augment_args.tune_target = 0.6 else: raise UserError(f'--cmethod={cmethod} not supported') if dcap is not None: assert isinstance(dcap, float) if not dcap > 0: raise UserError('--dcap must be positive') desc += f'-dcap{dcap:g}' args.D_args.fmap_base = max(int(args.D_args.fmap_base * dcap), 1) args.D_args.fmap_max = max(int(args.D_args.fmap_max * dcap), 1) # ---------------------------------- # Transfer learning: resume, freezed # ---------------------------------- resume_specs = { 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', } assert resume is None or isinstance(resume, str) if resume is None: resume = 'noresume' elif resume == 'noresume': desc += '-noresume' elif resume in resume_specs: desc += f'-resume{resume}' args.resume_pkl = resume_specs[resume] # predefined url else: desc += '-resumecustom' args.resume_pkl = resume # custom path or url if resume != 'noresume': args.augment_args.tune_kimg = 100 # make ADA react faster at the beginning args.G_smoothing_rampup = None # disable EMA rampup if freezed is not None: assert isinstance(freezed, int) if not freezed >= 0: raise UserError('--freezed must be non-negative') desc += f'-freezed{freezed:d}' args.D_args.freeze_layers = freezed return desc, args, outdir
def generate(args): np.random.seed() device = torch.device('cpu') Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.size = args['size'] Gs_kwargs.scale_type = args['scale_type'] nHW = [int(s) for s in args['nXY'].split('-')][::-1] n_mult = nHW[0] * nHW[1] lmask = np.tile(np.asarray([[[[1]]]]), (1, n_mult, 1, 1)) Gs_kwargs.countHW = nHW Gs_kwargs.splitfine = args['splitfine'] lmask = torch.from_numpy(lmask).to(device) # load base or custom network pkl_name = osp.splitext(args['model'])[0] with dnnlib.util.open_url(pkl_name + '.pkl') as f: Gs = legacy.load_network_pkl(f, custom=False, **Gs_kwargs)['G_ema'].to( device) # type: ignore lats = [] # list of [frm,1,512] for i in range(n_mult): lat_tmp = latent_anima((1, Gs.z_dim), args['frames'], args['fstep'], cubic=args['cubic'], gauss=args['gauss']) # [frm,1,512] lats.append(lat_tmp) # list of [frm,1,512] latents = np.concatenate(lats, 1) # [frm,X,512] latents = torch.from_numpy(latents).to(device) dconst = np.zeros([1, 1, 1, 1, 1]) dconst = torch.from_numpy(dconst).to(device) # labels / conditions label_size = Gs.c_dim if label_size > 0: labels = torch.zeros((1, n_mult, label_size), device=device) # [frm,X,lbl] label_ids = [] for i in range(n_mult): label_ids.append(random.randint(0, label_size - 1)) for i, l in enumerate(label_ids): labels[:, i, l] = 1 else: labels = [None] # generate images from latent timeline latent = latents[0] # [X,512] label = labels[0 % len(labels)] latmask = lmask[0 % len(lmask)] if lmask is not None else [None] # [X,h,w] dc = dconst[0 % len(dconst)] # [X,512,4,4] # generate multi-latent result Gs = Gs.float() output = Gs(latent, label, force_fp32=True, truncation_psi=args['trunc'], noise_mode='const') output = (output.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to( torch.uint8).cpu().numpy() # save image ext = 'png' if output.shape[3] == 4 else 'jpg' filename = osp.join(args['out_dir'], "%06d.%s" % (0, ext)) return Image.fromarray(output[0], 'RGB')