def create_initial_pkl( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. tf_config = {}, # Options for tflib.init_tf(). config_id = "config-f", # config-f is the only one tested ... num_channels = 3, # number of channels, e.g. 3 for RGB resolution_h = 1024, # height dimension of real/fake images resolution_w = 1024, # height dimension of real/fake images label_size = 0, # number of labels for a conditional model ): # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) resolution = resolution_h # training_set.shape[1] # Construct or load networks. with tf.device('/gpu:0'): print('Constructing networks...') G = tflib.Network('G', num_channels=num_channels, resolution=resolution, label_size=label_size, **G_args) D = tflib.Network('D', num_channels=num_channels, resolution=resolution, label_size=label_size, **D_args) Gs = G.clone('Gs') # Print layers and generate initial image snapshot. G.print_layers(); D.print_layers() pkl = 'network-initial-%s-%sx%s-%s.pkl' % (config_id, resolution_w, resolution_h, label_size) misc.save_pkl((G, D, Gs), pkl) print("Saving",pkl)
def create_model(data_shape, full=False, kwargs_in=None): init_res, resolution, res_log2 = calc_init_res(data_shape[1:]) kwargs_out = dnnlib.EasyDict() kwargs_out.num_channels = data_shape[0] kwargs_out.label_size = 0 if kwargs_in is not None: for k in list(kwargs_in.keys()): kwargs_out[k] = kwargs_in[k] kwargs_out.resolution = resolution kwargs_out.init_res = init_res if a.verbose is True: print(['%s: %s' % (kv[0], kv[1]) for kv in sorted(kwargs_out.items())]) if full is True: G = tflib.Network('G', func_name='training.networks_stylegan2.G_main', **kwargs_out) D = tflib.Network('D', func_name='training.networks_stylegan2.D_stylegan2', **kwargs_out) Gs = G.clone('Gs') else: Gs = tflib.Network('Gs', func_name='training.networks_stylegan2.G_main', **kwargs_out) G = D = None return G, D, Gs, res_log2
def G_main_spatial_biased_dsp( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. is_training=False, # Network is under training? Enables and disables specific features. is_validation=False, # Network is under validation? Chooses which value to use for truncation_psi. return_dlatents=False, # Return dlatents in addition to the images? is_template_graph=False, # True = template graph constructed by the Network class, False = actual evaluation. components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. mapping_func='G_mapping_spatial_biased_dsp', # Build func name for the mapping network. synthesis_func='G_synthesis_spatial_biased_dsp', # Build func name for the synthesis network. **kwargs): # Arguments for sub-networks (mapping and synthesis). # Validate arguments. assert not is_training or not is_validation # Setup components. if 'synthesis' not in components: components.synthesis = tflib.Network( 'G_spatial_biased_synthesis_dsp', func_name=globals()[synthesis_func], **kwargs) if 'mapping' not in components: components.mapping = tflib.Network('G_spatial_biased_mapping_dsp', func_name=globals()[mapping_func], dlatent_broadcast=None, **kwargs) # Setup variables. lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) # Evaluate mapping network. dlatents = components.mapping.get_output_for(latents_in, labels_in, is_training=is_training, **kwargs) dlatents = tf.cast(dlatents, tf.float32) # Evaluate synthesis network. deps = [] if 'lod' in components.synthesis.vars: deps.append(tf.assign(components.synthesis.vars['lod'], lod_in)) with tf.control_dependencies(deps): images_out = components.synthesis.get_output_for( dlatents, is_training=is_training, force_clean_graph=is_template_graph, **kwargs) # Return requested outputs. images_out = tf.identity(images_out, name='images_out') if return_dlatents: return images_out, dlatents return images_out
def load_model( url='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ', # karras2019stylegan-ffhq-1024x1024.pkl session=None, cache_dir='cache'): session = session or tf.get_default_session() with session.as_default(): with dnnlib.util.open_url(url, cache_dir=cache_dir) as f: _G, _D, _Gs = pickle.load(f) G = tflib.Network(_G.name, G_style, **_G.static_kwargs) G.copy_vars_from(_G) D = tflib.Network(_D.name, D_basic, **_D.static_kwargs) D.copy_vars_from(_D) Gs = tflib.Network(_Gs.name, G_style, **_Gs.static_kwargs) Gs.copy_vars_from(_Gs) return G, D, Gs
def load_perceptual( url='https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2', # vgg16_zhang_perceptual.pkl session=None, cache_dir='cache'): session = session or tf.get_default_session() with dnnlib.util.open_url(url, cache_dir=cache_dir) as f: _P = pickle.load(f) with session.as_default(): P = tflib.Network(_P.name, lpips_network, **_P.static_kwargs) P.copy_vars_from(_P) return P
def create_model(data_shape, full=False): init_res, resolution, res_log2 = calc_init_res(data_shape[1:]) Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.resolution = resolution Gs_kwargs.init_res = init_res Gs_kwargs.num_channels = data_shape[0] Gs_kwargs.label_size = 0 if full is True: G = tflib.Network('G', func_name='training.networks_stylegan2.G_main', **Gs_kwargs) D = tflib.Network('D', func_name='training.networks_stylegan2.D_stylegan2', **Gs_kwargs) Gs = G.clone('Gs') else: Gs = tflib.Network('Gs', func_name='training.networks_stylegan2.G_main', **Gs_kwargs) G = D = None return G, D, Gs, res_log2
def Decoder_main( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. is_training=False, # Network is under training? Enables and disables specific features. return_dlatents=False, # Return dlatents in addition to the images? is_template_graph=False, # True = template graph constructed by the Network class, False = actual evaluation. components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. mapping_func='Decoder_mapping', # Build func name for the mapping network. synthesis_func='Decoder_synthesis', # Build func name for the synthesis network. **kwargs): # Setup components. if 'synthesis' not in components: components.synthesis = tflib.Network( 'G_synthesis', func_name=globals()[synthesis_func], **kwargs) if 'mapping' not in components: components.mapping = tflib.Network('G_mapping', func_name=globals()[mapping_func], **kwargs) # Evaluate mapping network. dlatents = components.mapping.get_output_for(latents_in, labels_in, is_training=is_training, **kwargs) dlatents = tf.cast(dlatents, tf.float32) images_out = components.synthesis.get_output_for( dlatents, is_training=is_training, force_clean_graph=is_template_graph, **kwargs) # Return requested outputs. images_out = tf.identity(images_out, name='images_out') if return_dlatents: return images_out, dlatents return images_out
def get_Gs(opt): # Find and load the network checkpoints, 1-by-1: for exp_number, snapshot_kimg in zip(opt.models, opt.snapshot_kimgs): resume_pkl = find_model(exp_number) if not resume_pkl: if not exp_number.endswith('.pkl'): # Look for a pkl in results directory results_dir = os.path.join(os.getcwd(), config.result_dir) resume_pkl = find_pkl(results_dir, int(exp_number), snapshot_kimg) else: resume_pkl = exp_number tflib.init_tf() _, _, _Gs = load_pkl(resume_pkl) nz = _Gs.input_shapes[0][1] Gs = tflib.Network(name='Gs', func_name='training.networks_progan.G_paper', latent_size=nz, num_channels=3, resolution=128, label_size=0) Gs.copy_vars_from(_Gs) print(f'Visualizing pkl: {resume_pkl} with seed={opt.seed}') if nz < 12 and not opt.interpolate_pre_norm: print(f'Model {exp_number} uses a small z vector (nz={nz}); you might want to add ' f'--interpolate_pre_norm to your command.') yield Gs, nz
def embed(batch_size, resolution, imgs, network, iteration, result_dir, seed=6600): tf.reset_default_graph() G_args = dnnlib.EasyDict(func_name='training.networks_stylegan2_alpha.G_main') G_args.fmap_base = 8 << 10 print('Loading networks from "%s"...' % network) tflib.init_tf() G = tflib.Network('G', num_channels=3, resolution=128, **G_args) _, _, Gs = pretrained_networks.load_networks(network) G.copy_vars_from(Gs) img_in = tf.placeholder(tf.float32) opt = tf.train.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8) opt_T = tf.train.AdamOptimizer(learning_rate=0.002, beta1=0.9, beta2=0.999, epsilon=1e-8) noise_vars = [var for name, var in G.components.synthesis.vars.items() if name.startswith('noise')] alpha_vars = [var for name, var in G.components.synthesis.vars.items() if name.endswith('alpha')] alpha_evals = [alpha.eval() for alpha in alpha_vars] G_kwargs = dnnlib.EasyDict() G_kwargs.randomize_noise = False G_syn = G.components.synthesis rnd = np.random.RandomState(seed) dlatent_avg = [var for name, var in G.vars.items() if name.startswith('dlatent_avg')][0].eval() dlatent_avg = np.expand_dims(np.expand_dims(dlatent_avg, 0), 1) dlatent_avg = dlatent_avg.repeat(12, 1) dlatent = tf.get_variable('dlatent', dtype=tf.float32, initializer=tf.constant(dlatent_avg), trainable=True) T = tf.get_variable('T', dtype=tf.float32, initializer=tf.constant(0.95)) alpha_pre = [scale_alpha_exp(alpha_eval, T) for alpha_eval in alpha_evals] synth_img = G_syn.get_output_for(dlatent, is_training=False, alpha_pre=alpha_pre, **G_kwargs) # synth_img = (synth_img + 1.0) / 2.0 with tf.variable_scope('mse_loss'): mse_loss = tf.reduce_mean(tf.square(img_in - synth_img)) with tf.variable_scope('perceptual_loss'): vgg_in = tf.concat([img_in, synth_img], 0) tf.keras.backend.set_image_data_format('channels_first') vgg = tf.keras.applications.VGG16(include_top=False, input_tensor=vgg_in, input_shape=(3, 128, 128), weights='/gdata2/fengrl/metrics/vgg.h5', pooling=None) h1 = vgg.get_layer('block1_conv1').output h2 = vgg.get_layer('block1_conv2').output h3 = vgg.get_layer('block3_conv2').output h4 = vgg.get_layer('block4_conv2').output pcep_loss = tf.reduce_mean(tf.square(h1[0] - h1[1])) + tf.reduce_mean(tf.square(h2[0] - h2[1])) + \ tf.reduce_mean(tf.square(h3[0] - h3[1])) + tf.reduce_mean(tf.square(h4[0] - h4[1])) loss = 0.5 * mse_loss + 0.5 * pcep_loss with tf.control_dependencies([loss]): grads = tf.gradients(mse_loss, [dlatent, T]) train_op1 = opt.apply_gradients(zip([grads[0]], [dlatent])) train_op2 = opt_T.apply_gradients(zip([grads[1]], [T])) train_op = tf.group(train_op1, train_op2) reset_opt = tf.variables_initializer(opt.variables()+opt_T.variables()) reset_dl = tf.variables_initializer([dlatent, T]) tflib.init_uninitialized_vars() # rnd = np.random.RandomState(seed) tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] idx = 0 metrics_l = [] metrics_p = [] metrics_m = [] metrics_d = [] T_list = [] for img in imgs: img = np.expand_dims(img, 0) loss_list = [] p_loss_list = [] m_loss_list = [] dl_list = [] si_list = [] # tflib.set_vars({alpha: alpha_np for alpha, alpha_np in zip(alpha_vars, alpha_evals)}) tflib.run([reset_opt, reset_dl]) for i in range(iteration): loss_, p_loss_, m_loss_, dl_, si_, t_, _ = tflib.run([loss, pcep_loss, mse_loss, dlatent, synth_img, T, train_op], {img_in: img}) loss_list.append(loss_) p_loss_list.append(p_loss_) m_loss_list.append(m_loss_) dl_loss_ = np.sum(np.square(dl_-dlatent_avg)) dl_list.append(dl_loss_) if i % 500 == 0: si_list.append(si_) if i % 100 == 0: print('idx %d, Loss %f, mse %f, ppl %f, dl %f, t %f, step %d' % (idx, loss_, m_loss_, p_loss_, dl_loss_, t_, i)) print('T: %f, loss: %f, ppl: %f, mse: %f, d: %f' % (t_, loss_list[-1], p_loss_list[-1], m_loss_list[-1], dl_list[-1])) metrics_l.append(loss_list[-1]) metrics_p.append(p_loss_list[-1]) metrics_m.append(m_loss_list[-1]) metrics_d.append(dl_list[-1]) T_list.append(t_) misc.save_image_grid(np.concatenate(si_list, 0), os.path.join(result_dir, 'si%d.png' % idx), drange=[-1, 1]) misc.save_image_grid(si_list[-1], os.path.join(result_dir, 'sifinal%d.png' % idx), drange=[-1, 1]) with open(os.path.join(result_dir, 'metric_l%d.txt' % idx), 'w') as f: for l_ in loss_list: f.write(str(l_) + '\n') with open(os.path.join(result_dir, 'metric_p%d.txt' % idx), 'w') as f: for l_ in p_loss_list: f.write(str(l_) + '\n') with open(os.path.join(result_dir, 'metric_m%d.txt' % idx), 'w') as f: for l_ in m_loss_list: f.write(str(l_) + '\n') with open(os.path.join(result_dir, 'metric_d%d.txt' % idx), 'w') as f: for l_ in dl_list: f.write(str(l_) + '\n') idx += 1 l_mean = np.mean(metrics_l) p_mean = np.mean(metrics_p) m_mean = np.mean(metrics_m) d_mean = np.mean(metrics_d) with open(os.path.join(result_dir, 'metric_lmpd.txt'), 'w') as f: f.write(str(alpha_evals)+'\n') for i in range(len(metrics_l)): f.write(str(T_list[i])+' '+str(metrics_l[i])+' '+str(metrics_m[i])+' '+str(metrics_p[i])+' '+str(metrics_d[i])+'\n') print('Overall metrics: loss_mean %f, ppl_mean %f, mse_mean %f, d_mean %f' % (l_mean, p_mean, m_mean, d_mean)) with open(os.path.join(result_dir, 'mean_metrics.txt'), 'w') as f: f.write('loss %f\n' % l_mean) f.write('mse %f\n' % m_mean) f.write('ppl %f\n' % p_mean) f.write('dl %f\n' % d_mean)
def G_main( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. latmask, # mask for split-frame latents blending dconst, # initial (const) layer displacement truncation_psi=0.5, # Style strength multiplier for the truncation trick. None = disable. truncation_cutoff=None, # Number of layers for which to apply the truncation trick. None = disable. truncation_psi_val=None, # Value for truncation_psi to use during validation. truncation_cutoff_val=None, # Value for truncation_cutoff to use during validation. dlatent_avg_beta=0.995, # Decay for tracking the moving average of W during training. None = disable. style_mixing_prob=0.9, # Probability of mixing styles during training. None = disable. is_training=False, # Network is under training? Enables and disables specific features. is_validation=False, # Network is under validation? Chooses which value to use for truncation_psi. return_dlatents=False, # Return dlatents in addition to the images? is_template_graph=False, # True = template graph constructed by the Network class, False = actual evaluation. components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. mapping_func='G_mapping', # Build func name for the mapping network. synthesis_func='G_synthesis_stylegan2', # Build func name for the synthesis network. **kwargs): # Arguments for sub-networks (mapping and synthesis). # Validate arguments. assert not is_training or not is_validation assert isinstance(components, dnnlib.EasyDict) if is_validation: truncation_psi = truncation_psi_val truncation_cutoff = truncation_cutoff_val if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): truncation_psi = None if is_training: truncation_cutoff = None if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): dlatent_avg_beta = None if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): style_mixing_prob = None # Setup components. if 'synthesis' not in components: components.synthesis = tflib.Network( 'G_synthesis', func_name=globals()[synthesis_func], **kwargs) num_layers = components.synthesis.input_shape[1] dlatent_size = components.synthesis.input_shape[2] if 'mapping' not in components: components.mapping = tflib.Network('G_mapping', func_name=globals()[mapping_func], dlatent_broadcast=num_layers, **kwargs) # Setup variables. lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False) # Evaluate mapping network. dlatents = components.mapping.get_output_for(latents_in, labels_in, is_training=is_training, **kwargs) dlatents = tf.cast(dlatents, tf.float32) # Update moving average of W. if dlatent_avg_beta is not None: with tf.variable_scope('DlatentAvg'): batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) update_op = tf.assign( dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) with tf.control_dependencies([update_op]): dlatents = tf.identity(dlatents) # Perform style mixing regularization. if style_mixing_prob is not None: with tf.variable_scope('StyleMix'): latents2 = tf.random_normal(tf.shape(latents_in)) dlatents2 = components.mapping.get_output_for( latents2, labels_in, is_training=is_training, **kwargs) dlatents2 = tf.cast(dlatents2, tf.float32) layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 # original version mixing_cutoff = tf.cond( tf.random_uniform([], 0.0, 1.0) < style_mixing_prob, lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32), lambda: cur_layers) """ # Diff Augment version mixing_cutoff = tf.where_v2( tf.random_uniform([tf.shape(dlatents)[0]], 0.0, 1.0) < style_mixing_prob, tf.random_uniform([tf.shape(dlatents)[0]], 1, cur_layers, dtype=tf.int32), cur_layers[np.newaxis])[:, np.newaxis, np.newaxis] dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) """ # Apply truncation trick. if truncation_psi is not None: with tf.variable_scope('Truncation'): layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] layer_psi = np.ones(layer_idx.shape, dtype=np.float32) if truncation_cutoff is None: layer_psi *= truncation_psi else: layer_psi = tf.where(layer_idx < truncation_cutoff, layer_psi * truncation_psi, layer_psi) dlatents = tflib.lerp(dlatent_avg, dlatents, layer_psi) # Evaluate synthesis network. deps = [] if 'lod' in components.synthesis.vars: deps.append(tf.assign(components.synthesis.vars['lod'], lod_in)) with tf.control_dependencies(deps): images_out = components.synthesis.get_output_for( dlatents, latmask, dconst, is_training=is_training, force_clean_graph=is_template_graph, **kwargs) # Return requested outputs. images_out = tf.identity(images_out, name='images_out') if return_dlatents: return images_out, dlatents return images_out
def training_loop( submit_config, #--------------------------------------------------------------- # Modified by Deng et al. noise_dim=32, weight_args={}, train_stage_args={}, #--------------------------------------------------------------- G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=True, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, **_kwargs ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. PI = 3.1415927 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create 3d face reconstruction block FaceRender = Face3D() # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') #--------------------------------------------------------------- # Modified by Deng et al. G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, latent_size=254 + noise_dim, **G_args) #--------------------------------------------------------------- D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) resolution = tf.placeholder(tf.float32, name='resolution', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) #--------------------------------------------------------------- # Modified by Deng et al. G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\ G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\ lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\ drange_net=drange_net,lod_in=lod_in,**train_stage_args) #--------------------------------------------------------------- G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) #--------------------------------------------------------------- # Modified by Deng et al. restore_weights_and_initialize(train_stage_args) print('Setting up snapshot image grid...') sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_latents = tf.random_normal([np.prod(grid_size), 128 + 32 + 16 + 3]) grid_INPUTcoeff = z_to_lambda_mapping(grid_latents) grid_INPUTcoeff_w_t = tf.concat( [grid_INPUTcoeff, tf.zeros([np.prod(grid_size), 3])], axis=1) with tf.name_scope('FaceRender'): grid_render_img, _, _, _ = FaceRender.Reconstruction_Block( grid_INPUTcoeff_w_t, 256, np.prod(grid_size), progressive=False) grid_render_img = tf.transpose(grid_render_img, perm=[0, 3, 1, 2]) grid_render_img = process_reals(grid_render_img, lod_in, False, training_set.dynamic_range, drange_net) grid_INPUTcoeff_, grid_renders = tflib.run( [grid_INPUTcoeff, grid_render_img], {lod_in: sched.lod}) grid_noise = np.random.randn(np.prod(grid_size), 32) grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_, grid_noise], axis=1) grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) #--------------------------------------------------------------- summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) # print('iter') # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: #--------------------------------------------------------------- # Modified by Deng et al. grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) #--------------------------------------------------------------- if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. AE_opt_args=None, # Options for autoencoder optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. AE_loss_args=None, # Options for autoencoder loss. dataset_args={}, # Options for dataset.load_dataset(). dataset_args_eval={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). train_data_dir=None, # Directory to load datasets from. eval_data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=True, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, resume_with_own_vars=False ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. print("Loading train set from %s..." % dataset_args.tfrecord_dir) training_set = dataset.load_dataset( data_dir=dnnlib.convert_path(train_data_dir), verbose=True, **dataset_args) print("Loading eval set from %s..." % dataset_args_eval.tfrecord_dir) eval_set = dataset.load_dataset( data_dir=dnnlib.convert_path(eval_data_dir), verbose=True, **dataset_args_eval) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Freeze Discriminator if D_args['freeze']: num_layers = np.log2(training_set.resolution) - 1 layers = int(np.round(num_layers * 3. / 8.)) scope = ['Output', 'scores_out'] for layer in range(layers): scope += ['.*%d' % 2**layer] if 'train_scope' in D_args: scope[-1] += '.*%d' % D_args['train_scope'] D_args['train_scope'] = scope # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is '' or resume_with_new_nets or resume_with_own_vars: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not '': print('Loading networks from "%s"...' % resume_pkl) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G = rG D = rD Gs = rGs grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) # SVD stuff if 'syn_svd' in G_args or 'map_svd' in G_args: # Run graph to calculate SVD grid_latents_smol = grid_latents[:1] rho = np.array([1]) grid_fakes = G.run(grid_latents_smol, grid_labels, rho, is_validation=True) grid_fakes = Gs.run(grid_latents_smol, grid_labels, rho, is_validation=True) load_d_fake = D.run(grid_reals[:1], rho, is_validation=True) with tf.device('/gpu:0'): # Create SVD-decomposed graph rG, rD, rGs = G, D, Gs G_lambda_mask = { var: np.ones(G.vars[var].shape[-1]) for var in G.vars if 'SVD/s' in var } D_lambda_mask = { 'D/' + var: np.ones(D.vars[var].shape[-1]) for var in D.vars if 'SVD/s' in var } G_reduce_dims = { var: (0, int(Gs.vars[var].shape[-1])) for var in Gs.vars if 'SVD/s' in var } G_args['lambda_mask'] = G_lambda_mask G_args['reduce_dims'] = G_reduce_dims D_args['lambda_mask'] = D_lambda_mask # Create graph with no SVD operations G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=rG.input_shapes[1][1], factorized=True, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=rD.input_shapes[1][1], factorized=True, **D_args) Gs = G.clone('Gs') grid_fakes = G.run(grid_latents_smol, grid_labels, rho, is_validation=True, minibatch_size=1) grid_fakes = Gs.run(grid_latents_smol, grid_labels, rho, is_validation=True, minibatch_size=1) G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) # Reduce per-gpu minibatch size to fit in 16GB GPU memory if grid_reals.shape[2] >= 1024: sched_args.minibatch_gpu_base = 2 print('Batch size', sched_args.minibatch_gpu_base) # Generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) rho = np.array([1]) grid_fakes = Gs.run(grid_latents, grid_labels, rho, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) if resume_pkl is not '': load_d_real = rD.run(grid_reals[:1], rho, is_validation=True) load_d_fake = rD.run(grid_fakes[:1], rho, is_validation=True) d_fake = D.run(grid_fakes[:1], rho, is_validation=True) d_real = D.run(grid_reals[:1], rho, is_validation=True) print('Factorized fake', d_fake, 'loaded fake', load_d_fake, 'factorized real', d_real, 'loaded real', load_d_real) print('(should match)') # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) if AE_opt_args is not None: AE_opt_args = dict(AE_opt_args) AE_opt_args['minibatch_multiplier'] = minibatch_multiplier AE_opt_args['learning_rate'] = lrate_in AE_opt = tflib.Optimizer(name='TrainAE', **AE_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): if G_loss_args['func_name'] == 'training.loss.G_l1': G_loss_args['reals'] = reals_read else: G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) ae_iter_mul = 10 ae_rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus * ae_iter_mul) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: _g_loss, _ = tflib.run([G_loss, G_train_op], feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): print('g loss', _g_loss) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and cur_tick % network_snapshot_ticks == 0 or done: pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(eval_data_dir), num_gpus=num_gpus, tf_config=tf_config, rho=rho) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close() eval_set.close()
def training_loop( classifier_args={}, # Options for generator network. classifier_opt_args={}, # Options for generator optimizer. classifier_loss_args={}, dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. network_snapshot_ticks=5, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False): # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, shuffle_mb=2 * 4096, **dataset_args) # Construct or load networks. with tf.device('/gpu:0'): print('Constructing networks...') classifier = tflib.Network('classifier', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **classifier_args) classifier.print_layers() # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. classifier_opt_args = dict(classifier_opt_args) classifier_opt_args['minibatch_multiplier'] = minibatch_multiplier classifier_opt_args['learning_rate'] = lrate_in classifier_opt = tflib.Optimizer(name='TrainClassifier', **classifier_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('gpu%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. classifier_gpu = classifier if gpu == 0 else classifier.clone( classifier.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=0, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros( [sched.minibatch_gpu, 127])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. with tf.name_scope('classifier_loss'): classifier_loss, label = dnnlib.util.call_func_by_name( classifier=classifier_gpu, images=reals_read, labels=labels_read, **classifier_loss_args) classifier_opt.register_gradients(tf.reduce_mean(classifier_loss), classifier_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) classifier_train_op = classifier_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=0, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = 0 cur_tick = -1 tick_start_nimg = cur_nimg running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # Run training ops. feed_dict = { lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([classifier_train_op, data_fetch_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(data_fetch_op, feed_dict) classifier_loss_out, label_out, _ = tflib.run( [classifier_loss, label, classifier_train_op], feed_dict) print_output = False if print_output: print('label') print(np.round(label_out, 2)) print('loss') print(np.round(classifier_loss_out, 2)) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl(classifier, pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl(classifier, dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. loss_args={}, # Options for loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for metrics. metric_args={}, # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). ema_start_kimg=None, # Start of the exponential moving average. Default to the half-life period. G_ema_kimg=10, # Half-life of the exponential moving average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=False, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=4, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=2, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=1, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False ): # Construct new networks according to G_args and D_args before resuming training? if ema_start_kimg is None: ema_start_kimg = G_ema_kimg # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: resume_networks = misc.load_pkl(resume_pkl) rG, rD, rGs = resume_networks if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) G_lrate_in = tf.placeholder(tf.float32, name='G_lrate_in', shape=[]) D_lrate_in = tf.placeholder(tf.float32, name='D_lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) run_D_reg_in = tf.placeholder(tf.bool, name='run_D_reg', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta_mul_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[]) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_ema_kimg * 1000.0) if G_ema_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) G_opt_args['learning_rate'] = G_lrate_in D_opt_args['learning_rate'] = D_lrate_in for args in [G_opt_args, D_opt_args]: args['minibatch_multiplier'] = minibatch_multiplier G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) # Build training graph for each GPU. for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): with tf.name_scope('DataFetch'): reals_read, labels_read = training_set.get_minibatch_tf() reals_read = process_reals(reals_read, lod_in, mirror_augment, training_set.dynamic_range, drange_net) # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('loss'): G_loss, D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, real_labels=labels_read, **loss_args) # Register gradients. if not lazy_regularization: if D_reg is not None: D_loss += D_reg else: if D_reg is not None: D_loss = tf.cond(run_D_reg_in, lambda: D_loss + D_reg * D_reg_interval, lambda: D_loss) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta * Gs_beta_mul_in) with tf.control_dependencies([Gs_update_op]): G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list, **metric_args) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, G_lrate_in: sched.G_lrate, D_lrate_in: sched.D_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, Gs_beta_mul_in: 1 if cur_nimg >= ema_start_kimg * 1000 else 0, } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) feed_dict[run_D_reg_in] = run_D_reg cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. for _ in rounds: tflib.run(G_train_op, feed_dict) tflib.run(D_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
run_id = 15 snapshot = 15326 G_args = {} synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8) tflib.init_tf() # baseline model # network_pkl = '../../results/00015-sgan-ffhq256-1gpu-baseline/network-snapshot-014526.pkl' network_pkl = '../../results/00046-sgan-ffhq256-2gpu-adain-pixel-norm-continue/network-snapshot-012126.pkl' # no noise model # network_pkl = 'results/00022-sgan-ffhq256-2gpu/network-snapshot-005726.pkl' _G, _D, Gs = misc.load_pkl(network_pkl) G = tflib.Network('G', func_name='training.networks_stylegan_cutoff.G_style', num_channels=3, resolution=256, label_size=0, structure='linear', **G_args) G.copy_vars_from(Gs) G_original = tflib.Network('G', func_name='training.networks_stylegan.G_style', num_channels=3, resolution=256, label_size=0, structure='linear', **G_args) G_original.copy_vars_from(Gs) latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in [8]) images = G.run(latents, None, use_instance_norm = False, **synthesis_kwargs) images_original = G_original.run(latents, None, use_instance_norm = False, **synthesis_kwargs) print(images.shape) fig, axs = plt.subplots(3, 3) im = images[0] counter = 20 for i in range(3):
def G_style( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch, label_size]. truncation_psi=0.7, # Style strength multiplier for the truncation trick. None = disable. truncation_cutoff=8, # Number of layers for which to apply the truncation trick. None = disable. truncation_psi_val=None, # Value for truncation_psi to use during validation. truncation_cutoff_val=None, # Value for truncation_cutoff to use during validation. dlatent_avg_beta=0.995, # Decay for tracking the moving average of W during training. None = disable. style_mixing_prob=0.9, # Probability of mixing styles during training. None = disable. is_training=False, # Network is under training? Enables and disables specific features. is_validation=False, # Network is under validation? Chooses which value to use for truncation_psi. is_template_graph=False, # True = template graph constructed by the Network class, False = actual evaluation. components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. **kwargs): # Arguments for sub-networks (G_mapping and G_synthesis). # Validate arguments. assert not is_training or not is_validation assert isinstance(components, dnnlib.EasyDict) if is_validation: truncation_psi = truncation_psi_val truncation_cutoff = truncation_cutoff_val if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): truncation_psi = None if is_training or (truncation_cutoff is not None and not tflib.is_tf_expression(truncation_cutoff) and truncation_cutoff <= 0): truncation_cutoff = None if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): dlatent_avg_beta = None if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): style_mixing_prob = None # Setup components. if 'synthesis' not in components: components.synthesis = tflib.Network('G_synthesis', func_name=G_synthesis, **kwargs) num_layers = components.synthesis.input_shape[1] dlatent_size = components.synthesis.input_shape[2] if 'mapping' not in components: components.mapping = tflib.Network('G_mapping', func_name=G_mapping, dlatent_broadcast=num_layers, **kwargs) # Setup variables. lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False) # Evaluate mapping network. dlatents = components.mapping.get_output_for(latents_in, labels_in, **kwargs) # Update moving average of W. if dlatent_avg_beta is not None: with tf.variable_scope('DlatentAvg'): batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) update_op = tf.assign( dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) with tf.control_dependencies([update_op]): dlatents = tf.identity(dlatents) # Perform style mixing regularization. if style_mixing_prob is not None: with tf.name_scope('StyleMix'): latents2 = tf.random_normal(tf.shape(latents_in)) dlatents2 = components.mapping.get_output_for( latents2, labels_in, **kwargs) layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 mixing_cutoff = tf.cond( tf.random_uniform([], 0.0, 1.0) < style_mixing_prob, lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32), lambda: cur_layers) dlatents = tf.where( tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) # Apply truncation trick. if truncation_psi is not None and truncation_cutoff is not None: with tf.variable_scope('Truncation'): layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] ones = np.ones(layer_idx.shape, dtype=np.float32) coefs = tf.where(layer_idx < truncation_cutoff, truncation_psi * ones, ones) dlatents = tflib.lerp(dlatent_avg, dlatents, coefs) # Evaluate synthesis network. with tf.control_dependencies( [tf.assign(components.synthesis.find_var('lod'), lod_in)]): images_out = components.synthesis.get_output_for( dlatents, force_clean_graph=is_template_graph, **kwargs) return tf.identity(images_out, name='images_out')
def D_stylegan2( images_in, # First input: Images [minibatch, channel, height, width]. labels_in, # Second input: Labels [minibatch, label_size]. num_channels=3, # Number of input color channels. Overridden based on dataset. resolution=1024, # Input resolution. Overridden based on dataset. label_size=0, # Dimensionality of the labels, 0 if no labels. Overridden based on dataset. fmap_base=16 << 10, # Overall multiplier for the number of feature maps. fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. fmap_min=1, # Minimum number of feature maps in any layer. fmap_max=512, # Maximum number of feature maps in any layer. architecture='resnet', # Architecture: 'orig', 'skip', 'resnet'. nonlinearity='lrelu', # Activation function: 'relu', 'lrelu', etc. mbstd_group_size=4, # Group size for the minibatch standard deviation layer, 0 = disable. mbstd_num_features=1, # Number of features for the minibatch standard deviation layer. dtype='float32', # Data type to use for activations and outputs. resample_kernel=[ 1, 3, 3, 1 ], # Low-pass filter to apply when resampling activations. None = no filtering. mapping_label_func='D_mapping_label', components=dnnlib.EasyDict( ), # Container for sub-networks. Retained between calls. dlabel_size=128, **_kwargs): # Ignore unrecognized keyword args. resolution_log2 = int(np.log2(resolution)) assert resolution == 2**resolution_log2 and resolution >= 4 def nf(stage): return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min, fmap_max) assert architecture in ['orig', 'skip', 'resnet'] act = nonlinearity images_in.set_shape([None, num_channels, resolution, resolution]) labels_in.set_shape([None, label_size]) images_in = tf.cast(images_in, dtype) labels_in = tf.cast(labels_in, dtype) # dlabel = D_mapping_label(labels_in=labels_in, label_size=label_size) # dlabel = tf.cast(dlabel, dtype) if 'mapping_label' not in components: components.mapping_label = tflib.Network( 'D_mapping_label', func_name=globals()[mapping_label_func], label_size=label_size, dlabel_size=dlabel_size) dlabel = components.mapping_label.get_output_for(labels_in) dlabel = tf.cast(dlabel, dtype) # Building blocks for main layers. def fromrgb(x, y, res): # res = 2..resolution_log2 with tf.variable_scope('FromRGB'): t = apply_bias_act(modulated_conv2d_layer(y, dlabel, fmaps=nf(res - 1), kernel=1), act=act) return t if x is None else x + t def block(x, res): # res = 2..resolution_log2 t = x with tf.variable_scope('Conv0'): x = apply_bias_act(modulated_conv2d_layer(x, dlabel, fmaps=nf(res - 1), kernel=3), act=act) with tf.variable_scope('Conv1_down'): x = apply_bias_act(modulated_conv2d_layer( x, dlabel, fmaps=nf(res - 2), kernel=3, down=True, resample_kernel=resample_kernel), act=act) if architecture == 'resnet': with tf.variable_scope('Skip'): t = conv2d_layer(t, fmaps=nf(res - 2), kernel=1, down=True, resample_kernel=resample_kernel) x = (x + t) * (1 / np.sqrt(2)) return x def downsample(y): with tf.variable_scope('Downsample'): return downsample_2d(y, k=resample_kernel) # Main layers. x = None y = images_in for res in range(resolution_log2, 2, -1): with tf.variable_scope('%dx%d' % (2**res, 2**res)): if architecture == 'skip' or res == resolution_log2: x = fromrgb(x, y, res) x = block(x, res) if architecture == 'skip': y = downsample(y) # Final layers. with tf.variable_scope('4x4'): if architecture == 'skip': x = fromrgb(x, y, 2) if mbstd_group_size > 1: with tf.variable_scope('MinibatchStddev'): x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features) with tf.variable_scope('Conv'): x = apply_bias_act(modulated_conv2d_layer(x, dlabel, fmaps=nf(1), kernel=3), act=act) with tf.variable_scope('Dense0'): x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act) # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?" with tf.variable_scope('Output'): x = apply_bias_act(dense_layer(x, fmaps=max(labels_in.shape[1], 1))) if labels_in.shape[1] > 0: x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True) scores_out = x # Output. assert scores_out.dtype == tf.as_dtype(dtype) scores_out = tf.identity(scores_out, name='scores_out') return scores_out
def training_loop_refinement( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=True, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) _rG, _rD, rGs = misc.load_pkl(resume_pkl) del _rD, _rG if resume_with_new_nets: G.copy_vars_from(rGs) Gs.copy_vars_from(rGs) del rGs else: G = rG Gs = rGs # Set constant noise input for both G and Gs if G_args.get("randomize_noise", None) == False: noise_vars = [ var for name, var in G.components.synthesis.vars.items() if name.startswith('noise') ] rnd = np.random.RandomState(123) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] noise_vars = [ var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise') ] rnd = np.random.RandomState(123) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] # TESTS # from PIL import Image # reals, latents = training_set.get_minibatch_np(4) # reals = np.transpose(reals, [0, 2, 3, 1]) # Image.fromarray(reals[0], 'RGB').save("test_reals.png") # labels = training_set.get_random_labels_np(4) # Gs_kwargs = dnnlib.EasyDict() # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) # fakes = Gs.run(latents, labels, minibatch_size=4, **Gs_kwargs) # Image.fromarray(fakes[0], 'RGB').save("test_fakes_Gs_new.png") # fakes = G.run(latents, labels, minibatch_size=4, **Gs_kwargs) # Image.fromarray(fakes[0], 'RGB').save("test_fakes_G_new.png") # Print layers and generate initial image snapshot. G.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) # Freeze layers G_args.freeze_layers = list(G_args.get("freeze_layers", [])) def freeze_vars(gen, verbose=True): assert len(G_args.freeze_layers) > 0 for name in list(gen.trainables.keys()): if any(layer in name for layer in G_args.freeze_layers): del gen.trainables[name] if verbose: print(f"Freezed {name}") # Build training graph for each GPU. data_fetch_ops = [] loss_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if G_args.freeze_layers: freeze_vars(G_gpu, verbose=False) # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=None, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, latents=labels_read, **G_loss_args) loss_ops.append(G_loss) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) loss_op = tf.reduce_mean(tf.concat(loss_ops, axis=0)) G_train_op = G_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 loss_per_batch_sum = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } tflib.run(data_fetch_op, feed_dict) ### TEST # fakes = G.get_output_for(labels_read, training_set.get_random_labels_tf(minibatch_gpu_in), is_training=True) # this is without activation in ~[-1.5, 1.5] # fakes = tf.clip_by_value(fakes, drange_net[0], drange_net[1]) # reals = reals_read ### TEST for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: loss, _ = tflib.run([loss_op, G_train_op], feed_dict) # (loss, reals, fakes), _ = tflib.run([loss_op, G_train_op], feed_dict) tflib.run([data_fetch_op], feed_dict) # print(f"loss_tf {np.mean(loss)}") # print(f"loss_np {np.mean(np.square(reals - fakes))}") # print(f"loss_abs {np.mean(np.abs(reals - fakes))}") loss_per_batch_sum += loss #### TEST #### # if cur_nimg == sched.minibatch_size or cur_nimg % 2048 == 0: # from PIL import Image # reals = np.transpose(reals, [0, 2, 3, 1]) # fakes = np.transpose(fakes, [0, 2, 3, 1]) # diff = np.abs(reals - fakes) # print(diff.min(), diff.max()) # for idx, (fake, real) in enumerate(zip(fakes, reals)): # fake -= fake.min() # fake /= fake.max() # fake *= 255 # fake = fake.astype(np.uint8) # Image.fromarray(fake, 'RGB').save(f"fake_loss_{idx}.png") # real -= real.min() # real /= real.max() # real *= 255 # real = real.astype(np.uint8) # Image.fromarray(real, 'RGB').save(f"real_loss_{idx}.png") #### if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([Gs_update_op], feed_dict) # Slow path with gradient accumulation. FIXME: Probably wrong else: for _round in rounds: loss, _, _ = tflib.run( [loss_op, G_train_op, data_fetch_op], feed_dict) loss_per_batch_sum += loss / len(rounds) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time tick_loss = loss_per_batch_sum * sched.minibatch_size / ( tick_kimg * 1000) loss_per_batch_sum = 0 # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d loss/px %-12.8f time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), autosummary('Progress/loss_per_px', tick_loss), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, None, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, None, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
import cv2 G_args = {} synthesis_kwargs = dict(output_transform=dict( func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8) tflib.init_tf() # baseline model baseline_network_pkl = '../results/00015-sgan-ffhq256-1gpu-baseline/network-snapshot-014526.pkl' _G, _D, Gs = misc.load_pkl(baseline_network_pkl) G_baseline = tflib.Network( 'G', func_name='training.networks_stylegan_cutoff.G_style', num_channels=3, resolution=256, label_size=0, structure='linear', **G_args) G_baseline.copy_vars_from(Gs) without_progan_network_pkl = '../results/00001-sgan-ffhq256-2gpu-remove-progan/network-snapshot-014800.pkl' _G, _D, Gs = misc.load_pkl(without_progan_network_pkl) G_without_noise = tflib.Network( 'G', func_name='training.networks_stylegan_cutoff.G_style', num_channels=3, resolution=256, label_size=0, structure='linear',
def load_from(name, cfg): dnnlib.tflib.init_tf() with open(name, 'rb') as f: m = pickle.load(f) Gs = m[2] Gs_ = tflib.Network( 'G', func_name='stylegan.training.networks_stylegan.G_style', num_channels=3, resolution=1024) Gs_.copy_vars_from(Gs) model = Model( startf=cfg.MODEL.START_CHANNEL_COUNT, layer_count=cfg.MODEL.LAYER_COUNT, maxf=cfg.MODEL.MAX_CHANNEL_COUNT, latent_size=cfg.MODEL.LATENT_SPACE_SIZE, mapping_layers=cfg.MODEL.MAPPING_LAYERS, truncation_psi=0.7, #cfg.MODEL.TRUNCATIOM_PSI, truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF, channels=3) def tensor(x, transpose=None): x = Gs.vars[x].eval() if transpose: x = np.transpose(x, transpose) return torch.tensor(x) for i in range(cfg.MODEL.MAPPING_LAYERS): block = getattr(model.mapping, "block_%d" % (i + 1)) block.fc.weight[:] = tensor('G_mapping/Dense%d/weight' % i, (1, 0)) * block.fc.std block.fc.bias[:] = tensor( 'G_mapping/Dense%d/bias' % i) * block.fc.lrmul model.dlatent_avg.buff[:] = tensor('dlatent_avg') model.generator.const[:] = tensor('G_synthesis/4x4/Const/const') for i in range(model.generator.layer_count): j = model.generator.layer_count - i - 1 name = '%dx%d' % (2**(2 + i), 2**(2 + i)) block = model.generator.decode_block[i] prefix = 'G_synthesis/%s' % name if not block.has_first_conv: prefix_1 = '%s/Const' % prefix prefix_2 = '%s/Conv' % prefix else: prefix_1 = '%s/Conv0_up' % prefix prefix_2 = '%s/Conv1' % prefix block.noise_weight_1[0, :, 0, 0] = tensor('%s/Noise/weight' % prefix_1) block.noise_weight_2[0, :, 0, 0] = tensor('%s/Noise/weight' % prefix_2) if block.has_first_conv: if block.fused_scale: block.conv_1.weight[:] = tensor( '%s/weight' % prefix_1, (2, 3, 0, 1)) * block.conv_1.std else: block.conv_1.weight[:] = tensor( '%s/weight' % prefix_1, (3, 2, 0, 1)) * block.conv_1.std block.conv_2.weight[:] = tensor('%s/weight' % prefix_2, (3, 2, 0, 1)) * block.conv_2.std block.bias_1[0, :, 0, 0] = tensor('%s/bias' % prefix_1) block.bias_2[0, :, 0, 0] = tensor('%s/bias' % prefix_2) block.style_1.weight[:] = tensor('%s/StyleMod/weight' % prefix_1, (1, 0)) * block.style_1.std block.style_1.bias[:] = tensor('%s/StyleMod/bias' % prefix_1) block.style_2.weight[:] = tensor('%s/StyleMod/weight' % prefix_2, (1, 0)) * block.style_2.std block.style_2.bias[:] = tensor('%s/StyleMod/bias' % prefix_2) model.generator.to_rgb[i].to_rgb.weight[:] = tensor( 'G_synthesis/ToRGB_lod%d/weight' % (j), (3, 2, 0, 1)) * model.generator.to_rgb[i].to_rgb.std model.generator.to_rgb[i].to_rgb.bias[:] = tensor( 'G_synthesis/ToRGB_lod%d/bias' % (j)) return model, Gs_
def main(): os.makedirs(a.out_dir, exist_ok=True) # setup generator fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.func_name = 'training.stylegan2_multi.G_main' Gs_kwargs.verbose = a.verbose Gs_kwargs.size = a.size Gs_kwargs.scale_type = a.scale_type Gs_kwargs.impl = a.ops # load model with arguments sess = tflib.init_tf({'allow_soft_placement': True}) pkl_name = osp.splitext(a.model)[0] with open(pkl_name + '.pkl', 'rb') as file: network = pickle.load(file, encoding='latin1') try: _, _, network = network except: pass for k in list(network.static_kwargs.keys()): Gs_kwargs[k] = network.static_kwargs[k] # reload custom network, if needed if '.pkl' in a.model.lower(): print(' .. Gs from pkl ..', basename(a.model)) Gs = network else: # reconstruct network print(' .. Gs custom ..', basename(a.model)) Gs = tflib.Network('Gs', **Gs_kwargs) Gs.copy_vars_from(network) z_dim = Gs.input_shape[1] dz_dim = 512 # dlatent_size try: dl_dim = 2 * (int(np.floor(np.log2(Gs_kwargs.resolution))) - 1) except: print(' Resave model, no resolution kwarg found!') exit(1) dlat_shape = (1, dl_dim, dz_dim) # [1,18,512] # read saved latents if a.dlatents is not None and osp.isfile(a.dlatents): key_dlatents = load_latents(a.dlatents) if len(key_dlatents.shape) == 2: key_dlatents = np.expand_dims(key_dlatents, 0) elif a.dlatents is not None and osp.isdir(a.dlatents): # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1] key_dlatents = [] npy_list = file_list(a.dlatents, 'npy') for npy in npy_list: key_dlatent = load_latents(npy) if len(key_dlatent.shape) == 2: key_dlatent = np.expand_dims(key_dlatent, 0) key_dlatents.append(key_dlatent) key_dlatents = np.concatenate(key_dlatents) # [frm,18,512] else: print(' No input dlatents found') exit() key_dlatents = key_dlatents[:, np.newaxis] # [frm,1,18,512] print(' key dlatents', key_dlatents.shape) # replace higher layers with single (style) latent if a.style_npy_file is not None: print(' styling with latent', a.style_npy_file) style_dlatent = load_latents(a.style_npy_file) while len(style_dlatent.shape) < 4: style_dlatent = np.expand_dims(style_dlatent, 0) # try replacing 5 by other value, less than dl_dim key_dlatents[:, :, range(5, dl_dim), :] = style_dlatent[:, :, range(5, dl_dim), :] frames = key_dlatents.shape[0] * a.fstep dlatents = latent_anima(dlat_shape, frames, a.fstep, key_latents=key_dlatents, cubic=a.cubic, verbose=True) # [frm,1,512] print(' dlatents', dlatents.shape) frame_count = dlatents.shape[0] # truncation trick dlatent_avg = Gs.get_var('dlatent_avg') # (512,) tr_range = range(0, 8) dlatents[:, :, tr_range, :] = dlatent_avg + (dlatents[:, :, tr_range, :] - dlatent_avg) * a.trunc # distort image by tweaking initial const layer if a.digress > 0: try: latent_size = Gs.static_kwargs['latent_size'] except: latent_size = 512 # default latent size try: init_res = Gs.static_kwargs['init_res'] except: init_res = (4, 4) # default initial layer size dconst = a.digress * latent_anima([1, latent_size, *init_res], frames, a.fstep, cubic=True, verbose=False) else: dconst = np.zeros([frame_count, 1, 1, 1, 1]) # generate images from latent timeline pbar = ProgressBar(frame_count) for i in range(frame_count): if a.digress is True: tf.get_default_session().run(tf.assign(wvars[0], wts[i])) # generate multi-latent result if Gs.num_inputs == 2: output = Gs.components.synthesis.run(dlatents[i], randomize_noise=False, output_transform=fmt, minibatch_size=1) else: output = Gs.components.synthesis.run(dlatents[i], [None], dconst[i], randomize_noise=False, output_transform=fmt, minibatch_size=1) ext = 'png' if output.shape[3] == 4 else 'jpg' filename = osp.join(a.out_dir, "%06d.%s" % (i, ext)) imsave(filename, output[0]) pbar.upd()
def training_loop_vc( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. I_args={}, # Options for infogan-head/vcgan-head network. I_info_args={}, # Options for infogan-head/vcgan-head network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). use_info_gan=False, # Whether to use info-gan. use_vc_head=False, # Whether to use vc-head. use_vc_head_with_cls=False, # Whether to use classification in discriminator. data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=3, # Number of discrete latents in model. n_continuous=4, # Number of continuous latents in model. n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) if use_info_gan or use_vc_head or use_vc_head_with_cls: I = tflib.Network('I', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **I_args) if use_vc_head_with_cls: I_info = tflib.Network('I_info', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **I_info_args) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_info_gan or use_vc_head: rG, rD, rI, rGs = misc.load_pkl(resume_pkl) elif use_vc_head_with_cls: rG, rD, rI, rI_info, rGs = misc.load_pkl(resume_pkl) else: rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) if use_info_gan or use_vc_head or use_vc_head_with_cls: I.copy_vars_from(rI) if use_vc_head_with_cls: I_info.copy_vars_from(rI_info) Gs.copy_vars_from(rGs) else: G = rG D = rD if use_info_gan or use_vc_head or use_vc_head_with_cls: I = rI if use_vc_head_with_cls: I_info = rI_info Gs = rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() if use_info_gan or use_vc_head or use_vc_head_with_cls: I.print_layers() if use_vc_head_with_cls: I_info.print_layers() # pdb.set_trace() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) if traversal_grid: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) # pdb.set_trace() grid_fakes, _ = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=False) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') if use_info_gan or use_vc_head or use_vc_head_with_cls: I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') if use_vc_head_with_cls: I_info_gpu = I_info if gpu == 0 else I_info.clone( I_info.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): if use_info_gan or use_vc_head: G_loss, G_reg, I_loss, _ = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, I=I_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) elif use_vc_head_with_cls: G_loss, G_reg, I_loss, I_info_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, I=I_gpu, I_info=I_info_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) else: G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) # print('G_gpu.trainables:', G_gpu.trainables) # print('D_gpu.trainables:', D_gpu.trainables) # print('I_gpu.trainables:', I_gpu.trainables) if use_info_gan or use_vc_head: GI_gpu_trainables = collections.OrderedDict( list(G_gpu.trainables.items()) + list(I_gpu.trainables.items())) G_opt.register_gradients(tf.reduce_mean(G_loss + I_loss), GI_gpu_trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # G_opt.register_gradients(tf.reduce_mean(I_loss), # GI_gpu_trainables) # D_opt.register_gradients(tf.reduce_mean(I_loss), # D_gpu.trainables) elif use_vc_head_with_cls: GIIinfo_gpu_trainables = collections.OrderedDict( list(G_gpu.trainables.items()) + list(I_gpu.trainables.items()) + list(I_info_gpu.trainables.items())) G_opt.register_gradients( tf.reduce_mean(G_loss + I_loss + I_info_loss), GIIinfo_gpu_trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) else: G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # if use_info_gan: # # INFO-GAN-HEAD loss # G_opt.register_gradients(tf.reduce_mean(I_loss), # G_gpu.trainables) # G_opt.register_gradients(tf.reduce_mean(I_loss), # I_gpu.trainables) # D_opt.register_gradients(tf.reduce_mean(I_loss), # D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() if use_info_gan or use_vc_head or use_vc_head_with_cls: I.setup_weight_histograms() if use_vc_head_with_cls: I_info.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes, _ = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=False) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) if use_info_gan or use_vc_head: misc.save_pkl((G, D, I, Gs), pkl) elif use_vc_head_with_cls: misc.save_pkl((G, D, I, I_info, Gs), pkl) else: misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. if use_info_gan or use_vc_head: misc.save_pkl((G, D, I, Gs), dnnlib.make_run_dir_path('network-final.pkl')) elif use_vc_head_with_cls: misc.save_pkl((G, D, I, I_info, Gs), dnnlib.make_run_dir_path('network-final.pkl')) else: misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( run_dir='.', # Output directory. G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. loss_args={}, # Options for loss function. train_dataset_args={}, # Options for dataset to train with. # Options for dataset to evaluate metrics against. metric_dataset_args={}, augment_args={}, # Options for adaptive augmentations. metric_arg_list=[], # Metrics to evaluate during training. num_gpus=1, # Number of GPUs to use. minibatch_size=32, # Global minibatch size. minibatch_gpu=4, # Number of samples processed at a time by one GPU. # Half-life of the exponential moving average (EMA) of generator weights. G_smoothing_kimg=10, G_smoothing_rampup=None, # EMA ramp-up coefficient. # Number of minibatches to run in the inner loop. minibatch_repeats=4, lazy_regularization=True, # Perform regularization as a separate training step? # How often the perform regularization for G? Ignored if lazy_regularization=False. G_reg_interval=4, # How often the perform regularization for D? Ignored if lazy_regularization=False. D_reg_interval=16, # Total length of the training, measured in thousands of real images. total_kimg=25000, kimg_per_tick=4, # Progress snapshot interval. # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. image_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. network_snapshot_ticks=50, resume_pkl=None, # Network pickle to resume training from. # Callback function for determining whether to abort training. abort_fn=None, progress_fn=None, # Callback function for updating training progress. ): assert minibatch_size % (num_gpus * minibatch_gpu) == 0 start_time = time.time() print('Loading training set...') training_set = dataset.load_dataset(**train_dataset_args) print('Image shape:', np.int32(training_set.shape).tolist()) print('Label shape:', [training_set.label_size]) print() print('Constructing networks...') with tf.device('/gpu:0'): G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: rG, rD, rGs = pickle.load(f) G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) G.print_layers() D.print_layers() print('Exporting sample images...') grid_size, grid_reals, grid_labels = setup_snapshot_image_grid( training_set) save_image_grid(grid_reals, os.path.join(run_dir, 'reals.png'), drange=[0, 255], grid_size=grid_size) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu) # save_image_grid(grid_fakes, os.path.join( # run_dir, 'fakes_init.png'), drange=[-1, 1], grid_size=grid_size) print(f'Replicating networks across {num_gpus} GPUs...') G_gpus = [G] D_gpus = [D] for gpu in range(1, num_gpus): with tf.device(f'/gpu:{gpu}'): G_gpus.append(G.clone(f'{G.name}_gpu{gpu}')) D_gpus.append(D.clone(f'{D.name}_gpu{gpu}')) print('Initializing augmentations...') aug = None if augment_args.get('class_name', None) is not None: aug = dnnlib.util.construct_class_by_name(**augment_args) aug.init_validation_set(D_gpus=D_gpus, training_set=training_set) print('Setting up optimizers...') G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args[ 'minibatch_multiplier'] = minibatch_size // num_gpus // minibatch_gpu if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) print('Constructing training graph...') data_fetch_ops = [] training_set.configure(minibatch_gpu) for gpu, (G_gpu, D_gpu) in enumerate(zip(G_gpus, D_gpus)): with tf.name_scope(f'Train_gpu{gpu}'), tf.device(f'/gpu:{gpu}'): # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): real_images_var = tf.Variable( name='images', trainable=False, initial_value=tf.zeros([minibatch_gpu] + training_set.shape)) real_labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ minibatch_gpu, training_set.label_size ])) real_images_write, real_labels_write = training_set.get_minibatch_tf( ) real_images_write = tflib.convert_images_from_uint8( real_images_write) data_fetch_ops += [ tf.assign(real_images_var, real_images_write) ] data_fetch_ops += [ tf.assign(real_labels_var, real_labels_write) ] # Evaluate loss function and register gradients. fake_labels = training_set.get_random_labels_tf(minibatch_gpu) terms = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, aug=aug, fake_labels=fake_labels, real_images=real_images_var, real_labels=real_labels_var, **loss_args) if lazy_regularization: if terms.G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(terms.G_reg * G_reg_interval), G_gpu.trainables) if terms.D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(terms.D_reg * D_reg_interval), D_gpu.trainables) else: if terms.G_reg is not None: terms.G_loss += terms.G_reg if terms.D_reg is not None: terms.D_loss += terms.D_reg G_opt.register_gradients(tf.reduce_mean(terms.G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(terms.D_loss), D_gpu.trainables) print('Finalizing training ops...') data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_beta_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[]) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta_in) tflib.init_uninitialized_vars() with tf.device('/gpu:0'): peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() print('Initializing metrics...') summary_log = tf.summary.FileWriter(run_dir) metrics = [] for args in metric_arg_list: metric = dnnlib.util.construct_class_by_name(**args) metric.configure(dataset_args=metric_dataset_args, run_dir=run_dir) metrics.append(metric) print(f'Training for {total_kimg} kimg...') print() if progress_fn is not None: progress_fn(0, total_kimg) tick_start_time = time.time() maintenance_time = tick_start_time - start_time cur_nimg = 0 cur_tick = -1 tick_start_nimg = cur_nimg running_mb_counter = 0 done = False while not done: # Compute EMA decay parameter. Gs_nimg = G_smoothing_kimg * 1000.0 if G_smoothing_rampup is not None: Gs_nimg = min(Gs_nimg, cur_nimg * G_smoothing_rampup) Gs_beta = 0.5**(minibatch_size / max(Gs_nimg, 1e-8)) # Run training ops. for _repeat_idx in range(minibatch_repeats): rounds = range(0, minibatch_size, minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op]) if run_G_reg: tflib.run(G_reg_op) tflib.run([D_train_op, Gs_update_op], {Gs_beta_in: Gs_beta}) if run_D_reg: tflib.run(D_reg_op) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op) if run_G_reg: tflib.run(G_reg_op) tflib.run(Gs_update_op, {Gs_beta_in: Gs_beta}) for _round in rounds: tflib.run(data_fetch_op) tflib.run(D_train_op) if run_D_reg: tflib.run(D_reg_op) # Run validation. if aug is not None: aug.run_validation(minibatch_size=minibatch_size) # Tune augmentation parameters. if aug is not None: aug.tune(minibatch_size * minibatch_repeats) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) or (abort_fn is not None and abort_fn()) if done or cur_tick < 0 or cur_nimg >= tick_start_nimg + kimg_per_tick * 1000: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_end_time = time.time() total_time = tick_end_time - start_time tick_time = tick_end_time - tick_start_time # Report progress. print(' '.join([ f"tick {autosummary('Progress/tick', cur_tick):<5d}", f"kimg {autosummary('Progress/kimg', cur_nimg / 1000.0):<8.1f}", f"time {dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)):<12s}", f"sec/tick {autosummary('Timing/sec_per_tick', tick_time):<7.1f}", f"sec/kimg {autosummary('Timing/sec_per_kimg', tick_time / tick_kimg):<7.2f}", f"maintenance {autosummary('Timing/maintenance_sec', maintenance_time):<6.1f}", f"gpumem {autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30):<5.1f}", f"augment {autosummary('Progress/augment', aug.strength if aug is not None else 0):.3f}", ])) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Save snapshots. if image_snapshot_ticks is not None and ( done or cur_tick % image_snapshot_ticks == 0): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu) save_image_grid(grid_fakes, os.path.join( run_dir, f'fakes{cur_nimg // 1000:06d}.png'), drange=[-1, 1], grid_size=grid_size) if network_snapshot_ticks is not None and ( done or cur_tick % network_snapshot_ticks == 0): pkl = os.path.join( run_dir, f'network-snapshot-{cur_nimg // 1000:06d}.pkl') with open(pkl, 'wb') as f: pickle.dump((G, D, Gs), f) if len(metrics): print('Evaluating metrics...') for metric in metrics: metric.run(pkl, num_gpus=num_gpus) # Update summaries. for metric in metrics: metric.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time print() print('Exiting...') summary_log.close() training_set.close()
def main(): os.makedirs(a.out_dir, exist_ok=True) np.random.seed(seed=696) # setup generator fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.func_name = 'training.stylegan2_multi.G_main' Gs_kwargs.verbose = a.verbose Gs_kwargs.size = a.size Gs_kwargs.scale_type = a.scale_type Gs_kwargs.impl = a.ops # mask/blend latents with external latmask or by splitting the frame if a.latmask is None: nHW = [int(s) for s in a.nXY.split('-')][::-1] assert len(nHW)==2, ' Wrong count nXY: %d (must be 2)' % len(nHW) n_mult = nHW[0] * nHW[1] if a.verbose is True and n_mult > 1: print(' Latent blending w/split frame %d x %d' % (nHW[1], nHW[0])) lmask = np.tile(np.asarray([[[[None]]]]), (1,n_mult,1,1)) Gs_kwargs.countHW = nHW Gs_kwargs.splitfine = a.splitfine else: if a.verbose is True: print(' Latent blending with mask', a.latmask) n_mult = 2 if os.path.isfile(a.latmask): # single file lmask = np.asarray([[img_read(a.latmask)[:,:,0] / 255.]]) # [h,w] elif os.path.isdir(a.latmask): # directory with frame sequence lmask = np.asarray([[img_read(f)[:,:,0] / 255. for f in img_list(a.latmask)]]) # [h,w] else: print(' !! Blending mask not found:', a.latmask); exit(1) lmask = np.concatenate((lmask, 1 - lmask), 1) # [frm,2,h,w] Gs_kwargs.latmask_res = lmask.shape[2:] # load model with arguments sess = tflib.init_tf({'allow_soft_placement':True}) pkl_name = osp.splitext(a.model)[0] with open(pkl_name + '.pkl', 'rb') as file: network = pickle.load(file, encoding='latin1') try: _, _, network = network except: pass for k in list(network.static_kwargs.keys()): Gs_kwargs[k] = network.static_kwargs[k] # reload custom network, if needed if '.pkl' in a.model.lower(): print(' .. Gs from pkl ..', basename(a.model)) Gs = network else: # reconstruct network print(' .. Gs custom ..', basename(a.model)) # print(Gs_kwargs) Gs = tflib.Network('Gs', **Gs_kwargs) Gs.copy_vars_from(network) if a.verbose is True: print('kwargs:', ['%s: %s'%(kv[0],kv[1]) for kv in sorted(Gs.static_kwargs.items())]) if a.verbose is True: print(' out shape', Gs.output_shape[1:]) if a.size is None: a.size = Gs.output_shape[2:] if a.verbose is True: print(' making timeline..') lats = [] # list of [frm,1,512] for i in range(n_mult): lat_tmp = latent_anima((1, Gs.input_shape[1]), a.frames, a.fstep, cubic=a.cubic, gauss=a.gauss, verbose=False) # [frm,1,512] lats.append(lat_tmp) # list of [frm,1,512] latents = np.concatenate(lats, 1) # [frm,X,512] print(' latents', latents.shape) frame_count = latents.shape[0] # distort image by tweaking initial const layer if a.digress > 0: try: latent_size = Gs.static_kwargs['latent_size'] except: latent_size = 512 # default latent size try: init_res = Gs.static_kwargs['init_res'] except: init_res = (4,4) # default initial layer size dconst = [] for i in range(n_mult): dc_tmp = a.digress * latent_anima([1, latent_size, *init_res], a.frames, a.fstep, cubic=True, verbose=False) dconst.append(dc_tmp) dconst = np.concatenate(dconst, 1) else: dconst = np.zeros([frame_count, 1, 1, 1, 1]) # labels / conditions try: label_size = Gs_kwargs.label_size except: label_size = 0 if label_size > 0: labels = np.zeros((frame_count, n_mult, label_size)) # [frm,X,lbl] if a.labels is None: label_ids = [] for i in range(n_mult): label_ids.append(random.randint(0, label_size-1)) else: label_ids = [int(x) for x in a.labels.split('-')] label_ids = label_ids[:n_mult] # ensure we have enough labels for i, l in enumerate(label_ids): labels[:,i,l] = 1 else: labels = [None] # generate images from latent timeline pbar = ProgressBar(frame_count) for i in range(frame_count): latent = latents[i] # [X,512] label = labels[i % len(labels)] latmask = lmask[i % len(lmask)] if lmask is not None else [None] # [X,h,w] dc = dconst[i % len(dconst)] # [X,512,4,4] # generate multi-latent result if Gs.num_inputs == 2: output = Gs.run(latent, label, truncation_psi=a.trunc, randomize_noise=False, output_transform=fmt) else: output = Gs.run(latent, label, latmask, dc, truncation_psi=a.trunc, randomize_noise=False, output_transform=fmt) # save image ext = 'png' if output.shape[3]==4 else 'jpg' filename = osp.join(a.out_dir, "%06d.%s" % (i,ext)) imsave(filename, output[0]) pbar.upd() # convert latents to dlatents, save them if a.save_lat is True: latents = latents.squeeze(1) # [frm,512] dlatents = Gs.components.mapping.run(latents, label, dtype='float16') # [frm,18,512] filename = '{}-{}-{}.npy'.format(basename(a.model), a.size[1], a.size[0]) filename = osp.join(osp.dirname(a.out_dir), filename) np.save(filename, dlatents) print('saved dlatents', dlatents.shape, 'to', filename)
# Defining Input Placeholders image_path = tf.placeholder(tf.string) audio_path = tf.placeholder(tf.string) # Loading High-Resolution Image, Downsampled Low-Resolution Image, Preprocessed Audio, Nearest-Neighbor Interpolation of Inout Low-Resolution Image high_res_image, low_res_image, audio, low_res_image_nearest = load_test_sample( image_path, audio_path) # Constructing All Encoders with tf.device("/GPU:0"): tflib.init_tf() _, _, G = pickle.load(open(FLAGS.STYLEGAN_CHECKPOINT, "rb")) Gs = tflib.Network(name=G.name, func_name="networks_stylegan.G_style", **G.static_kwargs) with tf.variable_scope(LR_ENCODER_SCOPE, reuse=tf.AUTO_REUSE): encoded_input = LowResEncoder( input=low_res_image, num_channels=3, resolution=8, batch_size=FLAGS.BATCH_SIZE, num_scales=3, n_filters=128, output_feature_size=512, ) with tf.variable_scope(AUDIO_ENCODER_SCOPE, reuse=tf.AUTO_REUSE): audio_encoded_input = SpectrogramEncoder(
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int, eval_interval: int, minibatch_size: int, learning_rate: float, ramp_down_perc: float, noise: dict, validation_config: dict, train_tfrecords: str, noise2noise: bool): noise_augmenter = dnnlib.util.call_func_by_name(**noise) validation_set = ValidationSet(submit_config) validation_set.load(**validation_config) # Create a run context (hides low level details, exposes simple API to manage the run) # noinspection PyTypeChecker ctx = dnnlib.RunContext(submit_config, config) # Initialize TensorFlow graph and session using good default settings tfutil.init_tf(config.tf_config) dataset_iter = create_dataset(train_tfrecords, minibatch_size, noise_augmenter.add_train_noise_tf) # Construct the network using the Network helper class and a function defined in config.net_config with tf.device("/gpu:0"): net = tflib.Network(**config.net_config) # Optionally print layer information net.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device("/cpu:0"): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) noisy_input, noisy_target, clean_target = dataset_iter.get_next() noisy_input_split = tf.split(noisy_input, submit_config.num_gpus) noisy_target_split = tf.split(noisy_target, submit_config.num_gpus) clean_target_split = tf.split(clean_target, submit_config.num_gpus) # Define the loss function using the Optimizer helper class, this will take care of multi GPU opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config) for gpu in range(submit_config.num_gpus): with tf.device("/gpu:%d" % gpu): net_gpu = net if gpu == 0 else net.clone() denoised = net_gpu.get_output_for(noisy_input_split[gpu]) if noise2noise: meansq_error = tf.reduce_mean( tf.square(noisy_target_split[gpu] - denoised)) else: meansq_error = tf.reduce_mean( tf.square(clean_target_split[gpu] - denoised)) # Create an autosummary that will average over all GPUs with tf.control_dependencies([autosummary("Loss", meansq_error)]): opt.register_gradients(meansq_error, net_gpu.trainables) train_step = opt.apply_updates() # Create a log file for Tensorboard summary_log = tf.summary.FileWriter(submit_config.run_dir) summary_log.add_graph(tf.get_default_graph()) print('Training...') time_maintenance = ctx.get_time_since_last_update() ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=0, max_epoch=iteration_count) # *********************************** # The actual training loop for i in range(iteration_count): # Whether to stop the training or not should be asked from the context if ctx.should_stop(): break # Dump training status if i % eval_interval == 0: time_train = ctx.get_time_since_last_update() time_total = ctx.get_time_since_start() # Evaluate 'x' to draw a batch of inputs [source_mb, target_mb] = tfutil.run([noisy_input, clean_target]) denoised = net.run(source_mb) save_image(submit_config, denoised[0], "img_{0}_y_pred.png".format(i)) save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i)) save_image(submit_config, source_mb[0], "img_{0}_x_aug.png".format(i)) validation_set.evaluate(net, i, noise_augmenter.add_validation_noise_np) print( 'iter %-10d time %-12s eta %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f' % (autosummary('Timing/iter', i), dnnlib.util.format_time( autosummary('Timing/total_sec', time_total)), dnnlib.util.format_time( autosummary('Timing/total_sec', (time_train / eval_interval) * (iteration_count - i))), autosummary('Timing/sec_per_eval', time_train), autosummary('Timing/sec_per_iter', time_train / eval_interval), autosummary('Timing/maintenance_sec', time_maintenance))) dnnlib.tflib.autosummary.save_summaries(summary_log, i) ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=i, max_epoch=iteration_count) time_maintenance = ctx.get_last_update_interval() - time_train # Training epoch lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc, learning_rate) tfutil.run([train_step], {lrate_in: lrate}) # End of training print("Elapsed time: {0}".format( util.format_time(ctx.get_time_since_start()))) save_snapshot(submit_config, net, 'final') # Summary log and context should be closed at the end summary_log.close() ctx.close()
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=7000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): # Load pre-trained if resume_run_id is not None: if resume_run_id == 'latest': URL_FFHQ = 'https://s3-us-west-2.amazonaws.com/nanonets/blogs/karras2019stylegan-ffhq-1024x1024.pkl' tflib.init_tf() with dnnlib.util.open_url(URL_FFHQ, cache_dir=config.cache_dir) as f: G, D, Gs = pickle.load(f) """ network_pkl, resume_kimg = misc.locate_latest_pkl() print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) """ elif resume_run_id == 'restore_partial': print('Restore partially...') # Initialize networks G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # Load pre-trained networks assert restore_partial_fn != None G_partial, D_partial, Gs_partial = pickle.load( open(restore_partial_fn, 'rb')) # Restore (subset of) pre-trained weights # (only parameters that match both name and shape) G.copy_compatible_trainables_from(G_partial) D.copy_compatible_trainables_from(D_partial) Gs.copy_compatible_trainables_from(Gs_partial) else: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) # Start from scratch else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( submit_config, Encoder_args={}, E_opt_args={}, D_opt_args={}, E_loss_args={}, D_loss_args={}, lr_args=EasyDict(), tf_config={}, dataset_args=EasyDict(), decoder_pkl=EasyDict(), drange_data=[0, 255], drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. mirror_augment=False, resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? max_iters=150000, E_smoothing=0.999): tflib.init_tf(tf_config) with tf.name_scope('input'): real_train = tf.placeholder(tf.float32, [ submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size ], name='real_image_train') real_test = tf.placeholder(tf.float32, [ submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size ], name='real_image_test') real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0) with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) E, G, D, Gs, NE = misc.load_pkl(network_pkl) start = int(network_pkl.split('-')[-1].split('.') [0]) // submit_config.batch_size else: print('Constructing networks...') G, D, Gs, NE = misc.load_pkl(decoder_pkl.decoder_pkl) E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, phase=True, **Encoder_args) start = 0 Gs.print_layers() E.print_layers() D.print_layers() global_step = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step, lr_args.decay_step, lr_args.decay_rate, staircase=lr_args.stair) add_global = global_step.assign_add(1) E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) E_loss_rec = 0. E_loss_adv = 0. D_loss_real = 0. D_loss_fake = 0. D_loss_grad = 0. for gpu in range(submit_config.num_gpus): print('build graph on gpu %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow') perceptual_model = PerceptualModel( img_size=[submit_config.image_size, submit_config.image_size], multi_layers=False) real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net) with tf.name_scope('E_loss'), tf.control_dependencies(None): E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, perceptual_model=perceptual_model, reals=real_gpu, **E_loss_args) E_loss_rec += recon_loss E_loss_adv += adv_loss with tf.name_scope('D_loss'), tf.control_dependencies(None): D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args) D_loss_real += loss_real D_loss_fake += loss_fake D_loss_grad += loss_gp with tf.control_dependencies([add_global]): E_opt.register_gradients(E_loss, E_gpu.trainables) D_opt.register_gradients(D_loss, D_gpu.trainables) E_loss_rec /= submit_config.num_gpus E_loss_adv /= submit_config.num_gpus D_loss_real /= submit_config.num_gpus D_loss_fake /= submit_config.num_gpus D_loss_grad /= submit_config.num_gpus E_train_op = E_opt.apply_updates() D_train_op = D_opt.apply_updates() #Es_update_op = Es.setup_as_moving_average_of(E, beta=E_smoothing) print('building testing graph...') fake_X_val = test(E, Gs, real_test, submit_config) sess = tf.get_default_session() print('Getting training data...') image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train') image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test') summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: E.setup_weight_histograms() D.setup_weight_histograms() cur_nimg = start * submit_config.batch_size cur_tick = 0 tick_start_nimg = cur_nimg start_time = time.time() print('Optimization starts!!!') for it in range(start, max_iters): feed_dict = {real_train: sess.run(image_batch_train)} sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict) sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict) cur_nimg += submit_config.batch_size if it % 100 == 0: print("Iter: %06d kimg: %-8.1f time: %-12s" % (it, cur_nimg / 1000, dnnlib.util.format_time(time.time() - start_time))) sys.stdout.flush() tflib.autosummary.save_summaries(summary_log, it) if cur_nimg >= tick_start_nimg + 65000: cur_tick += 1 tick_start_nimg = cur_nimg if cur_tick % image_snapshot_ticks == 0: batch_images_test = sess.run(image_batch_test) batch_images_test = misc.adjust_dynamic_range( batch_images_test.astype(np.float32), [0, 255], [-1., 1.]) samples2 = sess.run(fake_X_val, feed_dict={real_test: batch_images_test}) samples2 = samples2.transpose(0, 2, 3, 1) batch_images_test = batch_images_test.transpose(0, 2, 3, 1) orin_recon = np.concatenate([batch_images_test, samples2], axis=0) imwrite(immerge(orin_recon, 2, submit_config.batch_size_test), '%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg)) if cur_tick % network_snapshot_ticks == 0: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs, NE), pkl) misc.save_pkl((E, G, D, Gs, NE), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close()
def training_loop( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. G_opt_args = {}, # Options for generator optimizer. D_opt_args = {}, # Options for discriminator optimizer. G_loss_args = {}, # Options for generator loss. D_loss_args = {}, # Options for discriminator loss. dataset_args = {}, # Options for dataset.load_dataset(). sched_args = {}, # Options for train.TrainingSchedule. grid_args = {}, # Options for train.setup_snapshot_image_grid(). setname = None, # Model name tf_config = {}, # Options for tflib.init_tf(). G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. lazy_regularization = True, # Perform regularization as a separate training step? G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg = 25000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? mirror_augment_v = False, # Enable mirror augment vertically? drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_pkl = 'latest', # Network pickle to resume training from, None = train from scratch. resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. restore_partial_fn = None, # Filename of network for partial restore resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) # custom resolution - for saved model name below resolution = training_set.resolution if training_set.init_res != [4,4]: init_res_str = '-%dx%d' % (training_set.init_res[0], training_set.init_res[1]) else: init_res_str = '' ext = 'png' if training_set.shape[0] == 4 else 'jpg' print(' model base resolution', resolution) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('_reals.%s'%ext), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print(' Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: if resume_pkl == 'latest': resume_pkl, resume_kimg = misc.locate_latest_pkl(dnnlib.submit_config.run_dir_root) elif resume_pkl == 'restore_partial': print(' Restore partially...') # Initialize networks G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # Load pre-trained networks assert restore_partial_fn != None G_partial, D_partial, Gs_partial = pickle.load(open(restore_partial_fn, 'rb')) # Restore (subset of) pre-trained weights (only parameters that match both name and shape) G.copy_compatible_trainables_from(G_partial) D.copy_compatible_trainables_from(D_partial) Gs.copy_compatible_trainables_from(Gs_partial) else: if resume_pkl is not None and resume_kimg == 0: resume_pkl, resume_kimg = misc.locate_latest_pkl(resume_pkl) print(' Loading networks from "%s", kimg %.3g' % (resume_pkl, resume_kimg)) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers if needed and generate initial image snapshot # G.print_layers(); D.print_layers() sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.%s'%ext), drange=drange_net, grid_size=grid_size) # Setup training inputs. print(' Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args) reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, training_set.dynamic_range, drange_net) reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() # print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() print(' Training for %d kimg (%d left) \n' % (total_kimg, total_kimg-resume_kimg)) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # , sched.lod if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu} for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time if sched.lod == 0: left_kimg = total_kimg - cur_nimg / 1000 left_sec = left_kimg * tick_time / tick_kimg finaltime = time.asctime(time.localtime(cur_time + left_sec)) msg_final = '%ss left till %s ' % (shortime(left_sec), finaltime[11:16]) else: msg_final = '' # Report progress. # print('tick %-4d kimg %-6.1f lod %-5.2f minibch %-3d:%d time %-8s min/tick %-6.3g %s sec/kimg %-7.3g gpumem %-4.1f %d lr %.2g ' % ( print('tick %-4d kimg %-6.1f time %-8s %s min/tick %-6.3g sec/kimg %-7.3g gpumem %-4.1f lr %.2g ' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), # autosummary('Progress/lod', sched.lod), # autosummary('Progress/minibatch', sched.minibatch_size), # autosummary('Progress/minibatch_gpu', sched.minibatch_gpu), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), msg_final, autosummary('Timing/min_per_tick', tick_time / 60), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), # autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30), sched.G_lrate)) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('snapshot-%d-%s%s-%04d.pkl' % (resolution, setname[-1], init_res_str, cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-%04d.pkl' % (setname[:-1], resolution, setname[-1], init_res_str, cur_nimg // 1000))) # Update summaries and RunContext. tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('snapshot-%d-%s%s-final.pkl' % (resolution, setname[-1], init_res_str))) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-final.pkl' % (setname[:-1], resolution, setname[-1], init_res_str))) # All done. summary_log.close() training_set.close()
# Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers(); D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus):