Beispiel #1
0
def run(dataset, data_dir, result_dir, num_gpus, total_kimg, mirror_augment,
        metrics, resume, resume_with_new_nets, disable_style_mod,
        disable_cond_mod):

    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.co_mod_gan.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.co_mod_gan.D_co_mod_gan'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_masked_logistic_ns_l1'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_masked_logistic_r1'
                      )  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='8k', layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'co-mod-gan'

    desc += '-' + os.path.basename(dataset)
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    if resume is not None:
        resume_kimg = int(
            os.path.basename(resume).replace('.pkl', '').split('-')[-1])
    else:
        resume_kimg = 0

    if disable_style_mod:
        G.style_mod = False

    if disable_cond_mod:
        G.cond_mod = False

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=resume_with_new_nets)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)