def train(): parser = argparse.ArgumentParser() parser.add_argument('--config', '-c', type=str) parser.add_argument('--debug', '-d', action='store_true') args = parser.parse_args() # ============= Load config ============= config_path = args.config config = yaml.load(open(config_path)) print(config) # ============= Experiment Folder============= assets_dir = os.path.join(config['log_dir'], config['name']) log_dir = os.path.join(assets_dir, 'log') ckpt_dir = os.path.join(assets_dir, 'ckpt_dir') sample_dir = os.path.join(assets_dir, 'sample') test_dir = os.path.join(assets_dir, 'test') # make directory if not exist try: os.makedirs(log_dir) except: pass try: os.makedirs(ckpt_dir) except: pass try: os.makedirs(sample_dir) except: pass try: os.makedirs(test_dir) except: pass # ============= Experiment Parameters ============= ckpt_dir_cls = config['cls_experiment'] BATCH_SIZE = config['batch_size'] EPOCHS = config['epochs'] channels = config['num_channel'] input_size = config['input_size'] NUMS_CLASS_cls = config['num_class'] NUMS_CLASS = config['num_bins'] MU_CLUSTER = config['mu_cluster'] VAR_CLUSTER = config['var_cluster'] TRAVERSAL_N_SIGMA = config['traversal_n_sigma'] STEP_SIZE = 2*TRAVERSAL_N_SIGMA * VAR_CLUSTER/(NUMS_CLASS - 1) OFFSET = MU_CLUSTER - TRAVERSAL_N_SIGMA*VAR_CLUSTER target_class = config['target_class'] # CSVAE parameters beta1 = config['beta1'] beta2 = config['beta2'] beta3 = config['beta3'] beta4 = config['beta4'] beta5 = config['beta5'] z_dim = config['z_dim'] w_dim = config['w_dim'] save_summary = int(config['save_summary']) save_ckpt = int(config['save_ckpt']) ckpt_dir_continue = config['ckpt_dir_continue'] dataset = config['dataset'] if dataset == 'CelebA': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader() EncoderZ = EncoderZ_128 EncoderW = EncoderW_128 DecoderX = DecoderX_128 DecoderY = DecoderY_128 elif dataset == 'shapes': pretrained_classifier = shapes_classifier if args.debug: my_data_loader = ShapesLoader(dbg_mode=True, dbg_size=config['batch_size'], dbg_image_label_dict=config['image_label_dict']) else: my_data_loader = ShapesLoader() EncoderZ = EncoderZ_64 EncoderW = EncoderW_64 DecoderX = DecoderX_64 DecoderY = DecoderY_64 elif dataset == 'CelebA64' or dataset == 'dermatology': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader(input_size=64) EncoderZ = EncoderZ_64 EncoderW = EncoderW_64 DecoderX = DecoderX_64 DecoderY = DecoderY_64 elif dataset == 'synthderm': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader(input_size=64) EncoderZ = EncoderZ_64 EncoderW = EncoderW_64 DecoderX = DecoderX_64 DecoderY = DecoderY_64 if ckpt_dir_continue == '': continue_train = False else: ckpt_dir_continue = os.path.join(ckpt_dir_continue, 'ckpt_dir') continue_train = True global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step') # ============= Data ============= try: categories, file_names_dict = read_data_file(config['image_label_dict']) except: print("Problem in reading input data file : ", config['image_label_dict']) sys.exit() data = np.asarray(list(file_names_dict.keys())) # CSVAE does not need discretizing categories. The default 2 is recommended. print("The classification categories are: ") print(categories) print('The size of the training set: ', data.shape[0]) fp = open(os.path.join(log_dir, 'setting.txt'), 'w') fp.write('config_file:' + str(config_path) + '\n') fp.close() # ============= placeholder ============= x_source = tf.placeholder(tf.float32, [None, input_size, input_size, channels], name='x_source') y_s = tf.placeholder(tf.int32, [None, NUMS_CLASS_cls], name='y_s') y_source = y_s[:, NUMS_CLASS_cls-1] train_phase = tf.placeholder(tf.bool, name='train_phase') y_target = tf.placeholder(tf.int32, [None, w_dim], name='y_target') # between 0 and NUMS_CLASS # ============= CSVAE ============= encoder_z = EncoderZ('encoder_z') encoder_w = EncoderW('encoder_w') decoder_x = DecoderX('decoder_x') decoder_y = DecoderY('decoder_y') # encode x to get mean, log variance, and samples from the latent subspace Z mu_z, logvar_z, z = encoder_z(x_source, z_dim) # encode x and y to get mean, log variance, and samples from the latent subspace W mu_w, logvar_w, w = encoder_w(x_source, y_source, w_dim) # pass samples of z and w to get predictions of x pred_x = decoder_x(tf.concat([w, z], axis=-1)) # get predicted labels based only on the latent subspace Z pred_y = decoder_y(z, NUMS_CLASS_cls) # Create and save a grid of images fake_img_traversal = tf.zeros([0, input_size, input_size, channels]) for i in range(w_dim): for j in range(NUMS_CLASS): val = j * STEP_SIZE np_arr = np.zeros((BATCH_SIZE, w_dim)) np_arr[:, i] = val tmp_w = tf.convert_to_tensor(np_arr, dtype=tf.float32) fake_img = decoder_x(tf.concat([tmp_w, z], axis=-1)) fake_img_traversal = tf.concat([fake_img_traversal, fake_img], axis=0) fake_img_traversal_board = make4d_tensor(fake_img_traversal, channels, input_size, w_dim, NUMS_CLASS, BATCH_SIZE) fake_img_traversal_save = make3d_tensor(fake_img_traversal, channels, input_size, w_dim, NUMS_CLASS, BATCH_SIZE) # Create and save 2d traversal, this is relevant only for w_dim == 2 fake_2d_img_traversal = tf.zeros([0, input_size, input_size, channels]) for i in range(NUMS_CLASS): for j in range(NUMS_CLASS): val_0 = i * STEP_SIZE val_1 = j * STEP_SIZE np_arr = np.zeros((BATCH_SIZE, w_dim)) np_arr[:, 0] = val_0 np_arr[:, 1] = val_1 tmp_w = tf.convert_to_tensor(np_arr, dtype=tf.float32) fake_2d_img = decoder_x(tf.concat([tmp_w, z], axis=-1)) fake_2d_img_traversal = tf.concat([fake_2d_img_traversal, fake_2d_img], axis=0) fake_2d_img_traversal_board = make4d_tensor(fake_2d_img_traversal, channels, input_size, NUMS_CLASS, NUMS_CLASS, BATCH_SIZE) fake_2d_img_traversal_save = make3d_tensor(fake_2d_img_traversal, channels, input_size, NUMS_CLASS, NUMS_CLASS, BATCH_SIZE) # Create a single image based on y_target target_w = STEP_SIZE * tf.cast(y_target, dtype=tf.float32) + OFFSET fake_target_img = decoder_x(tf.concat([target_w, z], axis=-1)) # ============= pre-trained classifier ============= real_img_cls_logit_pretrained, real_img_cls_prediction = pretrained_classifier(x_source, NUMS_CLASS_cls, reuse=False, name='classifier') fake_recon_cls_logit_pretrained, fake_recon_cls_prediction = pretrained_classifier(pred_x, NUMS_CLASS_cls, reuse=True) fake_img_cls_logit_pretrained, fake_img_cls_prediction = pretrained_classifier(fake_img, NUMS_CLASS_cls, reuse=True) # ============= predicted probabilities ============= fake_target_p_tensor = tf.reduce_max(tf.cast(y_target, tf.float32) * 1.0 / float(NUMS_CLASS - 1), axis=1) # ============= Loss ============= # OPTIMIZATION: # Specified in section 4.1 of http://www.cs.toronto.edu/~zemel/documents/Conditional_Subspace_VAE_all.pdf # There are three components: M1, M2, N # 1.Optimize the first loss related to maximizing variational lower bound # on the marginal log likelihood and minimizing mutual information # define two KL divergences: # KL divergence for label 1 # We want the latent subspace W for this label to be close to mean 0, var 0.01 kl1 = KL(mu1=mu_w, logvar1=logvar_w, mu2=tf.zeros_like(mu_w), logvar2=tf.ones_like(logvar_w) * np.log(0.01)) # KL divergence for label 0 # We want the latent subspace W for this label to be close to mean MU_CLUSTER, var VAR_CLUSTER kl0 = KL(mu1=mu_w, logvar1=logvar_w, mu2=tf.ones_like(mu_w) * MU_CLUSTER, logvar2=tf.ones_like(logvar_w) * np.log(VAR_CLUSTER)) loss_m1_1 = tf.reduce_sum(beta1 * tf.reduce_sum((x_source - pred_x) ** 2, axis=-1)) # corresponds to M1 loss_m1_2 = tf.reduce_sum( beta2 * tf.where(tf.equal(y_source, tf.ones_like(y_source)), kl1, kl0)) # corresponds to M1 loss_m1_3 = tf.reduce_sum( beta3 * KL(mu_z, logvar_z, tf.zeros_like(mu_z), tf.zeros_like(logvar_z))) # corresponds to M1 loss_m2 = tf.reduce_sum(beta4 * tf.reduce_sum(pred_y * safe_log(pred_y), axis=-1)) # corresponds to M2 loss_m1 = loss_m1_1 + loss_m1_2 + loss_m1_3 loss1 = loss_m1 + loss_m2 # 2. Optimize second loss related to learning the approximate posterior loss_n = tf.reduce_sum(beta5 * tf.where(y_source == 1, -safe_log(pred_y[:, 1]), -safe_log(pred_y[:, 0]))) # N loss2 = loss_n optimizer_1 = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(loss1, var_list=decoder_x.var_list() + encoder_w.var_list() + encoder_z.var_list(), global_step=global_step) optimizer_2 = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(loss2, var_list=decoder_y.var_list(), global_step=global_step) # combine losses for tracking loss = loss1 + loss2 # ============= summary ============= real_img_sum = tf.summary.image('real_img', x_source) fake_recon_img_sum = tf.summary.image('fake_recon_img', pred_x) fake_img_sum = tf.summary.image('fake_target_img', fake_target_img) fake_img_traversal_sum = tf.summary.image('fake_img_traversal', fake_img_traversal_board) fake_2d_img_traversal_sum = tf.summary.image('fake_2d_img_traversal', fake_2d_img_traversal_board) loss_m1_sum = tf.summary.scalar('losses/M1', loss_m1) loss_m1_1_sum = tf.summary.scalar('losses/M1/m1_1', loss_m1_1) loss_m1_2_sum = tf.summary.scalar('losses/M1/m1_2', loss_m1_2) loss_m1_3_sum = tf.summary.scalar('losses/M1/m1_3', loss_m1_3) loss_m2_sum = tf.summary.scalar('losses/M2', loss_m2) loss_n_sum = tf.summary.scalar('losses/N', loss_n) loss_sum = tf.summary.scalar('losses/total_loss', loss) part1_sum = tf.summary.merge( [loss_m1_sum, loss_m1_1_sum, loss_m1_2_sum, loss_m1_3_sum, loss_m2_sum]) part2_sum = tf.summary.merge( [loss_n_sum, loss_sum, ]) overall_sum = tf.summary.merge( [loss_sum, real_img_sum, fake_recon_img_sum, fake_img_sum, fake_img_traversal_sum, fake_2d_img_traversal_sum]) # ============= session ============= sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() writer = tf.summary.FileWriter(log_dir, sess.graph) # ============= Checkpoints ============= if continue_train: print(" [*] before training, Load checkpoint ") print(" [*] Reading checkpoint...") ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name)) print(ckpt_dir_continue, ckpt_name) print("Successful checkpoint upload") else: print("Failed checkpoint load") else: print(" [!] before training, no need to Load ") # ============= load pre-trained classifier checkpoint ============= class_vars = [var for var in slim.get_variables_to_restore() if 'classifier' in var.name] name_to_var_map_local = {var.op.name: var for var in class_vars} temp_saver = tf.train.Saver(var_list=name_to_var_map_local) ckpt = tf.train.get_checkpoint_state(ckpt_dir_cls) ckpt_name = os.path.basename(ckpt.model_checkpoint_path) temp_saver.restore(sess, os.path.join(ckpt_dir_cls, ckpt_name)) print("Classifier checkpoint loaded.................") print(ckpt_dir_cls, ckpt_name) # ============= Training ============= for e in range(1, EPOCHS + 1): np.random.shuffle(data) for i in range(data.shape[0] // BATCH_SIZE): if args.debug: image_paths = np.array([str(ind) for ind in my_data_loader.tmp_list]) else: image_paths = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] img, labels = my_data_loader.load_images_and_labels(image_paths, image_dir=config['image_dir'], n_class=1, file_names_dict=file_names_dict, num_channel=channels, do_center_crop=True) labels = labels.ravel() labels = np.eye(NUMS_CLASS_cls)[labels.astype(int)] target_labels_probs = np.random.randint(0, high=NUMS_CLASS, size=BATCH_SIZE) target_labels_w_ind = np.random.randint(0, high=w_dim, size=BATCH_SIZE) target_labels = np.eye(w_dim)[target_labels_w_ind] * np.repeat(np.expand_dims(target_labels_probs, axis=-1), w_dim, axis=1) my_feed_dict = {y_target: target_labels, x_source: img, train_phase: True, y_s: labels} _, par1_loss, par1_summary_str, overall_sum_str, counter = sess.run([optimizer_1, loss1, part1_sum, overall_sum, global_step], feed_dict=my_feed_dict) writer.add_summary(par1_summary_str, global_step=counter) writer.add_summary(overall_sum_str, global_step=counter) _, part2_loss, part2_summary_str, overall_sum_str2, counter = sess.run([optimizer_2, loss2, part2_sum, overall_sum, global_step], feed_dict=my_feed_dict) writer.add_summary(part2_summary_str, global_step=counter) writer.add_summary(overall_sum_str2, global_step=counter) def save_results(sess, step): num_seed_imgs = BATCH_SIZE img, labels = my_data_loader.load_images_and_labels(image_paths[0:num_seed_imgs], image_dir=config['image_dir'], n_class=1, file_names_dict=file_names_dict, num_channel=channels, do_center_crop=True) labels = labels.ravel() labels = np.eye(NUMS_CLASS_cls)[labels.astype(int)] target_labels_probs = np.random.randint(0, high=NUMS_CLASS, size=BATCH_SIZE) target_labels_w_ind = np.random.randint(0, high=w_dim, size=BATCH_SIZE) target_labels = np.eye(w_dim)[target_labels_w_ind] * np.repeat( np.expand_dims(target_labels_probs, axis=-1), w_dim, axis=1) my_feed_dict = {y_target: target_labels, x_source: img, train_phase: False, y_s: labels} sample_fake_img_traversal, sample_fake_2d_img_traversal = sess.run([fake_img_traversal_save, fake_2d_img_traversal_save], feed_dict=my_feed_dict) # save samples sample_file = os.path.join(sample_dir, '%06d.jpg' % step) save_image(sample_fake_img_traversal, sample_file) sample_file = os.path.join(sample_dir, '%06d_2d.jpg' % step) save_image(sample_fake_2d_img_traversal, sample_file) batch_counter = int(counter/2) if batch_counter % save_summary == 0: save_results(sess, batch_counter) if batch_counter % save_ckpt == 0: saver.save(sess, ckpt_dir + "/model%2d.ckpt" % batch_counter, global_step=global_step)
def test(config): # ============= Experiment Folder============= output_dir = os.path.join(config['log_dir'], config['name']) classifier_output_path = os.path.join(output_dir, 'classifier_output') try: os.makedirs(classifier_output_path) except: pass past_checkpoint = output_dir # ============= Experiment Parameters ============= BATCH_SIZE = config['batch_size'] channels = config['num_channel'] input_size = config['input_size'] N_CLASSES = config['num_class'] dataset = config['dataset'] # in certain circumstances, for example for when classifier has been trained # on re-sampled data, we want to still use the whole dataset for the generative model. # That's why we produce classifier's output on the test_image_label_dict if ('export_image_label_dict' in config.keys()) and ('export_train' in config.keys()) and ('export_test' in config.keys()): image_label_dict = config['export_image_label_dict'] train_ids = config['export_train'] test_ids = config['export_test'] else: image_label_dict = config['image_label_dict'] train_ids = config['train'] test_ids = config['test'] if dataset == 'CelebA': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader() elif dataset == 'shapes': pretrained_classifier = shapes_classifier my_data_loader = ShapesLoader() elif dataset == 'CelebA64' or dataset == 'dermatology': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader(input_size=64) elif dataset == 'synthderm': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader(input_size=64) # ============= Data ============= try: categories, file_names_dict = read_data_file(image_label_dict) except: print("Problem in reading input data file : ", image_label_dict) sys.exit() data_train = np.load(train_ids) data_test = np.load(test_ids) print("The classification categories are: ") print(categories) print('The size of the training set: ', data_train.shape[0]) print('The size of the testing set: ', data_test.shape[0]) # ============= placeholder ============= with tf.name_scope('input'): x_ = tf.placeholder(tf.float32, [None, input_size, input_size, channels], name='x-input') y_ = tf.placeholder(tf.int64, [None, N_CLASSES], name='y-input') isTrain = tf.placeholder(tf.bool) # ============= Model ============= if N_CLASSES == 1: y = tf.reshape(y_, [-1]) y = tf.one_hot(y, 2, on_value=1.0, off_value=0.0, axis=-1) logit, prediction = pretrained_classifier(x_, n_label=2, reuse=False, name='classifier', isTrain=isTrain) else: logit, prediction = pretrained_classifier(x_, n_label=N_CLASSES, reuse=False, name='classifier', isTrain=isTrain) y = y_ classif_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=y, logits=logit) loss = tf.losses.get_total_loss() # ============= Variables ============= # Note that this list of variables only include the weights and biases in the model. lst_vars = [] for v in tf.global_variables(): lst_vars.append(v) # ============= Session ============= sess = tf.InteractiveSession() saver = tf.train.Saver(var_list=lst_vars) tf.global_variables_initializer().run() # ============= Load Checkpoint ============= if past_checkpoint is not None: ckpt = tf.train.get_checkpoint_state(past_checkpoint + '/') if ckpt and ckpt.model_checkpoint_path: print(str(ckpt.model_checkpoint_path)) saver.restore(sess, tf.train.latest_checkpoint(past_checkpoint + '/')) else: sys.exit() else: sys.exit() # ============= Testing - Save the Output ============= def get_predictions(data, subset_name): names = np.empty([0]) prediction_y = np.empty([0]) true_y = np.empty([0]) num_batch = int(data.shape[0] / BATCH_SIZE) for i in range(0, num_batch): start = i * BATCH_SIZE ns = data[start:start + BATCH_SIZE] xs, ys = my_data_loader.load_images_and_labels( ns, image_dir=config['image_dir'], n_class=N_CLASSES, file_names_dict=file_names_dict, num_channel=channels, do_center_crop=True) [_pred] = sess.run([prediction], feed_dict={ x_: xs, isTrain: False, y_: ys }) if i == 0: names = np.asarray(ns) prediction_y = np.asarray(_pred) true_y = np.asarray(ys) else: names = np.append(names, np.asarray(ns), axis=0) prediction_y = np.append(prediction_y, np.asarray(_pred), axis=0) true_y = np.append(true_y, np.asarray(ys), axis=0) np.save(classifier_output_path + '/name_{}1.npy'.format(subset_name), names) np.save( classifier_output_path + '/prediction_y_{}1.npy'.format(subset_name), prediction_y) np.save(classifier_output_path + '/true_y_{}1.npy'.format(subset_name), true_y) return names, prediction_y, np.reshape(true_y, [-1, N_CLASSES]) train_names, train_prediction_y, train_true_y = get_predictions( data_train, 'train') test_names, test_prediction_y, test_true_y = get_predictions( data_test, 'test') return train_names, train_prediction_y, train_true_y, test_names, test_prediction_y, test_true_y
def test(config, dbg_img_label_dict=None, dbg_mode=False, export_output=True, dbg_size=10, dbg_img_indices=[], calc_stability=True): # ============= Experiment Folder============= assets_dir = os.path.join(config['log_dir'], config['name']) log_dir = os.path.join(assets_dir, 'log') ckpt_dir = os.path.join(assets_dir, 'ckpt_dir') sample_dir = os.path.join(assets_dir, 'sample') # Whether this is for saving the results for substitutability metric or the regular testing process. # If only for substitutability, we skip saving large arrays and additional multiple random outputs to avoid OOM calc_substitutability = config['calc_substitutability'] if calc_substitutability: substitutability_attr = config['substitutability_attr'] test_dir = os.path.join(assets_dir, 'test', 'substitutability_input') substitutability_exported_img_label_dict = os.path.join( test_dir, '{}_dims_{}_clss_{}.txt'.format(substitutability_attr, config['w_dim'], config['num_bins'])) substitutability_label_scaler = config['num_bins'] - 1 exported_dict = {} substitutability_classifier_config = config[ 'substitutability_classifier_config'] _cls_config = yaml.load(open(config['classifier_config'])) substitutability_img_subset = _cls_config['train'] substitutability_img_label_dict = _cls_config['image_label_dict'] _edited_cls_config = deepcopy(_cls_config) _edited_cls_config['image_dir'] = os.path.join(test_dir, 'images') if not os.path.exists(_edited_cls_config['image_dir']): os.makedirs(_edited_cls_config['image_dir']) _edited_cls_config[ 'image_label_dict'] = substitutability_exported_img_label_dict _edited_cls_config['train'] = os.path.join(test_dir, 'train_ids.npy') _edited_cls_config['test'] = '' # skips evaluating on test _edited_cls_config['log_dir'] = test_dir _edited_cls_config['ckpt_dir_continue'] = '' save_config_dict(_edited_cls_config, substitutability_classifier_config) else: test_dir = os.path.join(assets_dir, 'test') # ============= Experiment Parameters ============= ckpt_dir_cls = config['cls_experiment'] if 'evaluation_batch_size' in config.keys(): BATCH_SIZE = config['evaluation_batch_size'] else: BATCH_SIZE = config['batch_size'] channels = config['num_channel'] input_size = config['input_size'] NUMS_CLASS_cls = config['num_class'] NUMS_CLASS = config['num_bins'] MU_CLUSTER = config['mu_cluster'] VAR_CLUSTER = config['var_cluster'] TRAVERSAL_N_SIGMA = config['traversal_n_sigma'] STEP_SIZE = 2 * TRAVERSAL_N_SIGMA * VAR_CLUSTER / (NUMS_CLASS - 1) OFFSET = MU_CLUSTER - TRAVERSAL_N_SIGMA * VAR_CLUSTER metrics_stability_nx = config['metrics_stability_nx'] metrics_stability_var = config['metrics_stability_var'] target_class = config['target_class'] ckpt_dir_continue = ckpt_dir if dbg_img_label_dict is not None: image_label_dict = dbg_img_label_dict elif calc_substitutability: image_label_dict = substitutability_img_label_dict else: image_label_dict = config['image_label_dict'] # CSVAE parameters beta1 = config['beta1'] beta2 = config['beta2'] beta3 = config['beta3'] beta4 = config['beta4'] beta5 = config['beta5'] z_dim = config['z_dim'] w_dim = config['w_dim'] if dbg_mode: num_samples = dbg_size else: num_samples = config['count_to_save'] dataset = config['dataset'] if dataset == 'CelebA': my_data_loader = ImageLabelLoader(input_size=128) pretrained_classifier = celeba_classifier EncoderZ = EncoderZ_128 EncoderW = EncoderW_128 DecoderX = DecoderX_128 DecoderY = DecoderY_128 elif dataset == 'shapes': if calc_substitutability: my_data_loader = ShapesLoader() else: # my_data_loader = ShapesLoader() # for efficiency, let's just load as many samples as we need my_data_loader = ShapesLoader( dbg_mode=True, dbg_size=num_samples, dbg_image_label_dict=image_label_dict, dbg_img_indices=dbg_img_indices) dbg_mode = True pretrained_classifier = shapes_classifier EncoderZ = EncoderZ_64 EncoderW = EncoderW_64 DecoderX = DecoderX_64 DecoderY = DecoderY_64 elif dataset == 'CelebA64' or dataset == 'dermatology': my_data_loader = ImageLabelLoader(input_size=64) pretrained_classifier = celeba_classifier EncoderZ = EncoderZ_64 EncoderW = EncoderW_64 DecoderX = DecoderX_64 DecoderY = DecoderY_64 elif dataset == 'synthderm': my_data_loader = ImageLabelLoader(input_size=64) pretrained_classifier = celeba_classifier EncoderZ = EncoderZ_64 EncoderW = EncoderW_64 DecoderX = DecoderX_64 DecoderY = DecoderY_64 # ============= Data ============= try: categories, file_names_dict = read_data_file(image_label_dict) except: print("Problem in reading input data file : ", image_label_dict) sys.exit() if calc_substitutability: data = np.load(substitutability_img_subset) num_samples = len(data) elif dbg_mode and dataset == 'shapes': data = np.array([str(ind) for ind in my_data_loader.tmp_list]) else: if len(dbg_img_indices) > 0: data = np.asarray(dbg_img_indices) else: data = np.asarray(list(file_names_dict.keys())) print("The classification categories are: ") print(categories) print('The size of the test set: ', data.shape[0]) # ============= placeholder ============= x_source = tf.placeholder(tf.float32, [None, input_size, input_size, channels], name='x_source') y_s = tf.placeholder(tf.int32, [None, NUMS_CLASS_cls], name='y_s') y_source = y_s[:, NUMS_CLASS_cls - 1] train_phase = tf.placeholder(tf.bool, name='train_phase') y_target = tf.placeholder(tf.int32, [None, w_dim], name='y_target') # between 0 and NUMS_CLASS generation_dim = w_dim # ============= CSVAE ============= encoder_z = EncoderZ('encoder_z') encoder_w = EncoderW('encoder_w') decoder_x = DecoderX('decoder_x') decoder_y = DecoderY('decoder_y') # encode x to get mean, log variance, and samples from the latent subspace Z mu_z, logvar_z, z = encoder_z(x_source, z_dim) # encode x and y to get mean, log variance, and samples from the latent subspace W mu_w, logvar_w, w = encoder_w(x_source, y_source, w_dim) # pass samples of z and w to get predictions of x pred_x = decoder_x(tf.concat([w, z], axis=-1)) # get predicted labels based only on the latent subspace Z pred_y = decoder_y(z, NUMS_CLASS_cls) # Create a single image based on y_target target_w = STEP_SIZE * tf.cast(y_target, dtype=tf.float32) + OFFSET fake_target_img = decoder_x(tf.concat([target_w, z], axis=-1)) # ============= pre-trained classifier ============= real_img_cls_logit_pretrained, real_img_cls_prediction = pretrained_classifier( x_source, NUMS_CLASS_cls, reuse=False, name='classifier') fake_recon_cls_logit_pretrained, fake_recon_cls_prediction = pretrained_classifier( pred_x, NUMS_CLASS_cls, reuse=True) fake_img_cls_logit_pretrained, fake_img_cls_prediction = pretrained_classifier( fake_target_img, NUMS_CLASS_cls, reuse=True) # ============= predicted probabilities ============= fake_target_p_tensor = tf.reduce_max(tf.cast(y_target, tf.float32) * 1.0 / float(NUMS_CLASS - 1), axis=1) # ============= session ============= sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() # ============= Checkpoints ============= print(" [*] Reading checkpoint...") ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name)) print(ckpt_dir_continue, ckpt_name) print("Successful checkpoint upload") else: print("Failed checkpoint load") sys.exit() # ============= load pre-trained classifier checkpoint ============= class_vars = [ var for var in slim.get_variables_to_restore() if 'classifier' in var.name ] name_to_var_map_local = {var.op.name: var for var in class_vars} temp_saver = tf.train.Saver(var_list=name_to_var_map_local) ckpt = tf.train.get_checkpoint_state(ckpt_dir_cls) ckpt_name = os.path.basename(ckpt.model_checkpoint_path) temp_saver.restore(sess, os.path.join(ckpt_dir_cls, ckpt_name)) print("Classifier checkpoint loaded.................") print(ckpt_dir_cls, ckpt_name) # ============= Testing ============= def _save_output_array(name, values): np.save(os.path.join(test_dir, '{}.npy'.format(name)), values) if not calc_substitutability: names = np.empty([num_samples], dtype=object) real_imgs = np.empty([num_samples, input_size, input_size, channels]) fake_t_imgs = np.empty([ num_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels ]) fake_s_recon_imgs = np.empty([ num_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels ]) real_ps = np.empty( [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls]) recon_ps = np.empty( [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls]) fake_target_ps = np.empty([num_samples, generation_dim, NUMS_CLASS]) fake_ps = np.empty( [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls]) # For stability metric stability_fake_t_imgs = np.empty([ num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS, input_size, input_size, channels ]) stability_fake_s_recon_imgs = np.empty([ num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS, input_size, input_size, channels ]) stability_recon_ps = np.empty([ num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS, NUMS_CLASS_cls ]) stability_fake_ps = np.empty([ num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS, NUMS_CLASS_cls ]) arrs_to_save = [ 'names', 'real_imgs', 'fake_t_imgs', 'fake_s_recon_imgs', 'real_ps', 'recon_ps', 'fake_target_ps', 'fake_ps', 'stability_fake_t_imgs', 'stability_fake_s_recon_imgs', 'stability_recon_ps', 'stability_fake_ps' ] np.random.shuffle(data) data = data[0:num_samples] for i in range(math.ceil(data.shape[0] / BATCH_SIZE)): image_paths = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] # num_seed_imgs is either BATCH_SIZE # or if the number of samples is not divisible by BATCH_SIZE a smaller value num_seed_imgs = np.shape(image_paths)[0] img, _labels = my_data_loader.load_images_and_labels( image_paths, config['image_dir'], 1, file_names_dict, channels, do_center_crop=True) img_repeat = np.repeat(img, NUMS_CLASS * generation_dim, 0) labels = np.repeat(_labels, NUMS_CLASS * generation_dim, 0) labels = labels.ravel() labels = np.eye(NUMS_CLASS_cls)[labels.astype(int)] _dim_bin_arr = np.zeros((generation_dim * NUMS_CLASS, generation_dim)) for _gen_dim in range(generation_dim): _start = _gen_dim * NUMS_CLASS _end = (_gen_dim + 1) * NUMS_CLASS _dim_bin_arr_sub = np.zeros((NUMS_CLASS, generation_dim)) _dim_bin_arr_sub[:, _gen_dim] = np.asarray(range(NUMS_CLASS)) _dim_bin_arr[_start:_end, :] = _dim_bin_arr_sub target_labels = np.tile( _dim_bin_arr, (num_seed_imgs, 1)) # [num_seed_imgs * w_dim * NUMS_CLASS, w_dim] # target_labels = np.tile( # np.repeat(np.expand_dims(np.asarray(range(NUMS_CLASS)), axis=1), generation_dim, axis=1), # (num_seed_imgs*generation_dim, 1)) # [num_seed_imgs * w_dim * NUMS_CLASS, w_dim] my_feed_dict = { y_target: target_labels, x_source: img_repeat, train_phase: False, y_s: labels } fake_t_img, fake_s_recon_img, real_p, recon_p, fake_target_p, fake_p = sess.run( [ fake_target_img, pred_x, real_img_cls_prediction, fake_recon_cls_prediction, fake_target_p_tensor, fake_img_cls_prediction ], feed_dict=my_feed_dict) print('{} / {}'.format(i + 1, math.ceil(data.shape[0] / BATCH_SIZE))) _num_cur_samples = len(image_paths) if calc_substitutability: _ind_generation_dim = np.random.randint(low=0, high=generation_dim, size=_num_cur_samples) reshaped_imgs = np.reshape( fake_t_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) sub_exported_dict = save_batch_images( reshaped_imgs, image_paths, _ind_generation_dim, _labels, substitutability_label_scaler, _edited_cls_config['image_dir'], has_extension=(dataset != 'shapes')) exported_dict.update(sub_exported_dict) else: start_ind = i * BATCH_SIZE end_ind = start_ind + _num_cur_samples names[start_ind:end_ind] = np.asarray(image_paths) if calc_stability: for j in range(metrics_stability_nx): noisy_img = img + np.random.normal( loc=0.0, scale=metrics_stability_var, size=np.shape(img)) stability_img_repeat = np.repeat( noisy_img, NUMS_CLASS * generation_dim, 0) stability_feed_dict = { y_target: target_labels, x_source: stability_img_repeat, train_phase: False, y_s: labels } _stability_fake_t_img, _stability_fake_s_recon_img, _stability_recon_p, _stability_fake_p = sess.run( [ fake_target_img, pred_x, fake_recon_cls_prediction, fake_img_cls_prediction ], feed_dict=stability_feed_dict) stability_fake_t_imgs[start_ind:end_ind, j] = np.reshape( _stability_fake_t_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) stability_fake_s_recon_imgs[ start_ind:end_ind, j] = np.reshape( _stability_fake_s_recon_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) stability_recon_ps[start_ind:end_ind, j] = np.reshape( _stability_recon_p, (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls)) stability_fake_ps[start_ind:end_ind, j] = np.reshape( _stability_fake_p, (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls)) real_imgs[start_ind:end_ind] = img fake_t_imgs[start_ind:end_ind] = np.reshape( fake_t_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) fake_s_recon_imgs[start_ind:end_ind] = np.reshape( fake_s_recon_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) real_ps[start_ind:end_ind] = np.reshape( real_p, (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls)) recon_ps[start_ind:end_ind] = np.reshape( recon_p, (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls)) fake_target_ps[start_ind:end_ind] = np.reshape( fake_target_p, (_num_cur_samples, generation_dim, NUMS_CLASS)) fake_ps[start_ind:end_ind] = np.reshape( fake_p, (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls)) output_dict = {} if calc_substitutability: save_dict(exported_dict, substitutability_exported_img_label_dict, substitutability_attr) np.save(_edited_cls_config['train'], np.asarray(list(exported_dict.keys()))) # retrain the classifier with the new generated images tf.reset_default_graph() train_classif(config['substitutability_classifier_config']) else: if export_output: for arr_name in arrs_to_save: _save_output_array(arr_name, eval(arr_name)) for arr_name in arrs_to_save: output_dict.update({arr_name: eval(arr_name)}) return output_dict
def train(): parser = argparse.ArgumentParser() parser.add_argument('--config', '-c', type=str) parser.add_argument('--debug', '-d', action='store_true') args = parser.parse_args() # ============= Load config ============= config_path = args.config config = yaml.load(open(config_path)) print(config) # ============= Experiment Folder============= assets_dir = os.path.join(config['log_dir'], config['name']) log_dir = os.path.join(assets_dir, 'log') ckpt_dir = os.path.join(assets_dir, 'ckpt_dir') sample_dir = os.path.join(assets_dir, 'sample') test_dir = os.path.join(assets_dir, 'test') # make directory if not exist try: os.makedirs(log_dir) except: pass try: os.makedirs(ckpt_dir) except: pass try: os.makedirs(sample_dir) except: pass try: os.makedirs(test_dir) except: pass # ============= Experiment Parameters ============= ckpt_dir_cls = config['cls_experiment'] BATCH_SIZE = config['batch_size'] EPOCHS = config['epochs'] channels = config['num_channel'] input_size = config['input_size'] NUMS_CLASS_cls = config['num_class'] NUMS_CLASS = config['num_bins'] target_class = config['target_class'] lambda_GAN = config['lambda_GAN'] lambda_cyc = config['lambda_cyc'] lambda_cls = config['lambda_cls'] save_summary = int(config['save_summary']) save_ckpt = int(config['save_ckpt']) ckpt_dir_continue = config['ckpt_dir_continue'] k_dim = config['k_dim'] lambda_r = config['lambda_r'] disentangle = k_dim > 1 discriminate_evert_nth = config['discriminate_every_nth'] generate_every_nth = config['generate_every_nth'] dataset = config['dataset'] if dataset == 'CelebA': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader() Discriminator_Ordinal = Discriminator_Ordinal_128 Generator_Encoder_Decoder = Generator_Encoder_Decoder_128 Discriminator_Contrastive = Discriminator_Contrastive_128 elif dataset == 'shapes': pretrained_classifier = shapes_classifier if args.debug: my_data_loader = ShapesLoader( dbg_mode=True, dbg_size=config['batch_size'], dbg_image_label_dict=config['image_label_dict']) else: my_data_loader = ShapesLoader() Discriminator_Ordinal = Discriminator_Ordinal_64 Generator_Encoder_Decoder = Generator_Encoder_Decoder_64 Discriminator_Contrastive = Discriminator_Contrastive_64 elif dataset == 'CelebA64' or dataset == 'dermatology': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader(input_size=64) Discriminator_Ordinal = Discriminator_Ordinal_64 Generator_Encoder_Decoder = Generator_Encoder_Decoder_64 Discriminator_Contrastive = Discriminator_Contrastive_64 elif dataset == 'synthderm': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader(input_size=64) Discriminator_Ordinal = Discriminator_Ordinal_64 Generator_Encoder_Decoder = Generator_Encoder_Decoder_64 Discriminator_Contrastive = Discriminator_Contrastive_64 if ckpt_dir_continue == '': continue_train = False else: ckpt_dir_continue = os.path.join(ckpt_dir_continue, 'ckpt_dir') continue_train = True global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step') # ============= Data ============= try: categories, file_names_dict = read_data_file( config['image_label_dict']) except: print("Problem in reading input data file : ", config['image_label_dict']) sys.exit() data = np.asarray(list(file_names_dict.keys())) print("The classification categories are: ") print(categories) print('The size of the training set: ', data.shape[0]) fp = open(os.path.join(log_dir, 'setting.txt'), 'w') fp.write('config_file:' + str(config_path) + '\n') fp.close() # ============= placeholder ============= x_source = tf.placeholder(tf.float32, [None, input_size, input_size, channels], name='x_source') y_s = tf.placeholder(tf.int32, [None, NUMS_CLASS], name='y_s') y_source = y_s[:, 0] train_phase = tf.placeholder(tf.bool, name='train_phase') y_t = tf.placeholder(tf.int32, [None, NUMS_CLASS], name='y_t') y_target = y_t[:, 0] if disentangle: y_regularizer = tf.placeholder(tf.int32, [None], name='y_regularizer') y_r = tf.placeholder(tf.float32, [None, k_dim], name='y_r') y_r_0 = tf.zeros_like(y_r, name='y_r_0') # ============= G & D ============= G = Generator_Encoder_Decoder( "generator") # with conditional BN, SAGAN: SN here as well D = Discriminator_Ordinal("discriminator") # with SN and projection real_source_logits = D(x_source, y_s, NUMS_CLASS, "NO_OPS") if disentangle: fake_target_img, fake_target_img_embedding = G( x_source, y_regularizer * NUMS_CLASS + y_target, NUMS_CLASS * k_dim) fake_source_img, fake_source_img_embedding = G( fake_target_img, y_regularizer * NUMS_CLASS + y_source, NUMS_CLASS * k_dim) fake_source_recons_img, x_source_img_embedding = G( x_source, y_regularizer * NUMS_CLASS + y_source, NUMS_CLASS * k_dim) else: fake_target_img, fake_target_img_embedding = G(x_source, y_target, NUMS_CLASS) fake_source_img, fake_source_img_embedding = G(fake_target_img, y_source, NUMS_CLASS) fake_source_recons_img, x_source_img_embedding = G( x_source, y_source, NUMS_CLASS) fake_target_logits = D(fake_target_img, y_t, NUMS_CLASS, None) # ============= pre-trained classifier ============= real_img_cls_logit_pretrained, real_img_cls_prediction = pretrained_classifier( x_source, NUMS_CLASS_cls, reuse=False, name='classifier') fake_img_cls_logit_pretrained, fake_img_cls_prediction = pretrained_classifier( fake_target_img, NUMS_CLASS_cls, reuse=True) real_img_recons_cls_logit_pretrained, real_img_recons_cls_prediction = pretrained_classifier( fake_source_img, NUMS_CLASS_cls, reuse=True) # ============= pre-trained classifier loss ============= real_p = tf.cast(y_target, tf.float32) * 1.0 / float(NUMS_CLASS - 1) fake_q = fake_img_cls_prediction[:, target_class] fake_evaluation = (real_p * safe_log(fake_q)) + ( (1 - real_p) * safe_log(1 - fake_q)) fake_evaluation = -tf.reduce_mean(fake_evaluation) recons_evaluation = (real_img_cls_prediction[:, target_class] * safe_log( real_img_recons_cls_prediction[:, target_class])) + ( (1 - real_img_cls_prediction[:, target_class]) * safe_log(1 - real_img_recons_cls_prediction[:, target_class])) recons_evaluation = -tf.reduce_mean(recons_evaluation) # ============= regularizer constrastive discriminator loss ============= if disentangle: R = Discriminator_Contrastive("disentangler") regularizer_fake_target_v_source_logits = R( tf.concat([x_source, fake_target_img], axis=-1), k_dim) regularizer_fake_source_v_target_logits = R( tf.concat([fake_target_img, fake_source_img], axis=-1), k_dim) regularizer_fake_source_v_source_logits = R( tf.concat([x_source, fake_source_img], axis=-1), k_dim) regularizer_fake_source_recon_v_source_logits = R( tf.concat([x_source, fake_source_recons_img], axis=-1), k_dim) # ============= Loss ============= D_loss_GAN, D_acc, D_precision, D_recall = discriminator_loss( 'hinge', real_source_logits, fake_target_logits) G_loss_GAN = generator_loss('hinge', fake_target_logits) G_loss_cyc = l1_loss(x_source, fake_source_img) G_loss_rec = l1_loss( x_source, fake_source_recons_img ) #+l2_loss(x_source_img_embedding, fake_source_img_embedding) D_loss = (D_loss_GAN * lambda_GAN) D_opt = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(D_loss, var_list=D.var_list(), global_step=global_step) if disentangle: R_fake_target_v_source_loss, R_fake_target_v_source_acc = contrastive_regularizer_loss( regularizer_fake_target_v_source_logits, y_r) R_fake_source_v_target_loss, R_fake_source_v_target_acc = contrastive_regularizer_loss( regularizer_fake_source_v_target_logits, y_r) R_fake_source_v_source_loss, R_fake_source_v_source_acc = contrastive_regularizer_loss( regularizer_fake_source_v_source_logits, y_r_0) R_fake_source_recon_v_source_loss, R_fake_source_recon_v_source_acc = contrastive_regularizer_loss( regularizer_fake_source_recon_v_source_logits, y_r_0) R_loss = R_fake_target_v_source_loss + R_fake_source_v_target_loss + R_fake_source_v_source_loss + R_fake_source_recon_v_source_loss R_opt = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize( R_loss * lambda_r, var_list=R.var_list(), global_step=global_step) G_loss = (G_loss_GAN * lambda_GAN) + (G_loss_rec * lambda_cyc) + ( G_loss_cyc * lambda_cyc) + (fake_evaluation * lambda_cls) + ( recons_evaluation * lambda_cls) + (R_loss * lambda_r) G_opt = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize( G_loss, var_list=G.var_list() + R.var_list(), global_step=global_step) else: G_loss = (G_loss_GAN * lambda_GAN) + (G_loss_rec * lambda_cyc) + ( G_loss_cyc * lambda_cyc) + (fake_evaluation * lambda_cls) + ( recons_evaluation * lambda_cls) G_opt = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize( G_loss, var_list=G.var_list(), global_step=global_step) # ============= summary ============= real_img_sum = tf.summary.image('real_img', x_source) fake_img_sum = tf.summary.image('fake_target_img', fake_target_img) fake_source_img_sum = tf.summary.image('fake_source_img', fake_source_img) fake_source_recons_img_sum = tf.summary.image('fake_source_recons_img', fake_source_recons_img) acc_d = tf.summary.scalar('discriminator/acc_d', D_acc) precision_d = tf.summary.scalar('discriminator/precision_d', D_precision) recall_d = tf.summary.scalar('discriminator/recall_d', D_recall) loss_d_sum = tf.summary.scalar('discriminator/loss_d', D_loss) loss_d_GAN_sum = tf.summary.scalar('discriminator/loss_d_GAN', D_loss_GAN) loss_g_sum = tf.summary.scalar('generator/loss_g', G_loss) loss_g_GAN_sum = tf.summary.scalar('generator/loss_g_GAN', G_loss_GAN) loss_g_cyc_sum = tf.summary.scalar('generator/G_loss_cyc', G_loss_cyc) G_loss_rec_sum = tf.summary.scalar('generator/G_loss_rec', G_loss_rec) evaluation_fake = tf.summary.scalar('generator/fake_evaluation', fake_evaluation) evaluation_recons = tf.summary.scalar('generator/recons_evaluation', recons_evaluation) g_sum = tf.summary.merge([ loss_g_sum, loss_g_GAN_sum, loss_g_cyc_sum, real_img_sum, G_loss_rec_sum, fake_img_sum, fake_source_img_sum, fake_source_recons_img_sum, evaluation_fake, evaluation_recons ]) d_sum = tf.summary.merge( [loss_d_sum, loss_d_GAN_sum, acc_d, precision_d, recall_d]) # Disentangler Contrastive Regularizer losses if disentangle: loss_r_fake_target_v_source = tf.summary.scalar( 'disentangler/loss_r_fake_target_v_source', R_fake_target_v_source_loss) loss_r_fake_source_v_target = tf.summary.scalar( 'disentangler/loss_r_fake_source_v_target', R_fake_source_v_target_loss) loss_r_fake_source_v_source = tf.summary.scalar( 'disentangler/loss_r_fake_source_v_source', R_fake_source_v_source_loss) loss_r_fake_source_recon_v_source = tf.summary.scalar( 'disentangler/loss_r_fake_source_recon_v_source', R_fake_source_recon_v_source_loss) loss_r_sum = tf.summary.scalar('disentangler/loss_r', R_loss) acc_r_fake_target_v_source = tf.summary.scalar( 'disentangler/acc_r_fake_target_v_source', R_fake_target_v_source_acc) acc_r_fake_source_v_target = tf.summary.scalar( 'disentangler/acc_r_fake_source_v_target', R_fake_source_v_target_acc) acc_r_fake_source_v_source = tf.summary.scalar( 'disentangler/acc_r_fake_source_v_source', R_fake_source_v_source_acc) acc_r_fake_source_recon_v_source = tf.summary.scalar( 'disentangler/acc_r_fake_source_recon_v_source', R_fake_source_recon_v_source_acc) r_sum = tf.summary.merge([ loss_r_sum, loss_r_fake_target_v_source, loss_r_fake_source_v_target, loss_r_fake_source_v_source, loss_r_fake_source_recon_v_source, acc_r_fake_target_v_source, acc_r_fake_source_v_target, acc_r_fake_source_v_source, acc_r_fake_source_recon_v_source ]) # ============= session ============= sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() writer = tf.summary.FileWriter(log_dir, sess.graph) # ============= Checkpoints ============= if continue_train: print(" [*] before training, Load checkpoint ") print(" [*] Reading checkpoint...") ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name)) print(ckpt_dir_continue, ckpt_name) print("Successful checkpoint upload") else: print("Failed checkpoint load") else: print(" [!] before training, no need to Load ") # ============= load pre-trained classifier checkpoint ============= class_vars = [ var for var in slim.get_variables_to_restore() if 'classifier' in var.name ] name_to_var_map_local = {var.op.name: var for var in class_vars} temp_saver = tf.train.Saver(var_list=name_to_var_map_local) ckpt = tf.train.get_checkpoint_state(ckpt_dir_cls) ckpt_name = os.path.basename(ckpt.model_checkpoint_path) temp_saver.restore(sess, os.path.join(ckpt_dir_cls, ckpt_name)) print("Classifier checkpoint loaded.................") print(ckpt_dir_cls, ckpt_name) # ============= Training ============= for e in range(EPOCHS): np.random.shuffle(data) for i in range(data.shape[0] // BATCH_SIZE): if args.debug: image_paths = np.array( [str(ind) for ind in my_data_loader.tmp_list]) else: image_paths = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] img, labels = my_data_loader.load_images_and_labels( image_paths, image_dir=config['image_dir'], n_class=1, file_names_dict=file_names_dict, num_channel=channels, do_center_crop=True) labels = labels.ravel() target_labels = np.random.randint(0, high=NUMS_CLASS, size=BATCH_SIZE) identity_ind = labels == target_labels labels = convert_ordinal_to_binary(labels, NUMS_CLASS) target_labels = convert_ordinal_to_binary(target_labels, NUMS_CLASS) if disentangle: target_disentangle_ind = np.random.randint(0, high=k_dim, size=BATCH_SIZE) target_disentangle_ind_one_hot = np.eye( k_dim)[target_disentangle_ind] target_disentangle_ind_one_hot[identity_ind, :] = 0 my_feed_dict = { y_t: target_labels, x_source: img, train_phase: True, y_s: labels, y_regularizer: target_disentangle_ind, y_r: target_disentangle_ind_one_hot } else: my_feed_dict = { y_t: target_labels, x_source: img, train_phase: True, y_s: labels } if (i + 1) % discriminate_evert_nth == 0: _, d_loss, summary_str, counter = sess.run( [D_opt, D_loss, d_sum, global_step], feed_dict=my_feed_dict) writer.add_summary(summary_str, counter) if (i + 1) % generate_every_nth == 0: if disentangle: _, g_loss, g_summary_str, r_loss, r_summary_str, counter = sess.run( [G_opt, G_loss, g_sum, R_loss, r_sum, global_step], feed_dict=my_feed_dict) # _, r_loss, r_summary_str = sess.run([R_opt, R_loss, r_sum], feed_dict=my_feed_dict) writer.add_summary(r_summary_str, counter) else: _, g_loss, g_summary_str, counter = sess.run( [G_opt, G_loss, g_sum, global_step], feed_dict=my_feed_dict) writer.add_summary(g_summary_str, counter) def save_results(sess, step): num_seed_imgs = 8 img, labels = my_data_loader.load_images_and_labels( image_paths[0:num_seed_imgs], image_dir=config['image_dir'], n_class=1, file_names_dict=file_names_dict, num_channel=channels, do_center_crop=True) labels = np.repeat(labels, NUMS_CLASS * k_dim, 0) labels = labels.ravel() labels = convert_ordinal_to_binary(labels, NUMS_CLASS) img_repeat = np.repeat(img, NUMS_CLASS * k_dim, 0) target_labels = np.asarray([ np.asarray(range(NUMS_CLASS)) for j in range(num_seed_imgs * k_dim) ]) target_labels = target_labels.ravel() identity_ind = labels == target_labels target_labels = convert_ordinal_to_binary( target_labels, NUMS_CLASS) if disentangle: target_disentangle_ind = np.asarray([ np.repeat(np.asarray(range(k_dim)), NUMS_CLASS) for j in range(num_seed_imgs) ]) target_disentangle_ind = target_disentangle_ind.ravel() target_disentangle_ind_one_hot = np.eye( k_dim)[target_disentangle_ind] target_disentangle_ind_one_hot[identity_ind, :] = 0 my_feed_dict = { y_t: target_labels, x_source: img_repeat, train_phase: False, y_s: labels, y_regularizer: target_disentangle_ind, y_r: target_disentangle_ind_one_hot } else: my_feed_dict = { y_t: target_labels, x_source: img_repeat, train_phase: False, y_s: labels } FAKE_IMG, fake_logits_ = sess.run( [fake_target_img, fake_target_logits], feed_dict=my_feed_dict) output_fake_img = np.reshape( FAKE_IMG, [-1, k_dim, NUMS_CLASS, input_size, input_size, channels]) # save samples sample_file = os.path.join(sample_dir, '%06d.jpg' % step) save_images(output_fake_img, sample_file, num_samples=num_seed_imgs, nums_class=NUMS_CLASS, k_dim=k_dim, image_size=input_size, num_channel=channels) np.save(sample_file.split('.jpg')[0] + '_y.npy', labels) _approx_num_seen_batches = int(counter / 3) if _approx_num_seen_batches % save_summary == 0: save_results(sess, _approx_num_seen_batches) if _approx_num_seen_batches % save_ckpt == 0: saver.save(sess, ckpt_dir + "/model%2d.ckpt" % _approx_num_seen_batches, global_step=global_step)
def train(config_path, overwrite_output_dir=None): config = yaml.load(open(config_path)) print(config) # ============= Experiment Folder============= if overwrite_output_dir is not None: output_dir = overwrite_output_dir else: output_dir = os.path.join(config['log_dir'], config['name']) try: os.makedirs(output_dir) except: pass try: os.makedirs(os.path.join(output_dir, 'logs')) except: pass # ============= Experiment Parameters ============= BATCH_SIZE = config['batch_size'] EPOCHS = config['epochs'] channels = config['num_channel'] input_size = config['input_size'] N_CLASSES = config['num_class'] ckpt_dir_continue = config['ckpt_dir_continue'] dataset = config['dataset'] if dataset == 'CelebA': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader(input_size=128) elif dataset == 'shapes': pretrained_classifier = shapes_classifier if config['image_dir'] == '': my_data_loader = ShapesLoader() else: my_data_loader = ImageLabelLoader(input_size=64) elif dataset == 'CelebA64' or dataset == 'dermatology': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader(input_size=64) elif dataset == 'synthderm': pretrained_classifier = celeba_classifier my_data_loader = ImageLabelLoader(input_size=64) if ckpt_dir_continue == '': continue_train = False else: continue_train = True if config['test'] == '': evaluate = False else: evaluate = True # ============= Data ============= try: categories, file_names_dict = read_data_file( config['image_label_dict']) except: print("Problem in reading input data file : ", config['image_label_dict']) sys.exit() data_train = np.load(config['train']) print("The classification categories are: ") print(categories) print('The size of the training set: ', data_train.shape[0]) if evaluate: data_test = np.load(config['test']) print('The size of the testing set: ', data_test.shape[0]) fp = open(os.path.join(output_dir, 'setting.txt'), 'w') fp.write('config_file:' + str(config_path) + '\n') fp.close() # ============= placeholder ============= with tf.name_scope('input'): x_ = tf.placeholder(tf.float32, [None, input_size, input_size, channels], name='x-input') y_ = tf.placeholder(tf.int64, [None, N_CLASSES], name='y-input') isTrain = tf.placeholder(tf.bool) # ============= Model ============= if N_CLASSES == 1: y = tf.reshape(y_, [-1]) y = tf.one_hot(y, 2, on_value=1.0, off_value=0.0, axis=-1) logit, prediction = pretrained_classifier(x_, n_label=2, reuse=False, name='classifier', isTrain=isTrain) else: logit, prediction = pretrained_classifier(x_, n_label=N_CLASSES, reuse=False, name='classifier', isTrain=isTrain) y = y_ classif_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=y, logits=logit) classif_acc = calc_accuracy(prediction=prediction, labels=y) loss = tf.losses.get_total_loss() # ============= Optimization functions ============= train_step = tf.train.AdamOptimizer(0.0001).minimize(loss) # ============= summary ============= cls_loss = tf.summary.scalar('classif_loss', classif_loss) total_loss = tf.summary.scalar('total_loss', loss) cls_acc = tf.summary.scalar('classif_acc', classif_acc) sum_train = tf.summary.merge([cls_loss, total_loss, cls_acc]) # ============= Variables ============= # Note that this list of variables only include the weights and biases in the model. lst_vars = [] for v in tf.global_variables(): lst_vars.append(v) # ============= Session ============= sess = tf.InteractiveSession() saver = tf.train.Saver(var_list=lst_vars) tf.global_variables_initializer().run() writer = tf.summary.FileWriter(output_dir + '/train', sess.graph) if evaluate: writer_test = tf.summary.FileWriter(output_dir + '/test', sess.graph) # ============= Checkpoints ============= if continue_train: print("Before training, Load checkpoint ") print("Reading checkpoint...") ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name)) print(ckpt_name) print("Successful checkpoint upload") else: print("Failed checkpoint load") sys.exit() # ============= Training ============= train_loss = [] test_loss = [] itr_train = 0 itr_test = 0 for epoch in range(EPOCHS): total_loss = 0.0 perm = np.arange(data_train.shape[0]) np.random.shuffle(perm) data_train = data_train[perm] num_batch = int(data_train.shape[0] / BATCH_SIZE) for i in range(0, num_batch): start = i * BATCH_SIZE ns = data_train[start:start + BATCH_SIZE] xs, ys = my_data_loader.load_images_and_labels( ns, image_dir=config['image_dir'], n_class=N_CLASSES, file_names_dict=file_names_dict, num_channel=channels, do_center_crop=True) [_, _loss, summary_str] = sess.run([train_step, loss, sum_train], feed_dict={ x_: xs, isTrain: True, y_: ys }) writer.add_summary(summary_str, itr_train) itr_train += 1 total_loss += _loss total_loss /= i print("Epoch: " + str(epoch) + " loss: " + str(total_loss) + '\n') train_loss.append(total_loss) if evaluate: total_loss = 0.0 perm = np.arange(data_test.shape[0]) np.random.shuffle(perm) data_test = data_test[perm] num_batch = int(data_test.shape[0] / BATCH_SIZE) for i in range(0, num_batch): start = i * BATCH_SIZE ns = data_test[start:start + BATCH_SIZE] xs, ys = my_data_loader.load_images_and_labels( ns, image_dir=config['image_dir'], n_class=N_CLASSES, file_names_dict=file_names_dict, num_channel=channels, do_center_crop=True) [_loss, summary_str] = sess.run([loss, sum_train], feed_dict={ x_: xs, isTrain: False, y_: ys }) writer_test.add_summary(summary_str, itr_test) itr_test += 1 total_loss += _loss total_loss /= i print("Epoch: " + str(epoch) + " Test loss: " + str(total_loss) + '\n') test_loss.append(total_loss) np.save(os.path.join(output_dir, 'logs', 'test_loss.npy'), np.asarray(test_loss)) checkpoint_name = os.path.join(output_dir, 'cp1_epoch' + str(epoch) + '.ckpt') save_path = saver.save(sess, checkpoint_name) np.save(os.path.join(output_dir, 'logs', 'train_loss.npy'), np.asarray(train_loss))
def test(config, dbg_img_label_dict=None, dbg_mode=False, export_output=True, dbg_size=10, dbg_img_indices=[], calc_stability=True): # ============= Experiment Folder============= assets_dir = os.path.join(config['log_dir'], config['name']) log_dir = os.path.join(assets_dir, 'log') ckpt_dir = os.path.join(assets_dir, 'ckpt_dir') sample_dir = os.path.join(assets_dir, 'sample') # Whether this is for saving the results for substitutability metric or the regular testing process. # If only for substitutability, we skip saving large arrays and additional multiple random outputs to avoid OOM calc_substitutability = config['calc_substitutability'] if calc_substitutability: substitutability_attr = config['substitutability_attr'] test_dir = os.path.join(assets_dir, 'test', 'substitutability_input') substitutability_exported_img_label_dict = os.path.join( test_dir, '{}_dims_{}_clss_{}.txt'.format(substitutability_attr, config['k_dim'], config['num_bins'])) substitutability_label_scaler = config['num_bins'] - 1 exported_dict = {} substitutability_classifier_config = config[ 'substitutability_classifier_config'] _cls_config = yaml.load(open(config['classifier_config'])) substitutability_img_subset = _cls_config['train'] substitutability_img_label_dict = _cls_config['image_label_dict'] _edited_cls_config = deepcopy(_cls_config) _edited_cls_config['image_dir'] = os.path.join(test_dir, 'images') if not os.path.exists(_edited_cls_config['image_dir']): os.makedirs(_edited_cls_config['image_dir']) _edited_cls_config[ 'image_label_dict'] = substitutability_exported_img_label_dict _edited_cls_config['train'] = os.path.join(test_dir, 'train_ids.npy') _edited_cls_config['test'] = '' # skips evaluating on test _edited_cls_config['log_dir'] = test_dir _edited_cls_config['ckpt_dir_continue'] = '' save_config_dict(_edited_cls_config, substitutability_classifier_config) else: test_dir = os.path.join(assets_dir, 'test') # ============= Experiment Parameters ============= ckpt_dir_cls = config['cls_experiment'] if 'evaluation_batch_size' in config.keys(): BATCH_SIZE = config['evaluation_batch_size'] else: BATCH_SIZE = config['batch_size'] channels = config['num_channel'] input_size = config['input_size'] NUMS_CLASS_cls = config['num_class'] NUMS_CLASS = config['num_bins'] metrics_stability_nx = config['metrics_stability_nx'] metrics_stability_var = config['metrics_stability_var'] ckpt_dir_continue = ckpt_dir if dbg_img_label_dict is not None: image_label_dict = dbg_img_label_dict elif calc_substitutability: image_label_dict = substitutability_img_label_dict else: image_label_dict = config['image_label_dict'] # there are k_dim disentangled knobs at indices 0..k_dim-1 k_dim = config['k_dim'] disentangle = k_dim > 1 if dbg_mode: num_samples = dbg_size else: num_samples = config['count_to_save'] dataset = config['dataset'] if dataset == 'CelebA': my_data_loader = ImageLabelLoader(input_size=128) EMBEDDING_SIZE = embedding_size_128() pretrained_classifier = celeba_classifier Discriminator_Ordinal = Discriminator_Ordinal_128 Generator_Encoder_Decoder = Generator_Encoder_Decoder_128 elif dataset == 'shapes': if calc_substitutability: my_data_loader = ShapesLoader() else: # my_data_loader = ShapesLoader() # for efficiency, let's just load as many samples as we need my_data_loader = ShapesLoader( dbg_mode=True, dbg_size=num_samples, dbg_image_label_dict=image_label_dict, dbg_img_indices=dbg_img_indices) dbg_mode = True EMBEDDING_SIZE = embedding_size_64() pretrained_classifier = shapes_classifier Discriminator_Ordinal = Discriminator_Ordinal_64 Generator_Encoder_Decoder = Generator_Encoder_Decoder_64 elif dataset == 'CelebA64' or dataset == 'dermatology': my_data_loader = ImageLabelLoader(input_size=64) EMBEDDING_SIZE = embedding_size_64() pretrained_classifier = celeba_classifier Discriminator_Ordinal = Discriminator_Ordinal_64 Generator_Encoder_Decoder = Generator_Encoder_Decoder_64 elif dataset == 'synthderm': my_data_loader = ImageLabelLoader(input_size=64) EMBEDDING_SIZE = embedding_size_64() pretrained_classifier = celeba_classifier Discriminator_Ordinal = Discriminator_Ordinal_64 Generator_Encoder_Decoder = Generator_Encoder_Decoder_64 # ============= Data ============= try: categories, file_names_dict = read_data_file(image_label_dict) except: print("Problem in reading input data file : ", image_label_dict) sys.exit() if calc_substitutability: data = np.load(substitutability_img_subset) num_samples = len(data) elif dbg_mode and dataset == 'shapes': data = np.array([str(ind) for ind in my_data_loader.tmp_list]) else: if len(dbg_img_indices) > 0: data = np.asarray(dbg_img_indices) else: data = np.asarray(list(file_names_dict.keys())) print("The classification categories are: ") print(categories) print('The size of the training set: ', data.shape[0]) # ============= placeholder ============= x_source = tf.placeholder(tf.float32, [None, input_size, input_size, channels], name='x_source') y_s = tf.placeholder(tf.int32, [None, NUMS_CLASS], name='y_s') y_source = y_s[:, 0] train_phase = tf.placeholder(tf.bool, name='train_phase') y_t = tf.placeholder(tf.int32, [None, NUMS_CLASS], name='y_t') y_target = y_t[:, 0] if disentangle: y_regularizer = tf.placeholder(tf.int32, [None], name='y_regularizer') y_r = tf.placeholder(tf.float32, [None, k_dim], name='y_r') generation_dim = k_dim # ============= G & D ============= G = Generator_Encoder_Decoder( "generator") # with conditional BN, SAGAN: SN here as well D = Discriminator_Ordinal("discriminator") # with SN and projection real_source_logits = D(x_source, y_s, NUMS_CLASS, "NO_OPS") if disentangle: fake_target_img, fake_target_img_embedding = G( x_source, y_regularizer * NUMS_CLASS + y_target, NUMS_CLASS * generation_dim) fake_source_img, fake_source_img_embedding = G( fake_target_img, y_regularizer * NUMS_CLASS + y_source, NUMS_CLASS * generation_dim) fake_source_recons_img, x_source_img_embedding = G( x_source, y_regularizer * NUMS_CLASS + y_source, NUMS_CLASS * generation_dim) else: fake_target_img, fake_target_img_embedding = G(x_source, y_target, NUMS_CLASS) fake_source_img, fake_source_img_embedding = G(fake_target_img, y_source, NUMS_CLASS) fake_source_recons_img, x_source_img_embedding = G( x_source, y_source, NUMS_CLASS) fake_target_logits = D(fake_target_img, y_t, NUMS_CLASS, None) # ============= pre-trained classifier ============= real_img_cls_logit_pretrained, real_img_cls_prediction = pretrained_classifier( x_source, NUMS_CLASS_cls, reuse=False, name='classifier') fake_img_cls_logit_pretrained, fake_img_cls_prediction = pretrained_classifier( fake_target_img, NUMS_CLASS_cls, reuse=True) real_img_recons_cls_logit_pretrained, real_img_recons_cls_prediction = pretrained_classifier( fake_source_img, NUMS_CLASS_cls, reuse=True) fake_img_target_cls_prediction = tf.cast( y_target, tf.float32) * 1.0 / float(NUMS_CLASS - 1) # ============= session ============= sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() # ============= Checkpoints ============= print(" [*] Reading checkpoint...") ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name)) print(ckpt_dir_continue, ckpt_name) print("Successful checkpoint upload") else: print("Failed checkpoint load") sys.exit() # ============= load pre-trained classifier checkpoint ============= class_vars = [ var for var in slim.get_variables_to_restore() if 'classifier' in var.name ] name_to_var_map_local = {var.op.name: var for var in class_vars} temp_saver = tf.train.Saver(var_list=name_to_var_map_local) ckpt = tf.train.get_checkpoint_state(ckpt_dir_cls) ckpt_name = os.path.basename(ckpt.model_checkpoint_path) temp_saver.restore(sess, os.path.join(ckpt_dir_cls, ckpt_name)) print("Classifier checkpoint loaded.................") print(ckpt_dir_cls, ckpt_name) # ============= Testing ============= def _save_output_array(name, values): np.save(os.path.join(test_dir, '{}.npy'.format(name)), values) if not calc_substitutability: names = np.empty([num_samples], dtype=object) real_imgs = np.empty([num_samples, input_size, input_size, channels]) fake_t_imgs = np.empty([ num_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels ]) fake_t_embeds = np.empty([num_samples, generation_dim, NUMS_CLASS] + EMBEDDING_SIZE) fake_s_imgs = np.empty([ num_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels ]) fake_s_embeds = np.empty([num_samples, generation_dim, NUMS_CLASS] + EMBEDDING_SIZE) fake_s_recon_imgs = np.empty([ num_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels ]) s_embeds = np.empty([num_samples, generation_dim, NUMS_CLASS] + EMBEDDING_SIZE) real_ps = np.empty( [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls]) recon_ps = np.empty( [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls]) fake_target_ps = np.empty([num_samples, generation_dim, NUMS_CLASS]) fake_ps = np.empty( [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls]) # For stability metric stability_fake_t_imgs = np.empty([ num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS, input_size, input_size, channels ]) stability_fake_s_recon_imgs = np.empty([ num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS, input_size, input_size, channels ]) stability_recon_ps = np.empty([ num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS, NUMS_CLASS_cls ]) stability_fake_ps = np.empty([ num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS, NUMS_CLASS_cls ]) arrs_to_save = [ 'names', 'real_imgs', 'fake_t_imgs', 'fake_t_embeds', 'fake_s_imgs', 'fake_s_embeds', 'fake_s_recon_imgs', 's_embeds', 'real_ps', 'recon_ps', 'fake_target_ps', 'fake_ps', 'stability_fake_t_imgs', 'stability_fake_s_recon_imgs', 'stability_recon_ps', 'stability_fake_ps' ] np.random.shuffle(data) data = data[0:num_samples] for i in range(math.ceil(data.shape[0] / BATCH_SIZE)): image_paths = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] # num_seed_imgs is either BATCH_SIZE # or if the number of samples is not divisible by BATCH_SIZE a smaller value num_seed_imgs = np.shape(image_paths)[0] img, _labels = my_data_loader.load_images_and_labels( image_paths, config['image_dir'], 1, file_names_dict, channels, do_center_crop=True) labels = np.repeat(_labels, NUMS_CLASS * generation_dim, 0) labels = labels.ravel() labels = convert_ordinal_to_binary(labels, NUMS_CLASS) img_repeat = np.repeat(img, NUMS_CLASS * generation_dim, 0) target_labels = np.asarray([ np.asarray(range(NUMS_CLASS)) for j in range(num_seed_imgs * generation_dim) ]) target_labels = target_labels.ravel() identity_ind = labels == target_labels target_labels = convert_ordinal_to_binary(target_labels, NUMS_CLASS) if disentangle: target_disentangle_ind = np.asarray([ np.repeat(np.asarray(range(generation_dim)), NUMS_CLASS) for j in range(num_seed_imgs) ]) target_disentangle_ind = target_disentangle_ind.ravel() target_disentangle_ind_one_hot = np.eye( generation_dim)[target_disentangle_ind][:, 0:k_dim] target_disentangle_ind_one_hot[identity_ind, :] = 0 my_feed_dict = { y_t: target_labels, x_source: img_repeat, train_phase: False, y_s: labels, y_regularizer: target_disentangle_ind, y_r: target_disentangle_ind_one_hot } else: my_feed_dict = { y_t: target_labels, x_source: img_repeat, train_phase: False, y_s: labels } fake_t_img, fake_t_embed, fake_s_img, fake_s_embed, fake_s_recon_img, s_embed, real_p, recon_p, fake_target_p, fake_p = sess.run( [ fake_target_img, fake_target_img_embedding, fake_source_img, fake_source_img_embedding, fake_source_recons_img, x_source_img_embedding, real_img_cls_prediction, real_img_recons_cls_prediction, fake_img_target_cls_prediction, fake_img_cls_prediction ], feed_dict=my_feed_dict) print('{} / {}'.format(i + 1, math.ceil(data.shape[0] / BATCH_SIZE))) _num_cur_samples = len(image_paths) if calc_substitutability: _ind_generation_dim = np.random.randint(low=0, high=generation_dim, size=_num_cur_samples) reshaped_imgs = np.reshape( fake_t_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) sub_exported_dict = save_batch_images( reshaped_imgs, image_paths, _ind_generation_dim, _labels, substitutability_label_scaler, _edited_cls_config['image_dir'], has_extension=(dataset != 'shapes')) exported_dict.update(sub_exported_dict) else: start_ind = i * BATCH_SIZE end_ind = start_ind + _num_cur_samples names[start_ind:end_ind] = np.asarray(image_paths) if calc_stability: for j in range(metrics_stability_nx): noisy_img = img + np.random.normal( loc=0.0, scale=metrics_stability_var, size=np.shape(img)) stability_img_repeat = np.repeat( noisy_img, NUMS_CLASS * generation_dim, 0) stability_feed_dict = my_feed_dict.copy() stability_feed_dict.update( {x_source: stability_img_repeat}) _stability_fake_t_img, _stability_fake_s_recon_img, _stability_recon_p, _stability_fake_p = sess.run( [ fake_target_img, fake_source_recons_img, real_img_recons_cls_prediction, fake_img_cls_prediction ], feed_dict=stability_feed_dict) stability_fake_t_imgs[start_ind:end_ind, j] = np.reshape( _stability_fake_t_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) stability_fake_s_recon_imgs[ start_ind:end_ind, j] = np.reshape( _stability_fake_s_recon_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) stability_recon_ps[start_ind:end_ind, j] = np.reshape( _stability_recon_p, (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls)) stability_fake_ps[start_ind:end_ind, j] = np.reshape( _stability_fake_p, (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls)) real_imgs[start_ind:end_ind] = img fake_t_imgs[start_ind:end_ind] = np.reshape( fake_t_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) fake_s_imgs[start_ind:end_ind] = np.reshape( fake_s_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) fake_s_recon_imgs[start_ind:end_ind] = np.reshape( fake_s_recon_img, (_num_cur_samples, generation_dim, NUMS_CLASS, input_size, input_size, channels)) real_ps[start_ind:end_ind] = np.reshape( real_p, (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls)) recon_ps[start_ind:end_ind] = np.reshape( recon_p, (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls)) fake_target_ps[start_ind:end_ind] = np.reshape( fake_target_p, (_num_cur_samples, generation_dim, NUMS_CLASS)) fake_ps[start_ind:end_ind] = np.reshape( fake_p, (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls)) _RESHAPE_EMBED_SIZE = [ _num_cur_samples, generation_dim, NUMS_CLASS ] + EMBEDDING_SIZE fake_t_embeds[start_ind:end_ind] = np.reshape( fake_t_embed, _RESHAPE_EMBED_SIZE) fake_s_embeds[start_ind:end_ind] = np.reshape( fake_s_embed, _RESHAPE_EMBED_SIZE) s_embeds[start_ind:end_ind] = np.reshape(s_embed, _RESHAPE_EMBED_SIZE) output_dict = {} if calc_substitutability: save_dict(exported_dict, substitutability_exported_img_label_dict, substitutability_attr) np.save(_edited_cls_config['train'], np.asarray(list(exported_dict.keys()))) # retrain the classifier with the new generated images tf.reset_default_graph() train_classif(config['substitutability_classifier_config']) else: if export_output: for arr_name in arrs_to_save: _save_output_array(arr_name, eval(arr_name)) for arr_name in arrs_to_save: output_dict.update({arr_name: eval(arr_name)}) return output_dict