def main(args): # Load config # =========================================== # with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # =========================================== # # Load dataset # =========================================== # data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] x_train = np.reshape(x_train, [3, 6, 40, 32, 32, 64, 64, 1]) # =========================================== # # Build model # =========================================== # if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support enc_dec_model='{}'!".format( args.enc_dec_model)) model = FactorVAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, 'Dz_tc_loss_coeff': args.Dz_tc_loss_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # =========================================== # # Initialize session # =========================================== # config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) train_helper.load(sess, load_step=args.load_step) save_dir = make_dir_if_not_exist( join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run))) # =========================================== # # Load result file # =========================================== # result_file = join(args.SEPIN_dir, "{}_{}".format(args.enc_dec_model, args.run), "results[num_samples={}].npz".format(args.num_samples)) results = np.load(result_file, "r") print("results.keys: {}".format(list(results.keys()))) np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3, suppress=True) # =========================================== # # Plotting # =========================================== # data = [ x_train[0, 3, 20, 16, 16], x_train[1, 3, 20, 16, 16], x_train[2, 3, 20, 16, 16] ] gt_factors = ['Shape', 'Scale', 'Rotation', 'Pos_x', 'Pos_y'] # (num_latents,) MI_zi_x = results['MI_zi_x'] SEP_zi = results['SEP_zi'] ids_sorted = np.argsort(SEP_zi, axis=0)[::-1] print("") print("MI_zi_x: {}".format(MI_zi_x)) print("SEP_zi: {}".format(SEP_zi)) print("ids_sorted: {}".format(ids_sorted)) span = 3 points_one_side = 5 for n in range(len(data)): img_file = join( save_dir, "sep_x{}_num_samples={}.png".format(n, args.num_samples)) model.cond_all_latents_traverse_v2( img_file, sess, data[n], z_comps=ids_sorted, z_comp_labels=[ "z[{}] (SEP={:.4f}, INFO={:.4f})".format( idx, SEP_zi[idx], MI_zi_x[idx]) for idx in ids_sorted ], span=span, points_1_side=points_one_side, hl_x=True, font_size=9, title_font_scale=1.5, subplot_adjust={ 'left': 0.55, 'right': 0.99, 'bottom': 0.01, 'top': 0.99 }, size_inches=(4.0, 1.7), batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) f.close()
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir) img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) num_train = celebA_loader.num_train_data # ===================================== # Instantiate model # ===================================== if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format(args.activation)) if args.enc_dec_model == "1Konny": # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support encoder/decoder model '{}'!".format(args.enc_dec_model)) model = FactorVAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments # save_dir = remove_dir_if_exist(join(args.save_dir, "FactorVAE_{}".format(args.run)), ask_4_permission=True) # save_dir = make_dir_if_not_exist(save_dir) save_dir = make_dir_if_not_exist(join(args.save_dir, "FactorVAE_{}".format(args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3, suppress=True) num_bins = args.num_bins data_proportion = args.data_proportion bin_limits = tuple([float(s) for s in args.bin_limits.split(";")]) top_k = args.top_k f = open(join(save_dir, 'log[bins={},bin_limits={},data={}].txt'. format(num_bins, bin_limits, data_proportion)), mode='w') print_ = functools.partial(print_both, file=f) result_file = join(args.informativeness_metrics_dir, "FactorVAE_{}".format(args.run), 'results[bins={},bin_limits={},data={}].npz'. format(num_bins, bin_limits, data_proportion)) results = np.load(result_file, "r") print_("") print_("num_bins: {}".format(num_bins)) print_("bin_limits: {}".format(bin_limits)) print_("data_proportion: {}".format(data_proportion)) print_("top_k: {}".format(top_k)) # Plotting # =========================================== # # seed = 389 # num_samples = 30 seed = 398 num_samples = 1 ids = list(range(seed, seed + num_samples)) print_("\nids: {}".format(ids)) data = celebA_loader.sample_images_from_dataset(sess, 'train', ids) span = 3 points_one_side = 5 print_("sorted_MI: {}".format(results["sorted_MI_z_x"])) print_("sorted_z_ids: {}".format(results["sorted_z_comps"])) print_("sorted_norm_MI: {}".format(results["sorted_norm_MI_z_x"])) print_("sorted_norm_z_ids: {}".format(results["sorted_norm_z_comps"])) top_MI = results["sorted_MI_z_x"][:top_k] top_z_ids = results["sorted_z_comps"][:top_k] top_norm_MI = results["sorted_norm_MI_z_x"][:top_k] top_norm_z_ids = results["sorted_norm_z_comps"][:top_k] print("Matplotlib font size: {}".format(matplotlib.rcParams['font.size'],)) for n in range(len(ids)): if top_k == 10: print("Plot conditional all comps z traverse with train sample {}!".format(ids[n])) img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={}].png". format(ids[n], num_bins, bin_limits, data_proportion)) model.cond_all_latents_traverse_v2(img_file, sess, data[n], z_comps=top_z_ids, z_comp_labels=["z[{}] ({:.2f})".format(comp, mi) for comp, mi in zip(top_z_ids, top_MI)], span=span, points_1_side=points_one_side, hl_x=True, font_size=matplotlib.rcParams['font.size'], subplot_adjust={'left': 0.15, 'right': 0.99, 'bottom': 0.01, 'top': 0.99}, size_inches=(6.3, 4.9), batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={},norm].png". format(ids[n], num_bins, bin_limits, data_proportion)) model.cond_all_latents_traverse_v2(img_file, sess, data[n], z_comps=top_norm_z_ids, z_comp_labels=["z[{}] ({:.2f})".format(comp, mi) for comp, mi in zip(top_norm_z_ids, top_norm_MI)], span=span, points_1_side=points_one_side, hl_x=True, font_size=matplotlib.rcParams['font.size'], subplot_adjust={'left': 0.15, 'right': 0.99, 'bottom': 0.01, 'top': 0.99}, size_inches=(6.3, 4.9), batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) else: print("Plot conditional all comps z traverse with train sample {}!".format(ids[n])) img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={}].png". format(ids[n], num_bins, bin_limits, data_proportion)) model.cond_all_latents_traverse_v2(img_file, sess, data[n], z_comps=top_z_ids, z_comp_labels=["z[{}] ({:.2f})".format(comp, mi) for comp, mi in zip(top_z_ids, top_MI)], span=span, points_1_side=points_one_side, hl_x=True, font_size=5, subplot_adjust={'left': 0.19, 'right': 0.99, 'bottom': 0.01, 'top': 0.99}, size_inches=(2.98, 9.85), batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) img_file = join(save_dir, "x_train[{}][bins={},bin_limits={},data={},norm].png". format(ids[n], num_bins, bin_limits, data_proportion)) model.cond_all_latents_traverse_v2(img_file, sess, data[n], z_comps=top_norm_z_ids, z_comp_labels=["z[{}] ({:.2f})".format(comp, mi) for comp, mi in zip(top_norm_z_ids, top_norm_MI)], span=span, points_1_side=points_one_side, hl_x=True, font_size=5, subplot_adjust={'left': 0.19, 'right': 0.99, 'bottom': 0.01, 'top': 0.99}, size_inches=(2.98, 9.85), batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # =========================================== # f.close()
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== data_file = join(RAW_DATA_DIR, "ComputerVision", "dSprites", "dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz") # It is already in the range [0, 1] with np.load(data_file, encoding="latin1") as f: x_train = f['imgs'] x_train = np.reshape(x_train, [3, 6, 40, 32, 32, 64, 64, 1]) # ===================================== # Instantiate model # ===================================== if args.enc_dec_model == "1Konny": encoder = Encoder_1Konny(args.z_dim, stochastic=True) decoder = Decoder_1Konny() disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support enc_dec_model='{}'!".format( args.enc_dec_model)) model = FactorVAE([64, 64, 1], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, 'Dz_tc_loss_coeff': args.Dz_tc_loss_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments save_dir = make_dir_if_not_exist( join(args.save_dir, "{}_{}".format(args.enc_dec_model, args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3, suppress=True) result_file = join(args.JEMMIG_sampling_dir, "{}_{}".format(args.enc_dec_model, args.run), "results[num_samples={}].npz".format(args.num_samples)) results = np.load(result_file, "r") print("results.keys: {}".format(list(results.keys()))) # Plotting # =========================================== # data = [ x_train[0, 3, 20, 16, 16], x_train[1, 3, 20, 16, 16], x_train[2, 3, 20, 16, 16] ] gt_factors = ['Shape', 'Scale', 'Rotation', 'Pos_x', 'Pos_y'] ids_sorted = results['id_sorted'] MI_zi_yk_sorted = results['MI_zi_yk_sorted'] H_zi_yk_sorted = results['H_zi_yk_sorted'] H_yk = results['H_yk'] RMIG_yk = results['RMIG_yk'] RMIG_norm_yk = results['RMIG_norm_yk'] JEMMIG_yk = results['JEMMIG_yk'] print("MI_zi_yk_sorted:\n{}".format(MI_zi_yk_sorted)) print("\nShow RMIG!") for k in range(len(gt_factors)): print( "{}, RMIG: {:.4f}, RMIG (norm): {:.4f}, H: {:.4f}, I1: {:.4f}, I2: {:.4f}" .format(gt_factors[k], RMIG_yk[k], RMIG_norm_yk[k], H_yk[k], MI_zi_yk_sorted[0, k], MI_zi_yk_sorted[1, k])) print("\nShow JEMMIG!") for k in range(len(gt_factors)): print( "{}, JEMMIG: {:.4f}, H1: {:.4f}, H1-I1: {:.4f}, I2: {:.4f}, top2 ids: z{}, z{}" .format(gt_factors[k], JEMMIG_yk[k], H_zi_yk_sorted[0, k], H_zi_yk_sorted[0, k] - MI_zi_yk_sorted[0, k], MI_zi_yk_sorted[1, k], ids_sorted[0, k], ids_sorted[1, k])) span = 3 points_one_side = 5 for n in range(len(data)): for k in range(len(gt_factors)): print("x={}, y={}!".format(n, gt_factors[k])) img_file = join( save_dir, "{}-x{}_num_samples={}].png".format(gt_factors[k], n, args.num_samples)) ids_top3 = ids_sorted[:3, k] MI_top3 = MI_zi_yk_sorted[:3, k] model.cond_all_latents_traverse_v2( img_file, sess, data[n], z_comps=ids_top3, z_comp_labels=[ "z[{}] ({:.4f})".format(comp, mi) for comp, mi in zip(ids_top3, MI_top3) ], span=span, points_1_side=points_one_side, hl_x=True, font_size=9, title="{} (RMIG={:.4f}, JEMMIG={:.4f}, H={:.4f})".format( gt_factors[k], RMIG_yk[k], JEMMIG_yk[k], H_yk[k]), title_font_scale=1.5, subplot_adjust={ 'left': 0.16, 'right': 0.99, 'bottom': 0.01, 'top': 0.88 }, size_inches=(6.2, 1.7), batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) f.close()
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== celebA_loader = TFCelebAWithAttrLoader(root_dir=args.celebA_root_dir) img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) num_train = celebA_loader.num_train_data # ===================================== # Instantiate model # ===================================== if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format( args.activation)) if args.enc_dec_model == "1Konny": # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support encoder/decoder model '{}'!".format( args.enc_dec_model)) model = FactorVAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments # save_dir = remove_dir_if_exist(join(args.save_dir, "FactorVAE_{}".format(args.run)), ask_4_permission=True) # save_dir = make_dir_if_not_exist(save_dir) save_dir = make_dir_if_not_exist( join(args.save_dir, "FactorVAE_{}".format(args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3, suppress=True) num_bins = args.num_bins bin_limits = tuple([float(s) for s in args.bin_limits.split(";")]) data_proportion = args.data_proportion f = open(join( save_dir, 'log[bins={},bin_limits={},data={}].txt'.format( num_bins, bin_limits, data_proportion)), mode='w') print_ = functools.partial(print_both, file=f) result_file = join( args.interpretability_metrics_dir, "FactorVAE_{}".format(args.run), "results[bins={},bin_limits={},data={}].npz".format( num_bins, bin_limits, data_proportion)) results = np.load(result_file, "r") print_("") print_("num_bins: {}".format(num_bins)) print_("bin_limits: {}".format(bin_limits)) print_("data_proportion: {}".format(data_proportion)) # Plotting # =========================================== # # seed = 389 # num_samples = 30 seed = 398 num_samples = 1 ids = list(range(seed, seed + num_samples)) print_("\nids: {}".format(ids)) data = celebA_loader.sample_images_from_dataset(sess, 'train', ids) span = 3 points_one_side = 5 attr_names = celebA_loader.attributes print_("#attrs: {}".format(len(attr_names))) print_("attr_names: {}".format(attr_names)) print_("results.keys: {}".format(list(results.keys()))) # (z_dim, num_attrs) MI_ids_sorted = results['MI_ids_sorted'] MI_sorted = results['MI_sorted'] MI_gap_y = results['MI_gap_y'] H_y = results['H_y_4_diff_z'][:, 0] assert MI_ids_sorted.shape[1] == len(attr_names) == len(MI_gap_y) == len(H_y), \ "MI_ids_sorted.shape: {}, len(attr_names): {}, len(MI_gap_y): {}, len(H_y): {}".format( MI_ids_sorted.shape, len(attr_names), len(MI_gap_y), len(H_y)) print_("\nShow RMIG!") for i in range(len(attr_names)): print("{}: RMIG: {:.4f}, RMIG (unnorm): {:.4f}, H: {:.4f}".format( attr_names[i], MI_gap_y[i], MI_gap_y[i] * H_y[i], H_y[i])) print_("\nShow JEMMIG!") H_z_y = results['H_z_y'] MI_z_y = results['MI_z_y'] ids_sorted_by_MI = np.argsort(MI_z_y, axis=0)[::-1] MI_z_y_sorted = np.take_along_axis(MI_z_y, ids_sorted_by_MI, axis=0) H_z_y_sorted = np.take_along_axis(H_z_y, ids_sorted_by_MI, axis=0) H_diff = H_z_y_sorted[0, :] - MI_z_y_sorted[0, :] JEMMIG_unnorm = H_diff + MI_z_y_sorted[1, :] JEMMIG_norm = JEMMIG_unnorm / (np.log(num_bins) + H_y) for i in range(len(attr_names)): print( "{}: JEMMIG: {:.4f}, JEMMIG (unnorm): {:.4f}, H_diff: {:.4f}, I2: {:.4f}, top 2 latents: z{}, z{}" .format(attr_names[i], JEMMIG_norm[i], JEMMIG_unnorm[i], H_diff[i], MI_z_y_sorted[1, i], ids_sorted_by_MI[0, i], ids_sorted_by_MI[1, i])) # Plot JEMMIG/RMIG of all attributes # ========================================= # import matplotlib.pyplot as plt import matplotlib from matplotlib.backends.backend_pdf import PdfPages font = {'family': 'normal', 'size': 12} matplotlib.rc('font', **font) width = 0.3 plt.bar(np.arange(len(attr_names)), JEMMIG_norm, width=width, align='center', edgecolor='black', label='JEMMIG') plt.bar(np.arange(len(attr_names)) + width, MI_gap_y, width=width, align='center', edgecolor='black', label='RMIG') plt.xticks(np.arange(len(attr_names)) + 0.5 * width, attr_names, rotation=90) plt.xlabel("model") plt.ylabel("JEMMIG/RMIG") # subplot_adjust = {'left': 0.08, 'right': 0.99, 'bottom': 0.17, 'top': 0.98} # plt.subplots_adjust(**subplot_adjust) plt.gcf().set_size_inches(12, 3) save_dir = make_dir_if_not_exist(save_dir) save_file = join(save_dir, "JEMMIG_tc_beta.pdf") with PdfPages(save_file) as pdf_file: plt.savefig(pdf_file, format='pdf') plt.show() plt.close() # ========================================= # # Uncomment if you want ''' for n in range(len(ids)): for k in range(len(attr_names)): MI_ids_top10 = MI_ids_sorted[:10, k] MI_top10 = MI_sorted[:10, k] print("Plot top 10 latents for factor '{}'!".format(attr_names[k])) img_file = join(save_dir, "x_train[{}][attr={}][bins={},bin_limits={},data={}].png". format(ids[n], attr_names[k], num_bins, bin_limits, data_proportion)) model.cond_all_latents_traverse_v2(img_file, sess, data[n], z_comps=MI_ids_top10, z_comp_labels=["z[{}] ({:.4f})".format(comp, mi) for comp, mi in zip(MI_ids_top10, MI_top10)], span=span, points_1_side=points_one_side, hl_x=True, font_size=9, title="{} (MI gap={:.4f}, H={:.4f})".format( attr_names[k], MI_gap_y[k], H_y[k]), title_font_scale=1.5, subplot_adjust={'left': 0.16, 'right': 0.99, 'bottom': 0.01, 'top': 0.95}, size_inches=(6.5, 5.2), batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) #''' # Top 5 only for n in range(len(ids)): for k in range(len(attr_names)): MI_ids_top10 = MI_ids_sorted[:5, k] MI_top10 = MI_sorted[:5, k] print("Plot top 5 latents for factor '{}'!".format(attr_names[k])) img_file = join( save_dir, "train{}_attr={}_bins={}_data={}.png".format( ids[n], attr_names[k], num_bins, data_proportion)) model.cond_all_latents_traverse_v2( img_file, sess, data[n], z_comps=MI_ids_top10, z_comp_labels=[ "z[{}] ({:.4f})".format(comp, mi) for comp, mi in zip(MI_ids_top10, MI_top10) ], span=span, points_1_side=points_one_side, hl_x=True, font_size=9, title="{} (MI gap={:.4f}, H={:.4f})".format( attr_names[k], MI_gap_y[k], H_y[k]), title_font_scale=1.5, subplot_adjust={ 'left': 0.16, 'right': 0.99, 'bottom': 0.005, 'top': 0.93 }, size_inches=(6.5, 2.8), batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) '''
def main(args): # ===================================== # Load config # ===================================== with open(join(args.output_dir, 'config.json')) as f: config = json.load(f) args.__dict__.update(config) # ===================================== # Dataset # ===================================== celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir) img_height, img_width = args.celebA_resize_size, args.celebA_resize_size celebA_loader.build_transformation_flow_tf( *celebA_loader.get_transform_fns("1Konny", resize_size=args.celebA_resize_size)) num_train = celebA_loader.num_train_data # ===================================== # Instantiate model # ===================================== if args.activation == "relu": activation = tf.nn.relu elif args.activation == "leaky_relu": activation = tf.nn.leaky_relu else: raise ValueError("Do not support '{}' activation!".format( args.activation)) if args.enc_dec_model == "1Konny": # assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(args.z_dim) encoder = Encoder_1Konny(args.z_dim, stochastic=True, activation=activation) decoder = Decoder_1Konny([img_height, img_width, 3], activation=activation, output_activation=tf.nn.sigmoid) disc_z = DiscriminatorZ_1Konny(num_outputs=2) else: raise ValueError("Do not support encoder/decoder model '{}'!".format( args.enc_dec_model)) model = FactorVAE([img_height, img_width, 3], args.z_dim, encoder=encoder, decoder=decoder, discriminator_z=disc_z, rec_x_mode=args.rec_x_mode, use_gp0_z_tc=True, gp0_z_tc_mode=args.gp0_z_tc_mode) loss_coeff_dict = { 'rec_x': args.rec_x_coeff, 'kld_loss': args.kld_loss_coeff, 'tc_loss': args.tc_loss_coeff, 'gp0_z_tc': args.gp0_z_tc_coeff, } model.build(loss_coeff_dict) SimpleParamPrinter.print_all_params_tf_slim() # ===================================== # Load model # ===================================== config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9 sess = tf.Session(config=config_proto) model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf")) train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir) # Load model train_helper.load(sess, load_step=args.load_step) # ===================================== # Experiments save_dir = remove_dir_if_exist(join(args.save_dir, "FactorVAE_{}".format(args.run)), ask_4_permission=False) save_dir = make_dir_if_not_exist(save_dir) # save_dir = make_dir_if_not_exist(join(args.save_dir, "FactorVAE_{}".format(args.run))) # ===================================== np.set_printoptions(threshold=np.nan, linewidth=1000, precision=3, suppress=True) f = open(join(save_dir, 'log.txt'), mode='w') print_ = functools.partial(print_both, file=f) # z gaussian stddev # ======================================= # all_z_mean = [] all_z_stddev = [] count = 0 for batch_ids in iterate_data(int(0.05 * num_train), 10 * args.batch_size, shuffle=False): x = celebA_loader.sample_images_from_dataset(sess, 'train', batch_ids) z_mean, z_stddev = sess.run(model.get_output(['z_mean', 'z_stddev']), feed_dict={ model.is_train: False, model.x_ph: x }) all_z_mean.append(z_mean) all_z_stddev.append(z_stddev) count += len(batch_ids) print("\rProcessed {} samples!".format(count), end="") print() all_z_mean = np.concatenate(all_z_mean, axis=0) all_z_stddev = np.concatenate(all_z_stddev, axis=0) # ======================================= # z_std_error = np.std(all_z_mean, axis=0, ddof=0) z_sorted_comps = np.argsort(z_std_error)[::-1] top10_z_comps = z_sorted_comps[:10] print_("") print_("z_std_error: {}".format(z_std_error)) print_("z_sorted_std_error: {}".format(z_std_error[z_sorted_comps])) print_("z_sorted_comps: {}".format(z_sorted_comps)) print_("top10_z_comps: {}".format(top10_z_comps)) z_stddev_mean = np.mean(all_z_stddev, axis=0) info_z_comps = [ idx for idx in range(len(z_stddev_mean)) if z_stddev_mean[idx] < 0.4 ] print_("info_z_comps: {}".format(info_z_comps)) print_("len(info_z_comps): {}".format(len(info_z_comps))) # Plotting # =========================================== # seed = 389 num_samples = 30 ids = list(range(seed, seed + num_samples)) print("\nids: {}".format(ids)) data = celebA_loader.sample_images_from_dataset(sess, 'train', ids) span = 3 points_one_side = 5 # span = 8 # points_one_side = 12 for n in range(len(ids)): print("Plot conditional all comps z traverse with train sample {}!". format(ids[n])) img_file = join(save_dir, "x_train[{}]_[span={}]_hl.png".format(ids[n], span)) # model.cond_all_latents_traverse_v2(img_file, sess, data[n], # z_comps=top10_z_comps, # z_comp_labels=None, # span=span, points_1_side=points_one_side, # hl_x=True, # batch_size=args.batch_size, # dec_output_2_img_func=binary_float_to_uint8) img_file = join( save_dir, "x_train[{}]_[span={}]_hl_labeled.png".format(ids[n], span)) model.cond_all_latents_traverse_v2( img_file, sess, data[n], z_comps=top10_z_comps, z_comp_labels=["z[{}]".format(comp) for comp in top10_z_comps], span=span, points_1_side=points_one_side, hl_x=True, subplot_adjust={ 'left': 0.09, 'right': 0.98, 'bottom': 0.02, 'top': 0.98 }, size_inches=(6, 5), batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # img_file = join(save_dir, "x_train[{}]_[span={}].png".format(ids[n], span)) # model.cond_all_latents_traverse_v2(img_file, sess, data[n], # z_comps=top10_z_comps, # z_comp_labels=None, # span=span, points_1_side=points_one_side, # hl_x=False, # batch_size=args.batch_size, # dec_output_2_img_func=binary_float_to_uint8) # # img_file = join(save_dir, "x_train[{}]_[span={}]_labeled.png".format(ids[n], span)) # model.cond_all_latents_traverse_v2(img_file, sess, data[n], # z_comps=top10_z_comps, # z_comp_labels=["z[{}]".format(comp) for comp in top10_z_comps], # span=span, points_1_side=points_one_side, # hl_x=False, # subplot_adjust={'left': 0.09, 'right': 0.98, 'bottom': 0.02, 'top': 0.98}, # size_inches=(6, 5), # batch_size=args.batch_size, # dec_output_2_img_func=binary_float_to_uint8) img_file = join( save_dir, "x_train[{}]_[span={}]_info_hl.png".format(ids[n], span)) model.cond_all_latents_traverse_v2( img_file, sess, data[n], z_comps=info_z_comps, z_comp_labels=None, span=span, points_1_side=points_one_side, hl_x=True, batch_size=args.batch_size, dec_output_2_img_func=binary_float_to_uint8) # =========================================== # f.close()