Beispiel #1
0
def run(g_loss, g_loss_kwargs, d_loss, d_loss_kwargs, dataset_train,
        dataset_eval, data_dir, result_dir, config_id, num_gpus, total_kimg,
        gamma, mirror_augment, metrics, resume_pkl, resume_kimg,
        resume_pkl_dir, max_images, lrate_base, img_ticks, net_ticks,
        skip_images):

    if g_loss_kwargs != '': g_loss_kwargs = json.loads(g_loss_kwargs)
    else: g_loss_kwargs = {}
    if d_loss_kwargs != '': d_loss_kwargs = json.loads(d_loss_kwargs)
    else: d_loss_kwargs = {}

    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(
        func_name='training.loss.' + g_loss, **g_loss_kwargs
    )  #G_logistic_ns_gsreg')      # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.' + d_loss,
                      **d_loss_kwargs)  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().

    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = img_ticks
    train.network_snapshot_ticks = net_ticks
    G.scale_func = 'training.networks_stylegan2.apply_identity'
    D.scale_func = None
    sched.G_lrate_base = sched.D_lrate_base = lrate_base  #0.002
    # TODO: Changed this to 16 to match DiffAug
    sched.minibatch_size_base = 16
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'
    sched.tick_kimg_base = 1
    sched.tick_kimg_dict = {
    }  #{8:28, 16:24, 32:20, 64:16, 128:12, 256:8, 512:6, 1024:4}): # Resolution-specific overrides.

    desc += '-' + dataset_train.split('/')[-1]
    # Get dataset paths
    t_path = dataset_train.split('/')
    e_path = dataset_eval.split('/')
    if len(t_path) > 1:
        dataset_train = t_path[-1]
        train.train_data_dir = os.path.join(data_dir, '/'.join(t_path[:-1]))
    if len(e_path) > 1:
        dataset_eval = e_path[-1]
        train.eval_data_dir = os.path.join(data_dir, '/'.join(e_path[:-1]))
    dataset_args = EasyDict(tfrecord_dir=dataset_train)
    # Limit number of training images during train (not eval)
    dataset_args['max_images'] = max_images
    if max_images: desc += '-%dimg' % max_images
    dataset_args['skip_images'] = skip_images
    dataset_args_eval = EasyDict(tfrecord_dir=dataset_eval)
    desc += '-' + dataset_eval

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    if mirror_augment: desc += '-aug'

    # Infer pretrain checkpoint from target dataset
    if not resume_pkl:
        if any(ds in dataset_train.lower()
               for ds in ['obama', 'celeba', 'rem', 'portrait']):
            resume_pkl = 'ffhq-config-f.pkl'
        if any(ds in dataset_train.lower()
               for ds in ['gogh', 'temple', 'tower', 'medici', 'bridge']):
            resume_pkl = 'church-config-f.pkl'
        if any(ds in dataset_train.lower() for ds in ['bus']):
            resume_pkl = 'car-config-f.pkl'
    resume_pkl = os.path.join(resume_pkl_dir, resume_pkl)
    train.resume_pkl = resume_pkl
    train.resume_kimg = resume_kimg

    train.resume_with_new_nets = True  # Recreate with new parameters
    # Adaptive parameters
    if 'ada' in config_id:
        G['train_scope'] = D[
            'train_scope'] = '.*adapt'  # Freeze old parameters
        if 'ss' in config_id:
            G['adapt_func'] = D[
                'adapt_func'] = 'training.networks_stylegan2.apply_adaptive_scale_shift'
        if 'sv' or 'pc' in config_id:  # [:9] == 'config-sv' or config_id[:9] == 'config-pc':
            G['map_svd'] = G['syn_svd'] = D['svd'] = True
            # Flatten over spatial dimension
            if 'flat' in config_id:
                G['spatial'] = D['spatial'] = True
            # Do PCA by centering before SVD
            if 'pc' in config_id:
                G['svd_center'] = D['svd_center'] = True
            G['svd_config'] = D['svd_config'] = 'S'
            if 'U' in config_id:
                G['svd_config'] += 'U'
                D['svd_config'] += 'U'
            if 'V' in config_id:
                G['svd_config'] += 'V'
                D['svd_config'] += 'V'
    # FreezeD
    D['freeze'] = 'fd' in config_id  #freeze_d
    # DiffAug
    if 'da' in config_id:
        G_loss = EasyDict(func_name='training.loss.G_ns_diffaug')
        D_loss = EasyDict(func_name='training.loss.D_ns_diffaug_r1')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)

    kwargs.update(
        G_args=G,
        D_args=D,
        G_opt_args=G_opt,
        D_opt_args=D_opt,
        G_loss_args=G_loss,
        D_loss_args=D_loss,
    )
    kwargs.update(dataset_args=dataset_args,
                  dataset_args_eval=dataset_args_eval,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)