Esempio n. 1
0
def create_model(data_shape, full=False, labels=None, kwargs_in=None):
    init_res, resolution, res_log2 = calc_init_res(data_shape[1:])
    kwargs_out = dnnlib.EasyDict()
    kwargs_out.num_channels = data_shape[0]
    if kwargs_in is not None:
        for k in list(kwargs_in.keys()):
            kwargs_out[k] = kwargs_in[k]
    if labels is not None: kwargs_out.label_size = labels
    kwargs_out.resolution = resolution
    kwargs_out.init_res = init_res
    if a.verbose is True:
        print(['%s: %s' % (kv[0], kv[1]) for kv in sorted(kwargs_out.items())])
    if full is True:
        G = tflib.Network('G',
                          func_name='training.networks_stylegan2.G_main',
                          **kwargs_out)
        D = tflib.Network('D',
                          func_name='training.networks_stylegan2.D_stylegan2',
                          **kwargs_out)
        Gs = G.clone('Gs')
    else:
        Gs = tflib.Network('Gs',
                           func_name='training.networks_stylegan2.G_main',
                           **kwargs_out)
        G = D = None
    return G, D, Gs
Esempio n. 2
0
def run(dataset, train_dir, config, d_aug, diffaug_policy, cond, ops, jpg_data, mirror, mirror_v, \
        lod_step_kimg, batch_size, resume, resume_kimg, finetune, num_gpus, ema_kimg, gamma, freezeD):

    # dataset (tfrecords) - preprocess or get
    tfr_files = file_list(os.path.dirname(dataset), 'tfr')
    tfr_files = [f for f in tfr_files if basename(dataset) in f]
    if len(tfr_files) == 0:
        tfr_file, total_samples = create_from_images(dataset, jpg=jpg_data)
    else:
        tfr_file = tfr_files[0]
    dataset_args = EasyDict(tfrecord=tfr_file, jpg_data=jpg_data)

    desc = basename(tfr_file).split('-')[0]

    # training functions
    if d_aug:  # https://github.com/mit-han-lab/data-efficient-gans
        train = EasyDict(
            run_func_name='training.training_loop_diffaug.training_loop'
        )  # Options for training loop (Diff Augment method)
        loss_args = EasyDict(
            func_name='training.loss_diffaug.ns_DiffAugment_r1',
            policy=diffaug_policy)  # Options for loss (Diff Augment method)
    else:  # original nvidia
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop (original from NVidia)
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                          )  # Options for discriminator loss.

    # network functions
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    G.impl = D.impl = ops

    # resolutions
    data_res = basename(tfr_file).split('-')[-1].split(
        'x')  # get resolution from dataset filename
    data_res = list(reversed([int(x)
                              for x in data_res]))  # convert to int list
    init_res, resolution, res_log2 = calc_init_res(data_res)
    if init_res != [4, 4]:
        print(' custom init resolution', init_res)
    G.init_res = D.init_res = list(init_res)

    train.setname = desc + config
    desc = '%s-%d-%s' % (desc, resolution, config)

    # training schedule
    sched.lod_training_kimg = lod_step_kimg
    sched.lod_transition_kimg = lod_step_kimg
    train.total_kimg = lod_step_kimg * res_log2 * 2  # a la ProGAN
    if finetune:
        train.total_kimg = 15000  # should start from ~10k kimg
    train.image_snapshot_ticks = 1
    train.network_snapshot_ticks = 5
    train.mirror_augment = mirror
    train.mirror_augment_v = mirror_v

    # learning rate
    if config == 'e':
        if finetune:  # uptrain 1024
            sched.G_lrate_base = 0.001
        else:  # train 1024
            sched.G_lrate_base = 0.001
            sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003}
            sched.lrate_step = 1500  # period for stepping to next lrate, in kimg
    if config == 'f':
        # sched.G_lrate_base = 0.0003
        sched.G_lrate_base = 0.001
    sched.D_lrate_base = sched.G_lrate_base  # *2 - not used anyway

    sched.minibatch_gpu_base = batch_size
    sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base
    sc.num_gpus = num_gpus

    if config == 'e':
        G.fmap_base = D.fmap_base = 8 << 10
        if d_aug: loss_args.gamma = 100 if gamma is None else gamma
        else: D_loss.gamma = 100 if gamma is None else gamma
    elif config == 'f':
        G.fmap_base = D.fmap_base = 16 << 10
    else:
        print(' Only configs E and F are implemented')
        exit()

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    if freezeD:
        D.freezeD = True
        train.resume_with_new_nets = True

    if d_aug:
        desc += '-daug'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    if d_aug:
        kwargs.update(loss_args=loss_args)
    else:
        kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = train_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Esempio n. 3
0
def main():
    tflib.init_tf({'allow_soft_placement': True})

    G_in, D_in, Gs_in = load_pkl(a.source)
    print(' Loading model', a.source, Gs_in.output_shape)
    _, res_in, _ = calc_init_res(Gs_in.output_shape[1:])

    if a.res is not None or a.alpha is True:
        if a.res is None: a.res = Gs_in.output_shape[2:]
        colors = 4 if a.alpha is True else Gs_in.output_shape[
            1]  # EXPERIMENTAL
        _, res_out, _ = calc_init_res([colors, *a.res])

        if res_in != res_out or a.alpha is True:  # add or remove layers
            assert G_in is not None and D_in is not None, " !! G/D subnets not found in source model !!"
            data_shape = [colors, res_out, res_out]
            print(' Reconstructing full model with shape', data_shape)
            G_out, D_out, Gs_out = create_model(data_shape, True, 0,
                                                Gs_in.static_kwargs)
            copy_vars(Gs_in, Gs_out)
            copy_vars(G_in, G_out)
            copy_vars(D_in, D_out, D=True)
            G_in, D_in, Gs_in = G_out, D_out, Gs_out
            a.full = True

        if a.res[0] != res_out or a.res[1] != res_out:  # crop or pad layers
            data_shape = [colors, *a.res]
            G_out, D_out, Gs_out = create_model(data_shape, True, 0,
                                                Gs_in.static_kwargs)
            if G_in is not None and D_in is not None:
                print(' Reconstructing full model with shape', data_shape)
                copy_and_crop_or_pad_trainables(G_in, G_out)
                copy_and_crop_or_pad_trainables(D_in, D_out)
                G_in, D_in = G_out, D_out
                a.full = True
            else:
                print(' Reconstructing Gs model with shape', data_shape)
            copy_and_crop_or_pad_trainables(Gs_in, Gs_out)
            Gs_in = Gs_out

    if a.labels is not None:
        assert G_in is not None and D_in is not None, " !! G/D subnets not found in source model !!"
        print(' Reconstructing full model with labels', a.labels)
        data_shape = Gs_in.output_shape[1:]
        G_out, D_out, Gs_out = create_model(data_shape, True, a.labels,
                                            Gs_in.static_kwargs)
        if a.verbose is True: D_out.print_layers()
        if a.verbose is True: G_out.print_layers()
        copy_and_fill_trainables(G_in, G_out)
        copy_and_fill_trainables(D_in, D_out)
        copy_and_fill_trainables(Gs_in, Gs_out)
        a.full = True

    if a.labels is None and a.res is None and a.alpha is not True:
        if a.reconstruct is True:
            print(' Reconstructing model with same size /',
                  'full' if a.full else 'Gs')
            data_shape = Gs_in.output_shape[1:]
            G_out, D_out, Gs_out = create_model(data_shape, a.full, 0,
                                                Gs_in.static_kwargs)
            Gs_out.copy_vars_from(Gs_in)
            if a.full is True and G_in is not None and D_in is not None:
                G_out.copy_vars_from(G_in)
                D_out.copy_vars_from(D_in)
        else:
            Gs_out = Gs_in

    out_name = basename(a.source)
    if a.res is not None: out_name += '-%dx%d' % (a.res[1], a.res[0])
    if a.alpha is True: out_name += 'a'
    if a.labels is not None: out_name += '-c%d' % a.labels

    if a.full is True:  # G_in is not None and D_in is not None
        save_pkl((G_out, D_out, Gs_out),
                 os.path.join(a.out_dir, '%s.pkl' % out_name))
    else:
        save_pkl(Gs_out, os.path.join(a.out_dir, '%s-Gs.pkl' % out_name))

    print(' Done')