Beispiel #1
0
        8: 128,
        16: 128,
        32: 64,
        64: 32,
        128: 16,
        256: 8,
        512: 4
    }
    #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}
    #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
    #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

    # Default options.
    train.total_kimg = 25000
    sched.lod_initial_resolution = 8
    sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # WGAN-GP loss for CelebA-HQ.
    desc += '-wgangp'
    G_loss = EasyDict(func_name='training.loss.G_wgan')
    D_loss = EasyDict(func_name='training.loss.D_wgan_gp')
    sched.G_lrate_dict = {
        k: min(v, 0.002)
        for k, v in sched.G_lrate_dict.items()
    }
    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # Table 1.
    #desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
Beispiel #2
0
        256: 8,
        512: 4
    }
    # desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 256; sched.minibatch_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 32, 256: 8, 512: 4}
    #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}
    #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
    #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

    # Default options.
    train.total_kimg = 25000
    sched.lod_initial_resolution = 8
    # sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    sched.G_lrate_dict = {
        4: 0.01,
        8: 0.01,
        16: 0.005,
        32: 0.005,
        64: 0.003,
        128: 0.002
    }
    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # WGAN-GP loss for CelebA-HQ.
    #desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # Table 1.
    #desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-remove-traditional-input'; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-noise-inputs'; G.style_mixing_prob = 0.0
    #desc += '-mixing-regularization' # default
Beispiel #3
0
    # Dataset.
    desc += '-r';     dataset = EasyDict(tfrecord_dir='smalls');              train.mirror_augment = True
    #desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq');          train.mirror_augment = True
    #desc += '-bedroom';  dataset = EasyDict(tfrecord_dir='lsun-bedroom-full'); train.mirror_augment = False
    #desc += '-car';      dataset = EasyDict(tfrecord_dir='lsun-car-512x384');  train.mirror_augment = False
    #desc += '-cat';      dataset = EasyDict(tfrecord_dir='lsun-cat-full');     train.mirror_augment = False

    # Number of GPUs.
    desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}
    #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}
    #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
    #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

    # Default options.
    sched.lod_initial_resolution = 8
    sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)
    train.mirror_augment = True
    train.total_kimg = 10000
	
    resume_run_id = None     # Run ID or network pkl to resume training from, None = start from scratch.
    resume_kimg  = 0      # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time = 0.0     # Assumed wallclock time at the beginning. Affects reporting.
    
    #resume_run_id = "/home/rubenfb14/stylegan/results/00001-sgan-r-1gpu/network-snapshot-002360.pkl"     # Run ID or network pkl to resume training from, None = start from scratch.
    #resume_kimg  = 2360      # Assumed training progress at the beginning. Affects reporting and training schedule.
    #resume_time = 3.11     # Assumed wallclock time at the beginning. Affects reporting.
    # WGAN-GP loss for CelebA-HQ.
    #desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # Table 1.
Beispiel #4
0
    dataset       = EasyDict()                                                             # Options for load_dataset().
    sched         = EasyDict()                                                             # Options for TrainingSchedule.
    grid          = EasyDict(size='4k', layout='random')                                   # Options for setup_snapshot_image_grid().
    metrics       = [metric_base.fid50k]                                                   # Options for MetricGroup.
    submit_config = dnnlib.SubmitConfig()                                                  # Options for dnnlib.submit_run().
    tf_config     = {'rnd.np_random_seed': 1000}                                           # Options for tflib.init_tf().

    # Dataset.
    desc += '-ffhq';     dataset = EasyDict(tfrecord_dir='stylegan');              train.mirror_augment = True
    #desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq');          train.mirror_augment = True
    #desc += '-bedroom';  dataset = EasyDict(tfrecord_dir='lsun-bedroom-full'); train.mirror_augment = False
    #desc += '-car';      dataset = EasyDict(tfrecord_dir='lsun-car-512x384');  train.mirror_augment = False
    #desc += '-cat';      dataset = EasyDict(tfrecord_dir='lsun-cat-full');     train.mirror_augment = False

    # Config presets from Progressive GAN (choose one).
    desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}; sched.G_lrate_dict = {1024: 0.0015}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000
    #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}; sched.G_lrate_dict = {512: 0.0015, 1024: 0.002}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000
    #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}; sched.G_lrate_dict = {256: 0.0015, 512: 0.002, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000
    #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}; sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000

    # Tuned config for StyleGAN.
    train.total_kimg = 25000; sched.lod_initial_resolution = 8

    # WGAN-GP loss for CelebA-HQ.
    #desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # Table 1.
    #desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-remove-traditional-input'; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-noise-inputs'; G.style_mixing_prob = 0.0
Beispiel #5
0
def run(data, train_dir, config, d_aug, diffaug_policy, cond, ops, mirror, mirror_v, \
        kimg, batch_size, lrate, resume, resume_kimg, num_gpus, ema_kimg, gamma, freezeD):

    # training functions
    if d_aug:  # https://github.com/mit-han-lab/data-efficient-gans
        train = EasyDict(
            run_func_name='training.training_loop_diffaug.training_loop'
        )  # Options for training loop (Diff Augment method)
        loss_args = EasyDict(
            func_name='training.loss_diffaug.ns_DiffAugment_r1',
            policy=diffaug_policy)  # Options for loss (Diff Augment method)
    else:  # original nvidia
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop (original from NVidia)
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                          )  # Options for discriminator loss.

    # network functions
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    G.impl = D.impl = ops

    # dataset (tfrecords) - get or create
    tfr_files = file_list(os.path.dirname(data), 'tfr')
    tfr_files = [
        f for f in tfr_files if basename(data) == basename(f).split('-')[0]
    ]
    if len(tfr_files) == 0 or os.stat(tfr_files[0]).st_size == 0:
        tfr_file, total_samples = create_from_image_folders(
            data) if cond is True else create_from_images(data)
    else:
        tfr_file = tfr_files[0]
    dataset_args = EasyDict(tfrecord=tfr_file)

    # resolutions
    with tf.Graph().as_default(), tflib.create_session().as_default():  # pylint: disable=not-context-manager
        dataset_obj = dataset.load_dataset(
            **dataset_args)  # loading the data to see what comes out
        resolution = dataset_obj.resolution
        init_res = dataset_obj.init_res
        res_log2 = dataset_obj.res_log2
        dataset_obj.close()
        dataset_obj = None

    if list(init_res) == [4, 4]:
        desc = '%s-%d' % (basename(data), resolution)
    else:
        print(' custom init resolution', init_res)
        desc = basename(tfr_file)
    G.init_res = D.init_res = list(init_res)

    train.savenames = [desc.replace(basename(data), 'snapshot'), desc]
    desc += '-%s' % config

    # training schedule
    train.total_kimg = kimg
    train.image_snapshot_ticks = 1 * num_gpus if kimg <= 1000 else 4 * num_gpus
    train.network_snapshot_ticks = 5
    train.mirror_augment = mirror
    train.mirror_augment_v = mirror_v
    sched.tick_kimg_base = 2 if train.total_kimg < 2000 else 4

    # learning rate
    if config == 'e':
        sched.G_lrate_base = 0.001
        sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003}
        sched.lrate_step = 1500  # period for stepping to next lrate, in kimg
    if config == 'f':
        sched.G_lrate_base = lrate  # 0.001 for big datasets, 0.0003 for few-shot
    sched.D_lrate_base = sched.G_lrate_base  # *2 - not used anyway

    # batch size (for 16gb memory GPU)
    sched.minibatch_gpu_base = 4096 // resolution if batch_size is None else batch_size
    print(' Batch size', sched.minibatch_gpu_base)
    sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base
    sc.num_gpus = num_gpus

    if config == 'e':
        G.fmap_base = D.fmap_base = 8 << 10
        if d_aug: loss_args.gamma = 100 if gamma is None else gamma
        else: D_loss.gamma = 100 if gamma is None else gamma
    elif config == 'f':
        G.fmap_base = D.fmap_base = 16 << 10
    else:
        print(' Only configs E and F are implemented')
        exit()

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    if freezeD:
        D.freezeD = True
        train.resume_with_new_nets = True

    if d_aug:
        desc += '-daug'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    if d_aug:
        kwargs.update(loss_args=loss_args)
    else:
        kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = train_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #6
0
def run(dataset, train_dir, config, d_aug, diffaug_policy, cond, ops, jpg_data, mirror, mirror_v, \
        lod_step_kimg, batch_size, resume, resume_kimg, finetune, num_gpus, ema_kimg, gamma, freezeD):

    # dataset (tfrecords) - preprocess or get
    tfr_files = file_list(os.path.dirname(dataset), 'tfr')
    tfr_files = [f for f in tfr_files if basename(dataset) in f]
    if len(tfr_files) == 0:
        tfr_file, total_samples = create_from_images(dataset, jpg=jpg_data)
    else:
        tfr_file = tfr_files[0]
    dataset_args = EasyDict(tfrecord=tfr_file, jpg_data=jpg_data)

    desc = basename(tfr_file).split('-')[0]

    # training functions
    if d_aug:  # https://github.com/mit-han-lab/data-efficient-gans
        train = EasyDict(
            run_func_name='training.training_loop_diffaug.training_loop'
        )  # Options for training loop (Diff Augment method)
        loss_args = EasyDict(
            func_name='training.loss_diffaug.ns_DiffAugment_r1',
            policy=diffaug_policy)  # Options for loss (Diff Augment method)
    else:  # original nvidia
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop (original from NVidia)
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_r1'
                          )  # Options for discriminator loss.

    # network functions
    G = EasyDict(func_name='training.networks_stylegan2.G_main'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    sc = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    G.impl = D.impl = ops

    # resolutions
    data_res = basename(tfr_file).split('-')[-1].split(
        'x')  # get resolution from dataset filename
    data_res = list(reversed([int(x)
                              for x in data_res]))  # convert to int list
    init_res, resolution, res_log2 = calc_init_res(data_res)
    if init_res != [4, 4]:
        print(' custom init resolution', init_res)
    G.init_res = D.init_res = list(init_res)

    train.setname = desc + config
    desc = '%s-%d-%s' % (desc, resolution, config)

    # training schedule
    sched.lod_training_kimg = lod_step_kimg
    sched.lod_transition_kimg = lod_step_kimg
    train.total_kimg = lod_step_kimg * res_log2 * 2  # a la ProGAN
    if finetune:
        train.total_kimg = 15000  # should start from ~10k kimg
    train.image_snapshot_ticks = 1
    train.network_snapshot_ticks = 5
    train.mirror_augment = mirror
    train.mirror_augment_v = mirror_v

    # learning rate
    if config == 'e':
        if finetune:  # uptrain 1024
            sched.G_lrate_base = 0.001
        else:  # train 1024
            sched.G_lrate_base = 0.001
            sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003}
            sched.lrate_step = 1500  # period for stepping to next lrate, in kimg
    if config == 'f':
        # sched.G_lrate_base = 0.0003
        sched.G_lrate_base = 0.001
    sched.D_lrate_base = sched.G_lrate_base  # *2 - not used anyway

    sched.minibatch_gpu_base = batch_size
    sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base
    sc.num_gpus = num_gpus

    if config == 'e':
        G.fmap_base = D.fmap_base = 8 << 10
        if d_aug: loss_args.gamma = 100 if gamma is None else gamma
        else: D_loss.gamma = 100 if gamma is None else gamma
    elif config == 'f':
        G.fmap_base = D.fmap_base = 16 << 10
    else:
        print(' Only configs E and F are implemented')
        exit()

    if cond:
        desc += '-cond'
        dataset_args.max_label_size = 'full'  # conditioned on full label

    if freezeD:
        D.freezeD = True
        train.resume_with_new_nets = True

    if d_aug:
        desc += '-daug'

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt)
    kwargs.update(dataset_args=dataset_args,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_pkl=resume,
                  resume_kimg=resume_kimg,
                  resume_with_new_nets=True)
    if ema_kimg is not None:
        kwargs.update(G_ema_kimg=ema_kimg)
    if d_aug:
        kwargs.update(loss_args=loss_args)
    else:
        kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = train_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #7
0
def main(args):
    desc = 'sgan'  # Description string included in result subdir name.
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan.G_style'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan.D_basic'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_logistic_nonsaturating'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_logistic_simplegp',
                      r1_gamma=10.0)  # Options for discriminator loss.
    dataset = EasyDict(tfrecord_dir=args.tfrecord_dirname,
                       resolution=args.resolution)
    desc += '-ffhq' + str(args.resolution)
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='4k', layout='random')  # Options for setup_snapshot_image_grid().
    metrics = [metric_base.fid50k]  # Options for MetricGroup.
    submit_config = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}  # Options for tflib.init_tf().
    train.mirror_augment = True

    # Number of GPUs.
    submit_config.num_gpus = args.n_gpus
    desc += '-{}gpu'.format(args.n_gpus)
    sched.minibatch_base = 4 * args.n_gpus
    sched.minibatch_dict = MINI_BATCH_DICT[args.n_gpus]

    # Default options.
    train.total_kimg = args.total_kimg
    sched.lod_initial_resolution = 8
    sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # kwargs
    kwargs = EasyDict(train)
    kwargs.update(data_root_dir=args.datadir)
    kwargs.update(resume_run_id=args.resume_run_id,
                  resume_snapshot=args.resume_snapshot,
                  resume_kimg=args.resume_kimg,
                  resume_time=args.resume_time)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(submit_config)
    kwargs.submit_config.run_dir_root = args.outdir
    kwargs.submit_config.run_dir_ignore += config.run_dir_ignore
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
    def initialize(self):
        parser = self.parser
        parser.add_argument('--config_file', type=argparse.FileType(mode='r'), help="configuration yml file")
        self.parser.add_argument('--overwrite_config', action='store_true', help="overwrite config files if they exist")
        self.parser.add_argument('--model', default='biggan', help="pretrained model to use, e.g. biggan, stylegan")
        parser.add_argument('--transform', default="zoom", help="transform operation, e.g. zoom, shiftx, color, rotate2d"),
        parser.add_argument('--num_samples', type=int, default=20000, help='number of latent z samples')
        parser.add_argument('--loss', type=str, default='l2', help='loss to use for training', choices=['l2', 'lpips'])
        parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate for training')
        parser.add_argument('--walk_type', type=str, default='NNz', choices=['NNz', 'linear'], help='type of latent z walk')
        parser.add_argument('--models_dir', type=str, default="./models", help="output directory for saved checkpoints")
        parser.add_argument('--model_save_freq', type=int, default=400, help="saves checkpoints after this many batches")
        parser.add_argument('--name', type=str, help="experiment name, saved within models_dir")
        parser.add_argument('--suffix', type=str, help="suffix for experiment name")
        parser.add_argument('--prefix', type=str, help="prefix for experiment name")
        parser.add_argument("--gpu", default="", type=str, help='GPUs to use (leave blank for CPU only)')

        # NN walk parameters
        group = parser.add_argument_group('nn', 'parameters used to specify NN walk')
        group.add_argument('--eps', type=float, help="step size of each NN block")
        group.add_argument('--num_steps', type=int, help="number of NN blocks")

        # color transformation parameters
        group = parser.add_argument_group('color', 'parameters used for color walk')
        group.add_argument('--channel', type=int, help="which channel to modify; if unspecified, modifies all channels for linear walks, and luminance for NN walks")

        # biggan walk parameters
        group = parser.add_argument_group('biggan', 'parameters used for biggan walk')
        group.add_argument('--category', type=int, help="which category to train on; if unspecified uses all categories")

        # stylegan walk parameters
        # Official training configs for StyleGAN, targeted mainly for anime.

        if 1:
            desc          = 'sgan'                                                                 # 包含在结果子目录名称中的描述字符串。
            train         = EasyDict(run_func_name='train.joint_train')         # 训练过程设置。
            G             = EasyDict(func_name='training.networks_stylegan.G_style')               # 生成网络架构设置。
            D             = EasyDict(func_name='training.networks_stylegan.D_basic')               # 判别网络架构设置。
            G_opt         = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                          # 生成网络优化器设置。
            D_opt         = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                          # 判别网络优化器设置。
            G_loss        = EasyDict(func_name='training.loss.G_logistic_nonsaturating_steer')           # 生成损失设置。
            D_loss        = EasyDict(func_name='training.loss.D_logistic_simplegp', r1_gamma=10.0) # 判别损失设置。
            dataset       = EasyDict()                                                             # 数据集设置,在后文确认。
            sched         = EasyDict()                                                             # 训练计划设置,在后文确认。
            grid          = EasyDict(size='4k', layout='random')                                   # setup_snapshot_image_grid()相关设置。
            metrics       = [metric_base.fid50k]                                                   # 指标方法设置。
            submit_config = dnnlib.SubmitConfig()                                                  # dnnlib.submit_run()相关设置。
            tf_config     = {'rnd.np_random_seed': 1000}                                           # tflib.init_tf()相关设置。

            # 数据集。
            desc += '-character';     dataset = EasyDict(tfrecord_dir='character');                 #train.mirror_augment = True
            
            # GPU数量。
            desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}
            #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}
            #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
            #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

            # 默认设置。
            train = EasyDict(total_kimg = 25000)
            sched.lod_initial_resolution = 8
            sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
            sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

            kwargs = EasyDict(is_train=True)
            kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss)
            kwargs.update(dataset_args=dataset, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config)
            kwargs.submit_config = copy.deepcopy(submit_config)
            kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(config.result_dir)
            kwargs.submit_config.run_dir_ignore += config.run_dir_ignore
            kwargs.submit_config.run_desc = desc

        else:
            kwargs = EasyDict(is_train=False)

        group = parser.add_argument_group('stylegan', 'parameters used for stylegan walk')
        group.add_argument('--dataset', default="anime", help="which dataset to use for pretrained stylegan, e.g. cars, cats, celebahq")
        group.add_argument('--latent', default="w", help="which latent space to use; z or w")
        group.add_argument('--truncation_psi', default=1.0, help="truncation for NN walk in w")
        group.add_argument('--train_args', default=kwargs, help="kwargs for training stylegan")

        # pgan walk parameters
        group = parser.add_argument_group('pgan', 'parameters used for pgan walk')
        group.add_argument('--dset', default="celebahq", help="which dataset to use for pretrained pgan")

        self.initialized = True
        return self.parser
Beispiel #9
0
    #desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 16; sched.minibatch_dict = {4: 16, 8: 16, 16: 16, 32: 16, 64: 16, 128: 16, 256: 16}
    desc += '-2gpu'
    submit_config.num_gpus = 2
    sched.minibatch_base = 8
    sched.minibatch_dict = {
        256: 32,
        512: 16
    }
    #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
    #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

    # Default options.
    train.total_kimg = 25000
    sched.lod_initial_resolution = 256
    #sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    sched.G_lrate_dict = {256: 1e-3, 512: 1e-3, 1024: 1e-4}
    sched.D_lrate_dict = {256: 1e-3, 512: 1e-3, 1024: 1e-4}

#----------------------------------------------------------------------------
# Main entry point for training.
# Calls the function indicated by 'train' using the selected options.


def main():
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
Beispiel #10
0
    #desc += '-bedroom';  dataset = EasyDict(tfrecord_dir='lsun-bedroom-full');    train.mirror_augment = False
    #desc += '-car';      dataset = EasyDict(tfrecord_dir='lsun-car-512x384');     train.mirror_augment = False
    #desc += '-cat';      dataset = EasyDict(tfrecord_dir='lsun-cat-full');        train.mirror_augment = False
    desc += '-cond'; dataset.max_label_size = 'full' # conditioned on full label
    # Number of GPUs.
    #desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}; sched.tick_kimg_dict = {4: 200, 8:100, 16:100, 32:60, 64:40}
    desc += '-preset-v1-1gpu'; submit_config.num_gpus = 1; D.mbstd_group_size = 16; sched.minibatch_base = 16; sched.minibatch_dict = {256: 14, 512: 6, 1024: 3}; sched.lod_training_kimg = 800; sched.lod_transition_kimg = 800; train.total_kimg = 19000
    #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}
    #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
    #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

    # Default options.
    #train.total_kimg = 5000
    #train.total_kimg = 25000
    sched.lod_initial_resolution = 8
    sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)
    #desc += '-preset-v2-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}; sched.G_lrate_dict = {1024: 0.0015}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict);

    # WGAN-GP loss for CelebA-HQ.
    desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.003) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # Table 1.
    #desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-remove-traditional-input'; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-noise-inputs'; G.style_mixing_prob = 0.0
    #desc += '-mixing-regularization' # default

    # Table 2.
    #desc += '-mix0'; G.style_mixing_prob = 0.0
Beispiel #11
0
def main():
    parser = ArgumentParser(description='Style GAN')
    parser.add_argument('--algorithm',
                        metavar='algorithm',
                        type=str,
                        help='algorithm',
                        default='sgan')
    parser.add_argument('--resume_run_id',
                        metavar='resume_run_id',
                        type=str,
                        help='resume run id',
                        default=None)
    parser.add_argument('--resume_kimg',
                        metavar='resume_kimg',
                        type=float,
                        help='resume kimg',
                        default=0.0)
    args = parser.parse_args()

    # ----------------------------------------------------------------------------
    # Official training configs for StyleGAN, targeted mainly for FFHQ.

    if args.algorithm == 'sgan':
        desc = 'sgan'  # Description string included in result subdir name.
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop.
        G = EasyDict(func_name='training.networks_stylegan.G_style'
                     )  # Options for generator network.
        D = EasyDict(func_name='training.networks_stylegan.D_basic'
                     )  # Options for discriminator network.
        G_opt = EasyDict(beta1=0.0, beta2=0.99,
                         epsilon=1e-8)  # Options for generator optimizer.
        D_opt = EasyDict(beta1=0.0, beta2=0.99,
                         epsilon=1e-8)  # Options for discriminator optimizer.
        G_loss = EasyDict(func_name='training.loss.G_logistic_nonsaturating'
                          )  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_logistic_simplegp',
                          r1_gamma=10.0)  # Options for discriminator loss.
        dataset = EasyDict()  # Options for load_dataset().
        sched = EasyDict()  # Options for TrainingSchedule.
        grid = EasyDict(
            size='4k',
            layout='random')  # Options for setup_snapshot_image_grid().
        metrics = [metric_base.fid50k]  # Options for MetricGroup.
        submit_config = dnnlib.SubmitConfig(
        )  # Options for dnnlib.submit_run().
        resume_run_id = args.resume_run_id
        resume_kimg = args.resume_kimg
        tf_config = {
            'rnd.np_random_seed': 1000
        }  # Options for tflib.init_tf().

        # Dataset.
        # desc += '-ffhq';     dataset = EasyDict(tfrecord_dir='ffhq');                 train.mirror_augment = True
        # desc += '-ffhq512';  dataset = EasyDict(tfrecord_dir='ffhq', resolution=512); train.mirror_augment = True
        # desc += '-ffhq256';  dataset = EasyDict(tfrecord_dir='ffhq', resolution=256); train.mirror_augment = True
        # desc += '-celebahq'; dataset = EasyDict(tfrecord_dir='celebahq');             train.mirror_augment = True
        # desc += '-bedroom';  dataset = EasyDict(tfrecord_dir='lsun-bedroom-full');    train.mirror_augment = False
        # desc += '-car';      dataset = EasyDict(tfrecord_dir='lsun-car-512x384');     train.mirror_augment = False
        # desc += '-cat';      dataset = EasyDict(tfrecord_dir='lsun-cat-full');        train.mirror_augment = False
        desc += '-anime-faces-64'
        dataset = EasyDict(tfrecord_dir='anime-faces-64', resolution=64)
        train.mirror_augment = True

        # Number of GPUs.
        # desc += '-1gpu'; submit_config.num_gpus = 1; sched.minibatch_base = 4; sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}
        desc += '-2gpu'
        submit_config.num_gpus = 2
        sched.minibatch_base = 8
        sched.minibatch_dict = {
            4: 256,
            8: 256,
            16: 128,
            32: 64,
            64: 32,
            128: 16,
            256: 8
        }
        # desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
        # desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

        # Default options.
        train.total_kimg = 25000
        sched.lod_initial_resolution = 8
        sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

        # WGAN-GP loss for CelebA-HQ.
        # desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

        # Table 1.
        # desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
        # desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
        # desc += '-remove-traditional-input'; G.style_mixing_prob = 0.0; G.use_noise = False
        # desc += '-add-noise-inputs'; G.style_mixing_prob = 0.0
        # desc += '-mixing-regularization' # default

        # Table 2.
        # desc += '-mix0'; G.style_mixing_prob = 0.0
        # desc += '-mix50'; G.style_mixing_prob = 0.5
        # desc += '-mix90'; G.style_mixing_prob = 0.9 # default
        # desc += '-mix100'; G.style_mixing_prob = 1.0

        # Table 4.
        # desc += '-traditional-0'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
        # desc += '-traditional-8'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 8; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
        # desc += '-stylebased-0'; G.mapping_layers = 0
        # desc += '-stylebased-1'; G.mapping_layers = 1
        # desc += '-stylebased-2'; G.mapping_layers = 2
        # desc += '-stylebased-8'; G.mapping_layers = 8 # default

    # ----------------------------------------------------------------------------
    # Official training configs for Progressive GAN, targeted mainly for CelebA-HQ.

    if args.algorithm == 'pgan':
        desc = 'pgan'  # Description string included in result subdir name.
        train = EasyDict(run_func_name='training.training_loop.training_loop'
                         )  # Options for training loop.
        G = EasyDict(func_name='training.networks_progan.G_paper'
                     )  # Options for generator network.
        D = EasyDict(func_name='training.networks_progan.D_paper'
                     )  # Options for discriminator network.
        G_opt = EasyDict(beta1=0.0, beta2=0.99,
                         epsilon=1e-8)  # Options for generator optimizer.
        D_opt = EasyDict(beta1=0.0, beta2=0.99,
                         epsilon=1e-8)  # Options for discriminator optimizer.
        G_loss = EasyDict(
            func_name='training.loss.G_wgan')  # Options for generator loss.
        D_loss = EasyDict(func_name='training.loss.D_wgan_gp'
                          )  # Options for discriminator loss.
        dataset = EasyDict()  # Options for load_dataset().
        sched = EasyDict()  # Options for TrainingSchedule.
        grid = EasyDict(
            size='1080p',
            layout='random')  # Options for setup_snapshot_image_grid().
        metrics = [metric_base.fid50k]  # Options for MetricGroup.
        submit_config = dnnlib.SubmitConfig(
        )  # Options for dnnlib.submit_run().
        resume_run_id = args.resume_run_id
        resume_kimg = args.resume_kimg
        tf_config = {
            'rnd.np_random_seed': 1000
        }  # Options for tflib.init_tf().

        # Dataset (choose one).
        desc += '-celebahq'
        dataset = EasyDict(tfrecord_dir='celebahq')
        train.mirror_augment = True
        # desc += '-celeba';              dataset = EasyDict(tfrecord_dir='celeba'); train.mirror_augment = True
        # desc += '-cifar10';             dataset = EasyDict(tfrecord_dir='cifar10')
        # desc += '-cifar100';            dataset = EasyDict(tfrecord_dir='cifar100')
        # desc += '-svhn';                dataset = EasyDict(tfrecord_dir='svhn')
        # desc += '-mnist';               dataset = EasyDict(tfrecord_dir='mnist')
        # desc += '-mnistrgb';            dataset = EasyDict(tfrecord_dir='mnistrgb')
        # desc += '-syn1024rgb';          dataset = EasyDict(class_name='training.dataset.SyntheticDataset', resolution=1024, num_channels=3)
        # desc += '-lsun-airplane';       dataset = EasyDict(tfrecord_dir='lsun-airplane-100k');       train.mirror_augment = True
        # desc += '-lsun-bedroom';        dataset = EasyDict(tfrecord_dir='lsun-bedroom-100k');        train.mirror_augment = True
        # desc += '-lsun-bicycle';        dataset = EasyDict(tfrecord_dir='lsun-bicycle-100k');        train.mirror_augment = True
        # desc += '-lsun-bird';           dataset = EasyDict(tfrecord_dir='lsun-bird-100k');           train.mirror_augment = True
        # desc += '-lsun-boat';           dataset = EasyDict(tfrecord_dir='lsun-boat-100k');           train.mirror_augment = True
        # desc += '-lsun-bottle';         dataset = EasyDict(tfrecord_dir='lsun-bottle-100k');         train.mirror_augment = True
        # desc += '-lsun-bridge';         dataset = EasyDict(tfrecord_dir='lsun-bridge-100k');         train.mirror_augment = True
        # desc += '-lsun-bus';            dataset = EasyDict(tfrecord_dir='lsun-bus-100k');            train.mirror_augment = True
        # desc += '-lsun-car';            dataset = EasyDict(tfrecord_dir='lsun-car-100k');            train.mirror_augment = True
        # desc += '-lsun-cat';            dataset = EasyDict(tfrecord_dir='lsun-cat-100k');            train.mirror_augment = True
        # desc += '-lsun-chair';          dataset = EasyDict(tfrecord_dir='lsun-chair-100k');          train.mirror_augment = True
        # desc += '-lsun-churchoutdoor';  dataset = EasyDict(tfrecord_dir='lsun-churchoutdoor-100k');  train.mirror_augment = True
        # desc += '-lsun-classroom';      dataset = EasyDict(tfrecord_dir='lsun-classroom-100k');      train.mirror_augment = True
        # desc += '-lsun-conferenceroom'; dataset = EasyDict(tfrecord_dir='lsun-conferenceroom-100k'); train.mirror_augment = True
        # desc += '-lsun-cow';            dataset = EasyDict(tfrecord_dir='lsun-cow-100k');            train.mirror_augment = True
        # desc += '-lsun-diningroom';     dataset = EasyDict(tfrecord_dir='lsun-diningroom-100k');     train.mirror_augment = True
        # desc += '-lsun-diningtable';    dataset = EasyDict(tfrecord_dir='lsun-diningtable-100k');    train.mirror_augment = True
        # desc += '-lsun-dog';            dataset = EasyDict(tfrecord_dir='lsun-dog-100k');            train.mirror_augment = True
        # desc += '-lsun-horse';          dataset = EasyDict(tfrecord_dir='lsun-horse-100k');          train.mirror_augment = True
        # desc += '-lsun-kitchen';        dataset = EasyDict(tfrecord_dir='lsun-kitchen-100k');        train.mirror_augment = True
        # desc += '-lsun-livingroom';     dataset = EasyDict(tfrecord_dir='lsun-livingroom-100k');     train.mirror_augment = True
        # desc += '-lsun-motorbike';      dataset = EasyDict(tfrecord_dir='lsun-motorbike-100k');      train.mirror_augment = True
        # desc += '-lsun-person';         dataset = EasyDict(tfrecord_dir='lsun-person-100k');         train.mirror_augment = True
        # desc += '-lsun-pottedplant';    dataset = EasyDict(tfrecord_dir='lsun-pottedplant-100k');    train.mirror_augment = True
        # desc += '-lsun-restaurant';     dataset = EasyDict(tfrecord_dir='lsun-restaurant-100k');     train.mirror_augment = True
        # desc += '-lsun-sheep';          dataset = EasyDict(tfrecord_dir='lsun-sheep-100k');          train.mirror_augment = True
        # desc += '-lsun-sofa';           dataset = EasyDict(tfrecord_dir='lsun-sofa-100k');           train.mirror_augment = True
        # desc += '-lsun-tower';          dataset = EasyDict(tfrecord_dir='lsun-tower-100k');          train.mirror_augment = True
        # desc += '-lsun-train';          dataset = EasyDict(tfrecord_dir='lsun-train-100k');          train.mirror_augment = True
        # desc += '-lsun-tvmonitor';      dataset = EasyDict(tfrecord_dir='lsun-tvmonitor-100k');      train.mirror_augment = True

        # Conditioning & snapshot options.
        # desc += '-cond'; dataset.max_label_size = 'full' # conditioned on full label
        # desc += '-cond1'; dataset.max_label_size = 1 # conditioned on first component of the label
        # desc += '-g4k'; grid.size = '4k'
        # desc += '-grpc'; grid.layout = 'row_per_class'

        # Config presets (choose one).
        # desc += '-preset-v1-1gpu'; submit_config.num_gpus = 1; D.mbstd_group_size = 16; sched.minibatch_base = 16; sched.minibatch_dict = {256: 14, 512: 6, 1024: 3}; sched.lod_training_kimg = 800; sched.lod_transition_kimg = 800; train.total_kimg = 19000
        desc += '-preset-v2-1gpu'
        submit_config.num_gpus = 1
        sched.minibatch_base = 4
        sched.minibatch_dict = {
            4: 128,
            8: 128,
            16: 128,
            32: 64,
            64: 32,
            128: 16,
            256: 8,
            512: 4
        }
        sched.G_lrate_dict = {
            1024: 0.0015
        }
        sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)
        train.total_kimg = 12000
        # desc += '-preset-v2-2gpus'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}; sched.G_lrate_dict = {512: 0.0015, 1024: 0.002}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000
        # desc += '-preset-v2-4gpus'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}; sched.G_lrate_dict = {256: 0.0015, 512: 0.002, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000
        # desc += '-preset-v2-8gpus'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}; sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict); train.total_kimg = 12000

        # Numerical precision (choose one).
        desc += '-fp32'
        sched.max_minibatch_per_gpu = {
            256: 16,
            512: 8,
            1024: 4
        }
        # desc += '-fp16'; G.dtype = 'float16'; D.dtype = 'float16'; G.pixelnorm_epsilon=1e-4; G_opt.use_loss_scaling = True; D_opt.use_loss_scaling = True; sched.max_minibatch_per_gpu = {512: 16, 1024: 8}

        # Disable individual features.
        # desc += '-nogrowing'; sched.lod_initial_resolution = 1024; sched.lod_training_kimg = 0; sched.lod_transition_kimg = 0; train.total_kimg = 10000
        # desc += '-nopixelnorm'; G.use_pixelnorm = False
        # desc += '-nowscale'; G.use_wscale = False; D.use_wscale = False
        # desc += '-noleakyrelu'; G.use_leakyrelu = False
        # desc += '-nosmoothing'; train.G_smoothing_kimg = 0.0
        # desc += '-norepeat'; train.minibatch_repeats = 1
        # desc += '-noreset'; train.reset_opt_for_new_lod = False

        # Special modes.
        # desc += '-BENCHMARK'; sched.lod_initial_resolution = 4; sched.lod_training_kimg = 3; sched.lod_transition_kimg = 3; train.total_kimg = (8*2+1)*3; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1000; train.network_snapshot_ticks = 1000
        # desc += '-BENCHMARK0'; sched.lod_initial_resolution = 1024; train.total_kimg = 10; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1000; train.network_snapshot_ticks = 1000
        # desc += '-VERBOSE'; sched.tick_kimg_base = 1; sched.tick_kimg_dict = {}; train.image_snapshot_ticks = 1; train.network_snapshot_ticks = 100
        # desc += '-GRAPH'; train.save_tf_graph = True
        # desc += '-HIST'; train.save_weight_histograms = True

    # ----------------------------------------------------------------------------
    # Main entry point for training.
    # Calls the function indicated by 'train' using the selected options.

    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset,
                  sched_args=sched,
                  grid_args=grid,
                  metric_arg_list=metrics,
                  tf_config=tf_config)
    kwargs.update(resume_run_id=resume_run_id, resume_kimg=resume_kimg)
    kwargs.submit_config = copy.deepcopy(submit_config)
    kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(
        config.result_dir)
    kwargs.submit_config.run_dir_ignore += config.run_dir_ignore
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
        16: 64,
        32: 32,
        64: 16,
        128: 8,
        256: 4,
        512: 2
    }
    #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}
    #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
    #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

    # Default options.
    train.total_kimg = 25000
    sched.lod_initial_resolution = 8
    #sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    sched.G_lrate_dict = {128: 0.0008, 256: 0.001, 512: 0.0015, 1024: 0.0025}
    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # WGAN-GP loss for CelebA-HQ.
    #desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # Table 1.
    #desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-remove-traditional-input'; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-noise-inputs'; G.style_mixing_prob = 0.0
    #desc += '-mixing-regularization' # default

    # Table 2.
    #desc += '-mix0'; G.style_mixing_prob = 0.0
    #desc += '-mix50'; G.style_mixing_prob = 0.5
Beispiel #13
0
        4: 256,
        8: 256,
        16: 128,
        32: 64,
        64: 32,
        128: 16,
        256: 8
    }
    #desc += '-2gpu'; submit_config.num_gpus = 2; sched.minibatch_base = 8; sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}
    #desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
    #desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

    # Default options.
    train.total_kimg = 9900
    sched.lod_initial_resolution = 8
    sched.G_lrate_dict = {512: 0.0015, 1024: 0.002}
    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # WGAN-GP loss for CelebA-HQ.
    #desc += '-wgangp'; G_loss = EasyDict(func_name='training.loss.G_wgan'); D_loss = EasyDict(func_name='training.loss.D_wgan_gp'); sched.G_lrate_dict = {k: min(v, 0.002) for k, v in sched.G_lrate_dict.items()}; sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    # Table 1.
    #desc += '-tuned-baseline'; G.use_styles = False; G.use_pixel_norm = True; G.use_instance_norm = False; G.mapping_layers = 0; G.truncation_psi = None; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-mapping-and-styles'; G.const_input_layer = False; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-remove-traditional-input'; G.style_mixing_prob = 0.0; G.use_noise = False
    #desc += '-add-noise-inputs'; G.style_mixing_prob = 0.0
    #desc += '-mixing-regularization' # default

    # Table 2.
    #desc += '-mix0'; G.style_mixing_prob = 0.0
    #desc += '-mix50'; G.style_mixing_prob = 0.5
Beispiel #14
0
def easygen_train(model_path,
                  images_path,
                  dataset_path,
                  start_kimg=7000,
                  max_kimg=25000,
                  schedule='',
                  seed=1000):
    #import stylegan
    #from stylegan import config
    ##from stylegan import dnnlib
    #from stylegan.dnnlib import EasyDict

    #images_dir = '/content/raw'
    #max_kimg = 25000
    #start_kimg = 7000
    #schedule = ''
    #model_in = '/content/karras2019stylegan-cats-256x256.pkl'

    #dataset_dir = '/content/stylegan_dataset' #os.path.join(cwd, 'cache', 'stylegan_dataset')

    import config
    config.data_dir = '/content/datasets'
    config.results_dir = '/content/results'
    config.cache_dir = '/contents/cache'
    run_dir_ignore = [
        '/contents/results', '/contents/datasets', 'contents/cache'
    ]
    import copy
    import dnnlib
    from dnnlib import EasyDict
    from metrics import metric_base
    # Prep dataset
    import dataset_tool
    print("prepping dataset...")
    dataset_tool.create_from_images(tfrecord_dir=dataset_path,
                                    image_dir=images_path,
                                    shuffle=False)
    # Set up training parameters
    desc = 'sgan'  # Description string included in result subdir name.
    train = EasyDict(run_func_name='training.training_loop.training_loop'
                     )  # Options for training loop.
    G = EasyDict(func_name='training.networks_stylegan.G_style'
                 )  # Options for generator network.
    D = EasyDict(func_name='training.networks_stylegan.D_basic'
                 )  # Options for discriminator network.
    G_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for generator optimizer.
    D_opt = EasyDict(beta1=0.0, beta2=0.99,
                     epsilon=1e-8)  # Options for discriminator optimizer.
    G_loss = EasyDict(func_name='training.loss.G_logistic_nonsaturating'
                      )  # Options for generator loss.
    D_loss = EasyDict(func_name='training.loss.D_logistic_simplegp',
                      r1_gamma=10.0)  # Options for discriminator loss.
    dataset = EasyDict()  # Options for load_dataset().
    sched = EasyDict()  # Options for TrainingSchedule.
    grid = EasyDict(
        size='1080p',
        layout='random')  # Options for setup_snapshot_image_grid().
    #metrics       = [metric_base.fid50k]                                                  # Options for MetricGroup.
    submit_config = dnnlib.SubmitConfig()  # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': seed}  # Options for tflib.init_tf().
    # Dataset
    desc += '-custom'
    dataset = EasyDict(tfrecord_dir=dataset_path)
    train.mirror_augment = True
    # Number of GPUs.
    desc += '-1gpu'
    submit_config.num_gpus = 1
    sched.minibatch_base = 4
    sched.minibatch_dict = {
        4: 128,
        8: 128,
        16: 128,
        32: 64,
        64: 32,
        128: 16,
        256: 8,
        512: 4
    }  #{4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 16}
    # Default options.
    train.total_kimg = max_kimg
    sched.lod_initial_resolution = 8
    sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)
    # schedule
    schedule_dict = {
        4: 160,
        8: 140,
        16: 120,
        32: 100,
        64: 80,
        128: 60,
        256: 40,
        512: 30,
        1024: 20
    }  #{4: 2, 8:2, 16:2, 32:2, 64:2, 128:2, 256:2, 512:2, 1024:2} # Runs faster for small datasets
    if len(schedule) >= 5 and schedule[0] == '{' and schedule[
            -1] == '}' and ':' in schedule:
        # is schedule a string of a dict?
        try:
            temp = eval(schedule)
            schedule_dict = dict(temp)
            # assert: it is a dict
        except:
            pass
    elif len(schedule) > 0:
        # is schedule an int?
        try:
            schedule_int = int(schedule)
            #assert: schedule is an int
            schedule_dict = {}
            for i in range(1, 10):
                schedule_dict[int(math.pow(2, i + 1))] = schedule_int
        except:
            pass
    print('schedule:', str(schedule_dict))
    sched.tick_kimg_dict = schedule_dict
    # resume kimg
    resume_kimg = start_kimg
    # path to model
    resume_run_id = model_path
    # tick snapshots
    image_snapshot_ticks = 1
    network_snapshot_ticks = 1
    # Submit run
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset,
                  sched_args=sched,
                  grid_args=grid,
                  tf_config=tf_config)
    kwargs.update(resume_kimg=resume_kimg, resume_run_id=resume_run_id)
    kwargs.update(image_snapshot_ticks=image_snapshot_ticks,
                  network_snapshot_ticks=network_snapshot_ticks)
    kwargs.submit_config = copy.deepcopy(submit_config)
    kwargs.submit_config.run_dir_root = dnnlib.submission.submit.get_template_from_path(
        config.result_dir)
    kwargs.submit_config.run_dir_ignore += config.run_dir_ignore
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
Beispiel #15
0
sched.minibatch_dict = {
    4: 256,
    8: 256,
    16: 128,
    32: 64,
    64: 32,
    128: 16,
    256: 8
}
# desc += '-4gpu'; submit_config.num_gpus = 4; sched.minibatch_base = 16; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
# desc += '-8gpu'; submit_config.num_gpus = 8; sched.minibatch_base = 32; sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}

# Default options.
train.total_kimg = 88550
sched.lod_initial_resolution = 8
sched.G_lrate_dict = {128: 0.0015, 256: 0.002}
sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

#----------------------------------------------------------------------------
# Main entry point for training.


def main():
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
                  D_opt_args=D_opt,
                  G_loss_args=G_loss,
                  D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset,
Beispiel #16
0
    sched.minibatch_base = 4
    sched.minibatch_dict = {
        4: 128,
        8: 128,
        16: 128,
        32: 64,
        64: 32,
        128: 16,
        256: 8,
        512: 4
    }

    # Default options.
    train.total_kimg = 20000
    sched.lod_initial_resolution = 4
    sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
    sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

    desc += '-cond'
    dataset.max_label_size = 'full'

# ----------------------------------------------------------------------------
# Main entry point for training.
# Calls the function indicated by 'train' using the selected options.


def main():
    kwargs = EasyDict(train)
    kwargs.update(G_args=G,
                  D_args=D,
                  G_opt_args=G_opt,
Beispiel #17
0
def main():
    #
    #    Usage example:
    #    python [this_file].py --kimg ###### --dataset [your data] --gpu_num 1 
    #       --start_res 8 --img_size 512 --progressive True
    #
    #
    # parse arguments
    args = parse_args()
    if args is None:
        exit()
    
    checkpoint_dir = args.checkpoint_dir
    nvlabs_stylegan_pkl_kimg = args.kimg
    nvlabs_stylegan_pkl_name = "network-snapshot-"+nvlabs_stylegan_pkl_kimg+".pkl"
    
    
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        
        # this is a hack since the taki0112 expects a dataset folder which may not exist
        dataset = args.dataset
        dataset_dir = "./dataset/" + dataset
        temp_dataset_file = make_temp_dataset_file(dataset_dir)
        
        
        # build the taki0112 StyleGAN architecture (vanilla Tensorflow)
        gan = StyleGAN(sess, args)
        
        
        # you have to go through this process to initialize everything needed to load the checkpoint...
        original_start_res = args.start_res
        args.start_res = args.img_size
        gan.start_res = args.img_size
        gan.build_model()
        args.start_res = original_start_res
        gan.start_res = original_start_res
        
        # remove the temp file and the directory if it is empty
        delete_temp_dataset_file(args, dataset_dir, temp_dataset_file)
        
        # Initialize TensorFlow.
        tflib.init_tf()
        
        tf.global_variables_initializer().run()
        
        
        vars = tf.trainable_variables("discriminator")
        vars_vals = sess.run(vars)
        for var, val in zip(vars, vars_vals):
            print(var.name)
        
        gan.saver = tf.train.Saver(max_to_keep=10)
        gan.load(checkpoint_dir)
        
        #
        #
        #   Make an NVlabs StyleGAN network (default initialization)
        #
        #
        
        # StyleGAN initialization parameters and options, if you care to change them, do so here
        desc          = "sgan"                                                                 
        train         = EasyDict(run_func_name="training.training_loop.training_loop")         
        G             = EasyDict(func_name="training.networks_stylegan.G_style")               
        D             = EasyDict(func_name="training.networks_stylegan.D_basic")               
        G_opt         = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                          
        D_opt         = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                          
        G_loss        = EasyDict(func_name="training.loss.G_logistic_nonsaturating")           
        D_loss        = EasyDict(func_name="training.loss.D_logistic_simplegp", r1_gamma=10.0) 
        dataset       = EasyDict()                                                             
        sched         = EasyDict()                                                             
        grid          = EasyDict(size="4k", layout="random")                                   
        metrics       = [metric_base.fid50k]                                                   
        submit_config = dnnlib.SubmitConfig()                                                  
        tf_config     = {"rnd.np_random_seed": 1000}                                           
        drange_net              = [-1,1]
        G_smoothing_kimg        = 10.0
        
        # Dataset.
        desc += "-"+args.dataset
        dataset = EasyDict(tfrecord_dir=args.dataset)
        train.mirror_augment = True
        
        # Number of GPUs.
        gpu_num = args.gpu_num
        if gpu_num == 1:
            desc += "-1gpu"; submit_config.num_gpus = 1
            sched.minibatch_base = 4
            sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4}
        elif gpu_num == 2:
            desc += "-2gpu"; submit_config.num_gpus = 2
            sched.minibatch_base = 8
            sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8}
        elif gpu_num == 4:
            desc += "-4gpu"; submit_config.num_gpus = 4
            sched.minibatch_base = 16
            sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16}
        elif gpu_num == 8:
            desc += "-8gpu"; submit_config.num_gpus = 8
            sched.minibatch_base = 32
            sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32}
        else:
            print("ERROR: invalid number of gpus:",gpu_num)
            sys.exit(-1)

        # Default options.
        train.total_kimg = 0
        sched.lod_initial_resolution = 8
        sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        sched.D_lrate_dict = EasyDict(sched.G_lrate_dict)

        # Initialize dnnlib and TensorFlow.
        # ctx = dnnlib.RunContext(submit_config, train)
        tflib.init_tf(tf_config)

        # Construct networks.
        with tf.device('/gpu:0'):
            print('Constructing networks...')
            dataset_resolution = args.img_size
            dataset_channels = 3 # fairly sure everyone is using 3 channels ... # training_set.shape[0],
            dataset_label_size = 0 # training_set.label_size,
            G = tflib.Network('G',
                num_channels=dataset_channels,
                resolution=dataset_resolution,
                label_size=dataset_label_size,
                **G)
            D = tflib.Network('D',
                num_channels=dataset_channels,
                resolution=dataset_resolution,
                label_size=dataset_label_size,
                **D)
            Gs = G.clone('Gs')
        G.print_layers(); D.print_layers()

        print('Building TensorFlow graph...')
        with tf.name_scope('Inputs'), tf.device('/cpu:0'):
            lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
            lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
            minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
            minibatch_split = minibatch_in // submit_config.num_gpus
            Gs_beta         = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32),
                                G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

        
        src_d = "discriminator"
        dst_d = "D"
        src_gs = "generator/g_synthesis"
        dst_gs = "G_synthesis" # "G_synthesis_1" <<<< this is handled later
        src_gm = "generator/g_mapping"
        dst_gm = "G_mapping" # "G_mapping_1" <<<< this is handled later
        
        
        vars = tf.trainable_variables(src_gm)
        vars_vals = sess.run(vars)
        
        
        # Copy over the discriminator weights
        for (new, old) in zip(tf.trainable_variables(dst_d), tf.trainable_variables(src_d)):
            update_weight = [tf.assign(new, old)]
            sess.run(update_weight)
            temp_vals = sess.run([new, old])
        
        # Copy over the Generator's mapping network weights
        for (new, old) in zip(tf.trainable_variables(dst_gm), tf.trainable_variables(src_gm)):
            update_weight = [tf.assign(new, old)]
            sess.run(update_weight)
            temp_vals = sess.run([new, old])
        
        # Because the two network architectures use slightly different columns on one variable,
        # you must set up code to handle the edge case transpose of the first case
        first = True
        for (new, old) in zip(tf.trainable_variables(dst_gs), tf.trainable_variables(src_gs)):
            temp_vals = sess.run([new, old])
            if new.shape != old.shape:
                # you need a transpose with perm # old = tf.reshape(old, tf.shape(new))
                # DO NOT USE RESHAPE
                # (made this mistake here and the results work but are quite terrifying)
                if (first):
                    first = False
                    old = tf.transpose(old, perm=[0, 3, 1, 2])
                else:
                    old = tf.transpose(old, perm=[0, 1, 3, 2])
            update_weight = [tf.assign(new, old)]
            sess.run(update_weight)
            
        # also update the running average network (not 100% sure this is necessary)
        dst_gs = "G_synthesis_1"
        dst_gm = "G_mapping_1"
        for (new, old) in zip(tf.trainable_variables(dst_gm), tf.trainable_variables(src_gm)):
            update_weight = [tf.assign(new, old)]
            sess.run(update_weight)
            temp_vals = sess.run([new, old])
        first = True
        for (new, old) in zip(tf.trainable_variables(dst_gs), tf.trainable_variables(src_gs)):
            temp_vals = sess.run([new, old])
            if new.shape != old.shape:
                # you need a transpose with perm # old = tf.reshape(old, tf.shape(new))
                # DO NOT USE RESHAPE
                # (made this mistake here and the results work but are quite terrifying)
                if (first):
                    first = False
                    old = tf.transpose(old, perm=[0, 3, 1, 2])
                else:
                    old = tf.transpose(old, perm=[0, 1, 3, 2])
            update_weight = [tf.assign(new, old)]
            sess.run(update_weight)
            
        # Also, assign the w_avg in the taki0112 network to the NVlabs Gs dlatent_avg
        new = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="G")
                if "dlatent_avg" in str(x)][0] # G.get_var("dlatent_avg")
        old = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="generator")
                if "avg" in str(x)][0]
        update_weight = [tf.assign(new, old)]
        sess.run(update_weight)
        vars = [new]
        vars_vals = gan.sess.run(vars)
        vars_vals = sess.run(vars)
        
        misc.save_pkl((G, D, Gs), "./"+nvlabs_stylegan_pkl_name)