def training_loop(config: Config): timer = Timer() print("Start task {}".format(config.task_name)) strategy = tf.distribute.MirroredStrategy() print('Loading Imagenet2012 dataset...') # dataset = load_from_h5(root=config.h5root, batch_size=config.batch_size) dataset, fixed_img = build_np_dataset(root=config.h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums, load_in_mem=config.load_in_mem) dataset = strategy.experimental_distribute_dataset(dataset) dataset = dataset.make_initializable_iterator() with strategy.scope(): global_step = tf.get_variable( name='global_step', initializer=tf.constant(0), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) print("Constructing networks...") fixed_x = tf.placeholder(tf.float32, [None, 128, 128, 3]) Generator = resnet_biggan.Generator( image_shape=[128, 128, 3], embed_y=False, embed_z=False, batch_norm_fn=arch_ops.self_modulated_batch_norm, spectral_norm=True) Discriminator = resnet_biggan.Discriminator(spectral_norm=True, project_y=False) # Despite Z_embed is out of Generator, it is viewed as part of Generator Z_embed = dense(120, False, name='embed_z', scope=Generator.name) D_embed = dense(120, True, name='embed_d', scope='Embed_D') learning_rate = tf.train.exponential_decay(config.lr, global_step, 60000, 0.8, staircase=False) Embed_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='d_opt', beta1=0.0, beta2=config.beta2) print("Building tensorflow graph...") def train_step(image): z = tf.random.normal( [config.batch_size // config.gpu_nums, config.dim_z], stddev=1.0, name='sample_z') w = Z_embed(z) fake = Generator(w, y=None, is_training=True) fake_out, fake_logits, fake_h = Discriminator(x=fake, y=None, is_training=True) real_out, real_logits, real_h = Discriminator(x=image, y=None, is_training=True) fake_w = D_embed(fake_h) real_w = D_embed(real_h) # x is the reconstruction of image x = Generator(real_w, None, True) _, real_logits_, real_h_ = Discriminator(x, None, True) d_loss = tf.reduce_mean(tf.nn.relu(1.0 - real_logits_)) with tf.variable_scope('recon_loss'): recon_loss_pixel = tf.reduce_mean(tf.square(real_h - real_h_)) sample_loss = tf.reduce_mean( tf.square(w - fake_w)) * config.s_loss_scale final_loss = d_loss + sample_loss * config.alpha + recon_loss_pixel * config.beta add_global = global_step.assign_add(1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([add_global] + update_ops): Embed_opt = Embed_solver.minimize( final_loss, var_list=D_embed.trainable_variables) with tf.control_dependencies([Embed_opt]): return tf.identity(final_loss), tf.identity(d_loss), tf.identity(recon_loss_pixel),\ tf.identity(sample_loss) final_loss, d_loss, r_loss, s_loss = compute_loss( train_step, dataset.get_next(), strategy) print("Building eval module...") with tf.init_scope(): _, _, fixed_h = Discriminator(fixed_x, None, True) fixed_w = D_embed(fixed_h) fixed_sample = Generator(z=fixed_w, y=None, is_training=True) print('Building init module...') with tf.init_scope(): init = [tf.global_variables_initializer(), dataset.initializer] restore_g = [ v for v in tf.global_variables() if 'opt' not in v.name and 'beta1_power' not in v.name and 'beta2_power' not in v.name and 'generator' in v.name ] restore_d = [ v for v in tf.global_variables() if 'opt' not in v.name and 'beta1_power' not in v.name and 'beta2_power' not in v.name and 'discriminator' in v.name ] saver_g = tf.train.Saver(restore_g, restore_sequentially=True) saver_d = tf.train.Saver(restore_d, restore_sequentially=True) saver_embed = tf.train.Saver(var_list=D_embed.trainable_variables) print("Start training...") with tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) as sess: sess.run(init) print("Restore generator...") saver_g.restore(sess, config.restore_g_dir) saver_d.restore(sess, config.restore_d_dir) save_image_grid(fixed_img, filename=config.model_dir + '/reals.png') timer.update() print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") for iteration in range(config.total_step): final_loss_, d_loss_, r_loss_, s_loss_, lr_ = sess.run( [final_loss, d_loss, r_loss, s_loss, learning_rate]) if iteration % config.print_loss_per_steps == 0: timer.update() print( "step %d, final_loss %f, d_loss %f, r_loss %f, s_loss %f, " "learning_rate % f, consuming time %s" % (iteration, final_loss_, d_loss_, r_loss_, s_loss_, lr_, timer.runing_time_format)) if iteration % config.eval_per_steps == 0: timer.update() fixed_ = sess.run(fixed_sample, {fixed_x: fixed_img}) save_image_grid(fixed_, filename=config.model_dir + '/fakes%06d.png' % iteration) if iteration % config.save_per_steps == 0: saver_embed.save(sess, save_path=config.model_dir + '/embed.ckpt', global_step=iteration, write_meta_graph=False)
def training_loop(config: Config): timer = Timer() print("Start task {}".format(config.task_name)) strategy = tf.distribute.MirroredStrategy() data_iter_params = {"batch_size": config.batch_size, "seed": config.seed} with strategy.scope(): # ema = tf.train.ExponentialMovingAverage(0.999) global_step = tf.get_variable( name='global_step', initializer=tf.constant(0), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) dataset = get_dataset( name=config.dataset, seed=config.seed).train_input_fn(params=data_iter_params) dataset = strategy.experimental_distribute_dataset(dataset) data_iter = dataset.make_initializable_iterator() print("Constructing networks...") InvMap = invert.InvMap(latent_size=config.dim_z) Generator = resnet_biggan.Generator( image_shape=[128, 128, 3], embed_y=False, embed_z=True, batch_norm_fn=arch_ops.self_modulated_batch_norm, spectral_norm=True) Discriminator = resnet_biggan.Discriminator(spectral_norm=True, project_y=False) I_opt = tf.train.AdamOptimizer(learning_rate=0.0005, name='i_opt', beta1=0.0, beta2=0.999) G_opt = tf.train.AdamOptimizer(learning_rate=0.00001, name='g_opt', beta1=0.0, beta2=0.999) D_opt = tf.train.AdamOptimizer(learning_rate=0.00005, name='d_opt', beta1=0.0, beta2=0.999) train_z = tf.random.normal( [config.batch_size // config.gpu_nums, config.dim_z], stddev=1.0, name='train_z') # eval_z = tf.random.uniform([config.batch_size // config.gpu_nums, config.dim_z], # minval=-1.0, maxval=1.0, name='eval_z') # eval_z = tf.placeholder(tf.float32, name='eval_z') fixed_sample_z = tf.placeholder(tf.float32, name='fixed_sample_z') print("Building tensorflow graph...") def train_step(training_who="G", step=None, z=None, data=None): img, labels = data w = InvMap(z) samples = Generator(z=w, y=None, is_training=True) d_real, d_real_logits, _ = Discriminator(x=img, y=None, is_training=True) d_fake, d_fake_logits, _ = Discriminator(x=samples, y=None, is_training=True) d_loss, _, _, g_loss = loss_lib.get_losses( d_real=d_real, d_fake=d_fake, d_real_logits=d_real_logits, d_fake_logits=d_fake_logits) inception_score = tfmetric.call_metric(run_dir_root=config.run_dir, name="is", images=samples) fid = tfmetric.call_metric(run_dir_root=config.run_dir, name="fid", reals=img, fakes=samples) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): if training_who == "G": train_op = tf.group( G_opt.minimize(g_loss, var_list=Generator.trainable_variables, global_step=step), I_opt.minimize(g_loss, var_list=InvMap.trainable_variables, global_step=step)) # decay = config.ema_decay * tf.cast( # tf.greater_equal(step, config.ema_start_step), tf.float32) # with tf.variable_scope('', reuse=tf.AUTO_REUSE): # ema = tf.train.ExponentialMovingAverage(decay=decay) # with tf.control_dependencies([train_op]): # train_op = ema.apply(Generator.trainable_variables + InvMap.trainable_variables) with tf.control_dependencies([train_op]): return tf.identity(g_loss), inception_score, fid else: train_op = D_opt.minimize( d_loss, var_list=Discriminator.trainable_variables, global_step=step) with tf.control_dependencies([train_op]): return tf.identity(d_loss), inception_score, fid # def eval_step(z, data=None): # img, _ = data # # with tf.variable_scope('', reuse=tf.AUTO_REUSE): # # ema = tf.train.ExponentialMovingAverage(decay=0.999) # # ema.apply(Generator.trainable_variables + InvMap.trainable_variables) # # # # def ema_getter(getter, name, *args, **kwargs): # # var = getter(name, *args, **kwargs) # # ema_var = ema.average(var) # # if ema_var is None: # # var_names_without_ema = {"u_var", "accu_mean", "accu_variance", # # "accu_counter", "update_accus"} # # if name.split("/")[-1] not in var_names_without_ema: # # logging.warning("Could not find EMA variable for %s.", name) # # return var # # return ema_var # # with tf.variable_scope("", values=[z, img], reuse=tf.AUTO_REUSE, # # custom_getter=ema_getter): # w = InvMap(z) # sampled = Generator(z=w, y=None, is_training=False) # inception_score = tfmetric.call_metric(run_dir_root=config.run_dir, # name="is", # images=sampled) # fid = tfmetric.call_metric(run_dir_root=config.run_dir, # name="fid", # reals=img, # fakes=sampled) # return inception_score, fid, sampled g_loss, d_loss, IS, FID = compute_loss(train_step, strategy, global_step, train_z, data_iter) print("Building eval module...") with tf.init_scope(): # IS, FID, eval_sample = compute_eval(eval_step, strategy, eval_z, data_iter) fixed_sample_w = InvMap(fixed_sample_z) eval_sample = Generator(z=fixed_sample_w, y=None, is_training=False) print('Building init module...') with tf.init_scope(): init = [tf.global_variables_initializer(), data_iter.initializer] restore_g = [ v for v in tf.global_variables() if 'opt' not in v.name and 'beta1_power' not in v.name and 'beta2_power' not in v.name and 'generator' in v.name ] restore_d = [ v for v in tf.global_variables() if 'opt' not in v.name and 'beta1_power' not in v.name and 'beta2_power' not in v.name and 'discriminator' in v.name ] saver_g = tf.train.Saver(restore_g, restore_sequentially=True) saver_d = tf.train.Saver(restore_d, restore_sequentially=True) print("Start training...") with tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) as sess: sess.run(init) fixed_z = np.random.uniform( low=-1.0, high=1.0, size=[config.batch_size * 2 // config.gpu_nums, config.dim_z]) print("Restore generator and discriminator...") saver_g.restore(sess, '/ghome/fengrl/gen_ckpt/gen-0') saver_d.restore(sess, '/ghome/fengrl/disc_ckpt/disc-0') print("Start iterations...") for iteration in range(config.total_step): for D_repeat in range(config.disc_iter): D_loss = sess.run(d_loss) G_loss = sess.run(g_loss) if iteration % config.print_loss_per_steps == 0: print("step %d, G_loss %f, D_loss %f" % (iteration, G_loss, D_loss)) if iteration % config.eval_per_steps == 0: timer.update() fixed_sample = sess.run(eval_sample, {fixed_sample_z: fixed_z}) save_image_grid(fixed_sample, filename=config.model_dir + '/fakes%06d.png' % iteration) is_eval, fid_eval = sess.run([IS, FID]) print( "Time %s, fid %f, inception_score %f , G_loss %f, D_loss %f, step %d" % (timer.runing_time, fid_eval, is_eval, G_loss, D_loss, iteration)) if iteration % config.save_per_steps == 0: saver_g.save(sess, save_path=config.model_dir + '/gen.ckpt', global_step=iteration, write_meta_graph=False) saver_d.save(sess, save_path=config.model_dir + '/disc.ckpt', global_step=iteration, write_meta_graph=False)
def training_loop(config: Config): timer = Timer() print("Start task {}".format(config.task_name)) strategy = tf.distribute.MirroredStrategy() print('Loading Imagenet2012 dataset...') # dataset = load_from_h5(root=config.h5root, batch_size=config.batch_size) dataset, _, fixed_img = datasets.build_data_input_pipeline_from_hdf5( root=config.h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums, load_in_mem=config.load_in_mem, labeled_per_class=config.labeled_per_class, save_index_dir=config.model_dir) dataset = strategy.experimental_distribute_dataset(dataset) dataset = dataset.make_initializable_iterator() eval_dset = datasets.build_eval_dset(config.eval_h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums) eval_dset = strategy.experimental_distribute_dataset(eval_dset) eval_dset = eval_dset.make_initializable_iterator() with strategy.scope(): global_step = tf.get_variable( name='global_step', initializer=tf.constant(0), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) print("Constructing networks...") Disccriminator = resnet_biggan.Discriminator(spectral_norm=True, project_y=False) Dense = tf.layers.Dense(1000, name='Final_dense') learning_rate = tf.train.exponential_decay(config.lr, global_step, 60000, 0.8, staircase=False) Dense_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='e_opt', beta2=config.beta2) print("Building tensorflow graph...") def train_step(image, label): _, _, w = Disccriminator(image, None, True) w = Dense(w) label = tf.one_hot(label, 1000) loss = tf.nn.softmax_cross_entropy_with_logits_v2(label, w) loss = tf.reduce_mean(loss) add_global = global_step.assign_add(1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([add_global] + update_ops): Dense_opt = Dense_solver.minimize( loss, var_list=Dense.trainable_variables) with tf.control_dependencies([Dense_opt]): return tf.identity(loss) loss_run = strategy.experimental_run_v2(train_step, dataset.get_next()) loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss_run, axis=None) print("Building eval module...") def eval_step(image, label): _, _, w = Disccriminator(image, None, is_training=True) w = Dense(w) p = tf.math.argmax(w, 1) p = tf.cast(p, tf.int32) precise = tf.reduce_mean(tf.cast(tf.equal(p, label), tf.float32)) return precise precise = strategy.experimental_run_v2(eval_step, eval_dset.get_next()) precise = strategy.reduce(tf.distribute.ReduceOp.MEAN, precise, axis=None) print('Building init module...') with tf.init_scope(): init = [ tf.global_variables_initializer(), dataset.initializer, eval_dset.initializer ] restore_d = [ v for v in tf.global_variables() if 'opt' not in v.name and 'beta1_power' not in v.name and 'beta2_power' not in v.name and 'discriminator' in v.name ] saver_d = tf.train.Saver(restore_d, restore_sequentially=True) saver_dense = tf.train.Saver(Dense.trainable_variables, restore_sequentially=True) print("Start training...") with tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) as sess: sess.run(init) print("Restore Encoder...") saver_d.restore(sess, config.restore_d_dir) timer.update() print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") for iteration in range(config.total_step): loss_, lr_ = sess.run([loss, learning_rate]) if iteration % config.print_loss_per_steps == 0: timer.update() print( "step %d, loss %f, learning_rate % f, consuming time %s" % (iteration, loss_, lr_, timer.runing_time_format)) if iteration % config.eval_per_steps == 0: timer.update() print('Starting eval...') precise_ = 0.0 eval_iters = 50000 // config.batch_size for _ in range(2 * eval_iters): precise_ += sess.run(precise) precise_ = precise_ / (2 * eval_iters) timer.update() print('Eval consuming time %s' % timer.duration_format) print( 'step %d, precision %f in eval dataset of length %d' % (iteration, precise_, eval_iters * config.batch_size)) if iteration % config.save_per_steps == 0: saver_dense.save(sess, save_path=config.model_dir + '/dense.ckpt', global_step=iteration, write_meta_graph=False)
def training_loop(config: Config): timer = Timer() print("Start task {}".format(config.task_name)) strategy = tf.distribute.MirroredStrategy() print('Loading Imagenet2012 dataset...') # dataset = load_from_h5(root=config.h5root, batch_size=config.batch_size) dataset = build_np_dataset(root=config.h5root, batch_size=config.batch_size, gpu_nums=config.gpu_nums) dataset = dataset.make_initializable_iterator() with strategy.scope(): global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) # dataset = get_dataset(name=config.dataset, # seed=config.seed).train_input_fn(params=data_iter_params) # dataset = strategy.experimental_distribute_dataset(dataset) # data_iter = dataset.make_initializable_iterator() print("Constructing networks...") # img = tf.placeholder(tf.float32, [None, 128, 128, 3]) fixed_x = tf.placeholder(tf.float32, [None, 128, 128, 3]) img = dataset.get_next() Encoder = ImagenetModel(resnet_size=50, num_classes=120, name='Encoder') VGG_alter = ImagenetModel(resnet_size=50, num_classes=120, name='vgg_alter') Generator = resnet_biggan.Generator(image_shape=[128, 128, 3], embed_y=False, embed_z=False, batch_norm_fn=arch_ops.self_modulated_batch_norm, spectral_norm=True) Discriminator = resnet_biggan.Discriminator(spectral_norm=True, project_y=False) learning_rate = tf.train.exponential_decay(0.0001, global_step, 150000 / config.gpu_nums, 0.8, staircase=False) E_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='e_opt', beta2=config.beta2) # D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate * 5, name='d_opt', beta1=config.beta1) print("Building tensorflow graph...") def train_step(image): w = Encoder(image, training=True) x = Generator(w, y=None, is_training=True) # _, real_logits, _ = Discriminator(img, y=None, is_training=True) _, fake_logits, _ = Discriminator(x, y=None, is_training=True) # real_logits = fp32(real_logits) fake_logits = fp32(fake_logits) with tf.variable_scope('recon_loss'): recon_loss_pixel = tf.reduce_mean(tf.square(x - image)) adv_loss = tf.reduce_mean(tf.nn.softplus(-fake_logits)) * config.g_loss_scale vgg_real = VGG_alter(image, training=True) vgg_fake = VGG_alter(x, training=True) feature_scale = tf.cast(tf.reduce_prod(vgg_real.shape[1:]), dtype=tf.float32) vgg_loss = config.r_loss_scale * tf.nn.l2_loss(vgg_fake - vgg_real) / (config.batch_size * feature_scale) e_loss = recon_loss_pixel + adv_loss + vgg_loss # with tf.variable_scope('d_loss'): # d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - real_logits)) # d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + fake_logits)) # d_loss = d_loss_real + d_loss_fake add_global = global_step.assign_add(1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([add_global] + update_ops): E_opt = E_solver.minimize(e_loss, var_list=Encoder.trainable_variables) with tf.control_dependencies([E_opt]): return tf.identity(e_loss) e_loss = compute_loss(train_step, dataset.get_next(), strategy) print("Building eval module...") with tf.init_scope(): fixed_w = Encoder(fixed_x, training=False) fixed_sample = Generator(z=fixed_w, y=None, is_training=False) print('Building init module...') with tf.init_scope(): init = [tf.global_variables_initializer(), dataset.initializer] restore_g = [v for v in tf.global_variables() if 'opt' not in v.name and 'beta1_power' not in v.name and 'beta2_power' not in v.name and 'generator' in v.name] restore_d = [v for v in tf.global_variables() if 'opt' not in v.name and 'beta1_power' not in v.name and 'beta2_power' not in v.name and 'discriminator' in v.name] saver_g = tf.train.Saver(restore_g, restore_sequentially=True) saver_d = tf.train.Saver(restore_d, restore_sequentially=True) saver_v = tf.train.Saver(VGG_alter.trainable_variables, restore_sequentially=True) saver_e = tf.train.Saver(Encoder.trainable_variables, restore_sequentially=True) print("Start training...") with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(init) print("Restore generator and discriminator...") saver_g.restore(sess, config.restore_g_dir) saver_d.restore(sess, config.restore_d_dir) saver_v.restore(sess, config.restore_v_dir) timer.update() fixed_img = sess.run(dataset.get_next()) save_image_grid(fixed_img, filename=config.model_dir + '/reals.png') print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") for iteration in range(config.total_step): e_loss_, adv_loss_, recon_loss_pixel_, vgg_loss_, lr_ = sess.run( [e_loss, adv_loss, recon_loss_pixel, vgg_loss, learning_rate]) if iteration % config.print_loss_per_steps == 0: timer.update() print("step %d, e_loss %f, adv_loss %f, recon_loss_pixel %f, vgg_loss %f, " "learning_rate % f, consuming time %s" % (iteration, e_loss_, adv_loss_, recon_loss_pixel_, vgg_loss_, lr_, timer.runing_time_format)) if iteration % config.eval_per_steps == 0: timer.update() fixed_ = sess.run(fixed_sample, {fixed_x: fixed_img}) save_image_grid(fixed_, filename=config.model_dir + '/fakes%06d.png' % iteration) if iteration % config.save_per_steps == 0: saver_e.save(sess, save_path=config.model_dir + '/en.ckpt', global_step=iteration, write_meta_graph=False)