def main(config): print('Loading %s dataset...' % config['dataset_name']) dset = get_dataset(config['dataset_name'], '/gdata/tfds', 2) dataset = dset.input_fn(config['batch_size'], mode='train') dataset = dataset.make_initializable_iterator() Encoder = nn.Encoder(config['dim_z'], exceptions=['opt'], name='Encoder') image, label = dataset.get_next() _, _, z = Encoder(image, is_training=True) saver = tf.train.Saver(Encoder.restore_variables) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run([tf.global_variables_initializer(), dataset.initializer]) print("Restore Encoder...") saver.restore(sess, config['model_dir'] + '/en.ckpt-248000') print('Generate embeddings...') f = open(config['model_dir'] + '/embeddings.tsv', 'wt') f_writer = csv.writer(f, delimiter='\t') g = open(config['model_dir'] + '/labels.tsv', 'wt') g_writer = csv.writer(g, delimiter='\t') for _ in tqdm(range(config['total_step'])): z_, l_ = sess.run([z, label]) for row in z_: f_writer.writerow(row) for row in l_: g_writer.writerow([row]) f.close() g.close()
def training_loop(config: Config): timer = Timer() print('Task name %s' % config.task_name) print('Loading %s dataset...' % config.dataset_name) dset = get_dataset(config.dataset_name, config.tfds_dir, config.gpu_nums * 2) dataset = train_input_fn(dset, config.batch_size) dataset = dataset.make_initializable_iterator() print("Constructing networks...") Encoder = vae.Encoder(config.dim_z, exceptions=['opt'], name='Encoder') Decoder = vae.Decoder(dset.image_shape, exceptions=['opt'], name='Decoder') print("Building tensorflow graph...") image, label = dataset.get_next() _, _, z = Encoder(image, is_training=True) sigma2_plus = compute_sigma2(z) print("Building eval module...") fixed_z = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z0 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z1 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_x = tf.placeholder(tf.float32, (config.example_nums, ) + dset.image_shape) fixed_x0 = tf.placeholder(tf.float32, (config.example_nums, ) + dset.image_shape) fixed_x1 = tf.placeholder(tf.float32, (config.example_nums, ) + dset.image_shape) input_dict = { 'fixed_z': fixed_z, 'fixed_z0': fixed_z0, 'fixed_z1': fixed_z1, 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1, 'num_midpoints': config.num_midpoints } def sample_step(): out_dict = generate_sample(Decoder, input_dict) out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict)) out_dict.update({ 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1 }) return out_dict o_dict = sample_step() print("Building init module...") with tf.init_scope(): init = [tf.global_variables_initializer(), dataset.initializer] saver_e = tf.train.Saver(Encoder.restore_variables) saver_d = tf.train.Saver(Decoder.restore_variables) print('Starting training...') with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(init) if config.resume: print("Restore vae...") saver_e.restore(sess, config.restore_e_dir) saver_d.restore(sess, config.restore_d_dir) timer.update() print('Preparing sample utils...') fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) o_dict_ = sess.run(o_dict, { fixed_x: fixed_x_, fixed_x0: fixed_x0_, fixed_x1: fixed_x1_ }) for key in o_dict: if o_dict_[key].ndim == 5: img = o_dict_[key].transpose([0, 1, 4, 2, 3]) else: img = o_dict_[key].transpose([0, 3, 1, 2]) save_image_grid(img, config.model_dir + '/%s.jpg' % key) print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") sigma2 = 0.0 count = 0 with tf.io.TFRecordWriter(config.model_dir + '/CelebA64_rep.tfrecords') as writer: while True: try: image_, label_, sigma2_plus_, rep_ = sess.run( [image, label, sigma2_plus, z]) sigma2 += sigma2_plus_ count += 1 for n in range(image_.shape[0]): tf_example = serialize_example(image_[n], label_[n], rep_[n]) writer.write(tf_example) if count % 100 == 0: timer.update() print('Complete %d bathes, consuming time %s' % (count, timer.runing_time_format)) except tf.errors.OutOfRangeError: np.save(config.model_dir + '/sigma2.npy', sigma2 / (count * config.batch_size)) print('Done!') break
def training_loop(config: Config): timer = Timer() opts = wae_opts.config_celebA print('Task name %s' % config.task_name) print('Loading %s dataset...' % config.dataset_name) dataset = load_CelebA_KNN_from_record( config.record_dir + '/CelebA64knn5_rep.tfrecords', config.batch_size) dataset = dataset.make_initializable_iterator() laplace_sigma2 = 1.0 / (-np.log(config.laplace_a)) global_step = tf.get_variable( name='global_step', initializer=tf.constant(0), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) print("Constructing networks...") Encoder = vae.Encoder(opts, exceptions=['opt'], name='VAE_En') Decoder = vae.Decoder(opts, exceptions=['opt'], name='VAE_De') Discriminator = vae.Discriminator(opts, exceptions=['opt']) valina_encoder = vae_.Encoder(256, exceptions=['opt'], name='Encoder') def lip_metric(inputs): return inputs def d_metric(inputs): _, _, outputs = valina_encoder(inputs, True) return outputs def generator(inputs): outputs = Decoder(inputs, True, False) return outputs def lip_generator(inputs): _, _, outputs = Encoder(inputs, True) return outputs PPL = ppl.PPL_mnist(epsilon=0.01, sampling='full', generator=generator, d_metric=d_metric) Lip_PPL = ppl.PPL_mnist(epsilon=0.01, sampling='full', generator=lip_generator, d_metric=lip_metric) learning_rate = tf.train.exponential_decay(config.lr, global_step, config.decay_step, config.decay_coef, staircase=False) solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='opt', beta2=config.beta2) adv_solver = tf.train.AdamOptimizer(learning_rate=2 * learning_rate, name='opt', beta1=opts['adam_beta1']) print("Building tensorflow graph...") def train_step(data): image, rep, label, neighbour, index = data mu_z, log_sigma_z, z = Encoder(image, is_training=True) x = Decoder(z, is_training=True, flatten=False) with tf.variable_scope('reconstruction_loss'): # recon_loss = - tf.reduce_mean(tf.reduce_sum( # image * tf.log(x + EPS) + (1 - image) * tf.log(1 - x + EPS), [1, 2, 3])) recon_loss = 0.05 * tf.reduce_mean( tf.reduce_sum(tf.square(image - x), [1, 2, 3])) with tf.variable_scope('smooth_loss'): mask = make_mask(neighbour, index) s_w = mask * smoother_weight( rep, 'heat', sigma2=laplace_sigma2, mask=mask) smooth_loss = batch_laplacian(s_w, z) * config.laplace_lambda s_w_mean = tf.reduce_mean( s_w) * config.batch_size * config.batch_size / ( tf.reduce_sum(mask) + EPS) with tf.variable_scope('wae_penalty'): Pz = tf.random.normal(shape=[config.batch_size, config.dim_z], mean=0.0, stddev=1.0) logits_Pz = Discriminator(Pz, True) logits_Qz = Discriminator(z, True) loss_Pz = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=logits_Pz, labels=tf.ones_like(logits_Pz))) loss_Qz = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=logits_Qz, labels=tf.zeros_like(logits_Qz))) loss_Qz_trick = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=logits_Qz, labels=tf.ones_like(logits_Qz))) loss_adv = config.wae_lambda * (loss_Pz + loss_Qz) loss_match = config.wae_lambda * loss_Qz_trick # loss = kl_divergence + recon_loss + smooth_loss loss = loss_match + recon_loss + smooth_loss add_global = global_step.assign_add(1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([add_global] + update_ops): opt = solver.minimize(loss, var_list=Encoder.trainable_variables + Decoder.trainable_variables) with tf.control_dependencies([opt]): l1, l2, l3, l4, l5 = tf.identity(loss), tf.identity(recon_loss), \ tf.identity(loss_match), tf.identity(smooth_loss), tf.identity(s_w_mean) with tf.control_dependencies([add_global] + update_ops): d_opt = adv_solver.minimize( loss_adv, var_list=Discriminator.trainable_variables) with tf.control_dependencies([d_opt]): l6 = tf.identity(loss_adv) return l1, l2, l3, l4, l5, l6 loss, r_loss, m_loss, s_loss, s_w, a_loss = train_step(dataset.get_next()) def pretrain(data): image, rep, label, neighbour, index = data mu_z, log_sigma_z, z = Encoder(image, is_training=True) Pz = tf.random.normal(shape=[config.batch_size, config.dim_z], mean=0.0, stddev=1.0) mean_pz = tf.reduce_mean(Pz, axis=0, keep_dims=True) mean_qz = tf.reduce_mean(z, axis=0, keep_dims=True) mean_loss = tf.reduce_mean(tf.square(mean_pz - mean_qz)) cov_pz = tf.matmul(Pz - mean_pz, Pz - mean_pz, transpose_a=True) / (config.batch_size - 1) cov_qz = tf.matmul(z - mean_qz, z - mean_qz, transpose_a=True) / (config.batch_size - 1) cov_loss = tf.reduce_mean(tf.square(cov_pz - cov_qz)) pretrain_loss = cov_loss + mean_loss update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt = solver.minimize(pretrain_loss, var_list=Encoder.trainable_variables) with tf.control_dependencies([opt]): p_loss = tf.identity(pretrain_loss) return p_loss p_loss = pretrain(dataset.get_next()) print("Building eval module...") fixed_z = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z0 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z1 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_x = tf.placeholder(tf.float32, (config.example_nums, ) + (64, 64, 3)) fixed_x0 = tf.placeholder(tf.float32, (config.example_nums, ) + (64, 64, 3)) fixed_x1 = tf.placeholder(tf.float32, (config.example_nums, ) + (64, 64, 3)) input_dict = { 'fixed_z': fixed_z, 'fixed_z0': fixed_z0, 'fixed_z1': fixed_z1, 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1, 'num_midpoints': config.num_midpoints } def sample_step(): out_dict = generate_sample(Decoder, input_dict) out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict)) out_dict.update({ 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1 }) return out_dict o_dict = sample_step() def eval_step(img1, img2): z0 = tf.random.normal(shape=[config.batch_size, config.dim_z], mean=0.0, stddev=1.0) z1 = tf.random.normal(shape=[config.batch_size, config.dim_z], mean=0.0, stddev=1.0) _, _, img1_z = Encoder(img1, True) _, _, img2_z = Encoder(img2, True) ppl_sample_loss = PPL(z0, z1) ppl_de_loss = PPL(img1_z, img2_z) lip_loss = Lip_PPL(img1, img2) return ppl_sample_loss, ppl_de_loss, lip_loss img_1, _, _, _, _ = dataset.get_next() img_2, _, _, _, _ = dataset.get_next() ppl_sa_loss, ppl_de_loss, lip_loss = eval_step(img_1, img_2) print("Building init module...") with tf.init_scope(): init = [tf.global_variables_initializer(), dataset.initializer] saver_e = tf.train.Saver(Encoder.restore_variables, max_to_keep=10) saver_d = tf.train.Saver(Decoder.restore_variables, max_to_keep=10) saver_v = tf.train.Saver(valina_encoder.restore_variables) print('Starting training...') with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(init) saver_v.restore(sess, config.restore_s_dir) if config.resume: print("Restore vae...") saver_e.restore(sess, config.restore_e_dir) saver_d.restore(sess, config.restore_d_dir) timer.update() print('Preparing sample utils...') fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") print("Start pretraining of Encoder...") for iteration in range(500): p_loss_ = sess.run(p_loss) if iteration % 50 == 0: print("Pretrain_step %d, p_loss %f" % (iteration, p_loss_)) print("Pretraining of Encoder Done! p_loss %f. Now start training..." % p_loss_) loss_list = [] r_loss_list = [] m_loss_list = [] s_loss_list = [] a_loss_list = [] ppl_sa_list = [] ppl_re_list = [] lip_list = [] for iteration in range(config.total_step): loss_, r_loss_, m_loss_, s_loss_, sw_sum_, lr_ = \ sess.run([loss, r_loss, m_loss, s_loss, s_w, learning_rate]) a_loss_ = sess.run(a_loss) if iteration % config.print_loss_per_steps == 0: loss_list.append(loss_) r_loss_list.append(r_loss_) m_loss_list.append(m_loss_) s_loss_list.append(s_loss_) a_loss_list.append(a_loss_) timer.update() print( "step %d, loss %f, r_loss_ %f, m_loss_ %f, s_loss_ %f, sw %f, a_loss %f " "learning_rate % f, consuming time %s" % (iteration, loss_, r_loss_, m_loss_, s_loss_, np.mean(sw_sum_), a_loss_, lr_, timer.runing_time_format)) if iteration % 1000 == 0: sa_loss_ = 0.0 de_loss_ = 0.0 lip_loss_ = 0.0 for _ in range(200): sa_p, de_p, lip_p = sess.run( [ppl_sa_loss, ppl_de_loss, lip_loss]) sa_loss_ += sa_p de_loss_ += de_p lip_loss_ += lip_p sa_loss_ /= config.batch_size * 256 de_loss_ /= config.batch_size * 256 lip_loss_ /= config.batch_size * 256 ppl_re_list.append(de_loss_) ppl_sa_list.append(sa_loss_) lip_list.append(lip_loss_) print("ppl_sample %f, ppl_resample %f, lipschitze %f" % (sa_loss_, de_loss_, lip_loss_)) if iteration % config.eval_per_steps == 0: o_dict_ = sess.run(o_dict, { fixed_x: fixed_x_, fixed_x0: fixed_x0_, fixed_x1: fixed_x1_ }) for key in o_dict: if not os.path.exists(config.model_dir + '/%06d' % iteration): os.makedirs(config.model_dir + '/%06d' % iteration) if o_dict_[key].ndim == 5: img = o_dict_[key].transpose([0, 1, 4, 2, 3]) else: img = o_dict_[key].transpose([0, 3, 1, 2]) save_image_grid( img, config.model_dir + '/%06d/%s.jpg' % (iteration, key)) if iteration % config.save_per_steps == 0: saver_e.save(sess, save_path=config.model_dir + '/en.ckpt', global_step=iteration, write_meta_graph=False) saver_d.save(sess, save_path=config.model_dir + '/de.ckpt', global_step=iteration, write_meta_graph=False) metric_dict = { 'r': r_loss_list, 'm': m_loss_list, 's': s_loss_list, 'a': a_loss_list, 'psa': ppl_sa_list, 'pre': ppl_re_list, 'lip': lip_list } np.save(config.model_dir + '/%06d' % iteration + 'metric.npy', metric_dict)
def training_loop(config: Config): timer = Timer() print('Task name %s' % config.task_name) print('Loading %s dataset...' % config.dataset_name) dset = get_dataset(config.dataset_name, config.tfds_dir, config.gpu_nums * 2) dataset = dset.input_fn(config.batch_size, mode='train') dataset = dataset.make_initializable_iterator() eval_dataset = dset.input_fn(config.batch_size, mode='eval') eval_dataset = eval_dataset.make_initializable_iterator() global_step = tf.get_variable( name='global_step', initializer=tf.constant(0), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) print("Constructing networks...") Encoder = vae.Encoder(config.dim_z, exceptions=['opt'], name='Encoder') Decoder = vae.Decoder(dset.image_shape, exceptions=['opt'], name='Decoder') learning_rate = tf.train.exponential_decay(config.lr, global_step, config.decay_step, config.decay_coef, staircase=False) solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='opt', beta2=config.beta2) print("Building tensorflow graph...") def train_step(image): mu_z, log_sigma_z, z = Encoder(image, is_training=True) x = Decoder(z, is_training=True) with tf.variable_scope('reconstruction_loss'): recon_loss = config.sigma**2 * tf.reduce_mean( tf.reduce_sum(tf.square(image - x), [1, 2, 3])) loss = recon_loss add_global = global_step.assign_add(1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([add_global] + update_ops): opt = solver.minimize(loss, var_list=Encoder.trainable_variables + Decoder.trainable_variables) with tf.control_dependencies([opt]): return tf.identity(loss) loss = train_step(dataset.get_next()[0]) print("Building eval module...") fixed_z = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z0 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z1 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_x = tf.placeholder(tf.float32, (config.example_nums, ) + dset.image_shape) fixed_x0 = tf.placeholder(tf.float32, (config.example_nums, ) + dset.image_shape) fixed_x1 = tf.placeholder(tf.float32, (config.example_nums, ) + dset.image_shape) input_dict = { 'fixed_z': fixed_z, 'fixed_z0': fixed_z0, 'fixed_z1': fixed_z1, 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1, 'num_midpoints': config.num_midpoints } def sample_step(): out_dict = generate_sample(Decoder, input_dict) out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict)) out_dict.update({ 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1 }) return out_dict def eval_step(image): mu_z, log_sigma_z, z = Encoder(image, is_training=True) x = Decoder(z, is_training=True) mse = tf.reduce_mean( tf.reduce_sum(tf.square(image - x), axis=[1, 2, 3])) return mse mse = eval_step(dataset.get_next()[0]) o_dict = sample_step() print("Building init module...") with tf.init_scope(): init = [ tf.global_variables_initializer(), dataset.initializer, eval_dataset.initializer ] saver_e = tf.train.Saver(Encoder.restore_variables) saver_d = tf.train.Saver(Decoder.restore_variables) print('Starting training...') with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(init) if config.resume: print("Restore vae...") saver_e.restore(sess, config.restore_e_dir) saver_d.restore(sess, config.restore_d_dir) timer.update() print('Preparing sample utils...') fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") for iteration in range(config.total_step): loss_, lr_ = sess.run([loss, learning_rate]) if iteration % config.print_loss_per_steps == 0: mse_ = sess.run(mse) timer.update() print("step %d, loss %f, mse %f, " "learning_rate % f, consuming time %s" % (iteration, loss_, mse_, lr_, timer.runing_time_format)) if iteration % config.eval_per_steps == 0: o_dict_ = sess.run(o_dict, { fixed_x: fixed_x_, fixed_x0: fixed_x0_, fixed_x1: fixed_x1_ }) for key in o_dict: if not os.path.exists(config.model_dir + '/%06d' % iteration): os.makedirs(config.model_dir + '/%06d' % iteration) if o_dict_[key].ndim == 5: img = o_dict_[key].transpose([0, 1, 4, 2, 3]) else: img = o_dict_[key].transpose([0, 3, 1, 2]) save_image_grid( img, config.model_dir + '/%06d/%s.jpg' % (iteration, key)) if iteration % config.save_per_steps == 0: saver_e.save(sess, save_path=config.model_dir + '/en.ckpt', global_step=iteration, write_meta_graph=False) saver_d.save(sess, save_path=config.model_dir + '/de.ckpt', global_step=iteration, write_meta_graph=False)