def D_hinge(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) fake_images_out = G.get_output_for(latents, labels, is_training=True) real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) fake_scores_out = fp32( D.get_output_for(fake_images_out, labels, is_training=True)) real_scores_out = autosummary('Loss/scores/real', real_scores_out) fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) loss = tf.maximum(0., 1. + fake_scores_out) + tf.maximum( 0., 1. - real_scores_out) return loss
def D_logistic(G, D, opt, training_set, minibatch_size, reals, labels): # pylint: disable=unused-argument latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) fake_images_out = G.get_output_for(latents, labels, is_training=True) real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) fake_scores_out = fp32( D.get_output_for(fake_images_out, labels, is_training=True)) real_scores_out = autosummary('Loss/scores/real', real_scores_out) fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) loss = tf.nn.softplus( fake_scores_out) # -log(1 - logistic(fake_scores_out)) loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type return loss
def D_wgan_gp( G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument wgan_lambda=10.0, # Weight for the gradient penalty term. wgan_epsilon=0.001, # Weight for the epsilon term, \epsilon_{drift}. wgan_target=1.0): # Target value for gradient magnitudes. latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) fake_images_out = G.get_output_for(latents, labels, is_training=True) real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) fake_scores_out = fp32( D.get_output_for(fake_images_out, labels, is_training=True)) real_scores_out = autosummary('Loss/scores/real', real_scores_out) fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) loss = fake_scores_out - real_scores_out with tf.name_scope('GradientPenalty'): mixing_factors = tf.random_uniform([minibatch_size, 1, 1, 1], 0.0, 1.0, dtype=fake_images_out.dtype) mixed_images_out = tflib.lerp(tf.cast(reals, fake_images_out.dtype), fake_images_out, mixing_factors) mixed_scores_out = fp32( D.get_output_for(mixed_images_out, labels, is_training=True)) mixed_scores_out = autosummary('Loss/scores/mixed', mixed_scores_out) mixed_loss = opt.apply_loss_scaling(tf.reduce_sum(mixed_scores_out)) mixed_grads = opt.undo_loss_scaling( fp32(tf.gradients(mixed_loss, [mixed_images_out])[0])) mixed_norms = tf.sqrt( tf.reduce_sum(tf.square(mixed_grads), axis=[1, 2, 3])) mixed_norms = autosummary('Loss/mixed_norms', mixed_norms) gradient_penalty = tf.square(mixed_norms - wgan_target) loss += gradient_penalty * (wgan_lambda / (wgan_target**2)) with tf.name_scope('EpsilonPenalty'): epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) loss += epsilon_penalty * wgan_epsilon return loss
def D_logistic_simplegp(G, D, opt, training_set, minibatch_size, reals, labels, r1_gamma=10.0, r2_gamma=0.0): # pylint: disable=unused-argument latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) fake_images_out = G.get_output_for(latents, labels, is_training=True) real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) fake_scores_out = fp32( D.get_output_for(fake_images_out, labels, is_training=True)) real_scores_out = autosummary('Loss/scores/real', real_scores_out) fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) loss = tf.nn.softplus( fake_scores_out) # -log(1 - logistic(fake_scores_out)) loss += tf.nn.softplus(-real_scores_out) # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type if r1_gamma != 0.0: with tf.name_scope('R1Penalty'): real_loss = opt.apply_loss_scaling(tf.reduce_sum(real_scores_out)) real_grads = opt.undo_loss_scaling( fp32(tf.gradients(real_loss, [reals])[0])) r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1, 2, 3]) r1_penalty = autosummary('Loss/r1_penalty', r1_penalty) loss += r1_penalty * (r1_gamma * 0.5) if r2_gamma != 0.0: with tf.name_scope('R2Penalty'): fake_loss = opt.apply_loss_scaling(tf.reduce_sum(fake_scores_out)) fake_grads = opt.undo_loss_scaling( fp32(tf.gradients(fake_loss, [fake_images_out])[0])) r2_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1, 2, 3]) r2_penalty = autosummary('Loss/r2_penalty', r2_penalty) loss += r2_penalty * (r2_gamma * 0.5) return loss
def D_wgan( G, D, opt, training_set, minibatch_size, reals, labels, # pylint: disable=unused-argument wgan_epsilon=0.001): # Weight for the epsilon term, \epsilon_{drift}. latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:]) fake_images_out = G.get_output_for(latents, labels, is_training=True) real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True)) fake_scores_out = fp32( D.get_output_for(fake_images_out, labels, is_training=True)) real_scores_out = autosummary('Loss/scores/real', real_scores_out) fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out) loss = fake_scores_out - real_scores_out with tf.name_scope('EpsilonPenalty'): epsilon_penalty = autosummary('Loss/epsilon_penalty', tf.square(real_scores_out)) loss += epsilon_penalty * wgan_epsilon return loss
def training_loop( submit_config, G_args={}, # Options for generator network. D_args={}, # Options for discriminator network. G_opt_args={}, # Options for generator optimizer. D_opt_args={}, # Options for discriminator optimizer. G_loss_args={}, # Options for generator loss. D_loss_args={}, # Options for discriminator loss. dataset_args={}, # Options for dataset.load_dataset(). sched_args={}, # Options for train.TrainingSchedule. grid_args={}, # Options for train.setup_snapshot_image_grid(). metric_arg_list=[], # Options for MetricGroup. tf_config={}, # Options for tflib.init_tf(). G_smoothing_kimg=10.0, # Half-life of the running average of generator weights. D_repeats=1, # How many times the discriminator is trained per G iteration. minibatch_repeats=4, # Number of minibatches to run before adjusting training parameters. reset_opt_for_new_lod=True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? total_kimg=15000, # Total length of the training, measured in thousands of real images. mirror_augment=False, # Enable mirror augment? drange_net=[ -1, 1 ], # Dynamic range used when feeding image data to the networks. image_snapshot_ticks=1, # How often to export image snapshots? network_snapshot_ticks=10, # How often to export network snapshots? save_tf_graph=False, # Include full TensorFlow computation graph in the tfevents file? save_weight_histograms=False, # Include weight histograms in the tfevents file? resume_run_id=None, # Run ID or network pkl to resume training from, None = start from scratch. resume_snapshot=None, # Snapshot index to resume training from, None = autodetect. resume_kimg=0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. resume_time=0.0 ): # Assumed wallclock time at the beginning. Affects reporting. # Initialize dnnlib and TensorFlow. ctx = dnnlib.RunContext(submit_config, train) tflib.init_tf(tf_config) # Load training set. training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) # Construct networks. with tf.device('/gpu:0'): if resume_run_id is not None: network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) print('Loading networks from "%s"...' % network_pkl) G, D, Gs = misc.load_pkl(network_pkl) else: print('Constructing networks...') G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) Gs = G.clone('Gs') G.print_layers() D.print_layers() print('Building TensorFlow graph...') with tf.name_scope('Inputs'), tf.device('/cpu:0'): lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) minibatch_split = minibatch_in // submit_config.num_gpus Gs_beta = 0.5**tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) for gpu in range(submit_config.num_gpus): with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') lod_assign_ops = [ tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in) ] reals, labels = training_set.get_minibatch_tf() reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) with tf.name_scope('G_loss'), tf.control_dependencies( lod_assign_ops): G_loss = gan_streamer.dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) with tf.name_scope('D_loss'), tf.control_dependencies( lod_assign_ops): D_loss = gan_streamer.dnnlib.util.call_func_by_name( G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) G_train_op = G_opt.apply_updates() D_train_op = D_opt.apply_updates() Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid( G, training_set, **grid_args) sched = training_schedule(cur_nimg=total_kimg * 1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: G.setup_weight_histograms() D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # Choose training parameters and configure training ops. sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil( sched.lod) != np.ceil(prev_lod): G_opt.reset_optimizer_state() D_opt.reset_optimizer_state() prev_lod = sched.lod # Run training ops. for _mb_repeat in range(minibatch_repeats): for _D_repeat in range(D_repeats): tflib.run( [D_train_op, Gs_update_op], { lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch }) cur_nimg += sched.minibatch tflib.run( [G_train_op], { lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch }) # Perform maintenance tasks once per tick. done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # Report progress. print( 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), gan_streamer.dnnlib.util.format_time( autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # Save snapshots. if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch // submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join( submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join( submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((G, D, Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) # Update summaries and RunContext. metrics.update_autosummaries() gan_streamer.dnnlib.tflib.autosummary.save_summaries( summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # Write final results. misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'albums.pkl')) summary_log.close() ctx.close()