def main(config): print('Loading %s dataset...' % config['dataset_name']) dset = get_dataset(config['dataset_name'], '/gdata/tfds', 2) dataset = dset.input_fn(config['batch_size'], mode='train') dataset = dataset.make_initializable_iterator() Encoder = nn.Encoder(config['dim_z'], config['e_hidden_num'], exceptions=['opt'], name='Encoder') image, label = dataset.get_next() _, _, z = Encoder(image, is_training=True) saver = tf.train.Saver(Encoder.restore_variables) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run([tf.global_variables_initializer(), dataset.initializer]) print("Restore Encoder...") saver.restore(sess, config['model_dir'] + '/en.ckpt-30000') print('Generate embeddings...') f = open(config['model_dir'] + '/embeddings.tsv', 'wt') f_writer = csv.writer(f, delimiter='\t') g = open(config['model_dir'] + '/labels.tsv', 'wt') g_writer = csv.writer(g, delimiter='\t') for _ in tqdm(range(config['total_step'])): z_, l_ = sess.run([z, label]) for row in z_: f_writer.writerow(row) for row in l_: g_writer.writerow([row]) f.close() g.close()
def training_loop(config: Config): timer = Timer() print('Task name %s' % config.task_name) print('Loading %s dataset...' % config.dataset_name) dataset = load_mnist_from_record( config.record_dir + '/Mnist20_rep.tfrecords', config.batch_size) dataset = dataset.make_initializable_iterator() laplace_sigma2 = np.load(config.record_dir + '/sigma2.npy') / (-np.log(config.laplace_a)) global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) print("Constructing networks...") Encoder = vae.Encoder(config.dim_z, config.e_hidden_num, exceptions=['opt'], name='Encoder') Decoder = vae.Decoder(config.img_shape, config.d_hidden_num, exceptions=['opt'], name='Decoder') learning_rate = tf.train.exponential_decay(config.lr, global_step, config.decay_step, config.decay_coef, staircase=False) solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='opt', beta2=config.beta2) print("Building tensorflow graph...") def train_step(data): image, rep, label = data mu_z, log_sigma_z, z = Encoder(image, is_training=True) x = Decoder(z, is_training=True, flatten=False) with tf.variable_scope('kl_divergence'): kl_divergence = -tf.reduce_mean( tf.reduce_sum( 0.5 * (1 + log_sigma_z - mu_z**2 - tf.exp(log_sigma_z)), 1)) with tf.variable_scope('reconstruction_loss'): recon_loss = -tf.reduce_mean( tf.reduce_sum( image * tf.log(x + EPS) + (1 - image) * tf.log(1 - x + EPS), [1, 2, 3])) with tf.variable_scope('smooth_loss'): s_w = smoother_weight(rep, 'heat', sigma2=laplace_sigma2) smooth_loss = batch_laplacian(s_w, z) * config.laplace_lambda loss = kl_divergence + recon_loss + smooth_loss add_global = global_step.assign_add(1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([add_global] + update_ops): opt = solver.minimize(loss, var_list=Encoder.trainable_variables + Decoder.trainable_variables) with tf.control_dependencies([opt]): return tf.identity(loss), tf.identity(recon_loss), \ tf.identity(kl_divergence), tf.identity(smooth_loss), tf.identity(s_w) loss, r_loss, kl_loss, s_loss, s_w = train_step(dataset.get_next()) print("Building eval module...") fixed_z = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z0 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z1 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_x = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) fixed_x0 = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) fixed_x1 = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) input_dict = { 'fixed_z': fixed_z, 'fixed_z0': fixed_z0, 'fixed_z1': fixed_z1, 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1, 'num_midpoints': config.num_midpoints } def eval_step(): out_dict = generate_sample(Decoder, input_dict) out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict)) out_dict.update({ 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1 }) return out_dict o_dict = eval_step() print("Building init module...") with tf.init_scope(): init = [tf.global_variables_initializer(), dataset.initializer] saver_e = tf.train.Saver(Encoder.restore_variables) saver_d = tf.train.Saver(Decoder.restore_variables) print('Starting training...') with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(init) if config.resume: print("Restore vae...") saver_e.restore(sess, config.restore_e_dir) saver_d.restore(sess, config.restore_d_dir) timer.update() print('Preparing eval utils...') fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") for iteration in range(config.total_step): loss_, r_loss_, kl_loss_, s_loss_, sw_sum_, lr_ = \ sess.run([loss, r_loss, kl_loss, s_loss, s_w, learning_rate]) if iteration % config.print_loss_per_steps == 0: timer.update() print( "step %d, loss %f, r_loss_ %f, kl_loss_ %f, s_loss_ %f, sw_prod %f, " "learning_rate % f, consuming time %s" % (iteration, loss_, r_loss_, kl_loss_, s_loss_, np.prod(sw_sum_) **(1 / 255**2), lr_, timer.runing_time_format)) if iteration % config.eval_per_steps == 0: o_dict_ = sess.run(o_dict, { fixed_x: fixed_x_, fixed_x0: fixed_x0_, fixed_x1: fixed_x1_ }) for key in o_dict: if not os.path.exists(config.model_dir + '/%06d' % iteration): os.makedirs(config.model_dir + '/%06d' % iteration) save_image_grid( o_dict_[key], config.model_dir + '/%06d/%s.jpg' % (iteration, key)) if iteration % config.save_per_steps == 0: saver_e.save(sess, save_path=config.model_dir + '/en.ckpt', global_step=iteration, write_meta_graph=False) saver_d.save(sess, save_path=config.model_dir + '/de.ckpt', global_step=iteration, write_meta_graph=False)
def training_loop(config: Config): timer = Timer() print('Task name %s' % config.task_name) strategy = tf.distribute.MirroredStrategy() print('Loading %s dataset...' % config.dataset_name) dset = get_dataset(config.dataset_name, config.tfds_dir, config.gpu_nums * 2) dataset = dset.input_fn(config.batch_size, mode='train') dataset = strategy.experimental_distribute_dataset(dataset) dataset = dataset.make_initializable_iterator() eval_dataset = dset.input_fn(config.batch_size, mode='eval') eval_dataset = strategy.experimental_distribute_dataset(eval_dataset) eval_dataset = eval_dataset.make_initializable_iterator() with strategy.scope(): global_step = tf.get_variable( name='global_step', initializer=tf.constant(0), trainable=False, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) print("Constructing networks...") Encoder = vae.Encoder(config.dim_z, config.e_hidden_num, exceptions=['opt'], name='Encoder') Decoder = vae.Decoder(config.img_shape, config.d_hidden_num, exceptions=['opt'], name='Decoder') learning_rate = tf.train.exponential_decay(config.lr, global_step, config.decay_step, config.decay_coef, staircase=False) solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='opt', beta2=config.beta2) print("Building tensorflow graph...") def train_step(image): mu_z, log_sigma_z, z = Encoder(image, is_training=True) x = Decoder(z, is_training=True, flatten=False) with tf.variable_scope('kl_divergence'): kl_divergence = 0.5 * (1 + log_sigma_z - mu_z**2 - tf.exp(log_sigma_z)) with tf.variable_scope('reconstruction_loss'): recon_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=image, logits=x)) loss = kl_divergence + recon_loss add_global = global_step.assign_add(1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([add_global] + update_ops): opt = solver.minimize(loss, var_list=Encoder.trainable_variables + Decoder.trainable_variables) with tf.control_dependencies([opt]): return tf.identity(loss), tf.identity(recon_loss), \ tf.identity(kl_divergence) loss, r_loss, kl_loss = strategy.experimental_run_v2( train_step, (dataset.get_next()[0], )) loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss, axis=None) r_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, r_loss, axis=None) kl_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, kl_loss, axis=None) print("Building eval module...") fixed_z = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z0 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z1 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_x = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) fixed_x0 = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) fixed_x1 = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) input_dict = { 'fixed_z': fixed_z, 'fixed_z0': fixed_z0, 'fixed_z1': fixed_z1, 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1, 'num_midpoints': config.num_midpoints } def eval_step(): out_dict = generate_sample(Decoder, input_dict) out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict)) return out_dict if config.gpu_nums == 1: o_dict = strategy.experimental_run_v2(eval_step, ()) else: o_dict = concate_PerReplica( strategy.experimental_run_v2(eval_step, ())) print("Building init module...") with tf.init_scope(): init = [ tf.global_variables_initializer(), dataset.initializer, eval_dataset.initializer ] saver_e = tf.train.Saver(Encoder.restore_variables) saver_d = tf.train.Saver(Decoder.restore_variables) print('Starting training...') with tf.Session(config=tf.ConfigProto( allow_soft_placement=True)) as sess: sess.run(init) if config.resume: print("Restore vae...") saver_e.restore(sess, config.restore_e_dir) saver_d.restore(sess, config.restore_d_dir) timer.update() print('Preparing eval utils...') fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") for iteration in range(config.total_step): loss_, r_loss_, kl_loss_, lr_ = sess.run( [loss, r_loss, kl_loss, learning_rate]) if iteration % config.print_loss_per_steps == 0: timer.update() print( "step %d, loss %f, r_loss_ %f, kl_loss_ %f, learning_rate % f, consuming time %s" % (iteration, loss_, r_loss_, kl_loss_, lr_, timer.runing_time_format)) if iteration % config.eval_per_steps == 0: sess.run(o_dict, { fixed_x: fixed_x_, fixed_x0: fixed_x0_, fixed_x1: fixed_x1 }) for key in o_dict: save_image_grid( o_dict[key], config.model_dir + '/%s%06d' % (key, iteration)) if iteration % config.save_per_steps == 0: saver_e.save(sess, save_path=config.model_dir + '/en.ckpt', global_step=iteration, write_meta_graph=False) saver_d.save(sess, save_path=config.model_dir + '/de.ckpt', global_step=iteration, write_meta_graph=False)
def training_loop(config: Config): timer = Timer() opts = w_config.config_mnist print('Task name %s' % config.task_name) print('Loading %s dataset...' % config.dataset_name) dataset = load_mnist_KNN_from_record(config.record_dir + '/Mnist20knn5_rep.tfrecords', config.batch_size) dataset = dataset.make_initializable_iterator() # laplace_sigma2 = np.load(config.record_dir + '/knn5sigma2.npy') / (-np.log(config.laplace_a)) laplace_sigma2 = 1.0 / (-np.log(config.laplace_a)) global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) print("Constructing networks...") Encoder = vae.Encoder(config.dim_z, config.e_hidden_num, exceptions=['opt'], name='Encoder') Decoder = vae.Decoder(config.img_shape, config.d_hidden_num, exceptions=['opt'], name='Decoder') valina_encoder = vae.Encoder(config.dim_z, config.e_hidden_num, exceptions=['opt'], name='VAE_En') def lip_metric(inputs): return inputs def d_metric(inputs): _, _, outputs = valina_encoder(inputs, True) return outputs def generator(inputs): outputs = Decoder(inputs, True, False) return outputs def lip_generator(inputs): _, _, outputs = Encoder(inputs, True) return outputs PPL = ppl.PPL_mnist(epsilon=0.01, sampling='full', generator=generator, d_metric=d_metric) Lip_PPL = ppl.PPL_mnist(epsilon=0.01, sampling='full', generator=lip_generator, d_metric=lip_metric) learning_rate = tf.train.exponential_decay(config.lr, global_step, config.decay_step, config.decay_coef, staircase=False) solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='opt', beta1=opts['adam_beta1']) print("Building tensorflow graph...") def train_step(data): image, rep, label, neighbour, index = data mu_z, log_sigma_z, z = Encoder(image, is_training=True) x = Decoder(z, is_training=True, flatten=False) with tf.variable_scope('kl_divergence'): kl_divergence = - config.wae_lambda * tf.reduce_mean(tf.reduce_sum( 0.5 * (1 + log_sigma_z - mu_z ** 2 - tf.exp(log_sigma_z)), 1)) with tf.variable_scope('reconstruction_loss'): recon_loss = - tf.reduce_mean(tf.reduce_sum( image * tf.log(x + EPS) + (1 - image) * tf.log(1 - x + EPS), [1, 2, 3])) # recon_loss = 0.05 * tf.reduce_mean(tf.reduce_sum(tf.square(image - x), [1, 2, 3])) with tf.variable_scope('smooth_loss'): mask = make_mask(neighbour, index) s_w = mask * smoother_weight(rep, 'heat', sigma2=laplace_sigma2, mask=mask) smooth_loss = batch_laplacian(s_w, z) * config.laplace_lambda s_w_mean = tf.reduce_mean(s_w) * config.batch_size * config.batch_size / (tf.reduce_sum(mask) + EPS) loss = kl_divergence + recon_loss + smooth_loss # loss = loss_match + recon_loss add_global = global_step.assign_add(1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([add_global] + update_ops): opt = solver.minimize(loss, var_list=Encoder.trainable_variables + Decoder.trainable_variables) with tf.control_dependencies([opt]): l1, l2, l3, l4, l5 = tf.identity(loss), tf.identity(recon_loss), \ tf.identity(kl_divergence), tf.identity(smooth_loss), tf.identity(s_w_mean) return l1, l2, l3, l4, l5 loss, r_loss, kl_loss, s_loss, s_w = train_step(dataset.get_next()) # def pretrain(data): # image, rep, label, neighbour, index = data # mu_z, log_sigma_z, z = Encoder(image, is_training=True) # Pz = tf.random.normal(shape=[config.batch_size, config.dim_z], mean=0.0, stddev=1.0) # mean_pz = tf.reduce_mean(Pz, axis=0, keep_dims=True) # mean_qz = tf.reduce_mean(z, axis=0, keep_dims=True) # mean_loss = tf.reduce_mean(tf.square(mean_pz - mean_qz)) # cov_pz = tf.matmul(Pz - mean_pz, Pz - mean_pz, # transpose_a=True) / (config.batch_size - 1) # cov_qz = tf.matmul(z - mean_qz, z - mean_qz, # transpose_a=True) / (config.batch_size - 1) # cov_loss = tf.reduce_mean(tf.square(cov_pz - cov_qz)) # pretrain_loss = cov_loss + mean_loss # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies(update_ops): # opt = solver.minimize(pretrain_loss, var_list=Encoder.trainable_variables) # with tf.control_dependencies([opt]): # p_loss = tf.identity(pretrain_loss) # return p_loss # p_loss = pretrain(dataset.get_next()) print("Building eval module...") fixed_z = tf.constant(np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z0 = tf.constant(np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z1 = tf.constant(np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_x = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) fixed_x0 = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) fixed_x1 = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) input_dict = {'fixed_z': fixed_z, 'fixed_z0': fixed_z0, 'fixed_z1': fixed_z1, 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1, 'num_midpoints': config.num_midpoints} def sample_step(): out_dict = generate_sample(Decoder, input_dict) out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict)) out_dict.update({'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1}) return out_dict o_dict = sample_step() def eval_step(img1, img2): z0 = tf.random.normal(shape=[config.batch_size, config.dim_z], mean=0.0, stddev=1.0) z1 = tf.random.normal(shape=[config.batch_size, config.dim_z], mean=0.0, stddev=1.0) _, _, img1_z = Encoder(img1, True) _, _, img2_z = Encoder(img2, True) ppl_sample_loss = PPL(z0, z1) ppl_de_loss = PPL(img1_z, img2_z) lip_loss = Lip_PPL(img1, img2) return ppl_sample_loss, ppl_de_loss, lip_loss img_1, _, _, _, _ = dataset.get_next() img_2, _, _, _, _ = dataset.get_next() ppl_sa_loss, ppl_de_loss, lip_loss = eval_step(img_1, img_2) print("Building init module...") with tf.init_scope(): init = [tf.global_variables_initializer(), dataset.initializer] saver_e = tf.train.Saver(Encoder.restore_variables, max_to_keep=10) saver_d = tf.train.Saver(Decoder.restore_variables, max_to_keep=10) saver_v = tf.train.Saver(valina_encoder.restore_variables) print('Starting training...') with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(init) saver_v.restore(sess, config.restore_s_dir) if config.resume: print("Restore vae...") saver_e.restore(sess, config.restore_e_dir) saver_d.restore(sess, config.restore_d_dir) timer.update() print('Preparing eval utils...') fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") # print("Start pretraining of Encoder...") # for iteration in range(500): # p_loss_ = sess.run(p_loss) # if iteration % 50 == 0: # print("Pretrain_step %d, p_loss %f" % (iteration, p_loss_)) # print("Pretraining of Encoder Done! p_loss %f. Now start training..." % p_loss_) loss_list = [] r_loss_list = [] kl_loss_list = [] s_loss_list = [] ppl_sa_list = [] ppl_re_list = [] lip_list = [] for iteration in range(config.total_step): loss_, r_loss_, m_loss_, s_loss_, sw_sum_, lr_ = \ sess.run([loss, r_loss, kl_loss, s_loss, s_w, learning_rate]) if iteration % config.print_loss_per_steps == 0: loss_list.append(loss_) r_loss_list.append(r_loss_) kl_loss_list.append(m_loss_) s_loss_list.append(s_loss_) timer.update() print("step %d, loss %f, r_loss_ %f, kl_loss_ %f, s_loss_ %f, sw %f, " "learning_rate % f, consuming time %s" % (iteration, loss_, r_loss_, m_loss_, s_loss_, np.mean(sw_sum_), lr_, timer.runing_time_format)) if iteration % 1000 == 0: sa_loss_ = 0.0 de_loss_ = 0.0 lip_loss_ = 0.0 for _ in range(200): sa_p, de_p, lip_p = sess.run([ppl_sa_loss, ppl_de_loss, lip_loss]) sa_loss_ += sa_p de_loss_ += de_p lip_loss_ += lip_p sa_loss_ /= config.batch_size * 256 de_loss_ /= config.batch_size * 256 lip_loss_ /= config.batch_size * 256 ppl_re_list.append(de_loss_) ppl_sa_list.append(sa_loss_) lip_list.append(lip_loss_) print("ppl_sample %f, ppl_resample %f, lipschitze %f" % (sa_loss_, de_loss_, lip_loss_)) if iteration % config.eval_per_steps == 0: o_dict_ = sess.run(o_dict, {fixed_x: fixed_x_, fixed_x0: fixed_x0_, fixed_x1: fixed_x1_}) for key in o_dict: if not os.path.exists(config.model_dir + '/%06d' % iteration): os.makedirs(config.model_dir + '/%06d' % iteration) save_image_grid(o_dict_[key], config.model_dir + '/%06d/%s.jpg' % (iteration, key)) if iteration % config.save_per_steps == 0: saver_e.save(sess, save_path=config.model_dir + '/en.ckpt', global_step=iteration, write_meta_graph=False) saver_d.save(sess, save_path=config.model_dir + '/de.ckpt', global_step=iteration, write_meta_graph=False) metric_dict = {'r': r_loss_list, 'm': kl_loss_list, 's': s_loss_list, 'psa': ppl_sa_list, 'pre': ppl_re_list, 'lip': lip_list} np.save(config.model_dir + '/%06d' % iteration + 'metric.npy', metric_dict)
def training_loop(config: Config): timer = Timer() print('Task name %s' % config.task_name) print('Loading %s dataset...' % config.dataset_name) dset = get_dataset(config.dataset_name, config.tfds_dir, config.gpu_nums * 2) dataset = train_input_fn(dset, config.batch_size) dataset = dataset.make_initializable_iterator() print("Constructing networks...") Encoder = vae.Encoder(config.dim_z, config.e_hidden_num, exceptions=['opt'], name='VAE_En') Decoder = vae.Decoder(config.img_shape, config.d_hidden_num, exceptions=['opt'], name='VAE_De') print("Building tensorflow graph...") image, label = dataset.get_next() _, _, z = Encoder(image, is_training=True) sigma2_plus = compute_sigma2(z) print("Building eval module...") fixed_z = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z0 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_z1 = tf.constant( np.random.normal(size=[config.example_nums, config.dim_z]), dtype=tf.float32) fixed_x = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) fixed_x0 = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) fixed_x1 = tf.placeholder(tf.float32, [config.example_nums] + config.img_shape) input_dict = { 'fixed_z': fixed_z, 'fixed_z0': fixed_z0, 'fixed_z1': fixed_z1, 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1, 'num_midpoints': config.num_midpoints } def sample_step(): out_dict = generate_sample(Decoder, input_dict) out_dict.update(reconstruction_sample(Encoder, Decoder, input_dict)) out_dict.update({ 'fixed_x': fixed_x, 'fixed_x0': fixed_x0, 'fixed_x1': fixed_x1 }) return out_dict o_dict = sample_step() print("Building init module...") with tf.init_scope(): init = [tf.global_variables_initializer(), dataset.initializer] saver_e = tf.train.Saver(Encoder.restore_variables) saver_d = tf.train.Saver(Decoder.restore_variables) print('Starting training...') with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(init) if config.resume: print("Restore vae...") saver_e.restore(sess, config.restore_e_dir) saver_d.restore(sess, config.restore_d_dir) timer.update() print('Preparing sample utils...') fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) o_dict_ = sess.run(o_dict, { fixed_x: fixed_x_, fixed_x0: fixed_x0_, fixed_x1: fixed_x1_ }) for key in o_dict_: save_image_grid(o_dict_[key], config.model_dir + '/%s.jpg' % key) print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") sigma2 = 0.0 count = 0 with tf.io.TFRecordWriter(config.model_dir + '/Mnist20_rep.tfrecords') as writer: while True: try: image_, label_, sigma2_plus_, rep_ = sess.run( [image, label, sigma2_plus, z]) sigma2 += sigma2_plus_ count += 1 for n in range(image_.shape[0]): tf_example = serialize_example(image_[n], label_[n], rep_[n]) writer.write(tf_example) if count % 100 == 0: timer.update() print('Complete %d bathes, consuming time %s' % (count, timer.runing_time_format)) except tf.errors.OutOfRangeError: np.save(config.model_dir + '/sigma2.npy', sigma2 / (count * config.batch_size)) print('Done!') break
def training_loop(config: Config): timer = Timer() print('Task name %s' % config.task_name) print('Loading %s dataset...' % config.dataset_name) dset = get_dataset(config.dataset_name, config.tfds_dir, config.gpu_nums * 2) dataset = dset.input_fn(config.batch_size, mode='train') dataset = dataset.make_initializable_iterator() eval_dataset = dset.input_fn(config.batch_size, mode='eval') eval_dataset = eval_dataset.make_initializable_iterator() global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False) print("Constructing networks...") Encoder = vae.Encoder(config.dim_z, config.e_hidden_num, exceptions=['opt'], name='Classifier') last_layer = layer(config.dim_z) learning_rate = tf.train.exponential_decay(config.lr, global_step, config.decay_step, config.decay_coef, staircase=False) solver = tf.train.AdamOptimizer(learning_rate=learning_rate, name='opt', beta2=config.beta2) print("Building tensorflow graph...") def train_step(data): image, label = data mu_z, log_sigma_z, z = Encoder(image, is_training=True) label = tf.one_hot(label, 10) y = last_layer(z, True) loss = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=label) loss = tf.reduce_mean(loss) add_global = global_step.assign_add(1) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([add_global] + update_ops): opt = solver.minimize(loss, var_list=Encoder.trainable_variables + last_layer.trainable_variables) with tf.control_dependencies([opt]): return tf.identity(loss) loss = train_step(dataset.get_next()) def eval_step(data): image, label = data mu_z, log_sigma_z, z = Encoder(image, is_training=True) y = last_layer(z, True) y = tf.nn.softmax(y) y = tf.arg_max(y, 1) p = tf.reduce_mean(tf.cast(tf.equal(y, label), tf.float32)) return p p = eval_step(eval_dataset.get_next()) print("Building init module...") with tf.init_scope(): init = [ tf.global_variables_initializer(), dataset.initializer, eval_dataset.initializer ] saver_e = tf.train.Saver(Encoder.restore_variables) print('Starting training...') with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: sess.run(init) if config.resume: print("Restore vae...") saver_e.restore(sess, config.restore_e_dir) timer.update() print('Preparing eval utils...') fixed_x_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x0_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) fixed_x1_, _ = get_fixed_x(sess, dataset, config.example_nums, config.batch_size) print("Completing all work, iteration now start, consuming %s " % timer.runing_time_format) print("Start iterations...") for iteration in range(config.total_step): loss_, lr_ = sess.run([loss, learning_rate]) if iteration % config.print_loss_per_steps == 0: timer.update() print( "step %d, loss %f, learning_rate % f, consuming time %s" % (iteration, loss_, lr_, timer.runing_time_format)) if iteration % 1000 == 0: p_ = 0.0 for _ in range(10): p_plus = sess.run(p) p_ += p_plus p_ /= 10 print('precise in eval %f' % p_) if iteration % config.save_per_steps == 0: saver_e.save(sess, save_path=config.model_dir + '/en.ckpt', global_step=iteration, write_meta_graph=False)