def main(): with tf.device(config.device): t = build_graph(is_test=True) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess: logger.info(config.ckpt_path) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path)) logger.info("Loading model completely") z_latent = sampler_switch(config) d_q = sess.run(t.p_o, feed_dict={ t.z_e: dp.test.e, t.x_c: dp.test.c, t.z_l: z_latent, t.p_in: dp.test.rd, }) r_p = sess.run(t.p_i, feed_dict={ t.x_c: dp.test.c, t.z_l: z_latent, t.z_e: dp.test.e, t.p_in: dp.test.rd }) # inverse the scaled output qm, qr, rdm, rdr = dp.out.qm, dp.out.qr, dp.out.rdm, dp.out.rdr actual_Q = anti_norm(dp.test.q, qm, qr) result_Q = anti_norm(d_q, qm, qr) actual_r = anti_norm(dp.test.rd, rdm, rdr) result_r = anti_norm(r_p, rdm, rdr) # save the result ensemble = { 'actual_Q': actual_Q, 'result_Q': result_Q, 'actual_r': actual_r, 'result_r': result_r } path = os.path.join(config.logs_path, config.description + '-test.pkl') pickle_save(ensemble, 'test_result', path) copy_file(path, config.history_test_path) # visualize the process vis.cplot(actual_Q[:, 0], result_Q[:, 0], ['Q1', 'origin', 'modify'], config.t_p) vis.cplot(actual_Q[:, 1], result_Q[:, 1], ['Q2', 'origin', 'modify'], config.t_p) for num in range(6): vis.cplot(actual_r[:, num], result_r[:, num], ['R{}'.format(num + 1), 'origin', 'modify'], config.t_p)
def main(): with tf.device(config.device): t = build_graph(is_test=True) with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess: saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path)) q_errors = [] r_adjs = [] z_adjs = [] z_true = sess.run(t.z_img, feed_dict={t.x: dp.test.rd}) for time in range(test_times): z_latent = sampler_switch(config) q_error = sess.run(t.dq, feed_dict={ t.z_e: dp.test.e, t.x_c: dp.test.c, t.z_l: z_latent, t.p_in: dp.test.rd, t.p_t: dp.test.q, }) r_adj = sess.run(t.x_lat, feed_dict={ t.x_c: dp.test.c, t.z_l: z_latent, t.z_e: dp.test.e, }) z_adj = sess.run(t.z_img, feed_dict={t.x: r_adj}) q_errors.append(q_error) r_adjs.append(r_adj) z_adjs.append(z_adj) q_errors = (np.array(q_errors) - np.expand_dims(dp.test.e, axis=0))**2 r_adjs = np.array(r_adjs).reshape(-1, config.ndim_x) z_adjs = np.array(z_adjs).reshape(-1, config.ndim_z) pickle_save([q_errors, r_adjs, z_adjs, z_true], ["productions", "adjustment", "latent_variables"], '{}/{}-metric_plus.pkl'.format(config.logs_path, config.description))
def main(db='gs'): tf.reset_default_graph() config.batch_size = dp.valid.num_sample config.distribution_sampler = db with tf.device(config.device): t = build_graph(is_test=True) with tf.Session( config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True )) as sess: saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path)) q_errors = [] r_adjs = [] z_adjs = [] z_true = sess.run(t.z_img, feed_dict={ t.x: dp.test.rd }) for time in range(test_times): z_latent = sampler_switch(config) q_error = sess.run(t.dq, feed_dict={ t.z_e: dp.test.e, t.x_c: dp.test.c, t.z_l: z_latent, t.p_in: dp.test.rd, t.p_t: dp.test.q, }) r_adj = sess.run(t.x_lat, feed_dict={ t.x_c: dp.test.c, t.z_l: z_latent, t.z_e: dp.test.e, }) z_adj = sess.run(t.z_img, feed_dict={ t.x: r_adj }) q_errors.append(q_error) r_adjs.append(r_adj) z_adjs.append(z_adj) q_errors = (np.array(q_errors) - np.expand_dims(dp.test.e, axis=0))**2 r_adjs = np.array(r_adjs).reshape(-1, config.ndim_x) z_adjs = np.array(z_adjs).reshape(-1, config.ndim_z) # revise the number of batch size tf.reset_default_graph() config.batch_size = dp.train_l.num_sample with tf.device(config.device): t = build_graph(is_test=True) with tf.Session( config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True )) as sess: saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path)) z_train = sess.run(t.z_img, feed_dict={ t.x: dp.train.rd }) pickle_save([q_errors, r_adjs, z_adjs, z_true,z_train], ["productions", "adjustment", "latent_variables"], '{}/{}-metric_plus3.pkl'.format(config.logs_path, config.get_description())) print('{}/{}-metric_plus3.pkl have been saved'.format(config.logs_path, config.get_description()))
def main(run_load_from_file=False): config = BaseConfig() config.folder_init() dp = SuvsDataProvider(num_validation=config.num_vad, shuffle='every_epoch') max_epoch = 500 batch_size_l = config.batch_size path = os.path.join(config.logs_path, config.description + '-train.pkl') # training with tf.device(config.device): h = build_graph() sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True) sess_config.gpu_options.allow_growth = True sess_config.gpu_options.per_process_gpu_memory_fraction = 0.9 saver = tf.train.Saver(max_to_keep=2) with tf.Session(config=sess_config) as sess: ''' Load from checkpoint or start a new session ''' if run_load_from_file: saver.restore(sess, tf.train.latest_checkpoint(config.ckpt_path)) training_epoch_loss, _ = pickle_load(path) else: sess.run(tf.global_variables_initializer()) training_epoch_loss = [] # Recording loss per epoch process = Process() lr_schedule = create_lr_schedule(lr_base=2e-4, decay_rate=0.1, decay_epochs=500, truncated_epoch=2000, mode=config.lr_schedule) for epoch in range(max_epoch): process.start_epoch() ''' Learning rate generator ''' learning_rate = lr_schedule(epoch) # Recording loss per iteration training_iteration_loss = [] sum_loss_rest = 0 sum_loss_dcm = 0 sum_loss_gen = 0 process_iteration = Process() data_size = dp.train_l.num_sample num_batch = data_size // config.batch_size for i in range(num_batch + 1): process_iteration.start_epoch() # Inputs # sample from data distribution batch_l = dp.train_l.next_batch(batch_size_l) z_prior = sampler.sampler_switch(config) # adversarial phase for discriminator_z _, Dz_err = sess.run([h.opt_dz, h.loss_dz], feed_dict={ h.x: batch_l.x, h.z_p: z_prior, h.lr: learning_rate, }) z_latent = sampler.sampler_switch(config) _, Di_err = sess.run( [h.opt_dimg, h.loss_dimg], feed_dict={ h.x_c: batch_l.c, h.z_l: z_latent, h.z_e: batch_l.e, h.x_s: batch_l.x, h.lr: learning_rate, }) z_latent = sampler.sampler_switch(config) # reconstruction_phase _, R_err, Ez_err, Gi_err, GE_err, EG_err = sess.run( fetches=[ h.opt_r, h.loss_r, h.loss_e, h.loss_d, h.loss_l, h.loss_eg ], feed_dict={ h.x: batch_l.x, h.z_p: z_prior, h.x_c: batch_l.c, h.z_l: z_latent, h.z_e: batch_l.e, h.x_s: batch_l.x, h.lr: learning_rate, }) # process phase _, P_err = sess.run([h.opt_p, h.loss_p], feed_dict={ h.p_i: batch_l.rd, h.p_ot: batch_l.q, h.lr: learning_rate }) # push process to normal z_latent = sampler.sampler_switch(config) _, GP_err = sess.run( [h.opt_q, h.loss_q], feed_dict={ h.x_c: batch_l.c, h.z_l: z_latent, h.z_e: batch_l.e, h.p_in: batch_l.rd, h.p_ot: batch_l.q, h.lr: learning_rate, }) # recording loss function training_iteration_loss.append([ R_err, Ez_err, Gi_err, GE_err, EG_err, Dz_err, Di_err, P_err, GP_err ]) sum_loss_rest += R_err sum_loss_dcm += Dz_err + Di_err sum_loss_gen += Gi_err + Ez_err if i % 10 == 0 and False: process_iteration.display_current_results( i, num_batch, { 'reconstruction': sum_loss_rest / (i + 1), 'discriminator': sum_loss_dcm / (i + 1), 'generator': sum_loss_gen / (i + 1), }) # In end of epoch, summary the loss average_loss_per_epoch = np.mean(np.array(training_iteration_loss), axis=0) # validation phase num_test = dp.valid.num_sample // config.batch_size testing_iteration_loss = [] for batch in range(num_test): z_latent = sampler.sampler_switch(config) batch_v = dp.valid.next_batch(config.batch_size) GPt_err = sess.run(h.loss_q, feed_dict={ h.x_c: batch_v.c, h.z_l: z_latent, h.z_e: batch_v.e, h.p_in: batch_v.rd, h.p_ot: batch_v.q, }) Pt_err = sess.run(h.loss_p, feed_dict={ h.p_i: batch_v.rd, h.p_ot: batch_v.q, }) testing_iteration_loss.append([GPt_err, Pt_err]) average_test_loss = np.mean(np.array(testing_iteration_loss), axis=0) average_per_epoch = np.concatenate( (average_loss_per_epoch, average_test_loss), axis=0) training_epoch_loss.append(average_per_epoch) # training loss name training_loss_name = [ 'R_err', 'Ez_err', 'Gi_err', 'GE_err', 'EG_err', 'Dz_err', 'Di_err', 'P_err', 'GP_err', 'GPt_err', 'Pt_err', ] if epoch % 10 == 0: process.format_meter( epoch, max_epoch, { 'R_err': average_per_epoch[0], 'Ez_err': average_per_epoch[1], 'Gi_err': average_per_epoch[2], 'GE_err': average_per_epoch[3], 'EG_err': average_per_epoch[4], 'Dz_err': average_per_epoch[5], 'Di_err': average_per_epoch[6], 'P_err': average_per_epoch[7], 'GP_err': average_per_epoch[8], 'GPt_err': average_per_epoch[9], 'Pt_err': average_per_epoch[10], }) if (epoch % 1000 == 0 or epoch == max_epoch - 1) and epoch != 0: saver.save(sess, os.path.join(config.ckpt_path, 'model_checkpoint'), global_step=epoch) pickle_save(training_epoch_loss, training_loss_name, path) copy_file(path, config.history_train_path)