def project_image(proj, src_file, dst_dir, tmp_dir, video=False): data_dir = '%s/dataset' % tmp_dir if os.path.exists(data_dir): shutil.rmtree(data_dir) image_dir = '%s/images' % data_dir tfrecord_dir = '%s/tfrecords' % data_dir os.makedirs(image_dir, exist_ok=True) shutil.copy(src_file, image_dir + '/') dataset_tool.create_from_images_raw(tfrecord_dir, image_dir, shuffle=0) dataset_obj = dataset.load_dataset(data_dir=data_dir, tfrecord_dir='tfrecords', max_label_size=0, repeat=False, shuffle_mb=0) print('Projecting image "%s"...' % os.path.basename(src_file)) images, _labels = dataset_obj.get_minibatch_np(1) images = misc.adjust_dynamic_range(images, [0, 255], [-1, 1]) proj.start(images) if video: video_dir = '%s/video' % tmp_dir os.makedirs(video_dir, exist_ok=True) while proj.get_cur_step() < proj.num_steps: print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True) proj.step() if video: filename = '%s/%08d.png' % (video_dir, proj.get_cur_step()) misc.save_image_grid(proj.get_images(), filename, drange=[-1, 1]) print('\r%-30s\r' % '', end='', flush=True) os.makedirs(dst_dir, exist_ok=True) filename = os.path.join(dst_dir, os.path.basename(src_file)[:-4] + '.png') misc.save_image_grid(proj.get_images(), filename, drange=[-1, 1]) filename = os.path.join(dst_dir, os.path.basename(src_file)[:-4] + '.npy') np.save(filename, proj.get_dlatents()[0])
def project_image(proj, targets, work_dir, resolution, num_snapshots): filename = osp.join(work_dir, basename(work_dir)) video_out = cv2.VideoWriter(filename + '.avi', cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 24, resolution) snapshot_steps = set(proj.num_steps - np.linspace( 0, proj.num_steps, num_snapshots, endpoint=False, dtype=int)) misc.save_image_grid(targets, filename + '.jpg', drange=[-1, 1]) proj.start(targets) pbar = ProgressBar(proj.num_steps) while proj.get_cur_step() < proj.num_steps: proj.step() write_video_frame(proj, video_out) if proj.get_cur_step() in snapshot_steps: misc.save_image_grid(proj.get_images(), filename + '-%04d.jpg' % proj.get_cur_step(), drange=[-1, 1]) pbar.upd() dlats = proj.get_dlatents() np.save(filename + '-%04d.npy' % proj.get_cur_step(), dlats) video_out.release()
def project_image(proj, targets, labels, png_prefix, num_snapshots, save_npy, npy_file_prefix): snapshot_steps = set(proj.num_steps - np.linspace( 0, proj.num_steps, num_snapshots, endpoint=False, dtype=int)) misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1, 1]) proj.start(targets) if npy_file_prefix is not None and save_npy: print("WILL SAVE npy_file to: " + npy_file_prefix + '_<LOSS>.npy') while proj.get_cur_step() < proj.num_steps: #print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True) proj.step() if proj.get_cur_step() in snapshot_steps: misc.save_image_grid( proj.get_images(), png_prefix + 'step%04d_%0.2f.png' % (proj.get_cur_step(), proj._latest_loss_value), drange=[-1, 1]) print('\r%-30s\r' % '', end='', flush=True) if npy_file_prefix is not None and save_npy: proj.save_npy(npy_file_prefix)
def project_image(proj, targets, png_prefix, num_snapshots, onlyfirstandlast=False): snapshot_steps = set(proj.num_steps - np.linspace( 0, proj.num_steps, num_snapshots, endpoint=False, dtype=int)) if onlyfirstandlast: snapshot_steps = {1, 300} misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1, 1]) proj.start(targets) while proj.get_cur_step() < proj.num_steps: print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True) proj.step() if proj.get_cur_step() in snapshot_steps: misc.save_image_grid(proj.get_images(), png_prefix + 'step%04d.png' % proj.get_cur_step(), drange=[-1, 1]) print('\r%-30s\r' % '', end='', flush=True)
def save_atts(atts, filename, grid_size, drange, grid_fakes, n_samples_per): canvas = np.zeros( [grid_fakes.shape[0], 1, grid_fakes.shape[2], grid_fakes.shape[3]]) # atts: [b, n_latents, 1, res, res] for i in range(atts.shape[1]): att_sp = atts[:, i] # [b, 1, x_h, x_w] att_sp = (att_sp - att_sp.min()) / (att_sp.max() - att_sp.min()) grid_start_idx = i * n_samples_per canvas[grid_start_idx:grid_start_idx + n_samples_per] = att_sp[grid_start_idx:grid_start_idx + n_samples_per] # already_n_latents = 0 # for i, att in enumerate(atts): # att_sp = att[-1] # [b, n_latents, 1, x_h, x_w] # for j in range(att_sp.shape[1]): # att_sp_sub = att_sp[:, j] # [b, 1, x_h, x_w] # grid_start_idx = already_n_latents * n_samples_per # canvas[grid_start_idx : grid_start_idx + n_samples_per] = att_sp_sub[grid_start_idx : grid_start_idx + n_samples_per] # already_n_latents += 1 misc.save_image_grid(canvas, filename, drange=drange, grid_size=grid_size) return
def project_image(proj, targets, png_prefix, num_snapshots): snapshot_steps = set(proj.num_steps - np.linspace( 0, proj.num_steps, num_snapshots, endpoint=False, dtype=int)) misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1, 1]) proj.start(targets) while proj.get_cur_step() < proj.num_steps: print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True) proj.step() if proj.get_cur_step() in snapshot_steps: misc.save_image_grid(proj.get_images(), png_prefix + 'step%04d.png' % proj.get_cur_step(), drange=[-1, 1]) print('###step', proj.get_cur_step()) dlatents = proj.get_dlatents() for dlatents1 in dlatents: for dlatents2 in dlatents1: str = '' for e in dlatents2: str = '{} {}'.format(str, e) print('###', str) print('\r%-30s\r' % '', end='', flush=True)
def project_image(proj, targets, png_prefix, num_snapshots): snapshot_steps = set(proj.num_steps - np.linspace( 0, proj.num_steps, num_snapshots, endpoint=False, dtype=int)) print('target.shape', targets.shape) misc.save_image_grid(targets, png_prefix + 'target.png', drange=[-1, 1]) proj.start(targets) while proj.get_cur_step() < proj.num_steps: print('\r%d / %d ... ' % (proj.get_cur_step(), proj.num_steps), end='', flush=True) proj.step() if proj.get_cur_step() in snapshot_steps: misc.save_image_grid( proj.get_images(), png_prefix + 'corrupted-step%04d.png' % proj.get_cur_step(), drange=[-1, 1]) cleanImgFile = png_prefix + 'clean-step%04d.png' % proj.get_cur_step( ) print('Saving checkpoint %s' % cleanImgFile) misc.save_image_grid(proj.get_clean_images(), cleanImgFile, drange=[-1, 1]) print('\r%-30s\r' % '', end='', flush=True)
e = err(x, x1) # init = tf.global_variables_initializer # sess = tf.Session() # sess.run(init()) # sess.run(e) # sess.run(abserr(x,x1)) tflib.run(e) tflib.run(abserr(x, x1)) # img, img_re = sess.run([x, x1]) img, img_re = tflib.run([x, x1]) misc.save_image_grid(img, dnnlib.make_run_dir_path('/gdata2/fengrl/img.png'), drange=[-1, 1], grid_size=[8, 1]) misc.save_image_grid(img_re, dnnlib.make_run_dir_path('/gdata2/fengrl/img_re.png'), drange=[-1, 1], grid_size=[8, 1]) v = [v for v in tf.global_variables() if 'weight' in v.name] w = v[4] ww = v[4] * 5.0 f_shape = [16, 16] d = tf.rsqrt(tf.reduce_sum(tf.square(w), axis=[0, 1, 2]) + 1e-8) dd = tf.rsqrt(tf.reduce_sum(tf.square(ww), axis=[0, 1, 2]) + 1e-8) w *= d[np.newaxis, np.newaxis, np.newaxis, :]
def training_loop_refinement( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=True, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) _rG, _rD, rGs = misc.load_pkl(resume_pkl) del _rD, _rG if resume_with_new_nets: G.copy_vars_from(rGs) Gs.copy_vars_from(rGs) del rGs else: G = rG Gs = rGs # Set constant noise input for both G and Gs if G_args.get("randomize_noise", None) == False: noise_vars = [ var for name, var in G.components.synthesis.vars.items() if name.startswith('noise') ] rnd = np.random.RandomState(123) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] noise_vars = [ var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise') ] rnd = np.random.RandomState(123) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] # TESTS # from PIL import Image # reals, latents = training_set.get_minibatch_np(4) # reals = np.transpose(reals, [0, 2, 3, 1]) # Image.fromarray(reals[0], 'RGB').save("test_reals.png") # labels = training_set.get_random_labels_np(4) # Gs_kwargs = dnnlib.EasyDict() # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) # fakes = Gs.run(latents, labels, minibatch_size=4, **Gs_kwargs) # Image.fromarray(fakes[0], 'RGB').save("test_fakes_Gs_new.png") # fakes = G.run(latents, labels, minibatch_size=4, **Gs_kwargs) # Image.fromarray(fakes[0], 'RGB').save("test_fakes_G_new.png") # Print layers and generate initial image snapshot. G.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) # Freeze layers G_args.freeze_layers = list(G_args.get("freeze_layers", [])) def freeze_vars(gen, verbose=True): assert len(G_args.freeze_layers) > 0 for name in list(gen.trainables.keys()): if any(layer in name for layer in G_args.freeze_layers): del gen.trainables[name] if verbose: print(f"Freezed {name}") # Build training graph for each GPU. data_fetch_ops = [] loss_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if G_args.freeze_layers: freeze_vars(G_gpu, verbose=False) # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=None, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, latents=labels_read, **G_loss_args) loss_ops.append(G_loss) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) loss_op = tf.reduce_mean(tf.concat(loss_ops, axis=0)) G_train_op = G_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 loss_per_batch_sum = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } tflib.run(data_fetch_op, feed_dict) ### TEST # fakes = G.get_output_for(labels_read, training_set.get_random_labels_tf(minibatch_gpu_in), is_training=True) # this is without activation in ~[-1.5, 1.5] # fakes = tf.clip_by_value(fakes, drange_net[0], drange_net[1]) # reals = reals_read ### TEST for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: loss, _ = tflib.run([loss_op, G_train_op], feed_dict) # (loss, reals, fakes), _ = tflib.run([loss_op, G_train_op], feed_dict) tflib.run([data_fetch_op], feed_dict) # print(f"loss_tf {np.mean(loss)}") # print(f"loss_np {np.mean(np.square(reals - fakes))}") # print(f"loss_abs {np.mean(np.abs(reals - fakes))}") loss_per_batch_sum += loss #### TEST #### # if cur_nimg == sched.minibatch_size or cur_nimg % 2048 == 0: # from PIL import Image # reals = np.transpose(reals, [0, 2, 3, 1]) # fakes = np.transpose(fakes, [0, 2, 3, 1]) # diff = np.abs(reals - fakes) # print(diff.min(), diff.max()) # for idx, (fake, real) in enumerate(zip(fakes, reals)): # fake -= fake.min() # fake /= fake.max() # fake *= 255 # fake = fake.astype(np.uint8) # Image.fromarray(fake, 'RGB').save(f"fake_loss_{idx}.png") # real -= real.min() # real /= real.max() # real *= 255 # real = real.astype(np.uint8) # Image.fromarray(real, 'RGB').save(f"real_loss_{idx}.png") #### if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([Gs_update_op], feed_dict) # Slow path with gradient accumulation. FIXME: Probably wrong else: for _round in rounds: loss, _, _ = tflib.run( [loss_op, G_train_op, data_fetch_op], feed_dict) loss_per_batch_sum += loss / len(rounds) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time tick_loss = loss_per_batch_sum * sched.minibatch_size / ( tick_kimg * 1000) loss_per_batch_sum = 0 # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d loss/px %-12.8f time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), autosummary('Progress/loss_per_px', tick_loss), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, None, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, None, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def generate_images(network_pkl, seeds, truncation_psi, data_dir=None, dataset_name=None, model=None): G_args = EasyDict(func_name='training.' + model + '.G_main') dataset_args = EasyDict(tfrecord_dir=dataset_name) G_args.fmap_base = 8 << 10 tflib.init_tf() training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) print('Constructing networks...') Gs = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) print('Loading networks from "%s"...' % network_pkl) _, _, _Gs = pretrained_networks.load_networks(network_pkl) Gs.copy_vars_from(_Gs) noise_vars = [ var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise') ] Gs_kwargs = dnnlib.EasyDict() # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) Gs_kwargs.randomize_noise = False if truncation_psi is not None: Gs_kwargs.truncation_psi = truncation_psi for seed_idx, seed in enumerate(seeds): print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) rnd = np.random.RandomState(seed) z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component] tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] images, x_v, n_v, m_v = Gs.run( z, None, **Gs_kwargs) # [minibatch, height, width, channel] print(images.shape, n_v.shape, x_v.shape, m_v.shape) misc.convert_to_pil_image(images[0], drange=[-1, 1]).save( dnnlib.make_run_dir_path('seed%04d.png' % seed)) misc.save_image_grid(adjust_range(n_v), dnnlib.make_run_dir_path('seed%04d-nv.png' % seed), drange=[-1, 1]) print(np.linalg.norm(x_v - m_v)) misc.save_image_grid(adjust_range(x_v).transpose([1, 0, 2, 3]), dnnlib.make_run_dir_path('seed%04d-xv.png' % seed), drange=[-1, 1]) misc.save_image_grid(adjust_range(m_v).transpose([1, 0, 2, 3]), dnnlib.make_run_dir_path('seed%04d-mv.png' % seed), drange=[-1, 1]) misc.save_image_grid(adjust_range(clip(x_v, 'cat')), dnnlib.make_run_dir_path('seed%04d-xvs.png' % seed), drange=[-1, 1]) misc.save_image_grid(adjust_range(clip(m_v, 'ss')), dnnlib.make_run_dir_path('seed%04d-mvs.png' % seed), drange=[-1, 1]) misc.save_image_grid(adjust_range(clip(m_v, 'ffhq')), dnnlib.make_run_dir_path('seed%04d-fmvs.png' % seed), drange=[-1, 1])
def training_loop( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. G_opt_args = {}, # Options for generator optimizer. D_opt_args = {}, # Options for discriminator optimizer. G_loss_args = {}, # Options for generator loss. D_loss_args = {}, # Options for discriminator loss. dataset_args = {}, # Options for dataset.load_dataset(). sched_args = {}, # Options for train.TrainingSchedule. grid_args = {}, # Options for train.setup_snapshot_image_grid(). setname = None, # Model name tf_config = {}, # Options for tflib.init_tf(). G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. lazy_regularization = True, # Perform regularization as a separate training step? G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg = 25000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? mirror_augment_v = False, # Enable mirror augment vertically? drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_pkl = 'latest', # Network pickle to resume training from, None = train from scratch. resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. restore_partial_fn = None, # Filename of network for partial restore resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) # custom resolution - for saved model name below resolution = training_set.resolution if training_set.init_res != [4,4]: init_res_str = '-%dx%d' % (training_set.init_res[0], training_set.init_res[1]) else: init_res_str = '' ext = 'png' if training_set.shape[0] == 4 else 'jpg' print(' model base resolution', resolution) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('_reals.%s'%ext), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print(' Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: if resume_pkl == 'latest': resume_pkl, resume_kimg = misc.locate_latest_pkl(dnnlib.submit_config.run_dir_root) elif resume_pkl == 'restore_partial': print(' Restore partially...') # Initialize networks G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # Load pre-trained networks assert restore_partial_fn != None G_partial, D_partial, Gs_partial = pickle.load(open(restore_partial_fn, 'rb')) # Restore (subset of) pre-trained weights (only parameters that match both name and shape) G.copy_compatible_trainables_from(G_partial) D.copy_compatible_trainables_from(D_partial) Gs.copy_compatible_trainables_from(Gs_partial) else: if resume_pkl is not None and resume_kimg == 0: resume_pkl, resume_kimg = misc.locate_latest_pkl(resume_pkl) print(' Loading networks from "%s", kimg %.3g' % (resume_pkl, resume_kimg)) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers if needed and generate initial image snapshot # G.print_layers(); D.print_layers() sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.%s'%ext), drange=drange_net, grid_size=grid_size) # Setup training inputs. print(' Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args) reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, training_set.dynamic_range, drange_net) reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() # print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() print(' Training for %d kimg (%d left) \n' % (total_kimg, total_kimg-resume_kimg)) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # , sched.lod if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu} for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time if sched.lod == 0: left_kimg = total_kimg - cur_nimg / 1000 left_sec = left_kimg * tick_time / tick_kimg finaltime = time.asctime(time.localtime(cur_time + left_sec)) msg_final = '%ss left till %s ' % (shortime(left_sec), finaltime[11:16]) else: msg_final = '' # Report progress. # print('tick %-4d kimg %-6.1f lod %-5.2f minibch %-3d:%d time %-8s min/tick %-6.3g %s sec/kimg %-7.3g gpumem %-4.1f %d lr %.2g ' % ( print('tick %-4d kimg %-6.1f time %-8s %s min/tick %-6.3g sec/kimg %-7.3g gpumem %-4.1f lr %.2g ' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), # autosummary('Progress/lod', sched.lod), # autosummary('Progress/minibatch', sched.minibatch_size), # autosummary('Progress/minibatch_gpu', sched.minibatch_gpu), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), msg_final, autosummary('Timing/min_per_tick', tick_time / 60), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), # autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30), sched.G_lrate)) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('snapshot-%d-%s%s-%04d.pkl' % (resolution, setname[-1], init_res_str, cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-%04d.pkl' % (setname[:-1], resolution, setname[-1], init_res_str, cur_nimg // 1000))) # Update summaries and RunContext. tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('snapshot-%d-%s%s-final.pkl' % (resolution, setname[-1], init_res_str))) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-final.pkl' % (setname[:-1], resolution, setname[-1], init_res_str))) # All done. summary_log.close() training_set.close()
def generate_interpolation_between(submit_config, network_pkl, truncation_psi, latents_file_start, latents_file_end, num_steps, minibatch_size=4): print('starting process of generating interpolation between ' + latents_file_start + ' and ' + latents_file_end) tflib.init_tf({'rnd.np_random_seed': 1000}) grid_size = (num_steps, 1) grid_labels = [] f = open(latents_file_start, 'r') start_original_latents = np.array(json.load(f)) f.close() print('loaded start latents from ' + latents_file_start) f = open(latents_file_end, 'r') end_original_latents = np.array(json.load(f)) f.close() print('loaded end latents from ' + latents_file_end) print('Loading networks from "%s"...' % network_pkl) _G, _D, Gs = pretrained_networks.load_networks(network_pkl) w_avg = Gs.get_var('dlatent_avg') # [component] Gs_syn_kwargs = dnnlib.EasyDict() Gs_syn_kwargs.randomize_noise = False Gs_syn_kwargs.minibatch_size = minibatch_size all_latents = [] start_ltnts = start_original_latents[0] end_ltnts = end_original_latents[0] ltnt = np.array(start_ltnts) each_step_increment = [((end_ltnts[i] - start_ltnts[i]) / num_steps) for i in range(len(start_ltnts)) ] #np.zeros(len(start_ltnts)) print(str(each_step_increment)) print('Generating W vectors...') for i in range(grid_size[0] * grid_size[1]): for inc_index in range(len(each_step_increment)): ltnt[inc_index] += each_step_increment[inc_index] all_latents.append(ltnt) ltnt = np.array(ltnt) all_z = np.stack(all_latents) # all_z = np.stack([mutate_latents(original_latents[0], i) for i in range(grid_size[0]*grid_size[1])]) all_w = Gs.components.mapping.run(all_z, None) # [minibatch, layer, component] all_w = w_avg + (all_w - w_avg) * truncation_psi # [minibatch, layer, component] print('Generating images...') all_images = Gs.components.synthesis.run( all_w, **Gs_syn_kwargs) # [minibatch, height, width, channel] misc.save_image_grid(all_images, dnnlib.make_run_dir_path('interpolation.png'), drange=[-1, 1], grid_size=grid_size)
def training_loop( submit_config, G_args = {}, # 生成网络的设置。 D_args = {}, # 判别网络的设置。 G_opt_args = {}, # 生成网络优化器设置。 D_opt_args = {}, # 判别网络优化器设置。 G_loss_args = {}, # 生成损失设置。 D_loss_args = {}, # 判别损失设置。 dataset_args = {}, # 数据集设置。 sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 metric_arg_list = [], # 指标方法设置。 tf_config = {}, # tflib.init_tf()相关设置。 G_smoothing_kimg = 10.0, # 生成器权重的运行平均值的半衰期。 D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 mirror_augment = False, # 启用镜像增强? drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0): # 在训练开始时给定统计时间。影响报告。 # 初始化dnnlib和TensorFlow。 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # 载入训练集。 training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # 构建网络。 with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers(); D.print_layers() # 构建计算图与优化器 print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。 lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) cur_nimg += sched.minibatch tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def recon_real_one_img(network_pkl, img, mask_dir, num_snapshots, recontype, superres_factor, num_steps): if not runSerial: sys.stdout = open('logs/' + str(os.getpid()) + ".out", "w") #sys.stderr = open(str(os.getpid()) + ".err", "w") print('Loading networks from "%s"...' % network_pkl) _G, _D, Gs = pretrained_networks.load_networks(network_pkl) imgSize, nrChannelsNet, fakeImg = getImgSize(Gs) imgShort = img.split('/')[-1] # without full path forward, forwardTrue = constructForwardModel(recontype, imgSize, nrChannelsNet, mask_dir, imgShort, superres_factor) print('Loading image "%s"...' % img) image = cv2.imread( img, cv2.IMREAD_UNCHANGED ) # without this flat, cv2 automatically converts to 3-channels ndim = len(image.shape) print('Input image has shape:', image.shape) if ndim == 2: nrChannels = 1 image = image.reshape(1, image.shape[0], image.shape[1], 1) else: nrChannels = image.shape[-1] if nrChannels == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # convert from BGR to RGB image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]) #fakeImg = forward(fakeImg) # sample of f(G(z)) for checking dimensions and channels checkDimensions(fakeImg, nrChannels, inputWidth=image.shape[1], inputHeight=image.shape[2]) image = np.transpose(image, (0, 3, 1, 2)) # convert from NHWC to NCWH print('reshaped shape:', image.shape) image = misc.adjust_dynamic_range(image, [0, 255], [-1, 1]) # save true image png_prefix = dnnlib.make_run_dir_path(imgShort[:-4] + '-') misc.save_image_grid(image, png_prefix + 'true.png', drange=[-1, 1]) imgTrue = image # this is the true image # generate corrupted image, with true forward corruption model imgCorrupted = forwardTrue(tf.convert_to_tensor(image)) imgCorrupted = imgCorrupted.eval() proj = projector.Projector(forward, num_steps) proj.set_network(Gs) project_image(proj, targets=imgCorrupted, png_prefix=png_prefix, num_snapshots=num_snapshots) imgRecon = proj.get_clean_images() if recontype == 'inpaint': # for inpainting, also merge the reconstruction with target image imgMerged = tf.where( forward.mask, imgRecon, imgCorrupted).eval() # if true, then recon, else imgCorrupted misc.save_image_grid(imgMerged, png_prefix + 'merged.png', drange=[-1, 1]) else: imgMerged = None return imgCorrupted, imgRecon, imgMerged, imgTrue
def training_loop_hd( I_args={}, # Options for generator network. M_args={}, # Options for discriminator network. I_info_args={}, # Options for class network. I_opt_args={}, # Options for generator optimizer. I_loss_args={}, # Options for generator loss. resume_G_pkl=None, # G network pickle to help training. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. I_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? I_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to I_args and M_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=0, # Number of discrete latents in model. n_continuous=10, # Number of continuous latents in model. n_samples_per=4, # Number of samples for each line in traversal. use_hd_with_cls=False, # If use info_loss. resolution_manual=1024, # Resolution of generated images. use_level_training=False, # If use level training (hierarchical optimization strategy). level_I_kimg=1000, # Number of kimg of tick for I_level training. use_std_in_m=False, # If output prior std in M net. prior_latent_size=512, # Prior latent size. use_hyperplane=False, # If use hyperplane model. pretrained_type='with_stylegan2'): # Pretrained type for G. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') I = tflib.Network('I', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **I_args) M = tflib.Network('M', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **M_args) Is = I.clone('Is') if use_hd_with_cls: I_info = tflib.Network('I_info', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **I_info_args) if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_hd_with_cls: rI, rM, rIs, rI_info = misc.load_pkl(resume_pkl) else: rI, rM, rIs = misc.load_pkl(resume_pkl) if resume_with_new_nets: I.copy_vars_from(rI) M.copy_vars_from(rM) Is.copy_vars_from(rIs) if use_hd_with_cls: I_info.copy_vars_from(rI_info) else: I = rI M = rM Is = rIs if use_hd_with_cls: I_info = rI_info print('Loading generator from "%s"...' % resume_G_pkl) if pretrained_type == 'with_stylegan2': rG, rD, rGs = misc.load_pkl(resume_G_pkl) G = rG D = rD Gs = rGs elif pretrained_type == 'with_cascadeVAE': rG = misc.load_pkl(resume_G_pkl) G = rG Gs = rG # Print layers and generate initial image snapshot. I.print_layers() M.print_layers() # pdb.set_trace() training_set_resolution_log2 = int(np.log2(resolution_manual)) sched = training_schedule( cur_nimg=total_kimg * 1000, training_set_resolution_log2=training_set_resolution_log2, **sched_args) if not use_hyperplane: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels) print('grid_size:', grid_size) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) if resolution_manual >= 256: grid_size = (grid_size[0], grid_size[1] // 5) grid_latents = grid_latents[:grid_latents.shape[0] // 5] grid_labels = grid_labels[:grid_labels.shape[0] // 5] prior_traj_latents = M.run(grid_latents, is_validation=True, minibatch_size=sched.minibatch_gpu) if use_std_in_m: prior_traj_latents = prior_traj_latents[:, :prior_latent_size] else: grid_size = (n_samples_per, n_continuous) grid_labels = np.tile(grid_labels[:1], (n_continuous * n_samples_per, 1)) latent_dirs = get_latent_dirs(n_continuous) prior_traj_latents = get_prior_traj_by_dirs(latent_dirs, M, n_samples_per, prior_latent_size, grid_labels, sched) if resolution_manual >= 256: grid_size = (grid_size[0], grid_size[1] // 5) prior_traj_latents = prior_traj_latents[:prior_traj_latents. shape[0] // 5] grid_labels = grid_labels[:grid_labels.shape[0] // 5] print('prior_traj_latents.shape:', prior_traj_latents.shape) # pdb.set_trace() prior_traj_latents_show = np.reshape( prior_traj_latents, [-1, n_samples_per, prior_latent_size]) print_traj(prior_traj_latents_show) grid_fakes = Gs.run(prior_traj_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) if (n_continuous == 2) and (n_discrete == 0): n_per_line = 20 ex_latent_value = 3 prior_grid_latents, prior_grid_labels = get_2d_grid_latents( low=-ex_latent_value, high=ex_latent_value, n_per_line=n_per_line, grid_labels=grid_labels) grid_showing_fakes = Gs.run(prior_grid_latents, prior_grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_showing_fakes = add_outline(grid_showing_fakes, width=1) misc.save_image_grid( grid_showing_fakes, dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png'), drange=drange_net, grid_size=[n_per_line, n_per_line]) img_to_draw = Image.open( dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png')) img_to_draw = img_to_draw.convert('RGB') img_to_draw = draw_traj_on_prior_grid(img_to_draw, prior_traj_latents_show, ex_latent_value, n_per_line) img_to_draw.save( dnnlib.make_run_dir_path('fakes_init_2d_prior_grid_drawn.png')) if use_level_training: ending_level = training_set_resolution_log2 - 1 else: ending_level = 1 # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Is_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), I_smoothing_kimg * 1000.0) if I_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. I_opt_args = dict(I_opt_args) for args, reg_interval in [(I_opt_args, I_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio I_opts = [] I_reg_opts = [] for n_level in range(ending_level): I_opts.append(tflib.Optimizer(name='TrainI_%d' % n_level, **I_opt_args)) I_reg_opts.append( tflib.Optimizer(name='RegI_%d' % n_level, share=I_opts[-1], **I_opt_args)) # Build training graph for each GPU. for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of I and M. I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') M_gpu = M if gpu == 0 else M.clone(M.name + '_shadow') G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if use_hd_with_cls: I_info_gpu = I_info if gpu == 0 else I_info.clone(I_info.name + '_shadow') # Evaluate loss functions. lod_assign_ops = [] I_losses = [] I_regs = [] if 'lod' in I_gpu.vars: lod_assign_ops += [tf.assign(I_gpu.vars['lod'], lod_in)] if 'lod' in M_gpu.vars: lod_assign_ops += [tf.assign(M_gpu.vars['lod'], lod_in)] for n_level in range(ending_level): with tf.control_dependencies(lod_assign_ops): with tf.name_scope('I_loss_%d' % n_level): if use_hd_with_cls: I_loss, I_reg = dnnlib.util.call_func_by_name( I=I_gpu, M=M_gpu, G=G_gpu, I_info=I_info_gpu, opt=I_opts[n_level], training_set=training_set, minibatch_size=minibatch_gpu_in, **I_loss_args) else: I_loss, I_reg = dnnlib.util.call_func_by_name( I=I_gpu, M=M_gpu, G=G_gpu, opt=I_opts[n_level], n_levels=(n_level + 1) if use_level_training else None, training_set=training_set, minibatch_size=minibatch_gpu_in, **I_loss_args) I_losses.append(I_loss) I_regs.append(I_reg) # Register gradients. if not lazy_regularization: if I_regs[n_level] is not None: I_losses[n_level] += I_regs[n_level] else: if I_regs[n_level] is not None: I_reg_opts[n_level].register_gradients( tf.reduce_mean(I_regs[n_level] * I_reg_interval), I_gpu.trainables) if use_hd_with_cls: MIIinfo_gpu_trainables = collections.OrderedDict( list(M_gpu.trainables.items()) + list(I_gpu.trainables.items()) + list(I_info_gpu.trainables.items())) I_opts[n_level].register_gradients( tf.reduce_mean(I_losses[n_level]), MIIinfo_gpu_trainables) else: MI_gpu_trainables = collections.OrderedDict( list(M_gpu.trainables.items()) + list(I_gpu.trainables.items())) I_opts[n_level].register_gradients( tf.reduce_mean(I_losses[n_level]), MI_gpu_trainables) # Setup training ops. I_train_ops = [] I_reg_ops = [] for n_level in range(ending_level): I_train_ops.append(I_opts[n_level].apply_updates()) I_reg_ops.append(I_reg_opts[n_level].apply_updates(allow_no_op=True)) Is_update_op = Is.setup_as_moving_average_of(I, beta=Is_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: I.setup_weight_histograms() M.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break n_level = 0 if not use_level_training else min( cur_nimg // (level_I_kimg * 1000), training_set_resolution_log2 - 2) # Choose training parameters and configure training ops. sched = training_schedule( cur_nimg=cur_nimg, training_set_resolution_log2=training_set_resolution_log2, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): I_opts[n_level].reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.I_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_I_reg = (lazy_regularization and running_mb_counter % I_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run(I_train_ops[n_level], feed_dict) if run_I_reg: tflib.run(I_reg_ops[n_level], feed_dict) tflib.run([Is_update_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(I_train_ops[n_level], feed_dict) if run_I_reg: for _round in rounds: tflib.run(I_reg_ops[n_level], feed_dict) tflib.run(Is_update_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 100 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): if not use_hyperplane: prior_traj_latents = M.run( grid_latents, is_validation=True, minibatch_size=sched.minibatch_gpu) if use_std_in_m: prior_traj_latents = prior_traj_latents[:, : prior_latent_size] else: prior_traj_latents = get_prior_traj_by_dirs( latent_dirs, M, n_samples_per, prior_latent_size, grid_labels, sched) if resolution_manual >= 256: prior_traj_latents = prior_traj_latents[: prior_traj_latents .shape[0] // 5] prior_traj_latents_show = np.reshape( prior_traj_latents, [-1, n_samples_per, prior_latent_size]) print_traj(prior_traj_latents_show) grid_fakes = Gs.run(prior_traj_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if (n_continuous == 2) and (n_discrete == 0): n_per_line = 20 ex_latent_value = 3 prior_grid_latents, prior_grid_labels = get_2d_grid_latents( low=-ex_latent_value, high=ex_latent_value, n_per_line=n_per_line, grid_labels=grid_labels) grid_showing_fakes = Gs.run( prior_grid_latents, prior_grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_showing_fakes = add_outline(grid_showing_fakes, width=1) misc.save_image_grid(grid_showing_fakes, dnnlib.make_run_dir_path( 'fakes_2d_prior_grid%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=[n_per_line, n_per_line]) img_to_draw = Image.open( dnnlib.make_run_dir_path( 'fakes_2d_prior_grid%06d.png' % (cur_nimg // 1000))) img_to_draw = img_to_draw.convert('RGB') img_to_draw = draw_traj_on_prior_grid( img_to_draw, prior_traj_latents_show, ex_latent_value, n_per_line) img_to_draw.save( dnnlib.make_run_dir_path( 'fakes_2d_prior_grid_drawn%06d.png' % (cur_nimg // 1000))) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((I, M, Is), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((I, M, Is), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=10, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) # ajay - move init to after graph creation? tflib.init_tf(tf_config) # Load training set. print('ajay - config data dir', config.data_dir) training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, num_hosts=hvd.size(), index=hvd.rank(), **dataset_args) # Construct networks. print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # with tf.device('/gpu:0'): # if resume_run_id is not None: # network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) # print('Loading networks from "%s"...' % network_pkl) # G, D, Gs = misc.load_pkl(network_pkl) # else: # print('Constructing networks...') # G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) # D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) # Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tf.train.AdamOptimizer(learning_rate=lrate_in, beta1=0.0, beta2=0.99, epsilon=1e-8) G_opt = hvd.DistributedOptimizer(G_opt) D_opt = tf.train.AdamOptimizer(learning_rate=lrate_in, beta1=0.0, beta2=0.99, epsilon=1e-8) D_opt = hvd.DistributedOptimizer(D_opt) G_gpu = G D_gpu = D lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] # ajay - check if unique minibatch is guaranteed i.e sharding is done right! reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_grads = G_opt.compute_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_grads = D_opt.compute_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_gradients(G_grads) D_train_op = D_opt.apply_gradients(D_grads) # Horovod init_op = tf.initialize_all_variables() bcast_op = hvd.broadcast_global_variables(0) # ajay tf.get_default_session().run([init_op]) tflib.run([bcast_op]) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) # todo: ajay - note num_gpus need to change to hvd size when going multi-node sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) if hvd.rank() == 0: print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < (total_kimg * 1000): if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) # todo: ajay - find a way to manually resetoptimizer if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): tflib.assert_tf_initialized() G_opt_reset_op = [var.initializer for var in G_opt.variables()] D_opt_reset_op = [var.initializer for var in D_opt.variables()] tflib.run(G_opt_reset_op) tflib.run(D_opt_reset_op) # G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # grp_train_op = tf.group(D_train_op, [Gs_update_op]) # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch #// submit_config.num_gpus tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= (total_kimg * 1000)) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. # ajay #ajay mod if hvd.rank() == 0: print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f ' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( # autosummary('Progress/tick', cur_tick), # autosummary('Progress/kimg', cur_nimg / 1000.0), # autosummary('Progress/lod', sched.lod), # autosummary('Progress/minibatch', sched.minibatch), # dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), # autosummary('Timing/sec_per_tick', tick_time), # autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), # autosummary('Timing/maintenance_sec', maintenance_time), # autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) # autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) # autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) # ajay - note modifying to 1 for eval metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=1, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() if hvd.rank() == 0: tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. if hvd.rank() == 0: misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_auto_loop( Enc_args={}, # Options for encoder network. Dec_args={}, # Options for decoder network. opt_args={}, # Options for encoder optimizer. loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=25000, # Total length of the training, measured in thousands of real images. drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False ): # Include weight histograms in the tfevents file? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, _ = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): print('Constructing networks...') Enc = tflib.Network('Encoder', resolution=training_set.shape[1], out_channels=16, **Enc_args) Dec = tflib.Network('Decoder', resolution=training_set.shape[1] // 4, in_channels=16, **Dec_args) # Print layers and generate initial image snapshot. Enc.print_layers() Dec.print_layers() sched = sched_args sched.tick_kimg = 4 grid_codes = Enc.run((grid_reals / 127.5) - 1.0, minibatch_size=sched.minibatch_gpu) grid_fakes = Dec.run(grid_codes, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. opt_args = dict(opt_args) opt_args['minibatch_multiplier'] = minibatch_multiplier opt_args['learning_rate'] = lrate_in opt = tflib.Optimizer(name='TrainAuto', **opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of Enc and Dec. Enc_gpu = Enc if gpu == 0 else Enc.clone(Enc.name + '_shadow') Dec_gpu = Dec if gpu == 0 else Dec.clone(Dec.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, _ = process_reals(reals_write, labels_write, 0.0, False, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] reals_read = reals_var[:minibatch_gpu_in] # Evaluate loss functions. with tf.name_scope('loss'): loss = dnnlib.util.call_func_by_name(Enc=Enc_gpu, Dec=Dec_gpu, opt=opt, reals=reals_read, **loss_args) # Register gradients. opt.register_gradients( tf.reduce_mean(loss), list(Enc_gpu.trainables.values()) + list(Dec_gpu.trainables.values())) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) train_op = opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: Enc.setup_weight_histograms() Dec.setup_weight_histograms() print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = 0 cur_tick = -1 tick_start_nimg = cur_nimg while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # Run training ops. feed_dict = { lrate_in: sched.lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) cur_nimg += sched.minibatch_size # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run(data_fetch_op, feed_dict) tflib.run(train_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_codes = Enc.run((grid_reals / 127.5) - 1.0, minibatch_size=sched.minibatch_gpu) grid_fakes = Dec.run(grid_codes, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((Enc, Dec), pkl) # Update summaries and RunContext. tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0.0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((Enc, Dec), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0
def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. AE_opt_args=None, # Options for autoencoder optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. AE_loss_args=None, # Options for autoencoder loss. dataset_args={}, # Options for dataset.load_dataset(). dataset_args_eval={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). train_data_dir=None, # Directory to load datasets from. eval_data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=True, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, resume_with_own_vars=False ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. print("Loading train set from %s..." % dataset_args.tfrecord_dir) training_set = dataset.load_dataset( data_dir=dnnlib.convert_path(train_data_dir), verbose=True, **dataset_args) print("Loading eval set from %s..." % dataset_args_eval.tfrecord_dir) eval_set = dataset.load_dataset( data_dir=dnnlib.convert_path(eval_data_dir), verbose=True, **dataset_args_eval) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Freeze Discriminator if D_args['freeze']: num_layers = np.log2(training_set.resolution) - 1 layers = int(np.round(num_layers * 3. / 8.)) scope = ['Output', 'scores_out'] for layer in range(layers): scope += ['.*%d' % 2**layer] if 'train_scope' in D_args: scope[-1] += '.*%d' % D_args['train_scope'] D_args['train_scope'] = scope # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is '' or resume_with_new_nets or resume_with_own_vars: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not '': print('Loading networks from "%s"...' % resume_pkl) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G = rG D = rD Gs = rGs grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) # SVD stuff if 'syn_svd' in G_args or 'map_svd' in G_args: # Run graph to calculate SVD grid_latents_smol = grid_latents[:1] rho = np.array([1]) grid_fakes = G.run(grid_latents_smol, grid_labels, rho, is_validation=True) grid_fakes = Gs.run(grid_latents_smol, grid_labels, rho, is_validation=True) load_d_fake = D.run(grid_reals[:1], rho, is_validation=True) with tf.device('/gpu:0'): # Create SVD-decomposed graph rG, rD, rGs = G, D, Gs G_lambda_mask = { var: np.ones(G.vars[var].shape[-1]) for var in G.vars if 'SVD/s' in var } D_lambda_mask = { 'D/' + var: np.ones(D.vars[var].shape[-1]) for var in D.vars if 'SVD/s' in var } G_reduce_dims = { var: (0, int(Gs.vars[var].shape[-1])) for var in Gs.vars if 'SVD/s' in var } G_args['lambda_mask'] = G_lambda_mask G_args['reduce_dims'] = G_reduce_dims D_args['lambda_mask'] = D_lambda_mask # Create graph with no SVD operations G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=rG.input_shapes[1][1], factorized=True, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=rD.input_shapes[1][1], factorized=True, **D_args) Gs = G.clone('Gs') grid_fakes = G.run(grid_latents_smol, grid_labels, rho, is_validation=True, minibatch_size=1) grid_fakes = Gs.run(grid_latents_smol, grid_labels, rho, is_validation=True, minibatch_size=1) G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) # Reduce per-gpu minibatch size to fit in 16GB GPU memory if grid_reals.shape[2] >= 1024: sched_args.minibatch_gpu_base = 2 print('Batch size', sched_args.minibatch_gpu_base) # Generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) rho = np.array([1]) grid_fakes = Gs.run(grid_latents, grid_labels, rho, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) if resume_pkl is not '': load_d_real = rD.run(grid_reals[:1], rho, is_validation=True) load_d_fake = rD.run(grid_fakes[:1], rho, is_validation=True) d_fake = D.run(grid_fakes[:1], rho, is_validation=True) d_real = D.run(grid_reals[:1], rho, is_validation=True) print('Factorized fake', d_fake, 'loaded fake', load_d_fake, 'factorized real', d_real, 'loaded real', load_d_real) print('(should match)') # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) if AE_opt_args is not None: AE_opt_args = dict(AE_opt_args) AE_opt_args['minibatch_multiplier'] = minibatch_multiplier AE_opt_args['learning_rate'] = lrate_in AE_opt = tflib.Optimizer(name='TrainAE', **AE_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): if G_loss_args['func_name'] == 'training.loss.G_l1': G_loss_args['reals'] = reals_read else: G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) ae_iter_mul = 10 ae_rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus * ae_iter_mul) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: _g_loss, _ = tflib.run([G_loss, G_train_op], feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): print('g loss', _g_loss) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and cur_tick % network_snapshot_ticks == 0 or done: pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(eval_data_dir), num_gpus=num_gpus, tf_config=tf_config, rho=rho) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close() eval_set.close()
def training_loop_vc( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. I_args={}, # Options for infogan-head/vcgan-head network. I_info_args={}, # Options for infogan-head/vcgan-head network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). use_info_gan=False, # Whether to use info-gan. use_vc_head=False, # Whether to use vc-head. use_vc_head_with_cls=False, # Whether to use classification in discriminator. data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=3, # Number of discrete latents in model. n_continuous=4, # Number of continuous latents in model. n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) if use_info_gan or use_vc_head or use_vc_head_with_cls: I = tflib.Network('I', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **I_args) if use_vc_head_with_cls: I_info = tflib.Network('I_info', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **I_info_args) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_info_gan or use_vc_head: rG, rD, rI, rGs = misc.load_pkl(resume_pkl) elif use_vc_head_with_cls: rG, rD, rI, rI_info, rGs = misc.load_pkl(resume_pkl) else: rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) if use_info_gan or use_vc_head or use_vc_head_with_cls: I.copy_vars_from(rI) if use_vc_head_with_cls: I_info.copy_vars_from(rI_info) Gs.copy_vars_from(rGs) else: G = rG D = rD if use_info_gan or use_vc_head or use_vc_head_with_cls: I = rI if use_vc_head_with_cls: I_info = rI_info Gs = rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() if use_info_gan or use_vc_head or use_vc_head_with_cls: I.print_layers() if use_vc_head_with_cls: I_info.print_layers() # pdb.set_trace() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) if traversal_grid: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) # pdb.set_trace() grid_fakes, _ = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=False) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') if use_info_gan or use_vc_head or use_vc_head_with_cls: I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') if use_vc_head_with_cls: I_info_gpu = I_info if gpu == 0 else I_info.clone( I_info.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): if use_info_gan or use_vc_head: G_loss, G_reg, I_loss, _ = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, I=I_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) elif use_vc_head_with_cls: G_loss, G_reg, I_loss, I_info_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, I=I_gpu, I_info=I_info_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) else: G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) # print('G_gpu.trainables:', G_gpu.trainables) # print('D_gpu.trainables:', D_gpu.trainables) # print('I_gpu.trainables:', I_gpu.trainables) if use_info_gan or use_vc_head: GI_gpu_trainables = collections.OrderedDict( list(G_gpu.trainables.items()) + list(I_gpu.trainables.items())) G_opt.register_gradients(tf.reduce_mean(G_loss + I_loss), GI_gpu_trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # G_opt.register_gradients(tf.reduce_mean(I_loss), # GI_gpu_trainables) # D_opt.register_gradients(tf.reduce_mean(I_loss), # D_gpu.trainables) elif use_vc_head_with_cls: GIIinfo_gpu_trainables = collections.OrderedDict( list(G_gpu.trainables.items()) + list(I_gpu.trainables.items()) + list(I_info_gpu.trainables.items())) G_opt.register_gradients( tf.reduce_mean(G_loss + I_loss + I_info_loss), GIIinfo_gpu_trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) else: G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # if use_info_gan: # # INFO-GAN-HEAD loss # G_opt.register_gradients(tf.reduce_mean(I_loss), # G_gpu.trainables) # G_opt.register_gradients(tf.reduce_mean(I_loss), # I_gpu.trainables) # D_opt.register_gradients(tf.reduce_mean(I_loss), # D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() if use_info_gan or use_vc_head or use_vc_head_with_cls: I.setup_weight_histograms() if use_vc_head_with_cls: I_info.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes, _ = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=False) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) if use_info_gan or use_vc_head: misc.save_pkl((G, D, I, Gs), pkl) elif use_vc_head_with_cls: misc.save_pkl((G, D, I, I_info, Gs), pkl) else: misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. if use_info_gan or use_vc_head: misc.save_pkl((G, D, I, Gs), dnnlib.make_run_dir_path('network-final.pkl')) elif use_vc_head_with_cls: misc.save_pkl((G, D, I, I_info, Gs), dnnlib.make_run_dir_path('network-final.pkl')) else: misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop_vae( G_args={}, # Options for generator network. E_args={}, # Options for encoder network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. minibatch_repeats=1, # Number of minibatches to run before adjusting training parameters. total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=0, # Number of discrete latents in model. n_continuous=4, # Number of continuous latents in model. topk_dims_to_show=20, # Number of top disentant dimensions to show in a snapshot. subgroup_sizes_ls=None, subspace_sizes_ls=None, forward_eg=False, n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # If use Discriminator. use_D = D_args is not None # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) grid_fakes = add_outline(grid_reals, width=1) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') E = tflib.Network('E', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None] + training_set.shape, **E_args) G = tflib.Network( 'G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None, n_discrete + G_args.latent_size] if not forward_eg else [ None, n_discrete + G_args.latent_size + sum(subgroup_sizes_ls) ], **G_args) if use_D: D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None, D_args.latent_size], **D_args) if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_D: rE, rG, rD = misc.load_pkl(resume_pkl) else: rE, rG = misc.load_pkl(resume_pkl) if resume_with_new_nets: E.copy_vars_from(rE) G.copy_vars_from(rG) if use_D: D.copy_vars_from(rD) else: E = rE G = rG if use_D: D = rD # Print layers and generate initial image snapshot. E.print_layers() G.print_layers() if use_D: D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) if traversal_grid: if topk_dims_to_show > 0: topk_dims = np.arange(min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) print('grid_size:', grid_size) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v( G.run(append_gfeats(grid_latents, G) if forward_eg else grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True), 8) print('Lie_vars:', lie_vars[0]) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. G_opt_args = dict(G_opt_args) G_opt_args['minibatch_multiplier'] = minibatch_multiplier G_opt_args['learning_rate'] = lrate_in G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) if use_D: D_opt_args = dict(D_opt_args) D_opt_args['minibatch_multiplier'] = minibatch_multiplier D_opt_args['learning_rate'] = lrate_in D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if use_D: D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, 0., mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. if use_D: with tf.name_scope('G_loss'): G_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **G_loss_args) with tf.name_scope('D_loss'): D_loss = dnnlib.util.call_func_by_name( E=E_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) else: with tf.name_scope('G_loss'): G_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **G_loss_args) # Register gradients. EG_gpu_trainables = collections.OrderedDict( list(E_gpu.trainables.items()) + list(G_gpu.trainables.items())) G_opt.register_gradients(tf.reduce_mean(G_loss), EG_gpu_trainables) # G_opt.register_gradients(G_loss, # EG_gpu_trainables) if use_D: D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # D_opt.register_gradients(D_loss, # D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() if use_D: D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() if use_D: D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, 0) # Run training ops. feed_dict = { lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op], feed_dict) tflib.run([data_fetch_op], feed_dict) if use_D: tflib.run([D_train_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) if use_D: tflib.run(D_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) if use_D: misc.save_pkl((E, G, D), pkl) else: misc.save_pkl((E, G), pkl) met_outs = metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config, is_vae=True, use_D=use_D, Gs_kwargs=dict(is_validation=True)) if topk_dims_to_show > 0: if 'tpl_per_dim' in met_outs: avg_distance_per_dim = met_outs[ 'tpl_per_dim'] # shape: (n_continuous) topk_dims = np.argsort( avg_distance_per_dim )[::-1][:topk_dims_to_show] # shape: (20) else: topk_dims = np.arange( min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): if traversal_grid: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v( G.run(append_gfeats(grid_latents, G) if forward_eg else grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True), 8) print('Lie_vars:', lie_vars[0]) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. if use_D: misc.save_pkl((E, G, D), dnnlib.make_run_dir_path('network-final.pkl')) else: misc.save_pkl((E, G), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def embed(batch_size, resolution, imgs, network, iteration, result_dir, seed=6600): tf.reset_default_graph() print('Loading networks from "%s"...' % network) tflib.init_tf() _, _, G = pretrained_networks.load_networks(network) img_in = tf.placeholder(tf.float32) opt = tf.train.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8) noise_vars = [ var for name, var in G.components.synthesis.vars.items() if name.startswith('noise') ] alpha_vars = [ var for name, var in G.components.synthesis.vars.items() if name.endswith('alpha') ] alpha_eval = [alpha.eval() for alpha in alpha_vars] G_kwargs = dnnlib.EasyDict() G_kwargs.randomize_noise = False G_syn = G.components.synthesis rnd = np.random.RandomState(seed) dlatent_avg = [ var for name, var in G.vars.items() if name.startswith('dlatent_avg') ][0].eval() dlatent_avg = np.expand_dims(np.expand_dims(dlatent_avg, 0), 1) dlatent_avg = dlatent_avg.repeat(12, 1) dlatent = tf.get_variable('dlatent', dtype=tf.float32, initializer=tf.constant(dlatent_avg), trainable=True) synth_img = G_syn.get_output_for(dlatent, is_training=False, **G_kwargs) # synth_img = (synth_img + 1.0) / 2.0 with tf.variable_scope('mse_loss'): mse_loss = tf.reduce_mean(tf.square(img_in - synth_img)) with tf.variable_scope('perceptual_loss'): vgg_in = tf.concat([img_in, synth_img], 0) tf.keras.backend.set_image_data_format('channels_first') vgg = tf.keras.applications.VGG16( include_top=False, input_tensor=vgg_in, input_shape=(3, 128, 128), weights='/gdata2/fengrl/metrics/vgg.h5', pooling=None) h1 = vgg.get_layer('block1_conv1').output h2 = vgg.get_layer('block1_conv2').output h3 = vgg.get_layer('block3_conv2').output h4 = vgg.get_layer('block4_conv2').output pcep_loss = tf.reduce_mean(tf.square(h1[0] - h1[1])) + tf.reduce_mean(tf.square(h2[0] - h2[1])) + \ tf.reduce_mean(tf.square(h3[0] - h3[1])) + tf.reduce_mean(tf.square(h4[0] - h4[1])) loss = 0.5 * mse_loss + 0.5 * pcep_loss with tf.control_dependencies([loss]): train_op = opt.minimize(loss, var_list=[dlatent]) reset_opt = tf.variables_initializer(opt.variables()) reset_dl = tf.variables_initializer([dlatent]) tflib.init_uninitialized_vars() # rnd = np.random.RandomState(seed) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] idx = 0 metrics_l = [] metrics_p = [] metrics_m = [] metrics_d = [] metrics_args = [metric_defaults[x] for x in ['fid50k', 'ppl_wend']] metrics_fun = metric_base.MetricGroup(metrics_args) for temperature in [0.2, 0.5, 1.0, 1.5, 2.0, 10.0]: tflib.set_vars({ alpha: scale_alpha(alpha_np, temperature) for alpha, alpha_np in zip(alpha_vars, alpha_eval) }) # misc.save_pkl((G, G, G), os.path.join(result_dir, 'temp%f.pkl' % temperature)) # metrics_fun.run(os.path.join(result_dir, 'temp%f.pkl' % temperature), run_dir=result_dir, # data_dir='/gdata/fengrl/noise_test_dset/tfrecords', # dataset_args=dnnlib.EasyDict(tfrecord_dir='ffhq-128', shuffle_mb=0), # mirror_augment=True, num_gpus=1) for img in imgs: img = np.expand_dims(img, 0) loss_list = [] p_loss_list = [] m_loss_list = [] dl_list = [] si_list = [] tflib.run([reset_opt, reset_dl]) for i in range(iteration): loss_, p_loss_, m_loss_, dl_, si_, _ = tflib.run( [loss, pcep_loss, mse_loss, dlatent, synth_img, train_op], {img_in: img}) loss_list.append(loss_) p_loss_list.append(p_loss_) m_loss_list.append(m_loss_) dl_loss_ = np.sum(np.square(dl_ - dlatent_avg)) dl_list.append(dl_loss_) if i % 500 == 0: si_list.append(si_) if i % 100 == 0: print( 'Temperature %f, idx %d, Loss %f, mse %f, ppl %f, dl %f, step %d' % (temperature, idx, loss_, m_loss_, p_loss_, dl_loss_, i)) print('Temperature %f, idx %d, loss: %f, ppl: %f, mse: %f, d: %f' % (temperature, idx, loss_list[-1], p_loss_list[-1], m_loss_list[-1], dl_list[-1])) metrics_l.append(loss_list[-1]) metrics_p.append(p_loss_list[-1]) metrics_m.append(m_loss_list[-1]) metrics_d.append(dl_list[-1]) misc.save_image_grid(np.concatenate(si_list, 0), os.path.join( result_dir, 'temp%fsi%d.png' % (temperature, idx)), drange=[-1, 1]) misc.save_image_grid( si_list[-1], os.path.join(result_dir, 'temp%fsifinal%d.png' % (temperature, idx)), drange=[-1, 1]) with open( os.path.join(result_dir, 'temp%fmetric_l%d.txt' % (temperature, idx)), 'w') as f: for l_ in loss_list: f.write(str(l_) + '\n') with open( os.path.join(result_dir, 'temp%fmetric_p%d.txt' % (temperature, idx)), 'w') as f: for l_ in p_loss_list: f.write(str(l_) + '\n') with open( os.path.join(result_dir, 'temp%fmetric_m%d.txt' % (temperature, idx)), 'w') as f: for l_ in m_loss_list: f.write(str(l_) + '\n') with open( os.path.join(result_dir, 'temp%fmetric_d%d.txt' % (temperature, idx)), 'w') as f: for l_ in dl_list: f.write(str(l_) + '\n') idx += 1 l_mean = np.mean(metrics_l) p_mean = np.mean(metrics_p) m_mean = np.mean(metrics_m) d_mean = np.mean(metrics_d) with open( os.path.join(result_dir, 'Temp%fmetric_lmpd.txt' % temperature), 'w') as f: for i in range(len(metrics_l)): f.write( str(metrics_l[i]) + ' ' + str(metrics_m[i]) + ' ' + str(metrics_p[i]) + ' ' + str(metrics_d[i]) + '\n') print( 'Overall metrics: temp %f, loss_mean %f, ppl_mean %f, mse_mean %f, d_mean %f' % (temperature, l_mean, p_mean, m_mean, d_mean)) with open(os.path.join(result_dir, 'mean_metrics'), 'a') as f: f.write('Temperature %f\n' % temperature) f.write('loss %f\n' % l_mean) f.write('mse %f\n' % m_mean) f.write('ppl %f\n' % p_mean) f.write('dl %f\n' % d_mean)
def embed(batch_size, resolution, imgs, network, iteration, result_dir, seed=6600): tf.reset_default_graph() G_args = dnnlib.EasyDict(func_name='training.networks_stylegan2_alpha.G_main') G_args.fmap_base = 8 << 10 print('Loading networks from "%s"...' % network) tflib.init_tf() G = tflib.Network('G', num_channels=3, resolution=128, **G_args) _, _, Gs = pretrained_networks.load_networks(network) G.copy_vars_from(Gs) img_in = tf.placeholder(tf.float32) opt = tf.train.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8) opt_T = tf.train.AdamOptimizer(learning_rate=0.002, beta1=0.9, beta2=0.999, epsilon=1e-8) noise_vars = [var for name, var in G.components.synthesis.vars.items() if name.startswith('noise')] alpha_vars = [var for name, var in G.components.synthesis.vars.items() if name.endswith('alpha')] alpha_evals = [alpha.eval() for alpha in alpha_vars] G_kwargs = dnnlib.EasyDict() G_kwargs.randomize_noise = False G_syn = G.components.synthesis rnd = np.random.RandomState(seed) dlatent_avg = [var for name, var in G.vars.items() if name.startswith('dlatent_avg')][0].eval() dlatent_avg = np.expand_dims(np.expand_dims(dlatent_avg, 0), 1) dlatent_avg = dlatent_avg.repeat(12, 1) dlatent = tf.get_variable('dlatent', dtype=tf.float32, initializer=tf.constant(dlatent_avg), trainable=True) T = tf.get_variable('T', dtype=tf.float32, initializer=tf.constant(0.95)) alpha_pre = [scale_alpha_exp(alpha_eval, T) for alpha_eval in alpha_evals] synth_img = G_syn.get_output_for(dlatent, is_training=False, alpha_pre=alpha_pre, **G_kwargs) # synth_img = (synth_img + 1.0) / 2.0 with tf.variable_scope('mse_loss'): mse_loss = tf.reduce_mean(tf.square(img_in - synth_img)) with tf.variable_scope('perceptual_loss'): vgg_in = tf.concat([img_in, synth_img], 0) tf.keras.backend.set_image_data_format('channels_first') vgg = tf.keras.applications.VGG16(include_top=False, input_tensor=vgg_in, input_shape=(3, 128, 128), weights='/gdata2/fengrl/metrics/vgg.h5', pooling=None) h1 = vgg.get_layer('block1_conv1').output h2 = vgg.get_layer('block1_conv2').output h3 = vgg.get_layer('block3_conv2').output h4 = vgg.get_layer('block4_conv2').output pcep_loss = tf.reduce_mean(tf.square(h1[0] - h1[1])) + tf.reduce_mean(tf.square(h2[0] - h2[1])) + \ tf.reduce_mean(tf.square(h3[0] - h3[1])) + tf.reduce_mean(tf.square(h4[0] - h4[1])) loss = 0.5 * mse_loss + 0.5 * pcep_loss with tf.control_dependencies([loss]): grads = tf.gradients(mse_loss, [dlatent, T]) train_op1 = opt.apply_gradients(zip([grads[0]], [dlatent])) train_op2 = opt_T.apply_gradients(zip([grads[1]], [T])) train_op = tf.group(train_op1, train_op2) reset_opt = tf.variables_initializer(opt.variables()+opt_T.variables()) reset_dl = tf.variables_initializer([dlatent, T]) tflib.init_uninitialized_vars() # rnd = np.random.RandomState(seed) tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] idx = 0 metrics_l = [] metrics_p = [] metrics_m = [] metrics_d = [] T_list = [] for img in imgs: img = np.expand_dims(img, 0) loss_list = [] p_loss_list = [] m_loss_list = [] dl_list = [] si_list = [] # tflib.set_vars({alpha: alpha_np for alpha, alpha_np in zip(alpha_vars, alpha_evals)}) tflib.run([reset_opt, reset_dl]) for i in range(iteration): loss_, p_loss_, m_loss_, dl_, si_, t_, _ = tflib.run([loss, pcep_loss, mse_loss, dlatent, synth_img, T, train_op], {img_in: img}) loss_list.append(loss_) p_loss_list.append(p_loss_) m_loss_list.append(m_loss_) dl_loss_ = np.sum(np.square(dl_-dlatent_avg)) dl_list.append(dl_loss_) if i % 500 == 0: si_list.append(si_) if i % 100 == 0: print('idx %d, Loss %f, mse %f, ppl %f, dl %f, t %f, step %d' % (idx, loss_, m_loss_, p_loss_, dl_loss_, t_, i)) print('T: %f, loss: %f, ppl: %f, mse: %f, d: %f' % (t_, loss_list[-1], p_loss_list[-1], m_loss_list[-1], dl_list[-1])) metrics_l.append(loss_list[-1]) metrics_p.append(p_loss_list[-1]) metrics_m.append(m_loss_list[-1]) metrics_d.append(dl_list[-1]) T_list.append(t_) misc.save_image_grid(np.concatenate(si_list, 0), os.path.join(result_dir, 'si%d.png' % idx), drange=[-1, 1]) misc.save_image_grid(si_list[-1], os.path.join(result_dir, 'sifinal%d.png' % idx), drange=[-1, 1]) with open(os.path.join(result_dir, 'metric_l%d.txt' % idx), 'w') as f: for l_ in loss_list: f.write(str(l_) + '\n') with open(os.path.join(result_dir, 'metric_p%d.txt' % idx), 'w') as f: for l_ in p_loss_list: f.write(str(l_) + '\n') with open(os.path.join(result_dir, 'metric_m%d.txt' % idx), 'w') as f: for l_ in m_loss_list: f.write(str(l_) + '\n') with open(os.path.join(result_dir, 'metric_d%d.txt' % idx), 'w') as f: for l_ in dl_list: f.write(str(l_) + '\n') idx += 1 l_mean = np.mean(metrics_l) p_mean = np.mean(metrics_p) m_mean = np.mean(metrics_m) d_mean = np.mean(metrics_d) with open(os.path.join(result_dir, 'metric_lmpd.txt'), 'w') as f: f.write(str(alpha_evals)+'\n') for i in range(len(metrics_l)): f.write(str(T_list[i])+' '+str(metrics_l[i])+' '+str(metrics_m[i])+' '+str(metrics_p[i])+' '+str(metrics_d[i])+'\n') print('Overall metrics: loss_mean %f, ppl_mean %f, mse_mean %f, d_mean %f' % (l_mean, p_mean, m_mean, d_mean)) with open(os.path.join(result_dir, 'mean_metrics.txt'), 'w') as f: f.write('loss %f\n' % l_mean) f.write('mse %f\n' % m_mean) f.write('ppl %f\n' % p_mean) f.write('dl %f\n' % d_mean)
def generate_traversals(network_pkl, seeds, tpl_metric, n_samples_per, topk_dims_to_show, return_atts=False, bound=2): tflib.init_tf() print('Loading networks from "%s"...' % network_pkl) # _G, _D, Gs = pretrained_networks.load_networks(network_pkl) _G, _D, I, Gs = get_return_v(misc.load_pkl(network_pkl), 4) # noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] Gs_kwargs = dnnlib.EasyDict() # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) Gs_kwargs.randomize_noise = True if return_atts: Gs_kwargs.return_atts = True, n_continuous = Gs.input_shape[1] grid_labels = np.zeros([1, 0], dtype=np.float32) # Eval tpl if topk_dims_to_show > 0: metric_args = metric_defaults[tpl_metric] metric = dnnlib.util.call_func_by_name(**metric_args) met_outs = metric._evaluate(Gs, {}, 1) if 'tpl_per_dim' in met_outs: avg_distance_per_dim = met_outs[ 'tpl_per_dim'] # shape: (n_continuous) topk_dims = np.argsort( avg_distance_per_dim)[::-1][:topk_dims_to_show] # shape: (20) else: topk_dims = np.arange(min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) for seed_idx, seed in enumerate(seeds): grid_size, grid_latents, grid_labels = get_grid_latents( 0, n_continuous, n_samples_per, Gs, grid_labels, topk_dims) # images, _ = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel] grid_fakes, atts = get_return_v( Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=2, randomize_noise=True), 2) if return_atts: atts = atts[:, topk_dims] save_atts(atts, filename=dnnlib.make_run_dir_path('atts_seed%04d.png' % seed), grid_size=grid_size, drange=[0, 1], grid_fakes=grid_fakes, n_samples_per=n_samples_per) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('travs_seed%04d.png' % seed), drange=[-1., 1.], grid_size=grid_size)
def generate_images(network_pkl, network_G_pkl, n_imgs, model_type, n_discrete, n_continuous, use_std_in_m=None, latent_type='uniform', n_samples_per=10): print('Loading networks from "%s"...' % network_pkl) tflib.init_tf() if model_type == 'hd_dis_model_with_cls': # _G, _D, I, Gs = misc.load_pkl(network_pkl) I, M, Is, I_info = misc.load_pkl(network_pkl) else: # _G, _D, Gs = misc.load_pkl(network_pkl) I, M, Is = misc.load_pkl(network_pkl) # Load pretrained GAN _G, _D, Gs = misc.load_pkl(network_G_pkl) Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) Gs_kwargs.randomize_noise = False for idx in range(n_imgs): print('Generating image %d/%d ...' % (idx, n_imgs)) if n_discrete == 0: grid_labels = np.zeros([n_continuous * n_samples_per, 0], dtype=np.float32) else: grid_labels = np.zeros( [n_discrete * n_continuous * n_samples_per, 0], dtype=np.float32) grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, _G, grid_labels, latent_type=latent_type) prior_traj_latents = M.run(grid_latents, is_validation=True, minibatch_size=4) if use_std_in_m is not None: prior_traj_latents = prior_traj_latents[:, :prior_traj_latents. shape[1] // 2] grid_fakes = Gs.run(prior_traj_latents, grid_labels, is_validation=True, minibatch_size=4, randomize_noise=False) print(grid_fakes.shape) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('img_%04d.png' % idx), drange=[-1, 1], grid_size=grid_size) frames = [] grid_fakes = np.reshape(grid_fakes, [ n_continuous, n_samples_per, grid_fakes.shape[1], grid_fakes.shape[2], grid_fakes.shape[3] ]) for i in range(n_samples_per): to_concat = [grid_fakes[j, i] for j in range(n_continuous)] to_concat = tuple(to_concat) grid_fake_pil = misc.convert_to_pil_image( np.concatenate(to_concat, axis=2)) frames.append(grid_fake_pil) frames[0].save(dnnlib.make_run_dir_path('latents_trav_%04d.gif' % idx), format='GIF', append_images=frames[1:], save_all=True, duration=100, loop=0)
def embed(batch_size, resolution, imgs, network, iteration, result_dir, seed=6600): tf.reset_default_graph() print('Loading networks from "%s"...' % network) tflib.init_tf() _, _, G = pretrained_networks.load_networks(network) img_in = tf.placeholder(tf.float32) step = tf.get_variable('step', dtype=tf.float32, initializer=tf.constant(0.0)) lr = tf.train.exponential_decay(0.01, global_step=step, decay_steps=3500, decay_rate=0.5) opt = tf.train.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8) noise_vars = [ var for name, var in G.components.synthesis.vars.items() if name.startswith('noise') ] G_kwargs = dnnlib.EasyDict() G_kwargs.randomize_noise = False G_syn = G.components.synthesis rnd = np.random.RandomState(seed) dlatent_avg = [ var for name, var in G.vars.items() if name.startswith('dlatent_avg') ][0].eval() dlatent_avg = np.expand_dims(np.expand_dims(dlatent_avg, 0), 1) dlatent_avg = dlatent_avg.repeat(12, 1) dlatent = tf.get_variable('dlatent', dtype=tf.float32, initializer=tf.constant(dlatent_avg), trainable=True) synth_img = G_syn.get_output_for(dlatent, is_training=False, **G_kwargs) # synth_img = (synth_img + 1.0) / 2.0 with tf.variable_scope('mse_loss'): mse_loss = tf.reduce_mean(tf.square(img_in - synth_img)) with tf.variable_scope('perceptual_loss'): vgg_in = tf.concat([img_in, synth_img], 0) tf.keras.backend.set_image_data_format('channels_first') vgg = tf.keras.applications.VGG16( include_top=False, input_tensor=vgg_in, input_shape=(3, 128, 128), weights='/gdata2/fengrl/metrics/vgg.h5', pooling=None) h1 = vgg.get_layer('block1_conv1').output h2 = vgg.get_layer('block1_conv2').output h3 = vgg.get_layer('block3_conv2').output h4 = vgg.get_layer('block4_conv2').output pcep_loss = tf.reduce_mean(tf.square(h1[0] - h1[1])) + tf.reduce_mean(tf.square(h2[0] - h2[1])) + \ tf.reduce_mean(tf.square(h3[0] - h3[1])) + tf.reduce_mean(tf.square(h4[0] - h4[1])) loss = 0.5 * mse_loss + 0.5 * pcep_loss with tf.control_dependencies([loss, tf.assign_add(step, 1.0)]): train_op = opt.minimize(mse_loss, var_list=[dlatent]) reset_opt = tf.variables_initializer(opt.variables()) reset_dl = tf.variables_initializer([dlatent, step]) tflib.init_uninitialized_vars() # rnd = np.random.RandomState(seed) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] idx = 0 metrics_l = [] metrics_p = [] metrics_m = [] metrics_d = [] for img in imgs: img = np.expand_dims(img, 0) loss_list = [] p_loss_list = [] m_loss_list = [] dl_list = [] si_list = [] tflib.run([reset_opt, reset_dl]) print(lr.eval()) for i in range(iteration): loss_, p_loss_, m_loss_, dl_, si_, _ = tflib.run( [loss, pcep_loss, mse_loss, dlatent, synth_img, train_op], {img_in: img}) loss_list.append(loss_) p_loss_list.append(p_loss_) m_loss_list.append(m_loss_) dl_loss_ = np.sum(np.square(dl_ - dlatent_avg)) dl_list.append(dl_loss_) if i % 500 == 0: si_list.append(si_) if i % 100 == 0: print('idx %d, Loss %f, mse %f, ppl %f, dl %f, step %d' % (idx, loss_, m_loss_, p_loss_, dl_loss_, i)) print('loss: %f, ppl: %f, mse: %f, d: %f' % (loss_list[-1], p_loss_list[-1], m_loss_list[-1], dl_list[-1])) metrics_l.append(loss_list[-1]) metrics_p.append(p_loss_list[-1]) metrics_m.append(m_loss_list[-1]) metrics_d.append(dl_list[-1]) misc.save_image_grid(np.concatenate(si_list, 0), os.path.join(result_dir, 'si%d.png' % idx), drange=[-1, 1]) misc.save_image_grid(si_list[-1], os.path.join(result_dir, 'sifinal%d.png' % idx), drange=[-1, 1]) with open(os.path.join(result_dir, 'metric_l%d.txt' % idx), 'w') as f: for l_ in loss_list: f.write(str(l_) + '\n') with open(os.path.join(result_dir, 'metric_p%d.txt' % idx), 'w') as f: for l_ in p_loss_list: f.write(str(l_) + '\n') with open(os.path.join(result_dir, 'metric_m%d.txt' % idx), 'w') as f: for l_ in m_loss_list: f.write(str(l_) + '\n') with open(os.path.join(result_dir, 'metric_d%d.txt' % idx), 'w') as f: for l_ in dl_list: f.write(str(l_) + '\n') idx += 1 l_mean = np.mean(metrics_l) p_mean = np.mean(metrics_p) m_mean = np.mean(metrics_m) d_mean = np.mean(metrics_d) with open(os.path.join(result_dir, 'metric_lmpd.txt'), 'w') as f: for i in range(len(metrics_l)): f.write( str(metrics_l[i]) + ' ' + str(metrics_m[i]) + ' ' + str(metrics_p[i]) + ' ' + str(metrics_d[i]) + '\n') print( 'Overall metrics: loss_mean %f, ppl_mean %f, mse_mean %f, d_mean %f' % (l_mean, p_mean, m_mean, d_mean)) with open(os.path.join(result_dir, 'mean_metrics.txt'), 'w') as f: f.write('loss %f\n' % l_mean) f.write('mse %f\n' % m_mean) f.write('ppl %f\n' % p_mean) f.write('dl %f\n' % d_mean)
def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. loss_args={}, # Options for loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for metrics. metric_args={}, # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). ema_start_kimg=None, # Start of the exponential moving average. Default to the half-life period. G_ema_kimg=10, # Half-life of the exponential moving average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=False, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=4, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=2, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=1, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False ): # Construct new networks according to G_args and D_args before resuming training? if ema_start_kimg is None: ema_start_kimg = G_ema_kimg # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: resume_networks = misc.load_pkl(resume_pkl) rG, rD, rGs = resume_networks if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) G_lrate_in = tf.placeholder(tf.float32, name='G_lrate_in', shape=[]) D_lrate_in = tf.placeholder(tf.float32, name='D_lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) run_D_reg_in = tf.placeholder(tf.bool, name='run_D_reg', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta_mul_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[]) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_ema_kimg * 1000.0) if G_ema_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) G_opt_args['learning_rate'] = G_lrate_in D_opt_args['learning_rate'] = D_lrate_in for args in [G_opt_args, D_opt_args]: args['minibatch_multiplier'] = minibatch_multiplier G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) # Build training graph for each GPU. for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): with tf.name_scope('DataFetch'): reals_read, labels_read = training_set.get_minibatch_tf() reals_read = process_reals(reals_read, lod_in, mirror_augment, training_set.dynamic_range, drange_net) # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('loss'): G_loss, D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, real_labels=labels_read, **loss_args) # Register gradients. if not lazy_regularization: if D_reg is not None: D_loss += D_reg else: if D_reg is not None: D_loss = tf.cond(run_D_reg_in, lambda: D_loss + D_reg * D_reg_interval, lambda: D_loss) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta * Gs_beta_mul_in) with tf.control_dependencies([Gs_update_op]): G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list, **metric_args) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, G_lrate_in: sched.G_lrate, D_lrate_in: sched.D_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, Gs_beta_mul_in: 1 if cur_nimg >= ema_start_kimg * 1000 else 0, } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) feed_dict[run_D_reg_in] = run_D_reg cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. for _ in rounds: tflib.run(G_train_op, feed_dict) tflib.run(D_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def joint_train( submit_config, opt, metric_arg_list, sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 dataset_args = {}, # 数据集设置。 total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 mirror_augment = False, # 启用镜像增强? reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0, # 在训练开始时给定统计时间。影响报告。 *args, **kwargs ): output_dir = opt.output_dir graph_kwargs = util.set_graph_kwargs(opt) graph_util = importlib.import_module('graphs.' + opt.model + '.graph_util') constants = importlib.import_module('graphs.' + opt.model + '.constants') model = graphs.find_model_using_name(opt.model, opt.transform) g = model(submit_config=submit_config, dataset_args=dataset_args, **graph_kwargs, **kwargs) g.initialize_graph() # create training samples #num_samples = opt.num_samples # if opt.model == 'biggan' and opt.biggan.category is not None: # graph_inputs = graph_util.graph_input(g, num_samples, seed=0, category=opt.biggan.category) # else: # graph_inputs = graph_util.graph_input(g, num_samples, seed=0) w_snapshot_ticks = opt.model_save_freq ctx = dnnlib.RunContext(submit_config, train) training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(g.G, training_set, **grid_args) sched = training_loop.training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: g.G.setup_weight_histograms(); g.D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 loss_values = [] while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_loop.training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): g.G_opt.reset_optimizer_state(); # D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): alpha_for_graph, alpha_for_target = g.get_train_alpha(constants.BATCH_SIZE) if not isinstance(alpha_for_graph, list): alpha_for_graph = [alpha_for_graph] alpha_for_target = [alpha_for_target] for ag, at in zip(alpha_for_graph, alpha_for_target): feed_dict_out = graph_util.graph_input(g, constants.BATCH_SIZE, seed=0) out_zs = g.sess.run(g.outputs_orig, feed_dict_out) target_fn, mask_out = g.get_target_np(out_zs, at) feed_dict = feed_dict_out feed_dict[g.alpha] = ag feed_dict[g.target] = target_fn feed_dict[g.mask] = mask_out feed_dict[g.lod_in] = sched.lod feed_dict[g.lrate_in] = sched.D_lrate feed_dict[g.minibatch_in] = sched.minibatch curr_loss, _, Gs_op, G_op = g.sess.run([g.joint_loss, g.train_step, g.Gs_update_op, g.G_train_op], feed_dict=feed_dict) loss_values.append(curr_loss) cur_nimg += sched.minibatch #tflib.run([g.Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) #tflib.run([g.G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((g.G, g.D, g.Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) if cur_tick % w_snapshot_ticks == 0 or done: g.saver.save(g.sess, './{}/model_{}.ckpt'.format( output_dir, (cur_nimg // 1000)), write_meta_graph=False, write_state=False) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((g.G, g.D, g.Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close() loss_values = np.array(loss_values) np.save('./{}/loss_values.npy'.format(output_dir), loss_values) f, ax = plt.subplots(figsize=(10, 4)) ax.plot(loss_values) f.savefig('./{}/loss_values.png'.format(output_dir))
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=10000.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: #print('Constructing networks...') #G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) #D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) #Gs = G.clone('Gs') url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: G, D, Gs = pickle.load(f) print('Loading pretrained FFHQ network') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def main(): # Initialize TensorFlow. tflib.init_tf() # Load pre-trained network. #url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl #with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: _G, _D, Gs = misc.load_pkl('SNI.pkl') #pickle.load(f) # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. # Print network details. Gs.print_layers() number_unique_faces = 2 #They appear on rows of the generated figure variations_per_face = 5 #They appear on columns of the generated figure # Pick latent vector. #rnd = np.random.RandomState(2) #latents = rnd.randn(number_unique_faces, Gs.input_shape[1]) latents = np.random.randn(number_unique_faces, Gs.input_shape[1]) grid_latents = misc.randomize_global_codes(latents, variations_per_face) grid_fakes = Gs.run(grid_latents, None, truncation_psi=0.7) misc.save_image_grid(grid_fakes, 'example_fakes_global.png', drange=[-1, 1], grid_size=(variations_per_face, number_unique_faces)) grid_latents = misc.randomize_global_shared_codes(latents, variations_per_face) grid_fakes = Gs.run(grid_latents, None, truncation_psi=0.7) misc.save_image_grid(grid_fakes, 'example_fakes_shared.png', drange=[-1, 1], grid_size=(variations_per_face, number_unique_faces)) grid_latents = misc.randomize_all_local_codes(latents, variations_per_face) grid_fakes = Gs.run(grid_latents, None, truncation_psi=0.7) misc.save_image_grid(grid_fakes, 'example_fakes_alllocal.png', drange=[-1, 1], grid_size=(variations_per_face, number_unique_faces)) mask = np.zeros(shape=(8, 8, 1)) mask[4:8, 2:6] = 1 grid_latents = misc.randomize_specific_local_codes(latents, mask, variations_per_face) grid_fakes = Gs.run(grid_latents, None, truncation_psi=0.7) misc.save_image_grid(grid_fakes, 'example_fakes_mouth.png', drange=[-1, 1], grid_size=(variations_per_face, number_unique_faces)) mask = np.zeros(shape=(8, 8, 1)) mask[0:3, 1:7] = 1 grid_latents = misc.randomize_specific_local_codes(latents, mask, variations_per_face) grid_fakes = Gs.run(grid_latents, None, truncation_psi=0.7) misc.save_image_grid(grid_fakes, 'example_fakes_hair.png', drange=[-1, 1], grid_size=(variations_per_face, number_unique_faces))