def reconstruct_images(self, save_file, sess, x, block_shape, show_original_images=True, batch_size=20, dec_output_2_img_func=None, **kwargs): if batch_size < 0: x1 = self.reconstruct(sess, x, **kwargs) else: x1 = [] for batch_ids in iterate_data(len(x), batch_size, shuffle=False): x1.append(self.reconstruct(sess, x[batch_ids], **kwargs)) x1 = np.concatenate(x1, axis=0) if dec_output_2_img_func is not None: x1 = dec_output_2_img_func(x1) x = dec_output_2_img_func(x) x1 = np.reshape(x1, to_list(block_shape) + self.x_shape) x = np.reshape(x, to_list(block_shape) + self.x_shape) if show_original_images: save_img_blocks_col_by_col(save_file, [x, x1]) else: save_img_block(save_file, x1)
def generate_images(self, save_file, sess, z, block_shape, batch_size=20, dec_output_2_img_func=None, **kwargs): if batch_size < 0: x1_gen = self.decode(sess, z, **kwargs) else: x1_gen = [] for batch_ids in iterate_data(len(z), batch_size, shuffle=False): x1_gen.append(self.decode(sess, z[batch_ids], **kwargs)) x1_gen = np.concatenate(x1_gen, axis=0) if dec_output_2_img_func is not None: x1_gen = dec_output_2_img_func(x1_gen) x1_gen = np.reshape(x1_gen, to_list(block_shape) + self.x_shape) save_img_block(save_file, x1_gen)
def main(args): # Load config # ===================================== # with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # # Load dataset # ===================================== # data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] # 3 shape * 6 scale * 40 rotation * 32 pos X * 32 pos Y y_train = f['latents_classes'][:, 1:] x_train = np.expand_dims(x_train.astype(np.float32), axis=-1) num_train = len(x_train) print("num_train: {}".format(num_train)) # ===================================== # # Build model # ===================================== # if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny() else: raise ValueError("Do not support enc_dec_model='{}'!".format( args.enc_dec_model)) model = AAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, stochastic_z=args.stochastic_z, use_gp0_z=True, gp0_z_mode=args.gp0_z_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'G_loss_z1_gen': args.G_loss_z1_gen_coeff, 'D_loss_z1_gen': args.D_loss_z1_gen_coeff, 'gp0_z': args.gp0_z_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # # Initialize session # ===================================== # config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) train_helper.load(sess, load_step=args.load_step) # ===================================== # # Experiments # ===================================== # save_dir = make_dir_if_not_exist( join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run))) np.set_printoptions(threshold=np.nan, linewidth=1000, precision=5, suppress=True) # ===================================== # # Compute representations # ===================================== # z_data_file = join(save_dir, "z_data.npz") if not exists(z_data_file): all_z_mean = [] all_z_stddev = [] print("") print("Compute all_z_mean, all_z_stddev!") count = 0 for batch_ids in iterate_data(num_train, 10 * args.batch_size, shuffle=False): x = x_train[batch_ids] z_samples, z_mean, z_stddev = sess.run(model.get_output( ['z1_gen', 'z_mean', 'z_stddev']), feed_dict={ model.is_train: False, model.x_ph: x }) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) count += len(batch_ids) print("\rProcessed {} samples!".format(count), end="") print() all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) np.savez_compressed(z_data_file, all_z_mean=all_z_mean, all_z_stddev=all_z_stddev) else: print("{} exists. Load data from file!".format(z_data_file)) with np.load(z_data_file, "r") as f: all_z_mean = f['all_z_mean'] all_z_stddev = f['all_z_stddev'] # ===================================== # cont_mask = [False, True, True, True, True ] if args.continuous_only else None if args.classifier == "LASSO": results = compute_metrics_with_LASSO(latents=all_z_mean, factors=y_train, params={ 'alpha': args.LASSO_alpha, 'max_iter': args.LASSO_iters }, cont_mask=cont_mask) result_file = join( save_dir, "results[LASSO,{},alpha={},iters={}].npz".format( "cont" if args.continuous_only else "all", args.LASSO_alpha, args.LASSO_iters)) else: results = compute_metrics_with_RandomForest(latents=all_z_mean, factors=y_train, params={ 'n_estimators': args.RF_trees, 'max_depth': args.RF_depth }) result_file = join( save_dir, "results[RF,{},trees={},depth={}].npz".format( "cont" if args.continuous_only else "all", args.RF_trees, args.RF_depth)) np.savez_compressed(result_file, **results)
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] # 3 shape * 6 scale * 40 rotation * 32 pos X * 32 pos Y y_train = f['latents_classes'] x_train = np.expand_dims(x_train.astype(np.float32), axis=-1) num_train = len(x_train) print("num_train: {}".format(num_train)) print("y_train[:10]: {}".format(y_train[:10])) # ===================================== # Instantiate model # ===================================== if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support enc_dec_model='{}'!".format( args.enc_dec_model)) model = FactorVAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, 'Dz_tc_loss_coeff': args.Dz_tc_loss_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments save_dir = make_dir_if_not_exist( join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=5, suppress=True) num_bins = args.num_bins bin_limits = tuple([float(s) for s in args.bin_limits.split(";")]) data_proportion = args.data_proportion num_data = int(data_proportion * num_train) assert num_data == num_train, "For dSprites, you must use all data!" eps = 1e-8 # file f = open(join( save_dir, 'log[bins={},bin_limits={},data={}].txt'.format( num_bins, bin_limits, data_proportion)), mode='w') # print function print_ = functools.partial(print_both, file=f) print_("num_bins: {}".format(num_bins)) print_("bin_limits: {}".format(bin_limits)) print_("data_proportion: {}".format(data_proportion)) # Compute bins # ================================= # print_("") print_("bin_limits: {}".format(bin_limits)) assert len(bin_limits) == 2 and bin_limits[0] < bin_limits[ 1], "bin_limits={}".format(bin_limits) bins = np.linspace(bin_limits[0], bin_limits[1], num_bins + 1, endpoint=True) print_("bins: {}".format(bins)) assert len(bins) == num_bins + 1 bin_widths = [bins[b] - bins[b - 1] for b in range(1, len(bins))] print_("bin_widths: {}".format(bin_widths)) assert len(bin_widths ) == num_bins, "len(bin_widths)={} while num_bins={}!".format( len(bin_widths), num_bins) assert np.all(np.greater(bin_widths, 0)), "bin_widths: {}".format(bin_widths) bin_centers = [(bins[b] + bins[b - 1]) * 0.5 for b in range(1, len(bins))] print_("bin_centers: {}".format(bin_centers)) assert len(bin_centers ) == num_bins, "len(bin_centers)={} while num_bins={}!".format( len(bin_centers), num_bins) # ================================= # # Compute representations # ================================= # z_data_file = join(save_dir, "z_data[data={}].npz".format(data_proportion)) if not exists(z_data_file): all_z_mean = [] all_z_stddev = [] print("") print("Compute all_z_mean, all_z_stddev and all_attrs!") count = 0 for batch_ids in iterate_data(num_data, 10 * args.batch_size, shuffle=False): x = x_train[batch_ids] z_mean, z_stddev = sess.run(model.get_output( ['z_mean', 'z_stddev']), feed_dict={ model.is_train: False, model.x_ph: x }) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) count += len(batch_ids) print("\rProcessed {} samples!".format(count), end="") print() all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) np.savez_compressed(z_data_file, all_z_mean=all_z_mean, all_z_stddev=all_z_stddev) else: print("{} exists. Load data from file!".format(z_data_file)) with np.load(z_data_file, "r") as f: all_z_mean = f['all_z_mean'] all_z_stddev = f['all_z_stddev'] # ================================= # # Compute mutual information # ================================= # H_z = [] H_z_cond_x = [] MI_z_x = [] norm_MI_z_x = [] Q_z_cond_x = [] for i in range(args.z_dim): print_("") print_("Compute I(z{}, x)!".format(i)) # Q_s_cond_x all_Q_s_cond_x = [] for batch_ids in iterate_data(len(all_z_mean), 500, shuffle=False, include_remaining=True): # (batch_size, num_bins) q_s_cond_x = normal_density( np.expand_dims(bin_centers, axis=0), mean=np.expand_dims(all_z_mean[batch_ids, i], axis=-1), stddev=np.expand_dims(all_z_stddev[batch_ids, i], axis=-1)) # (batch_size, num_bins) max_q_s_cond_x = np.max(q_s_cond_x, axis=-1) # print("max_q_s_cond_x: {}".format(np.sort(max_q_s_cond_x))) # (batch_size, num_bins) deter_s_cond_x = at_bin(all_z_mean[batch_ids, i], bins).astype(np.float32) # (batch_size, num_bins) Q_s_cond_x = q_s_cond_x * np.expand_dims(bin_widths, axis=0) Q_s_cond_x = Q_s_cond_x / np.maximum( np.sum(Q_s_cond_x, axis=1, keepdims=True), eps) # print("sort(sum(Q_s_cond_x)) (before): {}".format(np.sort(np.sum(Q_s_cond_x, axis=-1)))) Q_s_cond_x = np.where( np.expand_dims(np.less(max_q_s_cond_x, 1e-5), axis=-1), deter_s_cond_x, Q_s_cond_x) # print("sort(sum(Q_s_cond_x)) (after): {}".format(np.sort(np.sum(Q_s_cond_x, axis=-1)))) all_Q_s_cond_x.append(Q_s_cond_x) all_Q_s_cond_x = np.concatenate(all_Q_s_cond_x, axis=0) print_("sort(sum(all_Q_s_cond_x))[:10]: {}".format( np.sort(np.sum(all_Q_s_cond_x, axis=-1), axis=0)[:100])) assert np.all(all_Q_s_cond_x >= 0), "'all_Q_s_cond_x' contains negative values. " \ "sorted_all_Q_s_cond_x[:30]:\n{}!".format( np.sort(all_Q_s_cond_x[:30], axis=None)) Q_z_cond_x.append(all_Q_s_cond_x) H_zi_cond_x = -np.mean(np.sum( all_Q_s_cond_x * np.log(np.maximum(all_Q_s_cond_x, eps)), axis=1), axis=0) # Q_s Q_s = np.mean(all_Q_s_cond_x, axis=0) print_("Q_s: {}".format(Q_s)) print_("sum(Q_s): {}".format(sum(Q_s))) assert np.all(Q_s >= 0), "'Q_s' contains negative values. " \ "sorted_Q_s[:10]:\n{}!".format(np.sort(Q_s, axis=None)) Q_s = Q_s / np.sum(Q_s, axis=0) print_("sum(Q_s) (normalized): {}".format(sum(Q_s))) H_zi = -np.sum(Q_s * np.log(np.maximum(Q_s, eps)), axis=0) MI_zi_x = H_zi - H_zi_cond_x normalized_MI_zi_x = (1.0 * MI_zi_x) / (H_zi + eps) print_("H_zi: {}".format(H_zi)) print_("H_zi_cond_x: {}".format(H_zi_cond_x)) print_("MI_zi_x: {}".format(MI_zi_x)) print_("normalized_MI_zi_x: {}".format(normalized_MI_zi_x)) H_z.append(H_zi) H_z_cond_x.append(H_zi_cond_x) MI_z_x.append(MI_zi_x) norm_MI_z_x.append(normalized_MI_zi_x) H_z = np.asarray(H_z, dtype=np.float32) H_z_cond_x = np.asarray(H_z_cond_x, dtype=np.float32) MI_z_x = np.asarray(MI_z_x, dtype=np.float32) norm_MI_z_x = np.asarray(norm_MI_z_x, dtype=np.float32) print_("") print_("H_z: {}".format(H_z)) print_("H_z_cond_x: {}".format(H_z_cond_x)) print_("MI_z_x: {}".format(MI_z_x)) print_("norm_MI_z_x: {}".format(norm_MI_z_x)) sorted_z_comps = np.argsort(MI_z_x, axis=0)[::-1] sorted_MI_z_x = np.take_along_axis(MI_z_x, sorted_z_comps, axis=0) print_("sorted_MI_z_x: {}".format(sorted_MI_z_x)) print_("sorted_z_comps: {}".format(sorted_z_comps)) sorted_norm_z_comps = np.argsort(norm_MI_z_x, axis=0)[::-1] sorted_norm_MI_z_x = np.take_along_axis(norm_MI_z_x, sorted_norm_z_comps, axis=0) print_("sorted_norm_MI_z_x: {}".format(sorted_norm_MI_z_x)) print_("sorted_norm_z_comps: {}".format(sorted_norm_z_comps)) result_file = join( save_dir, 'results[bins={},bin_limits={},data={}].npz'.format( num_bins, bin_limits, data_proportion)) np.savez_compressed(result_file, H_z=H_z, H_z_cond_x=H_z_cond_x, MI_z_x=MI_z_x, norm_MI_z_x=norm_MI_z_x, sorted_MI_z_x=sorted_MI_z_x, sorted_z_comps=sorted_z_comps, sorted_norm_MI_z_x=sorted_norm_MI_z_x, sorted_norm_z_comps=sorted_norm_z_comps) Q_z_cond_x = np.asarray(Q_z_cond_x, dtype=np.float32) z_prob_file = join( save_dir, 'z_prob[bins={},bin_limits={},data={}].npz'.format( num_bins, bin_limits, data_proportion)) np.savez_compressed(z_prob_file, Q_z_cond_x=Q_z_cond_x) # ================================= # f.close()
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== celebA_loader = TFCelebAWithAttrLoader(root_dir=args.celebA_root_dir) img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) # ===================================== # Instantiate model # ===================================== if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format( args.activation)) if args.enc_dec_model == "1Konny": # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support encoder/decoder model '{}'!".format( args.enc_dec_model)) model = FactorVAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments # save_dir = remove_dir_if_exist(join(args.save_dir, "FactorVAE_{}".format(args.run)), ask_4_permission=False) # save_dir = make_dir_if_not_exist(save_dir) save_dir = make_dir_if_not_exist( join(args.save_dir, "FactorVAE_{}".format(args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3, suppress=True) num_bins = args.num_bins bin_limits = tuple([float(s) for s in args.bin_limits.split(";")]) data_proportion = args.data_proportion num_data = int(data_proportion * celebA_loader.num_train_data) top_k = args.top_k eps = 1e-8 # file f = open(join( save_dir, 'log[bins={},bin_limits={},data={}].txt'.format( num_bins, bin_limits, data_proportion)), mode='w') # print function print_ = functools.partial(print_both, file=f) print_("num_bins: {}".format(num_bins)) print_("bin_limits: {}".format(bin_limits)) print_("data_proportion: {}".format(data_proportion)) print_("top_k: {}".format(top_k)) # Compute bins # ================================= # print_("") print_("bin_limits: {}".format(bin_limits)) assert len(bin_limits) == 2 and bin_limits[0] < bin_limits[ 1], "bin_limits={}".format(bin_limits) bins = np.linspace(bin_limits[0], bin_limits[1], num_bins + 1, endpoint=True) print_("bins: {}".format(bins)) assert len(bins) == num_bins + 1 bin_widths = [bins[b] - bins[b - 1] for b in range(1, len(bins))] print_("bin_widths: {}".format(bin_widths)) assert len(bin_widths ) == num_bins, "len(bin_widths)={} while num_bins={}!".format( len(bin_widths), num_bins) assert np.all(np.greater(bin_widths, 0)), "bin_widths: {}".format(bin_widths) bin_centers = [(bins[b] + bins[b - 1]) * 0.5 for b in range(1, len(bins))] print_("bin_centers: {}".format(bin_centers)) assert len(bin_centers ) == num_bins, "len(bin_centers)={} while num_bins={}!".format( len(bin_centers), num_bins) # ================================= # # Compute representations # ================================= # z_data_file = join(args.informativeness_metrics_dir, "FactorVAE_{}".format(args.run), "z_data[data={}].npz".format(data_proportion)) with np.load(z_data_file, "r") as f: all_z_mean = f['all_z_mean'] all_z_stddev = f['all_z_stddev'] print_("") print_("all_z_mean.shape: {}".format(all_z_mean.shape)) print_("all_z_stddev.shape: {}".format(all_z_stddev.shape)) # ================================= # # Compute the mutual information # ================================= # mi_file = join( args.informativeness_metrics_dir, "FactorVAE_{}".format(args.run), 'results[bins={},bin_limits={},data={}].npz'.format( num_bins, bin_limits, data_proportion)) with np.load(mi_file, "r") as f: sorted_MI_z_x = f['sorted_MI_z_x'] sorted_z_ids = f['sorted_z_comps'] H_z = f['H_z'] if top_k > 0: top_MI = sorted_MI_z_x[:top_k] top_z_ids = sorted_z_ids[:top_k] bot_MI = sorted_MI_z_x[-top_k:] bot_z_ids = sorted_z_ids[-top_k:] top_bot_MI = np.concatenate([top_MI, bot_MI], axis=0) top_bot_z_ids = np.concatenate([top_z_ids, bot_z_ids], axis=0) print_("top MI: {}".format(top_MI)) print_("top_z_ids: {}".format(top_z_ids)) print_("bot MI: {}".format(bot_MI)) print_("bot_z_ids: {}".format(bot_z_ids)) else: top_bot_MI = sorted_MI_z_x top_bot_z_ids = sorted_z_ids # ================================= # H_z1z2_mean_mat = np.full( [len(top_bot_z_ids), len(top_bot_z_ids)], -1, dtype=np.float32) MI_z1z2_mean_mat = np.full( [len(top_bot_z_ids), len(top_bot_z_ids)], -1, dtype=np.float32) H_z1z2_mean = [] MI_z1z2_mean = [] z1z2_ids = [] # Compute the mutual information # ================================= # for i in range(len(top_bot_z_ids)): z_idx1 = top_bot_z_ids[i] H_s1 = H_z[z_idx1] for j in range(i + 1, len(top_bot_z_ids)): z_idx2 = top_bot_z_ids[j] H_s2 = H_z[z_idx2] print_("") print_("Compute MI(z{}_mean, z{}_mean)!".format(z_idx1, z_idx2)) s1s2_mean_counter = np.zeros([num_bins, num_bins], dtype=np.int32) for batch_ids in iterate_data(len(all_z_mean), 100, shuffle=False, include_remaining=True): s1 = at_bin(all_z_mean[batch_ids, z_idx1], bins, one_hot=False) s2 = at_bin(all_z_mean[batch_ids, z_idx2], bins, one_hot=False) for s1_, s2_ in zip(s1, s2): s1s2_mean_counter[s1_, s2_] += 1 # I(s1, s2) = Q(s1, s2) * (log Q(s1, s2) - log Q(s1) log Q(s2)) # ---------------------------------- # Q_s1s2_mean = (s1s2_mean_counter * 1.0) / np.sum(s1s2_mean_counter).astype(np.float32) log_Q_s1s2_mean = np.log(np.maximum(Q_s1s2_mean, eps)) H_s1s2_mean = -np.sum(Q_s1s2_mean * log_Q_s1s2_mean) MI_s1s2_mean = H_s1 + H_s2 - H_s1s2_mean print_("H_s1: {}".format(H_s1)) print_("H_s2: {}".format(H_s2)) print_("H_s1s2_mean: {}".format(H_s1s2_mean)) print_("MI_s1s2_mean: {}".format(MI_s1s2_mean)) H_z1z2_mean.append(H_s1s2_mean) MI_z1z2_mean.append(MI_s1s2_mean) z1z2_ids.append((z_idx1, z_idx2)) H_z1z2_mean_mat[i, j] = H_s1s2_mean H_z1z2_mean_mat[j, i] = H_s1s2_mean MI_z1z2_mean_mat[i, j] = MI_s1s2_mean MI_z1z2_mean_mat[j, i] = MI_s1s2_mean H_z1z2_mean = np.asarray(H_z1z2_mean, dtype=np.float32) MI_z1z2_mean = np.asarray(MI_z1z2_mean, dtype=np.float32) z1z2_ids = np.asarray(z1z2_ids, dtype=np.int32) result_file = join( save_dir, "results[bins={},bin_limits={},data={},k={}].npz".format( num_bins, bin_limits, data_proportion, top_k)) results = { 'H_z1z2_mean': H_z1z2_mean, 'MI_z1z2_mean': MI_z1z2_mean, 'H_z1z2_mean_mat': H_z1z2_mean_mat, 'MI_z1z2_mean_mat': MI_z1z2_mean_mat, 'z1z2_ids': z1z2_ids, } np.savez_compressed(result_file, **results) # ================================= # f.close()
def cond_2_latents_traverse(self, save_file, sess, x, z_comp1, span1, num_points_one_side1, z_comp2, span2, num_points_one_side2, batch_size=20, hl_color="red", hl_width=1, dec_output_2_img_func=None, enc_kwargs={}, dec_kwargs={}): """ x: A SINGLE input image that we condition on z_comp1: An integer, specifying which z component we want to plot z_span1: The distance from the center value of z1 when we encode x z_comp2: An integer, specifying the other z component we want to plot z_span2: The distance from the center value of z2 when we encode x num_itpl_points_4_one_side: We have 2 sides and a conditional input x in the middle. This value describe the number of interpolation points we want for each side """ assert np.shape(x) == tuple( self.x_shape), "'x' must be a single instance!" # (1, x_dim) x_ = np.expand_dims(x, axis=0) # Compute z # ----------------------------- # # (1, z_dim) z = self.encode(sess, x_, **enc_kwargs) assert z.shape[0] == 1 # (z_dim, ) z = np.reshape(z, [int(np.prod(self.z_shape))]) # ----------------------------- # # Compute z meshgrid # ----------------------------- # # Compute 'z_range1' and 'z_range2' # (num_rows * num_cols, z_dim) z12_meshgrid, center_idx = get_meshgrid( (z[z_comp1], z[z_comp2]), (span1, span2), (num_points_one_side1, num_points_one_side2), return_center_idx=True) num_rows = 2 * num_points_one_side1 + 1 num_cols = 2 * num_points_one_side2 + 1 assert len(z12_meshgrid) == num_rows * num_cols z_meshgrid = np.tile(np.expand_dims(z, axis=0), [num_rows * num_cols, 1]) for i in range(num_rows * num_cols): z_meshgrid[i, z_comp1] = z12_meshgrid[i, 0] z_meshgrid[i, z_comp2] = z12_meshgrid[i, 1] z_meshgrid = np.reshape(z_meshgrid, [num_rows * num_cols] + self.z_shape) # ----------------------------- # # Reconstruct x meshgrid # ----------------------------- # if batch_size < 0: x_meshgrid = self.decode(sess, z_meshgrid, **dec_kwargs) else: x_meshgrid = [] for batch_ids in iterate_data(len(z_meshgrid), batch_size, shuffle=False): x_meshgrid.append( self.decode(sess, z_meshgrid[batch_ids], **dec_kwargs)) x_meshgrid = np.concatenate(x_meshgrid, axis=0) x_meshgrid[center_idx] = x x_meshgrid = np.reshape(x_meshgrid, [num_rows, num_cols] + self.x_shape) if dec_output_2_img_func is not None: x_meshgrid = dec_output_2_img_func(x_meshgrid) center_block_idx = (center_idx / num_cols, center_idx % num_cols) save_img_block_highlighted(save_file, x_meshgrid, [center_block_idx], hl_color=hl_color, hl_width=hl_width)
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== celebA_loader = TFCelebAWithAttrLoader(root_dir=args.celebA_root_dir) img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) # ===================================== # Instantiate model # ===================================== if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format( args.activation)) if args.enc_dec_model == "1Konny": # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny() else: raise ValueError("Do not support encoder/decoder model '{}'!".format( args.enc_dec_model)) model = AAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, stochastic_z=args.stochastic_z, use_gp0_z=True, gp0_z_mode=args.gp0_z_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'G_loss_z1_gen': args.G_loss_z1_gen_coeff, 'D_loss_z1_gen': args.D_loss_z1_gen_coeff, 'gp0_z': args.gp0_z_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments # save_dir = remove_dir_if_exist(join(args.save_dir, "AAE_{}".format(args.run)), ask_4_permission=False) # save_dir = make_dir_if_not_exist(save_dir) save_dir = make_dir_if_not_exist( join(args.save_dir, "AAE_{}".format(args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3, suppress=True) num_bins = args.num_bins bin_limits = tuple([float(s) for s in args.bin_limits.split(";")]) data_proportion = args.data_proportion num_data = int(data_proportion * celebA_loader.num_train_data) eps = 1e-8 # file f = open(join( save_dir, 'log[bins={},bin_limits={},data={}].txt'.format( num_bins, bin_limits, data_proportion)), mode='w') # print function print_ = functools.partial(print_both, file=f) ''' if attr_type == 0: attr_names = celebA_loader.attributes elif attr_type == 1: attr_names = ['Male', 'Black_Hair', 'Blond_Hair', 'Straight_Hair', 'Wavy_Hair', 'Bald', 'Oval_Face', 'Big_Nose', 'Chubby', 'Double_Chin', 'Goatee', 'No_Beard', 'Mouth_Slightly_Open', 'Smiling', 'Eyeglasses', 'Pale_Skin'] else: raise ValueError("Only support factor_type=0 or 1!") ''' print_("num_bins: {}".format(num_bins)) print_("bin_limits: {}".format(bin_limits)) print_("data_proportion: {}".format(data_proportion)) # Compute bins # ================================= # print_("") print_("bin_limits: {}".format(bin_limits)) assert len(bin_limits) == 2 and bin_limits[0] < bin_limits[ 1], "bin_limits={}".format(bin_limits) bins = np.linspace(bin_limits[0], bin_limits[1], num_bins + 1, endpoint=True) print_("bins: {}".format(bins)) assert len(bins) == num_bins + 1 bin_widths = [bins[b] - bins[b - 1] for b in range(1, len(bins))] print_("bin_widths: {}".format(bin_widths)) assert len(bin_widths ) == num_bins, "len(bin_widths)={} while num_bins={}!".format( len(bin_widths), num_bins) assert np.all(np.greater(bin_widths, 0)), "bin_widths: {}".format(bin_widths) bin_centers = [(bins[b] + bins[b - 1]) * 0.5 for b in range(1, len(bins))] print_("bin_centers: {}".format(bin_centers)) assert len(bin_centers ) == num_bins, "len(bin_centers)={} while num_bins={}!".format( len(bin_centers), num_bins) # ================================= # # Compute representations # ================================= # z_data_attr_file = join(save_dir, "z_data[data={}].npz".format(data_proportion)) if not exists(z_data_attr_file): all_z_mean = [] all_z_stddev = [] all_attrs = [] print("") print("Compute all_z_mean, all_z_stddev and all_attrs!") count = 0 for batch_ids in iterate_data(num_data, 10 * args.batch_size, shuffle=False): x = celebA_loader.sample_images_from_dataset( sess, 'train', batch_ids) attrs = celebA_loader.sample_attrs_from_dataset('train', batch_ids) assert attrs.shape[1] == celebA_loader.num_attributes z_mean, z_stddev = sess.run(model.get_output( ['z_mean', 'z_stddev']), feed_dict={ model.is_train: False, model.x_ph: x }) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) all_attrs.append(attrs) count += len(batch_ids) print("\rProcessed {} samples!".format(count), end="") print() all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) all_attrs = np.concatenate(all_attrs, axis=0) np.savez_compressed(z_data_attr_file, all_z_mean=all_z_mean, all_z_stddev=all_z_stddev, all_attrs=all_attrs) else: print("{} exists. Load data from file!".format(z_data_attr_file)) with np.load(z_data_attr_file, "r") as f: all_z_mean = f['all_z_mean'] all_z_stddev = f['all_z_stddev'] all_attrs = f['all_attrs'] print_("") print_("all_z_mean.shape: {}".format(all_z_mean.shape)) print_("all_z_stddev.shape: {}".format(all_z_stddev.shape)) print_("all_attrs.shape: {}".format(all_attrs.shape)) # ================================= # # Compute the probability mass function for ground truth factors # ================================= # num_attrs = all_attrs.shape[1] assert all_attrs.dtype == np.bool all_attrs = all_attrs.astype(np.int32) # (num_samples, num_attrs, 2) # The first component is 1 and the last component is 0 all_Q_y_cond_x = np.stack([all_attrs, 1 - all_attrs], axis=-1) # ================================= # # Compute Q(zi|x) # Compute I(zi, yk) # ================================= # Q_z_y = np.zeros([args.z_dim, num_attrs, num_bins, 2], dtype=np.float32) MI_z_y = np.zeros([args.z_dim, num_attrs], dtype=np.float32) H_z_y = np.zeros([args.z_dim, num_attrs], dtype=np.float32) H_z_4_diff_y = np.zeros([args.z_dim, num_attrs], dtype=np.float32) H_y_4_diff_z = np.zeros([num_attrs, args.z_dim], dtype=np.float32) for i in range(args.z_dim): print_("") print_("Compute all_Q_z{}_cond_x!".format(i)) # Q_s_cond_x all_Q_s_cond_x = [] for batch_ids in iterate_data(len(all_z_mean), 500, shuffle=False, include_remaining=True): # (batch_size, num_bins) q_s_cond_x = normal_density( np.expand_dims(bin_centers, axis=0), mean=np.expand_dims(all_z_mean[batch_ids, i], axis=-1), stddev=np.expand_dims(all_z_stddev[batch_ids, i], axis=-1)) # (batch_size, num_bins) max_q_s_cond_x = np.max(q_s_cond_x, axis=-1) # print("\nmax_q_s_cond_x: {}".format(np.sort(max_q_s_cond_x))) # (batch_size, num_bins) deter_s_cond_x = at_bin(all_z_mean[batch_ids, i], bins).astype(np.float32) # (batch_size, num_bins) Q_s_cond_x = q_s_cond_x * np.expand_dims(bin_widths, axis=0) Q_s_cond_x = Q_s_cond_x / np.maximum( np.sum(Q_s_cond_x, axis=1, keepdims=True), eps) # print("sort(sum(Q_s_cond_x)) (before): {}".format(np.sort(np.sum(Q_s_cond_x, axis=-1)))) Q_s_cond_x = np.where( np.expand_dims(np.less(max_q_s_cond_x, 1e-5), axis=-1), deter_s_cond_x, Q_s_cond_x) # print("sort(sum(Q_s_cond_x)) (after): {}".format(np.sort(np.sum(Q_s_cond_x, axis=-1)))) all_Q_s_cond_x.append(Q_s_cond_x) # (num_samples, num_bins) all_Q_s_cond_x = np.concatenate(all_Q_s_cond_x, axis=0) assert np.all(all_Q_s_cond_x >= 0), "'all_Q_s_cond_x' contains negative values. " \ "sorted_all_Q_s_cond_x[:30]:\n{}!".format( np.sort(all_Q_s_cond_x[:30], axis=None)) assert len(all_Q_s_cond_x) == len( all_attrs), "all_Q_s_cond_x.shape={}, all_attrs.shape={}".format( all_Q_s_cond_x.shape, all_attrs.shape) # I(z, y) for k in range(num_attrs): # Compute Q(zi, yk) # -------------------------------- # # (z_dim, 2) Q_zi_yk = np.matmul(np.transpose(all_Q_s_cond_x, axes=[1, 0]), all_Q_y_cond_x[:, k, :]) Q_zi_yk = Q_zi_yk / len(all_Q_y_cond_x) Q_zi_yk = Q_zi_yk / np.maximum(np.sum(Q_zi_yk), eps) assert np.all(Q_zi_yk >= 0), "'Q_zi_yk' contains negative values. " \ "sorted_Q_zi_yk[:10]:\n{}!".format(np.sort(Q_zi_yk, axis=None)) log_Q_zi_yk = np.log(np.clip(Q_zi_yk, eps, 1 - eps)) Q_z_y[i, k] = Q_zi_yk print_("sum(Q_zi_yk): {}".format(np.sum(Q_zi_yk))) # -------------------------------- # # Compute Q_z # -------------------------------- # Q_zi = np.sum(Q_zi_yk, axis=1) log_Q_zi = np.log(np.clip(Q_zi, eps, 1 - eps)) print_("sum(Q_z{}): {}".format(i, np.sum(Q_zi))) print_("Q_z{}: {}".format(i, Q_zi)) # -------------------------------- # # Compute Q_y # -------------------------------- # Q_yk = np.sum(Q_zi_yk, axis=0) log_Q_yk = np.log(np.clip(Q_yk, eps, 1 - eps)) print_("sum(Q_y{}): {}".format(k, np.sum(Q_yk))) print_("Q_y{}: {}".format(k, np.sum(Q_yk))) # -------------------------------- # MI_zi_yk = Q_zi_yk * (log_Q_zi_yk - np.expand_dims( log_Q_zi, axis=-1) - np.expand_dims(log_Q_yk, axis=0)) MI_zi_yk = np.sum(MI_zi_yk) H_zi_yk = -np.sum(Q_zi_yk * log_Q_zi_yk) H_zi = -np.sum(Q_zi * log_Q_zi) H_yk = -np.sum(Q_yk * log_Q_yk) MI_z_y[i, k] = MI_zi_yk H_z_y[i, k] = H_zi_yk H_z_4_diff_y[i, k] = H_zi H_y_4_diff_z[k, i] = H_yk # ================================= # print_("") print_("MI_z_y:\n{}".format(MI_z_y)) print_("H_z_y:\n{}".format(H_z_y)) print_("H_z_4_diff_y:\n{}".format(H_z_4_diff_y)) print_("H_y_4_diff_z:\n{}".format(H_y_4_diff_z)) # Compute metric # ================================= # # Sorted in decreasing order MI_ids_sorted = np.argsort(MI_z_y, axis=0)[::-1] MI_sorted = np.take_along_axis(MI_z_y, MI_ids_sorted, axis=0) MI_gap_y = np.divide(MI_sorted[0, :] - MI_sorted[1, :], H_y_4_diff_z[:, 0]) MIG = np.mean(MI_gap_y) print_("") print_("MI_sorted: {}".format(MI_sorted)) print_("MI_ids_sorted: {}".format(MI_ids_sorted)) print_("MI_gap_y: {}".format(MI_gap_y)) print_("MIG: {}".format(MIG)) results = { 'Q_z_y': Q_z_y, 'MI_z_y': MI_z_y, 'H_z_y': H_z_y, 'H_z_4_diff_y': H_z_4_diff_y, 'H_y_4_diff_z': H_y_4_diff_z, 'MI_sorted': MI_sorted, 'MI_ids_sorted': MI_ids_sorted, 'MI_gap_y': MI_gap_y, 'MIG': MIG, } result_file = join( save_dir, 'results[bins={},bin_limits={},data={}].npz'.format( num_bins, bin_limits, data_proportion)) np.savez_compressed(result_file, **results) # ================================= # f.close()
def main(args): # ===================================== # Load config # ===================================== with open(os.path.join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Preparation # ===================================== data_file = os.path.join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] x_train = np.expand_dims(x_train.astype(np.float32), axis=-1) # ===================================== # Instantiate models # ===================================== if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny() else: raise ValueError("Do not support enc_dec_model='{}'!".format( args.enc_dec_model)) model = AAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, stochastic_z=args.stochastic_z, use_gp0_z=True, gp0_z_mode=args.gp0_z_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'G_loss_z1_gen': args.G_loss_z1_gen_coeff, 'D_loss_z1_gen': args.D_loss_z1_gen_coeff, 'gp0_z': args.gp0_z_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_list() # ===================================== # TF Graph Handler asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset")) img_eval = remove_dir_if_exist(os.path.join(asset_dir, "img_eval"), ask_4_permission=False) img_eval = make_dir_if_not_exist(img_eval) img_x_rec = make_dir_if_not_exist(os.path.join(img_eval, "x_rec")) img_z_rand_2_traversal = make_dir_if_not_exist( os.path.join(img_eval, "z_rand_2_traversal")) img_z_cond_all_traversal = make_dir_if_not_exist( os.path.join(img_eval, "z_cond_all_traversal")) img_z_cond_1_traversal = make_dir_if_not_exist( os.path.join(img_eval, "z_cond_1_traversal")) img_z_corr = make_dir_if_not_exist(os.path.join(img_eval, "z_corr")) img_z_dist = make_dir_if_not_exist(os.path.join(img_eval, "z_dist")) img_z_stat_dist = make_dir_if_not_exist( os.path.join(img_eval, "z_stat_dist")) img_rec_error_dist = make_dir_if_not_exist( os.path.join(img_eval, "rec_error_dist")) model_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # ===================================== # ===================================== # Training Loop # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) # Load model train_helper.load(sess, load_step=args.load_step) #''' # Reconstruction # ======================================= # seed = 389 x = x_train[np.arange(seed, seed + 64)] img_file = os.path.join(img_x_rec, 'x_rec_train.png') model.reconstruct_images(img_file, sess, x, block_shape=[8, 8], batch_size=-1, dec_output_2_img_func=binary_float_to_uint8) # ======================================= # # z random/conditional traversal # ======================================= # # Plot z cont with z cont z_zero = np.zeros([args.z_dim], dtype=np.float32) z_rand = np.random.randn(args.z_dim) z_start, z_stop = -4, 4 num_points = 8 for i in range(args.z_dim): for j in range(i + 1, args.z_dim): print("Plot random 2 comps z traverse with {} and {} components!". format(i, j)) img_file = os.path.join(img_z_rand_2_traversal, 'z[{},{},zero].png'.format(i, j)) model.rand_2_latents_traverse( img_file, sess, default_z=z_zero, z_comp1=i, start1=z_start, stop1=z_stop, num_points1=num_points, z_comp2=j, start2=z_start, stop2=z_stop, num_points2=num_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) img_file = os.path.join(img_z_rand_2_traversal, 'z[{},{},rand].png'.format(i, j)) model.rand_2_latents_traverse( img_file, sess, default_z=z_rand, z_comp1=i, start1=z_start, stop1=z_stop, num_points1=num_points, z_comp2=j, start2=z_stop, stop2=z_stop, num_points2=num_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) seed = 389 z_start, z_stop = -4, 4 num_itpl_points = 8 for n in range(seed, seed + 30): print("Plot conditional all comps z traverse with test sample {}!". format(n)) x = x_train[n] img_file = os.path.join(img_z_cond_all_traversal, 'x_train{}.png'.format(n)) model.cond_all_latents_traverse( img_file, sess, x, start=z_start, stop=z_stop, num_itpl_points=num_itpl_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) seed = 64 z_start, z_stop = -4, 4 num_itpl_points = 8 print("Plot conditional 1 comp z traverse!") for i in range(args.z_dim): x = x_train[seed:seed + 64] img_file = os.path.join( img_z_cond_1_traversal, 'x_train[{},{}]_z{}.png'.format(seed, seed + 64, i)) model.cond_1_latent_traverse( img_file, sess, x, z_comp=i, start=z_start, stop=z_stop, num_itpl_points=num_itpl_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ======================================= # # ''' # z correlation matrix # ======================================= # data = x_train all_z = [] for batch_ids in iterate_data(len(data), args.batch_size, shuffle=False): x = data[batch_ids] z = model.encode(sess, x) assert len( z.shape) == 2 and z.shape[1] == args.z_dim, "z.shape: {}".format( z.shape) all_z.append(z) all_z = np.concatenate(all_z, axis=0) print("Start plotting!") plot_corrmat_with_histogram(os.path.join(img_z_corr, "corr_mat.png"), all_z) plot_comp_dist(os.path.join(img_z_dist, 'z_{}'), all_z, x_lim=(-5, 5)) print("Done!") # ======================================= # # ''' # z gaussian stddev # ======================================= # print("\nPlot z mean and stddev!") data = x_train all_z_mean = [] all_z_stddev = [] for batch_ids in iterate_data(len(data), args.batch_size, shuffle=False): x = data[batch_ids] z_mean, z_stddev = sess.run(model.get_output(['z_mean', 'z_stddev']), feed_dict={ model.is_train: False, model.x_ph: x }) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) plot_comp_dist(os.path.join(img_z_stat_dist, 'z_mean_{}.png'), all_z_mean, x_lim=(-5, 5)) plot_comp_dist(os.path.join(img_z_stat_dist, 'z_stddev_{}.png'), all_z_stddev, x_lim=(0, 3)) # ======================================= # # ''' # Decoder sensitivity # ======================================= # z_start = -3 z_stop = 3 for i in range(args.z_dim): print("\nPlot rec error distribution for z component {}!".format(i)) all_z1 = np.array(all_z, copy=True, dtype=np.float32) all_z2 = np.array(all_z, copy=True, dtype=np.float32) all_z1[:, i] = z_start all_z2[:, i] = z_stop all_x_rec1 = [] all_x_rec2 = [] for batch_ids in iterate_data(len(x_train), args.batch_size, shuffle=False): z1 = all_z1[batch_ids] z2 = all_z2[batch_ids] x1 = model.decode(sess, z1) x2 = model.decode(sess, z2) all_x_rec1.append(x1) all_x_rec2.append(x2) all_x_rec1 = np.concatenate(all_x_rec1, axis=0) all_x_rec2 = np.concatenate(all_x_rec2, axis=0) rec_errors = np.sum(np.reshape((all_x_rec1 - all_x_rec2)**2, [len(x_train), 28 * 28]), axis=1) plot_comp_dist( os.path.join( img_rec_error_dist, 'rec_error[zi={},{},{}].png'.format(i, z_start, z_stop)), rec_errors)
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] # 3 shape * 6 scale * 40 rotation * 32 pos X * 32 pos Y y_train = f['latents_classes'] x_train = np.expand_dims(x_train.astype(np.float32), axis=-1) num_train = len(x_train) print("num_train: {}".format(num_train)) print("y_train[:10]: {}".format(y_train[:10])) # ===================================== # Instantiate model # ===================================== if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny() else: raise ValueError("Do not support enc_dec_model='{}'!".format( args.enc_dec_model)) model = AAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, stochastic_z=args.stochastic_z, use_gp0_z=True, gp0_z_mode=args.gp0_z_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'G_loss_z1_gen': args.G_loss_z1_gen_coeff, 'D_loss_z1_gen': args.D_loss_z1_gen_coeff, 'gp0_z': args.gp0_z_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments save_dir = make_dir_if_not_exist( join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=5, suppress=True) num_bins = args.num_bins bin_limits = tuple([float(s) for s in args.bin_limits.split(";")]) data_proportion = args.data_proportion num_data = int(data_proportion * num_train) assert num_data == num_train, "For dSprites, you must use all data!" eps = 1e-8 # file f = open(join( save_dir, 'log[bins={},bin_limits={},data={}].txt'.format( num_bins, bin_limits, data_proportion)), mode='w') # print function print_ = functools.partial(print_both, file=f) print_("num_bins: {}".format(num_bins)) print_("bin_limits: {}".format(bin_limits)) print_("data_proportion: {}".format(data_proportion)) # Compute bins # ================================= # print_("") print_("bin_limits: {}".format(bin_limits)) assert len(bin_limits) == 2 and bin_limits[0] < bin_limits[ 1], "bin_limits={}".format(bin_limits) bins = np.linspace(bin_limits[0], bin_limits[1], num_bins + 1, endpoint=True) print_("bins: {}".format(bins)) assert len(bins) == num_bins + 1 bin_widths = [bins[b] - bins[b - 1] for b in range(1, len(bins))] print_("bin_widths: {}".format(bin_widths)) assert len(bin_widths ) == num_bins, "len(bin_widths)={} while num_bins={}!".format( len(bin_widths), num_bins) assert np.all(np.greater(bin_widths, 0)), "bin_widths: {}".format(bin_widths) bin_centers = [(bins[b] + bins[b - 1]) * 0.5 for b in range(1, len(bins))] print_("bin_centers: {}".format(bin_centers)) assert len(bin_centers ) == num_bins, "len(bin_centers)={} while num_bins={}!".format( len(bin_centers), num_bins) # ================================= # # Compute representations # ================================= # z_data_file = join(save_dir, "z_data[data={}].npz".format(data_proportion)) if not exists(z_data_file): all_z_mean = [] all_z_stddev = [] print("") print("Compute all_z_mean, all_z_stddev and all_attrs!") count = 0 for batch_ids in iterate_data(num_data, 10 * args.batch_size, shuffle=False): x = x_train[batch_ids] z_mean, z_stddev = sess.run(model.get_output( ['z_mean', 'z_stddev']), feed_dict={ model.is_train: False, model.x_ph: x }) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) count += len(batch_ids) print("\rProcessed {} samples!".format(count), end="") print() all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) np.savez_compressed(z_data_file, all_z_mean=all_z_mean, all_z_stddev=all_z_stddev) else: print("{} exists. Load data from file!".format(z_data_file)) with np.load(z_data_file, "r") as f: all_z_mean = f['all_z_mean'] all_z_stddev = f['all_z_stddev'] # ================================= # print_("") all_Q_z_cond_x = [] for i in range(args.z_dim): print_("\nCompute all_Q_z{}_cond_x!".format(i)) all_Q_s_cond_x = [] for batch_ids in iterate_data(len(all_z_mean), 500, shuffle=False, include_remaining=True): # (batch_size, num_bins) q_s_cond_x = normal_density( np.expand_dims(bin_centers, axis=0), mean=np.expand_dims(all_z_mean[batch_ids, i], axis=-1), stddev=np.expand_dims(all_z_stddev[batch_ids, i], axis=-1)) # (batch_size, num_bins) max_q_s_cond_x = np.max(q_s_cond_x, axis=-1) # print("\nmax_q_s_cond_x: {}".format(np.sort(max_q_s_cond_x))) # (batch_size, num_bins) deter_s_cond_x = at_bin(all_z_mean[batch_ids, i], bins).astype(np.float32) # (batch_size, num_bins) Q_s_cond_x = q_s_cond_x * np.expand_dims(bin_widths, axis=0) Q_s_cond_x = Q_s_cond_x / np.maximum( np.sum(Q_s_cond_x, axis=1, keepdims=True), eps) # print("sort(sum(Q_s_cond_x)) (before): {}".format(np.sort(np.sum(Q_s_cond_x, axis=-1)))) Q_s_cond_x = np.where( np.expand_dims(np.less(max_q_s_cond_x, 1e-5), axis=-1), deter_s_cond_x, Q_s_cond_x) # print("sort(sum(Q_s_cond_x)) (after): {}".format(np.sort(np.sum(Q_s_cond_x, axis=-1)))) all_Q_s_cond_x.append(Q_s_cond_x) # (num_samples, num_bins) all_Q_s_cond_x = np.concatenate(all_Q_s_cond_x, axis=0) assert np.all(all_Q_s_cond_x >= 0), "'all_Q_s_cond_x' contains negative values. " \ "sorted_all_Q_s_cond_x[:30]:\n{}!".format(np.sort(all_Q_s_cond_x[:30], axis=None)) assert len(all_Q_s_cond_x) == num_train all_Q_z_cond_x.append(all_Q_s_cond_x) # (z_dim, num_samples, num_bins) all_Q_z_cond_x = np.asarray(all_Q_z_cond_x, dtype=np.float32) print_("all_Q_z_cond_x.shape: {}".format(all_Q_z_cond_x.shape)) print_("sum(all_Q_z_cond_x)[:, :10]:\n{}".format( np.sum(all_Q_z_cond_x, axis=-1)[:, :10])) # (z_dim, num_bins) Q_z = np.mean(all_Q_z_cond_x, axis=1) log_Q_z = np.log(np.clip(Q_z, eps, 1 - eps)) print_("Q_z.shape: {}".format(Q_z.shape)) print_("sum(Q_z): {}".format(np.sum(Q_z, axis=-1))) # (z_dim, ) H_z = -np.sum(Q_z * log_Q_z, axis=-1) # Factors gt_factors = ['shape', 'scale', 'rotation', 'pos_x', 'pos_y'] gt_num_values = [3, 6, 40, 32, 32] MI_z_y = np.zeros([args.z_dim, len(gt_factors)], dtype=np.float32) H_z_y = np.zeros([args.z_dim, len(gt_factors)], dtype=np.float32) ids_sorted = np.zeros([args.z_dim, len(gt_factors)], dtype=np.int32) MI_z_y_sorted = np.zeros([args.z_dim, len(gt_factors)], dtype=np.float32) H_z_y_sorted = np.zeros([args.z_dim, len(gt_factors)], dtype=np.float32) H_y = [] RMIG = [] JEMMI = [] for k, (factor, num_values) in enumerate(zip(gt_factors, gt_num_values)): print_("\n#" + "=" * 50 + "#") print_("The {}-th gt factor '{}' has {} values!".format( k, factor, num_values)) print_("") # (num_samples, num_categories) # NOTE: We must use k+1 to account for the 'color' attribute, which is always white all_Q_yk_cond_x = one_hot(y_train[:, k + 1], num_categories=num_values, dtype=np.float32) print_("all_Q_yk_cond_x.shape: {}".format(all_Q_yk_cond_x.shape)) # (num_categories) Q_yk = np.mean(all_Q_yk_cond_x, axis=0) log_Q_yk = np.log(np.clip(Q_yk, eps, 1 - eps)) print_("Q_yk.shape: {}".format(Q_yk.shape)) H_yk = -np.sum(Q_yk * log_Q_yk) print_("H_yk: {}".format(H_yk)) H_y.append(H_yk) Q_z_yk = np.zeros([args.z_dim, num_bins, num_values], dtype=np.float32) # Compute I(zi, yk) for i in range(args.z_dim): print_("\n#" + "-" * 50 + "#") all_Q_zi_cond_x = all_Q_z_cond_x[i] assert len(all_Q_zi_cond_x) == len(all_Q_yk_cond_x) == num_train, \ "all_Q_zi_cond_x.shape: {}, all_Q_yk_cond_x.shape: {}".format( all_Q_zi_cond_x.shape, all_Q_yk_cond_x.shape) # (num_bins, num_categories) Q_zi_yk = np.matmul(np.transpose(all_Q_zi_cond_x, axes=[1, 0]), all_Q_yk_cond_x) Q_zi_yk = Q_zi_yk / num_train print_("np.sum(Q_zi_yk): {}".format(np.sum(Q_zi_yk))) Q_zi_yk = Q_zi_yk / np.maximum(np.sum(Q_zi_yk), eps) print_("np.sum(Q_zi_yk) (normalized): {}".format(np.sum(Q_zi_yk))) assert np.all(Q_zi_yk >= 0), "'Q_zi_yk' contains negative values. " \ "sorted_Q_zi_yk[:10]:\n{}!".format(np.sort(Q_zi_yk, axis=None)) # (num_bins, num_categories) log_Q_zi_yk = np.log(np.clip(Q_zi_yk, eps, 1 - eps)) print_("") print_("Q_zi (default): {}".format(Q_z[i])) print_("Q_zi (sum of Q_zi_yk over yk): {}".format( np.sum(Q_zi_yk, axis=-1))) print_("") print_("Q_yk (default): {}".format(Q_yk)) print_("Q_yk (sum of Q_zi_yk over zi): {}".format( np.sum(Q_zi_yk, axis=0))) MI_zi_yk = Q_zi_yk * (log_Q_zi_yk - np.expand_dims( log_Q_z[i], axis=-1) - np.expand_dims(log_Q_yk, axis=0)) MI_zi_yk = np.sum(MI_zi_yk) H_zi_yk = -np.sum(Q_zi_yk * log_Q_zi_yk) Q_z_yk[i] = Q_zi_yk MI_z_y[i, k] = MI_zi_yk H_z_y[i, k] = H_zi_yk print_("#" + "-" * 50 + "#") # Print statistics for all z print_("") print_("MI_z_yk:\n{}".format(MI_z_y[:, k])) print_("H_z_yk:\n{}".format(H_z_y[:, k])) print_("H_z:\n{}".format(H_z)) print_("H_yk:\n{}".format(H_yk)) # Compute RMIG and JEMMI ids_yk_sorted = np.argsort(MI_z_y[:, k], axis=0)[::-1] MI_z_yk_sorted = np.take_along_axis(MI_z_y[:, k], ids_yk_sorted, axis=0) H_z_yk_sorted = np.take_along_axis(H_z_y[:, k], ids_yk_sorted, axis=0) RMIG_yk = np.divide(MI_z_yk_sorted[0] - MI_z_yk_sorted[1], H_yk) JEMMI_yk = np.divide( H_z_yk_sorted[0] - MI_z_yk_sorted[0] + MI_z_yk_sorted[1], H_yk + np.log(num_bins)) ids_sorted[:, k] = ids_yk_sorted MI_z_y_sorted[:, k] = MI_z_yk_sorted H_z_y_sorted[:, k] = H_z_yk_sorted RMIG.append(RMIG_yk) JEMMI.append(JEMMI_yk) print_("") print_("ids_sorted: {}".format(ids_sorted)) print_("MI_z_yk_sorted: {}".format(MI_z_yk_sorted)) print_("RMIG_yk: {}".format(RMIG_yk)) print_("JEMMI_yk: {}".format(JEMMI_yk)) z_yk_prob_file = join( save_dir, "z_yk_prob_4_{}[bins={},bin_limits={},data={}].npz".format( factor, num_bins, bin_limits, data_proportion)) np.savez_compressed(z_yk_prob_file, Q_z_yk=Q_z_yk) print_("#" + "=" * 50 + "#") results = { "MI_z_y": MI_z_y, "H_z_y": H_z_y, "ids_sorted": ids_sorted, "MI_z_y_sorted": MI_z_y_sorted, "H_z_y_sorted": H_z_y_sorted, "H_z": H_z, "H_y": np.asarray(H_y, dtype=np.float32), "RMIG": np.asarray(RMIG, dtype=np.float32), "JEMMI": np.asarray(JEMMI, dtype=np.float32), } result_file = join( save_dir, "results[bins={},bin_limits={},data={}].npz".format( num_bins, bin_limits, data_proportion)) np.savez_compressed(result_file, **results) f.close()
def interpolate_images(self, save_file, sess, x1, x2, num_itpl_points, batch_on_row=True, batch_size=20, dec_output_2_img_func=None, enc_kwargs={}, dec_kwargs={}): if batch_size < 0: z1 = self.encode(sess, x1, **enc_kwargs) z2 = self.encode(sess, x2, **enc_kwargs) else: z1, z2 = [], [] for batch_ids in iterate_data(len(x1), batch_size, shuffle=False): z1.append(self.encode(sess, x1[batch_ids], **enc_kwargs)) z2.append(self.encode(sess, x2[batch_ids], **enc_kwargs)) z1 = np.concatenate(z1, axis=0) z2 = np.concatenate(z2, axis=0) z1_flat = np.ravel(z1) z2_flat = np.ravel(z2) zs_itpl = [] for i in range(1, num_itpl_points + 1): zi_flat = z1_flat + (i * 1.0 / (num_itpl_points + 1)) * (z2_flat - z1_flat) zs_itpl.append(zi_flat) # (num_itpl_points, batch_size * z_dim) zs_itpl = np.stack(zs_itpl, axis=0) # (num_itpl_points * batch_size, z_shape) zs_itpl = np.reshape(zs_itpl, [num_itpl_points * x1.shape[0]] + self.z_shape) if batch_size < 0: xs_itpl = self.decode(sess, zs_itpl, **dec_kwargs) else: xs_itpl = [] for batch_ids in iterate_data(len(zs_itpl), batch_size, shuffle=False): xs_itpl.append( self.decode(sess, zs_itpl[batch_ids], **dec_kwargs)) xs_itpl = np.concatenate(xs_itpl, axis=0) # (num_itpl_points, batch_size, x_dim) xs_itpl = np.reshape(xs_itpl, [num_itpl_points, x1.shape[0]] + self.x_shape) # (num_itpl_points + 2, batch_size, x_dim) xs_itpl = np.concatenate( [np.expand_dims(x1, axis=0), xs_itpl, np.expand_dims(x2, axis=0)], axis=0) if batch_on_row: xs_itpl = np.transpose(xs_itpl, [1, 0] + list(range(2, len(self.x_shape) + 2))) if dec_output_2_img_func is not None: xs_itpl = dec_output_2_img_func(xs_itpl) save_img_block(save_file, xs_itpl)
def plot_Z_itpl_bw_2Xs(self, save_file_prefix, sess, imgs_1, imgs_2, img_names_1, img_names_2, features, num_itpl_points=6, yx_types=('feature', 'itpl_point'), dec_output_2_img_func=None, img_ext='png', batch_size=-1): # img_ext # ---------------------------------------- # assert img_ext == 'png' or img_ext == 'jpg', "'img_ext' must be png or jpg!" # ---------------------------------------- # # coordinate # ---------------------------------------- # # For this kind of interpolation, the results will have 3 axes: # (num_inputs, num_features, num_itpl_points) # If we set mode == 'share_inputs', we will have 'num_inputs' block images # of shape (num_features, num_itpl_points) possible_coord_types = [('input', 'itpl_point'), ('feature', 'itpl_point'), ('itpl_point', 'input'), ('itpl_point', 'feature')] if isinstance(yx_types, tuple): assert len(yx_types) == 2, "'yx_types' must be a 2-tuples or " \ "a list of 2-tuples representing the yx coordinate types!" yx_types = [yx_types] assert isinstance(yx_types, list), "'yx_types' must be a 2-tuples or " \ "a list of 2-tuples representing the yx coordinate types!" assert all([yx_type in possible_coord_types for yx_type in yx_types]), \ "Only support the following coordinate types: {}".format(possible_coord_types) # ---------------------------------------- # # num_images # ---------------------------------------- # assert isinstance(imgs_1, np.ndarray) and imgs_1.ndim == 4, \ "'inp_imgs_1' must be a 4D numpy array of format (num_images, height, width, channels)!" assert isinstance(imgs_2, np.ndarray) and imgs_2.ndim == 4, \ "'inp_imgs_2' must be None or a 4D numpy array of format (num_images, height, width, channels)!" assert len(imgs_1) == len( imgs_2 ), "Number of images in 'inp_imgs_1' and 'inp_imgs_2' must be equal!" num_inputs = len(imgs_1) # ---------------------------------------- # # num_features # ---------------------------------------- # z_dim = int(np.prod(self.z_shape)) if features == 'all': features = [i for i in range(z_dim)] if isinstance(features, int): assert 0 <= features < z_dim, "'features' must be an integer or " \ "a list/tuple of integers in the range [0, {}]".format(z_dim - 1) features = [features] assert isinstance(features, (list, tuple)), "'features' must be an integer or " \ "a list/tuple of integers in the range [0, {}]".format(z_dim - 1) num_features = len(features) # ---------------------------------------- # # (num_images, z_dim) z1 = np.reshape(self.encode(sess, imgs_1), [num_inputs, z_dim]) z2 = np.reshape(self.encode(sess, imgs_2), [num_inputs, z_dim]) z_samples = [ ] # (num_features * num_itpl_points) of (num_images, z_dim) array for n in range(len(imgs_1)): for feature in features: # (num_itpl_points, ) itpl_points = np.linspace(z1[n, feature], z2[n, feature], num=num_itpl_points, endpoint=True) for itpl_point in itpl_points: z_copy = np.array(z1[n], dtype=z1.dtype, copy=True) z_copy[feature] = itpl_point z_samples.append(z_copy) # (num_inputs * num_features * num_itpl_points, z_dim) z_samples = np.stack(z_samples, axis=0) if batch_size < 0: z_samples = np.reshape( z_samples, [num_inputs * num_features * num_itpl_points] + self.z_shape) x_samples = self.decode(sess, z_samples) else: x_samples = [] for batch_ids in iterate_data(len(z_samples), batch_size, shuffle=False): x_samples.append( self.decode( sess, np.reshape(z_samples[batch_ids], [len(batch_ids)] + self.z_shape))) x_samples = np.concatenate(x_samples, axis=0) # (num_images, num_features, num_itpl_points) + x_shape x_samples = np.reshape(x_samples, [num_inputs, num_features, num_itpl_points] + self.x_shape) if dec_output_2_img_func is not None: x_samples = dec_output_2_img_func(x_samples) for yx_type in yx_types: if yx_type == ('feature', 'itpl_point'): x_itpl = x_samples assert img_names_1, "'inp_img_names_1' must be provided!" assert img_names_2, "'inp_img_names_2' must be provided!" save_file_postfixes = [ "-img[{}-{}].{}".format(img_names_1[i], img_names_2[i], img_ext) for i in range(len(x_itpl)) ] elif yx_type == ('itpl_point', 'feature'): x_itpl = np.transpose(x_samples, [0, 2, 1] + list(range(3, 3 + len(self.x_shape)))) assert img_names_1, "'inp_img_names_1' must be provided!" assert img_names_2, "'inp_img_names_2' must be provided!" save_file_postfixes = [ "-img[{}-{}].{}".format(img_names_1[i], img_names_2[i], img_ext) for i in range(len(x_itpl)) ] elif yx_type == ('input', 'itpl_point'): x_itpl = np.transpose(x_samples, [1, 0, 2] + list(range(3, 3 + len(self.x_shape)))) save_file_postfixes = [ "-feat[{}].{}".format(feature, img_ext) for feature in features ] elif yx_type == ('itpl_point', 'input'): x_itpl = np.transpose(x_samples, [1, 2, 0] + list(range(3, 3 + len(self.x_shape)))) save_file_postfixes = [ "-feat[{}].{}".format(feature, img_ext) for feature in features ] elif yx_type == (None, 'itpl_point'): # (num_images, num_features, num_itpl_points) + x_shape x_itpl = np.reshape( x_samples, [num_inputs * num_features, 1, num_itpl_points] + self.x_shape) save_file_postfixes = [ "-img[{}-{}]_feat[{}].{}".format(img_name_1, img_name_2, feature, img_ext) for img_name_1, img_name_2 in zip(img_names_1, img_name_2) for feature in features ] elif yx_type == ('itpl_point', None): # (num_images, num_features, num_itpl_points) + x_shape x_itpl = np.reshape( x_samples, [num_inputs * num_features, num_itpl_points, 1] + self.x_shape) save_file_postfixes = [ "-img[{}-{}]_feat[{}].{}".format(img_name_1, img_name_2, feature, img_ext) for img_name_1, img_name_2 in zip(img_names_1, img_name_2) for feature in features ] else: raise ValueError( "Only support the following coordinate types: {}".format( possible_coord_types)) for i in range(len(x_itpl)): save_file = save_file_prefix + save_file_postfixes[i] save_img_block(save_file, x_itpl[i])
def cond_1_latent_traverse(self, save_file, sess, x, z_comp, start=-3.0, stop=3.0, num_itpl_points=10, x_labels=None, hl_color="red", hl_width=1, subplot_adjust={}, batch_size=20, dec_output_2_img_func=None, enc_kwargs={}, dec_kwargs={}): assert num_itpl_points >= 2, "'num_points' must be >= 2!" itpl_points = [ start + (stop - start) * i * 1.0 / (num_itpl_points - 1) for i in range(0, num_itpl_points) ] assert (len(x.shape) == len(self.x_shape) + 1) and (x.shape[1:] == tuple(self.x_shape)), \ "'x' must contain batch dimension. Found x.shape={}!".format(x.shape) # Compute z # ----------------------------- # # (batch, z_shape) z = self.encode(sess, x, **enc_kwargs) # (batch, z_dim) z_dim = int(np.prod(self.z_shape)) z = np.reshape(z, [x.shape[0], z_dim]) # ----------------------------- # z_meshgrid = [] inserted_ids = [] for n in range(x.shape[0]): itpl_zn = [] idx = 0 for k in range(len(itpl_points)): z_copy = np.array(z[n], dtype=z.dtype, copy=True) z_copy[z_comp] = itpl_points[k] itpl_zn.append(z_copy) if (1 <= k ) and itpl_points[k - 1] <= z[n, z_comp] < itpl_points[k]: idx = k if itpl_points[len(itpl_points) - 1] <= z[n, z_comp]: idx = len(itpl_points) inserted_ids.append((n, idx)) itpl_zn.insert(idx, np.array(z[n], dtype=z.dtype, copy=True)) z_meshgrid.extend(itpl_zn) # Compute z meshgrid # ----------------------------- # num_rows = x.shape[0] num_cols = num_itpl_points + 1 assert len(z_meshgrid) == num_rows * num_cols z_meshgrid = np.reshape(z_meshgrid, [num_rows * num_cols] + self.z_shape) # ----------------------------- # # Reconstruct x meshgrid # ----------------------------- # if batch_size < 0: x_meshgrid = self.decode(sess, z_meshgrid, **dec_kwargs) else: x_meshgrid = [] for batch_ids in iterate_data(len(z_meshgrid), batch_size, shuffle=False): x_meshgrid.append( self.decode(sess, z_meshgrid[batch_ids], **dec_kwargs)) x_meshgrid = np.concatenate(x_meshgrid, axis=0) x_meshgrid = np.reshape(x_meshgrid, [num_rows, num_cols] + self.x_shape) for row_idx, col_idx in inserted_ids: x_meshgrid[row_idx, col_idx] = x[row_idx] if dec_output_2_img_func is not None: x_meshgrid = dec_output_2_img_func(x_meshgrid) if x_labels is not None: save_img_block_highlighted_with_ticklabels( save_file, x_meshgrid, hl_blocks=inserted_ids, hl_color=hl_color, hl_width=hl_width, x_tick_labels=None, y_tick_labels=x_labels, subplot_adjust=subplot_adjust) else: save_img_block_highlighted(save_file, x_meshgrid, hl_blocks=inserted_ids, hl_color=hl_color, hl_width=hl_width)
def cond_all_latents_traverse_v2( self, save_file, sess, x, z_comps=None, z_comp_labels=None, span=2, points_1_side=6, # substitute with original x and highlight hl_x=True, hl_color="red", hl_width=1, font_size=12, title="", title_font_scale=1.5, subplot_adjust={}, size_inches=None, batch_size=20, dec_output_2_img_func=None, enc_kwargs={}, dec_kwargs={}): assert np.shape(x) == tuple( self.x_shape), "'x' must be a single instance!" # (1, x_dim) x_ = np.expand_dims(x, axis=0) # Compute z # ----------------------------- # # (1, z_dim) z = self.encode(sess, x_, **enc_kwargs) assert z.shape[0] == 1 # (z_dim, ) z_dim = int(np.prod(self.z_shape)) z = np.reshape(z, [z_dim]) # ----------------------------- # if z_comps is None: z_comps = list(range(z_dim)) z_meshgrid = [] inserted_ids = [] s = span p = points_1_side for i, comp in enumerate(z_comps): # (2 * points_1_side + 1, ) itpl_vals = [(z[comp] - s) + 1.0 * i * s / p for i in range(p)] itpl_vals += [z[comp]] itpl_vals += [z[comp] + 1.0 * i * s / p for i in range(1, p + 1)] for val in itpl_vals: z_copy = np.array(z, dtype=z.dtype, copy=True) z_copy[comp] = val z_meshgrid.append(z_copy) inserted_ids.append((i, points_1_side)) # Compute z meshgrid # ----------------------------- # num_rows = len(z_comps) num_cols = 2 * points_1_side + 1 assert len(z_meshgrid) == num_rows * num_cols z_meshgrid = np.reshape(z_meshgrid, [num_rows * num_cols] + self.z_shape) # ----------------------------- # # Reconstruct x meshgrid # ----------------------------- # if batch_size < 0: x_meshgrid = self.decode(sess, z_meshgrid, **dec_kwargs) else: x_meshgrid = [] for batch_ids in iterate_data(len(z_meshgrid), batch_size, shuffle=False): x_meshgrid.append( self.decode(sess, z_meshgrid[batch_ids], **dec_kwargs)) x_meshgrid = np.concatenate(x_meshgrid, axis=0) x_meshgrid = np.reshape(x_meshgrid, [num_rows, num_cols] + self.x_shape) if hl_x: for row_idx, col_idx in inserted_ids: x_meshgrid[row_idx, col_idx] = x if dec_output_2_img_func is not None: x_meshgrid = dec_output_2_img_func(x_meshgrid) if z_comp_labels is not None: assert len(z_comp_labels) == len(z_comps), \ "Length of 'z_comp_labels' must be equal to the number of z components " \ "you want to draw. Found {} and {}, respectively!".format(len(z_comp_labels), len(z_comps)) if hl_x: save_img_block_highlighted_with_ticklabels( save_file, x_meshgrid, hl_blocks=inserted_ids, hl_color=hl_color, hl_width=hl_width, x_tick_labels=None, y_tick_labels=z_comp_labels, font_size=font_size, title=title, title_font_scale=title_font_scale, subplot_adjust=subplot_adjust, size_inches=size_inches) else: save_img_block_with_ticklabels( save_file, x_meshgrid, x_tick_labels=None, y_tick_labels=z_comp_labels, font_size=font_size, title=title, title_font_scale=title_font_scale, subplot_adjust=subplot_adjust, size_inches=size_inches) else: if hl_x: save_img_block_highlighted(save_file, x_meshgrid, hl_blocks=inserted_ids, hl_color=hl_color, hl_width=hl_width) else: save_img_block(save_file, x_meshgrid)
def cond_all_latents_traverse(self, save_file, sess, x, z_comps=None, z_comp_labels=None, start=-3.0, stop=3.0, num_itpl_points=10, hl_color="red", hl_width=1, subplot_adjust={}, batch_size=20, dec_output_2_img_func=None, enc_kwargs={}, dec_kwargs={}): assert num_itpl_points >= 2, "'num_points' must be >= 2!" itpl_points = [ start + (stop - start) * i * 1.0 / (num_itpl_points - 1) for i in range(0, num_itpl_points) ] assert np.shape(x) == tuple( self.x_shape), "'x' must be a single instance!" # (1, x_dim) x_ = np.expand_dims(x, axis=0) # Compute z # ----------------------------- # # (1, z_dim) z = self.encode(sess, x_, **enc_kwargs) assert z.shape[0] == 1 # (z_dim, ) z_dim = int(np.prod(self.z_shape)) z = np.reshape(z, [z_dim]) # ----------------------------- # if z_comps is None: z_comps = list(range(z_dim)) z_meshgrid = [] inserted_ids = [] for i, comp in enumerate(z_comps): itpl_zi = [] idx = 0 for k in range(len(itpl_points)): z_copy = np.array(z, dtype=z.dtype, copy=True) z_copy[comp] = itpl_points[k] itpl_zi.append(z_copy) if (1 <= k) and itpl_points[k - 1] <= z[comp] < itpl_points[k]: idx = k if itpl_points[len(itpl_points) - 1] <= z[comp]: idx = len(itpl_points) inserted_ids.append((i, idx)) itpl_zi.insert(idx, np.array(z, dtype=z.dtype, copy=True)) z_meshgrid.extend(itpl_zi) # Compute z meshgrid # ----------------------------- # num_rows = len(z_comps) num_cols = num_itpl_points + 1 assert len(z_meshgrid) == num_rows * num_cols z_meshgrid = np.reshape(z_meshgrid, [num_rows * num_cols] + self.z_shape) # ----------------------------- # # Reconstruct x meshgrid # ----------------------------- # if batch_size < 0: x_meshgrid = self.decode(sess, z_meshgrid, **dec_kwargs) else: x_meshgrid = [] for batch_ids in iterate_data(len(z_meshgrid), batch_size, shuffle=False): x_meshgrid.append( self.decode(sess, z_meshgrid[batch_ids], **dec_kwargs)) x_meshgrid = np.concatenate(x_meshgrid, axis=0) x_meshgrid = np.reshape(x_meshgrid, [num_rows, num_cols] + self.x_shape) for row_idx, col_idx in inserted_ids: x_meshgrid[row_idx, col_idx] = x if dec_output_2_img_func is not None: x_meshgrid = dec_output_2_img_func(x_meshgrid) if z_comp_labels is not None: assert len(z_comp_labels) == len(z_comps), \ "Length of 'z_comp_labels' must be equal to the number of z components " \ "you want to draw. Found {} and {}, respectively!".format(len(z_comp_labels), len(z_comps)) save_img_block_highlighted_with_ticklabels( save_file, x_meshgrid, hl_blocks=inserted_ids, hl_color=hl_color, hl_width=hl_width, x_tick_labels=None, y_tick_labels=z_comp_labels, subplot_adjust=subplot_adjust) else: save_img_block_highlighted(save_file, x_meshgrid, hl_blocks=inserted_ids, hl_color=hl_color, hl_width=hl_width)
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir) img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) # ===================================== # Instantiate model # ===================================== if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format(args.activation)) if args.enc_dec_model == "1Konny": assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny(num_outputs=2) elif args.enc_dec_model == "my": assert args.z_dim == 150, "For 1Konny, z_dim must be 150. Found {}!".format(args.z_dim) encoder = Encoder_My(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_My([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_My(num_outputs=2) else: raise ValueError("Do not support encoder/decoder model '{}'!".format(args.enc_dec_model)) model = FactorVAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments save_dir = remove_dir_if_exist(join(args.save_dir, "FactorVAE_{}".format(args.run)), ask_4_permission=False) save_dir = make_dir_if_not_exist(save_dir) # ===================================== # z correlation matrix # ======================================= # for deterministic in [True, False]: all_z = [] for batch_ids in iterate_data(celebA_loader.num_train_data, args.batch_size, shuffle=False): x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids) z = model.encode(sess, x, deterministic=deterministic) assert len(z.shape) == 2 and z.shape[1] == args.z_dim, "z.shape: {}".format(z.shape) all_z.append(z) all_z = np.concatenate(all_z, axis=0) # plot_corrmat(join(save_dir, "corr_mat[deter={}].png".format(deterministic)), all_z, # font={'size': 14}, # subplot_adjust={'left': 0.04, 'right': 0.96, 'bottom': 0.02, 'top': 0.98}, # size_inches=(7.2, 6)) plot_corrmat_with_histogram(join(save_dir, "corr_mat_hist[deter={}].png".format(deterministic)), all_z, font={'size': 14}, subplot_adjust={'left': 0.04, 'right': 0.96, 'bottom': 0.02, 'top': 0.98}, size_inches=(10, 3))
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir) img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) num_train = celebA_loader.num_train_data # ===================================== # Instantiate model # ===================================== if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format( args.activation)) if args.enc_dec_model == "1Konny": # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support encoder/decoder model '{}'!".format( args.enc_dec_model)) model = FactorVAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments save_dir = remove_dir_if_exist(join(args.save_dir, "FactorVAE_{}".format(args.run)), ask_4_permission=False) save_dir = make_dir_if_not_exist(save_dir) # save_dir = make_dir_if_not_exist(join(args.save_dir, "FactorVAE_{}".format(args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3, suppress=True) f = open(join(save_dir, 'log.txt'), mode='w') print_ = functools.partial(print_both, file=f) # z gaussian stddev # ======================================= # all_z_mean = [] all_z_stddev = [] count = 0 for batch_ids in iterate_data(int(0.05 * num_train), 10 * args.batch_size, shuffle=False): x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids) z_mean, z_stddev = sess.run(model.get_output(['z_mean', 'z_stddev']), feed_dict={ model.is_train: False, model.x_ph: x }) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) count += len(batch_ids) print("\rProcessed {} samples!".format(count), end="") print() all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) # ======================================= # z_std_error = np.std(all_z_mean, axis=0, ddof=0) z_sorted_comps = np.argsort(z_std_error)[::-1] top10_z_comps = z_sorted_comps[:10] print_("") print_("z_std_error: {}".format(z_std_error)) print_("z_sorted_std_error: {}".format(z_std_error[z_sorted_comps])) print_("z_sorted_comps: {}".format(z_sorted_comps)) print_("top10_z_comps: {}".format(top10_z_comps)) z_stddev_mean = np.mean(all_z_stddev, axis=0) info_z_comps = [ idx for idx in range(len(z_stddev_mean)) if z_stddev_mean[idx] < 0.4 ] print_("info_z_comps: {}".format(info_z_comps)) print_("len(info_z_comps): {}".format(len(info_z_comps))) # Plotting # =========================================== # seed = 389 num_samples = 30 ids = list(range(seed, seed + num_samples)) print("\nids: {}".format(ids)) data = celebA_loader.sample_images_from_dataset(sess, 'train', ids) span = 3 points_one_side = 5 # span = 8 # points_one_side = 12 for n in range(len(ids)): print("Plot conditional all comps z traverse with train sample {}!". format(ids[n])) img_file = join(save_dir, "x_train[{}]_[span={}]_hl.png".format(ids[n], span)) # model.cond_all_latents_traverse_v2(img_file, sess, data[n], # z_comps=top10_z_comps, # z_comp_labels=None, # span=span, points_1_side=points_one_side, # hl_x=True, # batch_size=args.batch_size, # dec_output_2_img_func=binary_float_to_uint8) img_file = join( save_dir, "x_train[{}]_[span={}]_hl_labeled.png".format(ids[n], span)) model.cond_all_latents_traverse_v2( img_file, sess, data[n], z_comps=top10_z_comps, z_comp_labels=["z[{}]".format(comp) for comp in top10_z_comps], span=span, points_1_side=points_one_side, hl_x=True, subplot_adjust={ 'left': 0.09, 'right': 0.98, 'bottom': 0.02, 'top': 0.98 }, size_inches=(6, 5), batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # img_file = join(save_dir, "x_train[{}]_[span={}].png".format(ids[n], span)) # model.cond_all_latents_traverse_v2(img_file, sess, data[n], # z_comps=top10_z_comps, # z_comp_labels=None, # span=span, points_1_side=points_one_side, # hl_x=False, # batch_size=args.batch_size, # dec_output_2_img_func=binary_float_to_uint8) # # img_file = join(save_dir, "x_train[{}]_[span={}]_labeled.png".format(ids[n], span)) # model.cond_all_latents_traverse_v2(img_file, sess, data[n], # z_comps=top10_z_comps, # z_comp_labels=["z[{}]".format(comp) for comp in top10_z_comps], # span=span, points_1_side=points_one_side, # hl_x=False, # subplot_adjust={'left': 0.09, 'right': 0.98, 'bottom': 0.02, 'top': 0.98}, # size_inches=(6, 5), # batch_size=args.batch_size, # dec_output_2_img_func=binary_float_to_uint8) img_file = join( save_dir, "x_train[{}]_[span={}]_info_hl.png".format(ids[n], span)) model.cond_all_latents_traverse_v2( img_file, sess, data[n], z_comps=info_z_comps, z_comp_labels=None, span=span, points_1_side=points_one_side, hl_x=True, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # =========================================== # f.close()
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== celebA_loader = TFCelebAWithAttrLoader(root_dir=args.celebA_root_dir) img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) # ===================================== # Instantiate model # ===================================== if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format(args.activation)) if args.enc_dec_model == "1Konny": # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny() else: raise ValueError("Do not support encoder/decoder model '{}'!".format(args.enc_dec_model)) model = AAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, stochastic_z=args.stochastic_z, use_gp0_z=True, gp0_z_mode=args.gp0_z_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'G_loss_z1_gen': args.G_loss_z1_gen_coeff, 'D_loss_z1_gen': args.D_loss_z1_gen_coeff, 'gp0_z': args.gp0_z_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments # save_dir = remove_dir_if_exist(join(args.save_dir, "AAE_{}".format(args.run)), ask_4_permission=True) # save_dir = make_dir_if_not_exist(save_dir) save_dir = make_dir_if_not_exist(join(args.save_dir, "AAE_{}".format(args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=4, suppress=True) num_bins = args.num_bins bin_limits = tuple([float(s) for s in args.bin_limits.split(";")]) data_proportion = args.data_proportion num_data = int(data_proportion * celebA_loader.num_train_data) eps = 1e-8 # file f = open(join(save_dir, 'log[bins={},bin_limits={},data={}].txt'. format(num_bins, bin_limits, data_proportion)), mode='w') # print function print_ = functools.partial(print_both, file=f) print_("num_bins: {}".format(num_bins)) print_("bin_limits: {}".format(bin_limits)) print_("data_proportion: {}".format(data_proportion)) # Compute representations # ================================= # z_data_file = join(save_dir, "z_data[data={}].npz".format(data_proportion)) if not exists(z_data_file): all_z_mean = [] all_z_stddev = [] print("") print("Compute all_z_mean and all_z_stddev!") count = 0 for batch_ids in iterate_data(num_data, 10 * args.batch_size, shuffle=False): x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids) z_mean, z_stddev = sess.run( model.get_output(['z_mean', 'z_stddev']), feed_dict={model.is_train: False, model.x_ph: x}) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) count += len(batch_ids) print("\rProcessed {} samples!".format(count), end="") print() all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) np.savez_compressed(z_data_file, all_z_mean=all_z_mean, all_z_stddev=all_z_stddev) else: print("{} exists. Load data from file!".format(z_data_file)) with np.load(z_data_file, "r") as f: all_z_mean = f['all_z_mean'] all_z_stddev = f['all_z_stddev'] print_("") print_("all_z_mean.shape: {}".format(all_z_mean.shape)) print_("all_z_stddev.shape: {}".format(all_z_stddev.shape)) # ================================= # # Compute bins # ================================= # print_("") print_("bin_limits: {}".format(bin_limits)) assert len(bin_limits) == 2 and bin_limits[0] < bin_limits[1], "bin_limits={}".format(bin_limits) bins = np.linspace(bin_limits[0], bin_limits[1], num_bins + 1, endpoint=True) print_("bins: {}".format(bins)) assert len(bins) == num_bins + 1 bin_widths = [bins[b] - bins[b - 1] for b in range(1, len(bins))] print_("bin_widths: {}".format(bin_widths)) assert len(bin_widths) == num_bins, "len(bin_widths)={} while num_bins={}!".format(len(bin_widths), num_bins) assert np.all(np.greater(bin_widths, 0)), "bin_widths: {}".format(bin_widths) bin_centers = [(bins[b] + bins[b - 1]) * 0.5 for b in range(1, len(bins))] print_("bin_centers: {}".format(bin_centers)) assert len(bin_centers) == num_bins, "len(bin_centers)={} while num_bins={}!".format(len(bin_centers), num_bins) # ================================= # # Compute mutual information # ================================= # H_z = [] H_z_cond_x = [] MI_z_x = [] norm_MI_z_x = [] Q_z_cond_x = [] Q_z = [] for i in range(args.z_dim): print_("") print_("Compute I(z{}, x)!".format(i)) # Q_s_cond_x all_Q_s_cond_x = [] for batch_ids in iterate_data(len(all_z_mean), 500, shuffle=False, include_remaining=True): # (batch_size, num_bins) q_s_cond_x = normal_density(np.expand_dims(bin_centers, axis=0), mean=np.expand_dims(all_z_mean[batch_ids, i], axis=-1), stddev=np.expand_dims(all_z_stddev[batch_ids, i], axis=-1)) # (batch_size, num_bins) max_q_s_cond_x = np.max(q_s_cond_x, axis=-1) # print("\nmax_q_s_cond_x: {}".format(np.sort(max_q_s_cond_x))) # (batch_size, num_bins) deter_s_cond_x = at_bin(all_z_mean[batch_ids, i], bins).astype(np.float32) # (batch_size, num_bins) Q_s_cond_x = q_s_cond_x * np.expand_dims(bin_widths, axis=0) Q_s_cond_x = Q_s_cond_x / np.maximum(np.sum(Q_s_cond_x, axis=1, keepdims=True), eps) # print("sort(sum(Q_s_cond_x)) (before): {}".format(np.sort(np.sum(Q_s_cond_x, axis=-1)))) Q_s_cond_x = np.where(np.expand_dims(np.less(max_q_s_cond_x, 1e-5), axis=-1), deter_s_cond_x, Q_s_cond_x) # print("sort(sum(Q_s_cond_x)) (after): {}".format(np.sort(np.sum(Q_s_cond_x, axis=-1)))) all_Q_s_cond_x.append(Q_s_cond_x) all_Q_s_cond_x = np.concatenate(all_Q_s_cond_x, axis=0) print_("sort(sum(all_Q_s_cond_x))[:10]: {}".format( np.sort(np.sum(all_Q_s_cond_x, axis=-1), axis=0)[:100])) assert np.all(all_Q_s_cond_x >= 0), "'all_Q_s_cond_x' contains negative values. " \ "sorted_all_Q_s_cond_x[:30]:\n{}!".format(np.sort(all_Q_s_cond_x[:30], axis=None)) Q_z_cond_x.append(all_Q_s_cond_x) H_zi_cond_x = -np.mean(np.sum(all_Q_s_cond_x * np.log(np.maximum(all_Q_s_cond_x, eps)), axis=1), axis=0) # Q_s Q_s = np.mean(all_Q_s_cond_x, axis=0) print_("Q_s: {}".format(Q_s)) print_("sum(Q_s): {}".format(sum(Q_s))) assert np.all(Q_s >= 0), "'Q_s' contains negative values. " \ "sorted_Q_s[:10]:\n{}!".format(np.sort(Q_s, axis=None)) Q_s = Q_s / np.sum(Q_s, axis=0) print_("sum(Q_s) (normalized): {}".format(sum(Q_s))) Q_z.append(Q_s) H_zi = -np.sum(Q_s * np.log(np.maximum(Q_s, eps)), axis=0) MI_zi_x = H_zi - H_zi_cond_x normalized_MI_zi_x = (1.0 * MI_zi_x) / (H_zi + eps) print_("H_zi: {}".format(H_zi)) print_("H_zi_cond_x: {}".format(H_zi_cond_x)) print_("MI_zi_x: {}".format(MI_zi_x)) print_("normalized_MI_zi_x: {}".format(normalized_MI_zi_x)) H_z.append(H_zi) H_z_cond_x.append(H_zi_cond_x) MI_z_x.append(MI_zi_x) norm_MI_z_x.append(normalized_MI_zi_x) H_z = np.asarray(H_z, dtype=np.float32) H_z_cond_x = np.asarray(H_z_cond_x, dtype=np.float32) MI_z_x = np.asarray(MI_z_x, dtype=np.float32) norm_MI_z_x = np.asarray(norm_MI_z_x, dtype=np.float32) print_("") print_("H_z: {}".format(H_z)) print_("H_z_cond_x: {}".format(H_z_cond_x)) print_("MI_z_x: {}".format(MI_z_x)) print_("norm_MI_z_x: {}".format(norm_MI_z_x)) sorted_z_comps = np.argsort(MI_z_x, axis=0)[::-1] sorted_MI_z_x = np.take_along_axis(MI_z_x, sorted_z_comps, axis=0) print_("sorted_MI_z_x: {}".format(sorted_MI_z_x)) print_("sorted_z_comps: {}".format(sorted_z_comps)) sorted_norm_z_comps = np.argsort(norm_MI_z_x, axis=0)[::-1] sorted_norm_MI_z_x = np.take_along_axis(norm_MI_z_x, sorted_norm_z_comps, axis=0) print_("sorted_norm_MI_z_x: {}".format(sorted_norm_MI_z_x)) print_("sorted_norm_z_comps: {}".format(sorted_norm_z_comps)) result_file = join(save_dir, 'results[bins={},bin_limits={},data={}].npz'. format(num_bins, bin_limits, data_proportion)) np.savez_compressed(result_file, H_z=H_z, H_z_cond_x=H_z_cond_x, MI_z_x=MI_z_x, norm_MI_z_x=norm_MI_z_x, sorted_MI_z_x=sorted_MI_z_x, sorted_z_comps=sorted_z_comps, sorted_norm_MI_z_x=sorted_norm_MI_z_x, sorted_norm_z_comps=sorted_norm_z_comps) Q_z_cond_x = np.asarray(Q_z_cond_x, dtype=np.float32) Q_z = np.asarray(Q_z, dtype=np.float32) z_prob_file = join(save_dir, 'z_prob[bins={},bin_limits={},data={}].npz'. format(num_bins, bin_limits, data_proportion)) np.savez_compressed(z_prob_file, Q_z_cond_x=Q_z_cond_x, Q_z=Q_z)
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] # 3 shape * 6 scale * 40 rotation * 32 pos X * 32 pos Y y_train = f['latents_classes'] x_train = np.expand_dims(x_train.astype(np.float32), axis=-1) num_train = len(x_train) print("num_train: {}".format(num_train)) print("y_train[:10]: {}".format(y_train[:10])) # ===================================== # Instantiate model # ===================================== if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny() else: raise ValueError("Do not support enc_dec_model='{}'!".format(args.enc_dec_model)) model = AAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, stochastic_z=args.stochastic_z, use_gp0_z=True, gp0_z_mode=args.gp0_z_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'G_loss_z1_gen': args.G_loss_z1_gen_coeff, 'D_loss_z1_gen': args.D_loss_z1_gen_coeff, 'gp0_z': args.gp0_z_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments save_dir = make_dir_if_not_exist(join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=5, suppress=True) num_bins = args.num_bins bin_limits = tuple([float(s) for s in args.bin_limits.split(";")]) data_proportion = args.data_proportion num_data = int(data_proportion * num_train) assert num_data == num_train, "For dSprites, you must use all data!" top_k = args.top_k eps = 1e-8 # file f = open(join(save_dir, 'log[bins={},bin_limits={},data={}].txt'. format(num_bins, bin_limits, data_proportion)), mode='w') # print function print_ = functools.partial(print_both, file=f) print_("num_bins: {}".format(num_bins)) print_("bin_limits: {}".format(bin_limits)) print_("data_proportion: {}".format(data_proportion)) print_("top_k: {}".format(top_k)) # Compute bins # ================================= # print_("") print_("bin_limits: {}".format(bin_limits)) assert len(bin_limits) == 2 and bin_limits[0] < bin_limits[1], "bin_limits={}".format(bin_limits) bins = np.linspace(bin_limits[0], bin_limits[1], num_bins + 1, endpoint=True) print_("bins: {}".format(bins)) assert len(bins) == num_bins + 1 bin_widths = [bins[b] - bins[b - 1] for b in range(1, len(bins))] print_("bin_widths: {}".format(bin_widths)) assert len(bin_widths) == num_bins, "len(bin_widths)={} while num_bins={}!".format(len(bin_widths), num_bins) assert np.all(np.greater(bin_widths, 0)), "bin_widths: {}".format(bin_widths) bin_centers = [(bins[b] + bins[b - 1]) * 0.5 for b in range(1, len(bins))] print_("bin_centers: {}".format(bin_centers)) assert len(bin_centers) == num_bins, "len(bin_centers)={} while num_bins={}!".format(len(bin_centers), num_bins) # ================================= # # Compute representations # ================================= # z_data_file = join(args.informativeness_metrics_dir, "{}_{}".format(args.enc_dec_model, args.run), "z_data[data={}].npz".format(data_proportion)) with np.load(z_data_file, "r") as f: all_z_mean = f['all_z_mean'] all_z_stddev = f['all_z_stddev'] print_("") print_("all_z_mean.shape: {}".format(all_z_mean.shape)) print_("all_z_stddev.shape: {}".format(all_z_stddev.shape)) # ================================= # # Compute the mutual information # ================================= # mi_file = join(args.informativeness_metrics_dir, "{}_{}".format(args.enc_dec_model, args.run), 'results[bins={},bin_limits={},data={}].npz'. format(num_bins, bin_limits, data_proportion)) with np.load(mi_file, "r") as f: sorted_MI_z_x = f['sorted_MI_z_x'] sorted_z_ids = f['sorted_z_comps'] H_z = f['H_z'] if top_k > 0: top_MI = sorted_MI_z_x[:top_k] top_z_ids = sorted_z_ids[:top_k] bot_MI = sorted_MI_z_x[-top_k:] bot_z_ids = sorted_z_ids[-top_k:] top_bot_MI = np.concatenate([top_MI, bot_MI], axis=0) top_bot_z_ids = np.concatenate([top_z_ids, bot_z_ids], axis=0) print_("top MI: {}".format(top_MI)) print_("top_z_ids: {}".format(top_z_ids)) print_("bot MI: {}".format(bot_MI)) print_("bot_z_ids: {}".format(bot_z_ids)) else: top_bot_MI = sorted_MI_z_x top_bot_z_ids = sorted_z_ids # ================================= # H_z1z2_mean_mat = np.full([len(top_bot_z_ids), len(top_bot_z_ids)], -1, dtype=np.float32) MI_z1z2_mean_mat = np.full([len(top_bot_z_ids), len(top_bot_z_ids)], -1, dtype=np.float32) H_z1z2_mean = [] MI_z1z2_mean = [] z1z2_ids = [] # Compute the mutual information # ================================= # for i in range(len(top_bot_z_ids)): z_idx1 = top_bot_z_ids[i] H_s1 = H_z[z_idx1] for j in range(i + 1, len(top_bot_z_ids)): z_idx2 = top_bot_z_ids[j] H_s2 = H_z[z_idx2] print_("") print_("Compute MI(z{}_mean, z{}_mean)!".format(z_idx1, z_idx2)) s1s2_mean_counter = np.zeros([num_bins, num_bins], dtype=np.int32) for batch_ids in iterate_data(len(all_z_mean), 100, shuffle=False, include_remaining=True): s1 = at_bin(all_z_mean[batch_ids, z_idx1], bins, one_hot=False) s2 = at_bin(all_z_mean[batch_ids, z_idx2], bins, one_hot=False) for s1_, s2_ in zip(s1, s2): s1s2_mean_counter[s1_, s2_] += 1 # I(s1, s2) = Q(s1, s2) * (log Q(s1, s2) - log Q(s1) log Q(s2)) # ---------------------------------- # Q_s1s2_mean = (s1s2_mean_counter * 1.0) / np.sum(s1s2_mean_counter).astype(np.float32) log_Q_s1s2_mean = np.log(np.maximum(Q_s1s2_mean, eps)) H_s1s2_mean = -np.sum(Q_s1s2_mean * log_Q_s1s2_mean) MI_s1s2_mean = H_s1 + H_s2 - H_s1s2_mean print_("H_s1: {}".format(H_s1)) print_("H_s2: {}".format(H_s2)) print_("H_s1s2_mean: {}".format(H_s1s2_mean)) print_("MI_s1s2_mean: {}".format(MI_s1s2_mean)) H_z1z2_mean.append(H_s1s2_mean) MI_z1z2_mean.append(MI_s1s2_mean) z1z2_ids.append((z_idx1, z_idx2)) H_z1z2_mean_mat[i, j] = H_s1s2_mean H_z1z2_mean_mat[j, i] = H_s1s2_mean MI_z1z2_mean_mat[i, j] = MI_s1s2_mean MI_z1z2_mean_mat[j, i] = MI_s1s2_mean H_z1z2_mean = np.asarray(H_z1z2_mean, dtype=np.float32) MI_z1z2_mean = np.asarray(MI_z1z2_mean, dtype=np.float32) z1z2_ids = np.asarray(z1z2_ids, dtype=np.int32) result_file = join(save_dir, "results[bins={},bin_limits={},data={},k={}].npz". format(num_bins, bin_limits, data_proportion, top_k)) results = { 'H_z1z2_mean': H_z1z2_mean, 'MI_z1z2_mean': MI_z1z2_mean, 'H_z1z2_mean_mat': H_z1z2_mean_mat, 'MI_z1z2_mean_mat': MI_z1z2_mean_mat, 'z1z2_ids': z1z2_ids, } np.savez_compressed(result_file, **results) # ================================= # f.close()
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] # 3 shape * 6 scale * 40 rotation * 32 pos X * 32 pos Y y_train = f['latents_classes'][:, 1:] x_train = np.expand_dims(x_train.astype(np.float32), axis=-1) num_train = len(x_train) print("num_train: {}".format(num_train)) print("y_train[:10]: {}".format(y_train[:10])) # ===================================== # Instantiate model # ===================================== if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support enc_dec_model='{}'!".format(args.enc_dec_model)) model = FactorVAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, 'Dz_tc_loss_coeff': args.Dz_tc_loss_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments save_dir = make_dir_if_not_exist(join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=5, suppress=True) num_bins = args.num_bins data_proportion = args.data_proportion num_data = int(data_proportion * num_train) assert num_data == num_train, "For dSprites, you must use all data!" # file f = open(join(save_dir, 'log[bins={},data={}].txt'. format(num_bins, data_proportion)), mode='w') # print function print_ = functools.partial(print_both, file=f) print_("num_bins: {}".format(num_bins)) print_("data_proportion: {}".format(data_proportion)) # Compute representations # ================================= # z_data_file = join(save_dir, "z_data[data={}].npz".format(data_proportion)) if not exists(z_data_file): all_z_mean = [] all_z_stddev = [] print("") print("Compute all_z_mean, all_z_stddev and all_attrs!") count = 0 for batch_ids in iterate_data(num_data, 10 * args.batch_size, shuffle=False): x = x_train[batch_ids] z_mean, z_stddev = sess.run( model.get_output(['z_mean', 'z_stddev']), feed_dict={model.is_train: False, model.x_ph: x}) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) count += len(batch_ids) print("\rProcessed {} samples!".format(count), end="") print() all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) np.savez_compressed(z_data_file, all_z_mean=all_z_mean, all_z_stddev=all_z_stddev) else: print("{} exists. Load data from file!".format(z_data_file)) with np.load(z_data_file, "r") as f: all_z_mean = f['all_z_mean'] all_z_stddev = f['all_z_stddev'] print_("") print_("all_z_mean.shape: {}".format(all_z_mean.shape)) print_("all_z_stddev.shape: {}".format(all_z_stddev.shape)) # ================================= # # Transpose and compute MIG score # ================================= # assert len(all_z_mean) == len(y_train) # (num_latents, num_samples) all_z_mean = np.transpose(all_z_mean, [1, 0]) print_("") print_("all_z_mean.shape: {}".format(all_z_mean.shape)) y_train = np.transpose(y_train, [1, 0]) print_("") print_("y_train.shape: {}".format(y_train.shape)) # All # --------------------------------- # result_all = compute_mig(all_z_mean, y_train, is_discrete_z=False, is_discrete_y=True, num_bins=num_bins) # (num_latents, num_factors) MI_gap_y = result_all['MI_gap_y'] attr_ids_sorted = np.argsort(MI_gap_y, axis=0)[::-1].tolist() MI_gap_y_sorted = MI_gap_y[attr_ids_sorted].tolist() print_("") print_("MIG: {}".format(result_all['MIG'])) print_("Sorted factors:\n{}".format(list(zip(attr_ids_sorted, MI_gap_y_sorted)))) save_file = join(save_dir, "results[bins={},data={}].npz".format(num_bins, data_proportion)) np.savez_compressed(save_file, **result_all) # --------------------------------- # # ================================= # f.close()
def main(args): # ===================================== # Load config # ===================================== with open(os.path.join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Preparation # ===================================== celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir) num_train = celebA_loader.num_train_data num_test = celebA_loader.num_test_data img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) # ===================================== # Instantiate models # ===================================== # Only use activation for encoder and decoder if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format( args.activation)) if args.enc_dec_model == "1Konny": # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny() else: raise ValueError("Do not support encoder/decoder model '{}'!".format( args.enc_dec_model)) model = AAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, stochastic_z=args.stochastic_z, use_gp0_z=True, gp0_z_mode=args.gp0_z_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'G_loss_z1_gen': args.G_loss_z1_gen_coeff, 'D_loss_z1_gen': args.D_loss_z1_gen_coeff, 'gp0_z': args.gp0_z_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_list() # ===================================== # TF Graph Handler asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset")) img_eval = remove_dir_if_exist(os.path.join(asset_dir, "img_eval"), ask_4_permission=False) img_eval = make_dir_if_not_exist(img_eval) img_x_gen = make_dir_if_not_exist(os.path.join(img_eval, "x_gen")) img_x_rec = make_dir_if_not_exist(os.path.join(img_eval, "x_rec")) img_z_rand_2_traversal = make_dir_if_not_exist( os.path.join(img_eval, "z_rand_2_traversal")) img_z_cond_all_traversal = make_dir_if_not_exist( os.path.join(img_eval, "z_cond_all_traversal")) img_z_cond_1_traversal = make_dir_if_not_exist( os.path.join(img_eval, "z_cond_1_traversal")) img_z_corr = make_dir_if_not_exist(os.path.join(img_eval, "z_corr")) img_z_dist = make_dir_if_not_exist(os.path.join(img_eval, "z_dist")) img_z_stat_dist = make_dir_if_not_exist( os.path.join(img_eval, "z_stat_dist")) # img_rec_error_dist = make_dir_if_not_exist(os.path.join(img_eval, "rec_error_dist")) model_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # ===================================== # ===================================== # Training Loop # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) # Load model train_helper.load(sess, load_step=args.load_step) # ''' # Generation # ======================================= # z = np.random.randn(64, args.z_dim) img_file = os.path.join(img_x_gen, 'x_gen_test.png') model.generate_images(img_file, sess, z, block_shape=[8, 8], batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ======================================= # # ''' # ''' # Reconstruction # ======================================= # seed = 389 x = celebA_loader.sample_images_from_dataset(sess, 'test', list(range(seed, seed + 64))) img_file = os.path.join(img_x_rec, 'x_rec_test.png') model.reconstruct_images(img_file, sess, x, block_shape=[8, 8], batch_size=-1, dec_output_2_img_func=binary_float_to_uint8) # ======================================= # # ''' # ''' # z random traversal # ======================================= # if args.z_dim <= 5: print("z_dim = {}, perform random traversal!".format(args.z_dim)) # Plot z cont with z cont z_zero = np.zeros([args.z_dim], dtype=np.float32) z_rand = np.random.randn(args.z_dim) z_start, z_stop = -4, 4 num_points = 8 for i in range(args.z_dim): for j in range(i + 1, args.z_dim): print( "Plot random 2 comps z traverse with {} and {} components!" .format(i, j)) img_file = os.path.join(img_z_rand_2_traversal, 'z[{},{},zero].png'.format(i, j)) model.rand_2_latents_traverse( img_file, sess, default_z=z_zero, z_comp1=i, start1=z_start, stop1=z_stop, num_points1=num_points, z_comp2=j, start2=z_start, stop2=z_stop, num_points2=num_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) img_file = os.path.join(img_z_rand_2_traversal, 'z[{},{},rand].png'.format(i, j)) model.rand_2_latents_traverse( img_file, sess, default_z=z_rand, z_comp1=i, start1=z_start, stop1=z_stop, num_points1=num_points, z_comp2=j, start2=z_stop, stop2=z_stop, num_points2=num_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ======================================= # # ''' # z conditional traversal (all features + one feature) # ======================================= # seed = 389 num_samples = 30 data = celebA_loader.sample_images_from_dataset( sess, 'train', list(range(seed, seed + num_samples))) z_start, z_stop = -4, 4 num_itpl_points = 8 for n in range(num_samples): print("Plot conditional all comps z traverse with test sample {}!". format(n)) img_file = os.path.join(img_z_cond_all_traversal, 'x_train{}.png'.format(n)) model.cond_all_latents_traverse( img_file, sess, data[n], start=z_start, stop=z_stop, num_itpl_points=num_itpl_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) z_start, z_stop = -4, 4 num_itpl_points = 8 for i in range(args.z_dim): print("Plot conditional z traverse with comp {}!".format(i)) img_file = os.path.join( img_z_cond_1_traversal, 'x_train[{},{}]_z{}.png'.format(seed, seed + num_samples, i)) model.cond_1_latent_traverse( img_file, sess, data, z_comp=i, start=z_start, stop=z_stop, num_itpl_points=num_itpl_points, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ======================================= # # ''' # ''' # z correlation matrix # ======================================= # all_z = [] for batch_ids in iterate_data(num_train, args.batch_size, shuffle=False): x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids) z = model.encode(sess, x) assert len( z.shape) == 2 and z.shape[1] == args.z_dim, "z.shape: {}".format( z.shape) all_z.append(z) all_z = np.concatenate(all_z, axis=0) print("Start plotting!") plot_corrmat_with_histogram(os.path.join(img_z_corr, "corr_mat.png"), all_z) plot_comp_dist(os.path.join(img_z_dist, 'z_{}'), all_z, x_lim=(-5, 5)) print("Done!") # ======================================= # # ''' # ''' # z gaussian stddev # ======================================= # print("\nPlot z mean and stddev!") all_z_mean = [] all_z_stddev = [] for batch_ids in iterate_data(num_train, args.batch_size, shuffle=False): x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids) z_mean, z_stddev = sess.run(model.get_output(['z_mean', 'z_stddev']), feed_dict={ model.is_train: False, model.x_ph: x }) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) plot_comp_dist(os.path.join(img_z_stat_dist, 'z_mean_{}.png'), all_z_mean, x_lim=(-5, 5)) plot_comp_dist(os.path.join(img_z_stat_dist, 'z_stddev_{}.png'), all_z_stddev, x_lim=(-0.5, 3))
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] # 3 shape * 6 scale * 40 rotation * 32 pos X * 32 pos Y y_train = f['latents_classes'] x_train = np.expand_dims(x_train.astype(np.float32), axis=-1) num_train = len(x_train) print("num_train: {}".format(num_train)) # ===================================== # Instantiate model # ===================================== if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support enc_dec_model='{}'!".format( args.enc_dec_model)) model = FactorVAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, 'Dz_tc_loss_coeff': args.Dz_tc_loss_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments save_dir = make_dir_if_not_exist( join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=5, suppress=True) num_samples = args.num_samples print("num_samples: {}".format(num_samples)) # Compute representations # ================================= # z_data_file = join(save_dir, "z_data.npz") if not exists(z_data_file): all_z_samples = [] all_z_mean = [] all_z_stddev = [] print("") print("Compute all_z_mean, all_z_stddev and all_attrs!") count = 0 for batch_ids in iterate_data(num_train, 10 * args.batch_size, shuffle=False): x = x_train[batch_ids] z_samples, z_mean, z_stddev = sess.run(model.get_output( ['z1_gen', 'z_mean', 'z_stddev']), feed_dict={ model.is_train: False, model.x_ph: x }) all_z_samples.append(z_samples) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) count += len(batch_ids) print("\rProcessed {} samples!".format(count), end="") print() all_z_samples = np.concatenate(all_z_samples, axis=0) all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) np.savez_compressed(z_data_file, all_z_samples=all_z_samples, all_z_mean=all_z_mean, all_z_stddev=all_z_stddev) else: print("{} exists. Load data from file!".format(z_data_file)) with np.load(z_data_file, "r") as f: all_z_samples = f['all_z_samples'] all_z_mean = f['all_z_mean'] all_z_stddev = f['all_z_stddev'] # ================================= # all_z_samples = np.reshape(all_z_samples, [3, 6, 40, 32, 32, -1]) all_z_mean = np.reshape(all_z_mean, [3, 6, 40, 32, 32, -1]) all_z_stddev = np.reshape(all_z_stddev, [3, 6, 40, 32, 32, -1]) if args.gpu_support == 'cupy': print("Use cupy instead of numpy!") results = MIG_4_dSprites_cupy(all_z_samples, all_z_mean, all_z_stddev, version=1, batch_size=10, num_samples=num_samples, gpu=args.gpu_id) else: results = MIG_4_dSprites(all_z_samples, all_z_mean, all_z_stddev, num_samples=num_samples, version=1, batch_size=200) result_file = join(save_dir, "results[num_samples={}].npz".format(num_samples)) np.savez_compressed(result_file, **results) f.close()
def main(args): np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3) # ===================================== # Preparation # ===================================== data_file = os.path.join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] x_train = np.expand_dims(x_train.astype(np.float32), axis=-1) num_train = len(x_train) print("x_train: {}".format(num_train)) args.output_dir = os.path.join(args.output_dir, args.enc_dec_model, args.run) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) else: if args.force_rm_dir: import shutil shutil.rmtree(args.output_dir, ignore_errors=True) print("Removed '{}'".format(args.output_dir)) else: raise ValueError("Output directory '{}' existed. 'force_rm_dir' " "must be set to True!".format(args.output_dir)) os.mkdir(args.output_dir) save_args(os.path.join(args.output_dir, 'config.json'), args) # pp.pprint(args.__dict__) # ===================================== # Instantiate models # ===================================== if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support enc_dec_model='{}'!".format( args.enc_dec_model)) model = FactorVAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, 'Dz_tc_loss_coeff': args.Dz_tc_loss_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_list() # SimpleParamPrinter.print_all_params_tf_slim() loss = model.get_loss() train_params = model.get_train_params() opt_Dz = tf.train.AdamOptimizer(learning_rate=args.lr_Dz, beta1=args.beta1_Dz, beta2=args.beta2_Dz) opt_vae = tf.train.AdamOptimizer(learning_rate=args.lr_vae, beta1=args.beta1_vae, beta2=args.beta2_vae) with tf.control_dependencies(model.get_all_update_ops()): train_op_Dz = opt_Dz.minimize(loss=loss['Dz_loss'], var_list=train_params['Dz_loss']) train_op_D = train_op_Dz train_op_vae = opt_vae.minimize(loss=loss['vae_loss'], var_list=train_params['vae_loss']) # ===================================== # TF Graph Handler asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset")) img_gen_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_gen")) img_rec_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_rec")) img_itpl_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img_itpl")) log_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "log")) train_log_file = os.path.join(log_dir, "train.log") summary_dir = make_dir_if_not_exist( os.path.join(args.output_dir, "summary_tf")) model_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper( log_dir=summary_dir, save_dir=model_dir, max_to_keep=3, max_to_keep_best=1, ) # ===================================== # ===================================== # Training Loop # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) train_helper.initialize(sess, init_variables=True, create_summary_writer=True) Dz_fetch_keys = [ 'Dz_loss', 'Dz_tc_loss', 'Dz_loss_normal', 'Dz_loss_factor', 'Dz_avg_prob_normal', 'Dz_avg_prob_factor', 'gp0_z_tc' ] D_fetch_keys = Dz_fetch_keys vae_fetch_keys = ['vae_loss', 'rec_x', 'kld_loss', 'tc_loss'] global_step = 0 for epoch in range(args.epochs): for batch_ids in iterate_data(num_train, args.batch_size, shuffle=True): global_step += 1 x = x_train[batch_ids] z = np.random.normal(size=[len(x), args.z_dim]) batch_ids_2 = np.random.choice(num_train, size=len(batch_ids)).tolist() xa = x_train[batch_ids_2] for i in range(args.D_steps): _, Dm = sess.run( [train_op_D, model.get_output(D_fetch_keys, as_dict=True)], feed_dict={ model.is_train: True, model.x_ph: x, model.z_ph: z, model.xa_ph: xa }) for i in range(args.vae_steps): _, VAEm = sess.run( [ train_op_vae, model.get_output(vae_fetch_keys, as_dict=True) ], feed_dict={ model.is_train: True, model.x_ph: x, model.z_ph: z, model.xa_ph: xa }) if global_step % args.save_freq == 0: train_helper.save(sess, global_step) if global_step % args.log_freq == 0: log_str = "\n[FactorVAE/{}/{} (dSprites)], Epoch[{}/{}], Step {}".format( args.enc_dec_model, args.run, epoch, args.epochs, global_step) + \ "\nvae_loss: {:.4f}, Dz_loss: {:.4f}, Dz_tc_loss: {:.4f}".format( VAEm['vae_loss'], Dm['Dz_loss'], Dm['Dz_tc_loss']) + \ "\nrec_x: {:.4f}, kld_loss: {:.4f}, tc_loss: {:.4f}".format( VAEm['rec_x'], VAEm['kld_loss'], VAEm['tc_loss']) + \ "\nDz_loss_normal: {:.4f}, Dz_loss_factor: {:.4f}".format( Dm['Dz_loss_normal'], Dm['Dz_loss_factor']) + \ "\nDz_avg_prob_normal: {:.4f}, Dz_avg_prob_factor: {:.4f}".format( Dm['Dz_avg_prob_normal'], Dm['Dz_avg_prob_factor']) + \ "\ngp0_z_tc_coeff: {:.4f}, gp0_z_tc: {:.4f}".format(args.gp0_z_tc_coeff, Dm['gp0_z_tc']) print(log_str) with open(train_log_file, "a") as f: f.write(log_str) f.write("\n") f.close() train_helper.add_summary( custom_tf_scalar_summary('vae_loss', VAEm['vae_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('rec_x', VAEm['rec_x'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('kld_loss', VAEm['kld_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('tc_loss', VAEm['tc_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_tc_loss', Dm['Dz_tc_loss'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_loss_normal', Dm['Dz_loss_normal'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_loss_factor', Dm['Dz_loss_factor'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_prob_normal', Dm['Dz_avg_prob_normal'], prefix='train'), global_step) train_helper.add_summary( custom_tf_scalar_summary('Dz_prob_factor', Dm['Dz_avg_prob_factor'], prefix='train'), global_step) if global_step % args.viz_gen_freq == 0: # Generate images # ------------------------- # z = np.random.normal(size=[64, args.z_dim]) img_file = os.path.join(img_gen_dir, 'step[%d]_gen_test.png' % global_step) model.generate_images( img_file, sess, z, block_shape=[8, 8], batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ------------------------- # if global_step % args.viz_rec_freq == 0: # Reconstruct images # ------------------------- # x = x_train[np.random.choice(num_train, size=64, replace=False)] img_file = os.path.join(img_rec_dir, 'step[%d]_rec_test.png' % global_step) model.reconstruct_images( img_file, sess, x, block_shape=[8, 8], batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ------------------------- # if global_step % args.viz_itpl_freq == 0: # Interpolate images # ------------------------- # x1 = x_train[np.random.choice(num_train, size=12, replace=False)] x2 = x_train[np.random.choice(num_train, size=12, replace=False)] img_file = os.path.join(img_itpl_dir, 'step[%d]_itpl_test.png' % global_step) model.interpolate_images( img_file, sess, x1, x2, num_itpl_points=12, batch_on_row=True, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # ------------------------- # # Last save train_helper.save(sess, global_step)
def main(args): # Create output directory # ===================================== # args.output_dir = os.path.join(args.output_dir, args.model_name, args.run) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) else: if args.force_rm_dir: import shutil shutil.rmtree(args.output_dir, ignore_errors=True) print("Removed '{}'".format(args.output_dir)) else: raise ValueError("Output directory '{}' existed. 'force_rm_dir' " "must be set to True!".format(args.output_dir)) os.mkdir(args.output_dir) save_args(os.path.join(args.output_dir, 'config.json'), args) # pp = pprint.PrettyPrinter(indent=4) # pp.pprint(args.__dict__) # ===================================== # # Specify data # ===================================== # if args.dataset == "mnist": x_shape = [28, 28, 1] elif args.dataset == "mnist_3" or args.dataset == "mnistm": x_shape = [28, 28, 3] elif args.dataset == "svhn" or args.dataset == "cifar10" or args.dataset == "cifar100": x_shape = [32, 32, 3] else: raise ValueError("Do not support dataset '{}'!".format(args.dataset)) if args.dataset == "cifar100": num_classes = 100 else: num_classes = 10 print("x_shape: {}".format(x_shape)) print("num_classes: {}".format(num_classes)) # ===================================== # # Load data # ===================================== # print("Loading {}!".format(args.dataset)) train_loader = SimpleDataset4SSL() train_loader.load_npz_data(args.train_file) train_loader.create_ssl_data(args.num_labeled, num_classes=num_classes, shuffle=True, seed=args.seed) if args.input_norm != "applied": train_loader.x = uint8_to_binary_float(train_loader.x) else: print("Input normalization has been applied on train data!") test_loader = SimpleDataset() test_loader.load_npz_data(args.test_file) if args.input_norm != "applied": test_loader.x = uint8_to_binary_float(test_loader.x) else: print("Input normalization has been applied on test data!") print("train_l/train_u/test: {}/{}/{}".format( train_loader.num_labeled_data, train_loader.num_unlabeled_data, test_loader.num_data)) # import matplotlib.pyplot as plt # print("train_l.y[:10]: {}".format(train_l.y[:10])) # print("train_u.y[:10]: {}".format(train_u.y[:10])) # print("test.y[:10]: {}".format(test.y[:10])) # fig, axes = plt.subplots(3, 5) # for i in range(5): # axes[0][i].imshow(train_l.x[i]) # axes[1][i].imshow(train_u.x[i]) # axes[2][i].imshow(test.x[i]) # plt.show() if args.dataset == "mnist": train_loader.x = np.expand_dims(train_loader.x, axis=-1) test_loader.x = np.expand_dims(test_loader.x, axis=-1) elif args.dataset == "mnist_3": train_loader.x = np.stack( [train_loader.x, train_loader.x, train_loader.x], axis=-1) test_loader.x = np.stack([test_loader.x, test_loader.x, test_loader.x], axis=-1) # Data Preprocessing + Augmentation # ------------------------------------- # if args.input_norm == 'none' or args.input_norm == 'applied': print("Do not apply any normalization!") elif args.input_norm == 'zca': print("Apply ZCA whitening on data!") normalizer = ZCA() normalizer.fit(train_loader.x, eps=1e-5) train_loader.x = normalizer.transform(train_loader.x) test_loader.x = normalizer.transform(test_loader.x) elif args.input_norm == 'standard': print("Apply Standardization on data!") normalizer = Standardization() normalizer.fit(train_loader.x) train_loader.x = normalizer.transform(train_loader.x) test_loader.x = normalizer.transform(test_loader.x) else: raise ValueError("Do not support 'input_norm'={}".format( args.input_norm)) # ------------------------------------- # # ===================================== # # Hyperparameters # ===================================== # hyper_updater = HyperParamUpdater( ['lr', 'ema_momentum', 'cent_u_coeff', 'cons_coeff'], [ args.lr_max, args.ema_momentum_init, args.cent_u_coeff_max, args.cons_coeff_max ], scope='moving_hyperparams') # ===================================== # # Build model # ===================================== # # IMPORTANT: Remember to test with No Gaussian Noise print("args.gauss_noise: {}".format(args.gauss_noise)) if args.model_name == "9310gaurav": main_classifier = MainClassifier_9310gaurav( num_classes=num_classes, use_gauss_noise=args.gauss_noise) else: raise ValueError("Do not support model_name='{}'!".format( args.model_name)) # Input Perturber # ------------------------------------- # # Input perturber only performs 'translating_pixels' (Both CIFAR-10 and SVHN) here input_perturber = InputPerturber( normalizer=None, # We do not use normalizer here! flip_horizontally=args.flip_horizontally, flip_vertically=False, # We do not flip images vertically! translating_pixels=args.translating_pixels, noise_std=0.0) # We do not add noise here! # ------------------------------------- # # Main model # ------------------------------------- # model = MeanTeacher(x_shape=x_shape, y_shape=num_classes, main_classifier=main_classifier, input_perturber=input_perturber, cons_mode=args.cons_mode, ema_momentum=hyper_updater.variables['ema_momentum'], cons_4_unlabeled_only=args.cons_4_unlabeled_only, weight_decay=args.weight_decay) loss_coeff_dict = { 'cross_ent_l': args.cross_ent_l, 'cond_ent_u': hyper_updater.variables['cent_u_coeff'], 'cons': hyper_updater.variables['cons_coeff'], } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_list(trainable_only=False) # ------------------------------------- # # ===================================== # # Build optimizer # ===================================== # losses = model.get_loss() train_params = model.get_train_params() opt_AE = tf.train.MomentumOptimizer( learning_rate=hyper_updater.variables['lr'], momentum=args.lr_momentum, use_nesterov=True) # Contain both batch norm update and teacher param update update_ops = model.get_all_update_ops() print("update_ops: {}".format(update_ops)) with tf.control_dependencies(update_ops): train_op_AE = opt_AE.minimize(loss=losses['loss'], var_list=train_params['loss']) # ===================================== # # Create directories # ===================================== # asset_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "asset")) img_dir = make_dir_if_not_exist(os.path.join(asset_dir, "img")) log_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "log")) train_log_file = os.path.join(log_dir, "train.log") summary_dir = make_dir_if_not_exist( os.path.join(args.output_dir, "summary_tf")) model_dir = make_dir_if_not_exist(os.path.join(args.output_dir, "model_tf")) # ===================================== # # Create session # ===================================== # config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) train_helper = SimpleTrainHelper(log_dir=summary_dir, save_dir=model_dir, max_to_keep=args.num_save, max_to_keep_best=args.num_save_best) train_helper.initialize(sess, init_variables=True, create_summary_writer=True) # ===================================== # # Start training # ===================================== # # Summarizer # ------------------------------------- # fetch_keys_AE_l = ['acc_y_l', 'cross_ent_l'] fetch_keys_AE_u = ['acc_y_u', 'cond_ent_u', 'cons'] # To compare between MDL loss and xent+consistency to see whether MDL loss # is a better indicator for generalization compared to xent+consistency or not fetch_keys_AE = fetch_keys_AE_l + fetch_keys_AE_u train_summarizer = ScalarSummarizer([(key, 'mean') for key in fetch_keys_AE]) fetch_keys_test = ['acc_y', 'acc_y_stu'] eval_summarizer = ScalarSummarizer([(key, 'mean') for key in fetch_keys_test]) # ------------------------------------- # # Data sampler # ------------------------------------- # # The number of labeled data varies during training if args.batch_size_labeled <= 0: sampler = ContinuousIndexSampler(train_loader.num_data, args.batch_size, shuffle=True) sampling_separately = False print("batch_size_l, batch_size_u vary but their sum={}!".format( args.batch_size)) elif 0 < args.batch_size_labeled < args.batch_size: batch_size_l = args.batch_size_labeled batch_size_u = args.batch_size - args.batch_size_labeled print("batch_size_l/batch_size_u: {}/{}".format( batch_size_l, batch_size_u)) # IMPORTANT: Here we must use 'train_loader.labeled_ids' and 'train_loader.unlabeled_ids', # NOT 'train_loader.num_labeled_data' and 'train_loader.num_unlabeled_data' sampler_l = ContinuousIndexSampler(train_loader.labeled_ids, batch_size_l, shuffle=True) sampler_u = ContinuousIndexSampler(train_loader.unlabeled_ids, batch_size_u, shuffle=True) sampler = ContinuousIndexSamplerGroup(sampler_l, sampler_u) sampling_separately = True else: raise ValueError( "'args.batch_size_labeled' must be in ({}, {})!".format( 0, args.batch_size)) # ------------------------------------- # # Annealer # ------------------------------------- # step_rampup_annealer = StepAnnealing(args.rampup_len_step, value_0=0, value_1=1) sigmoid_rampup_annealer = SigmoidRampup(args.rampup_len_step) sigmoid_rampdown_annealer = SigmoidRampdown(args.rampdown_len_step, args.steps) # ------------------------------------- # # Results Tracker # ------------------------------------- # tracker = BestResultsTracker([('acc_y', 'greater')], num_best=args.num_save_best) # ------------------------------------- # import math batches_per_epoch = int(math.ceil(train_loader.num_data / args.batch_size)) global_step = 0 log_time_start = time() for epoch in range(args.epochs): if global_step >= args.steps: break for batch in range(batches_per_epoch): if global_step >= args.steps: break global_step += 1 # Update hyper parameters # ---------------------------------- # step_rampup = step_rampup_annealer.get_value(global_step) sigmoid_rampup = sigmoid_rampup_annealer.get_value(global_step) sigmoid_rampdown = sigmoid_rampdown_annealer.get_value(global_step) lr = sigmoid_rampup * sigmoid_rampdown * args.lr_max ema_momentum = ( 1.0 - step_rampup ) * args.ema_momentum_init + step_rampup * args.ema_momentum_final cent_u_coeff = sigmoid_rampup * args.cent_u_coeff_max cons_coeff = sigmoid_rampup * args.cons_coeff_max hyper_updater.update(sess, feed_dict={ 'lr': lr, 'ema_momentum': ema_momentum, 'cent_u_coeff': cent_u_coeff, 'cons_coeff': cons_coeff }) hyper_vals = hyper_updater.get_value(sess) hyper_vals['sigmoid_rampup'] = sigmoid_rampup hyper_vals['sigmoid_rampdown'] = sigmoid_rampdown hyper_vals['step_rampup'] = step_rampup # ---------------------------------- # # Train model # ---------------------------------- # if sampling_separately: # print("Sample separately!") batch_ids_l, batch_ids_u = sampler.sample_group_of_ids() xl, yl, label_flag_l = train_loader.fetch_batch(batch_ids_l) xu, yu, label_flag_u = train_loader.fetch_batch(batch_ids_u) assert np.all(label_flag_l), "'label_flag_l: {}'".format( label_flag_l) assert not np.any(label_flag_u), "'label_flag_u: {}'".format( label_flag_u) x = np.concatenate([xl, xu], axis=0) y = np.concatenate([yl, yu], axis=0) label_flag = np.concatenate([label_flag_l, label_flag_u], axis=0) else: # print("Sample jointly!") batch_ids = sampler.sample_ids() x, y, label_flag = train_loader.fetch_batch(batch_ids) _, AEm = sess.run( [train_op_AE, model.get_output(fetch_keys_AE, as_dict=True)], feed_dict={ model.is_train: True, model.x_ph: x, model.y_ph: y, model.label_flag_ph: label_flag }) batch_results = AEm train_summarizer.accumulate(batch_results, args.batch_size) # ---------------------------------- # if global_step % args.save_freq == 0: train_helper.save(sess, global_step) if global_step % args.log_freq == 0: log_time_end = time() log_time_gap = (log_time_end - log_time_start) log_time_start = log_time_end summaries, results = train_summarizer.get_summaries_and_reset( summary_prefix='train') train_helper.add_summaries(summaries, global_step) train_helper.add_summaries( custom_tf_scalar_summaries(hyper_vals, prefix="moving_hyper"), global_step) log_str = "\n[MeanTeacher ({})/{}, {}], " \ "Epoch {}/{}, Batch {}/{} Step {} ({:.2f}s) (train)".format( args.dataset, args.model_name, args.run, epoch, args.epochs, batch, batches_per_epoch, global_step-1, log_time_gap) + \ "\n" + ", ".join(["{}: {:.4f}".format(key, results[key]) for key in fetch_keys_AE_l]) + \ "\n" + ", ".join(["{}: {:.4f}".format(key, results[key]) for key in fetch_keys_AE_u]) + \ "\n" + ", ".join(["{}: {:.4f}".format(key, hyper_vals[key]) for key in hyper_vals]) print(log_str) with open(train_log_file, "a") as f: f.write(log_str) f.write("\n") f.close() if global_step % args.eval_freq == 0: for batch_ids in iterate_data(test_loader.num_data, args.batch_size, shuffle=False, include_remaining=True): x, y = test_loader.fetch_batch(batch_ids) batch_results = sess.run(model.get_output(fetch_keys_test, as_dict=True), feed_dict={ model.is_train: False, model.x_ph: x, model.y_ph: y }) eval_summarizer.accumulate(batch_results, len(batch_ids)) summaries, results = eval_summarizer.get_summaries_and_reset( summary_prefix='test') train_helper.add_summaries(summaries, global_step) log_str = "Epoch {}/{}, Batch {}/{} (test), acc_y: {:.4f}, acc_y_stu: {:.4f}".format( epoch, args.epochs, batch, batches_per_epoch, results['acc_y'], results['acc_y_stu']) print(log_str) with open(train_log_file, "a") as f: f.write(log_str) f.write("\n") f.close() is_better = tracker.check_and_update(results, global_step) if is_better['acc_y']: train_helper.save_best(sess, global_step=global_step) # Last save train_helper.save(sess, global_step)
def rand_2_latents_traverse(self, save_file, sess, default_z, z_comp1, start1, stop1, num_points1, z_comp2, start2, stop2, num_points2, batch_size=20, dec_output_2_img_func=None, **kwargs): """ default_z: A single latent code to serve as default z_comp1: z component 1 z_limits1: 2-tuple specifying the low-high value of z_comp1 num_points1: Number of points z_comp2: z_limits2: num_points2: """ assert num_points1 >= 2, "'num_points1' must be >=2. Found {}!".format( num_points1) assert num_points2 >= 2, "'num_points2' must be >=2. Found {}!".format( num_points2) z_range1 = [ start1 + (stop1 - start1) * i * 1.0 / (num_points1 - 1) for i in range(num_points1) ] z_range2 = [ start2 + (stop2 - start2) * i * 1.0 / (num_points2 - 1) for i in range(num_points2) ] num_rows = len(z_range1) num_cols = len(z_range2) assert np.shape(default_z) == tuple( self.z_shape), "'default_z' must be a single instance!" default_z = np.reshape(default_z, [int(np.prod(self.z_shape))]) z_meshgrid = np.tile(np.expand_dims(default_z, axis=0), [num_rows * num_cols, 1]) for m in range(num_rows): for n in range(num_cols): z_meshgrid[m * num_cols + n, z_comp1] = z_range1[m] z_meshgrid[m * num_cols + n, z_comp2] = z_range2[n] # Reconstruct x meshgrid # ----------------------------- # if batch_size < 0: x_meshgrid = self.decode(sess, z_meshgrid, **kwargs) else: x_meshgrid = [] for batch_ids in iterate_data(len(z_meshgrid), batch_size, shuffle=False): x_meshgrid.append( self.decode(sess, z_meshgrid[batch_ids], **kwargs)) x_meshgrid = np.concatenate(x_meshgrid, axis=0) x_meshgrid = np.reshape(x_meshgrid, [num_rows, num_cols] + self.x_shape) if dec_output_2_img_func is not None: x_meshgrid = dec_output_2_img_func(x_meshgrid) save_img_block(save_file, x_meshgrid)