Beispiel #1
0
def run(model, dataset, data_dir, result_dir, config_id, num_gpus, total_kimg,
        gamma, mirror_augment, metrics, max_images, resume_pkl, resume_kimg,
        resume_time, pr):
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.' + model +
                 '.G_main')  # Options for generator network.
    D = EasyDict(func_name='training.' + model +
                 '.D_stylegan2')  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='8k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    train.resume_pkl = resume_pkl
    train.resume_kimg = resume_kimg
    train.resume_time = resume_time
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)
    dataset_args.update(max_images=max_images)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig' in config_id: G.architecture = 'orig'
        if 'Gskip' in config_id: G.architecture = 'skip'  # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig' in config_id: D.architecture = 'orig'
        if 'Dskip' in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet'  # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {
            128: 0.0015,
            256: 0.002,
            512: 0.003,
            1024: 0.003
        }
        sched.minibatch_size_base = 32  # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4  # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = 'G_synthesis_stylegan_revised'
        D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-C: Disable path length regularization.
    if config_id in ['config-a', 'config-b', 'config-c'] or\
            model in ['networks_stylegan2_att', 'networks_stylegan2_satt', 'networks_stylegan2_base',
                      'networks_stylegan2_resample', 'networks_stylegan2_nonoise']:
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    if pr is not None:
        if pr == 'true':
            G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg')
        elif pr == 'false':
            G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    if 'dcgan' in model:
        sched.G_lrate_base = sched.D_lrate_base = 0.0002
        sched.minibatch_size_base = 128
        sched.minibatch_gpu_base = 32  # (default)
        G.func_name = 'training.dcgan.G_main'
        D.func_name = 'training.dcgan.D_stylegan2'
        train.run_func_name = 'training.dcgan_loop.training_loop'
        G_opt = EasyDict(beta1=0.5, beta2=0.99,
                         epsilon=1e-8)  # Options for generator optimizer.
        D_opt = EasyDict(beta1=0.5, beta2=0.99,
                         epsilon=1e-8)  # Options for discriminator optimizer.
        G_loss.func_name = 'training.loss.G_logistic_ns'
        D_loss.func_name = 'training.loss.D_logistic'

        # G_loss = EasyDict(func_name='training.loss.G_loss_dcgan')
        # D_loss = EasyDict(func_name='training.loss.D_loss_dcgan')  # Options for discriminator loss.
        if 'add' in model:
            G.noise_style = 'add'
        elif 're' in model:
            G.noise_style = 're'
        if 'church' in dataset:
            G.clip_style = 'church'
        elif 'cat' in dataset:
            G.clip_style = 'cat'
        if 'cifar' in dataset:
            G.func_name = 'training.dcgan_cifar10.G_main'
            D.func_name = 'training.dcgan_cifar10.D_stylegan2'
        if 'wgan' in model:
            G_loss.func_name = 'training.loss.G_wgan'
            D_loss.func_name = 'training.loss.D_wgan_gp'

    # if model in ['networks_stylegan2_resample']:
    #     if 'cat' in dataset:
    # G.clip_style = 'cat'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)