def run_pickle(submit_config, metric_args, network_pkl, dataset_args, mirror_augment): ctx = dnnlib.RunContext(submit_config) tflib.init_tf() print('Evaluating %s metric on network_pkl "%s"...' % (metric_args.name, network_pkl)) metric = dnnlib.util.call_func_by_name(**metric_args) print() metric.run(network_pkl, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus) print() ctx.close()
def run_snapshot(submit_config, metric_args, run_id, snapshot): ctx = dnnlib.RunContext(submit_config) tflib.init_tf() print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot)) run_dir = misc.locate_run_dir(run_id) network_pkl = misc.locate_network_pkl(run_dir, snapshot) metric = dnnlib.util.call_func_by_name(**metric_args) print() metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) print() ctx.close()
def test_d(submit_config, resume_run_id, dataset_args, tf_config={}, resume_snapshot=None): ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) latents_1 = tf.placeholder(tf.float32) labels_1 = None training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) w_1 = Gs.components.mapping.get_output_for(latents_1, labels_1, is_validation=True) fake_image_1_op = Gs.components.synthesis.get_output_for( w_1, is_validation=True, randomize_noise=False) reals, labels = training_set.get_minibatch_tf() lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) reals = process_reals(reals, lod_in, False, training_set.dynamic_range, [-1, 1]) d_pred_real = D.get_output_for(reals, labels_1) d_pred_fake = D.get_output_for(fake_image_1_op, labels_1) training_set.configure(1, 0) for i in range(15): latents_1_val = np.random.randn(1, *G.input_shape[1:]) # d_pred, fake_image_1 = tflib.run([d_pred_op, fake_image_1_op], feed_dict={latents_1: latents_1_val, lod_in: 0}) d_pred_real_, d_pred_fake_, real_image = tflib.run( [d_pred_real, d_pred_fake, reals], feed_dict={ latents_1: latents_1_val, lod_in: 0 }) print(d_pred_real_, d_pred_fake_) misc.save_mri_image(real_image, os.path.join(submit_config.run_dir, 'real_{}.nii.gz'.format(i)), drange=[-1, 1])
def run_all_snapshots(submit_config, metric_args, run_id): ctx = dnnlib.RunContext(submit_config) tflib.init_tf() print('Evaluating %s metric on all snapshots of run_id %s...' % (metric_args.name, run_id)) run_dir = misc.locate_run_dir(run_id) network_pkls = misc.list_network_pkls(run_dir) metric = dnnlib.util.call_func_by_name(**metric_args) print() for idx, network_pkl in enumerate(network_pkls): ctx.update('', idx, len(network_pkls)) metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus) print() ctx.close()
def validate(submit_config: dnnlib.SubmitConfig, noise: dict, dataset: dict, network_snapshot: str): noise_augmenter = dnnlib.util.call_func_by_name(**noise) validation_set = ValidationSet(submit_config) validation_set.load(**dataset) ctx = dnnlib.RunContext(submit_config, config) tfutil.init_tf(config.tf_config) with tf.device("/gpu:0"): net = util.load_snapshot(network_snapshot) validation_set.evaluate(net, 0, noise_augmenter.add_validation_noise_np) ctx.close()
def mixing(submit_config, resume_run_id, tf_config = {}, resume_snapshot=None): ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) latents_1_val = np.random.randn(1,*G.input_shape[1:]) latents_2_val = np.random.randn(1,*G.input_shape[1:]) # latents_2_val = latents_1_val latents_1 = tf.placeholder(tf.float32) labels_1 = tf.constant([[0,0,0,0,1,0]]) latents_2 = tf.placeholder(tf.float32) labels_2 = tf.constant([[0,0,0,0,1,0]]) w_1 = Gs.components.mapping.get_output_for(latents_1, labels_1, is_validation=True) w_2 = Gs.components.mapping.get_output_for(latents_2, labels_2, is_validation=True) # w_1_val = tflib.run(w_1) # w_2_val = tflib.run(w_2) fake_image_1_op = Gs.components.synthesis.get_output_for(w_1, is_validation=True, randomize_noise=False) fake_image_2_op = Gs.components.synthesis.get_output_for(w_2, is_validation=True, randomize_noise=False) fake_image_1 = tflib.run(fake_image_1_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val}) fake_image_2 = tflib.run(fake_image_2_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val}) misc.save_image(fake_image_1[0], os.path.join(submit_config.run_dir,'fake_image_1.png'), drange=[-1,1]) misc.save_image(fake_image_2[0], os.path.join(submit_config.run_dir,'fake_image_2.png'), drange=[-1,1]) for i in range(15): w_mix = tf.concat([w_1[:,:i],w_2[:,i:]], axis=1) fake_mix_op = Gs.components.synthesis.get_output_for(w_mix, is_validation=True, randomize_noise=False) fake_mix_image = tflib.run(fake_mix_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val}) misc.save_image(fake_mix_image[0], os.path.join(submit_config.run_dir,'fake_mix_12_{}.png'.format(i)), drange=[-1,1]) for i in range(15): w_mix = tf.concat([w_2[:,:i],w_1[:,i:]], axis=1) fake_mix_op = Gs.components.synthesis.get_output_for(w_mix, is_validation=True, randomize_noise=False) fake_mix_image = tflib.run(fake_mix_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val}) misc.save_image(fake_mix_image[0], os.path.join(submit_config.run_dir,'fake_mix_21_{}.png'.format(i)), drange=[-1,1])
def validate(submit_config: dnnlib.SubmitConfig, noise: dict, dataset: dict, network_snapshot: str): noise_augmenter = dnnlib.util.call_func_by_name(**noise) validation_set = ValidationSet(submit_config) validation_set.load(**dataset) # Create a run context (hides low level details, exposes simple API to manage the run) ctx = dnnlib.RunContext(submit_config, config) # Initialize TensorFlow graph and session using good default settings tfutil.init_tf(config.tf_config) # Construct the network using the Network helper class and a function defined in config.net_config with tf.device("/gpu:0"): net = load_snapshot(network_snapshot) validation_set.evaluate(net, 0, noise_augmenter.add_validation_noise_np) ctx.close()
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int, eval_interval: int, minibatch_size: int, learning_rate: float, ramp_down_perc: float, noise: dict, validation_config: dict, train_tfrecords: str, noise2noise: bool): noise_augmenter = dnnlib.util.call_func_by_name(**noise) validation_set = ValidationSet(submit_config) validation_set.load(**validation_config) # Create a run context (hides low level details, exposes simple API to manage the run) # noinspection PyTypeChecker ctx = dnnlib.RunContext(submit_config, config) # Initialize TensorFlow graph and session using good default settings tfutil.init_tf(config.tf_config) dataset_iter = create_dataset(train_tfrecords, minibatch_size, noise_augmenter.add_train_noise_tf) # Construct the network using the Network helper class and a function defined in config.net_config with tf.device("/gpu:0"): net = tflib.Network(**config.net_config) # Optionally print layer information net.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device("/cpu:0"): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) noisy_input, noisy_target, clean_target = dataset_iter.get_next() noisy_input_split = tf.split(noisy_input, submit_config.num_gpus) noisy_target_split = tf.split(noisy_target, submit_config.num_gpus) clean_target_split = tf.split(clean_target, submit_config.num_gpus) # Define the loss function using the Optimizer helper class, this will take care of multi GPU opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config) for gpu in range(submit_config.num_gpus): with tf.device("/gpu:%d" % gpu): net_gpu = net if gpu == 0 else net.clone() denoised = net_gpu.get_output_for(noisy_input_split[gpu]) if noise2noise: meansq_error = tf.reduce_mean( tf.square(noisy_target_split[gpu] - denoised)) else: meansq_error = tf.reduce_mean( tf.square(clean_target_split[gpu] - denoised)) # Create an autosummary that will average over all GPUs with tf.control_dependencies([autosummary("Loss", meansq_error)]): opt.register_gradients(meansq_error, net_gpu.trainables) train_step = opt.apply_updates() # Create a log file for Tensorboard summary_log = tf.summary.FileWriter(submit_config.run_dir) summary_log.add_graph(tf.get_default_graph()) print('Training...') time_maintenance = ctx.get_time_since_last_update() ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=0, max_epoch=iteration_count) # *********************************** # The actual training loop for i in range(iteration_count): # Whether to stop the training or not should be asked from the context if ctx.should_stop(): break # Dump training status if i % eval_interval == 0: time_train = ctx.get_time_since_last_update() time_total = ctx.get_time_since_start() # Evaluate 'x' to draw a batch of inputs [source_mb, target_mb] = tfutil.run([noisy_input, clean_target]) denoised = net.run(source_mb) save_image(submit_config, denoised[0], "img_{0}_y_pred.png".format(i)) save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i)) save_image(submit_config, source_mb[0], "img_{0}_x_aug.png".format(i)) validation_set.evaluate(net, i, noise_augmenter.add_validation_noise_np) print( 'iter %-10d time %-12s eta %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f' % (autosummary('Timing/iter', i), dnnlib.util.format_time( autosummary('Timing/total_sec', time_total)), dnnlib.util.format_time( autosummary('Timing/total_sec', (time_train / eval_interval) * (iteration_count - i))), autosummary('Timing/sec_per_eval', time_train), autosummary('Timing/sec_per_iter', time_train / eval_interval), autosummary('Timing/maintenance_sec', time_maintenance))) dnnlib.tflib.autosummary.save_summaries(summary_log, i) ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=i, max_epoch=iteration_count) time_maintenance = ctx.get_last_update_interval() - time_train # Training epoch lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc, learning_rate) tfutil.run([train_step], {lrate_in: lrate}) # End of training print("Elapsed time: {0}".format( util.format_time(ctx.get_time_since_start()))) save_snapshot(submit_config, net, 'final') # Summary log and context should be closed at the end summary_log.close() ctx.close()
total_kimg = 15000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 1, # How often to export image snapshots? network_snapshot_ticks = 1, # How often to export network snapshots? save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_run_id = None, # Run ID or network pkl to resume training from, None = start from scratch. #resume_run_id = 'results/00011-sgan-custom_datasets-1gpu/network-snapshot-005645.pkl', resume_snapshot = None, # Snapshot index to resume training from, None = autodetect. resume_kimg = 0.0 # Assumed training progress at the beginning. Affects reporting and training schedule. #resume_kimg = 5645, resume_time = 0.0): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
def joint_train( submit_config, opt, metric_arg_list, sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 dataset_args = {}, # 数据集设置。 total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 mirror_augment = False, # 启用镜像增强? reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0, # 在训练开始时给定统计时间。影响报告。 *args, **kwargs ): output_dir = opt.output_dir graph_kwargs = util.set_graph_kwargs(opt) graph_util = importlib.import_module('graphs.' + opt.model + '.graph_util') constants = importlib.import_module('graphs.' + opt.model + '.constants') model = graphs.find_model_using_name(opt.model, opt.transform) g = model(submit_config=submit_config, dataset_args=dataset_args, **graph_kwargs, **kwargs) g.initialize_graph() # create training samples #num_samples = opt.num_samples # if opt.model == 'biggan' and opt.biggan.category is not None: # graph_inputs = graph_util.graph_input(g, num_samples, seed=0, category=opt.biggan.category) # else: # graph_inputs = graph_util.graph_input(g, num_samples, seed=0) w_snapshot_ticks = opt.model_save_freq ctx = dnnlib.RunContext(submit_config, train) training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(g.G, training_set, **grid_args) sched = training_loop.training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: g.G.setup_weight_histograms(); g.D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 loss_values = [] while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_loop.training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): g.G_opt.reset_optimizer_state(); # D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): alpha_for_graph, alpha_for_target = g.get_train_alpha(constants.BATCH_SIZE) if not isinstance(alpha_for_graph, list): alpha_for_graph = [alpha_for_graph] alpha_for_target = [alpha_for_target] for ag, at in zip(alpha_for_graph, alpha_for_target): feed_dict_out = graph_util.graph_input(g, constants.BATCH_SIZE, seed=0) out_zs = g.sess.run(g.outputs_orig, feed_dict_out) target_fn, mask_out = g.get_target_np(out_zs, at) feed_dict = feed_dict_out feed_dict[g.alpha] = ag feed_dict[g.target] = target_fn feed_dict[g.mask] = mask_out feed_dict[g.lod_in] = sched.lod feed_dict[g.lrate_in] = sched.D_lrate feed_dict[g.minibatch_in] = sched.minibatch curr_loss, _, Gs_op, G_op = g.sess.run([g.joint_loss, g.train_step, g.Gs_update_op, g.G_train_op], feed_dict=feed_dict) loss_values.append(curr_loss) cur_nimg += sched.minibatch #tflib.run([g.Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) #tflib.run([g.G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((g.G, g.D, g.Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) if cur_tick % w_snapshot_ticks == 0 or done: g.saver.save(g.sess, './{}/model_{}.ckpt'.format( output_dir, (cur_nimg // 1000)), write_meta_graph=False, write_state=False) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((g.G, g.D, g.Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close() loss_values = np.array(loss_values) np.save('./{}/loss_values.npy'.format(output_dir), loss_values) f, ax = plt.subplots(figsize=(10, 4)) ax.plot(loss_values) f.savefig('./{}/loss_values.png'.format(output_dir))
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device("/gpu:0"): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print("Constructing networks...") G = tflib.Network( name="G", num_inputs=2, # one for latents and one for labels num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network( name="D", num_inputs=int(np.log2(training_set.shape[1])) - 1 + 1, # +1 for labels :) num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone("Gs") G.print_layers() D.print_layers() print("Building TensorFlow graph...") with tf.name_scope("Inputs"), tf.device("/cpu:0"): lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[]) minibatch_in = tf.placeholder(tf.int32, name="minibatch_in", shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = (0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0) G_opt = tflib.Optimizer(name="TrainG", learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name="TrainD", learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow") D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow") reals, labels = training_set.get_minibatch_tf() reals = process_reals( reals, mirror_augment, training_set.dynamic_range, drange_net, depth=training_set.resolution_log2 - 1, ) with tf.name_scope("G_loss"): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope("D_loss"): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device("/gpu:0"): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) print("Setting up snapshot image grid...") grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_size // submit_config.num_gpus, ) print("Setting up run dir...") fake_multi_scale_dirs = [ os.path.join(submit_config.run_dir, str(2**res) + "x" + str(2**res)) for res in range(2, 2 + len(grid_fakes)) ] misc.save_image_grid( grid_reals, os.path.join(submit_config.run_dir, "reals.png"), drange=training_set.dynamic_range, grid_size=grid_size, ) misc.save_image_grids( grid_fakes, [ os.path.join(fake_multi_scale_dir, "fakes%06d.png" % resume_kimg) for fake_multi_scale_dir in fake_multi_scale_dirs ], drange=drange_net, grid_size=grid_size, ) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print("Training...\n") ctx.update("", cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg # configure the training_set to a proper minibatch size training_set.configure(sched.minibatch_size // submit_config.num_gpus) while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lrate_in: sched.D_lrate, minibatch_in: sched.minibatch_size }, ) cur_nimg += sched.minibatch_size tflib.run( [G_train_op], { lrate_in: sched.G_lrate, minibatch_in: sched.minibatch_size }, ) # Perform maintenance tasks once per tick. done = cur_nimg >= total_kimg * 1000 if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( "tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f" % ( autosummary("Progress/tick", cur_tick), autosummary("Progress/kimg", cur_nimg / 1000.0), autosummary("Progress/minibatch", sched.minibatch_size), dnnlib.util.format_time( autosummary("Timing/total_sec", total_time)), autosummary("Timing/sec_per_tick", tick_time), autosummary("Timing/sec_per_kimg", tick_time / tick_kimg), autosummary("Timing/maintenance_sec", maintenance_time), autosummary("Resources/peak_gpu_mem_gb", peak_gpu_mem_op.eval() / 2**30), )) autosummary("Timing/total_hours", total_time / (60.0 * 60.0)) autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_size // submit_config.num_gpus, ) misc.save_image_grids( grid_fakes, [ os.path.join(fake_multi_scale_dir, "fakes%06d.png" % (cur_nimg // 1000)) for fake_multi_scale_dir in fake_multi_scale_dirs ], drange=drange_net, grid_size=grid_size, ) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, "network-snapshot-%06d.pkl" % (cur_nimg // 1000), ) misc.save_pkl((G, D, Gs), pkl) metrics.run( pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config, ) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update(cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, "network-final.pkl")) summary_log.close() ctx.close()
def training_loop( submit_config, G_args = {}, # 生成网络的设置。 D_args = {}, # 判别网络的设置。 G_opt_args = {}, # 生成网络优化器设置。 D_opt_args = {}, # 判别网络优化器设置。 G_loss_args = {}, # 生成损失设置。 D_loss_args = {}, # 判别损失设置。 dataset_args = {}, # 数据集设置。 sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 metric_arg_list = [], # 指标方法设置。 tf_config = {}, # tflib.init_tf()相关设置。 G_smoothing_kimg = 10.0, # 生成器权重的运行平均值的半衰期。 D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 mirror_augment = False, # 启用镜像增强? drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0): # 在训练开始时给定统计时间。影响报告。 # 初始化dnnlib和TensorFlow。 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # 载入训练集。 training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # 构建网络。 with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers(); D.print_layers() # 构建计算图与优化器 print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。 lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) cur_nimg += sched.minibatch tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=10, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) # ajay - move init to after graph creation? tflib.init_tf(tf_config) # Load training set. print('ajay - config data dir', config.data_dir) training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, num_hosts=hvd.size(), index=hvd.rank(), **dataset_args) # Construct networks. print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # with tf.device('/gpu:0'): # if resume_run_id is not None: # network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) # print('Loading networks from "%s"...' % network_pkl) # G, D, Gs = misc.load_pkl(network_pkl) # else: # print('Constructing networks...') # G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) # D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) # Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tf.train.AdamOptimizer(learning_rate=lrate_in, beta1=0.0, beta2=0.99, epsilon=1e-8) G_opt = hvd.DistributedOptimizer(G_opt) D_opt = tf.train.AdamOptimizer(learning_rate=lrate_in, beta1=0.0, beta2=0.99, epsilon=1e-8) D_opt = hvd.DistributedOptimizer(D_opt) G_gpu = G D_gpu = D lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] # ajay - check if unique minibatch is guaranteed i.e sharding is done right! reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_grads = G_opt.compute_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_grads = D_opt.compute_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_gradients(G_grads) D_train_op = D_opt.apply_gradients(D_grads) # Horovod init_op = tf.initialize_all_variables() bcast_op = hvd.broadcast_global_variables(0) # ajay tf.get_default_session().run([init_op]) tflib.run([bcast_op]) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) # todo: ajay - note num_gpus need to change to hvd size when going multi-node sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) if hvd.rank() == 0: print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < (total_kimg * 1000): if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) # todo: ajay - find a way to manually resetoptimizer if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): tflib.assert_tf_initialized() G_opt_reset_op = [var.initializer for var in G_opt.variables()] D_opt_reset_op = [var.initializer for var in D_opt.variables()] tflib.run(G_opt_reset_op) tflib.run(D_opt_reset_op) # G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # grp_train_op = tf.group(D_train_op, [Gs_update_op]) # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch #// submit_config.num_gpus tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= (total_kimg * 1000)) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. # ajay #ajay mod if hvd.rank() == 0: print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f ' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( # autosummary('Progress/tick', cur_tick), # autosummary('Progress/kimg', cur_nimg / 1000.0), # autosummary('Progress/lod', sched.lod), # autosummary('Progress/minibatch', sched.minibatch), # dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), # autosummary('Timing/sec_per_tick', tick_time), # autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), # autosummary('Timing/maintenance_sec', maintenance_time), # autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) # autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) # autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) # ajay - note modifying to 1 for eval metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=1, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() if hvd.rank() == 0: tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. if hvd.rank() == 0: misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( submit_config, HP_args={}, # Options for the Hessian Penalty. G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? interp_snapshot_ticks=20, # How often to generate interpolation visualizations in TensorBoard? network_snapshot_ticks=5, # How often to export network snapshots? network_metric_ticks=5, # How often to evaluate network snapshots on specified metrics? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create a copy of dataset_args for running the metrics: metrics_dataset_args = deepcopy(dataset_args) metrics_dataset_args.shuffle_mb = 0 print('Saving interp videos every %s ticks' % interp_snapshot_ticks) print('Saving network snapshot every %s ticks' % network_snapshot_ticks) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # G.print_layers(); D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # The loss weighting of the Hessian Penalty can be dynamic over training, so we need to use a placeholder: hessian_weight = tf.placeholder(tf.float32, name='hessian_weight', shape=[]) G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) reg_ops = [ ] # Returning the values of the Hessian Penalty/ InfoGAN losses so they can be registered in TensorBoard for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss, G_hessian_penalty = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, hp_lambda=hessian_weight, HP_args=HP_args, gpu_ix=gpu, lod_in=lod_in, max_lod=training_set.resolution_log2, **G_loss_args) if HP_args.hp_lambda > 0: reg_ops += [G_hessian_penalty] with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss, mutual_info = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, gpu_ix=gpu, infogan_nz=D_args.infogan_nz, **D_loss_args) # print([name for name in D_gpu.trainables.keys()]) # gps = [weight for name, weight in G_gpu.trainables.items()][0] # dps = [weight for name, weight in D_gpu.trainables.items() if 'Q_Encoder' in name][0] # gg = autosummary('Loss/G_info_grad', tf.reduce_sum(tf.gradients(mutual_info, gps)[0]**2)) # dg = autosummary('Loss/D_info_grad', tf.reduce_sum(tf.gradients(mutual_info, dps)[0]**2)) # reg_ops.extend([dg, gg, dps, gps]) if G_args.infogan_lambda > 0 or D_args.infogan_lambda > 0: reg_ops += [mutual_info] # Note, even though we are adding mutual_info loss here, the only time the loss is non-zero # is when infogan_lambda > 0 (in Hessian Penalty experiments, we always set it to zero): G_opt.register_gradients( G_loss + G_args.infogan_lambda * mutual_info, G_gpu.trainables) D_opt.register_gradients( tf.reduce_mean(D_loss) + D_args.infogan_lambda * mutual_info, D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up snapshot interpolation...') nz = G.input_shapes[0][1] interp_steps = 24 # Number of frames in the visualization interp_batch_size = 8 # Number of gifs per row of visualization interp_z = vis_tools.sample_interp_zs(nz, interp_batch_size, interp_steps) interp_labels = np.zeros( [interp_steps * interp_batch_size * nz, training_set.label_size], dtype=training_set.label_dtype) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) summary_log.add_summary( build_image_summary(os.path.join(submit_config.run_dir, 'reals.png'), 'samples/real'), 0) summary_log.add_summary( build_image_summary( os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), 'samples/Gs'), resume_kimg) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) if interp_snapshot_ticks != -1 and interp_snapshot_ticks < 9999: print('Generating initial interpolations...') vis_tools.make_and_save_interpolation_gifs( Gs, interp_z, interp_labels, minibatch_size=sched.minibatch // submit_config.num_gpus, interp_steps=interp_steps, interp_batch_size=interp_batch_size, cur_kimg=resume_kimg, summary_log=summary_log) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 num_G_grad_steps = 0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch cur_hessian_weight = get_current_hessian_penalty_loss_weight( HP_args.hp_lambda, HP_args.hp_start_nimg, cur_nimg, HP_args.warmup_nimg) tflib.run( [G_train_op] + reg_ops, { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, hessian_weight: cur_hessian_weight }) num_G_grad_steps += 1 # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d hessian_weight %s time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), autosummary('Progress/hessian_weight', cur_hessian_weight), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) autosummary('Progress/G_grad_steps', num_G_grad_steps) # Save snapshots and fake image samples (for both Gs and G): if cur_tick % image_snapshot_ticks == 0 or done: iter = (cur_nimg // 1000) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes_inst = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) fake_path = os.path.join(submit_config.run_dir, 'fakes%06d.png' % iter) ifake_path = os.path.join(submit_config.run_dir, 'ifakes%06d.png' % iter) misc.save_image_grid(grid_fakes, fake_path, drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_fakes_inst, ifake_path, drange=drange_net, grid_size=grid_size) summary_log.add_summary( build_image_summary(fake_path, 'samples/Gs'), iter) summary_log.add_summary( build_image_summary(ifake_path, 'samples/G'), iter) # Generate/Save Interpolation Visualizations: if interp_snapshot_ticks != -1 and cur_tick % interp_snapshot_ticks == 0: vis_tools.make_and_save_interpolation_gifs( Gs, interp_z, interp_labels, minibatch_size=sched.minibatch // submit_config.num_gpus, interp_steps=interp_steps, interp_batch_size=interp_batch_size, cur_kimg=cur_nimg // 1000, summary_log=summary_log) # Save snapshot and run metrics: if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) if cur_tick % network_metric_ticks == 0 or done or cur_tick == 1: metrics.run(pkl, dataset_args=metrics_dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % total_kimg)) summary_log.close() ctx.close()
def training_loop( submit_config, #--------------------------------------------------------------- # Modified by Deng et al. noise_dim=32, weight_args={}, train_stage_args={}, #--------------------------------------------------------------- G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=True, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=87, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=2364, # Snapshot index to resume training from, None = autodetect. resume_kimg=2364, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, **_kwargs ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. PI = 3.1415927 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create 3d face reconstruction block FaceRender = Face3D() # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') #--------------------------------------------------------------- # Modified by Deng et al. G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, latent_size=254 + noise_dim, **G_args) #--------------------------------------------------------------- D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) resolution = tf.placeholder(tf.float32, name='resolution', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) #--------------------------------------------------------------- # Modified by Deng et al. G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\ G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\ lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\ drange_net=drange_net,lod_in=lod_in,**train_stage_args) #--------------------------------------------------------------- G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) #--------------------------------------------------------------- # Modified by Deng et al. restore_weights_and_initialize(train_stage_args) print('Setting up snapshot image grid...') sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_latents = tf.random_normal([np.prod(grid_size), 128 + 32 + 16 + 3]) grid_INPUTcoeff = z_to_lambda_mapping(grid_latents) grid_INPUTcoeff_w_t = tf.concat( [grid_INPUTcoeff, tf.zeros([np.prod(grid_size), 3])], axis=1) with tf.name_scope('FaceRender'): grid_render_img, _, _, _ = FaceRender.Reconstruction_Block( grid_INPUTcoeff_w_t, 256, np.prod(grid_size), progressive=False) grid_render_img = tf.transpose(grid_render_img, perm=[0, 3, 1, 2]) grid_render_img = process_reals(grid_render_img, lod_in, False, training_set.dynamic_range, drange_net) grid_INPUTcoeff_, grid_renders = tflib.run( [grid_INPUTcoeff, grid_render_img], {lod_in: sched.lod}) grid_noise = np.random.randn(np.prod(grid_size), 32) grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_, grid_noise], axis=1) grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) #--------------------------------------------------------------- summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) # print('iter') # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: #--------------------------------------------------------------- # Modified by Deng et al. grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) #--------------------------------------------------------------- if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close() #----------------------------------------------------------------------------
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=10000.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: #print('Constructing networks...') #G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) #D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) #Gs = G.clone('Gs') url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: G, D, Gs = pickle.load(f) print('Loading pretrained FFHQ network') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int, eval_interval: int, minibatch_size: int, learning_rate: float, ramp_down_perc: float, noise: dict, validation_config: dict, train_tfrecords: str, noise2noise: bool): noise_augmenter = dnnlib.util.call_func_by_name(**noise) validation_set = ValidationSet(submit_config) validation_set.load(**validation_config) # Create a run context (hides low level details, exposes simple API to manage the run) ctx = dnnlib.RunContext(submit_config, config) # Initialize TensorFlow graph and session using good default settings tfutil.init_tf(config.tf_config) dataset_iter = create_dataset(train_tfrecords, minibatch_size, noise_augmenter.add_train_noise_tf) # Construct the network using the Network helper class and a function defined in config.net_config with tf.device("/gpu:0"): net = tflib.Network(**config.net_config) # Optionally print layer information net.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device("/cpu:0"): lrate_in = tf.compat.v1.placeholder(tf.float32, name='lrate_in', shape=[]) #print("DEBUG train:", "dataset iter got called") noisy_input, noisy_target, clean_target = dataset_iter.get_next() noisy_input_split = tf.split(noisy_input, submit_config.num_gpus) noisy_target_split = tf.split(noisy_target, submit_config.num_gpus) print(len(noisy_input_split), noisy_input_split) clean_target_split = tf.split(clean_target, submit_config.num_gpus) # Split [?, 3, 256, 256] across num_gpus over axis 0 (i.e. the batch) # Define the loss function using the Optimizer helper class, this will take care of multi GPU opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config) radii = np.arange(128).reshape(128, 1) #image size 256, binning = 3 radial_masks = np.apply_along_axis(radial_mask, 1, radii, 128, 128, np.arange(0, 256), np.arange(0, 256), 20) print("RN SHAPE!!!!!!!!!!:", radial_masks.shape) radial_masks = np.expand_dims(radial_masks, 1) # (128, 1, 256, 256) #radial_masks = np.squeeze(np.stack((radial_masks,) * 3, -1)) # 43, 3, 256, 256 #radial_masks = radial_masks.transpose([0, 3, 1, 2]) radial_masks = radial_masks.astype(np.complex64) radial_masks = tf.expand_dims(radial_masks, 1) rn = tf.compat.v1.placeholder_with_default(radial_masks, [128, None, 1, 256, 256]) rn_split = tf.split(rn, submit_config.num_gpus, axis=1) freq_nyq = int(np.floor(int(256) / 2.0)) spatial_freq = radii.astype(np.float32) / freq_nyq spatial_freq = spatial_freq / max(spatial_freq) for gpu in range(submit_config.num_gpus): with tf.device("/gpu:%d" % gpu): net_gpu = net if gpu == 0 else net.clone() denoised_1 = net_gpu.get_output_for(noisy_input_split[gpu]) denoised_2 = net_gpu.get_output_for(noisy_target_split[gpu]) print(noisy_input_split[gpu].get_shape(), rn_split[gpu].get_shape()) if noise2noise: meansq_error = fourier_ring_correlation( noisy_target_split[gpu], denoised_1, rn_split[gpu], spatial_freq) - fourier_ring_correlation( noisy_target_split[gpu] - denoised_2, noisy_input_split[gpu] - denoised_1, rn_split[gpu], spatial_freq) else: meansq_error = tf.reduce_mean( tf.square(clean_target_split[gpu] - denoised)) # Create an autosummary that will average over all GPUs #tf.summary.histogram(name, var) with tf.control_dependencies([autosummary("Loss", meansq_error)]): opt.register_gradients(meansq_error, net_gpu.trainables) train_step = opt.apply_updates() # Create a log file for Tensorboard summary_log = tf.compat.v1.summary.FileWriter(submit_config.run_dir) summary_log.add_graph(tf.compat.v1.get_default_graph()) print('Training...') time_maintenance = ctx.get_time_since_last_update() ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=0, max_epoch=iteration_count) # The actual training loop for i in range(iteration_count): # Whether to stop the training or not should be asked from the context if ctx.should_stop(): break # Dump training status if i % eval_interval == 0: time_train = ctx.get_time_since_last_update() time_total = ctx.get_time_since_start() print("DEBUG TRAIN!", noisy_input.dtype, noisy_input[0][0].dtype) # Evaluate 'x' to draw a batch of inputs [source_mb, target_mb] = tfutil.run([noisy_input, clean_target]) denoised = net.run(source_mb) save_image(submit_config, denoised[0], "img_{0}_y_pred.tif".format(i)) save_image(submit_config, target_mb[0], "img_{0}_y.tif".format(i)) save_image(submit_config, source_mb[0], "img_{0}_x_aug.tif".format(i)) validation_set.evaluate(net, i, noise_augmenter.add_validation_noise_np) print( 'iter %-10d time %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f' % (autosummary('Timing/iter', i), dnnlib.util.format_time( autosummary('Timing/total_sec', time_total)), autosummary('Timing/sec_per_eval', time_train), autosummary('Timing/sec_per_iter', time_train / eval_interval), autosummary('Timing/maintenance_sec', time_maintenance))) dnnlib.tflib.autosummary.save_summaries(summary_log, i) ctx.update(loss='run %d' % submit_config.run_id, cur_epoch=i, max_epoch=iteration_count) time_maintenance = ctx.get_last_update_interval() - time_train save_snapshot(submit_config, net, str(i)) lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc, learning_rate) tfutil.run([train_step], {lrate_in: lrate}) print("Elapsed time: {0}".format( util.format_time(ctx.get_time_since_start()))) save_snapshot(submit_config, net, 'final') # Summary log and context should be closed at the end summary_log.close() ctx.close()