Beispiel #1
0
def run(dataset,
        data_dir,
        result_dir,
        num_gpus,
        total_kimg,
        mirror_augment,
        metrics,
        resume_pkl,
        model_type='vc_gan2',
        latent_type='uniform',
        batch_size=32,
        batch_per_gpu=16,
        random_seed=1000,
        G_fmap_base=8,
        module_G_list=None,
        G_nf_scale=4,
        E_fmap_base=8,
        module_E_list=None,
        E_nf_scale=4,
        D_fmap_base=9,
        module_D_list=None,
        D_nf_scale=4,
        I_fmap_base=9,
        module_I_list=None,
        I_nf_scale=4,
        fmap_decay=0.15,
        fmap_min=16,
        fmap_max=512,
        n_samples_per=10,
        topk_dims_to_show=20,
        hy_beta=1,
        hy_gamma=0,
        hy_dcp=40,
        hy_ncut=1,
        hy_rec=20,
        hy_hes=20,
        hy_lin=20,
        hy_mat=80,
        hy_gmat=0,
        hy_oth=80,
        hy_det=0,
        hessian_type='no_act_points',
        n_act_points=10,
        lie_alg_init_type='oth',
        lie_alg_init_scale=0.1,
        G_lrate_base=0.002,
        D_lrate_base=None,
        lambda_d_factor=10.,
        lambda_od=1.,
        group_loss_type='_rec_mat_',
        group_feats_size=400,
        temp=0.67,
        n_discrete=0,
        drange_net=[-1, 1],
        recons_type='bernoulli_loss',
        use_group_decomp=False,
        snapshot_ticks=10):
    train = EasyDict(
        run_func_name='training.training_loop_vae.training_loop_vae'
    )  # Options for training loop.

    if not (module_G_list is None):
        module_G_list = _str_to_list(module_G_list)
        key_G_ls, size_G_ls, count_dlatent_G_size = split_module_names(
            module_G_list)
    if not (module_E_list is None):
        module_E_list = _str_to_list(module_E_list)
        key_E_ls, size_E_ls, count_dlatent_E_size = split_module_names(
            module_E_list)
    if not (module_D_list is None):
        module_D_list = _str_to_list(module_D_list)
        key_D_ls, size_D_ls, count_dlatent_D_size = split_module_names(
            module_D_list)
    if not (module_I_list is None):
        module_I_list = _str_to_list(module_I_list)
        key_I_ls, size_I_ls, count_dlatent_I_size = split_module_names(
            module_I_list)

    D = D_opt = D_loss = None
    E = EasyDict(func_name='training.vae_networks.E_main_modular',
                 fmap_min=fmap_min,
                 fmap_max=fmap_max,
                 fmap_decay=fmap_decay,
                 latent_size=count_dlatent_E_size,
                 group_feats_size=group_feats_size,
                 module_E_list=module_E_list,
                 nf_scale=E_nf_scale,
                 n_discrete=n_discrete,
                 fmap_base=2 << E_fmap_base)  # Options for encoder network.
    G = EasyDict(func_name='training.vae_networks.G_main_modular',
                 fmap_min=fmap_min,
                 fmap_max=fmap_max,
                 fmap_decay=fmap_decay,
                 latent_size=count_dlatent_G_size,
                 group_feats_size=group_feats_size,
                 module_G_list=module_G_list,
                 nf_scale=G_nf_scale,
                 n_discrete=n_discrete,
                 recons_type=recons_type,
                 n_act_points=n_act_points,
                 lie_alg_init_type=lie_alg_init_type,
                 lie_alg_init_scale=lie_alg_init_scale,
                 fmap_base=2 << G_fmap_base)  # Options for generator network.
    I = EasyDict(func_name='training.vae_I_networks.I_main_modular',
                 fmap_min=fmap_min,
                 fmap_max=fmap_max,
                 fmap_decay=fmap_decay,
                 latent_size=count_dlatent_I_size,
                 module_I_list=module_I_list,
                 nf_scale=I_nf_scale,
                 fmap_base=2 << I_fmap_base)  # Options for I network.
    G_opt = EasyDict(beta1=0.9, beta2=0.999,
                     epsilon=1e-8)  # Options for generator optimizer.
    if model_type == 'factor_vae' or model_type == 'factor_sindis_vae':  # Factor-VAE
        D = EasyDict(
            func_name='training.vae_networks.D_factor_vae_modular',
            fmap_min=fmap_min,
            fmap_max=fmap_max,
            fmap_decay=fmap_decay,
            latent_size=count_dlatent_D_size,
            module_D_list=module_D_list,
            nf_scale=D_nf_scale,
            fmap_base=2 << D_fmap_base)  # Options for generator network.
        D_opt = EasyDict(beta1=0.5, beta2=0.9,
                         epsilon=1e-8)  # Options for discriminator optimizer.
    desc = model_type + '_modular'

    if model_type == 'beta_vae':  # Beta-VAE
        G_loss = EasyDict(
            func_name='training.loss_vae.beta_vae',
            latent_type=latent_type,
            hy_beta=hy_beta,
            recons_type=recons_type)  # Options for generator loss.
    elif model_type == 'betatc_vae':  # BetaTC-VAE
        G_loss = EasyDict(
            func_name='training.loss_vae.betatc_vae',
            latent_type=latent_type,
            hy_beta=hy_beta,
            recons_type=recons_type)  # Options for generator loss.
    elif model_type == 'lie_vae':  # LieVAE
        G_loss = EasyDict(
            func_name='training.loss_vae_lie.lie_vae',
            latent_type=latent_type,
            hy_rec=hy_rec,
            hy_dcp=hy_dcp,
            hy_hes=hy_hes,
            hy_lin=hy_lin,
            hy_ncut=hy_ncut,
            recons_type=recons_type)  # Options for generator loss.
    elif model_type == 'lie_vae_with_split':  # LieVAE with split loss
        G_loss = EasyDict(
            func_name='training.loss_vae_lie.lie_vae_with_split',
            latent_type=latent_type,
            hy_rec=hy_rec,
            hy_dcp=hy_dcp,
            hy_hes=hy_hes,
            hy_lin=hy_lin,
            hy_ncut=hy_ncut,
            recons_type=recons_type)  # Options for generator loss.
    elif model_type == 'group_vae_v2':  # GroupVAE-v2
        G_loss = EasyDict(
            func_name='training.loss_vae_group_v2.group_act_vae',
            latent_type=latent_type,
            hy_beta=hy_beta,
            hy_rec=hy_rec,
            hy_gmat=hy_gmat,
            hy_dcp=hy_dcp,
            hy_hes=hy_hes,
            hy_lin=hy_lin,
            hy_ncut=hy_ncut,
            hessian_type=hessian_type,
            recons_type=recons_type)  # Options for generator loss.
    elif model_type == 'group_vae_spl_v2':  # GroupVAE-v2
        G_loss = EasyDict(
            func_name='training.loss_vae_group_v2.group_act_spl_vae',
            latent_type=latent_type,
            hy_beta=hy_beta,
            hy_rec=hy_rec,
            hy_gmat=hy_gmat,
            hy_dcp=hy_dcp,
            hy_hes=hy_hes,
            hy_lin=hy_lin,
            hy_ncut=hy_ncut,
            hessian_type=hessian_type,
            recons_type=recons_type)  # Options for generator loss.
    elif model_type == 'group_vae':  # Group-VAE
        G_loss = EasyDict(
            func_name='training.loss_vae.group_vae',
            latent_type=latent_type,
            hy_beta=hy_beta,
            hy_dcp=hy_dcp,
            hy_ncut=hy_ncut,
            hy_rec=hy_rec,
            hy_mat=hy_mat,
            hy_gmat=hy_gmat,
            hy_oth=hy_oth,
            hy_det=hy_det,
            use_group_decomp=use_group_decomp,
            group_loss_type=group_loss_type,
            recons_type=recons_type)  # Options for generator loss.
    elif model_type == 'group_vae_wc':  # Group-VAE-with_Cat
        G_loss = EasyDict(
            func_name='training.loss_vae.group_vae_wc',
            latent_type=latent_type,
            hy_beta=hy_beta,
            hy_gamma=hy_gamma,
            temp=temp,
            use_group_decomp=use_group_decomp,
            group_loss_type=group_loss_type,
            recons_type=recons_type)  # Options for generator loss.
    elif model_type == 'dip_vae_i' or model_type == 'dip_vae_ii':  # DIP-VAE
        G_loss = EasyDict(
            func_name='training.loss_vae.dip_vae',
            lambda_d_factor=lambda_d_factor,
            lambda_od=lambda_od,
            latent_type=latent_type,
            dip_type=model_type,
            recons_type=recons_type)  # Options for generator loss.
    elif model_type == 'factor_vae':  # Factor-VAE
        G_loss = EasyDict(
            func_name='training.loss_vae.factor_vae_G',
            latent_type=latent_type,
            hy_gamma=hy_gamma,
            recons_type=recons_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vae.factor_vae_D',
            latent_type=latent_type)  # Options for discriminator loss.
    elif model_type == 'factor_sindis_vae':  # Factor-VAE
        G_loss = EasyDict(
            func_name='training.loss_vae.factor_vae_sindis_G',
            latent_type=latent_type,
            hy_gamma=hy_gamma,
            recons_type=recons_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_vae.factor_vae_sindis_D',
            latent_type=latent_type)  # Options for discriminator loss.

    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {
        'rnd.np_random_seed': random_seed,
        'allow_soft_placement': True
    }  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = snapshot_ticks
    sched.G_lrate_base = G_lrate_base
    sched.D_lrate_base = D_lrate_base
    sched.minibatch_size_base = batch_size
    sched.minibatch_gpu_base = batch_per_gpu
    metrics = [metric_defaults[x] for x in metrics]

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset, max_label_size='full')

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  E_args=E,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss,
                  traversal_grid=True)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  n_continuous=count_dlatent_G_size,
                  n_discrete=n_discrete,
                  drange_net=drange_net,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_samples_per=n_samples_per,
                  topk_dims_to_show=topk_dims_to_show)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #2
0
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma,
        mirror_augment, mirror_augment_v, metrics, min_h, min_w, res_log2, lr,
        use_attention, resume_with_new_nets, glr, dlr, use_raw, resume_pkl,
        minibatch_gpu_base, network_snapshot_ticks):
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='8k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.mirror_augment_v = mirror_augment_v
    train.resume_with_new_nets = resume_with_new_nets
    train.image_snapshot_ticks = 1
    train.network_snapshot_ticks = network_snapshot_ticks
    sched.G_lrate_base = sched.D_lrate_base = lr
    train.resume_pkl = resume_pkl

    if glr:
        sched.G_lrate_base = glr
    if dlr:
        sched.D_lrate_base = dlr

    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = minibatch_gpu_base
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)
    dataset_args.use_raw = use_raw
    G.min_h = D.min_h = dataset_args.min_h = min_h
    G.min_w = D.min_w = dataset_args.min_w = min_w
    G.res_log2 = D.res_log2 = dataset_args.res_log2 = res_log2

    if use_attention:
        desc += '-attention'
        G.use_attention = True
        D.use_attention = True

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {
            128: 0.0015,
            256: 0.002,
            512: 0.003,
            1024: 0.003
        }
        sched.minibatch_size_base = 32  # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4  # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = 'G_synthesis_stylegan_revised'
        D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-C: Disable path length regularization.
    if config_id in ['config-a', 'config-b', 'config-c']:
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #3
0
def run(dataset, train_dir, config, d_aug, diffaug_policy, cond, ops, jpg_data, mirror, mirror_v, \
        lod_step_kimg, batch_size, resume, resume_kimg, finetune, num_gpus, ema_kimg, gamma, freezeD):

    # dataset (tfrecords) - preprocess or get
    tfr_files = file_list(os.path.dirname(dataset), 'tfr')
    tfr_files = [f for f in tfr_files if basename(dataset) in f]
    if len(tfr_files) == 0:
        tfr_file, total_samples = create_from_images(dataset, jpg=jpg_data)
    else:
        tfr_file = tfr_files[0]
    dataset_args = EasyDict(tfrecord=tfr_file, jpg_data=jpg_data)

    desc = basename(tfr_file).split('-')[0]

    # training functions
    if d_aug:  # https://github.com/mit-han-lab/data-efficient-gans
        train = EasyDict(
            run_func_name='training.training_loop_diffaug.training_loop'
        )  # Options for training loop (Diff Augment method)
        loss_args = EasyDict(
            func_name='training.loss_diffaug.ns_DiffAugment_r1',
            policy=diffaug_policy)  # Options for loss (Diff Augment method)
    else:  # original nvidia
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop (original from NVidia)
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                          )  # Options for discriminator loss.

    # network functions
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    G.impl = D.impl = ops

    # resolutions
    data_res = basename(tfr_file).split('-')[-1].split(
        'x')  # get resolution from dataset filename
    data_res = list(reversed([int(x)
                              for x in data_res]))  # convert to int list
    init_res, resolution, res_log2 = calc_init_res(data_res)
    if init_res != [4, 4]:
        print(' custom init resolution', init_res)
    G.init_res = D.init_res = list(init_res)

    train.setname = desc + config
    desc = '%s-%d-%s' % (desc, resolution, config)

    # training schedule
    sched.lod_training_kimg = lod_step_kimg
    sched.lod_transition_kimg = lod_step_kimg
    train.total_kimg = lod_step_kimg * res_log2 * 2  # a la ProGAN
    if finetune:
        train.total_kimg = 15000  # should start from ~10k kimg
    train.image_snapshot_ticks = 1
    train.network_snapshot_ticks = 5
    train.mirror_augment = mirror
    train.mirror_augment_v = mirror_v

    # learning rate
    if config == 'e':
        if finetune:  # uptrain 1024
            sched.G_lrate_base = 0.001
        else:  # train 1024
            sched.G_lrate_base = 0.001
            sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003}
            sched.lrate_step = 1500  # period for stepping to next lrate, in kimg
    if config == 'f':
        # sched.G_lrate_base = 0.0003
        sched.G_lrate_base = 0.001
    sched.D_lrate_base = sched.G_lrate_base  # *2 - not used anyway

    sched.minibatch_gpu_base = batch_size
    sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base
    sc.num_gpus = num_gpus

    if config == 'e':
        G.fmap_base = D.fmap_base = 8 << 10
        if d_aug: loss_args.gamma = 100 if gamma is None else gamma
        else: D_loss.gamma = 100 if gamma is None else gamma
    elif config == 'f':
        G.fmap_base = D.fmap_base = 16 << 10
    else:
        print(' Only configs E and F are implemented')
        exit()

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    if freezeD:
        D.freezeD = True
        train.resume_with_new_nets = True

    if d_aug:
        desc += '-daug'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    if d_aug:
        kwargs.update(loss_args=loss_args)
    else:
        kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = train_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #4
0
def run(data, train_dir, config, d_aug, diffaug_policy, cond, ops, mirror, mirror_v, \
        kimg, batch_size, lrate, resume, resume_kimg, num_gpus, ema_kimg, gamma, freezeD):

    # training functions
    if d_aug:  # https://github.com/mit-han-lab/data-efficient-gans
        train = EasyDict(
            run_func_name='training.training_loop_diffaug.training_loop'
        )  # Options for training loop (Diff Augment method)
        loss_args = EasyDict(
            func_name='training.loss_diffaug.ns_DiffAugment_r1',
            policy=diffaug_policy)  # Options for loss (Diff Augment method)
    else:  # original nvidia
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop (original from NVidia)
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                          )  # Options for discriminator loss.

    # network functions
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    G.impl = D.impl = ops

    # dataset (tfrecords) - get or create
    tfr_files = file_list(os.path.dirname(data), 'tfr')
    tfr_files = [
        f for f in tfr_files if basename(data) == basename(f).split('-')[0]
    ]
    if len(tfr_files) == 0 or os.stat(tfr_files[0]).st_size == 0:
        tfr_file, total_samples = create_from_image_folders(
            data) if cond is True else create_from_images(data)
    else:
        tfr_file = tfr_files[0]
    dataset_args = EasyDict(tfrecord=tfr_file)

    # resolutions
    with tf.Graph().as_default(), tflib.create_session().as_default():  # pylint: disable=not-context-manager
        dataset_obj = dataset.load_dataset(
            **dataset_args)  # loading the data to see what comes out
        resolution = dataset_obj.resolution
        init_res = dataset_obj.init_res
        res_log2 = dataset_obj.res_log2
        dataset_obj.close()
        dataset_obj = None

    if list(init_res) == [4, 4]:
        desc = '%s-%d' % (basename(data), resolution)
    else:
        print(' custom init resolution', init_res)
        desc = basename(tfr_file)
    G.init_res = D.init_res = list(init_res)

    train.savenames = [desc.replace(basename(data), 'snapshot'), desc]
    desc += '-%s' % config

    # training schedule
    train.total_kimg = kimg
    train.image_snapshot_ticks = 1 * num_gpus if kimg <= 1000 else 4 * num_gpus
    train.network_snapshot_ticks = 5
    train.mirror_augment = mirror
    train.mirror_augment_v = mirror_v
    sched.tick_kimg_base = 2 if train.total_kimg < 2000 else 4

    # learning rate
    if config == 'e':
        sched.G_lrate_base = 0.001
        sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003}
        sched.lrate_step = 1500  # period for stepping to next lrate, in kimg
    if config == 'f':
        sched.G_lrate_base = lrate  # 0.001 for big datasets, 0.0003 for few-shot
    sched.D_lrate_base = sched.G_lrate_base  # *2 - not used anyway

    # batch size (for 16gb memory GPU)
    sched.minibatch_gpu_base = 4096 // resolution if batch_size is None else batch_size
    print(' Batch size', sched.minibatch_gpu_base)
    sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base
    sc.num_gpus = num_gpus

    if config == 'e':
        G.fmap_base = D.fmap_base = 8 << 10
        if d_aug: loss_args.gamma = 100 if gamma is None else gamma
        else: D_loss.gamma = 100 if gamma is None else gamma
    elif config == 'f':
        G.fmap_base = D.fmap_base = 16 << 10
    else:
        print(' Only configs E and F are implemented')
        exit()

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    if freezeD:
        D.freezeD = True
        train.resume_with_new_nets = True

    if d_aug:
        desc += '-daug'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    if d_aug:
        kwargs.update(loss_args=loss_args)
    else:
        kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = train_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma,
        mirror_augment, metrics):
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='8k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = float(
        os.environ['G_LR']) if 'G_LR' in os.environ else 0.002
    sched.D_lrate_base = float(
        os.environ['D_LR']) if 'D_LR' in os.environ else 0.002
    sched.G_lrate_base *= float(
        os.environ['G_LR_MULT']) if 'G_LR_MULT' in os.environ else 1.0
    sched.D_lrate_base *= float(
        os.environ['D_LR_MULT']) if 'D_LR_MULT' in os.environ else 1.0
    G_opt.beta2 = float(
        os.environ['G_BETA2']) if 'G_BETA2' in os.environ else 0.99
    D_opt.beta2 = float(
        os.environ['D_BETA2']) if 'D_BETA2' in os.environ else 0.99
    print('G_lrate: %f' % sched.G_lrate_base)
    print('D_lrate: %f' % sched.D_lrate_base)
    print('G_beta2: %f' % G_opt.beta2)
    print('D_beta2: %f' % D_opt.beta2)
    sched.minibatch_size_base = int(
        os.environ['BATCH_SIZE']) if 'BATCH_SIZE' in os.environ else num_gpus
    sched.minibatch_gpu_base = int(
        os.environ['BATCH_PER']) if 'BATCH_PER' in os.environ else 1
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'

    desc += '-' + dataset
    resolution = int(
        os.environ['RESOLUTION']) if 'RESOLUTION' in os.environ else 64
    dataset_args = EasyDict(tfrecord_dir=dataset, resolution=resolution)

    assert num_gpus in [
        1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192
    ]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        G.fmap_base = D.fmap_base = 8 << 10

    if 'FMAP_BASE' in os.environ:
        G.fmap_base = D.fmap_base = int(os.environ['FMAP_BASE']) << 10
    else:
        G.fmap_base = D.fmap_base = 16 << 10  # default

    print('G_fmap_base: %d' % G.fmap_base)
    print('D_fmap_base: %d' % D.fmap_base)

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {
            128: 0.0015,
            256: 0.002,
            512: 0.003,
            1024: 0.003
        }
        sched.minibatch_size_base = 32  # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4  # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = 'G_synthesis_stylegan_revised'
        D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-C: Disable path length regularization.
    if config_id in ['config-a', 'config-b', 'config-c']:
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
def run(dataset,
        data_dir,
        result_dir,
        num_gpus,
        total_kimg,
        mirror_augment,
        metrics,
        resume_pkl,
        model_type='vc_gan2',
        latent_type='uniform',
        batch_size=32,
        batch_per_gpu=16,
        random_seed=1000,
        G_fmap_base=8,
        module_G_list=None,
        G_nf_scale=4,
        E_fmap_base=8,
        module_E_list=None,
        E_nf_scale=4,
        D_fmap_base=9,
        module_D_list=None,
        D_nf_scale=4,
        fmap_decay=0.15,
        fmap_min=16,
        fmap_max=512,
        n_samples_per=10,
        arch='resnet',
        topk_dims_to_show=20,
        hy_beta=1,
        hy_gamma=0,
        hy_1p=0,
        lie_alg_init_type='oth',
        lie_alg_init_scale=0.1,
        G_lrate_base=0.002,
        D_lrate_base=None,
        group_feats_size=400,
        temp=0.67,
        n_discrete=0,
        epsilon=1,
        drange_net=[-1, 1],
        recons_type='bernoulli_loss',
        R_view_scale=1,
        group_feat_type='concat',
        use_sphere_points=False,
        use_learnable_sphere_points=False,
        n_sphere_points=100,
        mapping_after_exp=False,
        snapshot_ticks=10):
    train = EasyDict(
        run_func_name='training.training_loop_gan.training_loop_gan'
    )  # Options for training loop.

    if not (module_G_list is None):
        module_G_list = _str_to_list(module_G_list)
        key_G_ls, size_G_ls, count_dlatent_G_size = split_module_names(
            module_G_list)
    if not (module_E_list is None):
        module_E_list = _str_to_list(module_E_list)
        key_E_ls, size_E_ls, count_dlatent_E_size = split_module_names(
            module_E_list)
    if not (module_D_list is None):
        module_D_list = _str_to_list(module_D_list)
        key_D_ls, size_D_ls, count_dlatent_D_size = split_module_names(
            module_D_list)

    E = EasyDict(func_name='training.gan_networks.E_main_modular',
                 fmap_min=fmap_min,
                 fmap_max=fmap_max,
                 fmap_decay=fmap_decay,
                 latent_size=count_dlatent_E_size,
                 group_feats_size=group_feats_size,
                 module_E_list=module_E_list,
                 nf_scale=E_nf_scale,
                 n_discrete=n_discrete,
                 fmap_base=2 << E_fmap_base)  # Options for encoder network.
    D = EasyDict(
        func_name='training.gan_networks.D_main_modular',
        fmap_min=fmap_min,
        fmap_max=fmap_max,
        fmap_decay=fmap_decay,
        latent_size=count_dlatent_D_size,
        group_feats_size=group_feats_size,
        module_D_list=module_D_list,
        nf_scale=D_nf_scale,
        n_discrete=n_discrete,
        fmap_base=2 << D_fmap_base)  # Options for discriminator network.
    G = EasyDict(func_name='training.gan_networks.G_main_modular',
                 fmap_min=fmap_min,
                 fmap_max=fmap_max,
                 fmap_decay=fmap_decay,
                 latent_size=count_dlatent_G_size,
                 group_feats_size=group_feats_size,
                 module_G_list=module_G_list,
                 nf_scale=G_nf_scale,
                 n_discrete=n_discrete,
                 recons_type=recons_type,
                 lie_alg_init_type=lie_alg_init_type,
                 lie_alg_init_scale=lie_alg_init_scale,
                 R_view_scale=R_view_scale,
                 group_feat_type=group_feat_type,
                 mapping_after_exp=mapping_after_exp,
                 use_sphere_points=use_sphere_points,
                 use_learnable_sphere_points=use_learnable_sphere_points,
                 n_sphere_points=n_sphere_points,
                 fmap_base=2 << G_fmap_base)  # Options for generator network.
    G_opt = EasyDict(beta1=0.9, beta2=0.999,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.9, beta2=0.999,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    desc = model_type + '_modular'

    if model_type == 'so_gan':
        G_loss = EasyDict(
            func_name='training.loss_gan_so.so_gan',
            hy_1p=hy_1p,
            hy_beta=hy_beta,
            latent_type=latent_type,
            recons_type=recons_type)  # Options for generator loss.
        D_loss = EasyDict(
            func_name='training.loss_gan.gan_D',
            latent_type=latent_type)  # Options for discriminator loss.
    else:
        raise ValueError('Unknown model_type:', model_type)

    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {
        'rnd.np_random_seed': random_seed,
        'allow_soft_placement': True
    }  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = snapshot_ticks
    sched.G_lrate_base = G_lrate_base
    sched.D_lrate_base = D_lrate_base
    sched.minibatch_size_base = batch_size
    sched.minibatch_gpu_base = batch_per_gpu
    metrics = [metric_defaults[x] for x in metrics]

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset, max_label_size='full')

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  E_args=E,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss,
                  traversal_grid=True)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  n_continuous=count_dlatent_G_size,
                  n_discrete=n_discrete,
                  drange_net=drange_net,
                  metric_arg_list=metrics,
                  tf_config=tf_config,
                  resume_pkl=resume_pkl,
                  n_samples_per=n_samples_per,
                  topk_dims_to_show=topk_dims_to_show)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #7
0
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma,
        mirror_augment, metrics):
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='8k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id not in ['config-f', 'config-l']:
        G.fmap_base = D.fmap_base = 8 << 10

    # Config L: Generator training only
    if config_id == 'config-l':
        # Use labels as latent vector input
        dataset_args.max_label_size = "full"
        # Deactivate methods specific for GAN training
        G.truncation_psi = None
        G.randomize_noise = False
        G.style_mixing_prob = None
        G.dlatent_avg_beta = None
        G.conditional_labels = False
        # Refinement training
        G_loss.func_name = 'training.loss.G_reconstruction'
        train.run_func_name = 'training.training_loop.training_loop_refinement'
        # G.freeze_layers = ["mapping", "noise"]#, "4x4", "8x8", "16x16", "32x32"]
        # Network for refinement
        train.resume_pkl = "nets/stylegan2-ffhq-config-f.pkl"  # TODO init net
        train.resume_with_new_nets = True
        # Maintenance tasks
        sched.tick_kimg_base = 1  # 1 tick = 5000 images (metric update)
        sched.tick_kimg_dict = {}
        train.image_snapshot_ticks = 5  # Save every 5000 images
        train.network_snapshot_ticks = 10  # Save every 10000 images
        # Training parameters
        sched.G_lrate_base = 1e-4
        train.G_smoothing_kimg = 0.0
        sched.minibatch_size_base = sched.minibatch_gpu_base * num_gpus  # 4 per GPU

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {
            128: 0.0015,
            256: 0.002,
            512: 0.003,
            1024: 0.003
        }
        sched.minibatch_size_base = 32  # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4  # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = 'G_synthesis_stylegan_revised'
        D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-C: Disable path length regularization.
    if config_id in ['config-a', 'config-b', 'config-c']:
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)