def run(network_pkl, metrics, dataset, data_dir, mirror_augment, include_I=False, is_vae=False, mapping_nodup=False, avg_mv_for_I=False): print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), network_pkl)) tflib.init_tf() network_pkl = pretrained_networks.get_path_or_url(network_pkl) dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0) num_gpus = dnnlib.submit_config.num_gpus metric_group = metric_base.MetricGroup( [metric_defaults[metric] for metric in metrics]) metric_group.run(network_pkl, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=num_gpus, include_I=include_I, is_vae=is_vae, mapping_nodup=mapping_nodup, avg_mv_for_I=avg_mv_for_I)
def run_eval(dataset, resolution, num_gpus, metrics, resume, num_repeats, **kwargs): dataset = dataset_tool.create_dataset(dataset, resolution) print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), resume)) tflib.init_tf() dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0, from_tfrecords=True) metric_group = metric_base.MetricGroup([metric_defaults[metric] for metric in metrics], num_repeats=num_repeats) metric_group.run(resume, dataset_args=dataset_args, num_gpus=num_gpus)
def run_eval(dataset, resolution, result_dir, DiffAugment, num_gpus, batch_size, total_kimg, ema_kimg, num_samples, gamma, fmap_base, fmap_max, latent_size, mirror_augment, impl, metrics, resume, resume_kimg, num_repeats, eval): dataset = dataset_tool.create_dataset(dataset, resolution) print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), resume)) tflib.init_tf() dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0, from_tfrecords=True) metric_group = metric_base.MetricGroup([metric_defaults[metric] for metric in metrics], num_repeats=num_repeats) metric_group.run(resume, dataset_args=dataset_args, num_gpus=num_gpus)
def run(network_pkl, metrics, dataset, data_dir, mirror_augment, rho_steps): print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), network_pkl)) tflib.init_tf() network_pkl = pretrained_networks.get_path_or_url(network_pkl) dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0) num_gpus = dnnlib.submit_config.num_gpus metric_group = metric_base.MetricGroup([metric_defaults[metric] for metric in metrics]) if rho_steps > 1: rho_sweep = np.linspace(0, 1, rho_steps) else: rho_sweep = [1] for rho in rho_sweep: print(rho) metric_group.run(network_pkl, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=num_gpus, rho=rho)
def run(network_pkl, metrics, dataset, data_dir, mirror_augment): print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), network_pkl)) tflib.init_tf() network_pkl = pretrained_networks.get_path_or_url(network_pkl) dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0, max_label_size="full") num_gpus = dnnlib.submit_config.num_gpus metric_group = metric_base.MetricGroup( [metric_defaults[metric] for metric in metrics]) metric_group.run(network_pkl, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=num_gpus)
def run(network_pkl, metrics, dataset, data_dir, mirror_augment, paths): print("Evaluating metrics %s for %s..." % (",".join(metrics), network_pkl)) tflib.init_tf() network_pkl = pretrained_networks.get_path_or_url(network_pkl) dataset_args = dnnlib.EasyDict(tfrecord_dir = dataset, shuffle_mb = 0) num_gpus = dnnlib.submit_config.num_gpus metric_group = metric_base.MetricGroup([metric_defaults[metric] for metric in metrics]) tf_config = { "rnd.np_random_seed": 1000, "allow_soft_placement": True, "gpu_options.per_process_gpu_memory_fraction": 1.0 } metric_group.run(network_pkl, data_dir = data_dir, dataset_args = dataset_args, tf_config = tf_config, mirror_augment = mirror_augment, num_gpus = num_gpus, paths = paths)
def run(network_pkls, metrics, dataset, data_dir, mirror_augment, num_repeats, truncation, resume_with_new_nets): tflib.init_tf() dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0) num_gpus = dnnlib.submit_config.num_gpus truncations = [float(t) for t in truncation.split(',') ] if truncation is not None else [None] for network_pkl in network_pkls.split(','): print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), network_pkl)) metric_group = metric_base.MetricGroup( [metric_defaults[metric] for metric in metrics]) metric_group.run(network_pkl, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=num_gpus, num_repeats=num_repeats, resume_with_new_nets=resume_with_new_nets, truncations=truncations)
def run(network_pkl, metrics, dataset, data_dir, mirror_augment): print('Evaluating metrics "%s" for "%s"...' % (','.join(metrics), network_pkl)) tflib.init_tf() pkls = [ v for v in os.listdir(network_pkl) if v.startswith('network') and v.endswith('.pkl') ] pkls.sort() for pkl in pkls: net_pkl = pretrained_networks.get_path_or_url( os.path.join(network_pkl, pkl)) print('Process pkl %s' % pkl) dataset_args = dnnlib.EasyDict(tfrecord_dir=dataset, shuffle_mb=0) num_gpus = dnnlib.submit_config.num_gpus metric_group = metric_base.MetricGroup( [metric_defaults[metric] for metric in metrics]) metric_group.run(net_pkl, data_dir=data_dir, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=num_gpus) metric_group.update_autosummaries()
def training_loop( classifier_args={}, # Options for generator network. classifier_opt_args={}, # Options for generator optimizer. classifier_loss_args={}, dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. network_snapshot_ticks=5, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False): # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, shuffle_mb=2 * 4096, **dataset_args) # Construct or load networks. with tf.device('/gpu:0'): print('Constructing networks...') classifier = tflib.Network('classifier', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **classifier_args) classifier.print_layers() # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. classifier_opt_args = dict(classifier_opt_args) classifier_opt_args['minibatch_multiplier'] = minibatch_multiplier classifier_opt_args['learning_rate'] = lrate_in classifier_opt = tflib.Optimizer(name='TrainClassifier', **classifier_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('gpu%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. classifier_gpu = classifier if gpu == 0 else classifier.clone( classifier.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=0, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros( [sched.minibatch_gpu, 127])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. with tf.name_scope('classifier_loss'): classifier_loss, label = dnnlib.util.call_func_by_name( classifier=classifier_gpu, images=reals_read, labels=labels_read, **classifier_loss_args) classifier_opt.register_gradients(tf.reduce_mean(classifier_loss), classifier_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) classifier_train_op = classifier_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=0, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = 0 cur_tick = -1 tick_start_nimg = cur_nimg running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # Run training ops. feed_dict = { lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([classifier_train_op, data_fetch_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(data_fetch_op, feed_dict) classifier_loss_out, label_out, _ = tflib.run( [classifier_loss, label, classifier_train_op], feed_dict) print_output = False if print_output: print('label') print(np.round(label_out, 2)) print('loss') print(np.round(classifier_loss_out, 2)) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl(classifier, pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl(classifier, dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( # Configurations cG={}, cD={}, # Generator and Discriminator command-line arguments dataset_args={}, # dataset.load_dataset() options sched_args={}, # train.TrainingSchedule options vis_args={}, # vis.eval options grid_args={}, # train.setup_snapshot_img_grid() options metric_arg_list=[], # MetricGroup Options tf_config={}, # tflib.init_tf() options eval=False, # Evaluation mode train=False, # Training mode # Data data_dir=None, # Directory to load datasets from total_kimg=25000, # Total length of the training, measured in thousands of real images mirror_augment=False, # Enable mirror augmentation? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks ratio=1.0, # Image height/width ratio in the dataset # Optimization minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters lazy_regularization=True, # Perform regularization as a separate training step? smoothing_kimg=10.0, # Half-life of the running average of generator weights clip=None, # Clip gradients threshold # Resumption resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning # Affects reporting and training schedule resume_time=0.0, # Assumed wallclock time at the beginning, affects reporting recompile=False, # Recompile network from source code (otherwise loads from snapshot) # Logging summarize=True, # Create TensorBoard summaries save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? img_snapshot_ticks=3, # How often to save image snapshots? None = disable network_snapshot_ticks=3, # How often to save network snapshots? None = only save networks-final.pkl last_snapshots=10, # Maximal number of prior snapshots to save eval_images_num=50000, # Sample size for the metrics printname="", # Experiment name for logging # Architecture merge=False): # Generate several images and then merge them # Initialize dnnlib and TensorFlow tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus cG.name, cD.name = "g", "d" # Load dataset, configure training scheduler and metrics object dataset = data.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) sched = training_schedule(sched_args, cur_nimg=total_kimg * 1000, dataset=dataset) metrics = metric_base.MetricGroup(metric_arg_list) # Construct or load networks with tf.device("/gpu:0"): no_op = tf.no_op() G, D, Gs = None, None, None if resume_pkl is None or recompile: misc.log("Constructing networks...", "white") G = tflib.Network("G", num_channels=dataset.shape[0], resolution=dataset.shape[1], label_size=dataset.label_size, **cG.args) D = tflib.Network("D", num_channels=dataset.shape[0], resolution=dataset.shape[1], label_size=dataset.label_size, **cD.args) Gs = G.clone("Gs") if resume_pkl is not None: G, D, Gs = load_nets(resume_pkl, G, D, Gs, recompile) G.print_layers() D.print_layers() # Train/Evaluate/Visualize # Labels are optional but not essential grid_size, grid_reals, grid_labels = misc.setup_snapshot_img_grid( dataset, **grid_args) misc.save_img_grid(grid_reals, dnnlib.make_run_dir_path("reals.png"), drange=dataset.dynamic_range, grid_size=grid_size) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) if eval: # Save a snapshot of the current network to evaluate pkl = dnnlib.make_run_dir_path("network-eval-snapshot-%06d.pkl" % resume_kimg) misc.save_pkl((G, D, Gs), pkl, remove=False) # Quantitative evaluation metric = metrics.run(pkl, num_imgs=eval_images_num, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, ratio=ratio, tf_config=tf_config, mirror_augment=mirror_augment) # Qualitative evaluation visualize.eval(G, dataset, batch_size=sched.minibatch_gpu, drange_net=drange_net, ratio=ratio, **vis_args) if not train: dataset.close() exit() # Setup training inputs misc.log("Building TensorFlow graph...", "white") with tf.name_scope("Inputs"), tf.device("/cpu:0"): lrate_in_g = tf.placeholder(tf.float32, name="lrate_in_g", shape=[]) lrate_in_d = tf.placeholder(tf.float32, name="lrate_in_d", shape=[]) step = tf.placeholder(tf.int32, name="step", shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name="minibatch_size_in", shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name="minibatch_gpu_in", shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), smoothing_kimg * 1000.0) if smoothing_kimg > 0.0 else 0.0 # Set optimizers for cN, lr in [(cG, lrate_in_g), (cD, lrate_in_d)]: set_optimizer(cN, lr, minibatch_multiplier, lazy_regularization, clip) # Build training graph for each GPU data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu): # Create GPU-specific shadow copies of G and D for cN, N in [(cG, G), (cD, D)]: cN.gpu = N if gpu == 0 else N.clone(N.name + "_shadow") Gs_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + "_shadow") # Fetch training data via temporary variables with tf.name_scope("DataFetch"): reals, labels = dataset.get_minibatch_tf() reals = process_reals(reals, dataset.dynamic_range, drange_net, mirror_augment) reals, reals_fetch = read_data( reals, "reals", [sched.minibatch_gpu] + dataset.shape, minibatch_gpu_in) labels, labels_fetch = read_data( labels, "labels", [sched.minibatch_gpu, dataset.label_size], minibatch_gpu_in) data_fetch_ops += [reals_fetch, labels_fetch] # Evaluate loss functions with tf.name_scope("G_loss"): cG.loss, cG.reg = dnnlib.util.call_func_by_name( G=cG.gpu, D=cD.gpu, dataset=dataset, reals=reals, minibatch_size=minibatch_gpu_in, **cG.loss_args) with tf.name_scope("D_loss"): cD.loss, cD.reg = dnnlib.util.call_func_by_name( G=cG.gpu, D=cD.gpu, dataset=dataset, reals=reals, labels=labels, minibatch_size=minibatch_gpu_in, **cD.loss_args) for cN in [cG, cD]: set_optimizer_ops(cN, lazy_regularization, no_op) # Setup training ops data_fetch_op = tf.group(*data_fetch_ops) for cN in [cG, cD]: cN.train_op = cN.opt.apply_updates() cN.reg_op = cN.reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=beta) # Finalize graph with tf.device("/gpu:0"): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() # Tensorboard summaries if summarize: misc.log("Initializing logs...", "white") summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() # Initialize training misc.log("Training for %d kimg..." % total_kimg, "white") dnnlib.RunContext.get().update("", cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_tick, running_mb_counter = -1, 0 cur_nimg = int(resume_kimg * 1000) tick_start_nimg = cur_nimg for cN in [cG, cD]: cN.lossvals_agg = { k: None for k in ["loss", "reg", "norm", "reg_norm"] } cN.opt.reset_optimizer_state() # Training loop while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops sched = training_schedule(sched_args, cur_nimg=cur_nimg, dataset=dataset) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 dataset.configure(sched.minibatch_gpu) # Run training ops feed_dict = { lrate_in_g: sched.G_lrate, lrate_in_d: sched.D_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, step: sched.kimg } # Several iterations before updating training parameters for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) for cN in [cG, cD]: cN.run_reg = lazy_regularization and (running_mb_counter % cN.reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 for cN in [cG, cD]: cN.lossvals = { k: None for k in ["loss", "reg", "norm", "reg_norm"] } # Gradient accumulation for _round in rounds: cG.lossvals.update( tflib.run([cG.train_op, cG.ops], feed_dict)[1]) if cG.run_reg: _, cG.lossvals["reg_norm"] = tflib.run( [cG.reg_op, cG.reg_norm], feed_dict) tflib.run(data_fetch_op, feed_dict) cD.lossvals.update( tflib.run([cD.train_op, cD.ops], feed_dict)[1]) if cD.run_reg: _, cD.lossvals["reg_norm"] = tflib.run( [cD.reg_op, cD.reg_norm], feed_dict) tflib.run([Gs_update_op], feed_dict) # Track loss statistics for cN in [cG, cD]: for k in cN.lossvals_agg: cN.lossvals_agg[k] = emaAvg(cN.lossvals_agg[k], cN.lossvals[k]) # Perform maintenance tasks once per tick done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress print( ("tick %s kimg %s loss/reg: G (%s %s) D (%s %s) grad norms: G (%s %s) D (%s %s) " + "time %s sec/kimg %s maxGPU %sGB %s") % (misc.bold("%-5d" % autosummary("Progress/tick", cur_tick)), misc.bcolored( "{:>8.1f}".format( autosummary("Progress/kimg", cur_nimg / 1000.0)), "red"), misc.bcolored("{:>6.3f}".format(cG.lossvals_agg["loss"] or 0), "blue"), misc.bold("{:>6.3f}".format(cG.lossvals_agg["reg"] or 0)), misc.bcolored("{:>6.3f}".format(cD.lossvals_agg["loss"] or 0), "blue"), misc.bold("{:>6.3f}".format(cD.lossvals_agg["reg"] or 0)), misc.cond_bcolored(cG.lossvals_agg["norm"], 20.0, "red"), misc.cond_bcolored(cG.lossvals_agg["reg_norm"], 20.0, "red"), misc.cond_bcolored(cD.lossvals_agg["norm"], 20.0, "red"), misc.cond_bcolored(cD.lossvals_agg["reg_norm"], 20.0, "red"), misc.bold("%-10s" % dnnlib.util.format_time( autosummary("Timing/total_sec", total_time))), "{:>7.2f}".format( autosummary("Timing/sec_per_kimg", tick_time / tick_kimg)), "{:>4.1f}".format( autosummary("Resources/peak_gpu_mem_gb", peak_gpu_mem_op.eval() / 2**30)), printname)) autosummary("Timing/total_hours", total_time / (60.0 * 60.0)) autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0)) # Save snapshots if img_snapshot_ticks is not None and ( cur_tick % img_snapshot_ticks == 0 or done): visualize.eval(G, dataset, batch_size=sched.minibatch_gpu, training=True, step=cur_nimg // 1000, grid_size=grid_size, latents=grid_latents, labels=grid_labels, drange_net=drange_net, ratio=ratio, **vis_args) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl, remove=False) if cur_tick % network_snapshot_ticks == 0 or done: metric = metrics.run( pkl, num_imgs=eval_images_num, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, ratio=ratio, tf_config=tf_config, mirror_augment=mirror_augment) if last_snapshots > 0: misc.rm( sorted( glob.glob(dnnlib.make_run_dir_path( "network*.pkl")))[:-last_snapshots]) # Update summaries and RunContext if summarize: metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update(None, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path("network-final.pkl"), remove=False) # All done if summarize: summary_log.close() dataset.close()
def joint_train( submit_config, opt, metric_arg_list, sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 dataset_args = {}, # 数据集设置。 total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 mirror_augment = False, # 启用镜像增强? reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0, # 在训练开始时给定统计时间。影响报告。 *args, **kwargs ): output_dir = opt.output_dir graph_kwargs = util.set_graph_kwargs(opt) graph_util = importlib.import_module('graphs.' + opt.model + '.graph_util') constants = importlib.import_module('graphs.' + opt.model + '.constants') model = graphs.find_model_using_name(opt.model, opt.transform) g = model(submit_config=submit_config, dataset_args=dataset_args, **graph_kwargs, **kwargs) g.initialize_graph() # create training samples #num_samples = opt.num_samples # if opt.model == 'biggan' and opt.biggan.category is not None: # graph_inputs = graph_util.graph_input(g, num_samples, seed=0, category=opt.biggan.category) # else: # graph_inputs = graph_util.graph_input(g, num_samples, seed=0) w_snapshot_ticks = opt.model_save_freq ctx = dnnlib.RunContext(submit_config, train) training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(g.G, training_set, **grid_args) sched = training_loop.training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: g.G.setup_weight_histograms(); g.D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 loss_values = [] while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_loop.training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): g.G_opt.reset_optimizer_state(); # D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): alpha_for_graph, alpha_for_target = g.get_train_alpha(constants.BATCH_SIZE) if not isinstance(alpha_for_graph, list): alpha_for_graph = [alpha_for_graph] alpha_for_target = [alpha_for_target] for ag, at in zip(alpha_for_graph, alpha_for_target): feed_dict_out = graph_util.graph_input(g, constants.BATCH_SIZE, seed=0) out_zs = g.sess.run(g.outputs_orig, feed_dict_out) target_fn, mask_out = g.get_target_np(out_zs, at) feed_dict = feed_dict_out feed_dict[g.alpha] = ag feed_dict[g.target] = target_fn feed_dict[g.mask] = mask_out feed_dict[g.lod_in] = sched.lod feed_dict[g.lrate_in] = sched.D_lrate feed_dict[g.minibatch_in] = sched.minibatch curr_loss, _, Gs_op, G_op = g.sess.run([g.joint_loss, g.train_step, g.Gs_update_op, g.G_train_op], feed_dict=feed_dict) loss_values.append(curr_loss) cur_nimg += sched.minibatch #tflib.run([g.Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) #tflib.run([g.G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((g.G, g.D, g.Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) if cur_tick % w_snapshot_ticks == 0 or done: g.saver.save(g.sess, './{}/model_{}.ckpt'.format( output_dir, (cur_nimg // 1000)), write_meta_graph=False, write_state=False) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((g.G, g.D, g.Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close() loss_values = np.array(loss_values) np.save('./{}/loss_values.npy'.format(output_dir), loss_values) f, ax = plt.subplots(figsize=(10, 4)) ax.plot(loss_values) f.savefig('./{}/loss_values.png'.format(output_dir))
def training_loop_infernet( I_args={}, # Options for infogan-head/vcgan-head network. I_opt_args={}, # Options for discriminator optimizer. loss_args={}, # Options for discriminator loss. sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=5, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? G_pkl=None, # The G to load. resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Construct or load networks. with tf.device('/gpu:0'): G, D, I, Gs = misc.load_pkl(G_pkl) print('Gs.output_shapes:', Gs.output_shapes) if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') I = tflib.Network('I', num_channels=Gs.output_shapes[0][1], resolution=Gs.output_shapes[0][2], **I_args) if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) rI, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: I.copy_vars_from(rI) Gs.copy_vars_from(rGs) else: I = rI Gs = rGs # Print layers and generate initial image snapshot. Gs.print_layers() I.print_layers() # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. I_opt_args = dict(I_opt_args) I_opt_args['minibatch_multiplier'] = minibatch_multiplier I_opt_args['learning_rate'] = lrate_in I_opt = tflib.Optimizer(name='TrainI', **I_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow') # Evaluate loss functions. with tf.name_scope('I_loss'): loss, reg = dnnlib.util.call_func_by_name( G=G_gpu, I=I_gpu, opt=I_opt, minibatch_size=minibatch_gpu_in, **loss_args) if reg is not None: loss += reg # Register gradients. I_opt.register_gradients(tf.reduce_mean(loss), I_gpu.trainables) # Setup training ops. I_train_op = I_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: I.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. assert sched_args.minibatch_size % (sched_args.minibatch_gpu * num_gpus) == 0 # Run training ops. feed_dict = { lrate_in: sched_args.lrate, minibatch_size_in: sched_args.minibatch_size, minibatch_gpu_in: sched_args.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched_args.minibatch_size, sched_args.minibatch_gpu * num_gpus) cur_nimg += sched_args.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([I_train_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(I_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched_args.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched_args.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((I, G), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), num_gpus=num_gpus, tf_config=tf_config, train_infernet=True) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update(cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((I, G), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close()
def training_loop_vae( G_args={}, # Options for generator network. E_args={}, # Options for encoder network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. minibatch_repeats=1, # Number of minibatches to run before adjusting training parameters. total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=0, # Number of discrete latents in model. n_continuous=4, # Number of continuous latents in model. topk_dims_to_show=20, # Number of top disentant dimensions to show in a snapshot. subgroup_sizes_ls=None, subspace_sizes_ls=None, forward_eg=False, n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # If use Discriminator. use_D = D_args is not None # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) grid_fakes = add_outline(grid_reals, width=1) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') E = tflib.Network('E', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None] + training_set.shape, **E_args) G = tflib.Network( 'G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None, n_discrete + G_args.latent_size] if not forward_eg else [ None, n_discrete + G_args.latent_size + sum(subgroup_sizes_ls) ], **G_args) if use_D: D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None, D_args.latent_size], **D_args) if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_D: rE, rG, rD = misc.load_pkl(resume_pkl) else: rE, rG = misc.load_pkl(resume_pkl) if resume_with_new_nets: E.copy_vars_from(rE) G.copy_vars_from(rG) if use_D: D.copy_vars_from(rD) else: E = rE G = rG if use_D: D = rD # Print layers and generate initial image snapshot. E.print_layers() G.print_layers() if use_D: D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) if traversal_grid: if topk_dims_to_show > 0: topk_dims = np.arange(min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) print('grid_size:', grid_size) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v( G.run(append_gfeats(grid_latents, G) if forward_eg else grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True), 8) print('Lie_vars:', lie_vars[0]) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. G_opt_args = dict(G_opt_args) G_opt_args['minibatch_multiplier'] = minibatch_multiplier G_opt_args['learning_rate'] = lrate_in G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) if use_D: D_opt_args = dict(D_opt_args) D_opt_args['minibatch_multiplier'] = minibatch_multiplier D_opt_args['learning_rate'] = lrate_in D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if use_D: D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, 0., mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. if use_D: with tf.name_scope('G_loss'): G_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **G_loss_args) with tf.name_scope('D_loss'): D_loss = dnnlib.util.call_func_by_name( E=E_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) else: with tf.name_scope('G_loss'): G_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **G_loss_args) # Register gradients. EG_gpu_trainables = collections.OrderedDict( list(E_gpu.trainables.items()) + list(G_gpu.trainables.items())) G_opt.register_gradients(tf.reduce_mean(G_loss), EG_gpu_trainables) # G_opt.register_gradients(G_loss, # EG_gpu_trainables) if use_D: D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # D_opt.register_gradients(D_loss, # D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() if use_D: D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() if use_D: D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, 0) # Run training ops. feed_dict = { lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op], feed_dict) tflib.run([data_fetch_op], feed_dict) if use_D: tflib.run([D_train_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) if use_D: tflib.run(D_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) if use_D: misc.save_pkl((E, G, D), pkl) else: misc.save_pkl((E, G), pkl) met_outs = metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config, is_vae=True, use_D=use_D, Gs_kwargs=dict(is_validation=True)) if topk_dims_to_show > 0: if 'tpl_per_dim' in met_outs: avg_distance_per_dim = met_outs[ 'tpl_per_dim'] # shape: (n_continuous) topk_dims = np.argsort( avg_distance_per_dim )[::-1][:topk_dims_to_show] # shape: (20) else: topk_dims = np.arange( min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): if traversal_grid: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v( G.run(append_gfeats(grid_latents, G) if forward_eg else grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True), 8) print('Lie_vars:', lie_vars[0]) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. if use_D: misc.save_pkl((E, G, D), dnnlib.make_run_dir_path('network-final.pkl')) else: misc.save_pkl((E, G), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. G_opt_args = {}, # Options for generator optimizer. D_opt_args = {}, # Options for discriminator optimizer. G_loss_args = {}, # Options for generator loss. D_loss_args = {}, # Options for discriminator loss. dataset_args = {}, # Options for dataset.load_dataset(). sched_args = {}, # Options for train.TrainingSchedule. grid_args = {}, # Options for train.setup_snapshot_image_grid(). metric_arg_list = [], # Options for MetricGroup. tf_config = {}, # Options for tflib.init_tf(). data_dir = None, # Directory to load datasets from. G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. lazy_regularization = True, # Perform regularization as a separate training step? G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg = 25000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? drange_net = [0,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_pkl = None, # Network pickle to resume training from, None = train from scratch. resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. training_set.configure(minibatch_size=sched_args.batch_size) with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') start = 0 if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG); D.copy_vars_from(rD); Gs.copy_vars_from(rGs) else: G = rG; D = rD; Gs = rGs start = int(resume_pkl.split('-')[-1].split('.')[0]) // sched_args.batch_size # Print layers and generate initial image snapshot. G.print_layers(); D.print_layers() grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched_args.batch_size) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) global_step = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(sched_args.lr, global_step, sched_args.decay_step, sched_args.decay_rate, staircase=sched_args.stair) add_global = global_step.assign_add(1) D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) G_opt = tflib.Optimizer(name='TrainG', learning_rate=learning_rate, **G_opt_args) for gpu in range(num_gpus): print('build graph on gpu %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') with tf.name_scope('DataFetch'): reals_read, labels_read = training_set.get_minibatch_tf() reals_read, labels_read = process_reals(reals_read, labels_read, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('Loss'), tf.control_dependencies(None): loss, reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=sched_args.batch_size, reals=reals_read, labels=labels_read, **D_loss_args) with tf.control_dependencies([add_global]): G_opt.register_gradients(loss, G_gpu.trainables) D_opt.register_gradients(loss, D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break loss_, _, _, lr_ = tflib.run([loss, G_train_op, D_train_op, learning_rate]) cur_nimg += sched_args.batch_size * num_gpus done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched_args.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f ' 'sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f loss %-8.1f lr %-2.5f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched_args.batch_size), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2 ** 30), autosummary('loss', loss_), autosummary('lr', lr_))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched_args.batch_size) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0.0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device("/gpu:0"): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print("Constructing networks...") G = tflib.Network( name="G", num_inputs=2, # one for latents and one for labels num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network( name="D", num_inputs=int(np.log2(training_set.shape[1])) - 1 + 1, # +1 for labels :) num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone("Gs") G.print_layers() D.print_layers() print("Building TensorFlow graph...") with tf.name_scope("Inputs"), tf.device("/cpu:0"): lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[]) minibatch_in = tf.placeholder(tf.int32, name="minibatch_in", shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = (0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0) G_opt = tflib.Optimizer(name="TrainG", learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name="TrainD", learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow") D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow") reals, labels = training_set.get_minibatch_tf() reals = process_reals( reals, mirror_augment, training_set.dynamic_range, drange_net, depth=training_set.resolution_log2 - 1, ) with tf.name_scope("G_loss"): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope("D_loss"): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device("/gpu:0"): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) print("Setting up snapshot image grid...") grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_size // submit_config.num_gpus, ) print("Setting up run dir...") fake_multi_scale_dirs = [ os.path.join(submit_config.run_dir, str(2**res) + "x" + str(2**res)) for res in range(2, 2 + len(grid_fakes)) ] misc.save_image_grid( grid_reals, os.path.join(submit_config.run_dir, "reals.png"), drange=training_set.dynamic_range, grid_size=grid_size, ) misc.save_image_grids( grid_fakes, [ os.path.join(fake_multi_scale_dir, "fakes%06d.png" % resume_kimg) for fake_multi_scale_dir in fake_multi_scale_dirs ], drange=drange_net, grid_size=grid_size, ) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print("Training...\n") ctx.update("", cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg # configure the training_set to a proper minibatch size training_set.configure(sched.minibatch_size // submit_config.num_gpus) while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lrate_in: sched.D_lrate, minibatch_in: sched.minibatch_size }, ) cur_nimg += sched.minibatch_size tflib.run( [G_train_op], { lrate_in: sched.G_lrate, minibatch_in: sched.minibatch_size }, ) # Perform maintenance tasks once per tick. done = cur_nimg >= total_kimg * 1000 if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( "tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f" % ( autosummary("Progress/tick", cur_tick), autosummary("Progress/kimg", cur_nimg / 1000.0), autosummary("Progress/minibatch", sched.minibatch_size), dnnlib.util.format_time( autosummary("Timing/total_sec", total_time)), autosummary("Timing/sec_per_tick", tick_time), autosummary("Timing/sec_per_kimg", tick_time / tick_kimg), autosummary("Timing/maintenance_sec", maintenance_time), autosummary("Resources/peak_gpu_mem_gb", peak_gpu_mem_op.eval() / 2**30), )) autosummary("Timing/total_hours", total_time / (60.0 * 60.0)) autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_size // submit_config.num_gpus, ) misc.save_image_grids( grid_fakes, [ os.path.join(fake_multi_scale_dir, "fakes%06d.png" % (cur_nimg // 1000)) for fake_multi_scale_dir in fake_multi_scale_dirs ], drange=drange_net, grid_size=grid_size, ) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, "network-snapshot-%06d.pkl" % (cur_nimg // 1000), ) misc.save_pkl((G, D, Gs), pkl) metrics.run( pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config, ) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update(cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, "network-final.pkl")) summary_log.close() ctx.close()
def training_loop_refinement( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=True, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) _rG, _rD, rGs = misc.load_pkl(resume_pkl) del _rD, _rG if resume_with_new_nets: G.copy_vars_from(rGs) Gs.copy_vars_from(rGs) del rGs else: G = rG Gs = rGs # Set constant noise input for both G and Gs if G_args.get("randomize_noise", None) == False: noise_vars = [ var for name, var in G.components.synthesis.vars.items() if name.startswith('noise') ] rnd = np.random.RandomState(123) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] noise_vars = [ var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise') ] rnd = np.random.RandomState(123) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] # TESTS # from PIL import Image # reals, latents = training_set.get_minibatch_np(4) # reals = np.transpose(reals, [0, 2, 3, 1]) # Image.fromarray(reals[0], 'RGB').save("test_reals.png") # labels = training_set.get_random_labels_np(4) # Gs_kwargs = dnnlib.EasyDict() # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) # fakes = Gs.run(latents, labels, minibatch_size=4, **Gs_kwargs) # Image.fromarray(fakes[0], 'RGB').save("test_fakes_Gs_new.png") # fakes = G.run(latents, labels, minibatch_size=4, **Gs_kwargs) # Image.fromarray(fakes[0], 'RGB').save("test_fakes_G_new.png") # Print layers and generate initial image snapshot. G.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) # Freeze layers G_args.freeze_layers = list(G_args.get("freeze_layers", [])) def freeze_vars(gen, verbose=True): assert len(G_args.freeze_layers) > 0 for name in list(gen.trainables.keys()): if any(layer in name for layer in G_args.freeze_layers): del gen.trainables[name] if verbose: print(f"Freezed {name}") # Build training graph for each GPU. data_fetch_ops = [] loss_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if G_args.freeze_layers: freeze_vars(G_gpu, verbose=False) # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=None, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, latents=labels_read, **G_loss_args) loss_ops.append(G_loss) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) loss_op = tf.reduce_mean(tf.concat(loss_ops, axis=0)) G_train_op = G_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 loss_per_batch_sum = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } tflib.run(data_fetch_op, feed_dict) ### TEST # fakes = G.get_output_for(labels_read, training_set.get_random_labels_tf(minibatch_gpu_in), is_training=True) # this is without activation in ~[-1.5, 1.5] # fakes = tf.clip_by_value(fakes, drange_net[0], drange_net[1]) # reals = reals_read ### TEST for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: loss, _ = tflib.run([loss_op, G_train_op], feed_dict) # (loss, reals, fakes), _ = tflib.run([loss_op, G_train_op], feed_dict) tflib.run([data_fetch_op], feed_dict) # print(f"loss_tf {np.mean(loss)}") # print(f"loss_np {np.mean(np.square(reals - fakes))}") # print(f"loss_abs {np.mean(np.abs(reals - fakes))}") loss_per_batch_sum += loss #### TEST #### # if cur_nimg == sched.minibatch_size or cur_nimg % 2048 == 0: # from PIL import Image # reals = np.transpose(reals, [0, 2, 3, 1]) # fakes = np.transpose(fakes, [0, 2, 3, 1]) # diff = np.abs(reals - fakes) # print(diff.min(), diff.max()) # for idx, (fake, real) in enumerate(zip(fakes, reals)): # fake -= fake.min() # fake /= fake.max() # fake *= 255 # fake = fake.astype(np.uint8) # Image.fromarray(fake, 'RGB').save(f"fake_loss_{idx}.png") # real -= real.min() # real /= real.max() # real *= 255 # real = real.astype(np.uint8) # Image.fromarray(real, 'RGB').save(f"real_loss_{idx}.png") #### if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([Gs_update_op], feed_dict) # Slow path with gradient accumulation. FIXME: Probably wrong else: for _round in rounds: loss, _, _ = tflib.run( [loss_op, G_train_op, data_fetch_op], feed_dict) loss_per_batch_sum += loss / len(rounds) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time tick_loss = loss_per_batch_sum * sched.minibatch_size / ( tick_kimg * 1000) loss_per_batch_sum = 0 # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d loss/px %-12.8f time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), autosummary('Progress/loss_per_px', tick_loss), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, None, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, None, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop_vc2( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. I_args={}, # Options for infogan-head/vcgan-head network. I_info_args={}, # Options for infogan-head/vcgan-head network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. # G2_opt_args={}, # Options for generator2 optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). use_info_gan=False, # Whether to use info-gan. use_vc_head=False, # Whether to use vc-head. use_vc2_info_gan=False, # Whether to use vc2 infogan. use_perdis=False, # Whether use perceptual distance network. data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. # G2_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=3, # Number of discrete latents in model. n_continuous=4, # Number of continuous latents in model. return_atts=False, # If return attention maps. return_I_atts=False, # If return I_attention maps of vpex. avg_mv_for_I=False, # If use average moving for I. opt_reset_ls=None, # Reset lr list for gradual latents. topk_dims_to_show=20, # Number of top disentant dimensions to show in a snapshot. cascade_alt_freq_k=1, # Frequency in k for cascade_dim altering. n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # If include I include_I = use_info_gan or use_vc_head # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) grid_fakes = add_outline(grid_reals, width=1) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') print('G_args:', G_args) G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) if include_I: I = tflib.Network('I', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **I_args) if avg_mv_for_I: Is = I.clone('Is') elif use_perdis: DM = misc.load_pkl( 'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/vgg16_zhang_perceptual.pkl' ) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if include_I: if avg_mv_for_I: rG, rD, rI, rGs, rIs = misc.load_pkl(resume_pkl) else: rG, rD, rI, rGs = misc.load_pkl(resume_pkl) else: rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) if include_I: I.copy_vars_from(rI) if avg_mv_for_I: Is.copy_vars_from(rIs) Gs.copy_vars_from(rGs) else: G = rG D = rD if include_I: I = rI if avg_mv_for_I: Is = rIs Gs = rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() if include_I: I.print_layers() if use_perdis: DM.print_layers() # pdb.set_trace() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) if traversal_grid: if topk_dims_to_show > 0: topk_dims = np.arange(min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) print('topk_dims_to_show:', topk_dims_to_show) grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) print('grid_size:', grid_size) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) # pdb.set_trace() if return_atts: grid_fakes, atts = get_return_v( Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, return_atts=True, resolution=training_set.shape[1]), 2) # atts: [b, n_latents, 1, res, res] atts = atts[:, topk_dims] save_atts(atts, filename=dnnlib.make_run_dir_path('fakes_atts_init.png'), grid_size=grid_size, drange=[0, 1], grid_fakes=grid_fakes, n_samples_per=n_samples_per) else: grid_fakes = get_return_v( Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True), 1) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) if include_I and return_I_atts: if avg_mv_for_I: I_tmp = Is else: I_tmp = I _, atts = get_return_v( I_tmp.run(grid_fakes, grid_fakes, grid_latents, is_validation=True, minibatch_size=sched.minibatch_gpu, return_atts=True, resolution=training_set.shape[1]), 2) save_atts(atts, filename=dnnlib.make_run_dir_path('fakes_I_atts_init.png'), grid_size=grid_size, drange=[0, 1], grid_fakes=grid_fakes, n_samples_per=n_samples_per) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 cascade_dim = tf.placeholder(tf.int32, name='cascade_dim', shape=[]) # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) if use_vc2_info_gan: G2_opt = tflib.Optimizer(name='TrainG2', share=G_opt, **G_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') if include_I: I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') if use_perdis: DM_gpu = DM if gpu == 0 else DM.clone(DM.name + '_shadow') else: DM_gpu = None # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): if include_I: G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, I=I_gpu, DM=DM_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, cascade_dim=cascade_dim, **G_loss_args) else: G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, DM=DM_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) if use_vc2_info_gan: with tf.name_scope('G2_loss'): G2_loss, _ = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G2_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, is_G2_loss=True, **G_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) if include_I: GI_gpu_trainables = collections.OrderedDict( list(G_gpu.trainables.items()) + list(I_gpu.trainables.items())) G_opt.register_gradients(tf.reduce_mean(G_loss), GI_gpu_trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) elif use_vc2_info_gan: GD_gpu_trainables = collections.OrderedDict( list(G_gpu.trainables.items()) + list(D_gpu.trainables.items())) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) G2_opt.register_gradients(tf.reduce_mean(G2_loss), GD_gpu_trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) else: G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) if avg_mv_for_I: Is_update_op = Is.setup_as_moving_average_of(I, beta=Gs_beta) if use_vc2_info_gan: G2_train_op = G2_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() if include_I: I.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() # if opt_reset_ls is not None: # if cur_nimg in opt_reset_ls: # G_opt.reset_optimizer_state() # D_opt.reset_optimizer_state() prev_lod = sched.lod # Calculate which cascade_dim is to use. cur_nimg_k = cur_nimg // int(cascade_alt_freq_k * 1000) sched_cascade_dim = cur_nimg_k % n_continuous # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, cascade_dim: sched_cascade_dim } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) if avg_mv_for_I: tflib.run([D_train_op, Gs_update_op, Is_update_op], feed_dict) else: tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) if use_vc2_info_gan: tflib.run(G2_train_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) if avg_mv_for_I: tflib.run([Gs_update_op, Is_update_op], feed_dict) else: tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) if use_vc2_info_gan: for _round in rounds: tflib.run(G2_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) if include_I: if avg_mv_for_I: misc.save_pkl((G, D, I, Gs, Is), pkl) else: misc.save_pkl((G, D, I, Gs), pkl) else: misc.save_pkl((G, D, Gs), pkl) met_outs = metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config, include_I=include_I, avg_mv_for_I=avg_mv_for_I, Gs_kwargs=dict(is_validation=True, return_atts=False), mapping_nodup=True) if topk_dims_to_show > 0: if 'tpl_per_dim' in met_outs: avg_distance_per_dim = met_outs[ 'tpl_per_dim'] # shape: (n_continuous) topk_dims = np.argsort( avg_distance_per_dim )[::-1][:topk_dims_to_show] # shape: (20) else: topk_dims = np.arange( min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): if traversal_grid: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) if return_atts: grid_fakes, atts = get_return_v( Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, return_atts=True, resolution=training_set.shape[1]), 2) # atts: [b, n_latents, 1, res, res] atts = atts[:, topk_dims] save_atts(atts, filename=dnnlib.make_run_dir_path( 'fakes_atts%06d.png' % (cur_nimg // 1000)), grid_size=grid_size, drange=[0, 1], grid_fakes=grid_fakes, n_samples_per=n_samples_per) else: grid_fakes = get_return_v( Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True), 1) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if include_I and return_I_atts: if avg_mv_for_I: I_tmp = Is else: I_tmp = I _, atts = get_return_v( I_tmp.run(grid_fakes, grid_fakes, grid_latents, is_validation=True, minibatch_size=sched.minibatch_gpu, return_atts=True, resolution=training_set.shape[1]), 2) atts = atts[:, topk_dims] save_atts(atts, filename=dnnlib.make_run_dir_path( 'fakes_I_atts%06d.png' % (cur_nimg // 1000)), grid_size=grid_size, drange=[0, 1], grid_fakes=grid_fakes, n_samples_per=n_samples_per) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. if include_I: if avg_mv_for_I: misc.save_pkl((G, D, I, Gs, Is), dnnlib.make_run_dir_path('network-final.pkl')) else: misc.save_pkl((G, D, I, Gs), dnnlib.make_run_dir_path('network-final.pkl')) else: misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def embed(batch_size, resolution, imgs, network, iteration, result_dir, seed=6600): tf.reset_default_graph() print('Loading networks from "%s"...' % network) tflib.init_tf() _, _, G = pretrained_networks.load_networks(network) img_in = tf.placeholder(tf.float32) opt = tf.train.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8) noise_vars = [ var for name, var in G.components.synthesis.vars.items() if name.startswith('noise') ] alpha_vars = [ var for name, var in G.components.synthesis.vars.items() if name.endswith('alpha') ] alpha_eval = [alpha.eval() for alpha in alpha_vars] G_kwargs = dnnlib.EasyDict() G_kwargs.randomize_noise = False G_syn = G.components.synthesis rnd = np.random.RandomState(seed) dlatent_avg = [ var for name, var in G.vars.items() if name.startswith('dlatent_avg') ][0].eval() dlatent_avg = np.expand_dims(np.expand_dims(dlatent_avg, 0), 1) dlatent_avg = dlatent_avg.repeat(12, 1) dlatent = tf.get_variable('dlatent', dtype=tf.float32, initializer=tf.constant(dlatent_avg), trainable=True) synth_img = G_syn.get_output_for(dlatent, is_training=False, **G_kwargs) # synth_img = (synth_img + 1.0) / 2.0 with tf.variable_scope('mse_loss'): mse_loss = tf.reduce_mean(tf.square(img_in - synth_img)) with tf.variable_scope('perceptual_loss'): vgg_in = tf.concat([img_in, synth_img], 0) tf.keras.backend.set_image_data_format('channels_first') vgg = tf.keras.applications.VGG16( include_top=False, input_tensor=vgg_in, input_shape=(3, 128, 128), weights='/gdata2/fengrl/metrics/vgg.h5', pooling=None) h1 = vgg.get_layer('block1_conv1').output h2 = vgg.get_layer('block1_conv2').output h3 = vgg.get_layer('block3_conv2').output h4 = vgg.get_layer('block4_conv2').output pcep_loss = tf.reduce_mean(tf.square(h1[0] - h1[1])) + tf.reduce_mean(tf.square(h2[0] - h2[1])) + \ tf.reduce_mean(tf.square(h3[0] - h3[1])) + tf.reduce_mean(tf.square(h4[0] - h4[1])) loss = 0.5 * mse_loss + 0.5 * pcep_loss with tf.control_dependencies([loss]): train_op = opt.minimize(loss, var_list=[dlatent]) reset_opt = tf.variables_initializer(opt.variables()) reset_dl = tf.variables_initializer([dlatent]) tflib.init_uninitialized_vars() # rnd = np.random.RandomState(seed) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] idx = 0 metrics_l = [] metrics_p = [] metrics_m = [] metrics_d = [] metrics_args = [metric_defaults[x] for x in ['fid50k', 'ppl_wend']] metrics_fun = metric_base.MetricGroup(metrics_args) for temperature in [0.2, 0.5, 1.0, 1.5, 2.0, 10.0]: tflib.set_vars({ alpha: scale_alpha(alpha_np, temperature) for alpha, alpha_np in zip(alpha_vars, alpha_eval) }) # misc.save_pkl((G, G, G), os.path.join(result_dir, 'temp%f.pkl' % temperature)) # metrics_fun.run(os.path.join(result_dir, 'temp%f.pkl' % temperature), run_dir=result_dir, # data_dir='/gdata/fengrl/noise_test_dset/tfrecords', # dataset_args=dnnlib.EasyDict(tfrecord_dir='ffhq-128', shuffle_mb=0), # mirror_augment=True, num_gpus=1) for img in imgs: img = np.expand_dims(img, 0) loss_list = [] p_loss_list = [] m_loss_list = [] dl_list = [] si_list = [] tflib.run([reset_opt, reset_dl]) for i in range(iteration): loss_, p_loss_, m_loss_, dl_, si_, _ = tflib.run( [loss, pcep_loss, mse_loss, dlatent, synth_img, train_op], {img_in: img}) loss_list.append(loss_) p_loss_list.append(p_loss_) m_loss_list.append(m_loss_) dl_loss_ = np.sum(np.square(dl_ - dlatent_avg)) dl_list.append(dl_loss_) if i % 500 == 0: si_list.append(si_) if i % 100 == 0: print( 'Temperature %f, idx %d, Loss %f, mse %f, ppl %f, dl %f, step %d' % (temperature, idx, loss_, m_loss_, p_loss_, dl_loss_, i)) print('Temperature %f, idx %d, loss: %f, ppl: %f, mse: %f, d: %f' % (temperature, idx, loss_list[-1], p_loss_list[-1], m_loss_list[-1], dl_list[-1])) metrics_l.append(loss_list[-1]) metrics_p.append(p_loss_list[-1]) metrics_m.append(m_loss_list[-1]) metrics_d.append(dl_list[-1]) misc.save_image_grid(np.concatenate(si_list, 0), os.path.join( result_dir, 'temp%fsi%d.png' % (temperature, idx)), drange=[-1, 1]) misc.save_image_grid( si_list[-1], os.path.join(result_dir, 'temp%fsifinal%d.png' % (temperature, idx)), drange=[-1, 1]) with open( os.path.join(result_dir, 'temp%fmetric_l%d.txt' % (temperature, idx)), 'w') as f: for l_ in loss_list: f.write(str(l_) + '\n') with open( os.path.join(result_dir, 'temp%fmetric_p%d.txt' % (temperature, idx)), 'w') as f: for l_ in p_loss_list: f.write(str(l_) + '\n') with open( os.path.join(result_dir, 'temp%fmetric_m%d.txt' % (temperature, idx)), 'w') as f: for l_ in m_loss_list: f.write(str(l_) + '\n') with open( os.path.join(result_dir, 'temp%fmetric_d%d.txt' % (temperature, idx)), 'w') as f: for l_ in dl_list: f.write(str(l_) + '\n') idx += 1 l_mean = np.mean(metrics_l) p_mean = np.mean(metrics_p) m_mean = np.mean(metrics_m) d_mean = np.mean(metrics_d) with open( os.path.join(result_dir, 'Temp%fmetric_lmpd.txt' % temperature), 'w') as f: for i in range(len(metrics_l)): f.write( str(metrics_l[i]) + ' ' + str(metrics_m[i]) + ' ' + str(metrics_p[i]) + ' ' + str(metrics_d[i]) + '\n') print( 'Overall metrics: temp %f, loss_mean %f, ppl_mean %f, mse_mean %f, d_mean %f' % (temperature, l_mean, p_mean, m_mean, d_mean)) with open(os.path.join(result_dir, 'mean_metrics'), 'a') as f: f.write('Temperature %f\n' % temperature) f.write('loss %f\n' % l_mean) f.write('mse %f\n' % m_mean) f.write('ppl %f\n' % p_mean) f.write('dl %f\n' % d_mean)
def training_loop( submit_config, HP_args={}, # Options for the Hessian Penalty. G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? interp_snapshot_ticks=20, # How often to generate interpolation visualizations in TensorBoard? network_snapshot_ticks=5, # How often to export network snapshots? network_metric_ticks=5, # How often to evaluate network snapshots on specified metrics? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create a copy of dataset_args for running the metrics: metrics_dataset_args = deepcopy(dataset_args) metrics_dataset_args.shuffle_mb = 0 print('Saving interp videos every %s ticks' % interp_snapshot_ticks) print('Saving network snapshot every %s ticks' % network_snapshot_ticks) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # G.print_layers(); D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # The loss weighting of the Hessian Penalty can be dynamic over training, so we need to use a placeholder: hessian_weight = tf.placeholder(tf.float32, name='hessian_weight', shape=[]) G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) reg_ops = [ ] # Returning the values of the Hessian Penalty/ InfoGAN losses so they can be registered in TensorBoard for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss, G_hessian_penalty = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, hp_lambda=hessian_weight, HP_args=HP_args, gpu_ix=gpu, lod_in=lod_in, max_lod=training_set.resolution_log2, **G_loss_args) if HP_args.hp_lambda > 0: reg_ops += [G_hessian_penalty] with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss, mutual_info = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, gpu_ix=gpu, infogan_nz=D_args.infogan_nz, **D_loss_args) # print([name for name in D_gpu.trainables.keys()]) # gps = [weight for name, weight in G_gpu.trainables.items()][0] # dps = [weight for name, weight in D_gpu.trainables.items() if 'Q_Encoder' in name][0] # gg = autosummary('Loss/G_info_grad', tf.reduce_sum(tf.gradients(mutual_info, gps)[0]**2)) # dg = autosummary('Loss/D_info_grad', tf.reduce_sum(tf.gradients(mutual_info, dps)[0]**2)) # reg_ops.extend([dg, gg, dps, gps]) if G_args.infogan_lambda > 0 or D_args.infogan_lambda > 0: reg_ops += [mutual_info] # Note, even though we are adding mutual_info loss here, the only time the loss is non-zero # is when infogan_lambda > 0 (in Hessian Penalty experiments, we always set it to zero): G_opt.register_gradients( G_loss + G_args.infogan_lambda * mutual_info, G_gpu.trainables) D_opt.register_gradients( tf.reduce_mean(D_loss) + D_args.infogan_lambda * mutual_info, D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up snapshot interpolation...') nz = G.input_shapes[0][1] interp_steps = 24 # Number of frames in the visualization interp_batch_size = 8 # Number of gifs per row of visualization interp_z = vis_tools.sample_interp_zs(nz, interp_batch_size, interp_steps) interp_labels = np.zeros( [interp_steps * interp_batch_size * nz, training_set.label_size], dtype=training_set.label_dtype) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) summary_log.add_summary( build_image_summary(os.path.join(submit_config.run_dir, 'reals.png'), 'samples/real'), 0) summary_log.add_summary( build_image_summary( os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), 'samples/Gs'), resume_kimg) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) if interp_snapshot_ticks != -1 and interp_snapshot_ticks < 9999: print('Generating initial interpolations...') vis_tools.make_and_save_interpolation_gifs( Gs, interp_z, interp_labels, minibatch_size=sched.minibatch // submit_config.num_gpus, interp_steps=interp_steps, interp_batch_size=interp_batch_size, cur_kimg=resume_kimg, summary_log=summary_log) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 num_G_grad_steps = 0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch cur_hessian_weight = get_current_hessian_penalty_loss_weight( HP_args.hp_lambda, HP_args.hp_start_nimg, cur_nimg, HP_args.warmup_nimg) tflib.run( [G_train_op] + reg_ops, { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, hessian_weight: cur_hessian_weight }) num_G_grad_steps += 1 # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d hessian_weight %s time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), autosummary('Progress/hessian_weight', cur_hessian_weight), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) autosummary('Progress/G_grad_steps', num_G_grad_steps) # Save snapshots and fake image samples (for both Gs and G): if cur_tick % image_snapshot_ticks == 0 or done: iter = (cur_nimg // 1000) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes_inst = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) fake_path = os.path.join(submit_config.run_dir, 'fakes%06d.png' % iter) ifake_path = os.path.join(submit_config.run_dir, 'ifakes%06d.png' % iter) misc.save_image_grid(grid_fakes, fake_path, drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_fakes_inst, ifake_path, drange=drange_net, grid_size=grid_size) summary_log.add_summary( build_image_summary(fake_path, 'samples/Gs'), iter) summary_log.add_summary( build_image_summary(ifake_path, 'samples/G'), iter) # Generate/Save Interpolation Visualizations: if interp_snapshot_ticks != -1 and cur_tick % interp_snapshot_ticks == 0: vis_tools.make_and_save_interpolation_gifs( Gs, interp_z, interp_labels, minibatch_size=sched.minibatch // submit_config.num_gpus, interp_steps=interp_steps, interp_batch_size=interp_batch_size, cur_kimg=cur_nimg // 1000, summary_log=summary_log) # Save snapshot and run metrics: if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) if cur_tick % network_metric_ticks == 0 or done or cur_tick == 1: metrics.run(pkl, dataset_args=metrics_dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % total_kimg)) summary_log.close() ctx.close()
def training_loop( submit_config, #--------------------------------------------------------------- # Modified by Deng et al. noise_dim=32, weight_args={}, train_stage_args={}, #--------------------------------------------------------------- G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=True, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=87, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=2364, # Snapshot index to resume training from, None = autodetect. resume_kimg=2364, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, **_kwargs ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. PI = 3.1415927 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create 3d face reconstruction block FaceRender = Face3D() # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') #--------------------------------------------------------------- # Modified by Deng et al. G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, latent_size=254 + noise_dim, **G_args) #--------------------------------------------------------------- D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) resolution = tf.placeholder(tf.float32, name='resolution', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) #--------------------------------------------------------------- # Modified by Deng et al. G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\ G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\ lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\ drange_net=drange_net,lod_in=lod_in,**train_stage_args) #--------------------------------------------------------------- G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) #--------------------------------------------------------------- # Modified by Deng et al. restore_weights_and_initialize(train_stage_args) print('Setting up snapshot image grid...') sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_latents = tf.random_normal([np.prod(grid_size), 128 + 32 + 16 + 3]) grid_INPUTcoeff = z_to_lambda_mapping(grid_latents) grid_INPUTcoeff_w_t = tf.concat( [grid_INPUTcoeff, tf.zeros([np.prod(grid_size), 3])], axis=1) with tf.name_scope('FaceRender'): grid_render_img, _, _, _ = FaceRender.Reconstruction_Block( grid_INPUTcoeff_w_t, 256, np.prod(grid_size), progressive=False) grid_render_img = tf.transpose(grid_render_img, perm=[0, 3, 1, 2]) grid_render_img = process_reals(grid_render_img, lod_in, False, training_set.dynamic_range, drange_net) grid_INPUTcoeff_, grid_renders = tflib.run( [grid_INPUTcoeff, grid_render_img], {lod_in: sched.lod}) grid_noise = np.random.randn(np.prod(grid_size), 32) grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_, grid_noise], axis=1) grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) #--------------------------------------------------------------- summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) # print('iter') # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: #--------------------------------------------------------------- # Modified by Deng et al. grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) #--------------------------------------------------------------- if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close() #----------------------------------------------------------------------------
peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod:
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=10, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) # ajay - move init to after graph creation? tflib.init_tf(tf_config) # Load training set. print('ajay - config data dir', config.data_dir) training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, num_hosts=hvd.size(), index=hvd.rank(), **dataset_args) # Construct networks. print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # with tf.device('/gpu:0'): # if resume_run_id is not None: # network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) # print('Loading networks from "%s"...' % network_pkl) # G, D, Gs = misc.load_pkl(network_pkl) # else: # print('Constructing networks...') # G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) # D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) # Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tf.train.AdamOptimizer(learning_rate=lrate_in, beta1=0.0, beta2=0.99, epsilon=1e-8) G_opt = hvd.DistributedOptimizer(G_opt) D_opt = tf.train.AdamOptimizer(learning_rate=lrate_in, beta1=0.0, beta2=0.99, epsilon=1e-8) D_opt = hvd.DistributedOptimizer(D_opt) G_gpu = G D_gpu = D lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] # ajay - check if unique minibatch is guaranteed i.e sharding is done right! reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_grads = G_opt.compute_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_grads = D_opt.compute_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_gradients(G_grads) D_train_op = D_opt.apply_gradients(D_grads) # Horovod init_op = tf.initialize_all_variables() bcast_op = hvd.broadcast_global_variables(0) # ajay tf.get_default_session().run([init_op]) tflib.run([bcast_op]) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) # todo: ajay - note num_gpus need to change to hvd size when going multi-node sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) if hvd.rank() == 0: print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < (total_kimg * 1000): if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) # todo: ajay - find a way to manually resetoptimizer if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): tflib.assert_tf_initialized() G_opt_reset_op = [var.initializer for var in G_opt.variables()] D_opt_reset_op = [var.initializer for var in D_opt.variables()] tflib.run(G_opt_reset_op) tflib.run(D_opt_reset_op) # G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # grp_train_op = tf.group(D_train_op, [Gs_update_op]) # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch #// submit_config.num_gpus tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= (total_kimg * 1000)) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. # ajay #ajay mod if hvd.rank() == 0: print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f ' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( # autosummary('Progress/tick', cur_tick), # autosummary('Progress/kimg', cur_nimg / 1000.0), # autosummary('Progress/lod', sched.lod), # autosummary('Progress/minibatch', sched.minibatch), # dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), # autosummary('Timing/sec_per_tick', tick_time), # autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), # autosummary('Timing/maintenance_sec', maintenance_time), # autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) # autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) # autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) # ajay - note modifying to 1 for eval metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=1, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() if hvd.rank() == 0: tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. if hvd.rank() == 0: misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. G_opt_args = {}, # Options for generator optimizer. D_opt_args = {}, # Options for discriminator optimizer. G_loss_args = {}, # Options for generator loss. D_loss_args = {}, # Options for discriminator loss. dataset_args = {}, # Options for dataset.load_dataset(). sched_args = {}, # Options for train.TrainingSchedule. grid_args = {}, # Options for train.setup_snapshot_image_grid(). metric_arg_list = [], # Options for MetricGroup. tf_config = {}, # Options for tflib.init_tf(). data_dir = None, # Directory to load datasets from. G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. lazy_regularization = True, # Perform regularization as a separate training step? G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg = 25000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_pkl = None, # Network pickle to resume training from, None = train from scratch. resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tfex.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG); D.copy_vars_from(rD); Gs.copy_vars_from(rGs) else: G = rG; D = rD; Gs = rGs # Print layers and generate initial image snapshot. G.print_layers(); D.print_layers() sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tfex.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tfex.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args) reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tfex.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu} for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time # Report progress. print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop_hd( I_args={}, # Options for generator network. M_args={}, # Options for discriminator network. I_info_args={}, # Options for class network. I_opt_args={}, # Options for generator optimizer. I_loss_args={}, # Options for generator loss. resume_G_pkl=None, # G network pickle to help training. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. I_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? I_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to I_args and M_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=0, # Number of discrete latents in model. n_continuous=10, # Number of continuous latents in model. n_samples_per=4, # Number of samples for each line in traversal. use_hd_with_cls=False, # If use info_loss. resolution_manual=1024, # Resolution of generated images. use_level_training=False, # If use level training (hierarchical optimization strategy). level_I_kimg=1000, # Number of kimg of tick for I_level training. use_std_in_m=False, # If output prior std in M net. prior_latent_size=512, # Prior latent size. use_hyperplane=False, # If use hyperplane model. pretrained_type='with_stylegan2'): # Pretrained type for G. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') I = tflib.Network('I', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **I_args) M = tflib.Network('M', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **M_args) Is = I.clone('Is') if use_hd_with_cls: I_info = tflib.Network('I_info', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **I_info_args) if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_hd_with_cls: rI, rM, rIs, rI_info = misc.load_pkl(resume_pkl) else: rI, rM, rIs = misc.load_pkl(resume_pkl) if resume_with_new_nets: I.copy_vars_from(rI) M.copy_vars_from(rM) Is.copy_vars_from(rIs) if use_hd_with_cls: I_info.copy_vars_from(rI_info) else: I = rI M = rM Is = rIs if use_hd_with_cls: I_info = rI_info print('Loading generator from "%s"...' % resume_G_pkl) if pretrained_type == 'with_stylegan2': rG, rD, rGs = misc.load_pkl(resume_G_pkl) G = rG D = rD Gs = rGs elif pretrained_type == 'with_cascadeVAE': rG = misc.load_pkl(resume_G_pkl) G = rG Gs = rG # Print layers and generate initial image snapshot. I.print_layers() M.print_layers() # pdb.set_trace() training_set_resolution_log2 = int(np.log2(resolution_manual)) sched = training_schedule( cur_nimg=total_kimg * 1000, training_set_resolution_log2=training_set_resolution_log2, **sched_args) if not use_hyperplane: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels) print('grid_size:', grid_size) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) if resolution_manual >= 256: grid_size = (grid_size[0], grid_size[1] // 5) grid_latents = grid_latents[:grid_latents.shape[0] // 5] grid_labels = grid_labels[:grid_labels.shape[0] // 5] prior_traj_latents = M.run(grid_latents, is_validation=True, minibatch_size=sched.minibatch_gpu) if use_std_in_m: prior_traj_latents = prior_traj_latents[:, :prior_latent_size] else: grid_size = (n_samples_per, n_continuous) grid_labels = np.tile(grid_labels[:1], (n_continuous * n_samples_per, 1)) latent_dirs = get_latent_dirs(n_continuous) prior_traj_latents = get_prior_traj_by_dirs(latent_dirs, M, n_samples_per, prior_latent_size, grid_labels, sched) if resolution_manual >= 256: grid_size = (grid_size[0], grid_size[1] // 5) prior_traj_latents = prior_traj_latents[:prior_traj_latents. shape[0] // 5] grid_labels = grid_labels[:grid_labels.shape[0] // 5] print('prior_traj_latents.shape:', prior_traj_latents.shape) # pdb.set_trace() prior_traj_latents_show = np.reshape( prior_traj_latents, [-1, n_samples_per, prior_latent_size]) print_traj(prior_traj_latents_show) grid_fakes = Gs.run(prior_traj_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) if (n_continuous == 2) and (n_discrete == 0): n_per_line = 20 ex_latent_value = 3 prior_grid_latents, prior_grid_labels = get_2d_grid_latents( low=-ex_latent_value, high=ex_latent_value, n_per_line=n_per_line, grid_labels=grid_labels) grid_showing_fakes = Gs.run(prior_grid_latents, prior_grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_showing_fakes = add_outline(grid_showing_fakes, width=1) misc.save_image_grid( grid_showing_fakes, dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png'), drange=drange_net, grid_size=[n_per_line, n_per_line]) img_to_draw = Image.open( dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png')) img_to_draw = img_to_draw.convert('RGB') img_to_draw = draw_traj_on_prior_grid(img_to_draw, prior_traj_latents_show, ex_latent_value, n_per_line) img_to_draw.save( dnnlib.make_run_dir_path('fakes_init_2d_prior_grid_drawn.png')) if use_level_training: ending_level = training_set_resolution_log2 - 1 else: ending_level = 1 # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Is_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), I_smoothing_kimg * 1000.0) if I_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. I_opt_args = dict(I_opt_args) for args, reg_interval in [(I_opt_args, I_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio I_opts = [] I_reg_opts = [] for n_level in range(ending_level): I_opts.append(tflib.Optimizer(name='TrainI_%d' % n_level, **I_opt_args)) I_reg_opts.append( tflib.Optimizer(name='RegI_%d' % n_level, share=I_opts[-1], **I_opt_args)) # Build training graph for each GPU. for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of I and M. I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') M_gpu = M if gpu == 0 else M.clone(M.name + '_shadow') G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if use_hd_with_cls: I_info_gpu = I_info if gpu == 0 else I_info.clone(I_info.name + '_shadow') # Evaluate loss functions. lod_assign_ops = [] I_losses = [] I_regs = [] if 'lod' in I_gpu.vars: lod_assign_ops += [tf.assign(I_gpu.vars['lod'], lod_in)] if 'lod' in M_gpu.vars: lod_assign_ops += [tf.assign(M_gpu.vars['lod'], lod_in)] for n_level in range(ending_level): with tf.control_dependencies(lod_assign_ops): with tf.name_scope('I_loss_%d' % n_level): if use_hd_with_cls: I_loss, I_reg = dnnlib.util.call_func_by_name( I=I_gpu, M=M_gpu, G=G_gpu, I_info=I_info_gpu, opt=I_opts[n_level], training_set=training_set, minibatch_size=minibatch_gpu_in, **I_loss_args) else: I_loss, I_reg = dnnlib.util.call_func_by_name( I=I_gpu, M=M_gpu, G=G_gpu, opt=I_opts[n_level], n_levels=(n_level + 1) if use_level_training else None, training_set=training_set, minibatch_size=minibatch_gpu_in, **I_loss_args) I_losses.append(I_loss) I_regs.append(I_reg) # Register gradients. if not lazy_regularization: if I_regs[n_level] is not None: I_losses[n_level] += I_regs[n_level] else: if I_regs[n_level] is not None: I_reg_opts[n_level].register_gradients( tf.reduce_mean(I_regs[n_level] * I_reg_interval), I_gpu.trainables) if use_hd_with_cls: MIIinfo_gpu_trainables = collections.OrderedDict( list(M_gpu.trainables.items()) + list(I_gpu.trainables.items()) + list(I_info_gpu.trainables.items())) I_opts[n_level].register_gradients( tf.reduce_mean(I_losses[n_level]), MIIinfo_gpu_trainables) else: MI_gpu_trainables = collections.OrderedDict( list(M_gpu.trainables.items()) + list(I_gpu.trainables.items())) I_opts[n_level].register_gradients( tf.reduce_mean(I_losses[n_level]), MI_gpu_trainables) # Setup training ops. I_train_ops = [] I_reg_ops = [] for n_level in range(ending_level): I_train_ops.append(I_opts[n_level].apply_updates()) I_reg_ops.append(I_reg_opts[n_level].apply_updates(allow_no_op=True)) Is_update_op = Is.setup_as_moving_average_of(I, beta=Is_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: I.setup_weight_histograms() M.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break n_level = 0 if not use_level_training else min( cur_nimg // (level_I_kimg * 1000), training_set_resolution_log2 - 2) # Choose training parameters and configure training ops. sched = training_schedule( cur_nimg=cur_nimg, training_set_resolution_log2=training_set_resolution_log2, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): I_opts[n_level].reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.I_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_I_reg = (lazy_regularization and running_mb_counter % I_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run(I_train_ops[n_level], feed_dict) if run_I_reg: tflib.run(I_reg_ops[n_level], feed_dict) tflib.run([Is_update_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(I_train_ops[n_level], feed_dict) if run_I_reg: for _round in rounds: tflib.run(I_reg_ops[n_level], feed_dict) tflib.run(Is_update_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 100 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): if not use_hyperplane: prior_traj_latents = M.run( grid_latents, is_validation=True, minibatch_size=sched.minibatch_gpu) if use_std_in_m: prior_traj_latents = prior_traj_latents[:, : prior_latent_size] else: prior_traj_latents = get_prior_traj_by_dirs( latent_dirs, M, n_samples_per, prior_latent_size, grid_labels, sched) if resolution_manual >= 256: prior_traj_latents = prior_traj_latents[: prior_traj_latents .shape[0] // 5] prior_traj_latents_show = np.reshape( prior_traj_latents, [-1, n_samples_per, prior_latent_size]) print_traj(prior_traj_latents_show) grid_fakes = Gs.run(prior_traj_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if (n_continuous == 2) and (n_discrete == 0): n_per_line = 20 ex_latent_value = 3 prior_grid_latents, prior_grid_labels = get_2d_grid_latents( low=-ex_latent_value, high=ex_latent_value, n_per_line=n_per_line, grid_labels=grid_labels) grid_showing_fakes = Gs.run( prior_grid_latents, prior_grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_showing_fakes = add_outline(grid_showing_fakes, width=1) misc.save_image_grid(grid_showing_fakes, dnnlib.make_run_dir_path( 'fakes_2d_prior_grid%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=[n_per_line, n_per_line]) img_to_draw = Image.open( dnnlib.make_run_dir_path( 'fakes_2d_prior_grid%06d.png' % (cur_nimg // 1000))) img_to_draw = img_to_draw.convert('RGB') img_to_draw = draw_traj_on_prior_grid( img_to_draw, prior_traj_latents_show, ex_latent_value, n_per_line) img_to_draw.save( dnnlib.make_run_dir_path( 'fakes_2d_prior_grid_drawn%06d.png' % (cur_nimg // 1000))) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((I, M, Is), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((I, M, Is), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=10000.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: #print('Constructing networks...') #G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) #D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) #Gs = G.clone('Gs') url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: G, D, Gs = pickle.load(f) print('Loading pretrained FFHQ network') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( submit_config, G_args = {}, # 生成网络的设置。 D_args = {}, # 判别网络的设置。 G_opt_args = {}, # 生成网络优化器设置。 D_opt_args = {}, # 判别网络优化器设置。 G_loss_args = {}, # 生成损失设置。 D_loss_args = {}, # 判别损失设置。 dataset_args = {}, # 数据集设置。 sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 metric_arg_list = [], # 指标方法设置。 tf_config = {}, # tflib.init_tf()相关设置。 G_smoothing_kimg = 10.0, # 生成器权重的运行平均值的半衰期。 D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 mirror_augment = False, # 启用镜像增强? drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0): # 在训练开始时给定统计时间。影响报告。 # 初始化dnnlib和TensorFlow。 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # 载入训练集。 training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # 构建网络。 with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers(); D.print_layers() # 构建计算图与优化器 print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。 lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) cur_nimg += sched.minibatch tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()