def create_model(data_shape, full=False, labels=None, kwargs_in=None): init_res, resolution, res_log2 = calc_init_res(data_shape[1:]) kwargs_out = dnnlib.EasyDict() kwargs_out.num_channels = data_shape[0] if kwargs_in is not None: for k in list(kwargs_in.keys()): kwargs_out[k] = kwargs_in[k] if labels is not None: kwargs_out.label_size = labels kwargs_out.resolution = resolution kwargs_out.init_res = init_res if a.verbose is True: print(['%s: %s' % (kv[0], kv[1]) for kv in sorted(kwargs_out.items())]) if full is True: G = tflib.Network('G', func_name='training.networks_stylegan2.G_main', **kwargs_out) D = tflib.Network('D', func_name='training.networks_stylegan2.D_stylegan2', **kwargs_out) Gs = G.clone('Gs') else: Gs = tflib.Network('Gs', func_name='training.networks_stylegan2.G_main', **kwargs_out) G = D = None return G, D, Gs
def run(dataset, train_dir, config, d_aug, diffaug_policy, cond, ops, jpg_data, mirror, mirror_v, \ lod_step_kimg, batch_size, resume, resume_kimg, finetune, num_gpus, ema_kimg, gamma, freezeD): # dataset (tfrecords) - preprocess or get tfr_files = file_list(os.path.dirname(dataset), 'tfr') tfr_files = [f for f in tfr_files if basename(dataset) in f] if len(tfr_files) == 0: tfr_file, total_samples = create_from_images(dataset, jpg=jpg_data) else: tfr_file = tfr_files[0] dataset_args = EasyDict(tfrecord=tfr_file, jpg_data=jpg_data) desc = basename(tfr_file).split('-')[0] # training functions if d_aug: # https://github.com/mit-han-lab/data-efficient-gans train = EasyDict( run_func_name='training.training_loop_diffaug.training_loop' ) # Options for training loop (Diff Augment method) loss_args = EasyDict( func_name='training.loss_diffaug.ns_DiffAugment_r1', policy=diffaug_policy) # Options for loss (Diff Augment method) else: # original nvidia train = EasyDict(run_func_name='training.training_loop.training_loop' ) # Options for training loop (original from NVidia) G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg' ) # Options for generator loss. D_loss = EasyDict(func_name='training.loss.D_logistic_r1' ) # Options for discriminator loss. # network functions G = EasyDict(func_name='training.networks_stylegan2.G_main' ) # Options for generator network. D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2' ) # Options for discriminator network. G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. sched = EasyDict() # Options for TrainingSchedule. grid = EasyDict( size='1080p', layout='random') # Options for setup_snapshot_image_grid(). sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). G.impl = D.impl = ops # resolutions data_res = basename(tfr_file).split('-')[-1].split( 'x') # get resolution from dataset filename data_res = list(reversed([int(x) for x in data_res])) # convert to int list init_res, resolution, res_log2 = calc_init_res(data_res) if init_res != [4, 4]: print(' custom init resolution', init_res) G.init_res = D.init_res = list(init_res) train.setname = desc + config desc = '%s-%d-%s' % (desc, resolution, config) # training schedule sched.lod_training_kimg = lod_step_kimg sched.lod_transition_kimg = lod_step_kimg train.total_kimg = lod_step_kimg * res_log2 * 2 # a la ProGAN if finetune: train.total_kimg = 15000 # should start from ~10k kimg train.image_snapshot_ticks = 1 train.network_snapshot_ticks = 5 train.mirror_augment = mirror train.mirror_augment_v = mirror_v # learning rate if config == 'e': if finetune: # uptrain 1024 sched.G_lrate_base = 0.001 else: # train 1024 sched.G_lrate_base = 0.001 sched.G_lrate_dict = {0: 0.001, 1: 0.0007, 2: 0.0005, 3: 0.0003} sched.lrate_step = 1500 # period for stepping to next lrate, in kimg if config == 'f': # sched.G_lrate_base = 0.0003 sched.G_lrate_base = 0.001 sched.D_lrate_base = sched.G_lrate_base # *2 - not used anyway sched.minibatch_gpu_base = batch_size sched.minibatch_size_base = num_gpus * sched.minibatch_gpu_base sc.num_gpus = num_gpus if config == 'e': G.fmap_base = D.fmap_base = 8 << 10 if d_aug: loss_args.gamma = 100 if gamma is None else gamma else: D_loss.gamma = 100 if gamma is None else gamma elif config == 'f': G.fmap_base = D.fmap_base = 16 << 10 else: print(' Only configs E and F are implemented') exit() if cond: desc += '-cond' dataset_args.max_label_size = 'full' # conditioned on full label if freezeD: D.freezeD = True train.resume_with_new_nets = True if d_aug: desc += '-daug' sc.submit_target = dnnlib.SubmitTarget.LOCAL sc.local.do_not_copy_source_files = True kwargs = EasyDict(train) kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt) kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, tf_config=tf_config) kwargs.update(resume_pkl=resume, resume_kimg=resume_kimg, resume_with_new_nets=True) if ema_kimg is not None: kwargs.update(G_ema_kimg=ema_kimg) if d_aug: kwargs.update(loss_args=loss_args) else: kwargs.update(G_loss_args=G_loss, D_loss_args=D_loss) kwargs.submit_config = copy.deepcopy(sc) kwargs.submit_config.run_dir_root = train_dir kwargs.submit_config.run_desc = desc dnnlib.submit_run(**kwargs)
def main(): tflib.init_tf({'allow_soft_placement': True}) G_in, D_in, Gs_in = load_pkl(a.source) print(' Loading model', a.source, Gs_in.output_shape) _, res_in, _ = calc_init_res(Gs_in.output_shape[1:]) if a.res is not None or a.alpha is True: if a.res is None: a.res = Gs_in.output_shape[2:] colors = 4 if a.alpha is True else Gs_in.output_shape[ 1] # EXPERIMENTAL _, res_out, _ = calc_init_res([colors, *a.res]) if res_in != res_out or a.alpha is True: # add or remove layers assert G_in is not None and D_in is not None, " !! G/D subnets not found in source model !!" data_shape = [colors, res_out, res_out] print(' Reconstructing full model with shape', data_shape) G_out, D_out, Gs_out = create_model(data_shape, True, 0, Gs_in.static_kwargs) copy_vars(Gs_in, Gs_out) copy_vars(G_in, G_out) copy_vars(D_in, D_out, D=True) G_in, D_in, Gs_in = G_out, D_out, Gs_out a.full = True if a.res[0] != res_out or a.res[1] != res_out: # crop or pad layers data_shape = [colors, *a.res] G_out, D_out, Gs_out = create_model(data_shape, True, 0, Gs_in.static_kwargs) if G_in is not None and D_in is not None: print(' Reconstructing full model with shape', data_shape) copy_and_crop_or_pad_trainables(G_in, G_out) copy_and_crop_or_pad_trainables(D_in, D_out) G_in, D_in = G_out, D_out a.full = True else: print(' Reconstructing Gs model with shape', data_shape) copy_and_crop_or_pad_trainables(Gs_in, Gs_out) Gs_in = Gs_out if a.labels is not None: assert G_in is not None and D_in is not None, " !! G/D subnets not found in source model !!" print(' Reconstructing full model with labels', a.labels) data_shape = Gs_in.output_shape[1:] G_out, D_out, Gs_out = create_model(data_shape, True, a.labels, Gs_in.static_kwargs) if a.verbose is True: D_out.print_layers() if a.verbose is True: G_out.print_layers() copy_and_fill_trainables(G_in, G_out) copy_and_fill_trainables(D_in, D_out) copy_and_fill_trainables(Gs_in, Gs_out) a.full = True if a.labels is None and a.res is None and a.alpha is not True: if a.reconstruct is True: print(' Reconstructing model with same size /', 'full' if a.full else 'Gs') data_shape = Gs_in.output_shape[1:] G_out, D_out, Gs_out = create_model(data_shape, a.full, 0, Gs_in.static_kwargs) Gs_out.copy_vars_from(Gs_in) if a.full is True and G_in is not None and D_in is not None: G_out.copy_vars_from(G_in) D_out.copy_vars_from(D_in) else: Gs_out = Gs_in out_name = basename(a.source) if a.res is not None: out_name += '-%dx%d' % (a.res[1], a.res[0]) if a.alpha is True: out_name += 'a' if a.labels is not None: out_name += '-c%d' % a.labels if a.full is True: # G_in is not None and D_in is not None save_pkl((G_out, D_out, Gs_out), os.path.join(a.out_dir, '%s.pkl' % out_name)) else: save_pkl(Gs_out, os.path.join(a.out_dir, '%s-Gs.pkl' % out_name)) print(' Done')