def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma,
        mirror_augment, metrics):
    train = EasyDict(
        run_func_name='training.training_loop.rotation.v5_int_reg.training_loop'
    )
    G = EasyDict(func_name='training.networks.rotation.v5_int_reg.G_main')
    D = EasyDict(func_name='training.networks.rotation.v5_int_reg.D_stylegan2')
    G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)
    D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)
    G_loss = EasyDict(
        func_name='training.loss.rotation.v5_int_reg.G_logistic_ns_pathreg')
    D_loss = EasyDict(
        func_name='training.loss.rotation.v5_int_reg.D_logistic_r1')
    sched = EasyDict()
    grid = EasyDict(size='1080p', layout='random')
    sc = dnnlib.SubmitConfig()
    tf_config = {'rnd.np_random_seed': 1000}

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'rotation-v5-int-reg_256'

    G_loss.int_reg_clip = 5.0
    G_loss.rotation_step_size = 0.08 / 2

    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus

    assert config_id in _valid_configs

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #2
0
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics):
    train     = EasyDict(run_func_name='training.training_loop.conditional.v5_baseline.training_loop') # Options for training loop.
    G         = EasyDict(func_name='training.networks.conditional.baseline.G_main')       # Options for generator network.
    D         = EasyDict(func_name='training.networks.conditional.baseline.D_stylegan2')  # Options for discriminator network.
    G_opt     = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                  # Options for generator optimizer.
    D_opt     = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                  # Options for discriminator optimizer.
    G_loss    = EasyDict(func_name='training.loss.conditional.label_dropout.G_logistic_ns_pathreg')      # Options for generator loss.
    D_loss    = EasyDict(func_name='training.loss.conditional.label_dropout.D_logistic_r1')              # Options for discriminator loss.
    sched     = EasyDict()                                                     # Options for TrainingSchedule.
    grid      = EasyDict(size='1080p', layout='random')                           # Options for setup_snapshot_image_grid().
    sc        = dnnlib.SubmitConfig()                                          # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}                                   # Options for tflib.init_tf().

    # train.resume_pkl = './../results/00326-conditional_label_dropout_25/network-snapshot-000887.pkl'
    # train.resume_kimg = 887.0

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'conditional_label_dropout_25'

    G_loss.label_dropout_prob = 0.5
    D_loss.label_dropout_prob = 0.5

    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #3
0
def run(dataset,
        data_dir,
        result_dir,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_pkl,
        I_fmap_base=8,
        G_fmap_base=8,
        D_fmap_base=9,
        fmap_decay=0.15,
        C_lambda=1,
        n_samples_per=10,
        module_list=None,
        model_type='gan_like',
        epsilon_loss=3,
        random_eps=False,
        latent_type='uniform',
        batch_size=32,
        batch_per_gpu=16,
        return_atts=False,
        random_seed=1000,
        module_I_list=None,
        module_D_list=None,
        fmap_min=16,
        fmap_max=512,
        G_nf_scale=4,
        norm_ord=2,
        topk_dims_to_show=20,
        learning_rate=0.002,
        avg_mv_for_I=False,
        use_cascade=False,
        cascade_alt_freq_k=1,
        post_trans_wh=16,
        post_trans_cnn_dim=128,
        dff=512,
        trans_rate=0.1,
        construct_feat_by_concat=False,
        ncut_maxval=5,
        post_trans_mat=16,
        group_recons_lambda=0,
        trans_dim=512,
        network_snapshot_ticks=10):
    train = EasyDict(
        run_func_name='training.training_loop_tsfm.training_loop_tsfm'
    )  # Options for training loop.

    if not (module_list is None):
        module_list = _str_to_list(module_list)
        key_ls, size_ls, count_dlatent_size = split_module_names(module_list)

    if model_type == 'info_gan_like':  # Independent branch version InfoGAN
        G = EasyDict(
            func_name='training.tsfm_G_nets.G_main_tsfm',
            synthesis_func='G_synthesis_modular_tsfm',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            module_list=module_list,
            use_noise=True,
            return_atts=return_atts,
            G_nf_scale=G_nf_scale,
            trans_dim=trans_dim,
            post_trans_wh=post_trans_wh,
            post_trans_cnn_dim=post_trans_cnn_dim,
            dff=dff,
            trans_rate=trans_rate,
            construct_feat_by_concat=construct_feat_by_concat,
            ncut_maxval=ncut_maxval,
            post_trans_mat=post_trans_mat,
        )  # Options for generator network.
        I = EasyDict(func_name='training.tsfm_I_nets.head_infogan2',
                     dlatent_size=count_dlatent_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        desc = 'info_gan_like'
    elif model_type == 'ps_sc_like':  # COMA-FAIN
        G = EasyDict(
            func_name='training.tsfm_G_nets.G_main_tsfm',
            synthesis_func='G_synthesis_modular_tsfm',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            module_list=module_list,
            use_noise=True,
            return_atts=return_atts,
            G_nf_scale=G_nf_scale,
            trans_dim=trans_dim,
            post_trans_wh=post_trans_wh,
            post_trans_cnn_dim=post_trans_cnn_dim,
            dff=dff,
            trans_rate=trans_rate,
            construct_feat_by_concat=construct_feat_by_concat,
            ncut_maxval=ncut_maxval,
            post_trans_mat=post_trans_mat,
        )  # Options for generator network.
        I = EasyDict(func_name='training.tsfm_I_nets.head_ps_sc',
                     dlatent_size=count_dlatent_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        desc = 'ps_sc_like'
    elif model_type == 'gan_like':  # Just modular GAN.
        G = EasyDict(
            func_name='training.tsfm_G_nets.G_main_tsfm',
            synthesis_func='G_synthesis_modular_tsfm',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            module_list=module_list,
            use_noise=True,
            return_atts=return_atts,
            G_nf_scale=G_nf_scale,
            trans_dim=trans_dim,
            post_trans_wh=post_trans_wh,
            post_trans_cnn_dim=post_trans_cnn_dim,
            dff=dff,
            trans_rate=trans_rate,
            construct_feat_by_concat=construct_feat_by_concat,
            ncut_maxval=ncut_maxval,
            post_trans_mat=post_trans_mat,
        )  # Options for generator network.
        I = EasyDict()
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        desc = 'gan_like'
    else:
        raise ValueError('Not supported model tyle: ' + model_type)

    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    if model_type == 'info_gan_like':  # InfoGAN
        G_loss = EasyDict(
            func_name='training.loss_tsfm.G_logistic_ns_info_gan',
            C_lambda=C_lambda,
            latent_type=latent_type,
            norm_ord=norm_ord,
            group_recons_lambda=group_recons_lambda
        )  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_tsfm.D_logistic_r1_shared',
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'ps_sc_like':  # PS-SC
        G_loss = EasyDict(
            func_name='training.loss_tsfm.G_logistic_ns_ps_sc',
            C_lambda=C_lambda,
            group_recons_lambda=group_recons_lambda,
            epsilon=epsilon_loss,
            random_eps=random_eps,
            latent_type=latent_type,
            use_cascade=use_cascade)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_tsfm.D_logistic_r1_shared',
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'gan_like':  # Just GAN
        G_loss = EasyDict(func_name='training.loss_tsfm.G_logistic_ns',
                          latent_type=latent_type,
                          group_recons_lambda=group_recons_lambda
                          )  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_tsfm.D_logistic_r1_shared',
            latent_type=latent_type)  # Options for discriminator loss.

    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {
        'rnd.np_random_seed': random_seed
    }  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = learning_rate
    sched.minibatch_size_base = batch_size
    sched.minibatch_gpu_base = batch_per_gpu
    metrics = [metric_defaults[x] for x in metrics]

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset, max_label_size='full')

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    I.fmap_base = 2 << I_fmap_base
    G.fmap_base = 2 << G_fmap_base
    D.fmap_base = 2 << D_fmap_base

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(
        G_args=G,
        D_args=D,
        I_args=I,
        G_opt_args=G_opt,
        D_opt_args=D_opt,
        G_loss_args=G_loss,
        D_loss_args=D_loss,
        use_info_gan=(
            model_type == 'info_gan_like'),  # Independent branch version
        use_ps_head=(model_type == 'ps_sc_like'),
        avg_mv_for_I=avg_mv_for_I,
        traversal_grid=True,
        return_atts=return_atts)
    n_continuous = 0
    if not (module_list is None):
        for i, key in enumerate(key_ls):
            m_name = key.split('-')[0]
            if (m_name in LATENT_MODULES) and (not m_name == 'D_global'):
                n_continuous += size_ls[i]

    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_continuous=n_continuous,
                  n_samples_per=n_samples_per,
                  topk_dims_to_show=topk_dims_to_show,
                  cascade_alt_freq_k=cascade_alt_freq_k,
                  network_snapshot_ticks=network_snapshot_ticks)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #4
0
def run(dataset,
        data_dir,
        result_dir,
        config_id,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_run_id=None):
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='8k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    if resume_run_id is not None:
        # Resume from the ID of the results directory given
        ids = sorted(get_valid_runids(result_dir))

        if resume_run_id == 'recent':
            resume_run_id = ids[-1][0]
        else:
            try:
                resume_run_id = int(resume_run_id)
            except ValueError:
                raise RuntimeError(
                    '--resume argument is invalid (must be number, or "recent"): {}'
                    .format(resume_run_id))

        try:
            rundir_name = next(x[1] for x in ids if x[0] == resume_run_id)
        except StopIteration:
            raise RuntimeError(
                'Could not find results directory with run ID {} (options: {})'
                .format(resume_run_id, [x[0] for x in ids]))

        # Find kimg & pkl file
        rundir = os.path.join(result_dir, rundir_name)
        pkls = [
            name for name in os.listdir(rundir)
            if name.startswith('network-snapshot-') and name.endswith('.pkl')
        ]
        kimgs = sorted([(int(
            pkl.replace('network-snapshot-', '').replace('.pkl', '')), pkl)
                        for pkl in pkls],
                       key=lambda x: x[0])
        if len(kimgs) == 0:
            raise RuntimeError(
                'No network-snapshot-[0-9].pkl files found in {}'.format(
                    rundir))
        max_kimg = kimgs[-1][0]
        pkl_name = kimgs[-1][1]

        # Get wall clock time
        logfilepath = os.path.join(rundir, 'log.txt')
        with open(logfilepath, 'r') as f:
            logfile = f.read()
        for line in logfile.splitlines():
            if 'kimg {}'.format(max_kimg) in line:
                if 'time ' not in line:
                    raise RuntimeError(
                        'Invalid log file: {}'.format(logfilepath))
                line = line.split('time ')[1]
                if 'sec/tick' not in line:
                    raise RuntimeError(
                        'Invalid log file: {}'.format(logfilepath))
                line = line.split('sec/tick')[0].strip()
                # Parse d h m s, etc.
                total_seconds_formatted = line
                total_seconds = 0
                if 'd' in line:
                    arr = line.split('d')
                    days = int(arr[0].strip())
                    total_seconds += days * 24 * 60 * 60
                    line = arr[1]
                if 'h' in line:
                    arr = line.split('h')
                    hours = int(arr[0].strip())
                    total_seconds += hours * 60 * 60
                    line = arr[1]
                if 'm' in line:
                    arr = line.split('m')
                    mins = int(arr[0].strip())
                    total_seconds += mins * 60
                    line = arr[1]
                if 's' in line:
                    arr = line.split('s')
                    secs = int(arr[0].strip())
                    total_seconds += secs
                    line = arr[1]
                break

        # Set args for training
        train.resume_pkl = os.path.join(rundir, pkl_name)
        train.resume_kimg = max_kimg
        train.resume_time = total_seconds
        print('Resuming from run {}: kimg {}, time {}'.format(
            rundir_name, max_kimg, total_seconds_formatted))

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 1
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {
            128: 0.0015,
            256: 0.002,
            512: 0.003,
            1024: 0.003
        }
        sched.minibatch_size_base = 32  # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4  # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = 'G_synthesis_stylegan_revised'
        D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-C: Disable path length regularization.
    if config_id in ['config-a', 'config-b', 'config-c']:
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #5
0
def run(dataset, train_dir, config, d_aug, diffaug_policy, cond, ops, jpg_data, mirror, mirror_v, \
        lod_step_kimg, batch_size, resume, resume_kimg, finetune, num_gpus, ema_kimg, gamma, freezeD):

    # dataset (tfrecords) - preprocess or get
    tfr_files = file_list(os.path.dirname(dataset), 'tfr')
    tfr_files = [f for f in tfr_files if basename(dataset) in f]
    if len(tfr_files) == 0:
        tfr_file, total_samples = create_from_images(dataset, jpg=jpg_data)
    else:
        tfr_file = tfr_files[0]
    dataset_args = EasyDict(tfrecord=tfr_file, jpg_data=jpg_data)

    desc = basename(tfr_file).split('-')[0]

    # training functions
    if d_aug:  # https://github.com/mit-han-lab/data-efficient-gans
        train = EasyDict(
            run_func_name='training.training_loop_diffaug.training_loop'
        )  # Options for training loop (Diff Augment method)
        loss_args = EasyDict(
            func_name='training.loss_diffaug.ns_DiffAugment_r1',
            policy=diffaug_policy)  # Options for loss (Diff Augment method)
    else:  # original nvidia
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop (original from NVidia)
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                          )  # Options for discriminator loss.

    # network functions
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    G.impl = D.impl = ops

    # resolutions
    data_res = basename(tfr_file).split('-')[-1].split(
        'x')  # get resolution from dataset filename
    data_res = list(reversed([int(x)
                              for x in data_res]))  # convert to int list
    init_res, resolution, res_log2 = calc_init_res(data_res)
    if init_res != [4, 4]:
        print(' custom init resolution', init_res)
    G.init_res = D.init_res = list(init_res)

    train.setname = desc + config
    desc = '%s-%d-%s' % (desc, resolution, config)

    # training schedule
    sched.lod_training_kimg = lod_step_kimg
    sched.lod_transition_kimg = lod_step_kimg
    train.total_kimg = lod_step_kimg * res_log2 * 2  # a la ProGAN
    if finetune:
        train.total_kimg = 15000  # should start from ~10k kimg
    train.image_snapshot_ticks = 1
    train.network_snapshot_ticks = 5
    train.mirror_augment = mirror
    train.mirror_augment_v = mirror_v

    # learning rate
    if config == 'e':
        if finetune:  # uptrain 1024
            sched.G_lrate_base = 0.001
        else:  # train 1024
            sched.G_lrate_base = 0.001
            sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003}
            sched.lrate_step = 1500  # period for stepping to next lrate, in kimg
    if config == 'f':
        # sched.G_lrate_base = 0.0003
        sched.G_lrate_base = 0.001
    sched.D_lrate_base = sched.G_lrate_base  # *2 - not used anyway

    sched.minibatch_gpu_base = batch_size
    sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base
    sc.num_gpus = num_gpus

    if config == 'e':
        G.fmap_base = D.fmap_base = 8 << 10
        if d_aug: loss_args.gamma = 100 if gamma is None else gamma
        else: D_loss.gamma = 100 if gamma is None else gamma
    elif config == 'f':
        G.fmap_base = D.fmap_base = 16 << 10
    else:
        print(' Only configs E and F are implemented')
        exit()

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    if freezeD:
        D.freezeD = True
        train.resume_with_new_nets = True

    if d_aug:
        desc += '-daug'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    if d_aug:
        kwargs.update(loss_args=loss_args)
    else:
        kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = train_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #6
0
def run(g_loss, g_loss_kwargs, d_loss, d_loss_kwargs, dataset_train,
        dataset_eval, data_dir, result_dir, config_id, num_gpus, total_kimg,
        gamma, mirror_augment, metrics, resume_pkl, resume_kimg,
        resume_pkl_dir, max_images, lrate_base, img_ticks, net_ticks,
        skip_images):

    if g_loss_kwargs != '': g_loss_kwargs = json.loads(g_loss_kwargs)
    else: g_loss_kwargs = {}
    if d_loss_kwargs != '': d_loss_kwargs = json.loads(d_loss_kwargs)
    else: d_loss_kwargs = {}

    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(
        func_name='training.loss.' + g_loss, **g_loss_kwargs
    )  #G_logistic_ns_gsreg')      # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.' + d_loss,
                      **d_loss_kwargs)  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = img_ticks
    train.network_snapshot_ticks = net_ticks
    G.scale_func = 'training.networks_stylegan2.apply_identity'
    D.scale_func = None
    sched.G_lrate_base = sched.D_lrate_base = lrate_base  #0.002
    # TODO: Changed this to 16 to match DiffAug
    sched.minibatch_size_base = 16
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'
    sched.tick_kimg_base = 1
    sched.tick_kimg_dict = {
    }  #{8:28, 16:24, 32:20, 64:16, 128:12, 256:8, 512:6, 1024:4}): # Resolution-specific overrides.

    desc += '-' + dataset_train.split('/')[-1]
    # Get dataset paths
    t_path = dataset_train.split('/')
    e_path = dataset_eval.split('/')
    if len(t_path) > 1:
        dataset_train = t_path[-1]
        train.train_data_dir = os.path.join(data_dir, '/'.join(t_path[:-1]))
    if len(e_path) > 1:
        dataset_eval = e_path[-1]
        train.eval_data_dir = os.path.join(data_dir, '/'.join(e_path[:-1]))
    dataset_args = EasyDict(tfrecord_dir=dataset_train)
    # Limit number of training images during train (not eval)
    dataset_args['max_images'] = max_images
    if max_images: desc += '-%dimg' % max_images
    dataset_args['skip_images'] = skip_images
    dataset_args_eval = EasyDict(tfrecord_dir=dataset_eval)
    desc += '-' + dataset_eval

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    if mirror_augment: desc += '-aug'

    # Infer pretrain checkpoint from target dataset
    if not resume_pkl:
        if any(ds in dataset_train.lower()
               for ds in ['obama', 'celeba', 'rem', 'portrait']):
            resume_pkl = 'ffhq-config-f.pkl'
        if any(ds in dataset_train.lower()
               for ds in ['gogh', 'temple', 'tower', 'medici', 'bridge']):
            resume_pkl = 'church-config-f.pkl'
        if any(ds in dataset_train.lower() for ds in ['bus']):
            resume_pkl = 'car-config-f.pkl'
    resume_pkl = os.path.join(resume_pkl_dir, resume_pkl)
    train.resume_pkl = resume_pkl
    train.resume_kimg = resume_kimg

    train.resume_with_new_nets = True  # Recreate with new parameters
    # Adaptive parameters
    if 'ada' in config_id:
        G['train_scope'] = D[
            'train_scope'] = '.*adapt'  # Freeze old parameters
        if 'ss' in config_id:
            G['adapt_func'] = D[
                'adapt_func'] = 'training.networks_stylegan2.apply_adaptive_scale_shift'
        if 'sv' or 'pc' in config_id:  # [:9] == 'config-sv' or config_id[:9] == 'config-pc':
            G['map_svd'] = G['syn_svd'] = D['svd'] = True
            # Flatten over spatial dimension
            if 'flat' in config_id:
                G['spatial'] = D['spatial'] = True
            # Do PCA by centering before SVD
            if 'pc' in config_id:
                G['svd_center'] = D['svd_center'] = True
            G['svd_config'] = D['svd_config'] = 'S'
            if 'U' in config_id:
                G['svd_config'] += 'U'
                D['svd_config'] += 'U'
            if 'V' in config_id:
                G['svd_config'] += 'V'
                D['svd_config'] += 'V'
    # FreezeD
    D['freeze'] = 'fd' in config_id  #freeze_d
    # DiffAug
    if 'da' in config_id:
        G_loss = EasyDict(func_name='training.loss.G_ns_diffaug')
        D_loss = EasyDict(func_name='training.loss.D_ns_diffaug_r1')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)

    kwargs.update(
        G_args=G,
        D_args=D,
        G_opt_args=G_opt,
        D_opt_args=D_opt,
        G_loss_args=G_loss,
        D_loss_args=D_loss,
    )
    kwargs.update(dataset_args=dataset_args,
                  dataset_args_eval=dataset_args_eval,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
def run(dataset,
        data_dir,
        result_dir,
        config_id,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_pkl,
        I_fmap_base=8,
        G_fmap_base=8,
        D_fmap_base=9,
        fmap_decay=0.15,
        D_lambda=0,
        C_lambda=1,
        n_samples_per=10,
        module_list=None,
        model_type='vc_gan2',
        epsilon_loss=3,
        random_eps=False,
        latent_type='uniform',
        delta_type='onedim',
        connect_mode='concat',
        batch_size=32,
        batch_per_gpu=16,
        return_atts=False,
        random_seed=1000,
        module_I_list=None,
        module_D_list=None,
        fmap_min=16,
        fmap_max=512,
        G_nf_scale=4,
        I_nf_scale=4,
        D_nf_scale=4,
        return_I_atts=False,
        dlatent_size=24,
        arch='resnet',
        topk_dims_to_show=20):
    # print('module_list:', module_list)
    train = EasyDict(
        run_func_name='training.training_loop_vc2.training_loop_vc2'
    )  # Options for training loop.

    D_global_size = 0
    if not (module_list is None):
        module_list = _str_to_list(module_list)
        key_ls, size_ls, count_dlatent_size = split_module_names(module_list)
        for i, key in enumerate(key_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_size += size_ls[i]
    else:
        count_dlatent_size = dlatent_size

    if not (module_I_list is None):
        D_global_I_size = 0
        module_I_list = _str_to_list(module_I_list)
        key_I_ls, size_I_ls, count_dlatent_I_size = split_module_names(
            module_I_list)
        for i, key in enumerate(key_I_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_I_size += size_I_ls[i]

    if not (module_D_list is None):
        D_global_D_size = 0
        module_D_list = _str_to_list(module_D_list)
        key_D_ls, size_D_ls, count_dlatent_D_size = split_module_names(
            module_D_list)
        for i, key in enumerate(key_D_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_D_size += size_D_ls[i]

    if model_type == 'vc2_gan':  # Standard VP-GAN
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.vc2_head',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan'
    elif model_type == 'vpex_gan':  # VP-GAN extended.
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vpex_networks.vpex_net',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     return_atts=return_I_atts)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vpex_gan'
    elif model_type == 'vc2_gan_own_I':  # Standard VP-GAN with own I
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.I_modular_vc2',
                     dlatent_size=count_dlatent_I_size,
                     D_global_size=D_global_I_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_I_list=module_I_list,
                     I_nf_scale=I_nf_scale)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan_own_I'
    elif model_type == 'vc2_gan_own_ID':  # Standard VP-GAN with own ID
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.I_modular_vc2',
                     dlatent_size=count_dlatent_I_size,
                     D_global_size=D_global_I_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_I_list=module_I_list,
                     I_nf_scale=I_nf_scale)
        D = EasyDict(func_name='training.vc_networks2.D_modular_vc2',
                     dlatent_size=count_dlatent_D_size,
                     D_global_size=D_global_D_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_D_list=module_D_list,
                     D_nf_scale=D_nf_scale)
        I_info = EasyDict()
        desc = 'vc2_gan_ownID'
    else:
        raise ValueError('Not supported model tyle: ' + model_type)

    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    if model_type == 'vc2_gan':  # Standard VP-GAN
        G_loss = EasyDict(func_name='training.loss_vc2.G_logistic_ns_vc2',
                          D_global_size=D_global_size,
                          C_lambda=C_lambda,
                          epsilon=epsilon_loss,
                          random_eps=random_eps,
                          latent_type=latent_type,
                          delta_type=delta_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vpex_gan':  # VP-GAN extended.
        G_loss = EasyDict(func_name='training.loss_vpex.G_logistic_ns_vpex',
                          D_global_size=D_global_size,
                          C_lambda=C_lambda,
                          epsilon=epsilon_loss,
                          random_eps=random_eps,
                          latent_type=latent_type,
                          delta_type=delta_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_gan_own_I' or model_type == 'vc2_gan_own_ID':  # Standard VP-GAN with own I or D
        G_loss = EasyDict(func_name='training.loss_vc2.G_logistic_ns_vc2',
                          D_global_size=D_global_size,
                          C_lambda=C_lambda,
                          epsilon=epsilon_loss,
                          random_eps=random_eps,
                          latent_type=latent_type,
                          delta_type=delta_type,
                          own_I=True)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.

    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    # tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    tf_config = {
        'rnd.np_random_seed': random_seed
    }  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = batch_size
    sched.minibatch_gpu_base = batch_per_gpu
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset, max_label_size='full')

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    I.fmap_base = 2 << I_fmap_base
    G.fmap_base = 2 << G_fmap_base
    D.fmap_base = 2 << D_fmap_base

    # Config E: Set gamma to 100 and override G & D architecture.
    # D_loss.gamma = 100

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(
        G_args=G,
        D_args=D,
        I_args=I,
        I_info_args=I_info,
        G_opt_args=G_opt,
        D_opt_args=D_opt,
        G_loss_args=G_loss,
        D_loss_args=D_loss,
        use_info_gan=(
            model_type == 'vc2_info_gan2'),  # Independent branch version
        use_vc_head=(model_type == 'vc2_gan' or model_type == 'vpex_gan'
                     or model_type == 'vc2_gan_own_I'
                     or model_type == 'vc2_gan_own_ID'
                     or model_type == 'vc2_gan_byvae'),
        traversal_grid=True,
        return_atts=return_atts,
        return_I_atts=return_I_atts)
    n_continuous = 0
    if not (module_list is None):
        for i, key in enumerate(key_ls):
            m_name = key.split('-')[0]
            if (m_name in LATENT_MODULES) and (not m_name == 'D_global'):
                n_continuous += size_ls[i]
    else:
        n_continuous = dlatent_size

    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_discrete=D_global_size,
                  n_continuous=n_continuous,
                  n_samples_per=n_samples_per,
                  topk_dims_to_show=topk_dims_to_show)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #8
0
def run(dataset, resolution, result_dir, DiffAugment, num_gpus, batch_size,
        total_kimg, ema_kimg, num_samples, gamma, fmap_base, fmap_max,
        latent_size, mirror_augment, impl, metrics, resume, resume_kimg,
        num_repeats, eval):
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    loss_args = EasyDict(
        func_name='training.loss.ns_r1_DiffAugment')  # Options for loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='4k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    metrics = [metric_defaults[x] for x in metrics]
    metric_args = EasyDict(num_repeats=num_repeats)

    desc = 'DiffAugment-stylegan2' if DiffAugment else 'stylegan2'
    dataset_args = EasyDict(tfrecord_dir=dataset, resolution=resolution)
    desc += '-' + os.path.basename(dataset)
    if resolution is not None:
        desc += '-{}'.format(resolution)

    if num_samples is not None:
        dataset_args.num_samples = num_samples
        desc += '-{}samples'.format(num_samples)

    if batch_size is not None:
        desc += '-batch{}'.format(batch_size)
    else:
        batch_size = 32
    assert batch_size % num_gpus == 0
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus
    sched.minibatch_size_base = batch_size
    sched.minibatch_gpu_base = batch_size // num_gpus

    G.impl = D.impl = impl
    if fmap_base is not None:
        G.fmap_base = D.fmap_base = fmap_base
        desc += '-fmap{}'.format(fmap_base)
    if fmap_max is not None:
        G.fmap_max = D.fmap_max = fmap_max
        desc += '-fmax{}'.format(fmap_max)
    if latent_size is not None:
        G.latent_size = G.mapping_fmaps = G.dlatent_size = latent_size
        desc += '-latent{}'.format(latent_size)

    if gamma is not None:
        loss_args.gamma = gamma
        desc += '-gamma{}'.format(gamma)
    if DiffAugment:
        loss_args.policy = DiffAugment
        desc += '-' + DiffAugment.replace(',', '-')

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  loss_args=loss_args)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    kwargs.update(metric_args=metric_args)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #9
0
def run(
        dataset,
        data_dir,
        result_dir,
        config_id,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_pkl,
        D_global_size=3,
        C_global_size=0,  # Global C_latents.
        sb_C_global_size=4,
        C_local_hfeat_size=0,  # Local heatmap*features learned C_latents.
        C_local_heat_size=0,  # Local heatmap learned C_latents.
        n_samples_per=10,
        module_list=None,
        single_const=True,
        model_type='spatial_biased'):
    # print('module_list:', module_list)
    train = EasyDict(
        run_func_name='training.training_loop_dsp.training_loop_dsp'
    )  # Options for training loop.
    if model_type == 'spatial_biased':
        G = EasyDict(
            func_name=
            'training.spatial_biased_networks.G_main_spatial_biased_dsp',
            mapping_fmaps=128,
            fmap_max=128,
            latent_size=D_global_size + sb_C_global_size,
            dlatent_size=D_global_size + sb_C_global_size,
            D_global_size=D_global_size,
            sb_C_global_size=sb_C_global_size
        )  # Options for generator network.
        desc = 'spatial_biased_net'
    elif model_type == 'sb_general':
        G = EasyDict(
            func_name=
            'training.spatial_biased_networks.G_main_spatial_biased_dsp',
            synthesis_func='G_synthesis_sb_general_dsp',
            mapping_fmaps=128,
            fmap_max=128,
            latent_size=D_global_size + C_global_size + sb_C_global_size +
            C_local_hfeat_size + C_local_heat_size,
            dlatent_size=D_global_size + C_global_size + sb_C_global_size +
            C_local_hfeat_size + C_local_heat_size,
            D_global_size=D_global_size,
            C_global_size=C_global_size,
            sb_C_global_size=sb_C_global_size,
            C_local_hfeat_size=C_local_hfeat_size,
            C_local_heat_size=C_local_heat_size,
            use_noise=False)  # Options for generator network.
        desc = 'sb_general_net'
    elif model_type == 'sb_modular':
        module_list = _str_to_list(module_list)
        key_ls, size_ls, count_dlatent_size, _ = split_module_names(
            module_list)
        for i, key in enumerate(key_ls):
            if key.startswith('D_global'):
                D_global_size = size_ls[i]
                break
        print('D_global_size:', D_global_size)
        G = EasyDict(
            func_name=
            'training.spatial_biased_networks.G_main_spatial_biased_dsp',
            synthesis_func='G_synthesis_sb_modular',
            mapping_fmaps=128,
            fmap_max=128,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            D_global_size=D_global_size,
            module_list=module_list,
            single_const=single_const,
            use_noise=False)  # Options for generator network.
        desc = 'sb_modular_net'
    elif model_type == 'sb_singlelayer_modi':
        G = EasyDict(func_name='training.simple_networks.G_main_simple_dsp',
                     synthesis_func='G_synthesis_sb_singlelayer_modi_dsp',
                     mapping_fmaps=128,
                     fmap_max=128,
                     latent_size=D_global_size + sb_C_global_size,
                     dlatent_size=D_global_size + sb_C_global_size,
                     D_global_size=D_global_size,
                     sb_C_global_size=sb_C_global_size
                     )  # Options for generator network.
        desc = 'sb_singlelayer_net'
    elif model_type == 'stylegan2':
        G = EasyDict(
            func_name=
            'training.spatial_biased_networks.G_main_spatial_biased_dsp',
            dlatent_avg_beta=None,
            mapping_fmaps=128,
            fmap_max=128,
            latent_size=12,
            D_global_size=D_global_size,
            sb_C_global_size=sb_C_global_size
        )  # Options for generator network.
        desc = 'stylegan2_net'
    elif model_type == 'simple':
        G = EasyDict(func_name='training.simple_networks.G_main_simple_dsp',
                     latent_size=D_global_size + sb_C_global_size,
                     dlatent_size=D_global_size + sb_C_global_size,
                     D_global_size=D_global_size,
                     sb_C_global_size=sb_C_global_size
                     )  # Options for generator network.
    else:
        raise ValueError('Not supported model tyle: ' + model_type)

    if model_type == 'simple':
        D = EasyDict(func_name='training.simple_networks.D_simple_dsp'
                     )  # Options for discriminator network.
    else:
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_max=128)  # Options for discriminator network.
        # D         = EasyDict(func_name='training.spatial_biased_networks.D_with_discrete_dsp', fmap_max=128)  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(
        func_name='training.loss.G_logistic_ns_dsp',
        D_global_size=D_global_size)  # Options for generator loss.
    D_loss = EasyDict(
        func_name='training.loss.D_logistic_r1_dsp',
        D_global_size=D_global_size)  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset, max_label_size='full')

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss,
                  traversal_grid=True)
    if model_type == 'sb_modular':
        n_continuous = 0
        for i, key in enumerate(key_ls):
            m_name = key.split('-')[0]
            if (m_name in LATENT_MODULES) and (not m_name == 'D_global'):
                n_continuous += size_ls[i]
    else:
        n_continuous = C_global_size + sb_C_global_size + \
            C_local_hfeat_size + C_local_heat_size
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_discrete=D_global_size,
                  n_continuous=n_continuous,
                  n_samples_per=n_samples_per)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #10
0
def run(dataset, data_dir, result_dir, num_gpus, total_kimg, mirror_augment,
        metrics, resume, resume_with_new_nets, disable_style_mod,
        disable_cond_mod):

    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.co_mod_gan.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.co_mod_gan.D_co_mod_gan'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_masked_logistic_ns_l1'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_masked_logistic_r1'
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='8k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'co-mod-gan'

    desc += '-' + os.path.basename(dataset)
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    if resume is not None:
        resume_kimg = int(
            os.path.basename(resume).replace('.pkl', '').split('-')[-1])
    else:
        resume_kimg = 0

    if disable_style_mod:
        G.style_mod = False

    if disable_cond_mod:
        G.cond_mod = False

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=resume_with_new_nets)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
def run(dataset,
        data_dir,
        result_dir,
        config_id,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_pkl,
        fmap_decay=0.15,
        D_lambda=1,
        C_lambda=1,
        MI_lambda=1,
        cls_alpha=0,
        n_samples_per=10,
        module_list=None,
        single_const=True,
        model_type='spatial_biased',
        phi_blurry=0.5,
        latent_type='uniform'):

    train = EasyDict(
        run_func_name='training.training_loop_vid.training_loop_vid'
    )  # Options for training loop.

    D_global_size = 0

    module_list = _str_to_list(module_list)
    key_ls, size_ls, count_dlatent_size, _ = split_module_names(module_list)
    for i, key in enumerate(key_ls):
        if key.startswith('D_global'):
            D_global_size += size_ls[i]
            break
    print('D_global_size:', D_global_size)
    print('key_ls:', key_ls)
    print('size_ls:', size_ls)
    print('count_dlatent_size:', count_dlatent_size)

    if model_type == 'vid_model':
        G = EasyDict(func_name='training.vid_networks.G_main_vid',
                     synthesis_func='G_synthesis_vid_modular',
                     fmap_min=16,
                     fmap_max=512,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     single_const=single_const,
                     use_noise=True)  # Options for generator network.
        I = EasyDict(func_name='training.vid_networks.vid_head',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_max=512)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_max=512)  # Options for discriminator network.
        I_info = EasyDict()
        desc = model_type
    elif model_type == 'vid_with_cls':
        G = EasyDict(func_name='training.vid_networks.G_main_vid',
                     synthesis_func='G_synthesis_vid_modular',
                     fmap_min=16,
                     fmap_max=512,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     single_const=single_const,
                     use_noise=True)  # Options for generator network.
        I = EasyDict(func_name='training.vid_networks.vid_head',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_max=512)
        I_info = EasyDict(
            func_name='training.info_gan_networks.info_gan_head_cls',
            dlatent_size=count_dlatent_size,
            D_global_size=D_global_size,
            fmap_decay=fmap_decay,
            fmap_min=16,
            fmap_max=512)
        D = EasyDict(
            func_name='training.info_gan_networks.D_info_gan_stylegan2',
            fmap_max=512)  # Options for discriminator network.
        desc = model_type
    elif model_type == 'vid_naive_cluster_model':
        G = EasyDict(func_name='training.vid_networks.G_main_vid',
                     synthesis_func='G_synthesis_vid_modular',
                     fmap_min=16,
                     fmap_max=512,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     single_const=single_const,
                     use_noise=True)  # Options for generator network.
        I = EasyDict(func_name='training.vid_networks.vid_naive_cluster_head',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_max=512)  # Options for estimator network.
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_max=512)  # Options for discriminator network.
        I_info = EasyDict()
        desc = model_type
    elif model_type == 'vid_blurry_model':
        G = EasyDict(func_name='training.vid_networks.G_main_vid',
                     synthesis_func='G_synthesis_vid_modular',
                     fmap_min=16,
                     fmap_max=512,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     single_const=single_const,
                     use_noise=True)  # Options for generator network.
        I = EasyDict(func_name='training.vid_networks.vid_naive_cluster_head',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_max=512)  # Options for estimator network.
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_max=512)  # Options for discriminator network.
        I_info = EasyDict()
        desc = model_type
    else:
        raise ValueError('Not supported model tyle: ' + model_type)

    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    I_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    if model_type == 'vid_model':
        G_loss = EasyDict(
            func_name='training.loss_vid.G_logistic_ns_vid',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            latent_type=latent_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vid.D_logistic_r1_vid',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
        I_loss = EasyDict(func_name='training.loss_vid.I_vid',
                          D_global_size=D_global_size,
                          latent_type=latent_type,
                          C_lambda=C_lambda,
                          MI_lambda=MI_lambda)  # Options for estimator loss.
    elif model_type == 'vid_with_cls':
        G_loss = EasyDict(
            func_name='training.loss_vid.G_logistic_ns_vid',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            cls_alpha=cls_alpha,
            latent_type=latent_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vid.D_logistic_r1_info_gan_vid',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
        I_loss = EasyDict(func_name='training.loss_vid.I_vid',
                          D_global_size=D_global_size,
                          latent_type=latent_type,
                          C_lambda=C_lambda,
                          MI_lambda=MI_lambda)  # Options for estimator loss.
    elif model_type == 'vid_naive_cluster_model':
        G_loss = EasyDict(
            func_name='training.loss_vid.G_logistic_ns_vid_naive_cluster',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            latent_type=latent_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vid.D_logistic_r1_vid',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
        I_loss = EasyDict()  # Options for estimator loss.
        I_opt = EasyDict()
    elif model_type == 'vid_blurry_model':
        G_loss = EasyDict(
            func_name='training.loss_vid.G_logistic_ns_vid_naive_cluster',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            latent_type=latent_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vid.D_logistic_r1_vid',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
        I_loss = EasyDict(func_name='training.loss_vid.I_vid_blurry',
                          D_global_size=D_global_size,
                          latent_type=latent_type,
                          C_lambda=C_lambda,
                          MI_lambda=MI_lambda,
                          phi=phi_blurry)  # Options for estimator loss.
    else:
        raise ValueError('Not supported loss tyle: ' + model_type)

    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = sched.I_lrate_base = 0.002
    sched.minibatch_size_base = 16
    sched.minibatch_gpu_base = 8
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset, max_label_size='full')

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        # I.fmap_base = G.fmap_base = D.fmap_base = 8 << 10
        I.fmap_base = G.fmap_base = D.fmap_base = 2 << 8

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(
        G_args=G,
        D_args=D,
        I_args=I,
        I_info_args=I_info,
        G_opt_args=G_opt,
        D_opt_args=D_opt,
        I_opt_args=I_opt,
        G_loss_args=G_loss,
        D_loss_args=D_loss,
        I_loss_args=I_loss,
        use_vid_head=(model_type == 'vid_model'),
        use_vid_head_with_cls=(model_type == 'vid_with_cls'),
        use_vid_naive_cluster=(model_type == 'vid_naive_cluster_model'),
        use_vid_blurry=(model_type == 'vid_blurry_model'),
        traversal_grid=True)
    n_continuous = 0
    for i, key in enumerate(key_ls):
        m_name = key.split('-')[0]
        if (m_name in LATENT_MODULES) and (not m_name == 'D_global'):
            n_continuous += size_ls[i]

    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_discrete=D_global_size,
                  n_continuous=n_continuous,
                  n_samples_per=n_samples_per,
                  C_lambda=C_lambda,
                  MI_lambda=MI_lambda)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #12
0
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma,
        mirror_augment, metrics):
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='8k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id not in ['config-f', 'config-l']:
        G.fmap_base = D.fmap_base = 8 << 10

    # Config L: Generator training only
    if config_id == 'config-l':
        # Use labels as latent vector input
        dataset_args.max_label_size = "full"
        # Deactivate methods specific for GAN training
        G.truncation_psi = None
        G.randomize_noise = False
        G.style_mixing_prob = None
        G.dlatent_avg_beta = None
        G.conditional_labels = False
        # Refinement training
        G_loss.func_name = 'training.loss.G_reconstruction'
        train.run_func_name = 'training.training_loop.training_loop_refinement'
        # G.freeze_layers = ["mapping", "noise"]#, "4x4", "8x8", "16x16", "32x32"]
        # Network for refinement
        train.resume_pkl = "nets/stylegan2-ffhq-config-f.pkl"  # TODO init net
        train.resume_with_new_nets = True
        # Maintenance tasks
        sched.tick_kimg_base = 1  # 1 tick = 5000 images (metric update)
        sched.tick_kimg_dict = {}
        train.image_snapshot_ticks = 5  # Save every 5000 images
        train.network_snapshot_ticks = 10  # Save every 10000 images
        # Training parameters
        sched.G_lrate_base = 1e-4
        train.G_smoothing_kimg = 0.0
        sched.minibatch_size_base = sched.minibatch_gpu_base * num_gpus  # 4 per GPU

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {
            128: 0.0015,
            256: 0.002,
            512: 0.003,
            1024: 0.003
        }
        sched.minibatch_size_base = 32  # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4  # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = 'G_synthesis_stylegan_revised'
        D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-C: Disable path length regularization.
    if config_id in ['config-a', 'config-b', 'config-c']:
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #13
0
def run(
    dataset,
    data_dir,
    result_dir,
    config_id,
    num_gpus,
    total_kimg,
    gamma,
    mirror_augment,
    metrics,
    resume_pkl=None,
    resume_kimg=None,
):
    train = EasyDict(
        run_func_name="training.training_loop.training_loop",
        # training resume options:
        resume_pkl=
        resume_pkl,  # Network pickle to resume training from, None = train from scratch.
        resume_kimg=
        resume_kimg,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    )  # Options for training loop.
    G = EasyDict(func_name="training.networks_stylegan2.G_main"
                 )  # Options for generator network.
    D = EasyDict(func_name="training.networks_stylegan2.D_stylegan2"
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name="training.loss.G_logistic_ns_pathreg"
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name="training.loss.D_logistic_r1"
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size="8k", layout="random")  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {"rnd.np_random_seed": 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = "stylegan2"

    desc += "-" + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += "-%dgpu" % num_gpus

    assert config_id in _valid_configs
    desc += "-" + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != "config-f":
        G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith("config-e"):
        D_loss.gamma = 100
        if "Gorig" in config_id:
            G.architecture = "orig"
        if "Gskip" in config_id:
            G.architecture = "skip"  # (default)
        if "Gresnet" in config_id:
            G.architecture = "resnet"
        if "Dorig" in config_id:
            D.architecture = "orig"
        if "Dskip" in config_id:
            D.architecture = "skip"
        if "Dresnet" in config_id:
            D.architecture = "resnet"  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ["config-a", "config-b", "config-c", "config-d"]:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {
            128: 0.0015,
            256: 0.002,
            512: 0.003,
            1024: 0.003,
        }
        sched.minibatch_size_base = 32  # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4  # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = "G_synthesis_stylegan_revised"
        D.func_name = "training.networks_stylegan2.D_stylegan"

    # Configs A-C: Disable path length regularization.
    if config_id in ["config-a", "config-b", "config-c"]:
        G_loss = EasyDict(func_name="training.loss.G_logistic_ns")

    # Configs A-B: Disable lazy regularization.
    if config_id in ["config-a", "config-b"]:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == "config-a":
        G = EasyDict(func_name="training.networks_stylegan.G_style")
        D = EasyDict(func_name="training.networks_stylegan.D_basic")

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(
        G_args=G,
        D_args=D,
        G_opt_args=G_opt,
        D_opt_args=D_opt,
        G_loss_args=G_loss,
        D_loss_args=D_loss,
    )
    kwargs.update(
        dataset_args=dataset_args,
        sched_args=sched,
        grid_args=grid,
        metric_arg_list=metrics,
        tf_config=tf_config,
    )
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #14
0
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma,
        mirror_augment, metrics, dlatent_size, lr, batch_size, decay_step,
        decay_rate, stair, tick_kimg):
    train = EasyDict(run_func_name='training.vae_training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.vae_dcgan.Decoder_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.vae_dcgan.Encoder'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                      )  # Options for generator loss.
    D_loss = EasyDict(
        func_name='training.loss.vae_loss')  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='8k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10

    sched.batch_size = batch_size
    sched.lr = lr
    sched.decay_step = decay_step
    sched.decay_rate = decay_rate
    sched.stair = stair
    sched.tick_kimg = tick_kimg

    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'vae_dcgan'

    G.dlatent_size = dlatent_size
    D.dlatent_size = dlatent_size
    G.num_units = D.num_units = 1024
    G.act = D.act = 'relu'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #15
0
def create_model(config_id='config-f',
                 gamma=None,
                 height=512,
                 width=512,
                 cond=None,
                 label_size=0):
    train = EasyDict(run_func_name='training.diagnostic.create_initial_pkl'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    sched.minibatch_size_base = 192
    sched.minibatch_gpu_base = 3
    D_loss.gamma = 10
    desc = 'stylegan2'

    dataset_args = EasyDict()  # (tfrecord_dir=dataset)

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {
            128: 0.0015,
            256: 0.002,
            512: 0.003,
            1024: 0.003
        }
        sched.minibatch_size_base = 32  # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4  # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = 'G_synthesis_stylegan_revised'
        D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-C: Disable path length regularization.
    if config_id in ['config-a', 'config-b', 'config-c']:
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    G.update(resolution_h=height)
    G.update(resolution_w=width)
    D.update(resolution_h=height)
    D.update(resolution_w=width)

    sc.submit_target = dnnlib.SubmitTarget.DIAGNOSTIC
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    # [EDITED]
    kwargs.update(G_args=G,
                  D_args=D,
                  tf_config=tf_config,
                  config_id=config_id,
                  resolution_h=height,
                  resolution_w=width,
                  label_size=label_size)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_diagnostic(**kwargs)
    return f'network-initial-config-f-{height}x{width}-{label_size}.pkl'
Beispiel #16
0
def run(dataset,
        data_dir,
        result_dir,
        config_id,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_G_pkl=None,
        n_batch=2,
        n_batch_per_gpu=1,
        D_global_size=0,
        C_global_size=10,
        model_type='hd_dis_model',
        latent_type='uniform',
        resume_pkl=None,
        n_samples_per=4,
        D_lambda=0,
        C_lambda=1,
        epsilon_in_loss=3,
        random_eps=True,
        M_lrmul=0.1,
        resolution_manual=1024,
        pretrained_type='with_stylegan2',
        traj_lambda=None,
        level_I_kimg=1000,
        use_level_training=False,
        resume_kimg=0,
        use_std_in_m=False,
        prior_latent_size=512,
        stylegan2_dlatent_size=512,
        stylegan2_mapping_fmaps=512,
        M_mapping_fmaps=512,
        hyperplane_lambda=1,
        hyperdir_lambda=1):
    train = EasyDict(
        run_func_name='training.training_loop_hdwG.training_loop_hdwG')
    G = EasyDict(func_name='training.hd_networks_stylegan2.G_main',
                 latent_size=prior_latent_size,
                 dlatent_size=stylegan2_dlatent_size,
                 mapping_fmaps=stylegan2_mapping_fmaps,
                 mapping_lrmul=M_lrmul,
                 style_mixing_prob=None,
                 dlatent_avg_beta=None,
                 truncation_psi=None,
                 normalize_latents=False,
                 structure='fixed')
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2')
    if model_type == 'hd_hyperplane':
        M = EasyDict(func_name='training.hd_networks.net_M_hyperplane',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size,
                     latent_size=prior_latent_size,
                     mapping_lrmul=M_lrmul,
                     use_std_in_m=use_std_in_m)
        I = EasyDict(func_name='training.hd_networks.net_I',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size)
    elif model_type == 'vc_gan_preprior':
        M = EasyDict(func_name='training.hd_networks.net_M_vc',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size,
                     latent_size=prior_latent_size,
                     mapping_lrmul=M_lrmul,
                     use_std_in_m=use_std_in_m)
        I = EasyDict(func_name='training.hd_networks.net_I',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size)
    elif model_type == 'vc_gan':
        M = EasyDict(func_name='training.hd_networks.net_M_empty',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size,
                     latent_size=prior_latent_size,
                     mapping_lrmul=M_lrmul,
                     use_std_in_m=use_std_in_m)
        I = EasyDict(func_name='training.hd_networks.net_I',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size)
        G.mapping_func = 'G_mapping_hd_dis_to_dlatent'
    else:
        M = EasyDict(func_name='training.hd_networks.net_M',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size,
                     latent_size=prior_latent_size,
                     mapping_fmaps=M_mapping_fmaps,
                     mapping_lrmul=M_lrmul,
                     use_std_in_m=use_std_in_m)
        I = EasyDict(func_name='training.hd_networks.net_I',
                     C_global_size=C_global_size,
                     D_global_size=D_global_size)
    if model_type == 'hd_dis_model_with_cls':
        I_info = EasyDict(func_name='training.hd_networks.net_I_info',
                          C_global_size=C_global_size,
                          D_global_size=D_global_size)
    else:
        I_info = EasyDict()
    I_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)
    D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)
    if model_type == 'vc_gan':
        I_loss = EasyDict(func_name='training.loss_hdwG.IandG_vc_loss',
                          latent_type=latent_type,
                          D_global_size=D_global_size,
                          C_global_size=C_global_size,
                          D_lambda=D_lambda,
                          C_lambda=C_lambda,
                          epsilon=epsilon_in_loss,
                          random_eps=random_eps,
                          traj_lambda=traj_lambda,
                          resolution_manual=resolution_manual,
                          use_std_in_m=use_std_in_m,
                          model_type=model_type,
                          hyperplane_lambda=hyperplane_lambda,
                          prior_latent_size=prior_latent_size,
                          hyperdir_lambda=hyperdir_lambda)
    else:
        I_loss = EasyDict(
            func_name='training.loss_hdwG.IandMandG_hyperplane_loss',
            latent_type=latent_type,
            D_global_size=D_global_size,
            C_global_size=C_global_size,
            D_lambda=D_lambda,
            C_lambda=C_lambda,
            epsilon=epsilon_in_loss,
            random_eps=random_eps,
            traj_lambda=traj_lambda,
            resolution_manual=resolution_manual,
            use_std_in_m=use_std_in_m,
            model_type=model_type,
            hyperplane_lambda=hyperplane_lambda,
            prior_latent_size=prior_latent_size,
            hyperdir_lambda=hyperdir_lambda)
    D_loss = EasyDict(func_name='training.loss.D_logistic_r1')
    sched = EasyDict()
    grid = EasyDict(size='1080p', layout='random')
    sc = dnnlib.SubmitConfig()
    tf_config = {'rnd.np_random_seed': 1000}

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = n_batch
    sched.minibatch_gpu_base = n_batch_per_gpu
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'hdwG_disentanglement'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        # G.fmap_base = D.fmap_base = 8 << 10
        if resolution_manual <= 256:
            I.fmap_base = 2 << 8
            G.fmap_base = 2 << 10
            D.fmap_base = 2 << 8
        else:
            I.fmap_base = 8 << 10
            G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        # sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.002
        # sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        sched.minibatch_size_base = n_batch  # (default)
        # sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = n_batch_per_gpu  # (default)
        # sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        # G.synthesis_func = 'hd_networks_stylegan2.G_synthesis_stylegan_revised'
        G.synthesis_func = 'G_synthesis_stylegan_revised_hd'
        # D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  I_args=I,
                  M_args=M,
                  I_opt_args=I_opt,
                  D_opt_args=D_opt,
                  I_loss_args=I_loss,
                  D_loss_args=D_loss,
                  resume_pkl=resume_pkl,
                  resume_G_pkl=resume_G_pkl)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  use_hd_with_cls=(model_type == 'hd_dis_model_with_cls'),
                  use_hyperplane=(model_type == 'hd_hyperplane'),
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  n_discrete=D_global_size,
                  n_continuous=C_global_size,
                  n_samples_per=n_samples_per,
                  resolution_manual=resolution_manual,
                  pretrained_type=pretrained_type,
                  level_I_kimg=level_I_kimg,
                  use_level_training=use_level_training,
                  resume_kimg=resume_kimg,
                  use_std_in_m=use_std_in_m,
                  prior_latent_size=prior_latent_size,
                  latent_type=latent_type)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #17
0
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics):
    train     = EasyDict(run_func_name='training.training_loop.training_loop_mirror_v6_remove_half_fl_fr.training_loop')
    G         = EasyDict(func_name='training.networks.networks_stylegan2.G_main')
    D         = EasyDict(func_name='training.networks.networks_stylegan2_discriminator_new_rotation.D_stylegan2_new_rotaion')  # Options for discriminator network.
    G_opt     = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)
    D_opt     = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)
    G_loss    = EasyDict(func_name='training.loss.loss_G_new_rotation_squared_euclidean_10_interpolate_50_percent_uniform_dist_int_penalty.G_logistic_ns_pathreg')
    D_loss    = EasyDict(func_name='training.loss.loss_D_logistic_r1_new_rotation_euclidean_square.D_logistic_r1_new_rotation')
    sched     = EasyDict()
    grid      = EasyDict(size='1080p', layout='random')
    sc        = dnnlib.SubmitConfig()
    tf_config = {'rnd.np_random_seed': 1000}

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4

    # train.resume_pkl = './results/00200-stylegan2-car_labels_v7_oversample_filter-2gpu-config-f-squared_euclidean_10_interpolate_50_percent_int_reg-256/network-snapshot-000887.pkl'
    # train.resume_kimg = 887.2

    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'
    G.style_mixing_prob = None


    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id
    desc += '-squared_euclidean_10_interpolate_50_percent_int_reg_remove_half_fl_fr_no_noise_square'
    desc += '-256'

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig'   in config_id: G.architecture = 'orig'
        if 'Gskip'   in config_id: G.architecture = 'skip' # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig'   in config_id: D.architecture = 'orig'
        if 'Dskip'   in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet' # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        sched.minibatch_size_base = 32 # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4 # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = 'G_synthesis_stylegan_revised'
        D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-C: Disable path length regularization.
    if config_id in ['config-a', 'config-b', 'config-c']:
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #18
0
def run(data, train_dir, config, d_aug, diffaug_policy, cond, ops, mirror, mirror_v, \
        kimg, batch_size, lrate, resume, resume_kimg, num_gpus, ema_kimg, gamma, freezeD):

    # training functions
    if d_aug:  # https://github.com/mit-han-lab/data-efficient-gans
        train = EasyDict(
            run_func_name='training.training_loop_diffaug.training_loop'
        )  # Options for training loop (Diff Augment method)
        loss_args = EasyDict(
            func_name='training.loss_diffaug.ns_DiffAugment_r1',
            policy=diffaug_policy)  # Options for loss (Diff Augment method)
    else:  # original nvidia
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop (original from NVidia)
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                          )  # Options for discriminator loss.

    # network functions
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    G.impl = D.impl = ops

    # dataset (tfrecords) - get or create
    tfr_files = file_list(os.path.dirname(data), 'tfr')
    tfr_files = [
        f for f in tfr_files if basename(data) == basename(f).split('-')[0]
    ]
    if len(tfr_files) == 0 or os.stat(tfr_files[0]).st_size == 0:
        tfr_file, total_samples = create_from_image_folders(
            data) if cond is True else create_from_images(data)
    else:
        tfr_file = tfr_files[0]
    dataset_args = EasyDict(tfrecord=tfr_file)

    # resolutions
    with tf.Graph().as_default(), tflib.create_session().as_default():  # pylint: disable=not-context-manager
        dataset_obj = dataset.load_dataset(
            **dataset_args)  # loading the data to see what comes out
        resolution = dataset_obj.resolution
        init_res = dataset_obj.init_res
        res_log2 = dataset_obj.res_log2
        dataset_obj.close()
        dataset_obj = None

    if list(init_res) == [4, 4]:
        desc = '%s-%d' % (basename(data), resolution)
    else:
        print(' custom init resolution', init_res)
        desc = basename(tfr_file)
    G.init_res = D.init_res = list(init_res)

    train.savenames = [desc.replace(basename(data), 'snapshot'), desc]
    desc += '-%s' % config

    # training schedule
    train.total_kimg = kimg
    train.image_snapshot_ticks = 1 * num_gpus if kimg <= 1000 else 4 * num_gpus
    train.network_snapshot_ticks = 5
    train.mirror_augment = mirror
    train.mirror_augment_v = mirror_v
    sched.tick_kimg_base = 2 if train.total_kimg < 2000 else 4

    # learning rate
    if config == 'e':
        sched.G_lrate_base = 0.001
        sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003}
        sched.lrate_step = 1500  # period for stepping to next lrate, in kimg
    if config == 'f':
        sched.G_lrate_base = lrate  # 0.001 for big datasets, 0.0003 for few-shot
    sched.D_lrate_base = sched.G_lrate_base  # *2 - not used anyway

    # batch size (for 16gb memory GPU)
    sched.minibatch_gpu_base = 4096 // resolution if batch_size is None else batch_size
    print(' Batch size', sched.minibatch_gpu_base)
    sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base
    sc.num_gpus = num_gpus

    if config == 'e':
        G.fmap_base = D.fmap_base = 8 << 10
        if d_aug: loss_args.gamma = 100 if gamma is None else gamma
        else: D_loss.gamma = 100 if gamma is None else gamma
    elif config == 'f':
        G.fmap_base = D.fmap_base = 16 << 10
    else:
        print(' Only configs E and F are implemented')
        exit()

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    if freezeD:
        D.freezeD = True
        train.resume_with_new_nets = True

    if d_aug:
        desc += '-daug'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    if d_aug:
        kwargs.update(loss_args=loss_args)
    else:
        kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = train_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #19
0
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma,
        mirror_augment, metrics):
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='8k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 2
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {
            128: 0.0015,
            256: 0.002,
            512: 0.003,
            1024: 0.003
        }
        sched.minibatch_size_base = 32  # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4  # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = 'G_synthesis_stylegan_revised'
        D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-C: Disable path length regularization.
    if config_id in ['config-a', 'config-b', 'config-c']:
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
def run(dataset,
        data_dir,
        result_dir,
        config_id,
        num_gpus,
        total_kimg,
        gamma,
        mirror_augment,
        metrics,
        resume_pkl,
        I_fmap_base=8,
        G_fmap_base=8,
        D_fmap_base=9,
        fmap_decay=0.15,
        D_lambda=1,
        C_lambda=1,
        cls_alpha=0,
        n_samples_per=10,
        module_list=None,
        model_type='vc_gan2',
        epsilon_loss=3,
        random_eps=False,
        latent_type='uniform',
        delta_type='onedim',
        connect_mode='concat',
        batch_size=32,
        batch_per_gpu=16,
        return_atts=False,
        random_seed=1000,
        module_I_list=None,
        module_D_list=None,
        fmap_min=16,
        fmap_max=512,
        G_nf_scale=4,
        I_nf_scale=4,
        D_nf_scale=4,
        outlier_detector=False,
        gen_atts_in_D=False,
        no_atts_in_D=False,
        att_lambda=0,
        dlatent_size=24,
        arch='resnet',
        opt_reset_ls=None,
        norm_ord=2,
        n_dim_strict=0,
        drop_extra_torgb=False,
        latent_split_ls_for_std_gen=[5, 5, 5, 5],
        loose_rate=0.2,
        topk_dims_to_show=20,
        n_neg_samples=1,
        temperature=1.,
        learning_rate=0.002,
        avg_mv_for_I=False,
        use_cascade=False,
        cascade_alt_freq_k=1,
        regW_lambda=1,
        network_snapshot_ticks=10):
    # print('module_list:', module_list)
    train = EasyDict(
        run_func_name='training.training_loop_vc2.training_loop_vc2'
    )  # Options for training loop.
    if opt_reset_ls is not None:
        opt_reset_ls = _str_to_list_of_int(opt_reset_ls)

    D_global_size = 0
    if not (module_list is None):
        module_list = _str_to_list(module_list)
        key_ls, size_ls, count_dlatent_size = split_module_names(module_list)
        for i, key in enumerate(key_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_size += size_ls[i]
    else:
        count_dlatent_size = dlatent_size

    if not (module_I_list is None):
        D_global_I_size = 0
        module_I_list = _str_to_list(module_I_list)
        key_I_ls, size_I_ls, count_dlatent_I_size = split_module_names(
            module_I_list)
        for i, key in enumerate(key_I_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_I_size += size_I_ls[i]
    if not (module_D_list is None):
        D_global_D_size = 0
        module_D_list = _str_to_list(module_D_list)
        key_D_ls, size_D_ls, count_dlatent_D_size = split_module_names(
            module_D_list)
        for i, key in enumerate(key_D_ls):
            if key.startswith('D_global') or key.startswith('D_nocond_global'):
                D_global_D_size += size_D_ls[i]

    if model_type == 'vc2_info_gan':  # G1 and G2 version InfoGAN
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        D = EasyDict(func_name='training.vc_networks2.D_info_modular_vc2',
                     dlatent_size=count_dlatent_D_size,
                     D_global_size=D_global_D_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_D_list=module_D_list,
                     gen_atts_in_D=gen_atts_in_D,
                     no_atts_in_D=no_atts_in_D,
                     D_nf_scale=D_nf_scale)
        I = EasyDict()
        I_info = EasyDict()
        desc = 'vc2_info_gan_net'
    elif model_type == 'vc2_info_gan2':  # Independent branch version InfoGAN
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.vc2_head_infogan2',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_info_gan2_net'
    elif model_type == 'vc2_gan':  # Standard VP-GAN
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.vc2_head',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan'
    elif model_type == 'vc2_gan_byvae':  # COMA-FAIN
        G = EasyDict(
            func_name='training.vc_networks2.G_main_vc2',
            synthesis_func='G_synthesis_modular_vc2',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            D_global_size=D_global_size,
            module_list=module_list,
            use_noise=True,
            return_atts=return_atts,
            G_nf_scale=G_nf_scale,
            architecture=arch,
            drop_extra_torgb=drop_extra_torgb,
            latent_split_ls_for_std_gen=latent_split_ls_for_std_gen,
        )  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.vc2_head_byvae',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan_byvae'
    elif model_type == 'vc2_gan_byvae_simple':  # COMA-FAIN-simple
        G = EasyDict(
            func_name='training.vc_networks2.G_main_vc2',
            synthesis_func='G_synthesis_simple_vc2',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_size,
            dlatent_size=count_dlatent_size,
            D_global_size=D_global_size,
            module_list=module_list,
            use_noise=True,
            return_atts=return_atts,
            G_nf_scale=G_nf_scale,
            architecture=arch,
            drop_extra_torgb=drop_extra_torgb,
            latent_split_ls_for_std_gen=latent_split_ls_for_std_gen,
        )  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.I_byvae_simple',
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode)
        D = EasyDict(func_name='training.vc_networks2.D_stylegan2_simple',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan_byvae_simple'
    elif model_type == 'vc2_gan_style2_noI':  # Just Style2-style GAN
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_stylegan2_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=dlatent_size,
                     architecture=arch,
                     dlatent_size=count_dlatent_size,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict()
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan_style2_noI'
    elif model_type == 'vc2_gan_own_I':  # Standard VP-GAN with own I
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.I_modular_vc2',
                     dlatent_size=count_dlatent_I_size,
                     D_global_size=D_global_I_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_I_list=module_I_list,
                     I_nf_scale=I_nf_scale)
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = 'vc2_gan_own_I'
    elif model_type == 'vc2_gan_own_ID':  # Standard VP-GAN with own ID
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict(func_name='training.vc_networks2.I_modular_vc2',
                     dlatent_size=count_dlatent_I_size,
                     D_global_size=D_global_I_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_I_list=module_I_list,
                     I_nf_scale=I_nf_scale)
        D = EasyDict(func_name='training.vc_networks2.D_modular_vc2',
                     dlatent_size=count_dlatent_D_size,
                     D_global_size=D_global_D_size,
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     connect_mode=connect_mode,
                     module_D_list=module_D_list,
                     D_nf_scale=D_nf_scale)
        I_info = EasyDict()
        desc = 'vc2_gan_ownID'
    elif model_type == 'vc2_gan_noI' or model_type == 'vc2_traversal_contrastive' or \
        model_type == 'gan_regW': # Just modular GAN or traversal contrastive or regW
        G = EasyDict(func_name='training.vc_networks2.G_main_vc2',
                     synthesis_func='G_synthesis_modular_vc2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max,
                     fmap_decay=fmap_decay,
                     latent_size=count_dlatent_size,
                     dlatent_size=count_dlatent_size,
                     D_global_size=D_global_size,
                     module_list=module_list,
                     use_noise=True,
                     return_atts=return_atts,
                     G_nf_scale=G_nf_scale)  # Options for generator network.
        I = EasyDict()
        D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2',
                     fmap_min=fmap_min,
                     fmap_max=fmap_max)  # Options for discriminator network.
        I_info = EasyDict()
        desc = model_type
    else:
        raise ValueError('Not supported model tyle: ' + model_type)

    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    if model_type == 'vc2_info_gan':  # G1 and G2 version InfoGAN
        G_loss = EasyDict(
            func_name='training.loss_vc2.G_logistic_ns_vc2_info_gan',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            epsilon=epsilon_loss,
            random_eps=random_eps,
            latent_type=latent_type,
            delta_type=delta_type,
            outlier_detector=outlier_detector,
            gen_atts_in_D=gen_atts_in_D,
            att_lambda=att_lambda)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2_info_gan',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_info_gan2':  # Independent branch version InfoGAN
        G_loss = EasyDict(
            func_name='training.loss_vc2.G_logistic_ns_vc2_info_gan2',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            latent_type=latent_type,
            norm_ord=norm_ord,
            n_dim_strict=n_dim_strict,
            loose_rate=loose_rate)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2_info_gan2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_gan':  # Standard VP-GAN
        G_loss = EasyDict(func_name='training.loss_vc2.G_logistic_ns_vc2',
                          D_global_size=D_global_size,
                          C_lambda=C_lambda,
                          epsilon=epsilon_loss,
                          random_eps=random_eps,
                          latent_type=latent_type,
                          delta_type=delta_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_gan_byvae' or model_type == 'vc2_gan_byvae_simple':  # COMA-FAIN
        G_loss = EasyDict(
            func_name='training.loss_vc2.G_logistic_byvae_ns_vc2',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            epsilon=epsilon_loss,
            random_eps=random_eps,
            latent_type=latent_type,
            use_cascade=use_cascade,
            delta_type=delta_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_gan_own_I' or model_type == 'vc2_gan_own_ID':  # Standard VP-GAN with own I or D
        G_loss = EasyDict(func_name='training.loss_vc2.G_logistic_ns_vc2',
                          D_global_size=D_global_size,
                          C_lambda=C_lambda,
                          epsilon=epsilon_loss,
                          random_eps=random_eps,
                          latent_type=latent_type,
                          delta_type=delta_type,
                          own_I=True)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_gan_noI' or model_type == 'vc2_gan_style2_noI':  # Just GANs (modular or StyleGAN2-style)
        G_loss = EasyDict(
            func_name='training.loss_vc2.G_logistic_ns',
            latent_type=latent_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'vc2_traversal_contrastive':  # With perceptual distance as guide.
        G_loss = EasyDict(
            func_name=
            'training.loss_vc2.G_logistic_ns_vc2_traversal_contrastive',
            D_global_size=D_global_size,
            C_lambda=C_lambda,
            n_neg_samples=n_neg_samples,
            temperature=temperature,
            epsilon=epsilon_loss,
            random_eps=random_eps,
            latent_type=latent_type,
            delta_type=delta_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'gan_regW':
        G_loss = EasyDict(
            func_name='training.loss_vc2.G_logistic_ns_regW',
            latent_type=latent_type,
            regW_lambda=regW_lambda)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vc2.D_logistic_r1_vc2',
            D_global_size=D_global_size,
            latent_type=latent_type)  # Options for discriminator loss.

    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    # tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    tf_config = {
        'rnd.np_random_seed': random_seed
    }  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    # sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.G_lrate_base = sched.D_lrate_base = learning_rate
    sched.minibatch_size_base = batch_size
    sched.minibatch_gpu_base = batch_per_gpu
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset, max_label_size='full')

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    # I.fmap_base = 2 << 8
    # G.fmap_base = 2 << 8
    # D.fmap_base = 2 << 9
    I.fmap_base = 2 << I_fmap_base
    G.fmap_base = 2 << G_fmap_base
    D.fmap_base = 2 << D_fmap_base

    # Config E: Set gamma to 100 and override G & D architecture.
    # D_loss.gamma = 100

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(
        G_args=G,
        D_args=D,
        I_args=I,
        I_info_args=I_info,
        G_opt_args=G_opt,
        D_opt_args=D_opt,
        G_loss_args=G_loss,
        D_loss_args=D_loss,
        use_info_gan=(
            model_type == 'vc2_info_gan2'),  # Independent branch version
        use_vc_head=(model_type == 'vc2_gan' or model_type == 'vc2_gan_own_I'
                     or model_type == 'vc2_gan_own_ID'
                     or model_type == 'vc2_gan_byvae'
                     or model_type == 'vc2_gan_byvae_simple'),
        use_vc2_info_gan=(model_type == 'vc2_info_gan'),  # G1 and G2 version
        use_perdis=(model_type == 'vc2_traversal_contrastive'),
        avg_mv_for_I=avg_mv_for_I,
        traversal_grid=True,
        return_atts=return_atts)
    n_continuous = 0
    if not (module_list is None):
        for i, key in enumerate(key_ls):
            m_name = key.split('-')[0]
            if (m_name in LATENT_MODULES) and (not m_name == 'D_global'):
                n_continuous += size_ls[i]
    else:
        n_continuous = dlatent_size

    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_discrete=D_global_size,
                  n_continuous=n_continuous,
                  n_samples_per=n_samples_per,
                  topk_dims_to_show=topk_dims_to_show,
                  cascade_alt_freq_k=cascade_alt_freq_k,
                  network_snapshot_ticks=network_snapshot_ticks)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)