Copyright For Mutal information Metrics """ import os, sys from model import config import matplotlib.pyplot as plt import matplotlib as mpl import seaborn as sns sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../"))) from data_providers.data_provider import SuvsDataProvider from plot import Visualizer from sequential.utils import pickle_load vis = Visualizer() dp = SuvsDataProvider(num_validation=config.num_vad, shuffle='every_epoch') config.is_train = False config.batch_size = dp.valid.num_sample test_times = 60 # Gaussian mixture description1 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-gm-bc-gs-2018-02-04-metric.pkl' [q_errors1, r_adjs1, z_adjs1], name1 = pickle_load(description1) # Swiss Roll description2 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-sr-bc-gs-2018-02-04-metric.pkl' [q_errors2, r_adjs2, z_adjs2], name2 = pickle_load(description2) # Uniform Desk description3 = 'logs/hybrid_GAN_lin_res-Dz=0.01-R=1-Lat=1.5-Tv=0.01-d-ud-bc-gs-2018-02-04-metric.pkl' [q_errors3, r_adjs3, z_adjs3], name3 = pickle_load(description3)
def main(run_load_from_file=False): config = BaseConfig() config.folder_init() dp = SuvsDataProvider(num_validation=config.num_vad, shuffle='every_epoch') max_epoch = 500 batch_size_l = config.batch_size path = os.path.join(config.logs_path, config.description + '-train.pkl') # training with tf.device(config.device): h = build_graph() sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True) sess_config.gpu_options.allow_growth = True sess_config.gpu_options.per_process_gpu_memory_fraction = 0.9 saver = tf.train.Saver(max_to_keep=2) with tf.Session(config=sess_config) as sess: ''' Load from checkpoint or start a new session ''' if run_load_from_file: saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path)) training_epoch_loss, _ = pickle_load(path) else: sess.run(tf.global_variables_initializer()) training_epoch_loss = [] # Recording loss per epoch process = Process() lr_schedule = create_lr_schedule(lr_base=2e-4, decay_rate=0.1, decay_epochs=500, truncated_epoch=2000, mode=config.lr_schedule) for epoch in range(max_epoch): process.start_epoch() ''' Learning rate generator ''' learning_rate = lr_schedule(epoch) # Recording loss per iteration training_iteration_loss = [] sum_loss_rest = 0 sum_loss_dcm = 0 sum_loss_gen = 0 process_iteration = Process() data_size = dp.train_l.num_sample num_batch = data_size // config.batch_size for i in range(num_batch + 1): process_iteration.start_epoch() # Inputs # sample from data distribution batch_l = dp.train_l.next_batch(batch_size_l) z_prior = sampler.sampler_switch(config) # adversarial phase for discriminator_z _, Dz_err = sess.run([h.opt_dz, h.loss_dz], feed_dict={ h.x: batch_l.x, h.z_p: z_prior, h.lr: learning_rate, }) z_latent = sampler.sampler_switch(config) _, Di_err = sess.run( [h.opt_dimg, h.loss_dimg], feed_dict={ h.x_c: batch_l.c, h.z_l: z_latent, h.z_e: batch_l.e, h.x_s: batch_l.x, h.lr: learning_rate, }) z_latent = sampler.sampler_switch(config) # reconstruction_phase _, R_err, Ez_err, Gi_err, GE_err, EG_err = sess.run( fetches=[ h.opt_r, h.loss_r, h.loss_e, h.loss_d, h.loss_l, h.loss_eg ], feed_dict={ h.x: batch_l.x, h.z_p: z_prior, h.x_c: batch_l.c, h.z_l: z_latent, h.z_e: batch_l.e, h.x_s: batch_l.x, h.lr: learning_rate, }) # process phase _, P_err = sess.run([h.opt_p, h.loss_p], feed_dict={ h.p_i: batch_l.rd, h.p_ot: batch_l.q, h.lr: learning_rate }) # push process to normal z_latent = sampler.sampler_switch(config) _, GP_err = sess.run( [h.opt_q, h.loss_q], feed_dict={ h.x_c: batch_l.c, h.z_l: z_latent, h.z_e: batch_l.e, h.p_in: batch_l.rd, h.p_ot: batch_l.q, h.lr: learning_rate, }) # recording loss function training_iteration_loss.append([ R_err, Ez_err, Gi_err, GE_err, EG_err, Dz_err, Di_err, P_err, GP_err ]) sum_loss_rest += R_err sum_loss_dcm += Dz_err + Di_err sum_loss_gen += Gi_err + Ez_err if i % 10 == 0 and False: process_iteration.display_current_results( i, num_batch, { 'reconstruction': sum_loss_rest / (i + 1), 'discriminator': sum_loss_dcm / (i + 1), 'generator': sum_loss_gen / (i + 1), }) # In end of epoch, summary the loss average_loss_per_epoch = np.mean(np.array(training_iteration_loss), axis=0) # validation phase num_test = dp.valid.num_sample // config.batch_size testing_iteration_loss = [] for batch in range(num_test): z_latent = sampler.sampler_switch(config) batch_v = dp.valid.next_batch(config.batch_size) GPt_err = sess.run(h.loss_q, feed_dict={ h.x_c: batch_v.c, h.z_l: z_latent, h.z_e: batch_v.e, h.p_in: batch_v.rd, h.p_ot: batch_v.q, }) Pt_err = sess.run(h.loss_p, feed_dict={ h.p_i: batch_v.rd, h.p_ot: batch_v.q, }) testing_iteration_loss.append([GPt_err, Pt_err]) average_test_loss = np.mean(np.array(testing_iteration_loss), axis=0) average_per_epoch = np.concatenate( (average_loss_per_epoch, average_test_loss), axis=0) training_epoch_loss.append(average_per_epoch) # training loss name training_loss_name = [ 'R_err', 'Ez_err', 'Gi_err', 'GE_err', 'EG_err', 'Dz_err', 'Di_err', 'P_err', 'GP_err', 'GPt_err', 'Pt_err', ] if epoch % 10 == 0: process.format_meter( epoch, max_epoch, { 'R_err': average_per_epoch[0], 'Ez_err': average_per_epoch[1], 'Gi_err': average_per_epoch[2], 'GE_err': average_per_epoch[3], 'EG_err': average_per_epoch[4], 'Dz_err': average_per_epoch[5], 'Di_err': average_per_epoch[6], 'P_err': average_per_epoch[7], 'GP_err': average_per_epoch[8], 'GPt_err': average_per_epoch[9], 'Pt_err': average_per_epoch[10], }) if (epoch % 1000 == 0 or epoch == max_epoch - 1) and epoch != 0: saver.save(sess, os.path.join(config.ckpt_path, 'model_checkpoint'), global_step=epoch) pickle_save(training_epoch_loss, training_loss_name, path) copy_file(path, config.history_train_path)