def main(): tflib.init_tf() network_pkl, _ = misc.locate_latest_pkl() _G, _D, Gs = pickle.load(open(network_pkl, "rb")) destination = os.path.join(config.result_dir, 'figure') os.makedirs(destination, exist_ok=True) draw_uncurated_result_figure(os.path.join(destination, 'figure02-uncurated-ffhq.png'), Gs, cx=0, cy=0, cw=512, ch=512, rows=3, lods=[0, 1, 2, 2, 3, 3], seed=5) draw_style_mixing_figure( os.path.join(destination, 'figure03-style-mixing.png'), Gs, w=512, h=512, src_seeds=[639, 701, 687, 615, 2268], dst_seeds=[888, 829, 1898, 1733, 1614, 845], style_ranges=[range(0, 4)] * 3 + [range(4, 8)] * 2 + [range(8, 16)]) draw_noise_detail_figure(os.path.join(destination, 'figure04-noise-detail.png'), Gs, w=512, h=512, num_samples=100, seeds=[1157, 1012]) draw_noise_components_figure( os.path.join(destination, 'figure05-noise-components.png'), Gs, w=512, h=512, seeds=[1967, 1555], noise_ranges=[range(0, 18), range(0, 0), range(8, 18), range(0, 8)], flips=[1]) draw_truncation_trick_figure( os.path.join(destination, 'figure08-truncation-trick.png'), Gs, w=512, h=512, seeds=[91, 388, 389, 390, 391, 392, 393, 394, 395, 396], psis=[1, 0.7, 0.5, 0.25, 0, -0.25, -0.5, -1])
def main(): tflib.init_tf() network_pkl, _ = misc.locate_latest_pkl() _G, _D, Gs = pickle.load(open(network_pkl, "rb")) Gs.print_layers() for i in range(0, 1000): rnd = np.random.RandomState(None) latents = rnd.randn(1, Gs.input_shape[1]) fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.6, randomize_noise=True, output_transform=fmt) os.makedirs(os.path.join(config.result_dir, 'example'), exist_ok=True) png_filename = os.path.join(config.result_dir, 'example', 'example-' + str(i) + '.png') PIL.Image.fromarray(images[0], 'RGB').save(png_filename)
def training_loop( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. G_opt_args = {}, # Options for generator optimizer. D_opt_args = {}, # Options for discriminator optimizer. G_loss_args = {}, # Options for generator loss. D_loss_args = {}, # Options for discriminator loss. dataset_args = {}, # Options for dataset.load_dataset(). sched_args = {}, # Options for train.TrainingSchedule. grid_args = {}, # Options for train.setup_snapshot_image_grid(). setname = None, # Model name tf_config = {}, # Options for tflib.init_tf(). G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. lazy_regularization = True, # Perform regularization as a separate training step? G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg = 25000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? mirror_augment_v = False, # Enable mirror augment vertically? drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_pkl = 'latest', # Network pickle to resume training from, None = train from scratch. resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. restore_partial_fn = None, # Filename of network for partial restore resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) # custom resolution - for saved model name below resolution = training_set.resolution if training_set.init_res != [4,4]: init_res_str = '-%dx%d' % (training_set.init_res[0], training_set.init_res[1]) else: init_res_str = '' ext = 'png' if training_set.shape[0] == 4 else 'jpg' print(' model base resolution', resolution) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('_reals.%s'%ext), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print(' Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: if resume_pkl == 'latest': resume_pkl, resume_kimg = misc.locate_latest_pkl(dnnlib.submit_config.run_dir_root) elif resume_pkl == 'restore_partial': print(' Restore partially...') # Initialize networks G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # Load pre-trained networks assert restore_partial_fn != None G_partial, D_partial, Gs_partial = pickle.load(open(restore_partial_fn, 'rb')) # Restore (subset of) pre-trained weights (only parameters that match both name and shape) G.copy_compatible_trainables_from(G_partial) D.copy_compatible_trainables_from(D_partial) Gs.copy_compatible_trainables_from(Gs_partial) else: if resume_pkl is not None and resume_kimg == 0: resume_pkl, resume_kimg = misc.locate_latest_pkl(resume_pkl) print(' Loading networks from "%s", kimg %.3g' % (resume_pkl, resume_kimg)) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers if needed and generate initial image snapshot # G.print_layers(); D.print_layers() sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.%s'%ext), drange=drange_net, grid_size=grid_size) # Setup training inputs. print(' Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args) reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, training_set.dynamic_range, drange_net) reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() # print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() print(' Training for %d kimg (%d left) \n' % (total_kimg, total_kimg-resume_kimg)) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # , sched.lod if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu} for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time if sched.lod == 0: left_kimg = total_kimg - cur_nimg / 1000 left_sec = left_kimg * tick_time / tick_kimg finaltime = time.asctime(time.localtime(cur_time + left_sec)) msg_final = '%ss left till %s ' % (shortime(left_sec), finaltime[11:16]) else: msg_final = '' # Report progress. # print('tick %-4d kimg %-6.1f lod %-5.2f minibch %-3d:%d time %-8s min/tick %-6.3g %s sec/kimg %-7.3g gpumem %-4.1f %d lr %.2g ' % ( print('tick %-4d kimg %-6.1f time %-8s %s min/tick %-6.3g sec/kimg %-7.3g gpumem %-4.1f lr %.2g ' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), # autosummary('Progress/lod', sched.lod), # autosummary('Progress/minibatch', sched.minibatch_size), # autosummary('Progress/minibatch_gpu', sched.minibatch_gpu), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), msg_final, autosummary('Timing/min_per_tick', tick_time / 60), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), # autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30), sched.G_lrate)) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('snapshot-%d-%s%s-%04d.pkl' % (resolution, setname[-1], init_res_str, cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-%04d.pkl' % (setname[:-1], resolution, setname[-1], init_res_str, cur_nimg // 1000))) # Update summaries and RunContext. tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('snapshot-%d-%s%s-final.pkl' % (resolution, setname[-1], init_res_str))) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-final.pkl' % (setname[:-1], resolution, setname[-1], init_res_str))) # All done. summary_log.close() training_set.close()
def main(): tflib.init_tf() # Load pre-trained network. # url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: ## NOTE: insert model here: network_pkl, _ = misc.locate_latest_pkl() _G, _D, Gs = pickle.load(open(network_pkl, "rb")) # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. grid_size = [2,2] image_shrink = 1 image_zoom = 1 duration_sec = 60.0 smoothing_sec = 1.0 mp4_fps = 20 mp4_codec = 'libx264' mp4_bitrate = '5M' random_seed = 404 mp4_file = os.path.join(config.result_dir, 'random_grid_%s.mp4' % random_seed) minibatch_size = 8 num_frames = int(np.rint(duration_sec * mp4_fps)) random_state = np.random.RandomState(random_seed) # Generate latent vectors shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:] # [frame, image, channel, component] all_latents = random_state.randn(*shape).astype(np.float32) import scipy all_latents = scipy.ndimage.gaussian_filter(all_latents, [smoothing_sec * mp4_fps] + [0] * len(Gs.input_shape), mode='wrap') all_latents /= np.sqrt(np.mean(np.square(all_latents))) def create_image_grid(images, grid_size=None): assert images.ndim == 3 or images.ndim == 4 num, img_h, img_w, channels = images.shape if grid_size is not None: grid_w, grid_h = tuple(grid_size) else: grid_w = max(int(np.ceil(np.sqrt(num))), 1) grid_h = max((num - 1) // grid_w + 1, 1) grid = np.zeros([grid_h * img_h, grid_w * img_w, channels], dtype=images.dtype) for idx in range(num): x = (idx % grid_w) * img_w y = (idx // grid_w) * img_h grid[y : y + img_h, x : x + img_w] = images[idx] return grid # Frame generation func for moviepy. def make_frame(t): frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1)) latents = all_latents[frame_idx] fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=False, output_transform=fmt) grid = create_image_grid(images, grid_size) if image_zoom > 1: grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1], order=0) if grid.shape[2] == 1: grid = grid.repeat(3, 2) # grayscale => RGB return grid # Generate video. import moviepy.editor video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec) video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate) # import scipy # coarse duration_sec = 60.0 smoothing_sec = 1.0 mp4_fps = 20 num_frames = int(np.rint(duration_sec * mp4_fps)) random_seed = 500 random_state = np.random.RandomState(random_seed) w = 512 h = 512 #src_seeds = [601] dst_seeds = [700] style_ranges = ([0] * 7 + [range(8,16)]) * len(dst_seeds) fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) synthesis_kwargs = dict(output_transform=fmt, truncation_psi=0.7, minibatch_size=8) shape = [num_frames] + Gs.input_shape[1:] # [frame, image, channel, component] src_latents = random_state.randn(*shape).astype(np.float32) src_latents = scipy.ndimage.gaussian_filter(src_latents, smoothing_sec * mp4_fps, mode='wrap') src_latents /= np.sqrt(np.mean(np.square(src_latents))) dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds) src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component] src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs) dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs) canvas = PIL.Image.new('RGB', (w * (len(dst_seeds) + 1), h * 2), 'white') for col, dst_image in enumerate(list(dst_images)): canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), ((col + 1) * h, 0)) def make_frame(t): frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1)) src_image = src_images[frame_idx] canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), (0, h)) for col, dst_image in enumerate(list(dst_images)): col_dlatents = np.stack([dst_dlatents[col]]) col_dlatents[:, style_ranges[col]] = src_dlatents[frame_idx, style_ranges[col]] col_images = Gs.components.synthesis.run(col_dlatents, randomize_noise=False, **synthesis_kwargs) for row, image in enumerate(list(col_images)): canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * h, (row + 1) * w)) return np.array(canvas) # Generate video. import moviepy.editor mp4_file = os.path.join(config.result_dir,'interpolate.mp4') mp4_codec = 'libx264' mp4_bitrate = '5M' video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec) video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate) import scipy duration_sec = 60.0 smoothing_sec = 1.0 mp4_fps = 20 num_frames = int(np.rint(duration_sec * mp4_fps)) random_seed = 503 random_state = np.random.RandomState(random_seed) w = 512 h = 512 style_ranges = [range(6,16)] fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) synthesis_kwargs = dict(output_transform=fmt, truncation_psi=0.7, minibatch_size=8) shape = [num_frames] + Gs.input_shape[1:] # [frame, image, channel, component] src_latents = random_state.randn(*shape).astype(np.float32) src_latents = scipy.ndimage.gaussian_filter(src_latents, smoothing_sec * mp4_fps, mode='wrap') src_latents /= np.sqrt(np.mean(np.square(src_latents))) dst_latents = np.stack([random_state.randn(Gs.input_shape[1])]) src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component] dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component] def make_frame(t): frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1)) col_dlatents = np.stack([dst_dlatents[0]]) col_dlatents[:, style_ranges[0]] = src_dlatents[frame_idx, style_ranges[0]] col_images = Gs.components.synthesis.run(col_dlatents, randomize_noise=False, **synthesis_kwargs) return col_images[0] # Generate video. import moviepy.editor mp4_file = os.path.join(config.result_dir, 'fine_%s.mp4' % (random_seed)) mp4_codec = 'libx264' mp4_bitrate = '5M' video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec) video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics): train = EasyDict(run_func_name='training.training_loop.training_loop' ) # Options for training loop. G = EasyDict(func_name='training.networks_stylegan2.G_main' ) # Options for generator network. D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2' ) # Options for discriminator network. G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer. D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer. G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg' ) # Options for generator loss. D_loss = EasyDict(func_name='training.loss.D_logistic_r1' ) # Options for discriminator loss. sched = EasyDict() # Options for TrainingSchedule. grid = EasyDict( size='8k', layout='random') # Options for setup_snapshot_image_grid(). sc = dnnlib.SubmitConfig() # Options for dnnlib.submit_run(). tf_config = {'rnd.np_random_seed': 1000} # Options for tflib.init_tf(). try: pkl, kimg = misc.locate_latest_pkl(result_dir) train.resume_pkl = pkl train.resume_kimg = kimg except: print('Couldn\'t find valid snapshot, starting over') train.data_dir = data_dir train.total_kimg = total_kimg train.mirror_augment = mirror_augment train.image_snapshot_ticks = 1 train.network_snapshot_ticks = 2 sched.G_lrate_base = sched.D_lrate_base = 0.002 sched.minibatch_size_base = 32 sched.minibatch_gpu_base = 4 D_loss.gamma = 10 metrics = [metric_defaults[x] for x in metrics] desc = 'stylegan2' desc += '-' + dataset dataset_args = EasyDict(tfrecord_dir=dataset) assert num_gpus in [1, 2, 4, 8] sc.num_gpus = num_gpus desc += '-%dgpu' % num_gpus assert config_id in _valid_configs desc += '-' + config_id # Configs A-E: Shrink networks to match original StyleGAN. if config_id != 'config-f': G.fmap_base = D.fmap_base = 8 << 10 # Config E: Set gamma to 100 and override G & D architecture. if config_id.startswith('config-e'): D_loss.gamma = 100 if 'Gorig' in config_id: G.architecture = 'orig' if 'Gskip' in config_id: G.architecture = 'skip' # (default) if 'Gresnet' in config_id: G.architecture = 'resnet' if 'Dorig' in config_id: D.architecture = 'orig' if 'Dskip' in config_id: D.architecture = 'skip' if 'Dresnet' in config_id: D.architecture = 'resnet' # (default) # Configs A-D: Enable progressive growing and switch to networks that support it. if config_id in ['config-a', 'config-b', 'config-c', 'config-d']: sched.lod_initial_resolution = 8 sched.G_lrate_base = sched.D_lrate_base = 0.001 sched.G_lrate_dict = sched.D_lrate_dict = { 128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003 } sched.minibatch_size_base = 32 # (default) sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32} sched.minibatch_gpu_base = 4 # (default) sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4} G.synthesis_func = 'G_synthesis_stylegan_revised' D.func_name = 'training.networks_stylegan2.D_stylegan' # Configs A-C: Disable path length regularization. if config_id in ['config-a', 'config-b', 'config-c']: G_loss = EasyDict(func_name='training.loss.G_logistic_ns') # Configs A-B: Disable lazy regularization. if config_id in ['config-a', 'config-b']: train.lazy_regularization = False # Config A: Switch to original StyleGAN networks. if config_id == 'config-a': G = EasyDict(func_name='training.networks_stylegan.G_style') D = EasyDict(func_name='training.networks_stylegan.D_basic') if gamma is not None: D_loss.gamma = gamma sc.submit_target = dnnlib.SubmitTarget.LOCAL sc.local.do_not_copy_source_files = True kwargs = EasyDict(train) kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss) kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config) kwargs.submit_config = copy.deepcopy(sc) kwargs.submit_config.run_dir_root = result_dir kwargs.submit_config.run_desc = desc dnnlib.submit_run(**kwargs)
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. data_root_dir=None, dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=data_root_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: if resume_run_id == 'latest': network_pkl, resume_kimg = misc.locate_latest_pkl() else: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() # metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) # metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. # metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( run_dir='.', # Output directory. G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. loss_args={}, # Options for loss function. train_dataset_args={}, # Options for dataset to train with. metric_dataset_args={}, # Options for dataset to evaluate metrics against. augment_args={}, # Options for adaptive augmentations. metric_arg_list=[], # Metrics to evaluate during training. num_gpus=1, # Number of GPUs to use. minibatch_size=32, # Global minibatch size. minibatch_gpu=4, # Number of samples processed at a time by one GPU. G_smoothing_kimg=10, # Half-life of the exponential moving average (EMA) of generator weights. G_smoothing_rampup=None, # EMA ramp-up coefficient. minibatch_repeats=4, # Number of minibatches to run in the inner loop. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. total_kimg=25000, # Total length of the training, measured in thousands of real images. kimg_per_tick=4, # Progress snapshot interval. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. resume_pkl=None, # Network pickle to resume training from. abort_fn=None, # Callback function for determining whether to abort training. progress_fn=None, # Callback function for updating training progress. ): assert minibatch_size % (num_gpus * minibatch_gpu) == 0 start_time = time.time() print('Loading training set...') training_set = dataset.load_dataset(**train_dataset_args) print('Image shape:', np.int32(training_set.shape).tolist()) print('Label shape:', [training_set.label_size]) print() print('Constructing networks...') resume_kimg = 0 with tf.device('/gpu:0'): G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: if resume_pkl == 'latest': resume_pkl, resume_kimg = misc.locate_latest_pkl( f'{run_dir}/..') print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: rG, rD, rGs = pickle.load(f) G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) G.print_layers() D.print_layers() print('Exporting sample images...') grid_size, grid_reals, grid_labels = setup_snapshot_image_grid( training_set) save_image_grid(grid_reals, os.path.join(run_dir, 'reals.jpg'), drange=[0, 255], grid_size=grid_size) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu) save_image_grid(grid_fakes, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1, 1], grid_size=grid_size) print(f'Replicating networks across {num_gpus} GPUs...') G_gpus = [G] D_gpus = [D] for gpu in range(1, num_gpus): with tf.device(f'/gpu:{gpu}'): G_gpus.append(G.clone(f'{G.name}_gpu{gpu}')) D_gpus.append(D.clone(f'{D.name}_gpu{gpu}')) print('Initializing augmentations...') aug = None if augment_args.get('class_name', None) is not None: aug = dnnlib.util.construct_class_by_name(**augment_args) aug.init_validation_set(D_gpus=D_gpus, training_set=training_set) print('Setting up optimizers...') G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args[ 'minibatch_multiplier'] = minibatch_size // num_gpus // minibatch_gpu if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) print('Constructing training graph...') data_fetch_ops = [] training_set.configure(minibatch_gpu) for gpu, (G_gpu, D_gpu) in enumerate(zip(G_gpus, D_gpus)): with tf.name_scope(f'Train_gpu{gpu}'), tf.device(f'/gpu:{gpu}'): # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): real_images_var = tf.Variable( name='images', trainable=False, initial_value=tf.zeros([minibatch_gpu] + training_set.shape)) real_labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ minibatch_gpu, training_set.label_size ])) real_images_write, real_labels_write = training_set.get_minibatch_tf( ) real_images_write = tflib.convert_images_from_uint8( real_images_write) data_fetch_ops += [ tf.assign(real_images_var, real_images_write) ] data_fetch_ops += [ tf.assign(real_labels_var, real_labels_write) ] # Evaluate loss function and register gradients. fake_labels = training_set.get_random_labels_tf(minibatch_gpu) terms = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, aug=aug, fake_labels=fake_labels, real_images=real_images_var, real_labels=real_labels_var, **loss_args) if lazy_regularization: if terms.G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(terms.G_reg * G_reg_interval), G_gpu.trainables) if terms.D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(terms.D_reg * D_reg_interval), D_gpu.trainables) else: if terms.G_reg is not None: terms.G_loss += terms.G_reg if terms.D_reg is not None: terms.D_loss += terms.D_reg G_opt.register_gradients(tf.reduce_mean(terms.G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(terms.D_loss), D_gpu.trainables) print('Finalizing training ops...') data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_beta_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[]) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta_in) Gs_epochs = tf.placeholder(tf.float32, name='Gs_epochs', shape=[]) Gs_epochs_op = Gs.update_epochs(Gs_epochs) tflib.init_uninitialized_vars() with tf.device('/gpu:0'): peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() print('Initializing metrics...') summary_log = tf.summary.FileWriter(run_dir) metrics = [] for args in metric_arg_list: metric = dnnlib.util.construct_class_by_name(**args) metric.configure(dataset_args=metric_dataset_args, run_dir=run_dir) metrics.append(metric) print(f'Training for {total_kimg} kimg...') print() if progress_fn is not None: progress_fn(0, total_kimg) tick_start_time = time.time() maintenance_time = tick_start_time - start_time cur_tick = -1 cur_nimg = int(resume_kimg * 1000) tick_start_nimg = cur_nimg running_mb_counter = 0 done = False while not done: # Compute EMA decay parameter. Gs_nimg = G_smoothing_kimg * 1000.0 if G_smoothing_rampup is not None: Gs_nimg = min(Gs_nimg, cur_nimg * G_smoothing_rampup) Gs_beta = 0.5**(minibatch_size / max(Gs_nimg, 1e-8)) epochs = float( 100 * cur_nimg / (total_kimg * 1000)) # 100 total top k "epochs" in total_kimg # Run training ops. for _repeat_idx in range(minibatch_repeats): rounds = range(0, minibatch_size, minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op]) if run_G_reg: tflib.run(G_reg_op) tflib.run([D_train_op, Gs_update_op, Gs_epochs_op], { Gs_beta_in: Gs_beta, Gs_epochs: epochs }) if run_D_reg: tflib.run(D_reg_op) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op) if run_G_reg: tflib.run(G_reg_op) tflib.run([Gs_update_op, Gs_epochs_op], { Gs_beta_in: Gs_beta, Gs_epochs: epochs }) for _round in rounds: tflib.run(data_fetch_op) tflib.run(D_train_op) if run_D_reg: tflib.run(D_reg_op) # Run validation. if aug is not None: aug.run_validation(minibatch_size=minibatch_size) # Tune augmentation parameters. if aug is not None: aug.tune(minibatch_size * minibatch_repeats) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) or (abort_fn is not None and abort_fn()) if done or cur_tick < 0 or cur_nimg >= tick_start_nimg + kimg_per_tick * 1000: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_end_time = time.time() total_time = tick_end_time - start_time tick_time = tick_end_time - tick_start_time # Report progress. print(' '.join([ f"tick {autosummary('Progress/tick', cur_tick):<5d}", f"kimg {autosummary('Progress/kimg', cur_nimg / 1000.0):<8.1f}", f"time {dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)):<12s}", f"sec/tick {autosummary('Timing/sec_per_tick', tick_time):<7.1f}", f"sec/kimg {autosummary('Timing/sec_per_kimg', tick_time / tick_kimg):<7.2f}", f"maintenance {autosummary('Timing/maintenance_sec', maintenance_time):<6.1f}", f"gpumem {autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30):<5.1f}", f"augment {autosummary('Progress/augment', aug.strength if aug is not None else 0):.3f}", ])) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Save snapshots. if image_snapshot_ticks is not None and ( done or cur_tick % image_snapshot_ticks == 0): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=minibatch_gpu) save_image_grid(grid_fakes, os.path.join( run_dir, f'fakes{cur_nimg // 1000:06d}.jpg'), drange=[-1, 1], grid_size=grid_size) if network_snapshot_ticks is not None and ( done or cur_tick % network_snapshot_ticks == 0): pkl = os.path.join( run_dir, f'network-snapshot-{cur_nimg // 1000:06d}.pkl') with open(pkl, 'wb') as f: pickle.dump((G, D, Gs), f) if len(metrics): print('Evaluating metrics...') for metric in metrics: metric.run(pkl, num_gpus=num_gpus) # Update summaries. for metric in metrics: metric.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time print() print('Exiting...') summary_log.close() training_set.close()
def training_loop( run_dir='.', # Output directory. training_set_kwargs={}, # Options for training set. data_loader_kwargs={}, # Options for torch.utils.data.DataLoader. G_kwargs={}, # Options for generator network. D_kwargs={}, # Options for discriminator network. G_opt_kwargs={}, # Options for generator optimizer. D_opt_kwargs={}, # Options for discriminator optimizer. augment_kwargs=None, # Options for augmentation pipeline. None = disable. loss_kwargs={}, # Options for loss function. metrics=[], # Metrics to evaluate during training. random_seed=0, # Global random seed. num_gpus=1, # Number of GPUs participating in the training. rank=0, # Rank of the current process in [0, num_gpus[. batch_size=4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. batch_gpu=4, # Number of samples processed at a time by one GPU. ema_kimg=10, # Half-life of the exponential moving average (EMA) of generator weights. ema_rampup=None, # EMA ramp-up coefficient. G_reg_interval=4, # How often to perform regularization for G? None = disable lazy regularization. D_reg_interval=16, # How often to perform regularization for D? None = disable lazy regularization. augment_p=0, # Initial value of augmentation probability. ada_target=None, # ADA target value. None = fixed p. ada_interval=4, # How often to perform ADA adjustment? ada_kimg=500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. total_kimg=25000, # Total length of the training, measured in thousands of real images. kimg_per_tick=4, # Progress snapshot interval. image_snapshot_ticks=50, # How often to save image snapshots? None = disable. network_snapshot_ticks=50, # How often to save network snapshots? None = disable. resume_pkl=None, # Network pickle to resume training from. cudnn_benchmark=True, # Enable torch.backends.cudnn.benchmark? allow_tf32=False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32? abort_fn=None, # Callback function for determining whether to abort training. Must return consistent results across ranks. progress_fn=None, # Callback function for updating training progress. Called for all ranks. ): # Initialize. start_time = time.time() device = torch.device('cuda', rank) np.random.seed(random_seed * num_gpus + rank) torch.manual_seed(random_seed * num_gpus + rank) torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions conv2d_gradfix.enabled = True # Improves training speed. grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. # Load training set. if rank == 0: print('Loading training set...') training_set = dnnlib.util.construct_class_by_name( **training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter( torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size // num_gpus, **data_loader_kwargs)) if rank == 0: print() print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) print() # Construct networks. if rank == 0: print('Constructing networks...') common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) G = dnnlib.util.construct_class_by_name( **G_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module D = dnnlib.util.construct_class_by_name( **D_kwargs, **common_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module G_ema = copy.deepcopy(G).eval() # Resume from existing pickle. if resume_pkl == 'latest': out_dir = tmisc.get_parent_dir(run_dir) resume_pkl = tmisc.locate_latest_pkl(out_dir) resume_kimg = tmisc.parse_kimg_from_network_name(resume_pkl) if resume_kimg > 0: print(f'Resuming from kimg = {resume_kimg}') if ada_target is not None and augment_p == 0: # Overwrite augment_p only if the augmentation probability is not fixed by the user augment_p = tmisc.parse_augment_p_from_log(resume_pkl) if augment_p > 0: print(f'Resuming with augment_p = {augment_p}') if (resume_pkl is not None) and (rank == 0): print(f'Resuming from "{resume_pkl}"') with dnnlib.util.open_url(resume_pkl) as f: resume_data = legacy.load_network_pkl(f) for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]: misc.copy_params_and_buffers(resume_data[name], module, require_all=False) # Print network summary tables. if rank == 0: z = torch.empty([batch_gpu, G.z_dim], device=device) c = torch.empty([batch_gpu, G.c_dim], device=device) img = misc.print_module_summary(G, [z, c]) misc.print_module_summary(D, [img, c]) # Setup augmentation. if rank == 0: print('Setting up augmentation...') augment_pipe = None ada_stats = None if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): augment_pipe = dnnlib.util.construct_class_by_name( **augment_kwargs).train().requires_grad_(False).to( device) # subclass of torch.nn.Module augment_pipe.p.copy_(torch.as_tensor(augment_p)) if ada_target is not None: ada_stats = training_stats.Collector(regex='Loss/signs/real') # Distribute across GPUs. if rank == 0: print(f'Distributing across {num_gpus} GPUs...') ddp_modules = dict() for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]: if (num_gpus > 1) and (module is not None) and len( list(module.parameters())) != 0: module.requires_grad_(True) module = torch.nn.parallel.DistributedDataParallel( module, device_ids=[device], broadcast_buffers=False) module.requires_grad_(False) if name is not None: ddp_modules[name] = module # Setup training phases. if rank == 0: print('Setting up training phases...') loss = dnnlib.util.construct_class_by_name( device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss phases = [] for name, module, opt_kwargs, reg_interval in [ ('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval) ]: if reg_interval is None: opt = dnnlib.util.construct_class_by_name( params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'both', module=module, opt=opt, interval=1) ] else: # Lazy regularization. mb_ratio = reg_interval / (reg_interval + 1) opt_kwargs = dnnlib.EasyDict(opt_kwargs) opt_kwargs.lr = opt_kwargs.lr * mb_ratio opt_kwargs.betas = [beta**mb_ratio for beta in opt_kwargs.betas] opt = dnnlib.util.construct_class_by_name( module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer phases += [ dnnlib.EasyDict(name=name + 'main', module=module, opt=opt, interval=1) ] phases += [ dnnlib.EasyDict(name=name + 'reg', module=module, opt=opt, interval=reg_interval) ] for phase in phases: phase.start_event = None phase.end_event = None if rank == 0: phase.start_event = torch.cuda.Event(enable_timing=True) phase.end_event = torch.cuda.Event(enable_timing=True) # Export sample images. grid_size = None grid_z = None grid_c = None if rank == 0: print('Exporting sample images...') grid_size, images, labels = setup_snapshot_image_grid( training_set=training_set) save_image_grid(images, os.path.join(run_dir, 'reals.jpg'), drange=[0, 255], grid_size=grid_size) grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) grid_c = torch.from_numpy(labels).to(device).split(batch_gpu) images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1, 1], grid_size=grid_size) # Initialize logs. if rank == 0: print('Initializing logs...') stats_collector = training_stats.Collector(regex='.*') stats_metrics = dict() stats_jsonl = None stats_tfevents = None if rank == 0: stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') try: import torch.utils.tensorboard as tensorboard stats_tfevents = tensorboard.SummaryWriter(run_dir) except ImportError as err: print('Skipping tfevents export:', err) # Train. if rank == 0: print(f'Training for {total_kimg} kimg...') print() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - start_time batch_idx = 0 if progress_fn is not None: progress_fn(int(resume_kimg), total_kimg) while True: # Fetch training data. with torch.autograd.profiler.record_function('data_fetch'): phase_real_img, phase_real_c = next(training_set_iterator) phase_real_img = ( phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) phase_real_c = phase_real_c.to(device).split(batch_gpu) all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) all_gen_z = [ phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size) ] all_gen_c = [ training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size) ] all_gen_c = torch.from_numpy( np.stack(all_gen_c)).pin_memory().to(device) all_gen_c = [ phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size) ] # Execute training phases. for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): if batch_idx % phase.interval != 0: continue # Initialize gradient accumulation. if phase.start_event is not None: phase.start_event.record(torch.cuda.current_stream(device)) phase.opt.zero_grad(set_to_none=True) phase.module.requires_grad_(True) # Accumulate gradients over multiple rounds. for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate( zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)): sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1) gain = phase.interval loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain) # Update weights. phase.module.requires_grad_(False) with torch.autograd.profiler.record_function(phase.name + '_opt'): for param in phase.module.parameters(): if param.grad is not None: misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) phase.opt.step() if phase.end_event is not None: phase.end_event.record(torch.cuda.current_stream(device)) # Update G_ema. with torch.autograd.profiler.record_function('Gema'): ema_nimg = ema_kimg * 1000 if ema_rampup is not None: ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) ema_beta = 0.5**(batch_size / max(ema_nimg, 1e-8)) for p_ema, p in zip(G_ema.parameters(), G.parameters()): p_ema.copy_(p.lerp(p_ema, ema_beta)) for b_ema, b in zip(G_ema.buffers(), G.buffers()): b_ema.copy_(b) # Update state. cur_nimg += batch_size batch_idx += 1 # Execute ADA heuristic. if (ada_stats is not None) and (batch_idx % ada_interval == 0): ada_stats.update() adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * ( batch_size * ada_interval) / (ada_kimg * 1000) augment_pipe.p.copy_( (augment_pipe.p + adjust).max(misc.constant(0, device=device))) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if (not done) and (cur_tick != 0) and ( cur_nimg < tick_start_nimg + kimg_per_tick * 1000): continue # Print status line, accumulating the same information in stats_collector. tick_end_time = time.time() fields = [] fields += [ f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}" ] fields += [ f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}" ] fields += [ f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}" ] fields += [ f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}" ] fields += [ f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}" ] fields += [ f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}" ] fields += [ f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" ] fields += [ f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}" ] torch.cuda.reset_peak_memory_stats() fields += [ f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}" ] training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) if rank == 0: print(' '.join(fields)) # Check for abort. if (not done) and (abort_fn is not None) and abort_fn(): done = True if rank == 0: print() print('Aborting...') # Save image snapshot. if (rank == 0) and (image_snapshot_ticks is not None) and ( done or cur_tick % image_snapshot_ticks == 0): images = torch.cat([ G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c) ]).numpy() save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.jpg'), drange=[-1, 1], grid_size=grid_size) # Save network snapshot. snapshot_pkl = None snapshot_data = None if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0): snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]: if module is not None: if num_gpus > 1: misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg') module = copy.deepcopy(module).eval().requires_grad_( False).cpu() snapshot_data[name] = module del module # conserve memory snapshot_pkl = os.path.join( run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') if rank == 0: with open(snapshot_pkl, 'wb') as f: pickle.dump(snapshot_data, f) # Evaluate metrics. if (snapshot_data is not None) and (len(metrics) > 0): if rank == 0: print('Evaluating metrics...') for metric in metrics: result_dict = metric_main.calc_metric( metric=metric, G=snapshot_data['G_ema'], dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) if rank == 0: metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) stats_metrics.update(result_dict.results) del snapshot_data # conserve memory # Collect statistics. for phase in phases: value = [] if (phase.start_event is not None) and (phase.end_event is not None): phase.end_event.synchronize() value = phase.start_event.elapsed_time(phase.end_event) training_stats.report0('Timing/' + phase.name, value) stats_collector.update() stats_dict = stats_collector.as_dict() # Update logs. timestamp = time.time() if stats_jsonl is not None: fields = dict(stats_dict, timestamp=timestamp) stats_jsonl.write(json.dumps(fields) + '\n') stats_jsonl.flush() if stats_tfevents is not None: global_step = int(cur_nimg / 1e3) walltime = timestamp - start_time for name, value in stats_dict.items(): stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) for name, value in stats_metrics.items(): stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) stats_tfevents.flush() if progress_fn is not None: progress_fn(cur_nimg // 1000, total_kimg) # Update state. cur_tick += 1 tick_start_nimg = cur_nimg tick_start_time = time.time() maintenance_time = tick_start_time - tick_end_time if done: break # Done. if rank == 0: print() print('Exiting...')
def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. loss_args={}, # Options for loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for metrics. metric_args={}, # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). ema_start_kimg=None, # Start of the exponential moving average. Default to the half-life period. G_ema_kimg=10, # Half-life of the exponential moving average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=False, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=4, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=10, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=10, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False ): # Construct new networks according to G_args and D_args before resuming training? if ema_start_kimg is None: ema_start_kimg = G_ema_kimg # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.jpg'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl == 'latest': # https://github.com/skyflynil/stylegan2/blob/master/training/training_loop.py print("resuming from latest") resume_pkl, resume_kimg = misc.locate_latest_pkl( dnnlib.submit_config.run_dir_root) if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: resume_networks = misc.load_pkl(resume_pkl) rG, rD, rGs = resume_networks if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.jpg'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) G_lrate_in = tf.placeholder(tf.float32, name='G_lrate_in', shape=[]) D_lrate_in = tf.placeholder(tf.float32, name='D_lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) run_D_reg_in = tf.placeholder(tf.bool, name='run_D_reg', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta_mul_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[]) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_ema_kimg * 1000.0) if G_ema_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) G_opt_args['learning_rate'] = G_lrate_in D_opt_args['learning_rate'] = D_lrate_in for args in [G_opt_args, D_opt_args]: args['minibatch_multiplier'] = minibatch_multiplier G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) # Build training graph for each GPU. for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): with tf.name_scope('DataFetch'): reals_read, labels_read = training_set.get_minibatch_tf() reals_read = process_reals(reals_read, lod_in, mirror_augment, training_set.dynamic_range, drange_net) # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('loss'): G_loss, D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, real_labels=labels_read, **loss_args) # Register gradients. if not lazy_regularization: if D_reg is not None: D_loss += D_reg else: if D_reg is not None: D_loss = tf.cond(run_D_reg_in, lambda: D_loss + D_reg * D_reg_interval, lambda: D_loss) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta * Gs_beta_mul_in) with tf.control_dependencies([Gs_update_op]): G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list, **metric_args) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, G_lrate_in: sched.G_lrate, D_lrate_in: sched.D_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, Gs_beta_mul_in: 1 if cur_nimg >= ema_start_kimg * 1000 else 0, } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) feed_dict[run_D_reg_in] = run_D_reg cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. for _ in rounds: tflib.run(G_train_op, feed_dict) tflib.run(D_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.jpg' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()