Exemplo n.º 1
0
def custom_generator(data, **ex_kwargs):
    from training import stylegan2_multi as networks
    try:  # saved? (with new fix)
        fmap_base = data['G_ema'].synthesis.fmap_base
    except:  # default from original configs
        fmap_base = 32768 if data['G_ema'].img_resolution >= 512 else 16384
    kwargs = dnnlib.EasyDict(
        z_dim=data['G_ema'].z_dim,
        c_dim=data['G_ema'].c_dim,
        w_dim=data['G_ema'].w_dim,
        img_resolution=data['G_ema'].img_resolution,
        img_channels=data['G_ema'].img_channels,
        init_res=data['G_ema'].init_res,
        mapping_kwargs=dnnlib.EasyDict(
            num_layers=data['G_ema'].mapping.num_layers),
        synthesis_kwargs=dnnlib.EasyDict(channel_base=fmap_base, **ex_kwargs),
    )
    G_out = networks.Generator(**kwargs).eval().requires_grad_(False)
    misc.copy_params_and_buffers(data['G_ema'], G_out, require_all=False)
    return G_out
Exemplo n.º 2
0
def convert_tf_generator(tf_G, custom=False, **ex_kwargs):
    # def convert_tf_generator(tf_G):
    if tf_G.version < 4:
        raise ValueError('TensorFlow pickle version too low')

    # Collect kwargs.
    tf_kwargs = tf_G.static_kwargs
    known_kwargs = set()

    def kwarg(tf_name, default=None, none=None):
        known_kwargs.add(tf_name)
        val = tf_kwargs.get(tf_name, default)
        return val if val is not None else none

    # Convert kwargs.
    kwargs = dnnlib.EasyDict(
        z_dim=kwarg('latent_size', 512),
        c_dim=kwarg('label_size', 0),
        w_dim=kwarg('dlatent_size', 512),
        img_resolution=kwarg('resolution', 1024),
        img_channels=kwarg('num_channels', 3),
        mapping_kwargs=dnnlib.EasyDict(
            num_layers=kwarg('mapping_layers', 8),
            embed_features=kwarg('label_fmaps', None),
            layer_features=kwarg('mapping_fmaps', None),
            activation=kwarg('mapping_nonlinearity', 'lrelu'),
            lr_multiplier=kwarg('mapping_lrmul', 0.01),
            w_avg_beta=kwarg('w_avg_beta', 0.995, none=1),
        ),
        synthesis_kwargs=dnnlib.EasyDict(
            channel_base=kwarg('fmap_base', 16384) * 2,
            channel_max=kwarg('fmap_max', 512),
            num_fp16_res=kwarg('num_fp16_res', 0),
            conv_clamp=kwarg('conv_clamp', None),
            architecture=kwarg('architecture', 'skip'),
            resample_filter=kwarg('resample_kernel', [1, 3, 3, 1]),
            use_noise=kwarg('use_noise', True),
            activation=kwarg('nonlinearity', 'lrelu'),
        ),
        # !!! custom
        init_res=kwarg('init_res', [4, 4]),
    )

    # Check for unknown kwargs.
    kwarg('truncation_psi')
    kwarg('truncation_cutoff')
    kwarg('style_mixing_prob')
    kwarg('structure')
    unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
    # !!! custom
    if custom:
        kwargs.synthesis_kwargs = dnnlib.EasyDict(**kwargs.synthesis_kwargs,
                                                  **ex_kwargs)
    if len(unknown_kwargs) > 0:
        raise ValueError('Unknown TensorFlow kwargs:', unknown_kwargs)
        # raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
    # try:
    # if ex_kwargs['verbose'] is True: print(kwargs.synthesis_kwargs)
    # except: pass

    # Collect params.
    tf_params = _collect_tf_params(tf_G)
    for name, value in list(tf_params.items()):
        match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
        if match:
            r = kwargs.img_resolution // (2**int(match.group(1)))
            tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
            kwargs.synthesis.kwargs.architecture = 'orig'
    #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')

    # Convert params.
    if custom:
        from training import stylegan2_multi as networks
    else:
        from training import networks
    G = networks.Generator(**kwargs).eval().requires_grad_(False)
    # pylint: disable=unnecessary-lambda
    _populate_module_params(
        G,
        r'mapping\.w_avg',
        lambda: tf_params[f'dlatent_avg'],
        r'mapping\.embed\.weight',
        lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
        r'mapping\.embed\.bias',
        lambda: tf_params[f'mapping/LabelEmbed/bias'],
        r'mapping\.fc(\d+)\.weight',
        lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
        r'mapping\.fc(\d+)\.bias',
        lambda i: tf_params[f'mapping/Dense{i}/bias'],
        r'synthesis\.b4\.const',
        lambda: tf_params[f'synthesis/4x4/Const/const'][0],
        r'synthesis\.b4\.conv1\.weight',
        lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
        r'synthesis\.b4\.conv1\.bias',
        lambda: tf_params[f'synthesis/4x4/Conv/bias'],
        r'synthesis\.b4\.conv1\.noise_const',
        lambda: tf_params[f'synthesis/noise0'][0, 0],
        r'synthesis\.b4\.conv1\.noise_strength',
        lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
        r'synthesis\.b4\.conv1\.affine\.weight',
        lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
        r'synthesis\.b4\.conv1\.affine\.bias',
        lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.conv0\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r'synthesis\.b(\d+)\.conv0\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
        r'synthesis\.b(\d+)\.conv0\.noise_const',
        lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0
                                                                          ],
        r'synthesis\.b(\d+)\.conv0\.noise_strength',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
        r'synthesis\.b(\d+)\.conv0\.affine\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].
        transpose(),
        r'synthesis\.b(\d+)\.conv0\.affine\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.conv1\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(
            3, 2, 0, 1),
        r'synthesis\.b(\d+)\.conv1\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
        r'synthesis\.b(\d+)\.conv1\.noise_const',
        lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0
                                                                          ],
        r'synthesis\.b(\d+)\.conv1\.noise_strength',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
        r'synthesis\.b(\d+)\.conv1\.affine\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
        r'synthesis\.b(\d+)\.conv1\.affine\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.torgb\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(
            3, 2, 0, 1),
        r'synthesis\.b(\d+)\.torgb\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
        r'synthesis\.b(\d+)\.torgb\.affine\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
        r'synthesis\.b(\d+)\.torgb\.affine\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.skip\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r'.*\.resample_filter',
        None,
    )
    return G