def custom_generator(data, **ex_kwargs): from training import stylegan2_multi as networks try: # saved? (with new fix) fmap_base = data['G_ema'].synthesis.fmap_base except: # default from original configs fmap_base = 32768 if data['G_ema'].img_resolution >= 512 else 16384 kwargs = dnnlib.EasyDict( z_dim=data['G_ema'].z_dim, c_dim=data['G_ema'].c_dim, w_dim=data['G_ema'].w_dim, img_resolution=data['G_ema'].img_resolution, img_channels=data['G_ema'].img_channels, init_res=data['G_ema'].init_res, mapping_kwargs=dnnlib.EasyDict( num_layers=data['G_ema'].mapping.num_layers), synthesis_kwargs=dnnlib.EasyDict(channel_base=fmap_base, **ex_kwargs), ) G_out = networks.Generator(**kwargs).eval().requires_grad_(False) misc.copy_params_and_buffers(data['G_ema'], G_out, require_all=False) return G_out
def convert_tf_generator(tf_G, custom=False, **ex_kwargs): # def convert_tf_generator(tf_G): if tf_G.version < 4: raise ValueError('TensorFlow pickle version too low') # Collect kwargs. tf_kwargs = tf_G.static_kwargs known_kwargs = set() def kwarg(tf_name, default=None, none=None): known_kwargs.add(tf_name) val = tf_kwargs.get(tf_name, default) return val if val is not None else none # Convert kwargs. kwargs = dnnlib.EasyDict( z_dim=kwarg('latent_size', 512), c_dim=kwarg('label_size', 0), w_dim=kwarg('dlatent_size', 512), img_resolution=kwarg('resolution', 1024), img_channels=kwarg('num_channels', 3), mapping_kwargs=dnnlib.EasyDict( num_layers=kwarg('mapping_layers', 8), embed_features=kwarg('label_fmaps', None), layer_features=kwarg('mapping_fmaps', None), activation=kwarg('mapping_nonlinearity', 'lrelu'), lr_multiplier=kwarg('mapping_lrmul', 0.01), w_avg_beta=kwarg('w_avg_beta', 0.995, none=1), ), synthesis_kwargs=dnnlib.EasyDict( channel_base=kwarg('fmap_base', 16384) * 2, channel_max=kwarg('fmap_max', 512), num_fp16_res=kwarg('num_fp16_res', 0), conv_clamp=kwarg('conv_clamp', None), architecture=kwarg('architecture', 'skip'), resample_filter=kwarg('resample_kernel', [1, 3, 3, 1]), use_noise=kwarg('use_noise', True), activation=kwarg('nonlinearity', 'lrelu'), ), # !!! custom init_res=kwarg('init_res', [4, 4]), ) # Check for unknown kwargs. kwarg('truncation_psi') kwarg('truncation_cutoff') kwarg('style_mixing_prob') kwarg('structure') unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs) # !!! custom if custom: kwargs.synthesis_kwargs = dnnlib.EasyDict(**kwargs.synthesis_kwargs, **ex_kwargs) if len(unknown_kwargs) > 0: raise ValueError('Unknown TensorFlow kwargs:', unknown_kwargs) # raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0]) # try: # if ex_kwargs['verbose'] is True: print(kwargs.synthesis_kwargs) # except: pass # Collect params. tf_params = _collect_tf_params(tf_G) for name, value in list(tf_params.items()): match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name) if match: r = kwargs.img_resolution // (2**int(match.group(1))) tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value kwargs.synthesis.kwargs.architecture = 'orig' #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}') # Convert params. if custom: from training import stylegan2_multi as networks else: from training import networks G = networks.Generator(**kwargs).eval().requires_grad_(False) # pylint: disable=unnecessary-lambda _populate_module_params( G, r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'], r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(), r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'], r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(), r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'], r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0], r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1), r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'], r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0], r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'], r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(), r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1, r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight' ][::-1, ::-1].transpose(3, 2, 0, 1), r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'], r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0 ], r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'], r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight']. transpose(), r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1, r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose( 3, 2, 0, 1), r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'], r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0 ], r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'], r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(), r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1, r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose( 3, 2, 0, 1), r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'], r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(), r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1, r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight' ][::-1, ::-1].transpose(3, 2, 0, 1), r'.*\.resample_filter', None, ) return G