def main(args): np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3) # ===================================== # Preparation # ===================================== data_file = os.path.join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] x_train = np.expand_dims(x_train.astype(np.float32), axis=-1) num_train = len(x_train) print("x_train: {}".format(num_train)) args.output_dir = os.path.join(args.output_dir, args.enc_dec_model, args.run) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) else: if args.force_rm_dir: import shutil shutil.rmtree(args.output_dir, ignore_errors=True) print("Removed '{}'".format(args.output_dir)) else: raise ValueError("Output directory '{}' existed. 'force_rm_dir' " "must be set to True!".format(args.output_dir)) os.mkdir(args.output_dir) save_args(os.path.join(args.output_dir, 'config.json'), args) # pp.pprint(args.__dict__) # ===================================== # Instantiate models # ===================================== if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support enc_dec_model='{}'!".format( args.enc_dec_model)) model = FactorVAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, 'Dz_tc_loss_coeff': args.Dz_tc_loss_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_list() # SimpleParamPrinter.print_all_params_tf_slim() loss = model.get_loss() train_params = model.get_train_params() opt_Dz = tf.train.AdamOptimizer(learning_rate=args.lr_Dz, beta1=args.beta1_Dz, beta2=args.beta2_Dz) opt_vae = tf.train.AdamOptimizer(learning_rate=args.lr_vae, beta1=args.beta1_vae, beta2=args.beta2_vae) with tf.control_dependencies(model.get_all_update_ops()): train_op_Dz = opt_Dz.minimize(loss=loss['Dz_loss'], var_list=train_params['Dz_loss']) train_op_D = train_op_Dz train_op_vae = opt_vae.minimize(loss=loss['vae_loss'], var_list=train_params['vae_loss']) # ===================================== # TF Graph Handler asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset")) img_gen_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_gen")) img_rec_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_rec")) img_itpl_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_itpl")) log_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "log")) train_log_file = os.path.join(log_dir, "train.log") summary_dir = make_dir_if_not_exist( os.path.join(args.output_dir, "summary_tf")) model_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper( log_dir=summary_dir, save_dir=model_dir, max_to_keep=3, max_to_keep_best=1, ) # ===================================== # ===================================== # Training Loop # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) train_helper.initialize(sess, init_variables=True, create_summary_writer=True) Dz_fetch_keys = [ 'Dz_loss', 'Dz_tc_loss', 'Dz_loss_normal', 'Dz_loss_factor', 'Dz_avg_prob_normal', 'Dz_avg_prob_factor', 'gp0_z_tc' ] D_fetch_keys = Dz_fetch_keys vae_fetch_keys = ['vae_loss', 'rec_x', 'kld_loss', 'tc_loss'] global_step = 0 for epoch in range(args.epochs): for batch_ids in iterate_data(num_train, args.batch_size, shuffle=True): global_step += 1 x = x_train[batch_ids] z = np.random.normal(size=[len(x), args.z_dim]) batch_ids_2 = np.random.choice(num_train, size=len(batch_ids)).tolist() xa = x_train[batch_ids_2] for i in range(args.D_steps): _, Dm = sess.run( [train_op_D, model.get_output(D_fetch_keys, as_dict=True)], feed_dict={ model.is_train: True, model.x_ph: x, model.z_ph: z, model.xa_ph: xa }) for i in range(args.vae_steps): _, VAEm = sess.run( [ train_op_vae, model.get_output(vae_fetch_keys, as_dict=True) ], feed_dict={ model.is_train: True, model.x_ph: x, model.z_ph: z, model.xa_ph: xa }) if global_step % args.save_freq == 0: train_helper.save(sess, global_step) if global_step % args.log_freq == 0: log_str = "\n[FactorVAE/{}/{} (dSprites)], Epoch[{}/{}], Step {}".format( args.enc_dec_model, args.run, epoch, args.epochs, global_step) + \ "\nvae_loss: {:.4f}, Dz_loss: {:.4f}, Dz_tc_loss: {:.4f}".format( VAEm['vae_loss'], Dm['Dz_loss'], Dm['Dz_tc_loss']) + \ "\nrec_x: {:.4f}, kld_loss: {:.4f}, tc_loss: {:.4f}".format( VAEm['rec_x'], VAEm['kld_loss'], VAEm['tc_loss']) + \ "\nDz_loss_normal: {:.4f}, Dz_loss_factor: {:.4f}".format( Dm['Dz_loss_normal'], Dm['Dz_loss_factor']) + \ "\nDz_avg_prob_normal: {:.4f}, Dz_avg_prob_factor: {:.4f}".format( Dm['Dz_avg_prob_normal'], Dm['Dz_avg_prob_factor']) + \ "\ngp0_z_tc_coeff: {:.4f}, gp0_z_tc: {:.4f}".format(args.gp0_z_tc_coeff, Dm['gp0_z_tc']) print(log_str) with open(train_log_file, "a") as f: f.write(log_str) f.write("\n") f.close() train_helper.add_summary( custom_tf_scalar_summary('vae_loss', VAEm['vae_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('rec_x', VAEm['rec_x'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('kld_loss', VAEm['kld_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('tc_loss', VAEm['tc_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_tc_loss', Dm['Dz_tc_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_loss_normal', Dm['Dz_loss_normal'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_loss_factor', Dm['Dz_loss_factor'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_prob_normal', Dm['Dz_avg_prob_normal'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_prob_factor', Dm['Dz_avg_prob_factor'], prefix='train'), global_step) if global_step % args.viz_gen_freq == 0: # Generate images # ------------------------- # z = np.random.normal(size=[64, args.z_dim]) img_file = os.path.join(img_gen_dir, 'step[%d]_gen_test.png' % global_step) model.generate_images( img_file, sess, z, block_shape=[8, 8], batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ------------------------- # if global_step % args.viz_rec_freq == 0: # Reconstruct images # ------------------------- # x = x_train[np.random.choice(num_train, size=64, replace=False)] img_file = os.path.join(img_rec_dir, 'step[%d]_rec_test.png' % global_step) model.reconstruct_images( img_file, sess, x, block_shape=[8, 8], batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ------------------------- # if global_step % args.viz_itpl_freq == 0: # Interpolate images # ------------------------- # x1 = x_train[np.random.choice(num_train, size=12, replace=False)] x2 = x_train[np.random.choice(num_train, size=12, replace=False)] img_file = os.path.join(img_itpl_dir, 'step[%d]_itpl_test.png' % global_step) model.interpolate_images( img_file, sess, x1, x2, num_itpl_points=12, batch_on_row=True, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ------------------------- # # Last save train_helper.save(sess, global_step)
def main(args): # ===================================== # Preparation # ===================================== celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir) num_train = celebA_loader.num_train_data num_test = celebA_loader.num_test_data img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) args.output_dir = os.path.join(args.output_dir, args.enc_dec_model, args.run) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) else: if args.force_rm_dir: import shutil shutil.rmtree(args.output_dir, ignore_errors=True) print("Removed '{}'".format(args.output_dir)) else: raise ValueError("Output directory '{}' existed. 'force_rm_dir' " "must be set to True!".format(args.output_dir)) os.mkdir(args.output_dir) save_args(os.path.join(args.output_dir, 'config.json'), args) # pp.pprint(args.__dict__) # ===================================== # Instantiate models # ===================================== # Only use activation for encoder and decoder if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format( args.activation)) if args.enc_dec_model == "1Konny": # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support encoder/decoder model '{}'!".format( args.enc_dec_model)) model = FactorVAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'Dz_tc_loss': args.Dz_tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_list() loss = model.get_loss() train_params = model.get_train_params() opt_Dz = tf.train.AdamOptimizer(learning_rate=args.lr_Dz, beta1=args.beta1_Dz, beta2=args.beta2_Dz) opt_vae = tf.train.AdamOptimizer(learning_rate=args.lr_vae, beta1=args.beta1_vae, beta2=args.beta2_vae) with tf.control_dependencies(model.get_all_update_ops()): train_op_Dz = opt_Dz.minimize(loss=loss['Dz_loss'], var_list=train_params['Dz_loss']) train_op_D = train_op_Dz train_op_vae = opt_vae.minimize(loss=loss['vae_loss'], var_list=train_params['vae_loss']) # ===================================== # TF Graph Handler asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset")) img_gen_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_gen")) img_rec_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_rec")) img_itpl_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_itpl")) log_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "log")) train_log_file = os.path.join(log_dir, "train.log") summary_dir = make_dir_if_not_exist( os.path.join(args.output_dir, "summary_tf")) model_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper( log_dir=summary_dir, save_dir=model_dir, max_to_keep=3, max_to_keep_best=1, ) # ===================================== # ===================================== # Training Loop # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) train_helper.initialize(sess, init_variables=True, create_summary_writer=True) Dz_fetch_keys = [ 'Dz_loss', 'Dz_tc_loss', 'Dz_loss_normal', 'Dz_loss_factor', 'Dz_avg_prob_normal', 'Dz_avg_prob_factor', 'gp0_z_tc' ] D_fetch_keys = Dz_fetch_keys vae_fetch_keys = ['vae_loss', 'rec_x', 'kld_loss', 'tc_loss'] train_sampler = ContinuousIndexSampler(num_train, args.batch_size, shuffle=True) import math num_batch_per_epochs = int(math.ceil(num_train / args.batch_size)) global_step = 0 for epoch in range(args.epochs): for _ in range(num_batch_per_epochs): global_step += 1 batch_ids = train_sampler.sample_ids() x = celebA_loader.sample_images_from_dataset( sess, 'train', batch_ids) z = np.random.randn(len(x), args.z_dim) batch_ids_2 = np.random.choice(num_train, size=len(batch_ids)) xa = celebA_loader.sample_images_from_dataset( sess, 'train', batch_ids_2) for i in range(args.D_steps): _, Dm = sess.run( [train_op_D, model.get_output(D_fetch_keys, as_dict=True)], feed_dict={ model.is_train: True, model.x_ph: x, model.z_ph: z, model.xa_ph: xa }) for i in range(args.vae_steps): _, VAEm = sess.run( [ train_op_vae, model.get_output(vae_fetch_keys, as_dict=True) ], feed_dict={ model.is_train: True, model.x_ph: x, model.z_ph: z, model.xa_ph: xa }) if global_step % args.save_freq == 0: train_helper.save(sess, global_step) if global_step % args.log_freq == 0: log_str = "\n[FactorVAE (celebA)/{}, {}]".format(args.enc_dec_model, args.run) + \ "\nEpoch {}/{}, Step {}, vae_loss: {:.4f}, Dz_loss: {:.4f}, Dz_tc_loss: {:.4f}".format( epoch, args.epochs, global_step, VAEm['vae_loss'], Dm['Dz_loss'], Dm['Dz_tc_loss']) + \ "\nrec_x: {:.4f}, kld_loss: {:.4f}, tc_loss: {:.4f}".format( VAEm['rec_x'], VAEm['kld_loss'], VAEm['tc_loss']) + \ "\nDz_loss_normal: {:.4f}, Dz_loss_factor: {:.4f}".format( Dm['Dz_loss_normal'], Dm['Dz_loss_factor']) + \ "\nDz_avg_prob_normal: {:.4f}, Dz_avg_prob_factor: {:.4f}".format( Dm['Dz_avg_prob_normal'], Dm['Dz_avg_prob_factor']) + \ "\ngp0_z_tc_coeff: {:.4f}, gp0_z_tc: {:.4f}".format(args.gp0_z_tc_coeff, Dm['gp0_z_tc']) print(log_str) with open(train_log_file, "a") as f: f.write(log_str) f.write("\n") f.close() train_helper.add_summary( custom_tf_scalar_summary('vae_loss', VAEm['vae_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('rec_x', VAEm['rec_x'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('kld_loss', VAEm['kld_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('tc_loss', VAEm['tc_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_tc_loss', Dm['Dz_tc_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_loss_normal', Dm['Dz_loss_normal'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_loss_factor', Dm['Dz_loss_factor'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_prob_normal', Dm['Dz_avg_prob_normal'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_prob_factor', Dm['Dz_avg_prob_factor'], prefix='train'), global_step) if global_step % args.viz_gen_freq == 0: # Generate images # ------------------------- # z = np.random.randn(64, args.z_dim) img_file = os.path.join(img_gen_dir, 'step[%d]_gen_test.png' % global_step) model.generate_images( img_file, sess, z, block_shape=[8, 8], batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ------------------------- # if global_step % args.viz_rec_freq == 0: # Reconstruct images # ------------------------- # x = celebA_loader.sample_images_from_dataset( sess, 'test', np.random.choice(num_test, size=64, replace=False)) img_file = os.path.join(img_rec_dir, 'step[%d]_rec_test.png' % global_step) model.reconstruct_images( img_file, sess, x, block_shape=[8, 8], batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ------------------------- # if global_step % args.viz_itpl_freq == 0: # Interpolate images # ------------------------- # x1 = celebA_loader.sample_images_from_dataset( sess, 'test', np.random.choice(num_test, size=12, replace=False)) x2 = celebA_loader.sample_images_from_dataset( sess, 'test', np.random.choice(num_test, size=12, replace=False)) img_file = os.path.join(img_itpl_dir, 'step[%d]_itpl_test.png' % global_step) model.interpolate_images( img_file, sess, x1, x2, num_itpl_points=12, batch_on_row=True, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ------------------------- # if epoch % 100 == 0: train_helper.save_separately( sess, model_name="model_epoch{}".format(epoch), global_step=global_step) # Last save train_helper.save(sess, global_step)
def main(args): # ===================================== # Load config # ===================================== with open(os.path.join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Preparation # ===================================== celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir) num_train = celebA_loader.num_train_data num_test = celebA_loader.num_test_data img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) # ===================================== # Instantiate models # ===================================== # Only use activation for encoder and decoder if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format( args.activation)) if args.enc_dec_model == "1Konny": # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny(num_outputs=2) elif args.enc_dec_model == "my": # assert args.z_dim == 150, "For my, z_dim must be 150. Found {}!".format(args.z_dim) encoder = Encoder_My(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_My([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_My(num_outputs=2) else: raise ValueError("Do not support encoder/decoder model '{}'!".format( args.enc_dec_model)) model = FactorVAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, # 'Dz_tc_loss': args.Dz_tc_loss, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_list() # ===================================== # TF Graph Handler asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset")) img_eval = remove_dir_if_exist(os.path.join(asset_dir, "img_eval"), ask_4_permission=False) img_eval = make_dir_if_not_exist(img_eval) img_x_gen = make_dir_if_not_exist(os.path.join(img_eval, "x_gen")) img_x_rec = make_dir_if_not_exist(os.path.join(img_eval, "x_rec")) img_z_rand_2_traversal = make_dir_if_not_exist( os.path.join(img_eval, "z_rand_2_traversal")) img_z_cond_all_traversal = make_dir_if_not_exist( os.path.join(img_eval, "z_cond_all_traversal")) img_z_cond_1_traversal = make_dir_if_not_exist( os.path.join(img_eval, "z_cond_1_traversal")) img_z_corr = make_dir_if_not_exist(os.path.join(img_eval, "z_corr")) img_z_dist = make_dir_if_not_exist(os.path.join(img_eval, "z_dist")) img_z_stat_dist = make_dir_if_not_exist( os.path.join(img_eval, "z_stat_dist")) # img_rec_error_dist = make_dir_if_not_exist(os.path.join(img_eval, "rec_error_dist")) model_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # ===================================== # ===================================== # Training Loop # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) # Load model train_helper.load(sess, load_step=args.load_step) # ''' # Generation # ======================================= # z = np.random.randn(64, args.z_dim) img_file = os.path.join(img_x_gen, 'x_gen_test.png') model.generate_images(img_file, sess, z, block_shape=[8, 8], batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ======================================= # # ''' # ''' # Reconstruction # ======================================= # seed = 389 x = celebA_loader.sample_images_from_dataset(sess, 'test', list(range(seed, seed + 64))) img_file = os.path.join(img_x_rec, 'x_rec_test.png') model.reconstruct_images(img_file, sess, x, block_shape=[8, 8], batch_size=-1, dec_output_2_img_func=binary_float_to_uint8) # ======================================= # # ''' # ''' # z random traversal # ======================================= # if args.z_dim <= 5: print("z_dim = {}, perform random traversal!".format(args.z_dim)) # Plot z cont with z cont z_zero = np.zeros([args.z_dim], dtype=np.float32) z_rand = np.random.randn(args.z_dim) z_start, z_stop = -4, 4 num_points = 8 for i in range(args.z_dim): for j in range(i + 1, args.z_dim): print( "Plot random 2 comps z traverse with {} and {} components!" .format(i, j)) img_file = os.path.join(img_z_rand_2_traversal, 'z[{},{},zero].png'.format(i, j)) model.rand_2_latents_traverse( img_file, sess, default_z=z_zero, z_comp1=i, start1=z_start, stop1=z_stop, num_points1=num_points, z_comp2=j, start2=z_start, stop2=z_stop, num_points2=num_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) img_file = os.path.join(img_z_rand_2_traversal, 'z[{},{},rand].png'.format(i, j)) model.rand_2_latents_traverse( img_file, sess, default_z=z_rand, z_comp1=i, start1=z_start, stop1=z_stop, num_points1=num_points, z_comp2=j, start2=z_stop, stop2=z_stop, num_points2=num_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ======================================= # # ''' # z conditional traversal (all features + one feature) # ======================================= # seed = 389 num_samples = 30 data = celebA_loader.sample_images_from_dataset( sess, 'train', list(range(seed, seed + num_samples))) z_start, z_stop = -4, 4 num_itpl_points = 8 for n in range(num_samples): print("Plot conditional all comps z traverse with test sample {}!". format(n)) img_file = os.path.join(img_z_cond_all_traversal, 'x_train{}.png'.format(n)) model.cond_all_latents_traverse( img_file, sess, data[n], start=z_start, stop=z_stop, num_itpl_points=num_itpl_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) z_start, z_stop = -4, 4 num_itpl_points = 8 for i in range(args.z_dim): print("Plot conditional z traverse with comp {}!".format(i)) img_file = os.path.join( img_z_cond_1_traversal, 'x_train[{},{}]_z{}.png'.format(seed, seed + num_samples, i)) model.cond_1_latent_traverse( img_file, sess, data, z_comp=i, start=z_start, stop=z_stop, num_itpl_points=num_itpl_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ======================================= # # ''' # ''' # z correlation matrix # ======================================= # all_z = [] for batch_ids in iterate_data(num_train, args.batch_size, shuffle=False): x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids) z = model.encode(sess, x) assert len( z.shape) == 2 and z.shape[1] == args.z_dim, "z.shape: {}".format( z.shape) all_z.append(z) all_z = np.concatenate(all_z, axis=0) print("Start plotting!") plot_corrmat_with_histogram(os.path.join(img_z_corr, "corr_mat.png"), all_z) plot_comp_dist(os.path.join(img_z_dist, 'z_{}'), all_z, x_lim=(-5, 5)) print("Done!") # ======================================= # # ''' # ''' # z gaussian stddev # ======================================= # print("\nPlot z mean and stddev!") all_z_mean = [] all_z_stddev = [] for batch_ids in iterate_data(num_train, args.batch_size, shuffle=False): x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids) z_mean, z_stddev = sess.run(model.get_output(['z_mean', 'z_stddev']), feed_dict={ model.is_train: False, model.x_ph: x }) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) plot_comp_dist(os.path.join(img_z_stat_dist, 'z_mean_{}.png'), all_z_mean, x_lim=(-5, 5)) plot_comp_dist(os.path.join(img_z_stat_dist, 'z_stddev_{}.png'), all_z_stddev, x_lim=(-0.5, 3))