def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. loss_args={}, # Options for loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for metrics. metric_args={}, # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). ema_start_kimg=None, # Start of the exponential moving average. Default to the half-life period. G_ema_kimg=10, # Half-life of the exponential moving average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=False, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=4, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=2, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=1, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False ): # Construct new networks according to G_args and D_args before resuming training? if ema_start_kimg is None: ema_start_kimg = G_ema_kimg # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: resume_networks = misc.load_pkl(resume_pkl) rG, rD, rGs = resume_networks if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) G_lrate_in = tf.placeholder(tf.float32, name='G_lrate_in', shape=[]) D_lrate_in = tf.placeholder(tf.float32, name='D_lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) run_D_reg_in = tf.placeholder(tf.bool, name='run_D_reg', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta_mul_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[]) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_ema_kimg * 1000.0) if G_ema_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) G_opt_args['learning_rate'] = G_lrate_in D_opt_args['learning_rate'] = D_lrate_in for args in [G_opt_args, D_opt_args]: args['minibatch_multiplier'] = minibatch_multiplier G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) # Build training graph for each GPU. for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): with tf.name_scope('DataFetch'): reals_read, labels_read = training_set.get_minibatch_tf() reals_read = process_reals(reals_read, lod_in, mirror_augment, training_set.dynamic_range, drange_net) # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('loss'): G_loss, D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, real_labels=labels_read, **loss_args) # Register gradients. if not lazy_regularization: if D_reg is not None: D_loss += D_reg else: if D_reg is not None: D_loss = tf.cond(run_D_reg_in, lambda: D_loss + D_reg * D_reg_interval, lambda: D_loss) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta * Gs_beta_mul_in) with tf.control_dependencies([Gs_update_op]): G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list, **metric_args) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, G_lrate_in: sched.G_lrate, D_lrate_in: sched.D_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, Gs_beta_mul_in: 1 if cur_nimg >= ema_start_kimg * 1000 else 0, } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) feed_dict[run_D_reg_in] = run_D_reg cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. for _ in rounds: tflib.run(G_train_op, feed_dict) tflib.run(D_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. AE_opt_args=None, # Options for autoencoder optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. AE_loss_args=None, # Options for autoencoder loss. dataset_args={}, # Options for dataset.load_dataset(). dataset_args_eval={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). train_data_dir=None, # Directory to load datasets from. eval_data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=True, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, resume_with_own_vars=False ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. print("Loading train set from %s..." % dataset_args.tfrecord_dir) training_set = dataset.load_dataset( data_dir=dnnlib.convert_path(train_data_dir), verbose=True, **dataset_args) print("Loading eval set from %s..." % dataset_args_eval.tfrecord_dir) eval_set = dataset.load_dataset( data_dir=dnnlib.convert_path(eval_data_dir), verbose=True, **dataset_args_eval) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Freeze Discriminator if D_args['freeze']: num_layers = np.log2(training_set.resolution) - 1 layers = int(np.round(num_layers * 3. / 8.)) scope = ['Output', 'scores_out'] for layer in range(layers): scope += ['.*%d' % 2**layer] if 'train_scope' in D_args: scope[-1] += '.*%d' % D_args['train_scope'] D_args['train_scope'] = scope # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is '' or resume_with_new_nets or resume_with_own_vars: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not '': print('Loading networks from "%s"...' % resume_pkl) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G = rG D = rD Gs = rGs grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) # SVD stuff if 'syn_svd' in G_args or 'map_svd' in G_args: # Run graph to calculate SVD grid_latents_smol = grid_latents[:1] rho = np.array([1]) grid_fakes = G.run(grid_latents_smol, grid_labels, rho, is_validation=True) grid_fakes = Gs.run(grid_latents_smol, grid_labels, rho, is_validation=True) load_d_fake = D.run(grid_reals[:1], rho, is_validation=True) with tf.device('/gpu:0'): # Create SVD-decomposed graph rG, rD, rGs = G, D, Gs G_lambda_mask = { var: np.ones(G.vars[var].shape[-1]) for var in G.vars if 'SVD/s' in var } D_lambda_mask = { 'D/' + var: np.ones(D.vars[var].shape[-1]) for var in D.vars if 'SVD/s' in var } G_reduce_dims = { var: (0, int(Gs.vars[var].shape[-1])) for var in Gs.vars if 'SVD/s' in var } G_args['lambda_mask'] = G_lambda_mask G_args['reduce_dims'] = G_reduce_dims D_args['lambda_mask'] = D_lambda_mask # Create graph with no SVD operations G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=rG.input_shapes[1][1], factorized=True, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=rD.input_shapes[1][1], factorized=True, **D_args) Gs = G.clone('Gs') grid_fakes = G.run(grid_latents_smol, grid_labels, rho, is_validation=True, minibatch_size=1) grid_fakes = Gs.run(grid_latents_smol, grid_labels, rho, is_validation=True, minibatch_size=1) G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) # Reduce per-gpu minibatch size to fit in 16GB GPU memory if grid_reals.shape[2] >= 1024: sched_args.minibatch_gpu_base = 2 print('Batch size', sched_args.minibatch_gpu_base) # Generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) rho = np.array([1]) grid_fakes = Gs.run(grid_latents, grid_labels, rho, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) if resume_pkl is not '': load_d_real = rD.run(grid_reals[:1], rho, is_validation=True) load_d_fake = rD.run(grid_fakes[:1], rho, is_validation=True) d_fake = D.run(grid_fakes[:1], rho, is_validation=True) d_real = D.run(grid_reals[:1], rho, is_validation=True) print('Factorized fake', d_fake, 'loaded fake', load_d_fake, 'factorized real', d_real, 'loaded real', load_d_real) print('(should match)') # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) if AE_opt_args is not None: AE_opt_args = dict(AE_opt_args) AE_opt_args['minibatch_multiplier'] = minibatch_multiplier AE_opt_args['learning_rate'] = lrate_in AE_opt = tflib.Optimizer(name='TrainAE', **AE_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): if G_loss_args['func_name'] == 'training.loss.G_l1': G_loss_args['reals'] = reals_read else: G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) ae_iter_mul = 10 ae_rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus * ae_iter_mul) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: _g_loss, _ = tflib.run([G_loss, G_train_op], feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): print('g loss', _g_loss) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and cur_tick % network_snapshot_ticks == 0 or done: pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(eval_data_dir), num_gpus=num_gpus, tf_config=tf_config, rho=rho) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close() eval_set.close()
def training_loop_vc( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. I_args={}, # Options for infogan-head/vcgan-head network. I_info_args={}, # Options for infogan-head/vcgan-head network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). use_info_gan=False, # Whether to use info-gan. use_vc_head=False, # Whether to use vc-head. use_vc_head_with_cls=False, # Whether to use classification in discriminator. data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=3, # Number of discrete latents in model. n_continuous=4, # Number of continuous latents in model. n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) if use_info_gan or use_vc_head or use_vc_head_with_cls: I = tflib.Network('I', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **I_args) if use_vc_head_with_cls: I_info = tflib.Network('I_info', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **I_info_args) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_info_gan or use_vc_head: rG, rD, rI, rGs = misc.load_pkl(resume_pkl) elif use_vc_head_with_cls: rG, rD, rI, rI_info, rGs = misc.load_pkl(resume_pkl) else: rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) if use_info_gan or use_vc_head or use_vc_head_with_cls: I.copy_vars_from(rI) if use_vc_head_with_cls: I_info.copy_vars_from(rI_info) Gs.copy_vars_from(rGs) else: G = rG D = rD if use_info_gan or use_vc_head or use_vc_head_with_cls: I = rI if use_vc_head_with_cls: I_info = rI_info Gs = rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() if use_info_gan or use_vc_head or use_vc_head_with_cls: I.print_layers() if use_vc_head_with_cls: I_info.print_layers() # pdb.set_trace() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) if traversal_grid: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) # pdb.set_trace() grid_fakes, _ = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=False) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') if use_info_gan or use_vc_head or use_vc_head_with_cls: I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') if use_vc_head_with_cls: I_info_gpu = I_info if gpu == 0 else I_info.clone( I_info.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): if use_info_gan or use_vc_head: G_loss, G_reg, I_loss, _ = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, I=I_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) elif use_vc_head_with_cls: G_loss, G_reg, I_loss, I_info_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, I=I_gpu, I_info=I_info_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) else: G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) # print('G_gpu.trainables:', G_gpu.trainables) # print('D_gpu.trainables:', D_gpu.trainables) # print('I_gpu.trainables:', I_gpu.trainables) if use_info_gan or use_vc_head: GI_gpu_trainables = collections.OrderedDict( list(G_gpu.trainables.items()) + list(I_gpu.trainables.items())) G_opt.register_gradients(tf.reduce_mean(G_loss + I_loss), GI_gpu_trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # G_opt.register_gradients(tf.reduce_mean(I_loss), # GI_gpu_trainables) # D_opt.register_gradients(tf.reduce_mean(I_loss), # D_gpu.trainables) elif use_vc_head_with_cls: GIIinfo_gpu_trainables = collections.OrderedDict( list(G_gpu.trainables.items()) + list(I_gpu.trainables.items()) + list(I_info_gpu.trainables.items())) G_opt.register_gradients( tf.reduce_mean(G_loss + I_loss + I_info_loss), GIIinfo_gpu_trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) else: G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # if use_info_gan: # # INFO-GAN-HEAD loss # G_opt.register_gradients(tf.reduce_mean(I_loss), # G_gpu.trainables) # G_opt.register_gradients(tf.reduce_mean(I_loss), # I_gpu.trainables) # D_opt.register_gradients(tf.reduce_mean(I_loss), # D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() if use_info_gan or use_vc_head or use_vc_head_with_cls: I.setup_weight_histograms() if use_vc_head_with_cls: I_info.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes, _ = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=False) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) if use_info_gan or use_vc_head: misc.save_pkl((G, D, I, Gs), pkl) elif use_vc_head_with_cls: misc.save_pkl((G, D, I, I_info, Gs), pkl) else: misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. if use_info_gan or use_vc_head: misc.save_pkl((G, D, I, Gs), dnnlib.make_run_dir_path('network-final.pkl')) elif use_vc_head_with_cls: misc.save_pkl((G, D, I, I_info, Gs), dnnlib.make_run_dir_path('network-final.pkl')) else: misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop_refinement( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=True, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) Gs = G.clone('Gs') if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) _rG, _rD, rGs = misc.load_pkl(resume_pkl) del _rD, _rG if resume_with_new_nets: G.copy_vars_from(rGs) Gs.copy_vars_from(rGs) del rGs else: G = rG Gs = rGs # Set constant noise input for both G and Gs if G_args.get("randomize_noise", None) == False: noise_vars = [ var for name, var in G.components.synthesis.vars.items() if name.startswith('noise') ] rnd = np.random.RandomState(123) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] noise_vars = [ var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise') ] rnd = np.random.RandomState(123) tflib.set_vars( {var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width] # TESTS # from PIL import Image # reals, latents = training_set.get_minibatch_np(4) # reals = np.transpose(reals, [0, 2, 3, 1]) # Image.fromarray(reals[0], 'RGB').save("test_reals.png") # labels = training_set.get_random_labels_np(4) # Gs_kwargs = dnnlib.EasyDict() # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) # fakes = Gs.run(latents, labels, minibatch_size=4, **Gs_kwargs) # Image.fromarray(fakes[0], 'RGB').save("test_fakes_Gs_new.png") # fakes = G.run(latents, labels, minibatch_size=4, **Gs_kwargs) # Image.fromarray(fakes[0], 'RGB').save("test_fakes_G_new.png") # Print layers and generate initial image snapshot. G.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) # Freeze layers G_args.freeze_layers = list(G_args.get("freeze_layers", [])) def freeze_vars(gen, verbose=True): assert len(G_args.freeze_layers) > 0 for name in list(gen.trainables.keys()): if any(layer in name for layer in G_args.freeze_layers): del gen.trainables[name] if verbose: print(f"Freezed {name}") # Build training graph for each GPU. data_fetch_ops = [] loss_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if G_args.freeze_layers: freeze_vars(G_gpu, verbose=False) # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=None, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, latents=labels_read, **G_loss_args) loss_ops.append(G_loss) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) loss_op = tf.reduce_mean(tf.concat(loss_ops, axis=0)) G_train_op = G_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 loss_per_batch_sum = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } tflib.run(data_fetch_op, feed_dict) ### TEST # fakes = G.get_output_for(labels_read, training_set.get_random_labels_tf(minibatch_gpu_in), is_training=True) # this is without activation in ~[-1.5, 1.5] # fakes = tf.clip_by_value(fakes, drange_net[0], drange_net[1]) # reals = reals_read ### TEST for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: loss, _ = tflib.run([loss_op, G_train_op], feed_dict) # (loss, reals, fakes), _ = tflib.run([loss_op, G_train_op], feed_dict) tflib.run([data_fetch_op], feed_dict) # print(f"loss_tf {np.mean(loss)}") # print(f"loss_np {np.mean(np.square(reals - fakes))}") # print(f"loss_abs {np.mean(np.abs(reals - fakes))}") loss_per_batch_sum += loss #### TEST #### # if cur_nimg == sched.minibatch_size or cur_nimg % 2048 == 0: # from PIL import Image # reals = np.transpose(reals, [0, 2, 3, 1]) # fakes = np.transpose(fakes, [0, 2, 3, 1]) # diff = np.abs(reals - fakes) # print(diff.min(), diff.max()) # for idx, (fake, real) in enumerate(zip(fakes, reals)): # fake -= fake.min() # fake /= fake.max() # fake *= 255 # fake = fake.astype(np.uint8) # Image.fromarray(fake, 'RGB').save(f"fake_loss_{idx}.png") # real -= real.min() # real /= real.max() # real *= 255 # real = real.astype(np.uint8) # Image.fromarray(real, 'RGB').save(f"real_loss_{idx}.png") #### if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([Gs_update_op], feed_dict) # Slow path with gradient accumulation. FIXME: Probably wrong else: for _round in rounds: loss, _, _ = tflib.run( [loss_op, G_train_op, data_fetch_op], feed_dict) loss_per_batch_sum += loss / len(rounds) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time tick_loss = loss_per_batch_sum * sched.minibatch_size / ( tick_kimg * 1000) loss_per_batch_sum = 0 # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d loss/px %-12.8f time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), autosummary('Progress/loss_per_px', tick_loss), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, None, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, None, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
def training_loop( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. G_opt_args = {}, # Options for generator optimizer. D_opt_args = {}, # Options for discriminator optimizer. G_loss_args = {}, # Options for generator loss. D_loss_args = {}, # Options for discriminator loss. dataset_args = {}, # Options for dataset.load_dataset(). sched_args = {}, # Options for train.TrainingSchedule. grid_args = {}, # Options for train.setup_snapshot_image_grid(). metric_arg_list = [], # Options for MetricGroup. tf_config = {}, # Options for tflib.init_tf(). data_dir = None, # Directory to load datasets from. G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. lazy_regularization = True, # Perform regularization as a separate training step? G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg = 25000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? drange_net = [0,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_pkl = None, # Network pickle to resume training from, None = train from scratch. resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. training_set.configure(minibatch_size=sched_args.batch_size) with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') start = 0 if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG); D.copy_vars_from(rD); Gs.copy_vars_from(rGs) else: G = rG; D = rD; Gs = rGs start = int(resume_pkl.split('-')[-1].split('.')[0]) // sched_args.batch_size # Print layers and generate initial image snapshot. G.print_layers(); D.print_layers() grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched_args.batch_size) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) global_step = tf.Variable(start, trainable=False, name='learning_rate_step') learning_rate = tf.train.exponential_decay(sched_args.lr, global_step, sched_args.decay_step, sched_args.decay_rate, staircase=sched_args.stair) add_global = global_step.assign_add(1) D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args) G_opt = tflib.Optimizer(name='TrainG', learning_rate=learning_rate, **G_opt_args) for gpu in range(num_gpus): print('build graph on gpu %s' % str(gpu)) with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') with tf.name_scope('DataFetch'): reals_read, labels_read = training_set.get_minibatch_tf() reals_read, labels_read = process_reals(reals_read, labels_read, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('Loss'), tf.control_dependencies(None): loss, reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=sched_args.batch_size, reals=reals_read, labels=labels_read, **D_loss_args) with tf.control_dependencies([add_global]): G_opt.register_gradients(loss, G_gpu.trainables) D_opt.register_gradients(loss, D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break loss_, _, _, lr_ = tflib.run([loss, G_train_op, D_train_op, learning_rate]) cur_nimg += sched_args.batch_size * num_gpus done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched_args.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f ' 'sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f loss %-8.1f lr %-2.5f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched_args.batch_size), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2 ** 30), autosummary('loss', loss_), autosummary('lr', lr_))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched_args.batch_size) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0.0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def joint_train( submit_config, opt, metric_arg_list, sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 dataset_args = {}, # 数据集设置。 total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 mirror_augment = False, # 启用镜像增强? reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0, # 在训练开始时给定统计时间。影响报告。 *args, **kwargs ): output_dir = opt.output_dir graph_kwargs = util.set_graph_kwargs(opt) graph_util = importlib.import_module('graphs.' + opt.model + '.graph_util') constants = importlib.import_module('graphs.' + opt.model + '.constants') model = graphs.find_model_using_name(opt.model, opt.transform) g = model(submit_config=submit_config, dataset_args=dataset_args, **graph_kwargs, **kwargs) g.initialize_graph() # create training samples #num_samples = opt.num_samples # if opt.model == 'biggan' and opt.biggan.category is not None: # graph_inputs = graph_util.graph_input(g, num_samples, seed=0, category=opt.biggan.category) # else: # graph_inputs = graph_util.graph_input(g, num_samples, seed=0) w_snapshot_ticks = opt.model_save_freq ctx = dnnlib.RunContext(submit_config, train) training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(g.G, training_set, **grid_args) sched = training_loop.training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: g.G.setup_weight_histograms(); g.D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 loss_values = [] while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_loop.training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): g.G_opt.reset_optimizer_state(); # D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): alpha_for_graph, alpha_for_target = g.get_train_alpha(constants.BATCH_SIZE) if not isinstance(alpha_for_graph, list): alpha_for_graph = [alpha_for_graph] alpha_for_target = [alpha_for_target] for ag, at in zip(alpha_for_graph, alpha_for_target): feed_dict_out = graph_util.graph_input(g, constants.BATCH_SIZE, seed=0) out_zs = g.sess.run(g.outputs_orig, feed_dict_out) target_fn, mask_out = g.get_target_np(out_zs, at) feed_dict = feed_dict_out feed_dict[g.alpha] = ag feed_dict[g.target] = target_fn feed_dict[g.mask] = mask_out feed_dict[g.lod_in] = sched.lod feed_dict[g.lrate_in] = sched.D_lrate feed_dict[g.minibatch_in] = sched.minibatch curr_loss, _, Gs_op, G_op = g.sess.run([g.joint_loss, g.train_step, g.Gs_update_op, g.G_train_op], feed_dict=feed_dict) loss_values.append(curr_loss) cur_nimg += sched.minibatch #tflib.run([g.Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) #tflib.run([g.G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((g.G, g.D, g.Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) if cur_tick % w_snapshot_ticks == 0 or done: g.saver.save(g.sess, './{}/model_{}.ckpt'.format( output_dir, (cur_nimg // 1000)), write_meta_graph=False, write_state=False) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((g.G, g.D, g.Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close() loss_values = np.array(loss_values) np.save('./{}/loss_values.npy'.format(output_dir), loss_values) f, ax = plt.subplots(figsize=(10, 4)) ax.plot(loss_values) f.savefig('./{}/loss_values.png'.format(output_dir))
def training_loop_hd( I_args={}, # Options for generator network. M_args={}, # Options for discriminator network. I_info_args={}, # Options for class network. I_opt_args={}, # Options for generator optimizer. I_loss_args={}, # Options for generator loss. resume_G_pkl=None, # G network pickle to help training. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. I_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? I_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to I_args and M_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=0, # Number of discrete latents in model. n_continuous=10, # Number of continuous latents in model. n_samples_per=4, # Number of samples for each line in traversal. use_hd_with_cls=False, # If use info_loss. resolution_manual=1024, # Resolution of generated images. use_level_training=False, # If use level training (hierarchical optimization strategy). level_I_kimg=1000, # Number of kimg of tick for I_level training. use_std_in_m=False, # If output prior std in M net. prior_latent_size=512, # Prior latent size. use_hyperplane=False, # If use hyperplane model. pretrained_type='with_stylegan2'): # Pretrained type for G. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') I = tflib.Network('I', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **I_args) M = tflib.Network('M', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **M_args) Is = I.clone('Is') if use_hd_with_cls: I_info = tflib.Network('I_info', num_channels=training_set.shape[0], resolution=resolution_manual, label_size=training_set.label_size, **I_info_args) if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_hd_with_cls: rI, rM, rIs, rI_info = misc.load_pkl(resume_pkl) else: rI, rM, rIs = misc.load_pkl(resume_pkl) if resume_with_new_nets: I.copy_vars_from(rI) M.copy_vars_from(rM) Is.copy_vars_from(rIs) if use_hd_with_cls: I_info.copy_vars_from(rI_info) else: I = rI M = rM Is = rIs if use_hd_with_cls: I_info = rI_info print('Loading generator from "%s"...' % resume_G_pkl) if pretrained_type == 'with_stylegan2': rG, rD, rGs = misc.load_pkl(resume_G_pkl) G = rG D = rD Gs = rGs elif pretrained_type == 'with_cascadeVAE': rG = misc.load_pkl(resume_G_pkl) G = rG Gs = rG # Print layers and generate initial image snapshot. I.print_layers() M.print_layers() # pdb.set_trace() training_set_resolution_log2 = int(np.log2(resolution_manual)) sched = training_schedule( cur_nimg=total_kimg * 1000, training_set_resolution_log2=training_set_resolution_log2, **sched_args) if not use_hyperplane: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels) print('grid_size:', grid_size) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) if resolution_manual >= 256: grid_size = (grid_size[0], grid_size[1] // 5) grid_latents = grid_latents[:grid_latents.shape[0] // 5] grid_labels = grid_labels[:grid_labels.shape[0] // 5] prior_traj_latents = M.run(grid_latents, is_validation=True, minibatch_size=sched.minibatch_gpu) if use_std_in_m: prior_traj_latents = prior_traj_latents[:, :prior_latent_size] else: grid_size = (n_samples_per, n_continuous) grid_labels = np.tile(grid_labels[:1], (n_continuous * n_samples_per, 1)) latent_dirs = get_latent_dirs(n_continuous) prior_traj_latents = get_prior_traj_by_dirs(latent_dirs, M, n_samples_per, prior_latent_size, grid_labels, sched) if resolution_manual >= 256: grid_size = (grid_size[0], grid_size[1] // 5) prior_traj_latents = prior_traj_latents[:prior_traj_latents. shape[0] // 5] grid_labels = grid_labels[:grid_labels.shape[0] // 5] print('prior_traj_latents.shape:', prior_traj_latents.shape) # pdb.set_trace() prior_traj_latents_show = np.reshape( prior_traj_latents, [-1, n_samples_per, prior_latent_size]) print_traj(prior_traj_latents_show) grid_fakes = Gs.run(prior_traj_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) if (n_continuous == 2) and (n_discrete == 0): n_per_line = 20 ex_latent_value = 3 prior_grid_latents, prior_grid_labels = get_2d_grid_latents( low=-ex_latent_value, high=ex_latent_value, n_per_line=n_per_line, grid_labels=grid_labels) grid_showing_fakes = Gs.run(prior_grid_latents, prior_grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_showing_fakes = add_outline(grid_showing_fakes, width=1) misc.save_image_grid( grid_showing_fakes, dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png'), drange=drange_net, grid_size=[n_per_line, n_per_line]) img_to_draw = Image.open( dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png')) img_to_draw = img_to_draw.convert('RGB') img_to_draw = draw_traj_on_prior_grid(img_to_draw, prior_traj_latents_show, ex_latent_value, n_per_line) img_to_draw.save( dnnlib.make_run_dir_path('fakes_init_2d_prior_grid_drawn.png')) if use_level_training: ending_level = training_set_resolution_log2 - 1 else: ending_level = 1 # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Is_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), I_smoothing_kimg * 1000.0) if I_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. I_opt_args = dict(I_opt_args) for args, reg_interval in [(I_opt_args, I_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio I_opts = [] I_reg_opts = [] for n_level in range(ending_level): I_opts.append(tflib.Optimizer(name='TrainI_%d' % n_level, **I_opt_args)) I_reg_opts.append( tflib.Optimizer(name='RegI_%d' % n_level, share=I_opts[-1], **I_opt_args)) # Build training graph for each GPU. for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of I and M. I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow') M_gpu = M if gpu == 0 else M.clone(M.name + '_shadow') G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if use_hd_with_cls: I_info_gpu = I_info if gpu == 0 else I_info.clone(I_info.name + '_shadow') # Evaluate loss functions. lod_assign_ops = [] I_losses = [] I_regs = [] if 'lod' in I_gpu.vars: lod_assign_ops += [tf.assign(I_gpu.vars['lod'], lod_in)] if 'lod' in M_gpu.vars: lod_assign_ops += [tf.assign(M_gpu.vars['lod'], lod_in)] for n_level in range(ending_level): with tf.control_dependencies(lod_assign_ops): with tf.name_scope('I_loss_%d' % n_level): if use_hd_with_cls: I_loss, I_reg = dnnlib.util.call_func_by_name( I=I_gpu, M=M_gpu, G=G_gpu, I_info=I_info_gpu, opt=I_opts[n_level], training_set=training_set, minibatch_size=minibatch_gpu_in, **I_loss_args) else: I_loss, I_reg = dnnlib.util.call_func_by_name( I=I_gpu, M=M_gpu, G=G_gpu, opt=I_opts[n_level], n_levels=(n_level + 1) if use_level_training else None, training_set=training_set, minibatch_size=minibatch_gpu_in, **I_loss_args) I_losses.append(I_loss) I_regs.append(I_reg) # Register gradients. if not lazy_regularization: if I_regs[n_level] is not None: I_losses[n_level] += I_regs[n_level] else: if I_regs[n_level] is not None: I_reg_opts[n_level].register_gradients( tf.reduce_mean(I_regs[n_level] * I_reg_interval), I_gpu.trainables) if use_hd_with_cls: MIIinfo_gpu_trainables = collections.OrderedDict( list(M_gpu.trainables.items()) + list(I_gpu.trainables.items()) + list(I_info_gpu.trainables.items())) I_opts[n_level].register_gradients( tf.reduce_mean(I_losses[n_level]), MIIinfo_gpu_trainables) else: MI_gpu_trainables = collections.OrderedDict( list(M_gpu.trainables.items()) + list(I_gpu.trainables.items())) I_opts[n_level].register_gradients( tf.reduce_mean(I_losses[n_level]), MI_gpu_trainables) # Setup training ops. I_train_ops = [] I_reg_ops = [] for n_level in range(ending_level): I_train_ops.append(I_opts[n_level].apply_updates()) I_reg_ops.append(I_reg_opts[n_level].apply_updates(allow_no_op=True)) Is_update_op = Is.setup_as_moving_average_of(I, beta=Is_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: I.setup_weight_histograms() M.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break n_level = 0 if not use_level_training else min( cur_nimg // (level_I_kimg * 1000), training_set_resolution_log2 - 2) # Choose training parameters and configure training ops. sched = training_schedule( cur_nimg=cur_nimg, training_set_resolution_log2=training_set_resolution_log2, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): I_opts[n_level].reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.I_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_I_reg = (lazy_regularization and running_mb_counter % I_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run(I_train_ops[n_level], feed_dict) if run_I_reg: tflib.run(I_reg_ops[n_level], feed_dict) tflib.run([Is_update_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(I_train_ops[n_level], feed_dict) if run_I_reg: for _round in rounds: tflib.run(I_reg_ops[n_level], feed_dict) tflib.run(Is_update_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 100 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): if not use_hyperplane: prior_traj_latents = M.run( grid_latents, is_validation=True, minibatch_size=sched.minibatch_gpu) if use_std_in_m: prior_traj_latents = prior_traj_latents[:, : prior_latent_size] else: prior_traj_latents = get_prior_traj_by_dirs( latent_dirs, M, n_samples_per, prior_latent_size, grid_labels, sched) if resolution_manual >= 256: prior_traj_latents = prior_traj_latents[: prior_traj_latents .shape[0] // 5] prior_traj_latents_show = np.reshape( prior_traj_latents, [-1, n_samples_per, prior_latent_size]) print_traj(prior_traj_latents_show) grid_fakes = Gs.run(prior_traj_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if (n_continuous == 2) and (n_discrete == 0): n_per_line = 20 ex_latent_value = 3 prior_grid_latents, prior_grid_labels = get_2d_grid_latents( low=-ex_latent_value, high=ex_latent_value, n_per_line=n_per_line, grid_labels=grid_labels) grid_showing_fakes = Gs.run( prior_grid_latents, prior_grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True, normalize_latents=False) grid_showing_fakes = add_outline(grid_showing_fakes, width=1) misc.save_image_grid(grid_showing_fakes, dnnlib.make_run_dir_path( 'fakes_2d_prior_grid%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=[n_per_line, n_per_line]) img_to_draw = Image.open( dnnlib.make_run_dir_path( 'fakes_2d_prior_grid%06d.png' % (cur_nimg // 1000))) img_to_draw = img_to_draw.convert('RGB') img_to_draw = draw_traj_on_prior_grid( img_to_draw, prior_traj_latents_show, ex_latent_value, n_per_line) img_to_draw.save( dnnlib.make_run_dir_path( 'fakes_2d_prior_grid_drawn%06d.png' % (cur_nimg // 1000))) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((I, M, Is), pkl) metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((I, M, Is), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( submit_config, G_args = {}, # 生成网络的设置。 D_args = {}, # 判别网络的设置。 G_opt_args = {}, # 生成网络优化器设置。 D_opt_args = {}, # 判别网络优化器设置。 G_loss_args = {}, # 生成损失设置。 D_loss_args = {}, # 判别损失设置。 dataset_args = {}, # 数据集设置。 sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 metric_arg_list = [], # 指标方法设置。 tf_config = {}, # tflib.init_tf()相关设置。 G_smoothing_kimg = 10.0, # 生成器权重的运行平均值的半衰期。 D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 mirror_augment = False, # 启用镜像增强? drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0): # 在训练开始时给定统计时间。影响报告。 # 初始化dnnlib和TensorFlow。 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # 载入训练集。 training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # 构建网络。 with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers(); D.print_layers() # 构建计算图与优化器 print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。 lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) cur_nimg += sched.minibatch tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_auto_loop( Enc_args={}, # Options for encoder network. Dec_args={}, # Options for decoder network. opt_args={}, # Options for encoder optimizer. loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=25000, # Total length of the training, measured in thousands of real images. drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False ): # Include weight histograms in the tfevents file? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, _ = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): print('Constructing networks...') Enc = tflib.Network('Encoder', resolution=training_set.shape[1], out_channels=16, **Enc_args) Dec = tflib.Network('Decoder', resolution=training_set.shape[1] // 4, in_channels=16, **Dec_args) # Print layers and generate initial image snapshot. Enc.print_layers() Dec.print_layers() sched = sched_args sched.tick_kimg = 4 grid_codes = Enc.run((grid_reals / 127.5) - 1.0, minibatch_size=sched.minibatch_gpu) grid_fakes = Dec.run(grid_codes, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. opt_args = dict(opt_args) opt_args['minibatch_multiplier'] = minibatch_multiplier opt_args['learning_rate'] = lrate_in opt = tflib.Optimizer(name='TrainAuto', **opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of Enc and Dec. Enc_gpu = Enc if gpu == 0 else Enc.clone(Enc.name + '_shadow') Dec_gpu = Dec if gpu == 0 else Dec.clone(Dec.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, _ = process_reals(reals_write, labels_write, 0.0, False, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] reals_read = reals_var[:minibatch_gpu_in] # Evaluate loss functions. with tf.name_scope('loss'): loss = dnnlib.util.call_func_by_name(Enc=Enc_gpu, Dec=Dec_gpu, opt=opt, reals=reals_read, **loss_args) # Register gradients. opt.register_gradients( tf.reduce_mean(loss), list(Enc_gpu.trainables.values()) + list(Dec_gpu.trainables.values())) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) train_op = opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: Enc.setup_weight_histograms() Dec.setup_weight_histograms() print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = 0 cur_tick = -1 tick_start_nimg = cur_nimg while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # Run training ops. feed_dict = { lrate_in: sched.lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) cur_nimg += sched.minibatch_size # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run(data_fetch_op, feed_dict) tflib.run(train_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_codes = Enc.run((grid_reals / 127.5) - 1.0, minibatch_size=sched.minibatch_gpu) grid_fakes = Dec.run(grid_codes, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((Enc, Dec), pkl) # Update summaries and RunContext. tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0.0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((Enc, Dec), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=10, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) # ajay - move init to after graph creation? tflib.init_tf(tf_config) # Load training set. print('ajay - config data dir', config.data_dir) training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, num_hosts=hvd.size(), index=hvd.rank(), **dataset_args) # Construct networks. print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # with tf.device('/gpu:0'): # if resume_run_id is not None: # network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) # print('Loading networks from "%s"...' % network_pkl) # G, D, Gs = misc.load_pkl(network_pkl) # else: # print('Constructing networks...') # G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) # D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) # Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tf.train.AdamOptimizer(learning_rate=lrate_in, beta1=0.0, beta2=0.99, epsilon=1e-8) G_opt = hvd.DistributedOptimizer(G_opt) D_opt = tf.train.AdamOptimizer(learning_rate=lrate_in, beta1=0.0, beta2=0.99, epsilon=1e-8) D_opt = hvd.DistributedOptimizer(D_opt) G_gpu = G D_gpu = D lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] # ajay - check if unique minibatch is guaranteed i.e sharding is done right! reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_grads = G_opt.compute_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_grads = D_opt.compute_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_gradients(G_grads) D_train_op = D_opt.apply_gradients(D_grads) # Horovod init_op = tf.initialize_all_variables() bcast_op = hvd.broadcast_global_variables(0) # ajay tf.get_default_session().run([init_op]) tflib.run([bcast_op]) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) # todo: ajay - note num_gpus need to change to hvd size when going multi-node sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) if hvd.rank() == 0: print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < (total_kimg * 1000): if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) # todo: ajay - find a way to manually resetoptimizer if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): tflib.assert_tf_initialized() G_opt_reset_op = [var.initializer for var in G_opt.variables()] D_opt_reset_op = [var.initializer for var in D_opt.variables()] tflib.run(G_opt_reset_op) tflib.run(D_opt_reset_op) # G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # grp_train_op = tf.group(D_train_op, [Gs_update_op]) # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch #// submit_config.num_gpus tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= (total_kimg * 1000)) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. # ajay #ajay mod if hvd.rank() == 0: print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f ' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( # autosummary('Progress/tick', cur_tick), # autosummary('Progress/kimg', cur_nimg / 1000.0), # autosummary('Progress/lod', sched.lod), # autosummary('Progress/minibatch', sched.minibatch), # dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), # autosummary('Timing/sec_per_tick', tick_time), # autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), # autosummary('Timing/maintenance_sec', maintenance_time), # autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) # autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) # autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) # ajay - note modifying to 1 for eval metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=1, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() if hvd.rank() == 0: tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. if hvd.rank() == 0: misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( submit_config, HP_args={}, # Options for the Hessian Penalty. G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? interp_snapshot_ticks=20, # How often to generate interpolation visualizations in TensorBoard? network_snapshot_ticks=5, # How often to export network snapshots? network_metric_ticks=5, # How often to evaluate network snapshots on specified metrics? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create a copy of dataset_args for running the metrics: metrics_dataset_args = deepcopy(dataset_args) metrics_dataset_args.shuffle_mb = 0 print('Saving interp videos every %s ticks' % interp_snapshot_ticks) print('Saving network snapshot every %s ticks' % network_snapshot_ticks) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # G.print_layers(); D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # The loss weighting of the Hessian Penalty can be dynamic over training, so we need to use a placeholder: hessian_weight = tf.placeholder(tf.float32, name='hessian_weight', shape=[]) G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) reg_ops = [ ] # Returning the values of the Hessian Penalty/ InfoGAN losses so they can be registered in TensorBoard for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss, G_hessian_penalty = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, hp_lambda=hessian_weight, HP_args=HP_args, gpu_ix=gpu, lod_in=lod_in, max_lod=training_set.resolution_log2, **G_loss_args) if HP_args.hp_lambda > 0: reg_ops += [G_hessian_penalty] with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss, mutual_info = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, gpu_ix=gpu, infogan_nz=D_args.infogan_nz, **D_loss_args) # print([name for name in D_gpu.trainables.keys()]) # gps = [weight for name, weight in G_gpu.trainables.items()][0] # dps = [weight for name, weight in D_gpu.trainables.items() if 'Q_Encoder' in name][0] # gg = autosummary('Loss/G_info_grad', tf.reduce_sum(tf.gradients(mutual_info, gps)[0]**2)) # dg = autosummary('Loss/D_info_grad', tf.reduce_sum(tf.gradients(mutual_info, dps)[0]**2)) # reg_ops.extend([dg, gg, dps, gps]) if G_args.infogan_lambda > 0 or D_args.infogan_lambda > 0: reg_ops += [mutual_info] # Note, even though we are adding mutual_info loss here, the only time the loss is non-zero # is when infogan_lambda > 0 (in Hessian Penalty experiments, we always set it to zero): G_opt.register_gradients( G_loss + G_args.infogan_lambda * mutual_info, G_gpu.trainables) D_opt.register_gradients( tf.reduce_mean(D_loss) + D_args.infogan_lambda * mutual_info, D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up snapshot interpolation...') nz = G.input_shapes[0][1] interp_steps = 24 # Number of frames in the visualization interp_batch_size = 8 # Number of gifs per row of visualization interp_z = vis_tools.sample_interp_zs(nz, interp_batch_size, interp_steps) interp_labels = np.zeros( [interp_steps * interp_batch_size * nz, training_set.label_size], dtype=training_set.label_dtype) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) summary_log.add_summary( build_image_summary(os.path.join(submit_config.run_dir, 'reals.png'), 'samples/real'), 0) summary_log.add_summary( build_image_summary( os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), 'samples/Gs'), resume_kimg) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) if interp_snapshot_ticks != -1 and interp_snapshot_ticks < 9999: print('Generating initial interpolations...') vis_tools.make_and_save_interpolation_gifs( Gs, interp_z, interp_labels, minibatch_size=sched.minibatch // submit_config.num_gpus, interp_steps=interp_steps, interp_batch_size=interp_batch_size, cur_kimg=resume_kimg, summary_log=summary_log) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 num_G_grad_steps = 0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch cur_hessian_weight = get_current_hessian_penalty_loss_weight( HP_args.hp_lambda, HP_args.hp_start_nimg, cur_nimg, HP_args.warmup_nimg) tflib.run( [G_train_op] + reg_ops, { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, hessian_weight: cur_hessian_weight }) num_G_grad_steps += 1 # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d hessian_weight %s time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), autosummary('Progress/hessian_weight', cur_hessian_weight), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) autosummary('Progress/G_grad_steps', num_G_grad_steps) # Save snapshots and fake image samples (for both Gs and G): if cur_tick % image_snapshot_ticks == 0 or done: iter = (cur_nimg // 1000) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes_inst = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) fake_path = os.path.join(submit_config.run_dir, 'fakes%06d.png' % iter) ifake_path = os.path.join(submit_config.run_dir, 'ifakes%06d.png' % iter) misc.save_image_grid(grid_fakes, fake_path, drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_fakes_inst, ifake_path, drange=drange_net, grid_size=grid_size) summary_log.add_summary( build_image_summary(fake_path, 'samples/Gs'), iter) summary_log.add_summary( build_image_summary(ifake_path, 'samples/G'), iter) # Generate/Save Interpolation Visualizations: if interp_snapshot_ticks != -1 and cur_tick % interp_snapshot_ticks == 0: vis_tools.make_and_save_interpolation_gifs( Gs, interp_z, interp_labels, minibatch_size=sched.minibatch // submit_config.num_gpus, interp_steps=interp_steps, interp_batch_size=interp_batch_size, cur_kimg=cur_nimg // 1000, summary_log=summary_log) # Save snapshot and run metrics: if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) if cur_tick % network_metric_ticks == 0 or done or cur_tick == 1: metrics.run(pkl, dataset_args=metrics_dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % total_kimg)) summary_log.close() ctx.close()
def training_loop( submit_config, #--------------------------------------------------------------- # Modified by Deng et al. noise_dim=32, weight_args={}, train_stage_args={}, #--------------------------------------------------------------- G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=True, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=True, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=87, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=2364, # Snapshot index to resume training from, None = autodetect. resume_kimg=2364, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, **_kwargs ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. PI = 3.1415927 ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Create 3d face reconstruction block FaceRender = Face3D() # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') #--------------------------------------------------------------- # Modified by Deng et al. G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, latent_size=254 + noise_dim, **G_args) #--------------------------------------------------------------- D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) resolution = tf.placeholder(tf.float32, name='resolution', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) #--------------------------------------------------------------- # Modified by Deng et al. G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\ G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\ lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\ drange_net=drange_net,lod_in=lod_in,**train_stage_args) #--------------------------------------------------------------- G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) #--------------------------------------------------------------- # Modified by Deng et al. restore_weights_and_initialize(train_stage_args) print('Setting up snapshot image grid...') sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_latents = tf.random_normal([np.prod(grid_size), 128 + 32 + 16 + 3]) grid_INPUTcoeff = z_to_lambda_mapping(grid_latents) grid_INPUTcoeff_w_t = tf.concat( [grid_INPUTcoeff, tf.zeros([np.prod(grid_size), 3])], axis=1) with tf.name_scope('FaceRender'): grid_render_img, _, _, _ = FaceRender.Reconstruction_Block( grid_INPUTcoeff_w_t, 256, np.prod(grid_size), progressive=False) grid_render_img = tf.transpose(grid_render_img, perm=[0, 3, 1, 2]) grid_render_img = process_reals(grid_render_img, lod_in, False, training_set.dynamic_range, drange_net) grid_INPUTcoeff_, grid_renders = tflib.run( [grid_INPUTcoeff, grid_render_img], {lod_in: sched.lod}) grid_noise = np.random.randn(np.prod(grid_size), 32) grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_, grid_noise], axis=1) grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) #--------------------------------------------------------------- summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch, resolution: sched.resolution }) # print('iter') # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: #--------------------------------------------------------------- # Modified by Deng et al. grid_fakes = Gs.run(grid_INPUTcoeff_w_noise, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) #--------------------------------------------------------------- if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close() #----------------------------------------------------------------------------
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device("/gpu:0"): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print("Constructing networks...") G = tflib.Network( name="G", num_inputs=2, # one for latents and one for labels num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network( name="D", num_inputs=int(np.log2(training_set.shape[1])) - 1 + 1, # +1 for labels :) num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone("Gs") G.print_layers() D.print_layers() print("Building TensorFlow graph...") with tf.name_scope("Inputs"), tf.device("/cpu:0"): lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[]) minibatch_in = tf.placeholder(tf.int32, name="minibatch_in", shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = (0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0) G_opt = tflib.Optimizer(name="TrainG", learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name="TrainD", learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow") D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow") reals, labels = training_set.get_minibatch_tf() reals = process_reals( reals, mirror_augment, training_set.dynamic_range, drange_net, depth=training_set.resolution_log2 - 1, ) with tf.name_scope("G_loss"): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope("D_loss"): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device("/gpu:0"): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) print("Setting up snapshot image grid...") grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_size // submit_config.num_gpus, ) print("Setting up run dir...") fake_multi_scale_dirs = [ os.path.join(submit_config.run_dir, str(2**res) + "x" + str(2**res)) for res in range(2, 2 + len(grid_fakes)) ] misc.save_image_grid( grid_reals, os.path.join(submit_config.run_dir, "reals.png"), drange=training_set.dynamic_range, grid_size=grid_size, ) misc.save_image_grids( grid_fakes, [ os.path.join(fake_multi_scale_dir, "fakes%06d.png" % resume_kimg) for fake_multi_scale_dir in fake_multi_scale_dirs ], drange=drange_net, grid_size=grid_size, ) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print("Training...\n") ctx.update("", cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg # configure the training_set to a proper minibatch size training_set.configure(sched.minibatch_size // submit_config.num_gpus) while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lrate_in: sched.D_lrate, minibatch_in: sched.minibatch_size }, ) cur_nimg += sched.minibatch_size tflib.run( [G_train_op], { lrate_in: sched.G_lrate, minibatch_in: sched.minibatch_size }, ) # Perform maintenance tasks once per tick. done = cur_nimg >= total_kimg * 1000 if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( "tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f" % ( autosummary("Progress/tick", cur_tick), autosummary("Progress/kimg", cur_nimg / 1000.0), autosummary("Progress/minibatch", sched.minibatch_size), dnnlib.util.format_time( autosummary("Timing/total_sec", total_time)), autosummary("Timing/sec_per_tick", tick_time), autosummary("Timing/sec_per_kimg", tick_time / tick_kimg), autosummary("Timing/maintenance_sec", maintenance_time), autosummary("Resources/peak_gpu_mem_gb", peak_gpu_mem_op.eval() / 2**30), )) autosummary("Timing/total_hours", total_time / (60.0 * 60.0)) autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_size // submit_config.num_gpus, ) misc.save_image_grids( grid_fakes, [ os.path.join(fake_multi_scale_dir, "fakes%06d.png" % (cur_nimg // 1000)) for fake_multi_scale_dir in fake_multi_scale_dirs ], drange=drange_net, grid_size=grid_size, ) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, "network-snapshot-%06d.pkl" % (cur_nimg // 1000), ) misc.save_pkl((G, D, Gs), pkl) metrics.run( pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config, ) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update(cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, "network-final.pkl")) summary_log.close() ctx.close()
def training_loop( G_args = {}, # Options for generator network. D_args = {}, # Options for discriminator network. G_opt_args = {}, # Options for generator optimizer. D_opt_args = {}, # Options for discriminator optimizer. G_loss_args = {}, # Options for generator loss. D_loss_args = {}, # Options for discriminator loss. dataset_args = {}, # Options for dataset.load_dataset(). sched_args = {}, # Options for train.TrainingSchedule. grid_args = {}, # Options for train.setup_snapshot_image_grid(). setname = None, # Model name tf_config = {}, # Options for tflib.init_tf(). G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. lazy_regularization = True, # Perform regularization as a separate training step? G_reg_interval = 4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval = 16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg = 25000, # Total length of the training, measured in thousands of real images. mirror_augment = False, # Enable mirror augment? mirror_augment_v = False, # Enable mirror augment vertically? drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks = 50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks = 50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms = False, # Include weight histograms in the tfevents file? resume_pkl = 'latest', # Network pickle to resume training from, None = train from scratch. resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time = 0.0, # Assumed wallclock time at the beginning. Affects reporting. restore_partial_fn = None, # Filename of network for partial restore resume_with_new_nets = False): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(verbose=True, **dataset_args) # custom resolution - for saved model name below resolution = training_set.resolution if training_set.init_res != [4,4]: init_res_str = '-%dx%d' % (training_set.init_res[0], training_set.init_res[1]) else: init_res_str = '' ext = 'png' if training_set.shape[0] == 4 else 'jpg' print(' model base resolution', resolution) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('_reals.%s'%ext), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print(' Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') if resume_pkl is not None: if resume_pkl == 'latest': resume_pkl, resume_kimg = misc.locate_latest_pkl(dnnlib.submit_config.run_dir_root) elif resume_pkl == 'restore_partial': print(' Restore partially...') # Initialize networks G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') # Load pre-trained networks assert restore_partial_fn != None G_partial, D_partial, Gs_partial = pickle.load(open(restore_partial_fn, 'rb')) # Restore (subset of) pre-trained weights (only parameters that match both name and shape) G.copy_compatible_trainables_from(G_partial) D.copy_compatible_trainables_from(D_partial) Gs.copy_compatible_trainables_from(Gs_partial) else: if resume_pkl is not None and resume_kimg == 0: resume_pkl, resume_kimg = misc.locate_latest_pkl(resume_pkl) print(' Loading networks from "%s", kimg %.3g' % (resume_pkl, resume_kimg)) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G, D, Gs = rG, rD, rGs # Print layers if needed and generate initial image snapshot # G.print_layers(); D.print_layers() sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.%s'%ext), drange=drange_net, grid_size=grid_size) # Setup training inputs. print(' Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]: args['minibatch_multiplier'] = minibatch_multiplier args['learning_rate'] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args['learning_rate'] *= mb_ratio if 'beta1' in args: args['beta1'] **= mb_ratio if 'beta2' in args: args['beta2'] **= mb_ratio G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args) reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, training_set.dynamic_range, drange_net) reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)] if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope('G_loss'): G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope('D_loss'): D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() # print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms(); D.setup_weight_histograms() print(' Training for %d kimg (%d left) \n' % (total_kimg, total_kimg-resume_kimg)) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu) # , sched.lod if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu} for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0) run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 cur_time = time.time() tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time if sched.lod == 0: left_kimg = total_kimg - cur_nimg / 1000 left_sec = left_kimg * tick_time / tick_kimg finaltime = time.asctime(time.localtime(cur_time + left_sec)) msg_final = '%ss left till %s ' % (shortime(left_sec), finaltime[11:16]) else: msg_final = '' # Report progress. # print('tick %-4d kimg %-6.1f lod %-5.2f minibch %-3d:%d time %-8s min/tick %-6.3g %s sec/kimg %-7.3g gpumem %-4.1f %d lr %.2g ' % ( print('tick %-4d kimg %-6.1f time %-8s %s min/tick %-6.3g sec/kimg %-7.3g gpumem %-4.1f lr %.2g ' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), # autosummary('Progress/lod', sched.lod), # autosummary('Progress/minibatch', sched.minibatch_size), # autosummary('Progress/minibatch_gpu', sched.minibatch_gpu), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), msg_final, autosummary('Timing/min_per_tick', tick_time / 60), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), # autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30), sched.G_lrate)) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=drange_net, grid_size=grid_size) if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('snapshot-%d-%s%s-%04d.pkl' % (resolution, setname[-1], init_res_str, cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-%04d.pkl' % (setname[:-1], resolution, setname[-1], init_res_str, cur_nimg // 1000))) # Update summaries and RunContext. tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('snapshot-%d-%s%s-final.pkl' % (resolution, setname[-1], init_res_str))) misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-final.pkl' % (setname[:-1], resolution, setname[-1], init_res_str))) # All done. summary_log.close() training_set.close()
def training_loop_vae( G_args={}, # Options for generator network. E_args={}, # Options for encoder network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. minibatch_repeats=1, # Number of minibatches to run before adjusting training parameters. total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, # Construct new networks according to G_args and D_args before resuming training? traversal_grid=False, # Used for disentangled representation learning. n_discrete=0, # Number of discrete latents in model. n_continuous=4, # Number of continuous latents in model. topk_dims_to_show=20, # Number of top disentant dimensions to show in a snapshot. subgroup_sizes_ls=None, subspace_sizes_ls=None, forward_eg=False, n_samples_per=10): # Number of samples for each line in traversal. # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # If use Discriminator. use_D = D_args is not None # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) grid_fakes = add_outline(grid_reals, width=1) misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) # Construct or load networks. with tf.device('/gpu:0'): if resume_pkl is None or resume_with_new_nets: print('Constructing networks...') E = tflib.Network('E', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None] + training_set.shape, **E_args) G = tflib.Network( 'G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None, n_discrete + G_args.latent_size] if not forward_eg else [ None, n_discrete + G_args.latent_size + sum(subgroup_sizes_ls) ], **G_args) if use_D: D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, input_shape=[None, D_args.latent_size], **D_args) if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) if use_D: rE, rG, rD = misc.load_pkl(resume_pkl) else: rE, rG = misc.load_pkl(resume_pkl) if resume_with_new_nets: E.copy_vars_from(rE) G.copy_vars_from(rG) if use_D: D.copy_vars_from(rD) else: E = rE G = rG if use_D: D = rD # Print layers and generate initial image snapshot. E.print_layers() G.print_layers() if use_D: D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) if traversal_grid: if topk_dims_to_show > 0: topk_dims = np.arange(min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) print('grid_size:', grid_size) print('grid_latents.shape:', grid_latents.shape) print('grid_labels.shape:', grid_labels.shape) grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v( G.run(append_gfeats(grid_latents, G) if forward_eg else grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True), 8) print('Lie_vars:', lie_vars[0]) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size) # Setup training inputs. print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) # Setup optimizers. G_opt_args = dict(G_opt_args) G_opt_args['minibatch_multiplier'] = minibatch_multiplier G_opt_args['learning_rate'] = lrate_in G_opt = tflib.Optimizer(name='TrainG', **G_opt_args) if use_D: D_opt_args = dict(D_opt_args) D_opt_args['minibatch_multiplier'] = minibatch_multiplier D_opt_args['learning_rate'] = lrate_in D_opt = tflib.Optimizer(name='TrainD', **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): # Create GPU-specific shadow copies of G and D. E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow') G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') if use_D: D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') # Fetch training data via temporary variables. with tf.name_scope('DataFetch'): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape)) labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([ sched.minibatch_gpu, training_set.label_size ])) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, 0., mirror_augment, training_set.dynamic_range, drange_net) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. if use_D: with tf.name_scope('G_loss'): G_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **G_loss_args) with tf.name_scope('D_loss'): D_loss = dnnlib.util.call_func_by_name( E=E_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) else: with tf.name_scope('G_loss'): G_loss = dnnlib.util.call_func_by_name( E=E_gpu, G=G_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **G_loss_args) # Register gradients. EG_gpu_trainables = collections.OrderedDict( list(E_gpu.trainables.items()) + list(G_gpu.trainables.items())) G_opt.register_gradients(tf.reduce_mean(G_loss), EG_gpu_trainables) # G_opt.register_gradients(G_loss, # EG_gpu_trainables) if use_D: D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # D_opt.register_gradients(D_loss, # D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() if use_D: D_train_op = D_opt.apply_updates() # Finalize graph. with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print('Initializing logs...') summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() if use_D: D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training for %d kimg...\n' % total_kimg) dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, 0) # Run training ops. feed_dict = { lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op], feed_dict) tflib.run([data_fetch_op], feed_dict) if use_D: tflib.run([D_train_op], feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) if use_D: tflib.run(D_train_op, feed_dict) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/minibatch', sched.minibatch_size), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000)) if use_D: misc.save_pkl((E, G, D), pkl) else: misc.save_pkl((E, G), pkl) met_outs = metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config, is_vae=True, use_D=use_D, Gs_kwargs=dict(is_validation=True)) if topk_dims_to_show > 0: if 'tpl_per_dim' in met_outs: avg_distance_per_dim = met_outs[ 'tpl_per_dim'] # shape: (n_continuous) topk_dims = np.argsort( avg_distance_per_dim )[::-1][:topk_dims_to_show] # shape: (20) else: topk_dims = np.arange( min(topk_dims_to_show, n_continuous)) else: topk_dims = np.arange(n_continuous) if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): if traversal_grid: grid_size, grid_latents, grid_labels = get_grid_latents( n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims) else: grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v( G.run(append_gfeats(grid_latents, G) if forward_eg else grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, randomize_noise=True), 8) print('Lie_vars:', lie_vars[0]) grid_fakes = add_outline(grid_fakes, width=1) misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path( 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update('%.2f' % 0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get( ).get_last_update_interval() - tick_time # Save final snapshot. if use_D: misc.save_pkl((E, G, D), dnnlib.make_run_dir_path('network-final.pkl')) else: misc.save_pkl((E, G), dnnlib.make_run_dir_path('network-final.pkl')) # All done. summary_log.close() training_set.close()
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=10000.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: #print('Constructing networks...') #G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) #D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) #Gs = G.clone('Gs') url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: G, D, Gs = pickle.load(f) print('Loading pretrained FFHQ network') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) cmd = "gsutil cp " + os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)) + " gs://stylegan_out" response = subprocess.run(cmd, shell=True) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close()
def training_loop( G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). data_dir=None, # Directory to load datasets from. G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. lazy_regularization=True, # Perform regularization as a separate training step? G_reg_interval=4, # How often the perform regularization for G? Ignored if lazy_regularization=False. D_reg_interval=16, # How often the perform regularization for D? Ignored if lazy_regularization=False. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=25000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=50, # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'. network_snapshot_ticks=50, # How often to save network snapshots? None = only save 'networks-final.pkl'. save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_pkl=None, # Network pickle to resume training from, None = train from scratch. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0, # Assumed wallclock time at the beginning. Affects reporting. resume_with_new_nets=False, ): # Construct new networks according to G_args and D_args before resuming training? # Initialize dnnlib and TensorFlow. tflib.init_tf(tf_config) num_gpus = dnnlib.submit_config.num_gpus # Load training set. training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args) grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid( training_set, **grid_args) misc.save_image_grid( grid_reals, dnnlib.make_run_dir_path("reals.png"), drange=training_set.dynamic_range, grid_size=grid_size, ) # Construct or load networks. with tf.device("/gpu:0"): if resume_pkl is None or resume_with_new_nets: print("Constructing networks...") G = tflib.Network("G", num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network("D", num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone("Gs") if resume_pkl is not None: print('Loading networks from "%s"...' % resume_pkl) rG, rD, rGs = misc.load_pkl(resume_pkl) if resume_with_new_nets: G.copy_vars_from(rG) D.copy_vars_from(rD) Gs.copy_vars_from(rGs) else: G = rG D = rD Gs = rGs # Print layers and generate initial image snapshot. G.print_layers() D.print_layers() sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, **sched_args) grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:]) grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, ) misc.save_image_grid( grid_fakes, dnnlib.make_run_dir_path("fakes_init.png"), drange=drange_net, grid_size=grid_size, ) # Setup training inputs. print("Building TensorFlow graph...") with tf.name_scope("Inputs"), tf.device("/cpu:0"): lod_in = tf.placeholder(tf.float32, name="lod_in", shape=[]) lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[]) minibatch_size_in = tf.placeholder(tf.int32, name="minibatch_size_in", shape=[]) minibatch_gpu_in = tf.placeholder(tf.int32, name="minibatch_gpu_in", shape=[]) minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus) Gs_beta = (0.5**tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0) # Setup optimizers. G_opt_args = dict(G_opt_args) D_opt_args = dict(D_opt_args) for args, reg_interval in [ (G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval), ]: args["minibatch_multiplier"] = minibatch_multiplier args["learning_rate"] = lrate_in if lazy_regularization: mb_ratio = reg_interval / (reg_interval + 1) args["learning_rate"] *= mb_ratio if "beta1" in args: args["beta1"] **= mb_ratio if "beta2" in args: args["beta2"] **= mb_ratio G_opt = tflib.Optimizer(name="TrainG", **G_opt_args) D_opt = tflib.Optimizer(name="TrainD", **D_opt_args) G_reg_opt = tflib.Optimizer(name="RegG", share=G_opt, **G_opt_args) D_reg_opt = tflib.Optimizer(name="RegD", share=D_opt, **D_opt_args) # Build training graph for each GPU. data_fetch_ops = [] for gpu in range(num_gpus): with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu): # Create GPU-specific shadow copies of G and D. G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow") D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow") # Fetch training data via temporary variables. with tf.name_scope("DataFetch"): sched = training_schedule(cur_nimg=int(resume_kimg * 1000), training_set=training_set, **sched_args) reals_var = tf.Variable( name="reals", trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape), ) labels_var = tf.Variable( name="labels", trainable=False, initial_value=tf.zeros( [sched.minibatch_gpu, training_set.label_size]), ) reals_write, labels_write = training_set.get_minibatch_tf() reals_write, labels_write = process_reals( reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net, ) reals_write = tf.concat( [reals_write, reals_var[minibatch_gpu_in:]], axis=0) labels_write = tf.concat( [labels_write, labels_var[minibatch_gpu_in:]], axis=0) data_fetch_ops += [tf.assign(reals_var, reals_write)] data_fetch_ops += [tf.assign(labels_var, labels_write)] reals_read = reals_var[:minibatch_gpu_in] labels_read = labels_var[:minibatch_gpu_in] # Evaluate loss functions. lod_assign_ops = [] if "lod" in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars["lod"], lod_in)] if "lod" in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars["lod"], lod_in)] with tf.control_dependencies(lod_assign_ops): with tf.name_scope("G_loss"): G_loss, G_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args) with tf.name_scope("D_loss"): D_loss, D_reg = dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args) # Register gradients. if not lazy_regularization: if G_reg is not None: G_loss += G_reg if D_reg is not None: D_loss += D_reg else: if G_reg is not None: G_reg_opt.register_gradients( tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables) if D_reg is not None: D_reg_opt.register_gradients( tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) # Setup training ops. data_fetch_op = tf.group(*data_fetch_ops) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() G_reg_op = G_reg_opt.apply_updates(allow_no_op=True) D_reg_op = D_reg_opt.apply_updates(allow_no_op=True) Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) # Finalize graph. with tf.device("/gpu:0"): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) tflib.init_uninitialized_vars() print("Initializing logs...") summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path()) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print("Training for %d kimg...\n" % total_kimg) dnnlib.RunContext.get().update("", cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = dnnlib.RunContext.get().get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = -1 tick_start_nimg = cur_nimg prev_lod = -1.0 running_mb_counter = 0 while cur_nimg < total_kimg * 1000: if dnnlib.RunContext.get().should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args) assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0 training_set.configure(sched.minibatch_gpu, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. feed_dict = { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu, } for _repeat in range(minibatch_repeats): rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus) run_G_reg = lazy_regularization and running_mb_counter % G_reg_interval == 0 run_D_reg = lazy_regularization and running_mb_counter % D_reg_interval == 0 cur_nimg += sched.minibatch_size running_mb_counter += 1 # Fast path without gradient accumulation. if len(rounds) == 1: tflib.run([G_train_op, data_fetch_op], feed_dict) if run_G_reg: tflib.run(G_reg_op, feed_dict) tflib.run([D_train_op, Gs_update_op], feed_dict) if run_D_reg: tflib.run(D_reg_op, feed_dict) # Slow path with gradient accumulation. else: for _round in rounds: tflib.run(G_train_op, feed_dict) if run_G_reg: for _round in rounds: tflib.run(G_reg_op, feed_dict) tflib.run(Gs_update_op, feed_dict) for _round in rounds: tflib.run(data_fetch_op, feed_dict) tflib.run(D_train_op, feed_dict) if run_D_reg: for _round in rounds: tflib.run(D_reg_op, feed_dict) # Perform maintenance tasks once per tick. done = cur_nimg >= total_kimg * 1000 if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = dnnlib.RunContext.get().get_time_since_last_update() total_time = dnnlib.RunContext.get().get_time_since_start( ) + resume_time # Report progress. print( "tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f" % ( autosummary("Progress/tick", cur_tick), autosummary("Progress/kimg", cur_nimg / 1000.0), autosummary("Progress/lod", sched.lod), autosummary("Progress/minibatch", sched.minibatch_size), dnnlib.util.format_time( autosummary("Timing/total_sec", total_time)), autosummary("Timing/sec_per_tick", tick_time), autosummary("Timing/sec_per_kimg", tick_time / tick_kimg), autosummary("Timing/maintenance_sec", maintenance_time), autosummary("Resources/peak_gpu_mem_gb", peak_gpu_mem_op.eval() / 2**30), )) autosummary("Timing/total_hours", total_time / (60.0 * 60.0)) autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if image_snapshot_ticks is not None and ( cur_tick % image_snapshot_ticks == 0 or done): grid_fakes = Gs.run( grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu, ) misc.save_image_grid( grid_fakes, dnnlib.make_run_dir_path("fakes%06d.png" % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size, ) if network_snapshot_ticks is not None and ( cur_tick % network_snapshot_ticks == 0 or done): pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run( pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config, ) # Update summaries and RunContext. metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) dnnlib.RunContext.get().update("%.2f" % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ( dnnlib.RunContext.get().get_last_update_interval() - tick_time) # Save final snapshot. misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path("network-final.pkl")) # All done. summary_log.close() training_set.close()