コード例 #1
0
def convert_tf_generator(tf_G):
    if tf_G.version < 4:
        raise ValueError('TensorFlow pickle version too low')

    # Collect kwargs.
    tf_kwargs = tf_G.static_kwargs
    known_kwargs = set()

    def kwarg(tf_name, default=None, none=None):
        known_kwargs.add(tf_name)
        val = tf_kwargs.get(tf_name, default)
        return val if val is not None else none

    # Convert kwargs.
    kwargs = dnnlib.EasyDict(
        z_dim=kwarg('latent_size', 512),
        c_dim=kwarg('label_size', 0),
        w_dim=kwarg('dlatent_size', 512),
        img_resolution=kwarg('resolution', 1024),
        img_channels=kwarg('num_channels', 3),
        mapping_kwargs=dnnlib.EasyDict(
            num_layers=kwarg('mapping_layers', 8),
            embed_features=kwarg('label_fmaps', None),
            layer_features=kwarg('mapping_fmaps', None),
            activation=kwarg('mapping_nonlinearity', 'lrelu'),
            lr_multiplier=kwarg('mapping_lrmul', 0.01),
            w_avg_beta=kwarg('w_avg_beta', 0.995, none=1),
        ),
        synthesis_kwargs=dnnlib.EasyDict(
            channel_base=kwarg('fmap_base', None),
            channel_max=kwarg('fmap_max', 512),
            num_fp16_res=kwarg('num_fp16_res', 0),
            conv_clamp=kwarg('conv_clamp', None),
            architecture=kwarg('architecture', 'skip'),
            resample_filter=kwarg('resample_kernel', [1, 3, 3, 1]),
            use_noise=kwarg('use_noise', True),
            activation=kwarg('nonlinearity', 'lrelu'),
        ),
    )

    # Check for unknown kwargs.
    kwarg('truncation_psi')
    kwarg('truncation_cutoff')
    kwarg('style_mixing_prob')
    kwarg('structure')
    kwarg('impl')
    unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
    if len(unknown_kwargs) > 0:
        raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])

    # Collect params.
    tf_params = _collect_tf_params(tf_G)
    for name, value in list(tf_params.items()):
        match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
        if match:
            r = kwargs.img_resolution // (2**int(match.group(1)))
            tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
            kwargs.synthesis.kwargs.architecture = 'orig'
    #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')

    if kwargs.synthesis_kwargs.channel_base is None:
        top_level_weight = tf_params.get(
            f'synthesis/{kwargs.img_resolution}x{kwargs.img_resolution}/Conv1/weight',
            None)
        if top_level_weight is not None:
            kwargs.synthesis_kwargs.channel_base = top_level_weight.shape[
                -1] * kwargs.img_resolution
        else:
            kwargs.synthesis_kwargs.channel_base = 32768
    else:
        kwargs.synthesis_kwargs.channel_base *= 2

    # Convert params.
    from training import networks
    G = networks.Generator(**kwargs).eval().requires_grad_(False)
    # pylint: disable=unnecessary-lambda
    _populate_module_params(
        G,
        r'mapping\.w_avg',
        lambda: tf_params[f'dlatent_avg'],
        r'mapping\.embed\.weight',
        lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
        r'mapping\.embed\.bias',
        lambda: tf_params[f'mapping/LabelEmbed/bias'],
        r'mapping\.fc(\d+)\.weight',
        lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
        r'mapping\.fc(\d+)\.bias',
        lambda i: tf_params[f'mapping/Dense{i}/bias'],
        r'synthesis\.b4\.const',
        lambda: tf_params[f'synthesis/4x4/Const/const'][0],
        r'synthesis\.b4\.conv1\.weight',
        lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
        r'synthesis\.b4\.conv1\.bias',
        lambda: tf_params[f'synthesis/4x4/Conv/bias'],
        r'synthesis\.b4\.conv1\.noise_const',
        lambda: tf_params[f'synthesis/noise0'][0, 0],
        r'synthesis\.b4\.conv1\.noise_strength',
        lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
        r'synthesis\.b4\.conv1\.affine\.weight',
        lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
        r'synthesis\.b4\.conv1\.affine\.bias',
        lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.conv0\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r'synthesis\.b(\d+)\.conv0\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
        r'synthesis\.b(\d+)\.conv0\.noise_const',
        lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0
                                                                          ],
        r'synthesis\.b(\d+)\.conv0\.noise_strength',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
        r'synthesis\.b(\d+)\.conv0\.affine\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].
        transpose(),
        r'synthesis\.b(\d+)\.conv0\.affine\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.conv1\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(
            3, 2, 0, 1),
        r'synthesis\.b(\d+)\.conv1\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
        r'synthesis\.b(\d+)\.conv1\.noise_const',
        lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0
                                                                          ],
        r'synthesis\.b(\d+)\.conv1\.noise_strength',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
        r'synthesis\.b(\d+)\.conv1\.affine\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
        r'synthesis\.b(\d+)\.conv1\.affine\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.torgb\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(
            3, 2, 0, 1),
        r'synthesis\.b(\d+)\.torgb\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
        r'synthesis\.b(\d+)\.torgb\.affine\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
        r'synthesis\.b(\d+)\.torgb\.affine\.bias',
        lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
        r'synthesis\.b(\d+)\.skip\.weight',
        lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r'.*\.resample_filter',
        None,
    )
    return G
コード例 #2
0
ファイル: loader.py プロジェクト: dorarad/gansformer
def convert_tf_generator(tf_G):
    if tf_G.version < 4:
        raise ValueError("TensorFlow pickle version too low")

    # Collect kwargs
    tf_kwargs = tf_G.static_kwargs

    def kwarg(tf_names, default=None, none=None):
        if not isinstance(tf_names, list):
            tf_names = [tf_names]
        val = default
        for tf_name in tf_names:
            if tf_name in tf_kwargs:
                val = tf_kwargs[tf_name]
        return val if val is not None else none

    # Convert kwargs
    kwargs = dnnlib.EasyDict(
        z_dim=kwarg(["latent_size", "z_dim"], 512),
        c_dim=kwarg(["label_size", "c_dim"], 0),
        w_dim=kwarg(["dlatent_size", "w_dim"], 512),
        k=kwarg("components_num", 1) +
        int(tf_G.static_kwargs.get("transformer", False)),
        img_resolution=kwarg("resolution", 1024),
        img_channels=kwarg("num_channels", 3),
        mapping_kwargs=dnnlib.EasyDict(
            num_layers=kwarg(["mapping_layersnum", "mapping_num_layers"], 8),
            layer_dim=kwarg("mapping_dim", None),
            act=kwarg("mapping_nonlinearity", "lrelu"),
            lrmul=kwarg("mapping_lrmul", 0.01),
            w_avg_beta=kwarg(["dlatent_avg_beta", "w_avg_beta"], 0.995,
                             none=1),
            resnet=kwarg("mapping_resnet", False),
            ltnt2ltnt=kwarg("mapping_ltnt2ltnt", False),
            transformer=kwarg("transformer", False),
            num_heads=kwarg("num_heads", 1),
            attention_dropout=kwarg("attention_dropout", 0.12),
            ltnt_gate=kwarg("ltnt_gate", False),
            use_pos=kwarg("use_pos", False),
            normalize_global=False,
        ),
        synthesis_kwargs=dnnlib.EasyDict(
            crop_ratio=kwarg("crop_ratio", None),
            channel_base=kwarg(["fmap_base", "channel_base"], 16 << 10) * 2,
            channel_max=kwarg(["fmap_max", "channel_max"], 512),
            architecture=kwarg("architecture", "skip"),
            resample_kernel=kwarg("resample_kernel", [1, 3, 3, 1]),
            local_noise=kwarg("local_noise", True),
            act=kwarg("nonlinearity", "lrelu"),
            ltnt_stem=kwarg(["latent_stem", "ltnt_stem"], False),
            style=kwarg("style", True),
            transformer=kwarg("transformer", False),
            start_res=kwarg("start_res", 0),
            end_res=kwarg("end_res", 8),
            num_heads=kwarg("num_heads", 1),
            attention_dropout=kwarg("attention_dropout", 0.12),
            ltnt_gate=kwarg("ltnt_gate", False),
            img_gate=kwarg("img_gate", False),
            integration=kwarg("integration", "add"),
            norm=kwarg("norm", None),
            kmeans=kwarg("kmeans", False),
            kmeans_iters=kwarg("kmeans_iters", 1),
            iterative=kwarg("iterative", False),
            use_pos=kwarg("use_pos", False),
            pos_dim=kwarg("pos_dim", None),
            pos_type=kwarg("pos_type", "sinus"),
            pos_init=kwarg("pos_init", "uniform"),
            pos_directions_num=kwarg("pos_directions_num", 2),
        ),
    )

    # Collect params
    tf_params = _collect_tf_params(tf_G)
    for name, value in list(tf_params.items()):
        match = re.fullmatch(r"ToRGB_lod(\d+)/(.*)", name)
        if match:
            r = kwargs.img_resolution // (2**int(match.group(1)))
            tf_params[f"{r}x{r}/ToRGB/{match.group(2)}"] = value
            kwargs.synthesis.kwargs.architecture = "orig"
    # for name, value in tf_params.items(): print(f"{name:<50s}{list(value.shape)}")

    # Convert params
    from training import networks
    G = networks.Generator(**kwargs).eval().requires_grad_(False)
    index = lambda r, i: "" if int(r) == 4 else f"{i}{['_up',''][int(i)]}"
    plural = lambda s: {
        "queries": "query",
        "keys": "key",
        "values": "value"
    }[s]
    global_fix = lambda s: "global/" if "global" in s else ""
    z_dim = tf_G.static_kwargs.get("latent_size") or tf_G.static_kwargs.get(
        "z_dim") or 512

    _populate_module_params(
        G,
        r"pos",
        lambda: tf_params["ltnt_emb/emb"],
        # Mapping network
        r"mapping\.w_avg",
        lambda: tf_params["dlatent_avg"],
        r"mapping\.embed\.weight",
        lambda: tf_params["mapping/LabelConcat/weight"].transpose(),
        r"mapping\.embed\.bias",
        lambda: np.zeros([z_dim]),
        r"mapping\.([a-z_]+)\.l(\d+)\.fc(\d+)\.weight",
        lambda s, i, j: tf_params[f"mapping/{global_fix(s)}Dense{i}_{j}/weight"
                                  ].transpose(),
        r"mapping\.([a-z_]+)\.l(\d+)\.fc(\d+)\.bias",
        lambda s, i, j: tf_params[f"mapping/{global_fix(s)}Dense{i}_{j}/bias"],
        r"mapping\.([a-z_]+)\.out_layer\.weight",
        lambda s: tf_params[f"mapping/{global_fix(s)}Dense3/weight"].transpose(
        ),
        r"mapping\.([a-z_]+)\.out_layer\.bias",
        lambda s: tf_params[f"mapping/{global_fix(s)}Dense3/bias"],
        r"mapping\.mlp\.l(\d+)\.fc(\d+)\.weight",
        lambda i, j: tf_params[f"mapping/Dense{i}_{j}/weight"].transpose(),
        r"mapping\.mlp\.l(\d+)\.fc(\d+)\.bias",
        lambda i, j: tf_params[f"mapping/Dense{i}_{j}/bias"],
        r"mapping\.mlp\.out_layer\.weight",
        lambda: tf_params[f"mapping/Dense3/weight"].transpose(),
        r"mapping\.mlp\.out_layer\.bias",
        lambda: tf_params[f"mapping/Dense3/bias"],
        # Mapping ltnt2ltnt
        r"mapping\.mlp\.sa(\d+)\.to_([a-z]+)\.weight",
        lambda i, s: tf_params[f"mapping/AttLayer_{i}/weight_{plural(s)}"
                               ].transpose(),
        r"mapping\.mlp\.sa(\d+)\.to_([a-z]+)\.bias",
        lambda i, s: tf_params[f"mapping/AttLayer_{i}/bias_{plural(s)}"],
        r"mapping\.mlp\.sa(\d+)\.([a-z]+)_pos_map\.weight",
        lambda i, s: tf_params[f"mapping/AttLayer_{i}/weight_{s}_pos"
                               ].transpose(),
        r"mapping\.mlp\.sa(\d+)\.([a-z]+)_pos_map\.bias",
        lambda i, s: tf_params[f"mapping/AttLayer_{i}/bias_{s}_pos"],
        r"mapping\.mlp\.sa(\d+)\.modulation\.weight",
        lambda i: tf_params[f"mapping/AttLayer_{i}/weight_out"].transpose(),
        r"mapping\.mlp\.sa(\d+)\.modulation\.bias",
        lambda i: tf_params[f"mapping/AttLayer_{i}/bias_out"],
        r"mapping\.mlp\.sa(\d+)\.centroids",
        lambda i: tf_params[f"mapping/AttLayer_{i}/toasgn_init"],
        r"mapping\.mlp\.sa(\d+)\.queries2centroids",
        lambda i: tf_params[f"mapping/AttLayer_{i}/weight_key2"].transpose(),
        r"mapping\.mlp\.sa(\d+)\.queries2centroids",
        lambda i: tf_params[f"mapping/AttLayer_{i}/bias_key2"],
        r"mapping\.mlp\.sa(\d+)\.att_weight",
        lambda i: tf_params[f"mapping/AttLayer_{i}/iter_0/st_weights"],
        # Synthesis Network
        r"synthesis\.b4\.const",
        lambda: tf_params[f"synthesis/4x4/Const/const"][0],
        r"synthesis\.b(\d+)\.conv0\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/weight"
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r"synthesis\.b(\d+)\.conv1\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv{index(r,1)}/weight"].
        transpose(3, 2, 0, 1),
        r"synthesis\.b(\d+)\.conv(\d+)\.biasAct\.bias",
        lambda r, i: tf_params[f"synthesis/{r}x{r}/Conv{index(r,i)}/bias"],
        r"synthesis\.b(\d+)\.conv(\d+)\.noise_const",
        lambda r, i: tf_params[
            f"synthesis/noise{int(np.log2(int(r)))*2-5+int(i)}"][0, 0],
        r"synthesis\.b(\d+)\.conv(\d+)\.noise_strength",
        lambda r, i: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/noise_strength"],
        r"synthesis\.b(\d+)\.conv(\d+)\.affine\.weight",
        lambda r, i: tf_params[f"synthesis/{r}x{r}/Conv{index(r,i)}/mod_weight"
                               ].transpose(),
        r"synthesis\.b(\d+)\.conv(\d+)\.affine\.bias",
        lambda r, i: tf_params[f"synthesis/{r}x{r}/Conv{index(r,i)}/mod_bias"
                               ] + 1,
        # Synthesis Network: Latents to Image
        r"synthesis\.b(\d+)\.conv(\d+)\.transformer\.to_([a-z]+)\.weight",
        lambda r, i, s: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/AttLayer_l2n/weight_{plural(s)}"
        ].transpose(),
        r"synthesis\.b(\d+)\.conv(\d+)\.transformer\.to_([a-z]+)\.bias",
        lambda r, i, s: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/AttLayer_l2n/bias_{plural(s)}"
        ],
        r"synthesis\.b(\d+)\.conv(\d+)\.transformer\.([a-z]+)_pos_map\.weight",
        lambda r, i, s: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/AttLayer_l2n/weight_{s}_pos"].
        transpose(),
        r"synthesis\.b(\d+)\.conv(\d+)\.transformer\.([a-z]+)_pos_map\.bias",
        lambda r, i, s: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/AttLayer_l2n/bias_{s}_pos"],
        r"synthesis\.b(\d+)\.conv(\d+)\.transformer\.modulation\.weight",
        lambda r, i: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/AttLayer_l2n/weight_out"
        ].transpose(),
        r"synthesis\.b(\d+)\.conv(\d+)\.transformer\.modulation\.bias",
        lambda r, i: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/AttLayer_l2n/bias_out"],
        r"synthesis\.b(\d+)\.conv(\d+)\.transformer\.centroids",
        lambda r, i: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/AttLayer_l2n/toasgn_init"],
        r"synthesis\.b(\d+)\.conv(\d+)\.transformer\.queries2centroids",
        lambda r, i: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/AttLayer_l2n/weight_key2"
        ].transpose(),
        r"synthesis\.b(\d+)\.conv(\d+)\.transformer\.queries2centroids",
        lambda r, i: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/AttLayer_l2n/bias_key2"],
        r"synthesis\.b(\d+)\.conv(\d+)\.transformer\.att_weight",
        lambda r, i: tf_params[
            f"synthesis/{r}x{r}/Conv{index(r,i)}/AttLayer_l2n/iter_0/st_weights"
        ],
        # Synthesis Network's RGB layer
        r"synthesis\.b(\d+)\.torgb\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/weight"].transpose(
            3, 2, 0, 1),
        r"synthesis\.b(\d+)\.torgb\.biasAct\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/bias"],
        r"synthesis\.b(\d+)\.torgb\.affine\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_weight"].transpose(),
        r"synthesis\.b(\d+)\.torgb\.affine\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_bias"] + 1,
        r"synthesis\.b(\d+)\.skip\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Skip/weight"
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r"synthesis\.b256\.conv_last\.weight",
        lambda: tf_params[f"synthesis/256x256/ToRGB/extraLayer/weight"
                          ].transpose(3, 2, 0, 1),
        r"synthesis\.b256\.conv_last\.affine\.weight",
        lambda: tf_params[f"synthesis/256x256/ToRGB/extraLayer/mod_weight"
                          ].transpose(),
        r"synthesis\.b256\.conv_last\.affine\.bias",
        lambda: tf_params[f"synthesis/256x256/ToRGB/extraLayer/mod_bias"] + 1,
        r".*\.resample_kernel",
        None,
        r".*\.grid_pos",
        None,
    )
    return G
コード例 #3
0
def convert_tf_generator(tf_G):
    if tf_G.version < 4:
        raise ValueError("TensorFlow pickle version too low")

    # Collect kwargs.
    tf_kwargs = tf_G.static_kwargs
    known_kwargs = set()

    def kwarg(tf_name, default=None, none=None):
        known_kwargs.add(tf_name)
        val = tf_kwargs.get(tf_name, default)
        return val if val is not None else none

    # Convert kwargs.
    kwargs = dnnlib.EasyDict(
        z_dim=kwarg("latent_size", 512),
        c_dim=kwarg("label_size", 0),
        w_dim=kwarg("dlatent_size", 512),
        img_resolution=kwarg("resolution", 1024),
        img_channels=kwarg("num_channels", 3),
        mapping_kwargs=dnnlib.EasyDict(
            num_layers=kwarg("mapping_layers", 8),
            embed_features=kwarg("label_fmaps", None),
            layer_features=kwarg("mapping_fmaps", None),
            activation=kwarg("mapping_nonlinearity", "lrelu"),
            lr_multiplier=kwarg("mapping_lrmul", 0.01),
            w_avg_beta=kwarg("w_avg_beta", 0.995, none=1),
        ),
        synthesis_kwargs=dnnlib.EasyDict(
            channel_base=kwarg("fmap_base", 16384) * 2,
            channel_max=kwarg("fmap_max", 512),
            num_fp16_res=kwarg("num_fp16_res", 0),
            conv_clamp=kwarg("conv_clamp", None),
            architecture=kwarg("architecture", "skip"),
            resample_filter=kwarg("resample_kernel", [1, 3, 3, 1]),
            use_noise=kwarg("use_noise", True),
            activation=kwarg("nonlinearity", "lrelu"),
        ),
    )

    # Check for unknown kwargs.
    kwarg("truncation_psi")
    kwarg("truncation_cutoff")
    kwarg("style_mixing_prob")
    kwarg("structure")
    unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
    if len(unknown_kwargs) > 0:
        raise ValueError("Unknown TensorFlow kwarg", unknown_kwargs[0])

    # Collect params.
    tf_params = _collect_tf_params(tf_G)
    for name, value in list(tf_params.items()):
        match = re.fullmatch(r"ToRGB_lod(\d+)/(.*)", name)
        if match:
            r = kwargs.img_resolution // (2**int(match.group(1)))
            tf_params[f"{r}x{r}/ToRGB/{match.group(2)}"] = value
            kwargs.synthesis.kwargs.architecture = "orig"
    # for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')

    # Convert params.
    from training import networks

    G = networks.Generator(**kwargs).eval().requires_grad_(False)
    # pylint: disable=unnecessary-lambda
    _populate_module_params(
        G,
        r"mapping\.w_avg",
        lambda: tf_params[f"dlatent_avg"],
        r"mapping\.embed\.weight",
        lambda: tf_params[f"mapping/LabelEmbed/weight"].transpose(),
        r"mapping\.embed\.bias",
        lambda: tf_params[f"mapping/LabelEmbed/bias"],
        r"mapping\.fc(\d+)\.weight",
        lambda i: tf_params[f"mapping/Dense{i}/weight"].transpose(),
        r"mapping\.fc(\d+)\.bias",
        lambda i: tf_params[f"mapping/Dense{i}/bias"],
        r"synthesis\.b4\.const",
        lambda: tf_params[f"synthesis/4x4/Const/const"][0],
        r"synthesis\.b4\.conv1\.weight",
        lambda: tf_params[f"synthesis/4x4/Conv/weight"].transpose(3, 2, 0, 1),
        r"synthesis\.b4\.conv1\.bias",
        lambda: tf_params[f"synthesis/4x4/Conv/bias"],
        r"synthesis\.b4\.conv1\.noise_const",
        lambda: tf_params[f"synthesis/noise0"][0, 0],
        r"synthesis\.b4\.conv1\.noise_strength",
        lambda: tf_params[f"synthesis/4x4/Conv/noise_strength"],
        r"synthesis\.b4\.conv1\.affine\.weight",
        lambda: tf_params[f"synthesis/4x4/Conv/mod_weight"].transpose(),
        r"synthesis\.b4\.conv1\.affine\.bias",
        lambda: tf_params[f"synthesis/4x4/Conv/mod_bias"] + 1,
        r"synthesis\.b(\d+)\.conv0\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/weight"
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r"synthesis\.b(\d+)\.conv0\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/bias"],
        r"synthesis\.b(\d+)\.conv0\.noise_const",
        lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-5}"][0, 0
                                                                          ],
        r"synthesis\.b(\d+)\.conv0\.noise_strength",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/noise_strength"],
        r"synthesis\.b(\d+)\.conv0\.affine\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_weight"].
        transpose(),
        r"synthesis\.b(\d+)\.conv0\.affine\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_bias"] + 1,
        r"synthesis\.b(\d+)\.conv1\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/weight"].transpose(
            3, 2, 0, 1),
        r"synthesis\.b(\d+)\.conv1\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/bias"],
        r"synthesis\.b(\d+)\.conv1\.noise_const",
        lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-4}"][0, 0
                                                                          ],
        r"synthesis\.b(\d+)\.conv1\.noise_strength",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/noise_strength"],
        r"synthesis\.b(\d+)\.conv1\.affine\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_weight"].transpose(),
        r"synthesis\.b(\d+)\.conv1\.affine\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_bias"] + 1,
        r"synthesis\.b(\d+)\.torgb\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/weight"].transpose(
            3, 2, 0, 1),
        r"synthesis\.b(\d+)\.torgb\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/bias"],
        r"synthesis\.b(\d+)\.torgb\.affine\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_weight"].transpose(),
        r"synthesis\.b(\d+)\.torgb\.affine\.bias",
        lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_bias"] + 1,
        r"synthesis\.b(\d+)\.skip\.weight",
        lambda r: tf_params[f"synthesis/{r}x{r}/Skip/weight"
                            ][::-1, ::-1].transpose(3, 2, 0, 1),
        r".*\.resample_filter",
        None,
    )
    return G