Beispiel #1
0
def run(dataset,
        data_dir,
        result_dir,
        config_id,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_G_pkl=None,
        n_batch=2,
        n_batch_per_gpu=1,
        D_global_size=0,
        C_global_size=10,
        model_type='hd_dis_model',
        latent_type='uniform',
        resume_pkl=None,
        n_samples_per=4,
        D_lambda=0,
        C_lambda=1,
        epsilon_in_loss=3,
        random_eps=True,
        M_lrmul=0.1,
        resolution_manual=1024,
        pretrained_type='with_stylegan2',
        traj_lambda=None,
        level_I_kimg=1000,
        use_level_training=False,
        resume_kimg=0,
        use_std_in_m=False,
        prior_latent_size=512,
        stylegan2_dlatent_size=512,
        stylegan2_mapping_fmaps=512,
        M_mapping_fmaps=512,
        hyperplane_lambda=1,
        hyperdir_lambda=1):
    train = EasyDict(
        run_func_name='training.training_loop_hdwG.training_loop_hdwG')
    G = EasyDict(func_name='training.hd_networks_stylegan2.G_main',
                 latent_size=prior_latent_size,
                 dlatent_size=stylegan2_dlatent_size,
                 mapping_fmaps=stylegan2_mapping_fmaps,
                 mapping_lrmul=M_lrmul,
                 style_mixing_prob=None,
                 dlatent_avg_beta=None,
                 truncation_psi=None,
                 normalize_latents=False,
                 structure='fixed')
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2')
    if model_type == 'hd_hyperplane':
        M = EasyDict(func_name='training.hd_networks.net_M_hyperplane',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size,
                     latent_size=prior_latent_size,
                     mapping_lrmul=M_lrmul,
                     use_std_in_m=use_std_in_m)
        I = EasyDict(func_name='training.hd_networks.net_I',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size)
    elif model_type == 'vc_gan_preprior':
        M = EasyDict(func_name='training.hd_networks.net_M_vc',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size,
                     latent_size=prior_latent_size,
                     mapping_lrmul=M_lrmul,
                     use_std_in_m=use_std_in_m)
        I = EasyDict(func_name='training.hd_networks.net_I',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size)
    elif model_type == 'vc_gan':
        M = EasyDict(func_name='training.hd_networks.net_M_empty',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size,
                     latent_size=prior_latent_size,
                     mapping_lrmul=M_lrmul,
                     use_std_in_m=use_std_in_m)
        I = EasyDict(func_name='training.hd_networks.net_I',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size)
        G.mapping_func = 'G_mapping_hd_dis_to_dlatent'
    else:
        M = EasyDict(func_name='training.hd_networks.net_M',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size,
                     latent_size=prior_latent_size,
                     mapping_fmaps=M_mapping_fmaps,
                     mapping_lrmul=M_lrmul,
                     use_std_in_m=use_std_in_m)
        I = EasyDict(func_name='training.hd_networks.net_I',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size)
    if model_type == 'hd_dis_model_with_cls':
        I_info = EasyDict(func_name='training.hd_networks.net_I_info',
                          C_global_size=C_global_size,
                          D_global_size=D_global_size)
    else:
        I_info = EasyDict()
    I_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)
    D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)
    if model_type == 'vc_gan':
        I_loss = EasyDict(func_name='training.loss_hdwG.IandG_vc_loss',
                          latent_type=latent_type,
                          D_global_size=D_global_size,
                          C_global_size=C_global_size,
                          D_lambda=D_lambda,
                          C_lambda=C_lambda,
                          epsilon=epsilon_in_loss,
                          random_eps=random_eps,
                          traj_lambda=traj_lambda,
                          resolution_manual=resolution_manual,
                          use_std_in_m=use_std_in_m,
                          model_type=model_type,
                          hyperplane_lambda=hyperplane_lambda,
                          prior_latent_size=prior_latent_size,
                          hyperdir_lambda=hyperdir_lambda)
    else:
        I_loss = EasyDict(
            func_name='training.loss_hdwG.IandMandG_hyperplane_loss',
            latent_type=latent_type,
            D_global_size=D_global_size,
            C_global_size=C_global_size,
            D_lambda=D_lambda,
            C_lambda=C_lambda,
            epsilon=epsilon_in_loss,
            random_eps=random_eps,
            traj_lambda=traj_lambda,
            resolution_manual=resolution_manual,
            use_std_in_m=use_std_in_m,
            model_type=model_type,
            hyperplane_lambda=hyperplane_lambda,
            prior_latent_size=prior_latent_size,
            hyperdir_lambda=hyperdir_lambda)
    D_loss = EasyDict(func_name='training.loss.D_logistic_r1')
    sched = EasyDict()
    grid = EasyDict(size='1080p', layout='random')
    sc = dnnlib.SubmitConfig()
    tf_config = {'rnd.np_random_seed': 1000}

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = n_batch
    sched.minibatch_gpu_base = n_batch_per_gpu
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'hdwG_disentanglement'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        # G.fmap_base = D.fmap_base = 8 << 10
        if resolution_manual <= 256:
            I.fmap_base = 2 << 8
            G.fmap_base = 2 << 10
            D.fmap_base = 2 << 8
        else:
            I.fmap_base = 8 << 10
            G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        # sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.002
        # sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        sched.minibatch_size_base = n_batch  # (default)
        # sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = n_batch_per_gpu  # (default)
        # sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        # G.synthesis_func = 'hd_networks_stylegan2.G_synthesis_stylegan_revised'
        G.synthesis_func = 'G_synthesis_stylegan_revised_hd'
        # D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  I_args=I,
                  M_args=M,
                  I_opt_args=I_opt,
                  D_opt_args=D_opt,
                  I_loss_args=I_loss,
                  D_loss_args=D_loss,
                  resume_pkl=resume_pkl,
                  resume_G_pkl=resume_G_pkl)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  use_hd_with_cls=(model_type == 'hd_dis_model_with_cls'),
                  use_hyperplane=(model_type == 'hd_hyperplane'),
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  n_discrete=D_global_size,
                  n_continuous=C_global_size,
                  n_samples_per=n_samples_per,
                  resolution_manual=resolution_manual,
                  pretrained_type=pretrained_type,
                  level_I_kimg=level_I_kimg,
                  use_level_training=use_level_training,
                  resume_kimg=resume_kimg,
                  use_std_in_m=use_std_in_m,
                  prior_latent_size=prior_latent_size,
                  latent_type=latent_type)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #2
0
def run(dataset,
        data_dir,
        result_dir,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_pkl,
        I_fmap_base=8,
        G_fmap_base=8,
        D_fmap_base=9,
        fmap_decay=0.15,
        C_lambda=1,
        n_samples_per=10,
        module_list=None,
        model_type='gan_like',
        epsilon_loss=3,
        random_eps=False,
        latent_type='uniform',
        batch_size=32,
        batch_per_gpu=16,
        return_atts=False,
        random_seed=1000,
        module_I_list=None,
        module_D_list=None,
        fmap_min=16,
        fmap_max=512,
        G_nf_scale=4,
        norm_ord=2,
        topk_dims_to_show=20,
        learning_rate=0.002,
        avg_mv_for_I=False,
        use_cascade=False,
        cascade_alt_freq_k=1,
        post_trans_wh=16,
        post_trans_cnn_dim=128,
        dff=512,
        trans_rate=0.1,
        construct_feat_by_concat=False,
        ncut_maxval=5,
        post_trans_mat=16,
        group_recons_lambda=0,
        trans_dim=512,
        network_snapshot_ticks=10):
    train = EasyDict(
        run_func_name='training.training_loop_tsfm.training_loop_tsfm'
    )  # Options for training loop.

    if not (module_list is None):
        module_list = _str_to_list(module_list)
        key_ls, size_ls, count_dlatent_size = split_module_names(module_list)

    if model_type == 'info_gan_like':  # Independent branch version InfoGAN
        G = EasyDict(
            func_name='training.tsfm_G_nets.G_main_tsfm',
            synthesis_func='G_synthesis_modular_tsfm',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            module_list=module_list,
            use_noise=True,
            return_atts=return_atts,
            G_nf_scale=G_nf_scale,
            trans_dim=trans_dim,
            post_trans_wh=post_trans_wh,
            post_trans_cnn_dim=post_trans_cnn_dim,
            dff=dff,
            trans_rate=trans_rate,
            construct_feat_by_concat=construct_feat_by_concat,
            ncut_maxval=ncut_maxval,
            post_trans_mat=post_trans_mat,
        )  # Options for generator network.
        I = EasyDict(func_name='training.tsfm_I_nets.head_infogan2',
                     dlatent_size=count_dlatent_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        desc = 'info_gan_like'
    elif model_type == 'ps_sc_like':  # COMA-FAIN
        G = EasyDict(
            func_name='training.tsfm_G_nets.G_main_tsfm',
            synthesis_func='G_synthesis_modular_tsfm',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            module_list=module_list,
            use_noise=True,
            return_atts=return_atts,
            G_nf_scale=G_nf_scale,
            trans_dim=trans_dim,
            post_trans_wh=post_trans_wh,
            post_trans_cnn_dim=post_trans_cnn_dim,
            dff=dff,
            trans_rate=trans_rate,
            construct_feat_by_concat=construct_feat_by_concat,
            ncut_maxval=ncut_maxval,
            post_trans_mat=post_trans_mat,
        )  # Options for generator network.
        I = EasyDict(func_name='training.tsfm_I_nets.head_ps_sc',
                     dlatent_size=count_dlatent_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        desc = 'ps_sc_like'
    elif model_type == 'gan_like':  # Just modular GAN.
        G = EasyDict(
            func_name='training.tsfm_G_nets.G_main_tsfm',
            synthesis_func='G_synthesis_modular_tsfm',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            module_list=module_list,
            use_noise=True,
            return_atts=return_atts,
            G_nf_scale=G_nf_scale,
            trans_dim=trans_dim,
            post_trans_wh=post_trans_wh,
            post_trans_cnn_dim=post_trans_cnn_dim,
            dff=dff,
            trans_rate=trans_rate,
            construct_feat_by_concat=construct_feat_by_concat,
            ncut_maxval=ncut_maxval,
            post_trans_mat=post_trans_mat,
        )  # Options for generator network.
        I = EasyDict()
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        desc = 'gan_like'
    else:
        raise ValueError('Not supported model tyle: ' + model_type)

    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    if model_type == 'info_gan_like':  # InfoGAN
        G_loss = EasyDict(
            func_name='training.loss_tsfm.G_logistic_ns_info_gan',
            C_lambda=C_lambda,
            latent_type=latent_type,
            norm_ord=norm_ord,
            group_recons_lambda=group_recons_lambda
        )  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_tsfm.D_logistic_r1_shared',
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'ps_sc_like':  # PS-SC
        G_loss = EasyDict(
            func_name='training.loss_tsfm.G_logistic_ns_ps_sc',
            C_lambda=C_lambda,
            group_recons_lambda=group_recons_lambda,
            epsilon=epsilon_loss,
            random_eps=random_eps,
            latent_type=latent_type,
            use_cascade=use_cascade)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_tsfm.D_logistic_r1_shared',
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'gan_like':  # Just GAN
        G_loss = EasyDict(func_name='training.loss_tsfm.G_logistic_ns',
                          latent_type=latent_type,
                          group_recons_lambda=group_recons_lambda
                          )  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_tsfm.D_logistic_r1_shared',
            latent_type=latent_type)  # Options for discriminator loss.

    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {
        'rnd.np_random_seed': random_seed
    }  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = learning_rate
    sched.minibatch_size_base = batch_size
    sched.minibatch_gpu_base = batch_per_gpu
    metrics = [metric_defaults[x] for x in metrics]

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset, max_label_size='full')

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    I.fmap_base = 2 << I_fmap_base
    G.fmap_base = 2 << G_fmap_base
    D.fmap_base = 2 << D_fmap_base

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(
        G_args=G,
        D_args=D,
        I_args=I,
        G_opt_args=G_opt,
        D_opt_args=D_opt,
        G_loss_args=G_loss,
        D_loss_args=D_loss,
        use_info_gan=(
            model_type == 'info_gan_like'),  # Independent branch version
        use_ps_head=(model_type == 'ps_sc_like'),
        avg_mv_for_I=avg_mv_for_I,
        traversal_grid=True,
        return_atts=return_atts)
    n_continuous = 0
    if not (module_list is None):
        for i, key in enumerate(key_ls):
            m_name = key.split('-')[0]
            if (m_name in LATENT_MODULES) and (not m_name == 'D_global'):
                n_continuous += size_ls[i]

    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_continuous=n_continuous,
                  n_samples_per=n_samples_per,
                  topk_dims_to_show=topk_dims_to_show,
                  cascade_alt_freq_k=cascade_alt_freq_k,
                  network_snapshot_ticks=network_snapshot_ticks)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
def run(result_dir,
        num_gpus,
        total_kimg,
        mirror_augment,
        metrics,
        resume_pkl,
        G_pkl,
        I_fmap_base=8,
        fmap_decay=0.15,
        n_samples_per=10,
        module_list=None,
        latent_type='uniform',
        batch_size=32,
        batch_per_gpu=16,
        random_seed=1000,
        fmap_min=16,
        fmap_max=512,
        dlatent_size=10,
        I_nf_scale=4,
        arch='resnet'):
    print('module_list:', module_list)
    train = EasyDict(
        run_func_name='training.training_loop_infernet.training_loop_infernet'
    )  # Options for training loop.

    module_list = _str_to_list(module_list)
    I = EasyDict(func_name='training.vc_networks2.infer_modular',
                 dlatent_size=dlatent_size,
                 fmap_min=fmap_min,
                 fmap_max=fmap_max,
                 module_list=module_list,
                 I_nf_scale=I_nf_scale)
    desc = 'inference_net'

    I_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    loss = EasyDict(func_name='training.loss_inference.I_loss',
                    latent_type=latent_type,
                    dlatent_size=dlatent_size)  # Options for generator loss.

    sched = EasyDict()  # Options for TrainingSchedule.
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    # tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    tf_config = {
        'rnd.np_random_seed': random_seed
    }  # Options for tflib.init_tf().

    train.total_kimg = total_kimg
    sched.lrate = 0.002
    sched.tick_kimg = 1
    sched.minibatch_size = batch_size
    sched.minibatch_gpu = batch_per_gpu
    metrics = [metric_defaults[x] for x in metrics]

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    # Configs A-E: Shrink networks to match original StyleGAN.
    I.fmap_base = 2 << I_fmap_base

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(I_args=I, I_opt_args=I_opt, loss_args=loss)
    kwargs.update(sched_args=sched,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  G_pkl=G_pkl,
                  n_samples_per=n_samples_per)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #4
0
def generate_images(network_pkl,
                    seeds,
                    truncation_psi,
                    data_dir=None,
                    dataset_name=None,
                    model=None):
    G_args = EasyDict(func_name='training.' + model + '.G_main')
    dataset_args = EasyDict(tfrecord_dir=dataset_name)
    G_args.fmap_base = 8 << 10
    tflib.init_tf()
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    print('Constructing networks...')
    Gs = tflib.Network('G',
                       num_channels=training_set.shape[0],
                       resolution=training_set.shape[1],
                       label_size=training_set.label_size,
                       **G_args)
    print('Loading networks from "%s"...' % network_pkl)
    _, _, _Gs = pretrained_networks.load_networks(network_pkl)
    Gs.copy_vars_from(_Gs)
    noise_vars = [
        var for name, var in Gs.components.synthesis.vars.items()
        if name.startswith('noise')
    ]

    Gs_kwargs = dnnlib.EasyDict()
    # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation_psi is not None:
        Gs_kwargs.truncation_psi = truncation_psi

    for seed_idx, seed in enumerate(seeds):
        print('Generating image for seed %d (%d/%d) ...' %
              (seed, seed_idx, len(seeds)))
        rnd = np.random.RandomState(seed)
        z = rnd.randn(1, *Gs.input_shape[1:])  # [minibatch, component]
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]
        images, x_v, n_v, m_v = Gs.run(
            z, None, **Gs_kwargs)  # [minibatch, height, width, channel]

        print(images.shape, n_v.shape, x_v.shape, m_v.shape)
        misc.convert_to_pil_image(images[0], drange=[-1, 1]).save(
            dnnlib.make_run_dir_path('seed%04d.png' % seed))
        misc.save_image_grid(adjust_range(n_v),
                             dnnlib.make_run_dir_path('seed%04d-nv.png' %
                                                      seed),
                             drange=[-1, 1])
        print(np.linalg.norm(x_v - m_v))
        misc.save_image_grid(adjust_range(x_v).transpose([1, 0, 2, 3]),
                             dnnlib.make_run_dir_path('seed%04d-xv.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(m_v).transpose([1, 0, 2, 3]),
                             dnnlib.make_run_dir_path('seed%04d-mv.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(x_v, 'cat')),
                             dnnlib.make_run_dir_path('seed%04d-xvs.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(m_v, 'ss')),
                             dnnlib.make_run_dir_path('seed%04d-mvs.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(m_v, 'ffhq')),
                             dnnlib.make_run_dir_path('seed%04d-fmvs.png' %
                                                      seed),
                             drange=[-1, 1])
def run(dataset,
        data_dir,
        result_dir,
        config_id,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_pkl,
        I_fmap_base=8,
        G_fmap_base=8,
        D_fmap_base=9,
        fmap_decay=0.15,
        D_lambda=0,
        C_lambda=1,
        n_samples_per=10,
        module_list=None,
        model_type='vc_gan2',
        epsilon_loss=3,
        random_eps=False,
        latent_type='uniform',
        delta_type='onedim',
        connect_mode='concat',
        batch_size=32,
        batch_per_gpu=16,
        return_atts=False,
        random_seed=1000,
        module_I_list=None,
        module_D_list=None,
        fmap_min=16,
        fmap_max=512,
        G_nf_scale=4,
        I_nf_scale=4,
        D_nf_scale=4,
        return_I_atts=False,
        dlatent_size=24,
        arch='resnet',
        topk_dims_to_show=20):
    # print('module_list:', module_list)
    train = EasyDict(
        run_func_name='training.training_loop_vc2.training_loop_vc2'
    )  # Options for training loop.

    D_global_size = 0
    if not (module_list is None):
        module_list = _str_to_list(module_list)
        key_ls, size_ls, count_dlatent_size = split_module_names(module_list)
        for i, key in enumerate(key_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_size += size_ls[i]
    else:
        count_dlatent_size = dlatent_size

    if not (module_I_list is None):
        D_global_I_size = 0
        module_I_list = _str_to_list(module_I_list)
        key_I_ls, size_I_ls, count_dlatent_I_size = split_module_names(
            module_I_list)
        for i, key in enumerate(key_I_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_I_size += size_I_ls[i]

    if not (module_D_list is None):
        D_global_D_size = 0
        module_D_list = _str_to_list(module_D_list)
        key_D_ls, size_D_ls, count_dlatent_D_size = split_module_names(
            module_D_list)
        for i, key in enumerate(key_D_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_D_size += size_D_ls[i]

    if model_type == 'vc2_gan':  # Standard VP-GAN
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.vc2_head',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan'
    elif model_type == 'vpex_gan':  # VP-GAN extended.
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vpex_networks.vpex_net',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     return_atts=return_I_atts)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vpex_gan'
    elif model_type == 'vc2_gan_own_I':  # Standard VP-GAN with own I
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.I_modular_vc2',
                     dlatent_size=count_dlatent_I_size,
                     D_global_size=D_global_I_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_I_list=module_I_list,
                     I_nf_scale=I_nf_scale)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan_own_I'
    elif model_type == 'vc2_gan_own_ID':  # Standard VP-GAN with own ID
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.I_modular_vc2',
                     dlatent_size=count_dlatent_I_size,
                     D_global_size=D_global_I_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_I_list=module_I_list,
                     I_nf_scale=I_nf_scale)
        D = EasyDict(func_name='training.vc_networks2.D_modular_vc2',
                     dlatent_size=count_dlatent_D_size,
                     D_global_size=D_global_D_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_D_list=module_D_list,
                     D_nf_scale=D_nf_scale)
        I_info = EasyDict()
        desc = 'vc2_gan_ownID'
    else:
        raise ValueError('Not supported model tyle: ' + model_type)

    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    if model_type == 'vc2_gan':  # Standard VP-GAN
        G_loss = EasyDict(func_name='training.loss_vc2.G_logistic_ns_vc2',
                          D_global_size=D_global_size,
                          C_lambda=C_lambda,
                          epsilon=epsilon_loss,
                          random_eps=random_eps,
                          latent_type=latent_type,
                          delta_type=delta_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vpex_gan':  # VP-GAN extended.
        G_loss = EasyDict(func_name='training.loss_vpex.G_logistic_ns_vpex',
                          D_global_size=D_global_size,
                          C_lambda=C_lambda,
                          epsilon=epsilon_loss,
                          random_eps=random_eps,
                          latent_type=latent_type,
                          delta_type=delta_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_gan_own_I' or model_type == 'vc2_gan_own_ID':  # Standard VP-GAN with own I or D
        G_loss = EasyDict(func_name='training.loss_vc2.G_logistic_ns_vc2',
                          D_global_size=D_global_size,
                          C_lambda=C_lambda,
                          epsilon=epsilon_loss,
                          random_eps=random_eps,
                          latent_type=latent_type,
                          delta_type=delta_type,
                          own_I=True)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.

    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    # tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    tf_config = {
        'rnd.np_random_seed': random_seed
    }  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = batch_size
    sched.minibatch_gpu_base = batch_per_gpu
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset, max_label_size='full')

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    I.fmap_base = 2 << I_fmap_base
    G.fmap_base = 2 << G_fmap_base
    D.fmap_base = 2 << D_fmap_base

    # Config E: Set gamma to 100 and override G & D architecture.
    # D_loss.gamma = 100

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(
        G_args=G,
        D_args=D,
        I_args=I,
        I_info_args=I_info,
        G_opt_args=G_opt,
        D_opt_args=D_opt,
        G_loss_args=G_loss,
        D_loss_args=D_loss,
        use_info_gan=(
            model_type == 'vc2_info_gan2'),  # Independent branch version
        use_vc_head=(model_type == 'vc2_gan' or model_type == 'vpex_gan'
                     or model_type == 'vc2_gan_own_I'
                     or model_type == 'vc2_gan_own_ID'
                     or model_type == 'vc2_gan_byvae'),
        traversal_grid=True,
        return_atts=return_atts,
        return_I_atts=return_I_atts)
    n_continuous = 0
    if not (module_list is None):
        for i, key in enumerate(key_ls):
            m_name = key.split('-')[0]
            if (m_name in LATENT_MODULES) and (not m_name == 'D_global'):
                n_continuous += size_ls[i]
    else:
        n_continuous = dlatent_size

    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_discrete=D_global_size,
                  n_continuous=n_continuous,
                  n_samples_per=n_samples_per,
                  topk_dims_to_show=topk_dims_to_show)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
def run(dataset,
        data_dir,
        result_dir,
        config_id,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_pkl,
        I_fmap_base=8,
        G_fmap_base=8,
        D_fmap_base=9,
        fmap_decay=0.15,
        D_lambda=1,
        C_lambda=1,
        cls_alpha=0,
        n_samples_per=10,
        module_list=None,
        model_type='vc_gan2',
        epsilon_loss=3,
        random_eps=False,
        latent_type='uniform',
        delta_type='onedim',
        connect_mode='concat',
        batch_size=32,
        batch_per_gpu=16,
        return_atts=False,
        random_seed=1000,
        module_I_list=None,
        module_D_list=None,
        fmap_min=16,
        fmap_max=512,
        G_nf_scale=4,
        I_nf_scale=4,
        D_nf_scale=4,
        outlier_detector=False,
        gen_atts_in_D=False,
        no_atts_in_D=False,
        att_lambda=0,
        dlatent_size=24,
        arch='resnet',
        opt_reset_ls=None,
        norm_ord=2,
        n_dim_strict=0,
        drop_extra_torgb=False,
        latent_split_ls_for_std_gen=[5, 5, 5, 5],
        loose_rate=0.2,
        topk_dims_to_show=20,
        n_neg_samples=1,
        temperature=1.,
        learning_rate=0.002,
        avg_mv_for_I=False,
        use_cascade=False,
        cascade_alt_freq_k=1,
        regW_lambda=1,
        network_snapshot_ticks=10):
    # print('module_list:', module_list)
    train = EasyDict(
        run_func_name='training.training_loop_vc2.training_loop_vc2'
    )  # Options for training loop.
    if opt_reset_ls is not None:
        opt_reset_ls = _str_to_list_of_int(opt_reset_ls)

    D_global_size = 0
    if not (module_list is None):
        module_list = _str_to_list(module_list)
        key_ls, size_ls, count_dlatent_size = split_module_names(module_list)
        for i, key in enumerate(key_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_size += size_ls[i]
    else:
        count_dlatent_size = dlatent_size

    if not (module_I_list is None):
        D_global_I_size = 0
        module_I_list = _str_to_list(module_I_list)
        key_I_ls, size_I_ls, count_dlatent_I_size = split_module_names(
            module_I_list)
        for i, key in enumerate(key_I_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_I_size += size_I_ls[i]
    if not (module_D_list is None):
        D_global_D_size = 0
        module_D_list = _str_to_list(module_D_list)
        key_D_ls, size_D_ls, count_dlatent_D_size = split_module_names(
            module_D_list)
        for i, key in enumerate(key_D_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_D_size += size_D_ls[i]

    if model_type == 'vc2_info_gan':  # G1 and G2 version InfoGAN
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        D = EasyDict(func_name='training.vc_networks2.D_info_modular_vc2',
                     dlatent_size=count_dlatent_D_size,
                     D_global_size=D_global_D_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_D_list=module_D_list,
                     gen_atts_in_D=gen_atts_in_D,
                     no_atts_in_D=no_atts_in_D,
                     D_nf_scale=D_nf_scale)
        I = EasyDict()
        I_info = EasyDict()
        desc = 'vc2_info_gan_net'
    elif model_type == 'vc2_info_gan2':  # Independent branch version InfoGAN
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.vc2_head_infogan2',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_info_gan2_net'
    elif model_type == 'vc2_gan':  # Standard VP-GAN
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.vc2_head',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan'
    elif model_type == 'vc2_gan_byvae':  # COMA-FAIN
        G = EasyDict(
            func_name='training.vc_networks2.G_main_vc2',
            synthesis_func='G_synthesis_modular_vc2',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            D_global_size=D_global_size,
            module_list=module_list,
            use_noise=True,
            return_atts=return_atts,
            G_nf_scale=G_nf_scale,
            architecture=arch,
            drop_extra_torgb=drop_extra_torgb,
            latent_split_ls_for_std_gen=latent_split_ls_for_std_gen,
        )  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.vc2_head_byvae',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan_byvae'
    elif model_type == 'vc2_gan_byvae_simple':  # COMA-FAIN-simple
        G = EasyDict(
            func_name='training.vc_networks2.G_main_vc2',
            synthesis_func='G_synthesis_simple_vc2',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            D_global_size=D_global_size,
            module_list=module_list,
            use_noise=True,
            return_atts=return_atts,
            G_nf_scale=G_nf_scale,
            architecture=arch,
            drop_extra_torgb=drop_extra_torgb,
            latent_split_ls_for_std_gen=latent_split_ls_for_std_gen,
        )  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.I_byvae_simple',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode)
        D = EasyDict(func_name='training.vc_networks2.D_stylegan2_simple',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan_byvae_simple'
    elif model_type == 'vc2_gan_style2_noI':  # Just Style2-style GAN
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_stylegan2_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=dlatent_size,
                     architecture=arch,
                     dlatent_size=count_dlatent_size,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict()
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan_style2_noI'
    elif model_type == 'vc2_gan_own_I':  # Standard VP-GAN with own I
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.I_modular_vc2',
                     dlatent_size=count_dlatent_I_size,
                     D_global_size=D_global_I_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_I_list=module_I_list,
                     I_nf_scale=I_nf_scale)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan_own_I'
    elif model_type == 'vc2_gan_own_ID':  # Standard VP-GAN with own ID
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.I_modular_vc2',
                     dlatent_size=count_dlatent_I_size,
                     D_global_size=D_global_I_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_I_list=module_I_list,
                     I_nf_scale=I_nf_scale)
        D = EasyDict(func_name='training.vc_networks2.D_modular_vc2',
                     dlatent_size=count_dlatent_D_size,
                     D_global_size=D_global_D_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_D_list=module_D_list,
                     D_nf_scale=D_nf_scale)
        I_info = EasyDict()
        desc = 'vc2_gan_ownID'
    elif model_type == 'vc2_gan_noI' or model_type == 'vc2_traversal_contrastive' or \
        model_type == 'gan_regW': # Just modular GAN or traversal contrastive or regW
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict()
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = model_type
    else:
        raise ValueError('Not supported model tyle: ' + model_type)

    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    if model_type == 'vc2_info_gan':  # G1 and G2 version InfoGAN
        G_loss = EasyDict(
            func_name='training.loss_vc2.G_logistic_ns_vc2_info_gan',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            epsilon=epsilon_loss,
            random_eps=random_eps,
            latent_type=latent_type,
            delta_type=delta_type,
            outlier_detector=outlier_detector,
            gen_atts_in_D=gen_atts_in_D,
            att_lambda=att_lambda)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2_info_gan',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_info_gan2':  # Independent branch version InfoGAN
        G_loss = EasyDict(
            func_name='training.loss_vc2.G_logistic_ns_vc2_info_gan2',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            latent_type=latent_type,
            norm_ord=norm_ord,
            n_dim_strict=n_dim_strict,
            loose_rate=loose_rate)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2_info_gan2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_gan':  # Standard VP-GAN
        G_loss = EasyDict(func_name='training.loss_vc2.G_logistic_ns_vc2',
                          D_global_size=D_global_size,
                          C_lambda=C_lambda,
                          epsilon=epsilon_loss,
                          random_eps=random_eps,
                          latent_type=latent_type,
                          delta_type=delta_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_gan_byvae' or model_type == 'vc2_gan_byvae_simple':  # COMA-FAIN
        G_loss = EasyDict(
            func_name='training.loss_vc2.G_logistic_byvae_ns_vc2',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            epsilon=epsilon_loss,
            random_eps=random_eps,
            latent_type=latent_type,
            use_cascade=use_cascade,
            delta_type=delta_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_gan_own_I' or model_type == 'vc2_gan_own_ID':  # Standard VP-GAN with own I or D
        G_loss = EasyDict(func_name='training.loss_vc2.G_logistic_ns_vc2',
                          D_global_size=D_global_size,
                          C_lambda=C_lambda,
                          epsilon=epsilon_loss,
                          random_eps=random_eps,
                          latent_type=latent_type,
                          delta_type=delta_type,
                          own_I=True)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_gan_noI' or model_type == 'vc2_gan_style2_noI':  # Just GANs (modular or StyleGAN2-style)
        G_loss = EasyDict(
            func_name='training.loss_vc2.G_logistic_ns',
            latent_type=latent_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_traversal_contrastive':  # With perceptual distance as guide.
        G_loss = EasyDict(
            func_name=
            'training.loss_vc2.G_logistic_ns_vc2_traversal_contrastive',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            n_neg_samples=n_neg_samples,
            temperature=temperature,
            epsilon=epsilon_loss,
            random_eps=random_eps,
            latent_type=latent_type,
            delta_type=delta_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'gan_regW':
        G_loss = EasyDict(
            func_name='training.loss_vc2.G_logistic_ns_regW',
            latent_type=latent_type,
            regW_lambda=regW_lambda)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.

    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    # tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    tf_config = {
        'rnd.np_random_seed': random_seed
    }  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    # sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.G_lrate_base = sched.D_lrate_base = learning_rate
    sched.minibatch_size_base = batch_size
    sched.minibatch_gpu_base = batch_per_gpu
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset, max_label_size='full')

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    # I.fmap_base = 2 << 8
    # G.fmap_base = 2 << 8
    # D.fmap_base = 2 << 9
    I.fmap_base = 2 << I_fmap_base
    G.fmap_base = 2 << G_fmap_base
    D.fmap_base = 2 << D_fmap_base

    # Config E: Set gamma to 100 and override G & D architecture.
    # D_loss.gamma = 100

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(
        G_args=G,
        D_args=D,
        I_args=I,
        I_info_args=I_info,
        G_opt_args=G_opt,
        D_opt_args=D_opt,
        G_loss_args=G_loss,
        D_loss_args=D_loss,
        use_info_gan=(
            model_type == 'vc2_info_gan2'),  # Independent branch version
        use_vc_head=(model_type == 'vc2_gan' or model_type == 'vc2_gan_own_I'
                     or model_type == 'vc2_gan_own_ID'
                     or model_type == 'vc2_gan_byvae'
                     or model_type == 'vc2_gan_byvae_simple'),
        use_vc2_info_gan=(model_type == 'vc2_info_gan'),  # G1 and G2 version
        use_perdis=(model_type == 'vc2_traversal_contrastive'),
        avg_mv_for_I=avg_mv_for_I,
        traversal_grid=True,
        return_atts=return_atts)
    n_continuous = 0
    if not (module_list is None):
        for i, key in enumerate(key_ls):
            m_name = key.split('-')[0]
            if (m_name in LATENT_MODULES) and (not m_name == 'D_global'):
                n_continuous += size_ls[i]
    else:
        n_continuous = dlatent_size

    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_discrete=D_global_size,
                  n_continuous=n_continuous,
                  n_samples_per=n_samples_per,
                  topk_dims_to_show=topk_dims_to_show,
                  cascade_alt_freq_k=cascade_alt_freq_k,
                  network_snapshot_ticks=network_snapshot_ticks)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
def run(dataset,
        data_dir,
        result_dir,
        config_id,
        num_gpus,
        total_kimg,
        mirror_augment,
        metrics,
        resume_G_pkl,
        n_batch=2,
        n_batch_per_gpu=1,
        D_global_size=0,
        C_global_size=10,
        model_type='hd_dis_model',
        latent_type='uniform',
        resume_pkl=None,
        n_samples_per=4,
        D_lambda=0,
        C_lambda=1,
        epsilon_in_loss=3,
        random_eps=True,
        M_lrmul=0.1,
        resolution_manual=1024,
        pretrained_type='with_stylegan2',
        traj_lambda=None,
        level_I_kimg=1000,
        use_level_training=False,
        resume_kimg=0,
        use_std_in_m=False,
        prior_latent_size=512,
        M_mapping_fmaps=512,
        hyperplane_lambda=1,
        hyperdir_lambda=1):
    train = EasyDict(run_func_name='training.training_loop_hd.training_loop_hd'
                     )  # Options for training loop with pretrained HD.
    if model_type == 'hd_hyperplane':
        M = EasyDict(
            func_name='training.hd_networks.net_M_hyperplane',
            C_global_size=C_global_size,
            D_global_size=D_global_size,
            latent_size=prior_latent_size,
            mapping_lrmul=M_lrmul,
            use_std_in_m=use_std_in_m)  # Options for dismapper network.
        I = EasyDict(
            func_name='training.hd_networks.net_I',
            C_global_size=C_global_size,
            D_global_size=D_global_size)  # Options for recognizor network.
    else:
        M = EasyDict(
            func_name='training.hd_networks.net_M',
            C_global_size=C_global_size,
            D_global_size=D_global_size,
            latent_size=prior_latent_size,
            mapping_fmaps=M_mapping_fmaps,
            mapping_lrmul=M_lrmul,
            use_std_in_m=use_std_in_m)  # Options for dismapper network.
        I = EasyDict(
            func_name='training.hd_networks.net_I',
            C_global_size=C_global_size,
            D_global_size=D_global_size)  # Options for recognizor network.
    if model_type == 'hd_dis_model_with_cls':
        I_info = EasyDict(func_name='training.hd_networks.net_I_info',
                          C_global_size=C_global_size,
                          D_global_size=D_global_size)
    else:
        I_info = EasyDict()
    I_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    if model_type == 'hd_hyperplane':
        I_loss = EasyDict(
            func_name='training.loss_hd.IandM_hyperplane_loss',
            latent_type=latent_type,
            D_global_size=D_global_size,
            C_global_size=C_global_size,
            D_lambda=D_lambda,
            C_lambda=C_lambda,
            epsilon=epsilon_in_loss,
            random_eps=random_eps,
            traj_lambda=traj_lambda,
            resolution_manual=resolution_manual,
            use_std_in_m=use_std_in_m,
            model_type=model_type,
            hyperplane_lambda=hyperplane_lambda,
            prior_latent_size=prior_latent_size,
            hyperdir_lambda=hyperdir_lambda)  # Options for discriminator loss.
    else:
        I_loss = EasyDict(func_name='training.loss_hd.IandM_loss',
                          latent_type=latent_type,
                          D_global_size=D_global_size,
                          C_global_size=C_global_size,
                          D_lambda=D_lambda,
                          C_lambda=C_lambda,
                          epsilon=epsilon_in_loss,
                          random_eps=random_eps,
                          traj_lambda=traj_lambda,
                          resolution_manual=resolution_manual,
                          use_std_in_m=use_std_in_m,
                          model_type=model_type,
                          hyperplane_lambda=hyperplane_lambda,
                          prior_latent_size=prior_latent_size
                          )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.I_lrate_base = 0.002
    sched.minibatch_size_base = n_batch
    sched.minibatch_gpu_base = n_batch_per_gpu
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'hd_disentanglement'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        I.fmap_base = 8 << 10

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(I_args=I,
                  M_args=M,
                  I_opt_args=I_opt,
                  I_loss_args=I_loss,
                  resume_G_pkl=resume_G_pkl)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  use_hd_with_cls=(model_type == 'hd_dis_model_with_cls'),
                  use_hyperplane=(model_type == 'hd_hyperplane'),
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_discrete=D_global_size,
                  n_continuous=C_global_size,
                  n_samples_per=n_samples_per,
                  resolution_manual=resolution_manual,
                  pretrained_type=pretrained_type,
                  level_I_kimg=level_I_kimg,
                  use_level_training=use_level_training,
                  resume_kimg=resume_kimg,
                  use_std_in_m=use_std_in_m,
                  prior_latent_size=prior_latent_size)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)