Beispiel #1
0
def run(opt):
    """Sets-up all of the parameters necessary to start a ProgressiveGAN training job."""
    desc = build_job_name(
        opt)  # Description string included in result subdir name.
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_progan.G_paper'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_progan.D_paper'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(
        func_name='training.loss.G_wgan')  # Options for generator loss.
    D_loss = EasyDict(
        func_name='training.loss.D_wgan_gp')  # Options for discriminator loss.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    submit_config = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {
        'rnd.np_random_seed': opt.seed
    }  # Options for tflib.init_tf().
    metrics = []  # Metrics to run during training.

    if 'FID' in opt.metrics:
        metrics.append(metric_base.fid50k)
    if 'PPL' in opt.metrics:
        metrics.append(metric_base.ppl_zend_v2)
    train.network_metric_ticks = opt.compute_metrics_ticks
    train.interp_snapshot_ticks = opt.compute_interp_ticks

    find_dataset(opt.dataset)

    # Optionally resume from checkpoint:
    if opt.resume_exp is not None:
        results_dir = os.path.join(os.getcwd(), config.result_dir)
        _resume_pkl = find_pkl(results_dir, opt.resume_exp,
                               opt.resume_snapshot)
        train.resume_run_id = opt.resume_exp
        train.resume_snapshot = _resume_pkl
        train.resume_kimg = int(_resume_pkl.split('.pkl')[0][-6:])
        if f'hessian_penalty_{opt.dataset}' not in _resume_pkl and opt.hp_lambda > 0:
            print(
                'When fine-tuning a job that was originally trained without the Hessian Penalty, '
                'hp_start_kimg is relative to the kimg of the checkpoint being resumed from. '
                'Hessian Penalty will be phased-in starting at absolute '
                f'kimg={opt.hp_start_kimg + train.resume_kimg}.')
            opt.hp_start_kimg += train.resume_kimg

    # Set up dataset hyper-parameters:
    dataset = EasyDict(tfrecord_dir=os.path.join(os.getcwd(), config.data_dir,
                                                 opt.dataset),
                       resolution=opt.resolution)
    train.mirror_augment = False

    # Set up network hyper-parameters:
    G.latent_size = opt.nz
    D.infogan_nz = opt.infogan_nz
    G.infogan_lambda = opt.infogan_lambda
    D.infogan_lambda = opt.infogan_lambda

    # When computing the multi-layer Hessian Penalty, we retrieve intermediate activations by accessing the
    # corresponding tensor's name. Below are the names of various activations in G that we can retrieve:
    activation_type = 'norm'
    progan_generator_layer_index_to_name = {
        1: f'4x4/Dense/Post_{activation_type}',
        2: f'4x4/Conv/Post_{activation_type}',
        3: f'8x8/Conv0_up/Post_{activation_type}',
        4: f'8x8/Conv1/Post_{activation_type}',
        5: f'16x16/Conv0_up/Post_{activation_type}',
        6: f'16x16/Conv1/Post_{activation_type}',
        7: f'32x32/Conv0_up/Post_{activation_type}',
        8: f'32x32/Conv1/Post_{activation_type}',
        9: f'64x64/Conv0_up/Post_{activation_type}',
        10: f'64x64/Conv1/Post_{activation_type}',
        11: f'128x128/Conv0_up/Post_{activation_type}',
        12: f'128x128/Conv1/Post_{activation_type}',
        13: 'images_out'  # final full-resolution RGB activation
    }

    # Convert from layer indices to layer names (which we'll need to compute the Hessian Penalty):
    layers_to_reg = [
        progan_generator_layer_index_to_name[layer_ix]
        for layer_ix in sorted(opt.layers_to_reg)
    ]

    # Store the Hessian Penalty parameters in their own dictionary:
    HP = EasyDict(hp_lambda=opt.hp_lambda,
                  epsilon=opt.epsilon,
                  num_rademacher_samples=opt.num_rademacher_samples,
                  layers_to_reg=layers_to_reg,
                  warmup_nimg=opt.warmup_kimg * 1000,
                  hp_start_nimg=opt.hp_start_kimg * 1000)

    # How long to train for (as measured by thousands of real images processed, not gradient steps):
    train.total_kimg = opt.total_kimg

    # We ran the original experiments using 4 GPUs per job. If using a different number,
    # we try to scale batch sizes up or down accordingly in the for-loop below. Note that
    # using other batch sizes is somewhat untested, though!
    submit_config.num_gpus = opt.num_gpus
    sched.minibatch_base = 32
    sched.minibatch_dict = {
        4: 2048,
        8: 1024,
        16: 512,
        32: 256,
        64: 128,
        128: 96,
        256: 32,
        512: 16
    }
    for res, batch_size in sched.minibatch_dict.items():
        sched.minibatch_dict[res] = int(batch_size * opt.num_gpus / 4)

    # Set-up WandB if optionally using it instead of TensorBoard:
    if opt.dashboard_api == 'wandb':
        init_wandb(opt=opt,
                   name=desc,
                   group=opt.dataset,
                   entity=opt.wandb_entity)

    # Start the training job:
    kwargs = EasyDict(train)
    kwargs.update(HP_args=HP,
                  G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(submit_config)
    kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(
        config.result_dir)
    kwargs.submit_config.run_dir_ignore += config.run_dir_ignore
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)