def set_optimizer(cN, lrate_in, batch_multiplier, lazy_regularization = True, clip = None): args = dict(cN.opt_args) args["batch_multiplier"] = batch_multiplier args["learning_rate"] = lrate_in if lazy_regularization: mb_ratio = cN.reg_interval / (cN.reg_interval + 1) args["learning_rate"] *= mb_ratio if "beta1" in args: args["beta1"] **= mb_ratio if "beta2" in args: args["beta2"] **= mb_ratio cN.opt = tflib.Optimizer(name = f"Loss{cN.name}", clip = clip, **args) cN.reg_opt = tflib.Optimizer(name = f"Reg{cN.name}", share = cN.opt, clip = clip, **args)
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int, eval_interval: int, minibatch_size: int, learning_rate: float, ramp_down_perc: float, noise: dict, validation_config: dict, train_tfrecords: str, noise2noise: bool): noise_augmenter = dnnlib.util.call_func_by_name(**noise) validation_set = ValidationSet(submit_config) validation_set.load(**validation_config) # Create a run context (hides low level details, exposes simple API to manage the run) ctx = dnnlib.RunContext(submit_config, config) # Initialize TensorFlow graph and session using good default settings tfutil.init_tf(config.tf_config) dataset_iter = create_dataset(train_tfrecords, minibatch_size, noise_augmenter.add_train_noise_tf) # Construct the network using the Network helper class and a function defined in config.net_config with tf.device("/gpu:0"): net = tflib.Network(**config.net_config) # Optionally print layer information net.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device("/cpu:0"): lrate_in = tf.compat.v1.placeholder(tf.float32, name='lrate_in', shape=[]) #print("DEBUG train:", "dataset iter got called") noisy_input, noisy_target, clean_target = dataset_iter.get_next() noisy_input_split = tf.split(noisy_input, submit_config.num_gpus) noisy_target_split = tf.split(noisy_target, submit_config.num_gpus) print(len(noisy_input_split), noisy_input_split) clean_target_split = tf.split(clean_target, submit_config.num_gpus) # Split [?, 3, 256, 256] across num_gpus over axis 0 (i.e. the batch) # Define the loss function using the Optimizer helper class, this will take care of multi GPU opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config) radii = np.arange(128).reshape(128, 1) #image size 256, binning = 3 radial_masks = np.apply_along_axis(radial_mask, 1, radii, 128, 128, np.arange(0, 256), np.arange(0, 256), 20) print("RN SHAPE!!!!!!!!!!:", radial_masks.shape) radial_masks = np.expand_dims(radial_masks, 1) # (128, 1, 256, 256) #radial_masks = np.squeeze(np.stack((radial_masks,) * 3, -1)) # 43, 3, 256, 256 #radial_masks = radial_masks.transpose([0, 3, 1, 2]) radial_masks = radial_masks.astype(np.complex64) radial_masks = tf.expand_dims(radial_masks, 1) rn = tf.compat.v1.placeholder_with_default(radial_masks, [128, None, 1, 256, 256]) rn_split = tf.split(rn, submit_config.num_gpus, axis=1) freq_nyq = int(np.floor(int(256) / 2.0)) spatial_freq = radii.astype(np.float32) / freq_nyq spatial_freq = spatial_freq / max(spatial_freq) for gpu in range(submit_config.num_gpus): with tf.device("/gpu:%d" % gpu): net_gpu = net if gpu == 0 else net.clone() denoised_1 = net_gpu.get_output_for(noisy_input_split[gpu]) denoised_2 = net_gpu.get_output_for(noisy_target_split[gpu]) print(noisy_input_split[gpu].get_shape(), rn_split[gpu].get_shape()) if noise2noise: meansq_error = fourier_ring_correlation( noisy_target_split[gpu], denoised_1, rn_split[gpu], spatial_freq) - fourier_ring_correlation( noisy_target_split[gpu] - denoised_2, noisy_input_split[gpu] - denoised_1, rn_split[gpu], spatial_freq) else: meansq_error = tf.reduce_mean( tf.square(clean_target_split[gpu] - denoised)) # Create an autosummary that will average over all GPUs #tf.summary.histogram(name, var) with tf.control_dependencies([autosummary("Loss", meansq_error)]): opt.register_gradients(meansq_error, net_gpu.trainables) train_step = opt.apply_updates() # Create a log file for Tensorboard summary_log = tf.compat.v1.summary.FileWriter(submit_config.run_dir) summary_log.add_graph(tf.compat.v1.get_default_graph()) print('Training...') time_maintenance = ctx.get_time_since_last_update() ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=0, max_epoch=iteration_count) # The actual training loop for i in range(iteration_count): # Whether to stop the training or not should be asked from the context if ctx.should_stop(): break # Dump training status if i % eval_interval == 0: time_train = ctx.get_time_since_last_update() time_total = ctx.get_time_since_start() print("DEBUG TRAIN!", noisy_input.dtype, noisy_input[0][0].dtype) # Evaluate 'x' to draw a batch of inputs [source_mb, target_mb] = tfutil.run([noisy_input, clean_target]) denoised = net.run(source_mb) save_image(submit_config, denoised[0], "img_{0}_y_pred.tif".format(i)) save_image(submit_config, target_mb[0], "img_{0}_y.tif".format(i)) save_image(submit_config, source_mb[0], "img_{0}_x_aug.tif".format(i)) validation_set.evaluate(net, i, noise_augmenter.add_validation_noise_np) print( 'iter %-10d time %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f' % (autosummary('Timing/iter', i), dnnlib.util.format_time( autosummary('Timing/total_sec', time_total)), autosummary('Timing/sec_per_eval', time_train), autosummary('Timing/sec_per_iter', time_train / eval_interval), autosummary('Timing/maintenance_sec', time_maintenance))) dnnlib.tflib.autosummary.save_summaries(summary_log, i) ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=i, max_epoch=iteration_count) time_maintenance = ctx.get_last_update_interval() - time_train save_snapshot(submit_config, net, str(i)) lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc, learning_rate) tfutil.run([train_step], {lrate_in: lrate}) print("Elapsed time: {0}".format( util.format_time(ctx.get_time_since_start()))) save_snapshot(submit_config, net, 'final') # Summary log and context should be closed at the end summary_log.close() ctx.close()
def training_loop( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. G_opt_args = {}, # Options for generator optimizer. D_opt_args = {}, # Options for discriminator optimizer. G_loss_args = {}, # Options for generator loss. D_loss_args = {}, # Options for discriminator loss. dataset_args = {}, # Options for dataset.load_dataset(). sched_args = {}, # Options for train.TrainingSchedule. grid_args = {}, # Options for train.setup_snapshot_image_grid(). setname = None, # Model name tf_config = {}, # Options for tflib.init_tf(). G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. lazy_regularization = True, # Perform regularization as a separate training step? G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg = 25000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? mirror_augment_v = False, # Enable mirror augment vertically? drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_pkl = 'latest', # Network pickle to resume training from, None = train from scratch. resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. restore_partial_fn = None, # Filename of network for partial restore resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) # custom resolution - for saved model name below resolution = training_set.resolution if training_set.init_res != [4,4]: init_res_str = '-%dx%d' % (training_set.init_res[0], training_set.init_res[1]) else: init_res_str = '' ext = 'png' if training_set.shape[0] == 4 else 'jpg' print(' model base resolution', resolution) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('_reals.%s'%ext), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print(' Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: if resume_pkl == 'latest': resume_pkl, resume_kimg = misc.locate_latest_pkl(dnnlib.submit_config.run_dir_root) elif resume_pkl == 'restore_partial': print(' Restore partially...') # Initialize networks G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # Load pre-trained networks assert restore_partial_fn != None G_partial, D_partial, Gs_partial = pickle.load(open(restore_partial_fn, 'rb')) # Restore (subset of) pre-trained weights (only parameters that match both name and shape) G.copy_compatible_trainables_from(G_partial) D.copy_compatible_trainables_from(D_partial) Gs.copy_compatible_trainables_from(Gs_partial) else: if resume_pkl is not None and resume_kimg == 0: resume_pkl, resume_kimg = misc.locate_latest_pkl(resume_pkl) print(' Loading networks from "%s", kimg %.3g' % (resume_pkl, resume_kimg)) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers if needed and generate initial image snapshot # G.print_layers(); D.print_layers() sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.%s'%ext), drange=drange_net, grid_size=grid_size) # Setup training inputs. print(' Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args) reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, training_set.dynamic_range, drange_net) reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() # print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() print(' Training for %d kimg (%d left) \n' % (total_kimg, total_kimg-resume_kimg)) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # , sched.lod if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu} for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time if sched.lod == 0: left_kimg = total_kimg - cur_nimg / 1000 left_sec = left_kimg * tick_time / tick_kimg finaltime = time.asctime(time.localtime(cur_time + left_sec)) msg_final = '%ss left till %s ' % (shortime(left_sec), finaltime[11:16]) else: msg_final = '' # Report progress. # print('tick %-4d kimg %-6.1f lod %-5.2f minibch %-3d:%d time %-8s min/tick %-6.3g %s sec/kimg %-7.3g gpumem %-4.1f %d lr %.2g ' % ( print('tick %-4d kimg %-6.1f time %-8s %s min/tick %-6.3g sec/kimg %-7.3g gpumem %-4.1f lr %.2g ' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), # autosummary('Progress/lod', sched.lod), # autosummary('Progress/minibatch', sched.minibatch_size), # autosummary('Progress/minibatch_gpu', sched.minibatch_gpu), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), msg_final, autosummary('Timing/min_per_tick', tick_time / 60), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), # autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30), sched.G_lrate)) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('snapshot-%d-%s%s-%04d.pkl' % (resolution, setname[-1], init_res_str, cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-%04d.pkl' % (setname[:-1], resolution, setname[-1], init_res_str, cur_nimg // 1000))) # Update summaries and RunContext. tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('snapshot-%d-%s%s-final.pkl' % (resolution, setname[-1], init_res_str))) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-final.pkl' % (setname[:-1], resolution, setname[-1], init_res_str))) # All done. summary_log.close() training_set.close()
def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. loss_args={}, # Options for loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for metrics. metric_args={}, # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). ema_start_kimg=None, # Start of the exponential moving average. Default to the half-life period. G_ema_kimg=10, # Half-life of the exponential moving average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=False, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=4, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=2, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=1, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False ): # Construct new networks according to G_args and D_args before resuming training? if ema_start_kimg is None: ema_start_kimg = G_ema_kimg # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: resume_networks = misc.load_pkl(resume_pkl) rG, rD, rGs = resume_networks if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) G_lrate_in = tf.placeholder(tf.float32, name='G_lrate_in', shape=[]) D_lrate_in = tf.placeholder(tf.float32, name='D_lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) run_D_reg_in = tf.placeholder(tf.bool, name='run_D_reg', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta_mul_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[]) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_ema_kimg * 1000.0) if G_ema_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) G_opt_args['learning_rate'] = G_lrate_in D_opt_args['learning_rate'] = D_lrate_in for args in [G_opt_args, D_opt_args]: args['minibatch_multiplier'] = minibatch_multiplier G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) # Build training graph for each GPU. for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): with tf.name_scope('DataFetch'): reals_read, labels_read = training_set.get_minibatch_tf() reals_read = process_reals(reals_read, lod_in, mirror_augment, training_set.dynamic_range, drange_net) # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('loss'): G_loss, D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, real_labels=labels_read, **loss_args) # Register gradients. if not lazy_regularization: if D_reg is not None: D_loss += D_reg else: if D_reg is not None: D_loss = tf.cond(run_D_reg_in, lambda: D_loss + D_reg * D_reg_interval, lambda: D_loss) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta * Gs_beta_mul_in) with tf.control_dependencies([Gs_update_op]): G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list, **metric_args) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, G_lrate_in: sched.G_lrate, D_lrate_in: sched.D_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, Gs_beta_mul_in: 1 if cur_nimg >= ema_start_kimg * 1000 else 0, } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) feed_dict[run_D_reg_in] = run_D_reg cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. for _ in rounds: tflib.run(G_train_op, feed_dict) tflib.run(D_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop_vc( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. I_args={}, # Options for infogan-head/vcgan-head network. I_info_args={}, # Options for infogan-head/vcgan-head network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). use_info_gan=False, # Whether to use info-gan. use_vc_head=False, # Whether to use vc-head. use_vc_head_with_cls=False, # Whether to use classification in discriminator. data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=3, # Number of discrete latents in model. n_continuous=4, # Number of continuous latents in model. n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) if use_info_gan or use_vc_head or use_vc_head_with_cls: I = tflib.Network('I', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **I_args) if use_vc_head_with_cls: I_info = tflib.Network('I_info', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **I_info_args) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_info_gan or use_vc_head: rG, rD, rI, rGs = misc.load_pkl(resume_pkl) elif use_vc_head_with_cls: rG, rD, rI, rI_info, rGs = misc.load_pkl(resume_pkl) else: rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) if use_info_gan or use_vc_head or use_vc_head_with_cls: I.copy_vars_from(rI) if use_vc_head_with_cls: I_info.copy_vars_from(rI_info) Gs.copy_vars_from(rGs) else: G = rG D = rD if use_info_gan or use_vc_head or use_vc_head_with_cls: I = rI if use_vc_head_with_cls: I_info = rI_info Gs = rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() if use_info_gan or use_vc_head or use_vc_head_with_cls: I.print_layers() if use_vc_head_with_cls: I_info.print_layers() # pdb.set_trace() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) if traversal_grid: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) # pdb.set_trace() grid_fakes, _ = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=False) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') if use_info_gan or use_vc_head or use_vc_head_with_cls: I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') if use_vc_head_with_cls: I_info_gpu = I_info if gpu == 0 else I_info.clone( I_info.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): if use_info_gan or use_vc_head: G_loss, G_reg, I_loss, _ = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, I=I_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) elif use_vc_head_with_cls: G_loss, G_reg, I_loss, I_info_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, I=I_gpu, I_info=I_info_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) else: G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) # print('G_gpu.trainables:', G_gpu.trainables) # print('D_gpu.trainables:', D_gpu.trainables) # print('I_gpu.trainables:', I_gpu.trainables) if use_info_gan or use_vc_head: GI_gpu_trainables = collections.OrderedDict( list(G_gpu.trainables.items()) + list(I_gpu.trainables.items())) G_opt.register_gradients(tf.reduce_mean(G_loss + I_loss), GI_gpu_trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # G_opt.register_gradients(tf.reduce_mean(I_loss), # GI_gpu_trainables) # D_opt.register_gradients(tf.reduce_mean(I_loss), # D_gpu.trainables) elif use_vc_head_with_cls: GIIinfo_gpu_trainables = collections.OrderedDict( list(G_gpu.trainables.items()) + list(I_gpu.trainables.items()) + list(I_info_gpu.trainables.items())) G_opt.register_gradients( tf.reduce_mean(G_loss + I_loss + I_info_loss), GIIinfo_gpu_trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) else: G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # if use_info_gan: # # INFO-GAN-HEAD loss # G_opt.register_gradients(tf.reduce_mean(I_loss), # G_gpu.trainables) # G_opt.register_gradients(tf.reduce_mean(I_loss), # I_gpu.trainables) # D_opt.register_gradients(tf.reduce_mean(I_loss), # D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() if use_info_gan or use_vc_head or use_vc_head_with_cls: I.setup_weight_histograms() if use_vc_head_with_cls: I_info.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes, _ = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=False) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) if use_info_gan or use_vc_head: misc.save_pkl((G, D, I, Gs), pkl) elif use_vc_head_with_cls: misc.save_pkl((G, D, I, I_info, Gs), pkl) else: misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. if use_info_gan or use_vc_head: misc.save_pkl((G, D, I, Gs), dnnlib.make_run_dir_path('network-final.pkl')) elif use_vc_head_with_cls: misc.save_pkl((G, D, I, I_info, Gs), dnnlib.make_run_dir_path('network-final.pkl')) else: misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int, eval_interval: int, minibatch_size: int, learning_rate: float, ramp_down_perc: float, noise: dict, validation_config: dict, train_tfrecords: str, noise2noise: bool): noise_augmenter = dnnlib.util.call_func_by_name(**noise) validation_set = ValidationSet(submit_config) validation_set.load(**validation_config) # Create a run context (hides low level details, exposes simple API to manage the run) # noinspection PyTypeChecker ctx = dnnlib.RunContext(submit_config, config) # Initialize TensorFlow graph and session using good default settings tfutil.init_tf(config.tf_config) dataset_iter = create_dataset(train_tfrecords, minibatch_size, noise_augmenter.add_train_noise_tf) # Construct the network using the Network helper class and a function defined in config.net_config with tf.device("/gpu:0"): net = tflib.Network(**config.net_config) # Optionally print layer information net.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device("/cpu:0"): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) noisy_input, noisy_target, clean_target = dataset_iter.get_next() noisy_input_split = tf.split(noisy_input, submit_config.num_gpus) noisy_target_split = tf.split(noisy_target, submit_config.num_gpus) clean_target_split = tf.split(clean_target, submit_config.num_gpus) # Define the loss function using the Optimizer helper class, this will take care of multi GPU opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config) for gpu in range(submit_config.num_gpus): with tf.device("/gpu:%d" % gpu): net_gpu = net if gpu == 0 else net.clone() denoised = net_gpu.get_output_for(noisy_input_split[gpu]) if noise2noise: meansq_error = tf.reduce_mean( tf.square(noisy_target_split[gpu] - denoised)) else: meansq_error = tf.reduce_mean( tf.square(clean_target_split[gpu] - denoised)) # Create an autosummary that will average over all GPUs with tf.control_dependencies([autosummary("Loss", meansq_error)]): opt.register_gradients(meansq_error, net_gpu.trainables) train_step = opt.apply_updates() # Create a log file for Tensorboard summary_log = tf.summary.FileWriter(submit_config.run_dir) summary_log.add_graph(tf.get_default_graph()) print('Training...') time_maintenance = ctx.get_time_since_last_update() ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=0, max_epoch=iteration_count) # *********************************** # The actual training loop for i in range(iteration_count): # Whether to stop the training or not should be asked from the context if ctx.should_stop(): break # Dump training status if i % eval_interval == 0: time_train = ctx.get_time_since_last_update() time_total = ctx.get_time_since_start() # Evaluate 'x' to draw a batch of inputs [source_mb, target_mb] = tfutil.run([noisy_input, clean_target]) denoised = net.run(source_mb) save_image(submit_config, denoised[0], "img_{0}_y_pred.png".format(i)) save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i)) save_image(submit_config, source_mb[0], "img_{0}_x_aug.png".format(i)) validation_set.evaluate(net, i, noise_augmenter.add_validation_noise_np) print( 'iter %-10d time %-12s eta %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f' % (autosummary('Timing/iter', i), dnnlib.util.format_time( autosummary('Timing/total_sec', time_total)), dnnlib.util.format_time( autosummary('Timing/total_sec', (time_train / eval_interval) * (iteration_count - i))), autosummary('Timing/sec_per_eval', time_train), autosummary('Timing/sec_per_iter', time_train / eval_interval), autosummary('Timing/maintenance_sec', time_maintenance))) dnnlib.tflib.autosummary.save_summaries(summary_log, i) ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=i, max_epoch=iteration_count) time_maintenance = ctx.get_last_update_interval() - time_train # Training epoch lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc, learning_rate) tfutil.run([train_step], {lrate_in: lrate}) # End of training print("Elapsed time: {0}".format( util.format_time(ctx.get_time_since_start()))) save_snapshot(submit_config, net, 'final') # Summary log and context should be closed at the end summary_log.close() ctx.close()
def training_loop_refinement( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=True, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) _rG, _rD, rGs = misc.load_pkl(resume_pkl) del _rD, _rG if resume_with_new_nets: G.copy_vars_from(rGs) Gs.copy_vars_from(rGs) del rGs else: G = rG Gs = rGs # Set constant noise input for both G and Gs if G_args.get("randomize_noise", None) == False: noise_vars = [ var for name, var in G.components.synthesis.vars.items() if name.startswith('noise') ] rnd = np.random.RandomState(123) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] noise_vars = [ var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise') ] rnd = np.random.RandomState(123) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] # TESTS # from PIL import Image # reals, latents = training_set.get_minibatch_np(4) # reals = np.transpose(reals, [0, 2, 3, 1]) # Image.fromarray(reals[0], 'RGB').save("test_reals.png") # labels = training_set.get_random_labels_np(4) # Gs_kwargs = dnnlib.EasyDict() # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) # fakes = Gs.run(latents, labels, minibatch_size=4, **Gs_kwargs) # Image.fromarray(fakes[0], 'RGB').save("test_fakes_Gs_new.png") # fakes = G.run(latents, labels, minibatch_size=4, **Gs_kwargs) # Image.fromarray(fakes[0], 'RGB').save("test_fakes_G_new.png") # Print layers and generate initial image snapshot. G.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) # Freeze layers G_args.freeze_layers = list(G_args.get("freeze_layers", [])) def freeze_vars(gen, verbose=True): assert len(G_args.freeze_layers) > 0 for name in list(gen.trainables.keys()): if any(layer in name for layer in G_args.freeze_layers): del gen.trainables[name] if verbose: print(f"Freezed {name}") # Build training graph for each GPU. data_fetch_ops = [] loss_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if G_args.freeze_layers: freeze_vars(G_gpu, verbose=False) # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=None, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, latents=labels_read, **G_loss_args) loss_ops.append(G_loss) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) loss_op = tf.reduce_mean(tf.concat(loss_ops, axis=0)) G_train_op = G_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 loss_per_batch_sum = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } tflib.run(data_fetch_op, feed_dict) ### TEST # fakes = G.get_output_for(labels_read, training_set.get_random_labels_tf(minibatch_gpu_in), is_training=True) # this is without activation in ~[-1.5, 1.5] # fakes = tf.clip_by_value(fakes, drange_net[0], drange_net[1]) # reals = reals_read ### TEST for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: loss, _ = tflib.run([loss_op, G_train_op], feed_dict) # (loss, reals, fakes), _ = tflib.run([loss_op, G_train_op], feed_dict) tflib.run([data_fetch_op], feed_dict) # print(f"loss_tf {np.mean(loss)}") # print(f"loss_np {np.mean(np.square(reals - fakes))}") # print(f"loss_abs {np.mean(np.abs(reals - fakes))}") loss_per_batch_sum += loss #### TEST #### # if cur_nimg == sched.minibatch_size or cur_nimg % 2048 == 0: # from PIL import Image # reals = np.transpose(reals, [0, 2, 3, 1]) # fakes = np.transpose(fakes, [0, 2, 3, 1]) # diff = np.abs(reals - fakes) # print(diff.min(), diff.max()) # for idx, (fake, real) in enumerate(zip(fakes, reals)): # fake -= fake.min() # fake /= fake.max() # fake *= 255 # fake = fake.astype(np.uint8) # Image.fromarray(fake, 'RGB').save(f"fake_loss_{idx}.png") # real -= real.min() # real /= real.max() # real *= 255 # real = real.astype(np.uint8) # Image.fromarray(real, 'RGB').save(f"real_loss_{idx}.png") #### if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([Gs_update_op], feed_dict) # Slow path with gradient accumulation. FIXME: Probably wrong else: for _round in rounds: loss, _, _ = tflib.run( [loss_op, G_train_op, data_fetch_op], feed_dict) loss_per_batch_sum += loss / len(rounds) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time tick_loss = loss_per_batch_sum * sched.minibatch_size / ( tick_kimg * 1000) loss_per_batch_sum = 0 # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d loss/px %-12.8f time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), autosummary('Progress/loss_per_px', tick_loss), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, None, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, None, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( submit_config, Encoder_args={}, E_opt_args={}, D_opt_args={}, E_loss_args={}, D_loss_args={}, lr_args=EasyDict(), tf_config={}, dataset_args=EasyDict(), decoder_pkl=EasyDict(), drange_data=[0, 255], drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. mirror_augment=False, resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? max_iters=150000, E_smoothing=0.999): tflib.init_tf(tf_config) with tf.name_scope('input'): real_train = tf.placeholder(tf.float32, [ submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size ], name='real_image_train') real_test = tf.placeholder(tf.float32, [ submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size ], name='real_image_test') real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0) with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) E, G, D, Gs, NE = misc.load_pkl(network_pkl) start = int(network_pkl.split('-')[-1].split('.') [0]) // submit_config.batch_size else: print('Constructing networks...') G, D, Gs, NE = misc.load_pkl(decoder_pkl.decoder_pkl) E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, phase=True, **Encoder_args) start = 0 Gs.print_layers() E.print_layers() D.print_layers() global_step = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step, lr_args.decay_step, lr_args.decay_rate, staircase=lr_args.stair) add_global = global_step.assign_add(1) E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) E_loss_rec = 0. E_loss_adv = 0. D_loss_real = 0. D_loss_fake = 0. D_loss_grad = 0. for gpu in range(submit_config.num_gpus): print('build graph on gpu %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow') perceptual_model = PerceptualModel( img_size=[submit_config.image_size, submit_config.image_size], multi_layers=False) real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net) with tf.name_scope('E_loss'), tf.control_dependencies(None): E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, perceptual_model=perceptual_model, reals=real_gpu, **E_loss_args) E_loss_rec += recon_loss E_loss_adv += adv_loss with tf.name_scope('D_loss'), tf.control_dependencies(None): D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args) D_loss_real += loss_real D_loss_fake += loss_fake D_loss_grad += loss_gp with tf.control_dependencies([add_global]): E_opt.register_gradients(E_loss, E_gpu.trainables) D_opt.register_gradients(D_loss, D_gpu.trainables) E_loss_rec /= submit_config.num_gpus E_loss_adv /= submit_config.num_gpus D_loss_real /= submit_config.num_gpus D_loss_fake /= submit_config.num_gpus D_loss_grad /= submit_config.num_gpus E_train_op = E_opt.apply_updates() D_train_op = D_opt.apply_updates() #Es_update_op = Es.setup_as_moving_average_of(E, beta=E_smoothing) print('building testing graph...') fake_X_val = test(E, Gs, real_test, submit_config) sess = tf.get_default_session() print('Getting training data...') image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train') image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test') summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: E.setup_weight_histograms() D.setup_weight_histograms() cur_nimg = start * submit_config.batch_size cur_tick = 0 tick_start_nimg = cur_nimg start_time = time.time() print('Optimization starts!!!') for it in range(start, max_iters): feed_dict = {real_train: sess.run(image_batch_train)} sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict) sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict) cur_nimg += submit_config.batch_size if it % 100 == 0: print("Iter: %06d kimg: %-8.1f time: %-12s" % (it, cur_nimg / 1000, dnnlib.util.format_time(time.time() - start_time))) sys.stdout.flush() tflib.autosummary.save_summaries(summary_log, it) if cur_nimg >= tick_start_nimg + 65000: cur_tick += 1 tick_start_nimg = cur_nimg if cur_tick % image_snapshot_ticks == 0: batch_images_test = sess.run(image_batch_test) batch_images_test = misc.adjust_dynamic_range( batch_images_test.astype(np.float32), [0, 255], [-1., 1.]) samples2 = sess.run(fake_X_val, feed_dict={real_test: batch_images_test}) samples2 = samples2.transpose(0, 2, 3, 1) batch_images_test = batch_images_test.transpose(0, 2, 3, 1) orin_recon = np.concatenate([batch_images_test, samples2], axis=0) imwrite(immerge(orin_recon, 2, submit_config.batch_size_test), '%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg)) if cur_tick % network_snapshot_ticks == 0: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs, NE), pkl) misc.save_pkl((E, G, D, Gs, NE), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close()
def __init__( self, lr, walk_type, nsliders, loss_type, eps, N_f, stylegan_opts, is_train=False, submit_config=None, G_args={}, # 生成网络的设置。 D_args={}, # 判别网络的设置。 G_opt_args={}, # 生成网络优化器设置。 D_opt_args={}, # 判别网络优化器设置。 G_loss_args={}, # 生成损失设置。 D_loss_args={}, # 判别损失设置。 dataset_args={}, # 数据集设置。 metric_arg_list=[], # 指标方法设置。 tf_config={}, # tflib.init_tf()相关设置。 G_smoothing_kimg=10.0, # 生成器权重的运行平均值的半衰期。 D_repeats=1, # G每迭代一次训练判别器多少次。 minibatch_repeats=4, # 调整训练参数前要运行的minibatch的数量。 reset_opt_for_new_lod=True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? total_kimg=15000, # 训练的总长度,以成千上万个真实图像为统计。 mirror_augment=False, # 启用镜像增强? drange_net=[-1, 1], # 将图像数据馈送到网络时使用的动态范围。 *args, **kwargs): assert (loss_type in ['l2', 'lpips']), 'unimplemented loss' assert (stylegan_opts.latent in ['z', 'w']), 'unknown latent space' self.dataset_name = stylegan_opts.dataset self.dataset_args = constants.net_info[stylegan_opts.dataset] self.latent = stylegan_opts.latent self.is_train = is_train self.walk_type = walk_type self.N_f = N_f # NN num_steps self.eps = eps # NN step_size self.Nsliders = nsliders if hasattr(stylegan_opts, 'truncation_psi'): self.psi = stylegan_opts.truncation_psi else: self.psi = 1.0 tflib.init_tf() with tf.device('/gpu:0'): with dnnlib.util.open_url(self.dataset_args['url'], cache_dir=config.cache_dir) as f: # can only unpickle where dnnlib is importable, so add to syspath G, D, Gs = pickle.load(f) # input placeholders Nsliders = nsliders dim_z = self.dim_z = Gs.input_shape[1] z = self.z = tf.placeholder(tf.float32, shape=(None, dim_z)) if is_train: # judge training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # G.print_layers(); D.print_layers() # 构建计算图与优化器 print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。 lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div( tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) #D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') #D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D.find_var('lod'), lod_in) ] #reals, labels = training_set.get_minibatch_tf() #reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) self.steer(G_gpu, gpu_scope='GPU%d/' % gpu) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): # G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D, # steer_z=self.latent_space_new_reshape, # steer_x=self.target, # latent=self.latent, # **G_loss_args ) G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) # with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): # D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) if loss_type == 'l2': self.joint_loss = G_loss + self.loss else: self.joint_loss = G_loss + self.loss_lpips G_opt.register_gradients(tf.reduce_mean(self.joint_loss), G_gpu.trainables) # D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) train_step = tf.train.AdamOptimizer(lr).minimize( self.joint_loss, var_list=tf.trainable_variables(scope=self.scope), name='AdamOpter') G_train_op = G_opt.apply_updates() # D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) self.G_train_op = G_train_op self.Gs_update_op = Gs_update_op self.G = G self.D = D self.Gs = Gs self.G_opt = G_opt self.lod_in = lod_in self.lrate_in = lrate_in self.minibatch_in = minibatch_in else: self.steer(G) # set the scope to be 'walk' if loss_type == 'l2': train_step = tf.train.AdamOptimizer(lr).minimize( self.loss, var_list=tf.trainable_variables(scope=self.scope), name='AdamOpter') elif loss_type == 'lpips': train_step = tf.train.AdamOptimizer(lr).minimize( self.loss_lpips, var_list=tf.trainable_variables(scope=self.scope), name='AdamOpter') self.train_step = train_step
def training_loop( submit_config, HP_args={}, # Options for the Hessian Penalty. G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? interp_snapshot_ticks=20, # How often to generate interpolation visualizations in TensorBoard? network_snapshot_ticks=5, # How often to export network snapshots? network_metric_ticks=5, # How often to evaluate network snapshots on specified metrics? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create a copy of dataset_args for running the metrics: metrics_dataset_args = deepcopy(dataset_args) metrics_dataset_args.shuffle_mb = 0 print('Saving interp videos every %s ticks' % interp_snapshot_ticks) print('Saving network snapshot every %s ticks' % network_snapshot_ticks) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # G.print_layers(); D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # The loss weighting of the Hessian Penalty can be dynamic over training, so we need to use a placeholder: hessian_weight = tf.placeholder(tf.float32, name='hessian_weight', shape=[]) G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) reg_ops = [ ] # Returning the values of the Hessian Penalty/ InfoGAN losses so they can be registered in TensorBoard for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss, G_hessian_penalty = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, hp_lambda=hessian_weight, HP_args=HP_args, gpu_ix=gpu, lod_in=lod_in, max_lod=training_set.resolution_log2, **G_loss_args) if HP_args.hp_lambda > 0: reg_ops += [G_hessian_penalty] with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss, mutual_info = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, gpu_ix=gpu, infogan_nz=D_args.infogan_nz, **D_loss_args) # print([name for name in D_gpu.trainables.keys()]) # gps = [weight for name, weight in G_gpu.trainables.items()][0] # dps = [weight for name, weight in D_gpu.trainables.items() if 'Q_Encoder' in name][0] # gg = autosummary('Loss/G_info_grad', tf.reduce_sum(tf.gradients(mutual_info, gps)[0]**2)) # dg = autosummary('Loss/D_info_grad', tf.reduce_sum(tf.gradients(mutual_info, dps)[0]**2)) # reg_ops.extend([dg, gg, dps, gps]) if G_args.infogan_lambda > 0 or D_args.infogan_lambda > 0: reg_ops += [mutual_info] # Note, even though we are adding mutual_info loss here, the only time the loss is non-zero # is when infogan_lambda > 0 (in Hessian Penalty experiments, we always set it to zero): G_opt.register_gradients( G_loss + G_args.infogan_lambda * mutual_info, G_gpu.trainables) D_opt.register_gradients( tf.reduce_mean(D_loss) + D_args.infogan_lambda * mutual_info, D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up snapshot interpolation...') nz = G.input_shapes[0][1] interp_steps = 24 # Number of frames in the visualization interp_batch_size = 8 # Number of gifs per row of visualization interp_z = vis_tools.sample_interp_zs(nz, interp_batch_size, interp_steps) interp_labels = np.zeros( [interp_steps * interp_batch_size * nz, training_set.label_size], dtype=training_set.label_dtype) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) summary_log.add_summary( build_image_summary(os.path.join(submit_config.run_dir, 'reals.png'), 'samples/real'), 0) summary_log.add_summary( build_image_summary( os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), 'samples/Gs'), resume_kimg) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) if interp_snapshot_ticks != -1 and interp_snapshot_ticks < 9999: print('Generating initial interpolations...') vis_tools.make_and_save_interpolation_gifs( Gs, interp_z, interp_labels, minibatch_size=sched.minibatch // submit_config.num_gpus, interp_steps=interp_steps, interp_batch_size=interp_batch_size, cur_kimg=resume_kimg, summary_log=summary_log) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 num_G_grad_steps = 0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch cur_hessian_weight = get_current_hessian_penalty_loss_weight( HP_args.hp_lambda, HP_args.hp_start_nimg, cur_nimg, HP_args.warmup_nimg) tflib.run( [G_train_op] + reg_ops, { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, hessian_weight: cur_hessian_weight }) num_G_grad_steps += 1 # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d hessian_weight %s time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), autosummary('Progress/hessian_weight', cur_hessian_weight), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) autosummary('Progress/G_grad_steps', num_G_grad_steps) # Save snapshots and fake image samples (for both Gs and G): if cur_tick % image_snapshot_ticks == 0 or done: iter = (cur_nimg // 1000) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes_inst = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) fake_path = os.path.join(submit_config.run_dir, 'fakes%06d.png' % iter) ifake_path = os.path.join(submit_config.run_dir, 'ifakes%06d.png' % iter) misc.save_image_grid(grid_fakes, fake_path, drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_fakes_inst, ifake_path, drange=drange_net, grid_size=grid_size) summary_log.add_summary( build_image_summary(fake_path, 'samples/Gs'), iter) summary_log.add_summary( build_image_summary(ifake_path, 'samples/G'), iter) # Generate/Save Interpolation Visualizations: if interp_snapshot_ticks != -1 and cur_tick % interp_snapshot_ticks == 0: vis_tools.make_and_save_interpolation_gifs( Gs, interp_z, interp_labels, minibatch_size=sched.minibatch // submit_config.num_gpus, interp_steps=interp_steps, interp_batch_size=interp_batch_size, cur_kimg=cur_nimg // 1000, summary_log=summary_log) # Save snapshot and run metrics: if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) if cur_tick % network_metric_ticks == 0 or done or cur_tick == 1: metrics.run(pkl, dataset_args=metrics_dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % total_kimg)) summary_log.close() ctx.close()
def training_loop( submit_config, Encoder_args={}, D_args={}, G_args={}, E_opt_args={}, D_opt_args={}, E_loss_args=EasyDict(), D_loss_args={}, lr_args=EasyDict(), tf_config={}, dataset_args=EasyDict(), decoder_pkl=EasyDict(), drange_data=[0, 255], drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. mirror_augment=False, filter=64, # Minimum number of feature maps in any layer. filter_max=512, # Maximum number of feature maps in any layer. resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? d_scale=0.1, # Decide whether to update discriminator. pretrained_D=True, # Whether to use pre trained Discriminator. max_iters=150000): tflib.init_tf(tf_config) with tf.name_scope('Input'): real_train = tf.placeholder(tf.float32, [ submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size ], name='real_image_train') real_test = tf.placeholder(tf.float32, [ submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size ], name='real_image_test') real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0) with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) E, G, D, Gs = misc.load_pkl(network_pkl) G_style_mod = tflib.Network('G_StyleMod', resolution=submit_config.image_size, label_size=0, **G_args) start = int(network_pkl.split('-')[-1].split('.') [0]) // submit_config.batch_size print('Start: ', start) else: print('Constructing networks...') G, PreD, Gs = misc.load_pkl(decoder_pkl.decoder_pkl) num_layers = Gs.components.synthesis.input_shape[1] E = tflib.Network('E_gpu0', size=submit_config.image_size, filter=filter, filter_max=filter_max, num_layers=num_layers, is_training=True, num_gpus=submit_config.num_gpus, **Encoder_args) OriD = tflib.Network('D_ori', resolution=submit_config.image_size, label_size=0, **D_args) G_style_mod = tflib.Network('G_StyleMod', resolution=submit_config.image_size, label_size=0, **G_args) if pretrained_D: D = PreD else: D = OriD start = 0 Gs_vars_pairs = { name: tflib.run(val) for name, val in Gs.components.synthesis.vars.items() } for g_name, g_val in G_style_mod.vars.items(): tflib.set_vars({g_val: Gs_vars_pairs[g_name]}) E.print_layers() Gs.print_layers() D.print_layers() global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step, lr_args.decay_rate, staircase=lr_args.stair) add_global0 = global_step0.assign_add(1) E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args) if d_scale > 0: D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) E_loss_rec = 0. E_loss_adv = 0. D_loss_real = 0. D_loss_fake = 0. D_loss_grad = 0. for gpu in range(submit_config.num_gpus): print('Building Graph on GPU %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): E_gpu = E if gpu == 0 else E.clone(E.name[:-1] + str(gpu)) D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') G_gpu = G_style_mod if gpu == 0 else G_style_mod.clone( G_style_mod.name + '_shadow') feature_model = PerceptualModel(img_size=[ E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size ], multi_layers=False) real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net) with tf.name_scope('E_loss'), tf.control_dependencies(None): E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, feature_model=feature_model, reals=real_gpu, **E_loss_args) E_loss_rec += recon_loss E_loss_adv += adv_loss with tf.name_scope('D_loss'), tf.control_dependencies(None): D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args) D_loss_real += loss_real D_loss_fake += loss_fake D_loss_grad += loss_gp with tf.control_dependencies([add_global0]): E_opt.register_gradients(E_loss, E_gpu.trainables) if d_scale > 0: D_opt.register_gradients(D_loss, D_gpu.trainables) E_loss_rec /= submit_config.num_gpus E_loss_adv /= submit_config.num_gpus D_loss_real /= submit_config.num_gpus D_loss_fake /= submit_config.num_gpus D_loss_grad /= submit_config.num_gpus E_train_op = E_opt.apply_updates() if d_scale > 0: D_train_op = D_opt.apply_updates() print('Building testing graph...') fake_X_val = test(E, G_style_mod, real_test, submit_config) sess = tf.get_default_session() print('Getting training data...') image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train') image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test') summary_log = tf.summary.FileWriter(submit_config.run_dir) cur_nimg = start * submit_config.batch_size cur_tick = 0 tick_start_nimg = cur_nimg start_time = time.time() print('Optimization starts!!!') for it in range(start, max_iters): batch_images = sess.run(image_batch_train) feed_dict = {real_train: batch_images} _, recon_, adv_, lr = sess.run( [E_train_op, E_loss_rec, E_loss_adv, learning_rate], feed_dict) if d_scale > 0: _, d_r_, d_f_, d_g_ = sess.run( [D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict) cur_nimg += submit_config.batch_size run_time = time.time() - start_time iter_time = run_time / (it - start + 1) eta_time = iter_time * (max_iters - it - 1) if it % 50 == 0: print( 'Iter: %06d/%d, lr: %-.8f recon_loss: %-6.4f adv_loss: %-6.4f run_time:%-12s eta_time:%-12s' % (it, max_iters, lr, recon_, adv_, dnnlib.util.format_time(time.time() - start_time), dnnlib.util.format_time(eta_time))) if d_scale > 0: print('d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f ' % (d_r_, d_f_, d_g_)) sys.stdout.flush() tflib.autosummary.save_summaries(summary_log, it) if cur_nimg >= tick_start_nimg + 65000: cur_tick += 1 tick_start_nimg = cur_nimg if cur_tick % image_snapshot_ticks == 0: batch_images_test = sess.run(image_batch_test) batch_images_test = misc.adjust_dynamic_range( batch_images_test.astype(np.float32), [0, 255], [-1., 1.]) recon = sess.run(fake_X_val, feed_dict={real_test: batch_images_test}) orin_recon = np.concatenate([batch_images_test, recon], axis=0) orin_recon = adjust_pixel_range(orin_recon) orin_recon = fuse_images(orin_recon, row=2, col=submit_config.batch_size_test) # save image results during training, first row is original images and the second row is reconstructed images save_image( '%s/iter_%09d.png' % (submit_config.run_dir, cur_nimg), orin_recon) if cur_tick % network_snapshot_ticks == 0: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%09d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs), pkl) misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close()
def training_loop( submit_config, #--------------------------------------------------------------- # Modified by Deng et al. noise_dim=32, weight_args={}, train_stage_args={}, #--------------------------------------------------------------- G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=True, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=87, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=2364, # Snapshot index to resume training from, None = autodetect. resume_kimg=2364, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, **_kwargs ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. PI = 3.1415927 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create 3d face reconstruction block FaceRender = Face3D() # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') #--------------------------------------------------------------- # Modified by Deng et al. G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, latent_size=254 + noise_dim, **G_args) #--------------------------------------------------------------- D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) resolution = tf.placeholder(tf.float32, name='resolution', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) #--------------------------------------------------------------- # Modified by Deng et al. G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\ G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\ lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\ drange_net=drange_net,lod_in=lod_in,**train_stage_args) #--------------------------------------------------------------- G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) #--------------------------------------------------------------- # Modified by Deng et al. restore_weights_and_initialize(train_stage_args) print('Setting up snapshot image grid...') sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_latents = tf.random_normal([np.prod(grid_size), 128 + 32 + 16 + 3]) grid_INPUTcoeff = z_to_lambda_mapping(grid_latents) grid_INPUTcoeff_w_t = tf.concat( [grid_INPUTcoeff, tf.zeros([np.prod(grid_size), 3])], axis=1) with tf.name_scope('FaceRender'): grid_render_img, _, _, _ = FaceRender.Reconstruction_Block( grid_INPUTcoeff_w_t, 256, np.prod(grid_size), progressive=False) grid_render_img = tf.transpose(grid_render_img, perm=[0, 3, 1, 2]) grid_render_img = process_reals(grid_render_img, lod_in, False, training_set.dynamic_range, drange_net) grid_INPUTcoeff_, grid_renders = tflib.run( [grid_INPUTcoeff, grid_render_img], {lod_in: sched.lod}) grid_noise = np.random.randn(np.prod(grid_size), 32) grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_, grid_noise], axis=1) grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) #--------------------------------------------------------------- summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) # print('iter') # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: #--------------------------------------------------------------- # Modified by Deng et al. grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) #--------------------------------------------------------------- if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close() #----------------------------------------------------------------------------
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device("/gpu:0"): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print("Constructing networks...") G = tflib.Network( name="G", num_inputs=2, # one for latents and one for labels num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network( name="D", num_inputs=int(np.log2(training_set.shape[1])) - 1 + 1, # +1 for labels :) num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone("Gs") G.print_layers() D.print_layers() print("Building TensorFlow graph...") with tf.name_scope("Inputs"), tf.device("/cpu:0"): lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[]) minibatch_in = tf.placeholder(tf.int32, name="minibatch_in", shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = (0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0) G_opt = tflib.Optimizer(name="TrainG", learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name="TrainD", learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow") D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow") reals, labels = training_set.get_minibatch_tf() reals = process_reals( reals, mirror_augment, training_set.dynamic_range, drange_net, depth=training_set.resolution_log2 - 1, ) with tf.name_scope("G_loss"): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope("D_loss"): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device("/gpu:0"): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) print("Setting up snapshot image grid...") grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_size // submit_config.num_gpus, ) print("Setting up run dir...") fake_multi_scale_dirs = [ os.path.join(submit_config.run_dir, str(2**res) + "x" + str(2**res)) for res in range(2, 2 + len(grid_fakes)) ] misc.save_image_grid( grid_reals, os.path.join(submit_config.run_dir, "reals.png"), drange=training_set.dynamic_range, grid_size=grid_size, ) misc.save_image_grids( grid_fakes, [ os.path.join(fake_multi_scale_dir, "fakes%06d.png" % resume_kimg) for fake_multi_scale_dir in fake_multi_scale_dirs ], drange=drange_net, grid_size=grid_size, ) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print("Training...\n") ctx.update("", cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg # configure the training_set to a proper minibatch size training_set.configure(sched.minibatch_size // submit_config.num_gpus) while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lrate_in: sched.D_lrate, minibatch_in: sched.minibatch_size }, ) cur_nimg += sched.minibatch_size tflib.run( [G_train_op], { lrate_in: sched.G_lrate, minibatch_in: sched.minibatch_size }, ) # Perform maintenance tasks once per tick. done = cur_nimg >= total_kimg * 1000 if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( "tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f" % ( autosummary("Progress/tick", cur_tick), autosummary("Progress/kimg", cur_nimg / 1000.0), autosummary("Progress/minibatch", sched.minibatch_size), dnnlib.util.format_time( autosummary("Timing/total_sec", total_time)), autosummary("Timing/sec_per_tick", tick_time), autosummary("Timing/sec_per_kimg", tick_time / tick_kimg), autosummary("Timing/maintenance_sec", maintenance_time), autosummary("Resources/peak_gpu_mem_gb", peak_gpu_mem_op.eval() / 2**30), )) autosummary("Timing/total_hours", total_time / (60.0 * 60.0)) autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_size // submit_config.num_gpus, ) misc.save_image_grids( grid_fakes, [ os.path.join(fake_multi_scale_dir, "fakes%06d.png" % (cur_nimg // 1000)) for fake_multi_scale_dir in fake_multi_scale_dirs ], drange=drange_net, grid_size=grid_size, ) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, "network-snapshot-%06d.pkl" % (cur_nimg // 1000), ) misc.save_pkl((G, D, Gs), pkl) metrics.run( pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config, ) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update(cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, "network-final.pkl")) summary_log.close() ctx.close()
def training_loop( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. G_opt_args = {}, # Options for generator optimizer. D_opt_args = {}, # Options for discriminator optimizer. G_loss_args = {}, # Options for generator loss. D_loss_args = {}, # Options for discriminator loss. dataset_args = {}, # Options for dataset.load_dataset(). sched_args = {}, # Options for train.TrainingSchedule. grid_args = {}, # Options for train.setup_snapshot_image_grid(). metric_arg_list = [], # Options for MetricGroup. tf_config = {}, # Options for tflib.init_tf(). data_dir = None, # Directory to load datasets from. G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. lazy_regularization = True, # Perform regularization as a separate training step? G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg = 25000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? drange_net = [0,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_pkl = None, # Network pickle to resume training from, None = train from scratch. resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. training_set.configure(minibatch_size=sched_args.batch_size) with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') start = 0 if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG); D.copy_vars_from(rD); Gs.copy_vars_from(rGs) else: G = rG; D = rD; Gs = rGs start = int(resume_pkl.split('-')[-1].split('.')[0]) // sched_args.batch_size # Print layers and generate initial image snapshot. G.print_layers(); D.print_layers() grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched_args.batch_size) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) global_step = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(sched_args.lr, global_step, sched_args.decay_step, sched_args.decay_rate, staircase=sched_args.stair) add_global = global_step.assign_add(1) D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) G_opt = tflib.Optimizer(name='TrainG', learning_rate=learning_rate, **G_opt_args) for gpu in range(num_gpus): print('build graph on gpu %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') with tf.name_scope('DataFetch'): reals_read, labels_read = training_set.get_minibatch_tf() reals_read, labels_read = process_reals(reals_read, labels_read, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('Loss'), tf.control_dependencies(None): loss, reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=sched_args.batch_size, reals=reals_read, labels=labels_read, **D_loss_args) with tf.control_dependencies([add_global]): G_opt.register_gradients(loss, G_gpu.trainables) D_opt.register_gradients(loss, D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break loss_, _, _, lr_ = tflib.run([loss, G_train_op, D_train_op, learning_rate]) cur_nimg += sched_args.batch_size * num_gpus done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched_args.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f ' 'sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f loss %-8.1f lr %-2.5f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched_args.batch_size), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2 ** 30), autosummary('loss', loss_), autosummary('lr', lr_))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched_args.batch_size) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0.0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( submit_config, Encoder_args = {}, E_opt_args = {}, D_opt_args = {}, E_loss_args = EasyDict(), D_loss_args = {}, lr_args = EasyDict(), tf_config = {}, dataset_args = EasyDict(), decoder_pkl = EasyDict(), drange_data = [0, 255], drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. mirror_augment = False, resume_run_id = config.ENCODER_PICKLE_DIR, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot = None, # Snapshot index to resume training from, None = autodetect. image_snapshot_ticks = 1, # How often to export image snapshots? network_snapshot_ticks = 4, # How often to export network snapshots? max_iters = 150000): tflib.init_tf(tf_config) with tf.name_scope('input'): real_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='real_image_train') real_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='real_image_test') real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0) with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) E, G, D, Gs = misc.load_pkl(network_pkl) start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size print('Start: ', start) else: print('Constructing networks...') G, D, Gs = misc.load_pkl(decoder_pkl.decoder_pkl) num_layers = Gs.components.synthesis.input_shape[1] E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, num_layers=num_layers, phase=True, **Encoder_args) start = 0 E.print_layers(); Gs.print_layers(); D.print_layers() global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step, lr_args.decay_rate, staircase=lr_args.stair) add_global0 = global_step0.assign_add(1) E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) E_loss_rec = 0. E_loss_adv = 0. D_loss_real = 0. D_loss_fake = 0. D_loss_grad = 0. for gpu in range(submit_config.num_gpus): print('build graph on gpu %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow') perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False) real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net) with tf.name_scope('E_loss'), tf.control_dependencies(None): E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, perceptual_model=perceptual_model, reals=real_gpu, **E_loss_args) E_loss_rec += recon_loss E_loss_adv += adv_loss with tf.name_scope('D_loss'), tf.control_dependencies(None): D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args) D_loss_real += loss_real D_loss_fake += loss_fake D_loss_grad += loss_gp with tf.control_dependencies([add_global0]): E_opt.register_gradients(E_loss, E_gpu.trainables) D_opt.register_gradients(D_loss, D_gpu.trainables) E_loss_rec /= submit_config.num_gpus E_loss_adv /= submit_config.num_gpus D_loss_real /= submit_config.num_gpus D_loss_fake /= submit_config.num_gpus D_loss_grad /= submit_config.num_gpus E_train_op = E_opt.apply_updates() D_train_op = D_opt.apply_updates() print('building testing graph...') fake_X_val = test(E, Gs, real_test, submit_config) sess = tf.get_default_session() print('Getting training data...') image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train') image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test') summary_log = tf.summary.FileWriter(config.GDRIVE_PATH) cur_nimg = start * submit_config.batch_size cur_tick = 0 tick_start_nimg = cur_nimg start_time = time.time() init_pascal = tf.initialize_variables( [global_step0], name='init_pascal' ) sess.run(init_pascal) print('Optimization starts!!!') for it in range(start, max_iters): batch_images = sess.run(image_batch_train) feed_dict_1 = {real_train: batch_images} _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict_1) _, d_r_, d_f_, d_g_ = sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict_1) cur_nimg += submit_config.batch_size if it % 50 == 0: print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % ( it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time))) sys.stdout.flush() tflib.autosummary.save_summaries(summary_log, it) if it % 500 == 0: batch_images_test = sess.run(image_batch_test) batch_images_test = misc.adjust_dynamic_range(batch_images_test.astype(np.float32), [0, 255], [-1., 1.]) samples2 = sess.run(fake_X_val, feed_dict={real_test: batch_images_test}) orin_recon = np.concatenate([batch_images_test, samples2], axis=0) orin_recon = adjust_pixel_range(orin_recon) orin_recon = fuse_images(orin_recon, row=2, col=submit_config.batch_size_test) # save image results during training, first row is original images and the second row is reconstructed images save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), orin_recon) # save image to gdrive img_path = os.path.join(config.GDRIVE_PATH, 'images', ('iter_%08d.png' % (cur_nimg))) save_image(img_path, orin_recon) if cur_nimg >= tick_start_nimg + 65000: cur_tick += 1 tick_start_nimg = cur_nimg if cur_tick % network_snapshot_ticks == 0: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs), pkl) # save network snapshot to gdrive pkl_drive = os.path.join(config.GDRIVE_PATH, 'snapshots', 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs), pkl_drive) misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close()
# Define Variables seed_latent = tf.Variable(np.zeros((1, 512)), dtype=tf.float32) target = tf.Variable(np.zeros((1, 3, 1024, 1024)), dtype=np.float32) # Define network Gs = load_network() fea_ext_name_to_var_name = {'inception_v3': 'InceptionV3'} feature_extractor = get_feature_extractor(args) noise_loss, fake_images_out = get_nosie_output(args, Gs, seed_latent, None, target, feature_extractor) # Apply optimizers lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) noise_opt = tflib.Optimizer( name='Noise', learning_rate=lrate_in, beta1=0.0, beta2=0.99, epsilon=1e-8) noise_opt.register_gradients(tf.reduce_mean( noise_loss), OrderedDict([('latents', seed_latent)])) noise_update_op = noise_opt.apply_updates() # Get TF Session sess = tf.get_default_session() # Initialize variables uninstizalized_vars = [seed_latent, target] init_op = tf.variables_initializer(uninstizalized_vars) sess.run([init_op]) if args.feature_extractor is not None and args.feature_extractor != 'D': extractor_variables = slim.get_variables_to_restore() extractor_variables = [ x for x in extractor_variables if fea_ext_name_to_var_name[args.feature_extractor] in x.name]
def training_auto_loop( Enc_args={}, # Options for encoder network. Dec_args={}, # Options for decoder network. opt_args={}, # Options for encoder optimizer. loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=25000, # Total length of the training, measured in thousands of real images. drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False ): # Include weight histograms in the tfevents file? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, _ = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): print('Constructing networks...') Enc = tflib.Network('Encoder', resolution=training_set.shape[1], out_channels=16, **Enc_args) Dec = tflib.Network('Decoder', resolution=training_set.shape[1] // 4, in_channels=16, **Dec_args) # Print layers and generate initial image snapshot. Enc.print_layers() Dec.print_layers() sched = sched_args sched.tick_kimg = 4 grid_codes = Enc.run((grid_reals / 127.5) - 1.0, minibatch_size=sched.minibatch_gpu) grid_fakes = Dec.run(grid_codes, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. opt_args = dict(opt_args) opt_args['minibatch_multiplier'] = minibatch_multiplier opt_args['learning_rate'] = lrate_in opt = tflib.Optimizer(name='TrainAuto', **opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of Enc and Dec. Enc_gpu = Enc if gpu == 0 else Enc.clone(Enc.name + '_shadow') Dec_gpu = Dec if gpu == 0 else Dec.clone(Dec.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, _ = process_reals(reals_write, labels_write, 0.0, False, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] reals_read = reals_var[:minibatch_gpu_in] # Evaluate loss functions. with tf.name_scope('loss'): loss = dnnlib.util.call_func_by_name(Enc=Enc_gpu, Dec=Dec_gpu, opt=opt, reals=reals_read, **loss_args) # Register gradients. opt.register_gradients( tf.reduce_mean(loss), list(Enc_gpu.trainables.values()) + list(Dec_gpu.trainables.values())) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) train_op = opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: Enc.setup_weight_histograms() Dec.setup_weight_histograms() print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = 0 cur_tick = -1 tick_start_nimg = cur_nimg while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # Run training ops. feed_dict = { lrate_in: sched.lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) cur_nimg += sched.minibatch_size # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run(data_fetch_op, feed_dict) tflib.run(train_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_codes = Enc.run((grid_reals / 127.5) - 1.0, minibatch_size=sched.minibatch_gpu) grid_fakes = Dec.run(grid_codes, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((Enc, Dec), pkl) # Update summaries and RunContext. tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0.0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((Enc, Dec), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( classifier_args={}, # Options for generator network. classifier_opt_args={}, # Options for generator optimizer. classifier_loss_args={}, dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. network_snapshot_ticks=5, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False): # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, shuffle_mb=2 * 4096, **dataset_args) # Construct or load networks. with tf.device('/gpu:0'): print('Constructing networks...') classifier = tflib.Network('classifier', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **classifier_args) classifier.print_layers() # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. classifier_opt_args = dict(classifier_opt_args) classifier_opt_args['minibatch_multiplier'] = minibatch_multiplier classifier_opt_args['learning_rate'] = lrate_in classifier_opt = tflib.Optimizer(name='TrainClassifier', **classifier_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('gpu%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. classifier_gpu = classifier if gpu == 0 else classifier.clone( classifier.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=0, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros( [sched.minibatch_gpu, 127])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. with tf.name_scope('classifier_loss'): classifier_loss, label = dnnlib.util.call_func_by_name( classifier=classifier_gpu, images=reals_read, labels=labels_read, **classifier_loss_args) classifier_opt.register_gradients(tf.reduce_mean(classifier_loss), classifier_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) classifier_train_op = classifier_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=0, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = 0 cur_tick = -1 tick_start_nimg = cur_nimg running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # Run training ops. feed_dict = { lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([classifier_train_op, data_fetch_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(data_fetch_op, feed_dict) classifier_loss_out, label_out, _ = tflib.run( [classifier_loss, label, classifier_train_op], feed_dict) print_output = False if print_output: print('label') print(np.round(label_out, 2)) print('loss') print(np.round(classifier_loss_out, 2)) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl(classifier, pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl(classifier, dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop_hd( I_args={}, # Options for generator network. M_args={}, # Options for discriminator network. I_info_args={}, # Options for class network. I_opt_args={}, # Options for generator optimizer. I_loss_args={}, # Options for generator loss. resume_G_pkl=None, # G network pickle to help training. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. I_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? I_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to I_args and M_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=0, # Number of discrete latents in model. n_continuous=10, # Number of continuous latents in model. n_samples_per=4, # Number of samples for each line in traversal. use_hd_with_cls=False, # If use info_loss. resolution_manual=1024, # Resolution of generated images. use_level_training=False, # If use level training (hierarchical optimization strategy). level_I_kimg=1000, # Number of kimg of tick for I_level training. use_std_in_m=False, # If output prior std in M net. prior_latent_size=512, # Prior latent size. use_hyperplane=False, # If use hyperplane model. pretrained_type='with_stylegan2'): # Pretrained type for G. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') I = tflib.Network('I', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **I_args) M = tflib.Network('M', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **M_args) Is = I.clone('Is') if use_hd_with_cls: I_info = tflib.Network('I_info', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **I_info_args) if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_hd_with_cls: rI, rM, rIs, rI_info = misc.load_pkl(resume_pkl) else: rI, rM, rIs = misc.load_pkl(resume_pkl) if resume_with_new_nets: I.copy_vars_from(rI) M.copy_vars_from(rM) Is.copy_vars_from(rIs) if use_hd_with_cls: I_info.copy_vars_from(rI_info) else: I = rI M = rM Is = rIs if use_hd_with_cls: I_info = rI_info print('Loading generator from "%s"...' % resume_G_pkl) if pretrained_type == 'with_stylegan2': rG, rD, rGs = misc.load_pkl(resume_G_pkl) G = rG D = rD Gs = rGs elif pretrained_type == 'with_cascadeVAE': rG = misc.load_pkl(resume_G_pkl) G = rG Gs = rG # Print layers and generate initial image snapshot. I.print_layers() M.print_layers() # pdb.set_trace() training_set_resolution_log2 = int(np.log2(resolution_manual)) sched = training_schedule( cur_nimg=total_kimg * 1000, training_set_resolution_log2=training_set_resolution_log2, **sched_args) if not use_hyperplane: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels) print('grid_size:', grid_size) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) if resolution_manual >= 256: grid_size = (grid_size[0], grid_size[1] // 5) grid_latents = grid_latents[:grid_latents.shape[0] // 5] grid_labels = grid_labels[:grid_labels.shape[0] // 5] prior_traj_latents = M.run(grid_latents, is_validation=True, minibatch_size=sched.minibatch_gpu) if use_std_in_m: prior_traj_latents = prior_traj_latents[:, :prior_latent_size] else: grid_size = (n_samples_per, n_continuous) grid_labels = np.tile(grid_labels[:1], (n_continuous * n_samples_per, 1)) latent_dirs = get_latent_dirs(n_continuous) prior_traj_latents = get_prior_traj_by_dirs(latent_dirs, M, n_samples_per, prior_latent_size, grid_labels, sched) if resolution_manual >= 256: grid_size = (grid_size[0], grid_size[1] // 5) prior_traj_latents = prior_traj_latents[:prior_traj_latents. shape[0] // 5] grid_labels = grid_labels[:grid_labels.shape[0] // 5] print('prior_traj_latents.shape:', prior_traj_latents.shape) # pdb.set_trace() prior_traj_latents_show = np.reshape( prior_traj_latents, [-1, n_samples_per, prior_latent_size]) print_traj(prior_traj_latents_show) grid_fakes = Gs.run(prior_traj_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) if (n_continuous == 2) and (n_discrete == 0): n_per_line = 20 ex_latent_value = 3 prior_grid_latents, prior_grid_labels = get_2d_grid_latents( low=-ex_latent_value, high=ex_latent_value, n_per_line=n_per_line, grid_labels=grid_labels) grid_showing_fakes = Gs.run(prior_grid_latents, prior_grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_showing_fakes = add_outline(grid_showing_fakes, width=1) misc.save_image_grid( grid_showing_fakes, dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png'), drange=drange_net, grid_size=[n_per_line, n_per_line]) img_to_draw = Image.open( dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png')) img_to_draw = img_to_draw.convert('RGB') img_to_draw = draw_traj_on_prior_grid(img_to_draw, prior_traj_latents_show, ex_latent_value, n_per_line) img_to_draw.save( dnnlib.make_run_dir_path('fakes_init_2d_prior_grid_drawn.png')) if use_level_training: ending_level = training_set_resolution_log2 - 1 else: ending_level = 1 # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Is_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), I_smoothing_kimg * 1000.0) if I_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. I_opt_args = dict(I_opt_args) for args, reg_interval in [(I_opt_args, I_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio I_opts = [] I_reg_opts = [] for n_level in range(ending_level): I_opts.append(tflib.Optimizer(name='TrainI_%d' % n_level, **I_opt_args)) I_reg_opts.append( tflib.Optimizer(name='RegI_%d' % n_level, share=I_opts[-1], **I_opt_args)) # Build training graph for each GPU. for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of I and M. I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') M_gpu = M if gpu == 0 else M.clone(M.name + '_shadow') G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if use_hd_with_cls: I_info_gpu = I_info if gpu == 0 else I_info.clone(I_info.name + '_shadow') # Evaluate loss functions. lod_assign_ops = [] I_losses = [] I_regs = [] if 'lod' in I_gpu.vars: lod_assign_ops += [tf.assign(I_gpu.vars['lod'], lod_in)] if 'lod' in M_gpu.vars: lod_assign_ops += [tf.assign(M_gpu.vars['lod'], lod_in)] for n_level in range(ending_level): with tf.control_dependencies(lod_assign_ops): with tf.name_scope('I_loss_%d' % n_level): if use_hd_with_cls: I_loss, I_reg = dnnlib.util.call_func_by_name( I=I_gpu, M=M_gpu, G=G_gpu, I_info=I_info_gpu, opt=I_opts[n_level], training_set=training_set, minibatch_size=minibatch_gpu_in, **I_loss_args) else: I_loss, I_reg = dnnlib.util.call_func_by_name( I=I_gpu, M=M_gpu, G=G_gpu, opt=I_opts[n_level], n_levels=(n_level + 1) if use_level_training else None, training_set=training_set, minibatch_size=minibatch_gpu_in, **I_loss_args) I_losses.append(I_loss) I_regs.append(I_reg) # Register gradients. if not lazy_regularization: if I_regs[n_level] is not None: I_losses[n_level] += I_regs[n_level] else: if I_regs[n_level] is not None: I_reg_opts[n_level].register_gradients( tf.reduce_mean(I_regs[n_level] * I_reg_interval), I_gpu.trainables) if use_hd_with_cls: MIIinfo_gpu_trainables = collections.OrderedDict( list(M_gpu.trainables.items()) + list(I_gpu.trainables.items()) + list(I_info_gpu.trainables.items())) I_opts[n_level].register_gradients( tf.reduce_mean(I_losses[n_level]), MIIinfo_gpu_trainables) else: MI_gpu_trainables = collections.OrderedDict( list(M_gpu.trainables.items()) + list(I_gpu.trainables.items())) I_opts[n_level].register_gradients( tf.reduce_mean(I_losses[n_level]), MI_gpu_trainables) # Setup training ops. I_train_ops = [] I_reg_ops = [] for n_level in range(ending_level): I_train_ops.append(I_opts[n_level].apply_updates()) I_reg_ops.append(I_reg_opts[n_level].apply_updates(allow_no_op=True)) Is_update_op = Is.setup_as_moving_average_of(I, beta=Is_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: I.setup_weight_histograms() M.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break n_level = 0 if not use_level_training else min( cur_nimg // (level_I_kimg * 1000), training_set_resolution_log2 - 2) # Choose training parameters and configure training ops. sched = training_schedule( cur_nimg=cur_nimg, training_set_resolution_log2=training_set_resolution_log2, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): I_opts[n_level].reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.I_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_I_reg = (lazy_regularization and running_mb_counter % I_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run(I_train_ops[n_level], feed_dict) if run_I_reg: tflib.run(I_reg_ops[n_level], feed_dict) tflib.run([Is_update_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(I_train_ops[n_level], feed_dict) if run_I_reg: for _round in rounds: tflib.run(I_reg_ops[n_level], feed_dict) tflib.run(Is_update_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 100 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): if not use_hyperplane: prior_traj_latents = M.run( grid_latents, is_validation=True, minibatch_size=sched.minibatch_gpu) if use_std_in_m: prior_traj_latents = prior_traj_latents[:, : prior_latent_size] else: prior_traj_latents = get_prior_traj_by_dirs( latent_dirs, M, n_samples_per, prior_latent_size, grid_labels, sched) if resolution_manual >= 256: prior_traj_latents = prior_traj_latents[: prior_traj_latents .shape[0] // 5] prior_traj_latents_show = np.reshape( prior_traj_latents, [-1, n_samples_per, prior_latent_size]) print_traj(prior_traj_latents_show) grid_fakes = Gs.run(prior_traj_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if (n_continuous == 2) and (n_discrete == 0): n_per_line = 20 ex_latent_value = 3 prior_grid_latents, prior_grid_labels = get_2d_grid_latents( low=-ex_latent_value, high=ex_latent_value, n_per_line=n_per_line, grid_labels=grid_labels) grid_showing_fakes = Gs.run( prior_grid_latents, prior_grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_showing_fakes = add_outline(grid_showing_fakes, width=1) misc.save_image_grid(grid_showing_fakes, dnnlib.make_run_dir_path( 'fakes_2d_prior_grid%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=[n_per_line, n_per_line]) img_to_draw = Image.open( dnnlib.make_run_dir_path( 'fakes_2d_prior_grid%06d.png' % (cur_nimg // 1000))) img_to_draw = img_to_draw.convert('RGB') img_to_draw = draw_traj_on_prior_grid( img_to_draw, prior_traj_latents_show, ex_latent_value, n_per_line) img_to_draw.save( dnnlib.make_run_dir_path( 'fakes_2d_prior_grid_drawn%06d.png' % (cur_nimg // 1000))) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((I, M, Is), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((I, M, Is), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers(); D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates()
def training_loop( submit_config, Encoder_args = {}, E_opt_args = {}, D_opt_args = {}, E_loss_args = EasyDict(), D_loss_args = {}, lr_args = EasyDict(), tf_config = {}, dataset_args = EasyDict(), decoder_pkl = EasyDict(), inversion_pkl = EasyDict(), drange_data = [0, 255], drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. mirror_augment = False, resume_run_id = config.ENCODER_PICKLE_DIR, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot = None, # Snapshot index to resume training from, None = autodetect. image_snapshot_ticks = 1, # How often to export image snapshots? network_snapshot_ticks = 4, # How often to export network snapshots? max_iters = 150000): tflib.init_tf(tf_config) with tf.name_scope('input'): placeholder_real_portraits_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_portraits_train') placeholder_real_landmarks_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_train') placeholder_real_shuffled_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_shuffled_train') placeholder_landmarks_shuffled_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_landmarks_shuffled_train') placeholder_real_portraits_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_portraits_test') placeholder_real_landmarks_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_test') placeholder_real_shuffled_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_shuffled_test') placeholder_real_landmarks_shuffled_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_shuffled_test') real_split_landmarks = tf.split(placeholder_real_landmarks_train, num_or_size_splits=submit_config.num_gpus, axis=0) real_split_portraits = tf.split(placeholder_real_portraits_train, num_or_size_splits=submit_config.num_gpus, axis=0) real_split_shuffled = tf.split(placeholder_real_shuffled_train, num_or_size_splits=submit_config.num_gpus, axis=0) real_split_lm_shuffled = tf.split(placeholder_landmarks_shuffled_train, num_or_size_splits=submit_config.num_gpus, axis=0) placeholder_training_flag = tf.placeholder(tf.string, name='placeholder_training_flag') with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) E, G, D, Gs = misc.load_pkl(network_pkl) start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size print('Start: ', start) else: print('Constructing networks...') G, _, Gs = misc.load_pkl(decoder_pkl.decoder_pkl) # don't use pre-trained discriminator! num_layers = Gs.components.synthesis.input_shape[1] # here we add a new discriminator! D = tflib.Network('D', # name of the network how we call it num_channels=3, resolution=128, label_size=0, #some needed for this build function func_name="training.networks_stylegan.D_basic") # function of that network. more was not passed in d_args! # input is not passed here (just construction - note that we do not call the actual function!). Instead, network will inspect build function and require it for the get_output_for function. print("Created new Discriminator!") E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, num_layers=num_layers, phase=True, **Encoder_args) start = 0 Inv, _, _, _ = misc.load_pkl(inversion_pkl.inversion_pkl) E.print_layers(); Gs.print_layers(); D.print_layers() global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step, lr_args.decay_rate, staircase=lr_args.stair) add_global0 = global_step0.assign_add(1) E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) E_loss_rec = 0. E_loss_adv = 0. D_loss_real = 0. D_loss_fake = 0. D_loss_grad = 0. for gpu in range(submit_config.num_gpus): print('build graph on gpu %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow') Inv_gpu = Inv if gpu == 0 else Inv.clone(Inv.name + '_shadow') perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False) real_portraits_gpu = process_reals(real_split_portraits[gpu], mirror_augment, drange_data, drange_net) shuffled_portraits_gpu = process_reals(real_split_shuffled[gpu], mirror_augment, drange_data, drange_net) real_landmarks_gpu = process_reals(real_split_landmarks[gpu], mirror_augment, drange_data, drange_net) shuffled_landmarks_gpu = process_reals(real_split_lm_shuffled[gpu], mirror_augment, drange_data, drange_net) with tf.name_scope('E_loss'), tf.control_dependencies(None): E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, Inv=Inv_gpu, perceptual_model=perceptual_model, real_portraits=real_portraits_gpu, shuffled_portraits=shuffled_portraits_gpu, real_landmarks=real_landmarks_gpu, shuffled_landmarks=shuffled_landmarks_gpu, training_flag=placeholder_training_flag, **E_loss_args) E_loss_rec += recon_loss E_loss_adv += adv_loss with tf.name_scope('D_loss'), tf.control_dependencies(None): D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, Inv=Inv_gpu, real_portraits=real_portraits_gpu, shuffled_portraits=shuffled_portraits_gpu, real_landmarks=real_landmarks_gpu, training_flag=placeholder_training_flag, **D_loss_args) # change signature in ... D_loss_real += loss_real D_loss_fake += loss_fake D_loss_grad += loss_gp with tf.control_dependencies([add_global0]): E_opt.register_gradients(E_loss, E_gpu.trainables) D_opt.register_gradients(D_loss, D_gpu.trainables) E_loss_rec /= submit_config.num_gpus E_loss_adv /= submit_config.num_gpus D_loss_real /= submit_config.num_gpus D_loss_fake /= submit_config.num_gpus D_loss_grad /= submit_config.num_gpus E_train_op = E_opt.apply_updates() D_train_op = D_opt.apply_updates() print('building testing graph...') fake_X_val = test(E, Gs, Inv, placeholder_real_portraits_test, placeholder_real_landmarks_test, placeholder_real_shuffled_test, submit_config) inv_X_val = test_inversion(E, Gs, Inv, placeholder_real_portraits_test, placeholder_real_landmarks_test, placeholder_real_shuffled_test, submit_config) #sampled_portraits_val = sample_random_portraits(Gs, submit_config.batch_size) #sampled_portraits_val_test = sample_random_portraits(Gs, submit_config.batch_size_test) sess = tf.get_default_session() print('Getting training data...') # x_batch is a batch of (2, ..., ..., ...) records! stack_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train') stack_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test') stack_batch_train_secondary = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train_secondary') stack_batch_test_secondary = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test_secondary') summary_log = tf.summary.FileWriter(config.getGdrivePath()) cur_nimg = start * submit_config.batch_size cur_tick = 0 tick_start_nimg = cur_nimg start_time = time.time() init_fix = tf.initialize_variables( [global_step0], name='init_fix' ) sess.run(init_fix) print('Optimization starts!!!') # here is the actual training loop: all iterations for it in range(start, max_iters): batch_stacks = sess.run(stack_batch_train) batch_portraits = batch_stacks[:,0,:,:,:] batch_landmarks = batch_stacks[:,1,:,:,:] batch_stacks_secondary = sess.run(stack_batch_train_secondary) batch_shuffled = batch_stacks_secondary[:,0,:,:,:] batch_lm_shuffled = batch_stacks_secondary[:,1,:,:,:] training_flag = "pose" feed_dict_1 = {placeholder_real_portraits_train: batch_portraits, placeholder_real_landmarks_train: batch_landmarks, placeholder_real_shuffled_train:batch_shuffled, placeholder_landmarks_shuffled_train:batch_lm_shuffled, placeholder_training_flag: training_flag} # here we query these encoder- and discriminator losses. as input we provide: batch_stacks = batch of images + landmarks. _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict_1) _, d_r_, d_f_, d_g_= sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict_1) cur_nimg += submit_config.batch_size if it % 50 == 0: print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % ( it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time))) sys.stdout.flush() tflib.autosummary.save_summaries(summary_log, it) if it % 500 == 0: batch_stacks_test = sess.run(stack_batch_test) batch_portraits_test = batch_stacks_test[:,0,:,:,:] batch_landmarks_test = batch_stacks_test[:,1,:,:,:] batch_stacks_test_secondary = sess.run(stack_batch_test_secondary) batch_shuffled_test = batch_stacks_test_secondary[:,0,:,:,:] batch_shuffled_lm_test = batch_stacks_test_secondary[:,1,:,:,:] batch_portraits_test = misc.adjust_dynamic_range(batch_portraits_test.astype(np.float32), [0, 255], [-1., 1.]) batch_landmarks_test = misc.adjust_dynamic_range(batch_landmarks_test.astype(np.float32), [0, 255], [-1., 1.]) batch_shuffled_test = misc.adjust_dynamic_range(batch_shuffled_test.astype(np.float32), [0, 255], [-1., 1.]) batch_shuffled_lm_test = misc.adjust_dynamic_range(batch_shuffled_lm_test.astype(np.float32), [0, 255], [-1., 1.]) # first: input + target landmarks = manipulated image samples_manipulated = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_shuffled_lm_test}) # 2nd: manipulated + original landmarks samples_reconstructed = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: samples_manipulated, placeholder_real_landmarks_test: batch_landmarks_test}) # also: show direct reconstruction samples_direct_rec = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_landmarks_test}) # show results of the inverison portraits_inverted = sess.run(inv_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_landmarks_test}) # show: original portrait, original landmark, diret reconstruction, fake landmark, manipulated, rec. debug_img = np.concatenate([ batch_landmarks_test, # original landmarks batch_portraits_test, # original portraits, samples_direct_rec, # direct batch_shuffled_lm_test, # shuffled landmarks samples_manipulated, # manipulated images samples_reconstructed, portraits_inverted# cycle reconstructed images ], axis=0) debug_img = adjust_pixel_range(debug_img) debug_img = fuse_images(debug_img, row=6, col=submit_config.batch_size_test) # save image results during training, first row is original images and the second row is reconstructed images save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), debug_img) # save image to gdrive img_path = os.path.join(config.getGdrivePath(), 'images', ('iter_%08d.png' % (cur_nimg))) save_image(img_path, debug_img) if cur_nimg >= tick_start_nimg + 65000: cur_tick += 1 tick_start_nimg = cur_nimg if cur_tick % network_snapshot_ticks == 0: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs), pkl) # save network snapshot to gdrive pkl_drive = os.path.join(config.getGdrivePath(), 'snapshots', 'network-snapshot-%08d.pkl' % (cur_nimg)) misc.save_pkl((E, G, D, Gs), pkl_drive) misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close()
def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. AE_opt_args=None, # Options for autoencoder optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. AE_loss_args=None, # Options for autoencoder loss. dataset_args={}, # Options for dataset.load_dataset(). dataset_args_eval={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). train_data_dir=None, # Directory to load datasets from. eval_data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=True, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, resume_with_own_vars=False ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. print("Loading train set from %s..." % dataset_args.tfrecord_dir) training_set = dataset.load_dataset( data_dir=dnnlib.convert_path(train_data_dir), verbose=True, **dataset_args) print("Loading eval set from %s..." % dataset_args_eval.tfrecord_dir) eval_set = dataset.load_dataset( data_dir=dnnlib.convert_path(eval_data_dir), verbose=True, **dataset_args_eval) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Freeze Discriminator if D_args['freeze']: num_layers = np.log2(training_set.resolution) - 1 layers = int(np.round(num_layers * 3. / 8.)) scope = ['Output', 'scores_out'] for layer in range(layers): scope += ['.*%d' % 2**layer] if 'train_scope' in D_args: scope[-1] += '.*%d' % D_args['train_scope'] D_args['train_scope'] = scope # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is '' or resume_with_new_nets or resume_with_own_vars: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not '': print('Loading networks from "%s"...' % resume_pkl) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G = rG D = rD Gs = rGs grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) # SVD stuff if 'syn_svd' in G_args or 'map_svd' in G_args: # Run graph to calculate SVD grid_latents_smol = grid_latents[:1] rho = np.array([1]) grid_fakes = G.run(grid_latents_smol, grid_labels, rho, is_validation=True) grid_fakes = Gs.run(grid_latents_smol, grid_labels, rho, is_validation=True) load_d_fake = D.run(grid_reals[:1], rho, is_validation=True) with tf.device('/gpu:0'): # Create SVD-decomposed graph rG, rD, rGs = G, D, Gs G_lambda_mask = { var: np.ones(G.vars[var].shape[-1]) for var in G.vars if 'SVD/s' in var } D_lambda_mask = { 'D/' + var: np.ones(D.vars[var].shape[-1]) for var in D.vars if 'SVD/s' in var } G_reduce_dims = { var: (0, int(Gs.vars[var].shape[-1])) for var in Gs.vars if 'SVD/s' in var } G_args['lambda_mask'] = G_lambda_mask G_args['reduce_dims'] = G_reduce_dims D_args['lambda_mask'] = D_lambda_mask # Create graph with no SVD operations G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=rG.input_shapes[1][1], factorized=True, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=rD.input_shapes[1][1], factorized=True, **D_args) Gs = G.clone('Gs') grid_fakes = G.run(grid_latents_smol, grid_labels, rho, is_validation=True, minibatch_size=1) grid_fakes = Gs.run(grid_latents_smol, grid_labels, rho, is_validation=True, minibatch_size=1) G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) # Reduce per-gpu minibatch size to fit in 16GB GPU memory if grid_reals.shape[2] >= 1024: sched_args.minibatch_gpu_base = 2 print('Batch size', sched_args.minibatch_gpu_base) # Generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) rho = np.array([1]) grid_fakes = Gs.run(grid_latents, grid_labels, rho, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) if resume_pkl is not '': load_d_real = rD.run(grid_reals[:1], rho, is_validation=True) load_d_fake = rD.run(grid_fakes[:1], rho, is_validation=True) d_fake = D.run(grid_fakes[:1], rho, is_validation=True) d_real = D.run(grid_reals[:1], rho, is_validation=True) print('Factorized fake', d_fake, 'loaded fake', load_d_fake, 'factorized real', d_real, 'loaded real', load_d_real) print('(should match)') # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) if AE_opt_args is not None: AE_opt_args = dict(AE_opt_args) AE_opt_args['minibatch_multiplier'] = minibatch_multiplier AE_opt_args['learning_rate'] = lrate_in AE_opt = tflib.Optimizer(name='TrainAE', **AE_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): if G_loss_args['func_name'] == 'training.loss.G_l1': G_loss_args['reals'] = reals_read else: G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) ae_iter_mul = 10 ae_rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus * ae_iter_mul) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: _g_loss, _ = tflib.run([G_loss, G_train_op], feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): print('g loss', _g_loss) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and cur_tick % network_snapshot_ticks == 0 or done: pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(eval_data_dir), num_gpus=num_gpus, tf_config=tf_config, rho=rho) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close() eval_set.close()
def training_loop( submit_config, G_args = {}, # 生成网络的设置。 D_args = {}, # 判别网络的设置。 G_opt_args = {}, # 生成网络优化器设置。 D_opt_args = {}, # 判别网络优化器设置。 G_loss_args = {}, # 生成损失设置。 D_loss_args = {}, # 判别损失设置。 dataset_args = {}, # 数据集设置。 sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 metric_arg_list = [], # 指标方法设置。 tf_config = {}, # tflib.init_tf()相关设置。 G_smoothing_kimg = 10.0, # 生成器权重的运行平均值的半衰期。 D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 mirror_augment = False, # 启用镜像增强? drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0): # 在训练开始时给定统计时间。影响报告。 # 初始化dnnlib和TensorFlow。 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # 载入训练集。 training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # 构建网络。 with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers(); D.print_layers() # 构建计算图与优化器 print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。 lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) cur_nimg += sched.minibatch tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( run_dir='.', # Output directory. G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. loss_args={}, # Options for loss function. train_dataset_args={}, # Options for dataset to train with. # Options for dataset to evaluate metrics against. metric_dataset_args={}, augment_args={}, # Options for adaptive augmentations. metric_arg_list=[], # Metrics to evaluate during training. num_gpus=1, # Number of GPUs to use. minibatch_size=32, # Global minibatch size. minibatch_gpu=4, # Number of samples processed at a time by one GPU. # Half-life of the exponential moving average (EMA) of generator weights. G_smoothing_kimg=10, G_smoothing_rampup=None, # EMA ramp-up coefficient. # Number of minibatches to run in the inner loop. minibatch_repeats=4, lazy_regularization=True, # Perform regularization as a separate training step? # How often the perform regularization for G? Ignored if lazy_regularization=False. G_reg_interval=4, # How often the perform regularization for D? Ignored if lazy_regularization=False. D_reg_interval=16, # Total length of the training, measured in thousands of real images. total_kimg=25000, kimg_per_tick=4, # Progress snapshot interval. # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. image_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. network_snapshot_ticks=50, resume_pkl=None, # Network pickle to resume training from. # Callback function for determining whether to abort training. abort_fn=None, progress_fn=None, # Callback function for updating training progress. ): assert minibatch_size % (num_gpus * minibatch_gpu) == 0 start_time = time.time() print('Loading training set...') training_set = dataset.load_dataset(**train_dataset_args) print('Image shape:', np.int32(training_set.shape).tolist()) print('Label shape:', [training_set.label_size]) print() print('Constructing networks...') with tf.device('/gpu:0'): G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: rG, rD, rGs = pickle.load(f) G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) G.print_layers() D.print_layers() print('Exporting sample images...') grid_size, grid_reals, grid_labels = setup_snapshot_image_grid( training_set) save_image_grid(grid_reals, os.path.join(run_dir, 'reals.png'), drange=[0, 255], grid_size=grid_size) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu) # save_image_grid(grid_fakes, os.path.join( # run_dir, 'fakes_init.png'), drange=[-1, 1], grid_size=grid_size) print(f'Replicating networks across {num_gpus} GPUs...') G_gpus = [G] D_gpus = [D] for gpu in range(1, num_gpus): with tf.device(f'/gpu:{gpu}'): G_gpus.append(G.clone(f'{G.name}_gpu{gpu}')) D_gpus.append(D.clone(f'{D.name}_gpu{gpu}')) print('Initializing augmentations...') aug = None if augment_args.get('class_name', None) is not None: aug = dnnlib.util.construct_class_by_name(**augment_args) aug.init_validation_set(D_gpus=D_gpus, training_set=training_set) print('Setting up optimizers...') G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args[ 'minibatch_multiplier'] = minibatch_size // num_gpus // minibatch_gpu if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) print('Constructing training graph...') data_fetch_ops = [] training_set.configure(minibatch_gpu) for gpu, (G_gpu, D_gpu) in enumerate(zip(G_gpus, D_gpus)): with tf.name_scope(f'Train_gpu{gpu}'), tf.device(f'/gpu:{gpu}'): # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): real_images_var = tf.Variable( name='images', trainable=False, initial_value=tf.zeros([minibatch_gpu] + training_set.shape)) real_labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ minibatch_gpu, training_set.label_size ])) real_images_write, real_labels_write = training_set.get_minibatch_tf( ) real_images_write = tflib.convert_images_from_uint8( real_images_write) data_fetch_ops += [ tf.assign(real_images_var, real_images_write) ] data_fetch_ops += [ tf.assign(real_labels_var, real_labels_write) ] # Evaluate loss function and register gradients. fake_labels = training_set.get_random_labels_tf(minibatch_gpu) terms = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, aug=aug, fake_labels=fake_labels, real_images=real_images_var, real_labels=real_labels_var, **loss_args) if lazy_regularization: if terms.G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(terms.G_reg * G_reg_interval), G_gpu.trainables) if terms.D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(terms.D_reg * D_reg_interval), D_gpu.trainables) else: if terms.G_reg is not None: terms.G_loss += terms.G_reg if terms.D_reg is not None: terms.D_loss += terms.D_reg G_opt.register_gradients(tf.reduce_mean(terms.G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(terms.D_loss), D_gpu.trainables) print('Finalizing training ops...') data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_beta_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[]) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta_in) tflib.init_uninitialized_vars() with tf.device('/gpu:0'): peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() print('Initializing metrics...') summary_log = tf.summary.FileWriter(run_dir) metrics = [] for args in metric_arg_list: metric = dnnlib.util.construct_class_by_name(**args) metric.configure(dataset_args=metric_dataset_args, run_dir=run_dir) metrics.append(metric) print(f'Training for {total_kimg} kimg...') print() if progress_fn is not None: progress_fn(0, total_kimg) tick_start_time = time.time() maintenance_time = tick_start_time - start_time cur_nimg = 0 cur_tick = -1 tick_start_nimg = cur_nimg running_mb_counter = 0 done = False while not done: # Compute EMA decay parameter. Gs_nimg = G_smoothing_kimg * 1000.0 if G_smoothing_rampup is not None: Gs_nimg = min(Gs_nimg, cur_nimg * G_smoothing_rampup) Gs_beta = 0.5**(minibatch_size / max(Gs_nimg, 1e-8)) # Run training ops. for _repeat_idx in range(minibatch_repeats): rounds = range(0, minibatch_size, minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op]) if run_G_reg: tflib.run(G_reg_op) tflib.run([D_train_op, Gs_update_op], {Gs_beta_in: Gs_beta}) if run_D_reg: tflib.run(D_reg_op) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op) if run_G_reg: tflib.run(G_reg_op) tflib.run(Gs_update_op, {Gs_beta_in: Gs_beta}) for _round in rounds: tflib.run(data_fetch_op) tflib.run(D_train_op) if run_D_reg: tflib.run(D_reg_op) # Run validation. if aug is not None: aug.run_validation(minibatch_size=minibatch_size) # Tune augmentation parameters. if aug is not None: aug.tune(minibatch_size * minibatch_repeats) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) or (abort_fn is not None and abort_fn()) if done or cur_tick < 0 or cur_nimg >= tick_start_nimg + kimg_per_tick * 1000: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_end_time = time.time() total_time = tick_end_time - start_time tick_time = tick_end_time - tick_start_time # Report progress. print(' '.join([ f"tick {autosummary('Progress/tick', cur_tick):<5d}", f"kimg {autosummary('Progress/kimg', cur_nimg / 1000.0):<8.1f}", f"time {dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)):<12s}", f"sec/tick {autosummary('Timing/sec_per_tick', tick_time):<7.1f}", f"sec/kimg {autosummary('Timing/sec_per_kimg', tick_time / tick_kimg):<7.2f}", f"maintenance {autosummary('Timing/maintenance_sec', maintenance_time):<6.1f}", f"gpumem {autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30):<5.1f}", f"augment {autosummary('Progress/augment', aug.strength if aug is not None else 0):.3f}", ])) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Save snapshots. if image_snapshot_ticks is not None and ( done or cur_tick % image_snapshot_ticks == 0): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu) save_image_grid(grid_fakes, os.path.join( run_dir, f'fakes{cur_nimg // 1000:06d}.png'), drange=[-1, 1], grid_size=grid_size) if network_snapshot_ticks is not None and ( done or cur_tick % network_snapshot_ticks == 0): pkl = os.path.join( run_dir, f'network-snapshot-{cur_nimg // 1000:06d}.pkl') with open(pkl, 'wb') as f: pickle.dump((G, D, Gs), f) if len(metrics): print('Evaluating metrics...') for metric in metrics: metric.run(pkl, num_gpus=num_gpus) # Update summaries. for metric in metrics: metric.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time print() print('Exiting...') summary_log.close() training_set.close()
def training_loop_vae( G_args={}, # Options for generator network. E_args={}, # Options for encoder network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. minibatch_repeats=1, # Number of minibatches to run before adjusting training parameters. total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=0, # Number of discrete latents in model. n_continuous=4, # Number of continuous latents in model. topk_dims_to_show=20, # Number of top disentant dimensions to show in a snapshot. subgroup_sizes_ls=None, subspace_sizes_ls=None, forward_eg=False, n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # If use Discriminator. use_D = D_args is not None # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) grid_fakes = add_outline(grid_reals, width=1) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') E = tflib.Network('E', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None] + training_set.shape, **E_args) G = tflib.Network( 'G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None, n_discrete + G_args.latent_size] if not forward_eg else [ None, n_discrete + G_args.latent_size + sum(subgroup_sizes_ls) ], **G_args) if use_D: D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None, D_args.latent_size], **D_args) if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_D: rE, rG, rD = misc.load_pkl(resume_pkl) else: rE, rG = misc.load_pkl(resume_pkl) if resume_with_new_nets: E.copy_vars_from(rE) G.copy_vars_from(rG) if use_D: D.copy_vars_from(rD) else: E = rE G = rG if use_D: D = rD # Print layers and generate initial image snapshot. E.print_layers() G.print_layers() if use_D: D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) if traversal_grid: if topk_dims_to_show > 0: topk_dims = np.arange(min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) print('grid_size:', grid_size) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v( G.run(append_gfeats(grid_latents, G) if forward_eg else grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True), 8) print('Lie_vars:', lie_vars[0]) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. G_opt_args = dict(G_opt_args) G_opt_args['minibatch_multiplier'] = minibatch_multiplier G_opt_args['learning_rate'] = lrate_in G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) if use_D: D_opt_args = dict(D_opt_args) D_opt_args['minibatch_multiplier'] = minibatch_multiplier D_opt_args['learning_rate'] = lrate_in D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if use_D: D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, 0., mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. if use_D: with tf.name_scope('G_loss'): G_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **G_loss_args) with tf.name_scope('D_loss'): D_loss = dnnlib.util.call_func_by_name( E=E_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) else: with tf.name_scope('G_loss'): G_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **G_loss_args) # Register gradients. EG_gpu_trainables = collections.OrderedDict( list(E_gpu.trainables.items()) + list(G_gpu.trainables.items())) G_opt.register_gradients(tf.reduce_mean(G_loss), EG_gpu_trainables) # G_opt.register_gradients(G_loss, # EG_gpu_trainables) if use_D: D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # D_opt.register_gradients(D_loss, # D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() if use_D: D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() if use_D: D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, 0) # Run training ops. feed_dict = { lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op], feed_dict) tflib.run([data_fetch_op], feed_dict) if use_D: tflib.run([D_train_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) if use_D: tflib.run(D_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) if use_D: misc.save_pkl((E, G, D), pkl) else: misc.save_pkl((E, G), pkl) met_outs = metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config, is_vae=True, use_D=use_D, Gs_kwargs=dict(is_validation=True)) if topk_dims_to_show > 0: if 'tpl_per_dim' in met_outs: avg_distance_per_dim = met_outs[ 'tpl_per_dim'] # shape: (n_continuous) topk_dims = np.argsort( avg_distance_per_dim )[::-1][:topk_dims_to_show] # shape: (20) else: topk_dims = np.arange( min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): if traversal_grid: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v( G.run(append_gfeats(grid_latents, G) if forward_eg else grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True), 8) print('Lie_vars:', lie_vars[0]) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. if use_D: misc.save_pkl((E, G, D), dnnlib.make_run_dir_path('network-final.pkl')) else: misc.save_pkl((E, G), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def set_network(self, Gs, dtype='float16'): if Gs is None: self._Gs = None return self._Gs = Gs.clone(randomize_noise=False, dtype=dtype, num_fp16_res=0, fused_modconv=True) # Compute dlatent stats. self._info( f'Computing W midpoint and stddev using {self.dlatent_avg_samples} samples...' ) latent_samples = np.random.RandomState(123).randn( self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:]) dlatent_samples = self._Gs.components.mapping.run( latent_samples, None) # [N, L, C] dlatent_samples = dlatent_samples[:, :1, :].astype( np.float32) # [N, 1, C] self._dlatent_avg = np.mean(dlatent_samples, axis=0, keepdims=True) # [1, 1, C] self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg)**2) / self.dlatent_avg_samples)**0.5 self._info(f'std = {self._dlatent_std:g}') # Setup noise inputs. self._info('Setting up noise inputs...') self._noise_vars = [] noise_init_ops = [] noise_normalize_ops = [] while True: n = f'G_synthesis/noise{len(self._noise_vars)}' if not n in self._Gs.vars: break v = self._Gs.vars[n] self._noise_vars.append(v) noise_init_ops.append( tf.assign(v, tf.random_normal(tf.shape(v), dtype=tf.float32))) noise_mean = tf.reduce_mean(v) noise_std = tf.reduce_mean((v - noise_mean)**2)**0.5 noise_normalize_ops.append( tf.assign(v, (v - noise_mean) / noise_std)) self._noise_init_op = tf.group(*noise_init_ops) self._noise_normalize_op = tf.group(*noise_normalize_ops) # Build image output graph. self._info('Building image output graph...') self._minibatch_size = 1 self._dlatents_var = tf.Variable( tf.zeros([self._minibatch_size] + list(self._dlatent_avg.shape[1:])), name='dlatents_var') self._dlatent_noise_in = tf.placeholder(tf.float32, [], name='noise_in') dlatents_noise = tf.random.normal( shape=self._dlatents_var.shape) * self._dlatent_noise_in self._dlatents_expr = tf.tile( self._dlatents_var + dlatents_noise, [1, self._Gs.components.synthesis.input_shape[1], 1]) self._images_float_expr = tf.cast( self._Gs.components.synthesis.get_output_for(self._dlatents_expr), tf.float32) self._images_uint8_expr = tflib.convert_images_to_uint8( self._images_float_expr, nchw_to_nhwc=True) # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. proc_images_expr = (self._images_float_expr + 1) * (255 / 2) sh = proc_images_expr.shape.as_list() if sh[2] > 256: factor = sh[2] // 256 proc_images_expr = tf.reduce_mean(tf.reshape( proc_images_expr, [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]), axis=[3, 5]) # Build loss graph. self._info('Building loss graph...') self._target_images_var = tf.Variable(tf.zeros(proc_images_expr.shape), name='target_images_var') if self._lpips is None: with dnnlib.util.open_url( 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/vgg16_zhang_perceptual.pkl' ) as f: self._lpips = pickle.load(f) self._dist = self._lpips.get_output_for(proc_images_expr, self._target_images_var) self._loss = tf.reduce_sum(self._dist) # Build noise regularization graph. self._info('Building noise regularization graph...') reg_loss = 0.0 for v in self._noise_vars: sz = v.shape[2] while True: reg_loss += tf.reduce_mean( v * tf.roll(v, shift=1, axis=3))**2 + tf.reduce_mean( v * tf.roll(v, shift=1, axis=2))**2 if sz <= 8: break # Small enough already v = tf.reshape(v, [1, 1, sz // 2, 2, sz // 2, 2]) # Downscale v = tf.reduce_mean(v, axis=[3, 5]) sz = sz // 2 self._loss += reg_loss * self.regularize_noise_weight # Setup optimizer. self._info('Setting up optimizer...') self._lrate_in = tf.placeholder(tf.float32, [], name='lrate_in') self._opt = tflib.Optimizer(learning_rate=self._lrate_in) self._opt.register_gradients(self._loss, [self._dlatents_var] + self._noise_vars) self._opt_step = self._opt.apply_updates()
def training_loop_infernet( I_args={}, # Options for infogan-head/vcgan-head network. I_opt_args={}, # Options for discriminator optimizer. loss_args={}, # Options for discriminator loss. sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=5, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? G_pkl=None, # The G to load. resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Construct or load networks. with tf.device('/gpu:0'): G, D, I, Gs = misc.load_pkl(G_pkl) print('Gs.output_shapes:', Gs.output_shapes) if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') I = tflib.Network('I', num_channels=Gs.output_shapes[0][1], resolution=Gs.output_shapes[0][2], **I_args) if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) rI, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: I.copy_vars_from(rI) Gs.copy_vars_from(rGs) else: I = rI Gs = rGs # Print layers and generate initial image snapshot. Gs.print_layers() I.print_layers() # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. I_opt_args = dict(I_opt_args) I_opt_args['minibatch_multiplier'] = minibatch_multiplier I_opt_args['learning_rate'] = lrate_in I_opt = tflib.Optimizer(name='TrainI', **I_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow') # Evaluate loss functions. with tf.name_scope('I_loss'): loss, reg = dnnlib.util.call_func_by_name( G=G_gpu, I=I_gpu, opt=I_opt, minibatch_size=minibatch_gpu_in, **loss_args) if reg is not None: loss += reg # Register gradients. I_opt.register_gradients(tf.reduce_mean(loss), I_gpu.trainables) # Setup training ops. I_train_op = I_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: I.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. assert sched_args.minibatch_size % (sched_args.minibatch_gpu * num_gpus) == 0 # Run training ops. feed_dict = { lrate_in: sched_args.lrate, minibatch_size_in: sched_args.minibatch_size, minibatch_gpu_in: sched_args.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched_args.minibatch_size, sched_args.minibatch_gpu * num_gpus) cur_nimg += sched_args.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([I_train_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(I_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched_args.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched_args.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((I, G), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), num_gpus=num_gpus, tf_config=tf_config, train_infernet=True) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update(cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((I, G), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close()
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=10000.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: #print('Constructing networks...') #G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) #D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) #Gs = G.clone('Gs') url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: G, D, Gs = pickle.load(f) print('Loading pretrained FFHQ network') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid( grid_reals, dnnlib.make_run_dir_path("reals.png"), drange=training_set.dynamic_range, grid_size=grid_size, ) # Construct or load networks. with tf.device("/gpu:0"): if resume_pkl is None or resume_with_new_nets: print("Constructing networks...") G = tflib.Network("G", num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network("D", num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone("Gs") if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G = rG D = rD Gs = rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, ) misc.save_image_grid( grid_fakes, dnnlib.make_run_dir_path("fakes_init.png"), drange=drange_net, grid_size=grid_size, ) # Setup training inputs. print("Building TensorFlow graph...") with tf.name_scope("Inputs"), tf.device("/cpu:0"): lod_in = tf.placeholder(tf.float32, name="lod_in", shape=[]) lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name="minibatch_size_in", shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name="minibatch_gpu_in", shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = (0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0) # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [ (G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval), ]: args["minibatch_multiplier"] = minibatch_multiplier args["learning_rate"] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args["learning_rate"] *= mb_ratio if "beta1" in args: args["beta1"] **= mb_ratio if "beta2" in args: args["beta2"] **= mb_ratio G_opt = tflib.Optimizer(name="TrainG", **G_opt_args) D_opt = tflib.Optimizer(name="TrainD", **D_opt_args) G_reg_opt = tflib.Optimizer(name="RegG", share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name="RegD", share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow") D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow") # Fetch training data via temporary variables. with tf.name_scope("DataFetch"): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name="reals", trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape), ) labels_var = tf.Variable( name="labels", trainable=False, initial_value=tf.zeros( [sched.minibatch_gpu, training_set.label_size]), ) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net, ) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if "lod" in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars["lod"], lod_in)] if "lod" in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars["lod"], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope("G_loss"): G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope("D_loss"): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device("/gpu:0"): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print("Initializing logs...") summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print("Training for %d kimg...\n" % total_kimg) dnnlib.RunContext.get().update("", cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = lazy_regularization and running_mb_counter % G_reg_interval == 0 run_D_reg = lazy_regularization and running_mb_counter % D_reg_interval == 0 cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = cur_nimg >= total_kimg * 1000 if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( "tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f" % ( autosummary("Progress/tick", cur_tick), autosummary("Progress/kimg", cur_nimg / 1000.0), autosummary("Progress/lod", sched.lod), autosummary("Progress/minibatch", sched.minibatch_size), dnnlib.util.format_time( autosummary("Timing/total_sec", total_time)), autosummary("Timing/sec_per_tick", tick_time), autosummary("Timing/sec_per_kimg", tick_time / tick_kimg), autosummary("Timing/maintenance_sec", maintenance_time), autosummary("Resources/peak_gpu_mem_gb", peak_gpu_mem_op.eval() / 2**30), )) autosummary("Timing/total_hours", total_time / (60.0 * 60.0)) autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, ) misc.save_image_grid( grid_fakes, dnnlib.make_run_dir_path("fakes%06d.png" % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size, ) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run( pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config, ) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update("%.2f" % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ( dnnlib.RunContext.get().get_last_update_interval() - tick_time) # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path("network-final.pkl")) # All done. summary_log.close() training_set.close()