Exemple #1
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str)
    parser.add_argument('--debug', '-d', action='store_true')
    args = parser.parse_args()

    # ============= Load config =============
    config_path = args.config
    config = yaml.load(open(config_path))
    print(config)

    # ============= Experiment Folder=============
    assets_dir = os.path.join(config['log_dir'], config['name'])
    log_dir = os.path.join(assets_dir, 'log')
    ckpt_dir = os.path.join(assets_dir, 'ckpt_dir')
    sample_dir = os.path.join(assets_dir, 'sample')
    test_dir = os.path.join(assets_dir, 'test')
    # make directory if not exist
    try:
        os.makedirs(log_dir)
    except:
        pass
    try:
        os.makedirs(ckpt_dir)
    except:
        pass
    try:
        os.makedirs(sample_dir)
    except:
        pass
    try:
        os.makedirs(test_dir)
    except:
        pass

    # ============= Experiment Parameters =============
    ckpt_dir_cls = config['cls_experiment']
    BATCH_SIZE = config['batch_size']
    EPOCHS = config['epochs']
    channels = config['num_channel']
    input_size = config['input_size']
    NUMS_CLASS_cls = config['num_class']
    NUMS_CLASS = config['num_bins']
    MU_CLUSTER = config['mu_cluster']
    VAR_CLUSTER = config['var_cluster']
    TRAVERSAL_N_SIGMA = config['traversal_n_sigma']
    STEP_SIZE = 2*TRAVERSAL_N_SIGMA * VAR_CLUSTER/(NUMS_CLASS - 1)
    OFFSET = MU_CLUSTER - TRAVERSAL_N_SIGMA*VAR_CLUSTER
    target_class = config['target_class']

    # CSVAE parameters
    beta1 = config['beta1']
    beta2 = config['beta2']
    beta3 = config['beta3']
    beta4 = config['beta4']
    beta5 = config['beta5']
    z_dim = config['z_dim']
    w_dim = config['w_dim']

    save_summary = int(config['save_summary'])
    save_ckpt = int(config['save_ckpt'])
    ckpt_dir_continue = config['ckpt_dir_continue']

    dataset = config['dataset']
    if dataset == 'CelebA':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader()
        EncoderZ = EncoderZ_128
        EncoderW = EncoderW_128
        DecoderX = DecoderX_128
        DecoderY = DecoderY_128

    elif dataset == 'shapes':
        pretrained_classifier = shapes_classifier
        if args.debug:
            my_data_loader = ShapesLoader(dbg_mode=True, dbg_size=config['batch_size'],
                                          dbg_image_label_dict=config['image_label_dict'])
        else:
            my_data_loader = ShapesLoader()
        EncoderZ = EncoderZ_64
        EncoderW = EncoderW_64
        DecoderX = DecoderX_64
        DecoderY = DecoderY_64

    elif dataset == 'CelebA64' or dataset == 'dermatology':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
        EncoderZ = EncoderZ_64
        EncoderW = EncoderW_64
        DecoderX = DecoderX_64
        DecoderY = DecoderY_64
    elif dataset == 'synthderm':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
        EncoderZ = EncoderZ_64
        EncoderW = EncoderW_64
        DecoderX = DecoderX_64
        DecoderY = DecoderY_64

    if ckpt_dir_continue == '':
        continue_train = False
    else:
        ckpt_dir_continue = os.path.join(ckpt_dir_continue, 'ckpt_dir')
        continue_train = True

    global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')

    # ============= Data =============
    try:
        categories, file_names_dict = read_data_file(config['image_label_dict'])
    except:
        print("Problem in reading input data file : ", config['image_label_dict'])
        sys.exit()
    data = np.asarray(list(file_names_dict.keys()))

    # CSVAE does not need discretizing categories. The default 2 is recommended.
    print("The classification categories are: ")
    print(categories)
    print('The size of the training set: ', data.shape[0])
    fp = open(os.path.join(log_dir, 'setting.txt'), 'w')
    fp.write('config_file:' + str(config_path) + '\n')
    fp.close()

    # ============= placeholder =============
    x_source = tf.placeholder(tf.float32, [None, input_size, input_size, channels], name='x_source')
    y_s = tf.placeholder(tf.int32, [None, NUMS_CLASS_cls], name='y_s')
    y_source = y_s[:, NUMS_CLASS_cls-1]
    train_phase = tf.placeholder(tf.bool, name='train_phase')

    y_target = tf.placeholder(tf.int32, [None, w_dim], name='y_target')  # between 0 and NUMS_CLASS

    # ============= CSVAE =============

    encoder_z = EncoderZ('encoder_z')
    encoder_w = EncoderW('encoder_w')
    decoder_x = DecoderX('decoder_x')
    decoder_y = DecoderY('decoder_y')

    # encode x to get mean, log variance, and samples from the latent subspace Z
    mu_z, logvar_z, z = encoder_z(x_source, z_dim)
    # encode x and y to get mean, log variance, and samples from the latent subspace W
    mu_w, logvar_w, w = encoder_w(x_source, y_source, w_dim)

    # pass samples of z and w to get predictions of x
    pred_x = decoder_x(tf.concat([w, z], axis=-1))
    # get predicted labels based only on the latent subspace Z
    pred_y = decoder_y(z, NUMS_CLASS_cls)

    # Create and save a grid of images
    fake_img_traversal = tf.zeros([0, input_size, input_size, channels])
    for i in range(w_dim):
        for j in range(NUMS_CLASS):
            val = j * STEP_SIZE
            np_arr = np.zeros((BATCH_SIZE, w_dim))
            np_arr[:, i] = val
            tmp_w = tf.convert_to_tensor(np_arr, dtype=tf.float32)
            fake_img = decoder_x(tf.concat([tmp_w, z], axis=-1))
            fake_img_traversal = tf.concat([fake_img_traversal, fake_img], axis=0)
    fake_img_traversal_board = make4d_tensor(fake_img_traversal, channels, input_size, w_dim, NUMS_CLASS, BATCH_SIZE)
    fake_img_traversal_save = make3d_tensor(fake_img_traversal, channels, input_size, w_dim, NUMS_CLASS, BATCH_SIZE)

    # Create and save 2d traversal, this is relevant only for w_dim == 2
    fake_2d_img_traversal = tf.zeros([0, input_size, input_size, channels])
    for i in range(NUMS_CLASS):
        for j in range(NUMS_CLASS):
            val_0 = i * STEP_SIZE
            val_1 = j * STEP_SIZE
            np_arr = np.zeros((BATCH_SIZE, w_dim))
            np_arr[:, 0] = val_0
            np_arr[:, 1] = val_1
            tmp_w = tf.convert_to_tensor(np_arr, dtype=tf.float32)
            fake_2d_img = decoder_x(tf.concat([tmp_w, z], axis=-1))
            fake_2d_img_traversal = tf.concat([fake_2d_img_traversal, fake_2d_img], axis=0)
    fake_2d_img_traversal_board = make4d_tensor(fake_2d_img_traversal, channels, input_size, NUMS_CLASS, NUMS_CLASS, BATCH_SIZE)
    fake_2d_img_traversal_save = make3d_tensor(fake_2d_img_traversal, channels, input_size, NUMS_CLASS, NUMS_CLASS, BATCH_SIZE)

    # Create a single image based on y_target
    target_w = STEP_SIZE * tf.cast(y_target, dtype=tf.float32) + OFFSET
    fake_target_img = decoder_x(tf.concat([target_w, z], axis=-1))

    # ============= pre-trained classifier =============

    real_img_cls_logit_pretrained, real_img_cls_prediction = pretrained_classifier(x_source, NUMS_CLASS_cls,
                                                                                   reuse=False, name='classifier')
    fake_recon_cls_logit_pretrained, fake_recon_cls_prediction = pretrained_classifier(pred_x, NUMS_CLASS_cls,
                                                                                       reuse=True)
    fake_img_cls_logit_pretrained, fake_img_cls_prediction = pretrained_classifier(fake_img, NUMS_CLASS_cls,
                                                                                   reuse=True)

    # ============= predicted probabilities =============
    fake_target_p_tensor = tf.reduce_max(tf.cast(y_target, tf.float32) * 1.0 / float(NUMS_CLASS - 1), axis=1)

    # ============= Loss =============
    # OPTIMIZATION:

    # Specified in section 4.1 of http://www.cs.toronto.edu/~zemel/documents/Conditional_Subspace_VAE_all.pdf
    # There are three components: M1, M2, N

    # 1.Optimize the first loss related to maximizing variational lower bound
    #   on the marginal log likelihood and minimizing mutual information

    # define two KL divergences:
    # KL divergence for label 1
    #    We want the latent subspace W for this label to be close to mean 0, var 0.01
    kl1 = KL(mu1=mu_w, logvar1=logvar_w,
             mu2=tf.zeros_like(mu_w), logvar2=tf.ones_like(logvar_w) * np.log(0.01))
    # KL divergence for label 0
    #    We want the latent subspace W for this label to be close to mean MU_CLUSTER, var VAR_CLUSTER
    kl0 = KL(mu1=mu_w, logvar1=logvar_w, mu2=tf.ones_like(mu_w) * MU_CLUSTER, logvar2=tf.ones_like(logvar_w) * np.log(VAR_CLUSTER))

    loss_m1_1 = tf.reduce_sum(beta1 * tf.reduce_sum((x_source - pred_x) ** 2, axis=-1))  # corresponds to M1
    loss_m1_2 = tf.reduce_sum(
        beta2 * tf.where(tf.equal(y_source, tf.ones_like(y_source)), kl1, kl0))  # corresponds to M1
    loss_m1_3 = tf.reduce_sum(
        beta3 * KL(mu_z, logvar_z, tf.zeros_like(mu_z), tf.zeros_like(logvar_z)))  # corresponds to M1
    loss_m2 = tf.reduce_sum(beta4 * tf.reduce_sum(pred_y * safe_log(pred_y), axis=-1))  # corresponds to M2

    loss_m1 = loss_m1_1 + loss_m1_2 + loss_m1_3
    loss1 = loss_m1 + loss_m2

    # 2. Optimize second loss related to learning the approximate posterior

    loss_n = tf.reduce_sum(beta5 * tf.where(y_source == 1, -safe_log(pred_y[:, 1]), -safe_log(pred_y[:, 0])))  # N

    loss2 = loss_n

    optimizer_1 = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(loss1, var_list=decoder_x.var_list() +
                                                                                             encoder_w.var_list() +
                                                                                             encoder_z.var_list(),
                                                                             global_step=global_step)
    optimizer_2 = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(loss2, var_list=decoder_y.var_list(),
                                                                             global_step=global_step)

    # combine losses for tracking
    loss = loss1 + loss2

    # ============= summary =============
    real_img_sum = tf.summary.image('real_img', x_source)
    fake_recon_img_sum = tf.summary.image('fake_recon_img', pred_x)
    fake_img_sum = tf.summary.image('fake_target_img', fake_target_img)
    fake_img_traversal_sum = tf.summary.image('fake_img_traversal', fake_img_traversal_board)
    fake_2d_img_traversal_sum = tf.summary.image('fake_2d_img_traversal', fake_2d_img_traversal_board)

    loss_m1_sum = tf.summary.scalar('losses/M1', loss_m1)
    loss_m1_1_sum = tf.summary.scalar('losses/M1/m1_1', loss_m1_1)
    loss_m1_2_sum = tf.summary.scalar('losses/M1/m1_2', loss_m1_2)
    loss_m1_3_sum = tf.summary.scalar('losses/M1/m1_3', loss_m1_3)
    loss_m2_sum = tf.summary.scalar('losses/M2', loss_m2)
    loss_n_sum = tf.summary.scalar('losses/N', loss_n)
    loss_sum = tf.summary.scalar('losses/total_loss', loss)

    part1_sum = tf.summary.merge(
        [loss_m1_sum, loss_m1_1_sum, loss_m1_2_sum, loss_m1_3_sum, loss_m2_sum])
    part2_sum = tf.summary.merge(
        [loss_n_sum, loss_sum, ])
    overall_sum = tf.summary.merge(
        [loss_sum, real_img_sum, fake_recon_img_sum, fake_img_sum, fake_img_traversal_sum, fake_2d_img_traversal_sum])

    # ============= session =============
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    writer = tf.summary.FileWriter(log_dir, sess.graph)

    # ============= Checkpoints =============
    if continue_train:
        print(" [*] before training, Load checkpoint ")
        print(" [*] Reading checkpoint...")

        ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name))
            print(ckpt_dir_continue, ckpt_name)
            print("Successful checkpoint upload")
        else:
            print("Failed checkpoint load")
    else:
        print(" [!] before training, no need to Load ")

    # ============= load pre-trained classifier checkpoint =============
    class_vars = [var for var in slim.get_variables_to_restore() if 'classifier' in var.name]
    name_to_var_map_local = {var.op.name: var for var in class_vars}
    temp_saver = tf.train.Saver(var_list=name_to_var_map_local)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir_cls)
    ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
    temp_saver.restore(sess, os.path.join(ckpt_dir_cls, ckpt_name))
    print("Classifier checkpoint loaded.................")
    print(ckpt_dir_cls, ckpt_name)

    # ============= Training =============
    for e in range(1, EPOCHS + 1):
        np.random.shuffle(data)
        for i in range(data.shape[0] // BATCH_SIZE):
            if args.debug:
                image_paths = np.array([str(ind) for ind in my_data_loader.tmp_list])
            else:
                image_paths = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            img, labels = my_data_loader.load_images_and_labels(image_paths, image_dir=config['image_dir'], n_class=1,
                                                                file_names_dict=file_names_dict,
                                                                num_channel=channels, do_center_crop=True)

            labels = labels.ravel()
            labels = np.eye(NUMS_CLASS_cls)[labels.astype(int)]

            target_labels_probs = np.random.randint(0, high=NUMS_CLASS, size=BATCH_SIZE)
            target_labels_w_ind = np.random.randint(0, high=w_dim, size=BATCH_SIZE)
            target_labels = np.eye(w_dim)[target_labels_w_ind] * np.repeat(np.expand_dims(target_labels_probs, axis=-1), w_dim, axis=1)

            my_feed_dict = {y_target: target_labels, x_source: img, train_phase: True, y_s: labels}

            _, par1_loss, par1_summary_str, overall_sum_str, counter = sess.run([optimizer_1, loss1, part1_sum, overall_sum, global_step],
                                                                       feed_dict=my_feed_dict)

            writer.add_summary(par1_summary_str, global_step=counter)
            writer.add_summary(overall_sum_str, global_step=counter)

            _, part2_loss, part2_summary_str, overall_sum_str2, counter = sess.run([optimizer_2, loss2, part2_sum, overall_sum, global_step],
                                                                          feed_dict=my_feed_dict)
            writer.add_summary(part2_summary_str, global_step=counter)
            writer.add_summary(overall_sum_str2, global_step=counter)

            def save_results(sess, step):
                num_seed_imgs = BATCH_SIZE
                img, labels = my_data_loader.load_images_and_labels(image_paths[0:num_seed_imgs],
                                                                    image_dir=config['image_dir'], n_class=1,
                                                                    file_names_dict=file_names_dict,
                                                                    num_channel=channels,
                                                                    do_center_crop=True)

                labels = labels.ravel()
                labels = np.eye(NUMS_CLASS_cls)[labels.astype(int)]

                target_labels_probs = np.random.randint(0, high=NUMS_CLASS, size=BATCH_SIZE)
                target_labels_w_ind = np.random.randint(0, high=w_dim, size=BATCH_SIZE)
                target_labels = np.eye(w_dim)[target_labels_w_ind] * np.repeat(
                    np.expand_dims(target_labels_probs, axis=-1), w_dim, axis=1)

                my_feed_dict = {y_target: target_labels, x_source: img, train_phase: False,
                                y_s: labels}

                sample_fake_img_traversal, sample_fake_2d_img_traversal = sess.run([fake_img_traversal_save, fake_2d_img_traversal_save], feed_dict=my_feed_dict)

                # save samples
                sample_file = os.path.join(sample_dir, '%06d.jpg' % step)
                save_image(sample_fake_img_traversal, sample_file)

                sample_file = os.path.join(sample_dir, '%06d_2d.jpg' % step)
                save_image(sample_fake_2d_img_traversal, sample_file)

            batch_counter = int(counter/2)
            if batch_counter % save_summary == 0:
                save_results(sess, batch_counter)

            if batch_counter % save_ckpt == 0:
                saver.save(sess, ckpt_dir + "/model%2d.ckpt" % batch_counter, global_step=global_step)
Exemple #2
0
def test(config):
    # ============= Experiment Folder=============
    output_dir = os.path.join(config['log_dir'], config['name'])
    classifier_output_path = os.path.join(output_dir, 'classifier_output')
    try:
        os.makedirs(classifier_output_path)
    except:
        pass
    past_checkpoint = output_dir
    # ============= Experiment Parameters =============
    BATCH_SIZE = config['batch_size']
    channels = config['num_channel']
    input_size = config['input_size']
    N_CLASSES = config['num_class']
    dataset = config['dataset']
    # in certain circumstances, for example for when classifier has been trained
    # on re-sampled data, we want to still use the whole dataset for the generative model.
    # That's why we produce classifier's output on the test_image_label_dict
    if ('export_image_label_dict'
            in config.keys()) and ('export_train'
                                   in config.keys()) and ('export_test'
                                                          in config.keys()):
        image_label_dict = config['export_image_label_dict']
        train_ids = config['export_train']
        test_ids = config['export_test']
    else:
        image_label_dict = config['image_label_dict']
        train_ids = config['train']
        test_ids = config['test']
    if dataset == 'CelebA':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader()
    elif dataset == 'shapes':
        pretrained_classifier = shapes_classifier
        my_data_loader = ShapesLoader()
    elif dataset == 'CelebA64' or dataset == 'dermatology':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
    elif dataset == 'synthderm':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
    # ============= Data =============
    try:
        categories, file_names_dict = read_data_file(image_label_dict)
    except:
        print("Problem in reading input data file : ", image_label_dict)
        sys.exit()
    data_train = np.load(train_ids)
    data_test = np.load(test_ids)
    print("The classification categories are: ")
    print(categories)
    print('The size of the training set: ', data_train.shape[0])
    print('The size of the testing set: ', data_test.shape[0])

    # ============= placeholder =============
    with tf.name_scope('input'):
        x_ = tf.placeholder(tf.float32,
                            [None, input_size, input_size, channels],
                            name='x-input')
        y_ = tf.placeholder(tf.int64, [None, N_CLASSES], name='y-input')
        isTrain = tf.placeholder(tf.bool)
    # ============= Model =============

    if N_CLASSES == 1:
        y = tf.reshape(y_, [-1])
        y = tf.one_hot(y, 2, on_value=1.0, off_value=0.0, axis=-1)
        logit, prediction = pretrained_classifier(x_,
                                                  n_label=2,
                                                  reuse=False,
                                                  name='classifier',
                                                  isTrain=isTrain)
    else:
        logit, prediction = pretrained_classifier(x_,
                                                  n_label=N_CLASSES,
                                                  reuse=False,
                                                  name='classifier',
                                                  isTrain=isTrain)
        y = y_
    classif_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=y,
                                                   logits=logit)
    loss = tf.losses.get_total_loss()
    # ============= Variables =============
    # Note that this list of variables only include the weights and biases in the model.
    lst_vars = []
    for v in tf.global_variables():
        lst_vars.append(v)
    # ============= Session =============
    sess = tf.InteractiveSession()
    saver = tf.train.Saver(var_list=lst_vars)
    tf.global_variables_initializer().run()
    # ============= Load Checkpoint =============
    if past_checkpoint is not None:
        ckpt = tf.train.get_checkpoint_state(past_checkpoint + '/')
        if ckpt and ckpt.model_checkpoint_path:
            print(str(ckpt.model_checkpoint_path))
            saver.restore(sess,
                          tf.train.latest_checkpoint(past_checkpoint + '/'))
        else:
            sys.exit()
    else:
        sys.exit()
    # ============= Testing - Save the Output =============

    def get_predictions(data, subset_name):
        names = np.empty([0])
        prediction_y = np.empty([0])
        true_y = np.empty([0])

        num_batch = int(data.shape[0] / BATCH_SIZE)
        for i in range(0, num_batch):
            start = i * BATCH_SIZE
            ns = data[start:start + BATCH_SIZE]
            xs, ys = my_data_loader.load_images_and_labels(
                ns,
                image_dir=config['image_dir'],
                n_class=N_CLASSES,
                file_names_dict=file_names_dict,
                num_channel=channels,
                do_center_crop=True)
            [_pred] = sess.run([prediction],
                               feed_dict={
                                   x_: xs,
                                   isTrain: False,
                                   y_: ys
                               })
            if i == 0:
                names = np.asarray(ns)
                prediction_y = np.asarray(_pred)
                true_y = np.asarray(ys)
            else:
                names = np.append(names, np.asarray(ns), axis=0)
                prediction_y = np.append(prediction_y,
                                         np.asarray(_pred),
                                         axis=0)
                true_y = np.append(true_y, np.asarray(ys), axis=0)
        np.save(classifier_output_path + '/name_{}1.npy'.format(subset_name),
                names)
        np.save(
            classifier_output_path +
            '/prediction_y_{}1.npy'.format(subset_name), prediction_y)
        np.save(classifier_output_path + '/true_y_{}1.npy'.format(subset_name),
                true_y)
        return names, prediction_y, np.reshape(true_y, [-1, N_CLASSES])

    train_names, train_prediction_y, train_true_y = get_predictions(
        data_train, 'train')
    test_names, test_prediction_y, test_true_y = get_predictions(
        data_test, 'test')

    return train_names, train_prediction_y, train_true_y, test_names, test_prediction_y, test_true_y
Exemple #3
0
def test(config,
         dbg_img_label_dict=None,
         dbg_mode=False,
         export_output=True,
         dbg_size=10,
         dbg_img_indices=[],
         calc_stability=True):

    # ============= Experiment Folder=============
    assets_dir = os.path.join(config['log_dir'], config['name'])
    log_dir = os.path.join(assets_dir, 'log')
    ckpt_dir = os.path.join(assets_dir, 'ckpt_dir')
    sample_dir = os.path.join(assets_dir, 'sample')

    # Whether this is for saving the results for substitutability metric or the regular testing process.
    # If only for substitutability, we skip saving large arrays and additional multiple random outputs to avoid OOM
    calc_substitutability = config['calc_substitutability']

    if calc_substitutability:
        substitutability_attr = config['substitutability_attr']

        test_dir = os.path.join(assets_dir, 'test', 'substitutability_input')
        substitutability_exported_img_label_dict = os.path.join(
            test_dir, '{}_dims_{}_clss_{}.txt'.format(substitutability_attr,
                                                      config['w_dim'],
                                                      config['num_bins']))
        substitutability_label_scaler = config['num_bins'] - 1
        exported_dict = {}

        substitutability_classifier_config = config[
            'substitutability_classifier_config']
        _cls_config = yaml.load(open(config['classifier_config']))
        substitutability_img_subset = _cls_config['train']
        substitutability_img_label_dict = _cls_config['image_label_dict']
        _edited_cls_config = deepcopy(_cls_config)
        _edited_cls_config['image_dir'] = os.path.join(test_dir, 'images')
        if not os.path.exists(_edited_cls_config['image_dir']):
            os.makedirs(_edited_cls_config['image_dir'])
        _edited_cls_config[
            'image_label_dict'] = substitutability_exported_img_label_dict
        _edited_cls_config['train'] = os.path.join(test_dir, 'train_ids.npy')
        _edited_cls_config['test'] = ''  # skips evaluating on test
        _edited_cls_config['log_dir'] = test_dir
        _edited_cls_config['ckpt_dir_continue'] = ''
        save_config_dict(_edited_cls_config,
                         substitutability_classifier_config)
    else:
        test_dir = os.path.join(assets_dir, 'test')

    # ============= Experiment Parameters =============

    ckpt_dir_cls = config['cls_experiment']
    if 'evaluation_batch_size' in config.keys():
        BATCH_SIZE = config['evaluation_batch_size']
    else:
        BATCH_SIZE = config['batch_size']
    channels = config['num_channel']
    input_size = config['input_size']
    NUMS_CLASS_cls = config['num_class']
    NUMS_CLASS = config['num_bins']
    MU_CLUSTER = config['mu_cluster']
    VAR_CLUSTER = config['var_cluster']
    TRAVERSAL_N_SIGMA = config['traversal_n_sigma']
    STEP_SIZE = 2 * TRAVERSAL_N_SIGMA * VAR_CLUSTER / (NUMS_CLASS - 1)
    OFFSET = MU_CLUSTER - TRAVERSAL_N_SIGMA * VAR_CLUSTER

    metrics_stability_nx = config['metrics_stability_nx']
    metrics_stability_var = config['metrics_stability_var']
    target_class = config['target_class']
    ckpt_dir_continue = ckpt_dir
    if dbg_img_label_dict is not None:
        image_label_dict = dbg_img_label_dict
    elif calc_substitutability:
        image_label_dict = substitutability_img_label_dict
    else:
        image_label_dict = config['image_label_dict']

    # CSVAE parameters
    beta1 = config['beta1']
    beta2 = config['beta2']
    beta3 = config['beta3']
    beta4 = config['beta4']
    beta5 = config['beta5']
    z_dim = config['z_dim']
    w_dim = config['w_dim']

    if dbg_mode:
        num_samples = dbg_size
    else:
        num_samples = config['count_to_save']

    dataset = config['dataset']

    if dataset == 'CelebA':
        my_data_loader = ImageLabelLoader(input_size=128)
        pretrained_classifier = celeba_classifier
        EncoderZ = EncoderZ_128
        EncoderW = EncoderW_128
        DecoderX = DecoderX_128
        DecoderY = DecoderY_128
    elif dataset == 'shapes':
        if calc_substitutability:
            my_data_loader = ShapesLoader()
        else:
            # my_data_loader = ShapesLoader()
            # for efficiency, let's just load as many samples as we need
            my_data_loader = ShapesLoader(
                dbg_mode=True,
                dbg_size=num_samples,
                dbg_image_label_dict=image_label_dict,
                dbg_img_indices=dbg_img_indices)
            dbg_mode = True
        pretrained_classifier = shapes_classifier
        EncoderZ = EncoderZ_64
        EncoderW = EncoderW_64
        DecoderX = DecoderX_64
        DecoderY = DecoderY_64
    elif dataset == 'CelebA64' or dataset == 'dermatology':
        my_data_loader = ImageLabelLoader(input_size=64)
        pretrained_classifier = celeba_classifier
        EncoderZ = EncoderZ_64
        EncoderW = EncoderW_64
        DecoderX = DecoderX_64
        DecoderY = DecoderY_64
    elif dataset == 'synthderm':
        my_data_loader = ImageLabelLoader(input_size=64)
        pretrained_classifier = celeba_classifier
        EncoderZ = EncoderZ_64
        EncoderW = EncoderW_64
        DecoderX = DecoderX_64
        DecoderY = DecoderY_64

    # ============= Data =============
    try:
        categories, file_names_dict = read_data_file(image_label_dict)
    except:
        print("Problem in reading input data file : ", image_label_dict)
        sys.exit()
    if calc_substitutability:
        data = np.load(substitutability_img_subset)
        num_samples = len(data)
    elif dbg_mode and dataset == 'shapes':
        data = np.array([str(ind) for ind in my_data_loader.tmp_list])
    else:
        if len(dbg_img_indices) > 0:
            data = np.asarray(dbg_img_indices)
        else:
            data = np.asarray(list(file_names_dict.keys()))
    print("The classification categories are: ")
    print(categories)
    print('The size of the test set: ', data.shape[0])

    # ============= placeholder =============
    x_source = tf.placeholder(tf.float32,
                              [None, input_size, input_size, channels],
                              name='x_source')
    y_s = tf.placeholder(tf.int32, [None, NUMS_CLASS_cls], name='y_s')
    y_source = y_s[:, NUMS_CLASS_cls - 1]
    train_phase = tf.placeholder(tf.bool, name='train_phase')

    y_target = tf.placeholder(tf.int32, [None, w_dim],
                              name='y_target')  # between 0 and NUMS_CLASS

    generation_dim = w_dim

    # ============= CSVAE =============
    encoder_z = EncoderZ('encoder_z')
    encoder_w = EncoderW('encoder_w')
    decoder_x = DecoderX('decoder_x')
    decoder_y = DecoderY('decoder_y')

    # encode x to get mean, log variance, and samples from the latent subspace Z
    mu_z, logvar_z, z = encoder_z(x_source, z_dim)
    # encode x and y to get mean, log variance, and samples from the latent subspace W
    mu_w, logvar_w, w = encoder_w(x_source, y_source, w_dim)

    # pass samples of z and w to get predictions of x
    pred_x = decoder_x(tf.concat([w, z], axis=-1))
    # get predicted labels based only on the latent subspace Z
    pred_y = decoder_y(z, NUMS_CLASS_cls)

    # Create a single image based on y_target
    target_w = STEP_SIZE * tf.cast(y_target, dtype=tf.float32) + OFFSET
    fake_target_img = decoder_x(tf.concat([target_w, z], axis=-1))

    # ============= pre-trained classifier =============
    real_img_cls_logit_pretrained, real_img_cls_prediction = pretrained_classifier(
        x_source, NUMS_CLASS_cls, reuse=False, name='classifier')
    fake_recon_cls_logit_pretrained, fake_recon_cls_prediction = pretrained_classifier(
        pred_x, NUMS_CLASS_cls, reuse=True)
    fake_img_cls_logit_pretrained, fake_img_cls_prediction = pretrained_classifier(
        fake_target_img, NUMS_CLASS_cls, reuse=True)

    # ============= predicted probabilities =============
    fake_target_p_tensor = tf.reduce_max(tf.cast(y_target, tf.float32) * 1.0 /
                                         float(NUMS_CLASS - 1),
                                         axis=1)

    # ============= session =============
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    # ============= Checkpoints =============
    print(" [*] Reading checkpoint...")

    ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name))
        print(ckpt_dir_continue, ckpt_name)
        print("Successful checkpoint upload")
    else:
        print("Failed checkpoint load")
        sys.exit()

    # ============= load pre-trained classifier checkpoint =============
    class_vars = [
        var for var in slim.get_variables_to_restore()
        if 'classifier' in var.name
    ]
    name_to_var_map_local = {var.op.name: var for var in class_vars}
    temp_saver = tf.train.Saver(var_list=name_to_var_map_local)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir_cls)
    ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
    temp_saver.restore(sess, os.path.join(ckpt_dir_cls, ckpt_name))
    print("Classifier checkpoint loaded.................")
    print(ckpt_dir_cls, ckpt_name)

    # ============= Testing =============
    def _save_output_array(name, values):
        np.save(os.path.join(test_dir, '{}.npy'.format(name)), values)

    if not calc_substitutability:
        names = np.empty([num_samples], dtype=object)
        real_imgs = np.empty([num_samples, input_size, input_size, channels])
        fake_t_imgs = np.empty([
            num_samples, generation_dim, NUMS_CLASS, input_size, input_size,
            channels
        ])
        fake_s_recon_imgs = np.empty([
            num_samples, generation_dim, NUMS_CLASS, input_size, input_size,
            channels
        ])
        real_ps = np.empty(
            [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls])
        recon_ps = np.empty(
            [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls])
        fake_target_ps = np.empty([num_samples, generation_dim, NUMS_CLASS])
        fake_ps = np.empty(
            [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls])

        # For stability metric
        stability_fake_t_imgs = np.empty([
            num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS,
            input_size, input_size, channels
        ])
        stability_fake_s_recon_imgs = np.empty([
            num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS,
            input_size, input_size, channels
        ])
        stability_recon_ps = np.empty([
            num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS,
            NUMS_CLASS_cls
        ])
        stability_fake_ps = np.empty([
            num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS,
            NUMS_CLASS_cls
        ])

        arrs_to_save = [
            'names', 'real_imgs', 'fake_t_imgs', 'fake_s_recon_imgs',
            'real_ps', 'recon_ps', 'fake_target_ps', 'fake_ps',
            'stability_fake_t_imgs', 'stability_fake_s_recon_imgs',
            'stability_recon_ps', 'stability_fake_ps'
        ]

    np.random.shuffle(data)

    data = data[0:num_samples]
    for i in range(math.ceil(data.shape[0] / BATCH_SIZE)):
        image_paths = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
        # num_seed_imgs is either BATCH_SIZE
        # or if the number of samples is not divisible by BATCH_SIZE a smaller value
        num_seed_imgs = np.shape(image_paths)[0]
        img, _labels = my_data_loader.load_images_and_labels(
            image_paths,
            config['image_dir'],
            1,
            file_names_dict,
            channels,
            do_center_crop=True)
        img_repeat = np.repeat(img, NUMS_CLASS * generation_dim, 0)

        labels = np.repeat(_labels, NUMS_CLASS * generation_dim, 0)
        labels = labels.ravel()
        labels = np.eye(NUMS_CLASS_cls)[labels.astype(int)]

        _dim_bin_arr = np.zeros((generation_dim * NUMS_CLASS, generation_dim))
        for _gen_dim in range(generation_dim):
            _start = _gen_dim * NUMS_CLASS
            _end = (_gen_dim + 1) * NUMS_CLASS
            _dim_bin_arr_sub = np.zeros((NUMS_CLASS, generation_dim))
            _dim_bin_arr_sub[:, _gen_dim] = np.asarray(range(NUMS_CLASS))
            _dim_bin_arr[_start:_end, :] = _dim_bin_arr_sub
        target_labels = np.tile(
            _dim_bin_arr,
            (num_seed_imgs, 1))  # [num_seed_imgs * w_dim * NUMS_CLASS, w_dim]
        # target_labels = np.tile(
        #     np.repeat(np.expand_dims(np.asarray(range(NUMS_CLASS)), axis=1), generation_dim, axis=1),
        #     (num_seed_imgs*generation_dim, 1))  # [num_seed_imgs * w_dim * NUMS_CLASS, w_dim]

        my_feed_dict = {
            y_target: target_labels,
            x_source: img_repeat,
            train_phase: False,
            y_s: labels
        }

        fake_t_img, fake_s_recon_img, real_p, recon_p, fake_target_p, fake_p = sess.run(
            [
                fake_target_img, pred_x, real_img_cls_prediction,
                fake_recon_cls_prediction, fake_target_p_tensor,
                fake_img_cls_prediction
            ],
            feed_dict=my_feed_dict)

        print('{} / {}'.format(i + 1, math.ceil(data.shape[0] / BATCH_SIZE)))

        _num_cur_samples = len(image_paths)

        if calc_substitutability:
            _ind_generation_dim = np.random.randint(low=0,
                                                    high=generation_dim,
                                                    size=_num_cur_samples)
            reshaped_imgs = np.reshape(
                fake_t_img, (_num_cur_samples, generation_dim, NUMS_CLASS,
                             input_size, input_size, channels))
            sub_exported_dict = save_batch_images(
                reshaped_imgs,
                image_paths,
                _ind_generation_dim,
                _labels,
                substitutability_label_scaler,
                _edited_cls_config['image_dir'],
                has_extension=(dataset != 'shapes'))
            exported_dict.update(sub_exported_dict)
        else:
            start_ind = i * BATCH_SIZE
            end_ind = start_ind + _num_cur_samples
            names[start_ind:end_ind] = np.asarray(image_paths)

            if calc_stability:
                for j in range(metrics_stability_nx):
                    noisy_img = img + np.random.normal(
                        loc=0.0,
                        scale=metrics_stability_var,
                        size=np.shape(img))
                    stability_img_repeat = np.repeat(
                        noisy_img, NUMS_CLASS * generation_dim, 0)
                    stability_feed_dict = {
                        y_target: target_labels,
                        x_source: stability_img_repeat,
                        train_phase: False,
                        y_s: labels
                    }
                    _stability_fake_t_img, _stability_fake_s_recon_img, _stability_recon_p, _stability_fake_p = sess.run(
                        [
                            fake_target_img, pred_x, fake_recon_cls_prediction,
                            fake_img_cls_prediction
                        ],
                        feed_dict=stability_feed_dict)

                    stability_fake_t_imgs[start_ind:end_ind, j] = np.reshape(
                        _stability_fake_t_img,
                        (_num_cur_samples, generation_dim, NUMS_CLASS,
                         input_size, input_size, channels))
                    stability_fake_s_recon_imgs[
                        start_ind:end_ind, j] = np.reshape(
                            _stability_fake_s_recon_img,
                            (_num_cur_samples, generation_dim, NUMS_CLASS,
                             input_size, input_size, channels))
                    stability_recon_ps[start_ind:end_ind, j] = np.reshape(
                        _stability_recon_p, (_num_cur_samples, generation_dim,
                                             NUMS_CLASS, NUMS_CLASS_cls))
                    stability_fake_ps[start_ind:end_ind, j] = np.reshape(
                        _stability_fake_p, (_num_cur_samples, generation_dim,
                                            NUMS_CLASS, NUMS_CLASS_cls))

            real_imgs[start_ind:end_ind] = img
            fake_t_imgs[start_ind:end_ind] = np.reshape(
                fake_t_img, (_num_cur_samples, generation_dim, NUMS_CLASS,
                             input_size, input_size, channels))
            fake_s_recon_imgs[start_ind:end_ind] = np.reshape(
                fake_s_recon_img,
                (_num_cur_samples, generation_dim, NUMS_CLASS, input_size,
                 input_size, channels))
            real_ps[start_ind:end_ind] = np.reshape(
                real_p,
                (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls))
            recon_ps[start_ind:end_ind] = np.reshape(
                recon_p,
                (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls))
            fake_target_ps[start_ind:end_ind] = np.reshape(
                fake_target_p, (_num_cur_samples, generation_dim, NUMS_CLASS))
            fake_ps[start_ind:end_ind] = np.reshape(
                fake_p,
                (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls))

    output_dict = {}
    if calc_substitutability:
        save_dict(exported_dict, substitutability_exported_img_label_dict,
                  substitutability_attr)
        np.save(_edited_cls_config['train'],
                np.asarray(list(exported_dict.keys())))

        # retrain the classifier with the new generated images
        tf.reset_default_graph()
        train_classif(config['substitutability_classifier_config'])
    else:
        if export_output:
            for arr_name in arrs_to_save:
                _save_output_array(arr_name, eval(arr_name))

        for arr_name in arrs_to_save:
            output_dict.update({arr_name: eval(arr_name)})

    return output_dict
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str)
    parser.add_argument('--debug', '-d', action='store_true')
    args = parser.parse_args()

    # ============= Load config =============
    config_path = args.config
    config = yaml.load(open(config_path))
    print(config)

    # ============= Experiment Folder=============
    assets_dir = os.path.join(config['log_dir'], config['name'])
    log_dir = os.path.join(assets_dir, 'log')
    ckpt_dir = os.path.join(assets_dir, 'ckpt_dir')
    sample_dir = os.path.join(assets_dir, 'sample')
    test_dir = os.path.join(assets_dir, 'test')
    # make directory if not exist
    try:
        os.makedirs(log_dir)
    except:
        pass
    try:
        os.makedirs(ckpt_dir)
    except:
        pass
    try:
        os.makedirs(sample_dir)
    except:
        pass
    try:
        os.makedirs(test_dir)
    except:
        pass

    # ============= Experiment Parameters =============
    ckpt_dir_cls = config['cls_experiment']
    BATCH_SIZE = config['batch_size']
    EPOCHS = config['epochs']
    channels = config['num_channel']
    input_size = config['input_size']
    NUMS_CLASS_cls = config['num_class']
    NUMS_CLASS = config['num_bins']
    target_class = config['target_class']
    lambda_GAN = config['lambda_GAN']
    lambda_cyc = config['lambda_cyc']
    lambda_cls = config['lambda_cls']
    save_summary = int(config['save_summary'])
    save_ckpt = int(config['save_ckpt'])
    ckpt_dir_continue = config['ckpt_dir_continue']
    k_dim = config['k_dim']
    lambda_r = config['lambda_r']
    disentangle = k_dim > 1
    discriminate_evert_nth = config['discriminate_every_nth']
    generate_every_nth = config['generate_every_nth']
    dataset = config['dataset']
    if dataset == 'CelebA':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader()
        Discriminator_Ordinal = Discriminator_Ordinal_128
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_128
        Discriminator_Contrastive = Discriminator_Contrastive_128
    elif dataset == 'shapes':
        pretrained_classifier = shapes_classifier
        if args.debug:
            my_data_loader = ShapesLoader(
                dbg_mode=True,
                dbg_size=config['batch_size'],
                dbg_image_label_dict=config['image_label_dict'])
        else:
            my_data_loader = ShapesLoader()
        Discriminator_Ordinal = Discriminator_Ordinal_64
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_64
        Discriminator_Contrastive = Discriminator_Contrastive_64
    elif dataset == 'CelebA64' or dataset == 'dermatology':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
        Discriminator_Ordinal = Discriminator_Ordinal_64
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_64
        Discriminator_Contrastive = Discriminator_Contrastive_64
    elif dataset == 'synthderm':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
        Discriminator_Ordinal = Discriminator_Ordinal_64
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_64
        Discriminator_Contrastive = Discriminator_Contrastive_64

    if ckpt_dir_continue == '':
        continue_train = False
    else:
        ckpt_dir_continue = os.path.join(ckpt_dir_continue, 'ckpt_dir')
        continue_train = True

    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')

    # ============= Data =============
    try:
        categories, file_names_dict = read_data_file(
            config['image_label_dict'])
    except:
        print("Problem in reading input data file : ",
              config['image_label_dict'])
        sys.exit()
    data = np.asarray(list(file_names_dict.keys()))
    print("The classification categories are: ")
    print(categories)
    print('The size of the training set: ', data.shape[0])
    fp = open(os.path.join(log_dir, 'setting.txt'), 'w')
    fp.write('config_file:' + str(config_path) + '\n')
    fp.close()

    # ============= placeholder =============
    x_source = tf.placeholder(tf.float32,
                              [None, input_size, input_size, channels],
                              name='x_source')
    y_s = tf.placeholder(tf.int32, [None, NUMS_CLASS], name='y_s')
    y_source = y_s[:, 0]
    train_phase = tf.placeholder(tf.bool, name='train_phase')

    y_t = tf.placeholder(tf.int32, [None, NUMS_CLASS], name='y_t')
    y_target = y_t[:, 0]

    if disentangle:
        y_regularizer = tf.placeholder(tf.int32, [None], name='y_regularizer')
        y_r = tf.placeholder(tf.float32, [None, k_dim], name='y_r')
        y_r_0 = tf.zeros_like(y_r, name='y_r_0')

    # ============= G & D =============
    G = Generator_Encoder_Decoder(
        "generator")  # with conditional BN, SAGAN: SN here as well
    D = Discriminator_Ordinal("discriminator")  # with SN and projection

    real_source_logits = D(x_source, y_s, NUMS_CLASS, "NO_OPS")
    if disentangle:
        fake_target_img, fake_target_img_embedding = G(
            x_source, y_regularizer * NUMS_CLASS + y_target,
            NUMS_CLASS * k_dim)
        fake_source_img, fake_source_img_embedding = G(
            fake_target_img, y_regularizer * NUMS_CLASS + y_source,
            NUMS_CLASS * k_dim)
        fake_source_recons_img, x_source_img_embedding = G(
            x_source, y_regularizer * NUMS_CLASS + y_source,
            NUMS_CLASS * k_dim)
    else:
        fake_target_img, fake_target_img_embedding = G(x_source, y_target,
                                                       NUMS_CLASS)
        fake_source_img, fake_source_img_embedding = G(fake_target_img,
                                                       y_source, NUMS_CLASS)
        fake_source_recons_img, x_source_img_embedding = G(
            x_source, y_source, NUMS_CLASS)
    fake_target_logits = D(fake_target_img, y_t, NUMS_CLASS, None)

    # ============= pre-trained classifier =============
    real_img_cls_logit_pretrained, real_img_cls_prediction = pretrained_classifier(
        x_source, NUMS_CLASS_cls, reuse=False, name='classifier')
    fake_img_cls_logit_pretrained, fake_img_cls_prediction = pretrained_classifier(
        fake_target_img, NUMS_CLASS_cls, reuse=True)
    real_img_recons_cls_logit_pretrained, real_img_recons_cls_prediction = pretrained_classifier(
        fake_source_img, NUMS_CLASS_cls, reuse=True)

    # ============= pre-trained classifier loss =============
    real_p = tf.cast(y_target, tf.float32) * 1.0 / float(NUMS_CLASS - 1)
    fake_q = fake_img_cls_prediction[:, target_class]
    fake_evaluation = (real_p * safe_log(fake_q)) + (
        (1 - real_p) * safe_log(1 - fake_q))
    fake_evaluation = -tf.reduce_mean(fake_evaluation)

    recons_evaluation = (real_img_cls_prediction[:, target_class] * safe_log(
        real_img_recons_cls_prediction[:, target_class])) + (
            (1 - real_img_cls_prediction[:, target_class]) *
            safe_log(1 - real_img_recons_cls_prediction[:, target_class]))
    recons_evaluation = -tf.reduce_mean(recons_evaluation)

    # ============= regularizer constrastive discriminator loss =============
    if disentangle:
        R = Discriminator_Contrastive("disentangler")

        regularizer_fake_target_v_source_logits = R(
            tf.concat([x_source, fake_target_img], axis=-1), k_dim)
        regularizer_fake_source_v_target_logits = R(
            tf.concat([fake_target_img, fake_source_img], axis=-1), k_dim)
        regularizer_fake_source_v_source_logits = R(
            tf.concat([x_source, fake_source_img], axis=-1), k_dim)
        regularizer_fake_source_recon_v_source_logits = R(
            tf.concat([x_source, fake_source_recons_img], axis=-1), k_dim)

    # ============= Loss =============
    D_loss_GAN, D_acc, D_precision, D_recall = discriminator_loss(
        'hinge', real_source_logits, fake_target_logits)
    G_loss_GAN = generator_loss('hinge', fake_target_logits)
    G_loss_cyc = l1_loss(x_source, fake_source_img)
    G_loss_rec = l1_loss(
        x_source, fake_source_recons_img
    )  #+l2_loss(x_source_img_embedding, fake_source_img_embedding)
    D_loss = (D_loss_GAN * lambda_GAN)
    D_opt = tf.train.AdamOptimizer(2e-4, beta1=0.,
                                   beta2=0.9).minimize(D_loss,
                                                       var_list=D.var_list(),
                                                       global_step=global_step)

    if disentangle:
        R_fake_target_v_source_loss, R_fake_target_v_source_acc = contrastive_regularizer_loss(
            regularizer_fake_target_v_source_logits, y_r)
        R_fake_source_v_target_loss, R_fake_source_v_target_acc = contrastive_regularizer_loss(
            regularizer_fake_source_v_target_logits, y_r)
        R_fake_source_v_source_loss, R_fake_source_v_source_acc = contrastive_regularizer_loss(
            regularizer_fake_source_v_source_logits, y_r_0)
        R_fake_source_recon_v_source_loss, R_fake_source_recon_v_source_acc = contrastive_regularizer_loss(
            regularizer_fake_source_recon_v_source_logits, y_r_0)
        R_loss = R_fake_target_v_source_loss + R_fake_source_v_target_loss + R_fake_source_v_source_loss + R_fake_source_recon_v_source_loss
        R_opt = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(
            R_loss * lambda_r, var_list=R.var_list(), global_step=global_step)
        G_loss = (G_loss_GAN * lambda_GAN) + (G_loss_rec * lambda_cyc) + (
            G_loss_cyc * lambda_cyc) + (fake_evaluation * lambda_cls) + (
                recons_evaluation * lambda_cls) + (R_loss * lambda_r)
        G_opt = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(
            G_loss,
            var_list=G.var_list() + R.var_list(),
            global_step=global_step)
    else:
        G_loss = (G_loss_GAN * lambda_GAN) + (G_loss_rec * lambda_cyc) + (
            G_loss_cyc * lambda_cyc) + (fake_evaluation * lambda_cls) + (
                recons_evaluation * lambda_cls)
        G_opt = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(
            G_loss, var_list=G.var_list(), global_step=global_step)

    # ============= summary =============
    real_img_sum = tf.summary.image('real_img', x_source)
    fake_img_sum = tf.summary.image('fake_target_img', fake_target_img)
    fake_source_img_sum = tf.summary.image('fake_source_img', fake_source_img)
    fake_source_recons_img_sum = tf.summary.image('fake_source_recons_img',
                                                  fake_source_recons_img)

    acc_d = tf.summary.scalar('discriminator/acc_d', D_acc)
    precision_d = tf.summary.scalar('discriminator/precision_d', D_precision)
    recall_d = tf.summary.scalar('discriminator/recall_d', D_recall)
    loss_d_sum = tf.summary.scalar('discriminator/loss_d', D_loss)
    loss_d_GAN_sum = tf.summary.scalar('discriminator/loss_d_GAN', D_loss_GAN)

    loss_g_sum = tf.summary.scalar('generator/loss_g', G_loss)
    loss_g_GAN_sum = tf.summary.scalar('generator/loss_g_GAN', G_loss_GAN)
    loss_g_cyc_sum = tf.summary.scalar('generator/G_loss_cyc', G_loss_cyc)
    G_loss_rec_sum = tf.summary.scalar('generator/G_loss_rec', G_loss_rec)

    evaluation_fake = tf.summary.scalar('generator/fake_evaluation',
                                        fake_evaluation)
    evaluation_recons = tf.summary.scalar('generator/recons_evaluation',
                                          recons_evaluation)
    g_sum = tf.summary.merge([
        loss_g_sum, loss_g_GAN_sum, loss_g_cyc_sum, real_img_sum,
        G_loss_rec_sum, fake_img_sum, fake_source_img_sum,
        fake_source_recons_img_sum, evaluation_fake, evaluation_recons
    ])
    d_sum = tf.summary.merge(
        [loss_d_sum, loss_d_GAN_sum, acc_d, precision_d, recall_d])
    # Disentangler Contrastive Regularizer losses
    if disentangle:
        loss_r_fake_target_v_source = tf.summary.scalar(
            'disentangler/loss_r_fake_target_v_source',
            R_fake_target_v_source_loss)
        loss_r_fake_source_v_target = tf.summary.scalar(
            'disentangler/loss_r_fake_source_v_target',
            R_fake_source_v_target_loss)
        loss_r_fake_source_v_source = tf.summary.scalar(
            'disentangler/loss_r_fake_source_v_source',
            R_fake_source_v_source_loss)
        loss_r_fake_source_recon_v_source = tf.summary.scalar(
            'disentangler/loss_r_fake_source_recon_v_source',
            R_fake_source_recon_v_source_loss)
        loss_r_sum = tf.summary.scalar('disentangler/loss_r', R_loss)

        acc_r_fake_target_v_source = tf.summary.scalar(
            'disentangler/acc_r_fake_target_v_source',
            R_fake_target_v_source_acc)
        acc_r_fake_source_v_target = tf.summary.scalar(
            'disentangler/acc_r_fake_source_v_target',
            R_fake_source_v_target_acc)
        acc_r_fake_source_v_source = tf.summary.scalar(
            'disentangler/acc_r_fake_source_v_source',
            R_fake_source_v_source_acc)
        acc_r_fake_source_recon_v_source = tf.summary.scalar(
            'disentangler/acc_r_fake_source_recon_v_source',
            R_fake_source_recon_v_source_acc)
        r_sum = tf.summary.merge([
            loss_r_sum, loss_r_fake_target_v_source,
            loss_r_fake_source_v_target, loss_r_fake_source_v_source,
            loss_r_fake_source_recon_v_source, acc_r_fake_target_v_source,
            acc_r_fake_source_v_target, acc_r_fake_source_v_source,
            acc_r_fake_source_recon_v_source
        ])

    # ============= session =============
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    writer = tf.summary.FileWriter(log_dir, sess.graph)

    # ============= Checkpoints =============
    if continue_train:
        print(" [*] before training, Load checkpoint ")
        print(" [*] Reading checkpoint...")

        ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name))
            print(ckpt_dir_continue, ckpt_name)
            print("Successful checkpoint upload")
        else:
            print("Failed checkpoint load")
    else:
        print(" [!] before training, no need to Load ")

    # ============= load pre-trained classifier checkpoint =============
    class_vars = [
        var for var in slim.get_variables_to_restore()
        if 'classifier' in var.name
    ]
    name_to_var_map_local = {var.op.name: var for var in class_vars}
    temp_saver = tf.train.Saver(var_list=name_to_var_map_local)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir_cls)
    ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
    temp_saver.restore(sess, os.path.join(ckpt_dir_cls, ckpt_name))
    print("Classifier checkpoint loaded.................")
    print(ckpt_dir_cls, ckpt_name)

    # ============= Training =============
    for e in range(EPOCHS):
        np.random.shuffle(data)
        for i in range(data.shape[0] // BATCH_SIZE):
            if args.debug:
                image_paths = np.array(
                    [str(ind) for ind in my_data_loader.tmp_list])
            else:
                image_paths = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            img, labels = my_data_loader.load_images_and_labels(
                image_paths,
                image_dir=config['image_dir'],
                n_class=1,
                file_names_dict=file_names_dict,
                num_channel=channels,
                do_center_crop=True)

            labels = labels.ravel()
            target_labels = np.random.randint(0,
                                              high=NUMS_CLASS,
                                              size=BATCH_SIZE)

            identity_ind = labels == target_labels

            labels = convert_ordinal_to_binary(labels, NUMS_CLASS)
            target_labels = convert_ordinal_to_binary(target_labels,
                                                      NUMS_CLASS)

            if disentangle:
                target_disentangle_ind = np.random.randint(0,
                                                           high=k_dim,
                                                           size=BATCH_SIZE)
                target_disentangle_ind_one_hot = np.eye(
                    k_dim)[target_disentangle_ind]
                target_disentangle_ind_one_hot[identity_ind, :] = 0
                my_feed_dict = {
                    y_t: target_labels,
                    x_source: img,
                    train_phase: True,
                    y_s: labels,
                    y_regularizer: target_disentangle_ind,
                    y_r: target_disentangle_ind_one_hot
                }
            else:
                my_feed_dict = {
                    y_t: target_labels,
                    x_source: img,
                    train_phase: True,
                    y_s: labels
                }

            if (i + 1) % discriminate_evert_nth == 0:

                _, d_loss, summary_str, counter = sess.run(
                    [D_opt, D_loss, d_sum, global_step],
                    feed_dict=my_feed_dict)
                writer.add_summary(summary_str, counter)

            if (i + 1) % generate_every_nth == 0:
                if disentangle:
                    _, g_loss, g_summary_str, r_loss, r_summary_str, counter = sess.run(
                        [G_opt, G_loss, g_sum, R_loss, r_sum, global_step],
                        feed_dict=my_feed_dict)
                    # _, r_loss, r_summary_str = sess.run([R_opt, R_loss, r_sum], feed_dict=my_feed_dict)
                    writer.add_summary(r_summary_str, counter)
                else:
                    _, g_loss, g_summary_str, counter = sess.run(
                        [G_opt, G_loss, g_sum, global_step],
                        feed_dict=my_feed_dict)
                writer.add_summary(g_summary_str, counter)

            def save_results(sess, step):
                num_seed_imgs = 8
                img, labels = my_data_loader.load_images_and_labels(
                    image_paths[0:num_seed_imgs],
                    image_dir=config['image_dir'],
                    n_class=1,
                    file_names_dict=file_names_dict,
                    num_channel=channels,
                    do_center_crop=True)
                labels = np.repeat(labels, NUMS_CLASS * k_dim, 0)
                labels = labels.ravel()
                labels = convert_ordinal_to_binary(labels, NUMS_CLASS)
                img_repeat = np.repeat(img, NUMS_CLASS * k_dim, 0)

                target_labels = np.asarray([
                    np.asarray(range(NUMS_CLASS))
                    for j in range(num_seed_imgs * k_dim)
                ])
                target_labels = target_labels.ravel()
                identity_ind = labels == target_labels
                target_labels = convert_ordinal_to_binary(
                    target_labels, NUMS_CLASS)

                if disentangle:
                    target_disentangle_ind = np.asarray([
                        np.repeat(np.asarray(range(k_dim)), NUMS_CLASS)
                        for j in range(num_seed_imgs)
                    ])
                    target_disentangle_ind = target_disentangle_ind.ravel()
                    target_disentangle_ind_one_hot = np.eye(
                        k_dim)[target_disentangle_ind]
                    target_disentangle_ind_one_hot[identity_ind, :] = 0
                    my_feed_dict = {
                        y_t: target_labels,
                        x_source: img_repeat,
                        train_phase: False,
                        y_s: labels,
                        y_regularizer: target_disentangle_ind,
                        y_r: target_disentangle_ind_one_hot
                    }
                else:
                    my_feed_dict = {
                        y_t: target_labels,
                        x_source: img_repeat,
                        train_phase: False,
                        y_s: labels
                    }

                FAKE_IMG, fake_logits_ = sess.run(
                    [fake_target_img, fake_target_logits],
                    feed_dict=my_feed_dict)

                output_fake_img = np.reshape(
                    FAKE_IMG,
                    [-1, k_dim, NUMS_CLASS, input_size, input_size, channels])

                # save samples
                sample_file = os.path.join(sample_dir, '%06d.jpg' % step)
                save_images(output_fake_img,
                            sample_file,
                            num_samples=num_seed_imgs,
                            nums_class=NUMS_CLASS,
                            k_dim=k_dim,
                            image_size=input_size,
                            num_channel=channels)
                np.save(sample_file.split('.jpg')[0] + '_y.npy', labels)

            _approx_num_seen_batches = int(counter / 3)
            if _approx_num_seen_batches % save_summary == 0:
                save_results(sess, _approx_num_seen_batches)

            if _approx_num_seen_batches % save_ckpt == 0:
                saver.save(sess,
                           ckpt_dir +
                           "/model%2d.ckpt" % _approx_num_seen_batches,
                           global_step=global_step)
def train(config_path, overwrite_output_dir=None):
    config = yaml.load(open(config_path))
    print(config)

    # ============= Experiment Folder=============
    if overwrite_output_dir is not None:
        output_dir = overwrite_output_dir
    else:
        output_dir = os.path.join(config['log_dir'], config['name'])
    try:
        os.makedirs(output_dir)
    except:
        pass
    try:
        os.makedirs(os.path.join(output_dir, 'logs'))
    except:
        pass
    # ============= Experiment Parameters =============
    BATCH_SIZE = config['batch_size']
    EPOCHS = config['epochs']
    channels = config['num_channel']
    input_size = config['input_size']
    N_CLASSES = config['num_class']
    ckpt_dir_continue = config['ckpt_dir_continue']
    dataset = config['dataset']
    if dataset == 'CelebA':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=128)
    elif dataset == 'shapes':
        pretrained_classifier = shapes_classifier
        if config['image_dir'] == '':
            my_data_loader = ShapesLoader()
        else:
            my_data_loader = ImageLabelLoader(input_size=64)
    elif dataset == 'CelebA64' or dataset == 'dermatology':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
    elif dataset == 'synthderm':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
    if ckpt_dir_continue == '':
        continue_train = False
    else:
        continue_train = True

    if config['test'] == '':
        evaluate = False
    else:
        evaluate = True
    # ============= Data =============
    try:
        categories, file_names_dict = read_data_file(
            config['image_label_dict'])
    except:
        print("Problem in reading input data file : ",
              config['image_label_dict'])
        sys.exit()
    data_train = np.load(config['train'])
    print("The classification categories are: ")
    print(categories)
    print('The size of the training set: ', data_train.shape[0])
    if evaluate:
        data_test = np.load(config['test'])
        print('The size of the testing set: ', data_test.shape[0])
    fp = open(os.path.join(output_dir, 'setting.txt'), 'w')
    fp.write('config_file:' + str(config_path) + '\n')
    fp.close()
    # ============= placeholder =============
    with tf.name_scope('input'):
        x_ = tf.placeholder(tf.float32,
                            [None, input_size, input_size, channels],
                            name='x-input')
        y_ = tf.placeholder(tf.int64, [None, N_CLASSES], name='y-input')
        isTrain = tf.placeholder(tf.bool)
    # ============= Model =============
    if N_CLASSES == 1:
        y = tf.reshape(y_, [-1])
        y = tf.one_hot(y, 2, on_value=1.0, off_value=0.0, axis=-1)
        logit, prediction = pretrained_classifier(x_,
                                                  n_label=2,
                                                  reuse=False,
                                                  name='classifier',
                                                  isTrain=isTrain)
    else:
        logit, prediction = pretrained_classifier(x_,
                                                  n_label=N_CLASSES,
                                                  reuse=False,
                                                  name='classifier',
                                                  isTrain=isTrain)
        y = y_
    classif_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=y,
                                                   logits=logit)
    classif_acc = calc_accuracy(prediction=prediction, labels=y)
    loss = tf.losses.get_total_loss()
    # ============= Optimization functions =============
    train_step = tf.train.AdamOptimizer(0.0001).minimize(loss)
    # ============= summary =============
    cls_loss = tf.summary.scalar('classif_loss', classif_loss)
    total_loss = tf.summary.scalar('total_loss', loss)
    cls_acc = tf.summary.scalar('classif_acc', classif_acc)
    sum_train = tf.summary.merge([cls_loss, total_loss, cls_acc])
    # ============= Variables =============
    # Note that this list of variables only include the weights and biases in the model.
    lst_vars = []
    for v in tf.global_variables():
        lst_vars.append(v)
    # ============= Session =============
    sess = tf.InteractiveSession()
    saver = tf.train.Saver(var_list=lst_vars)
    tf.global_variables_initializer().run()
    writer = tf.summary.FileWriter(output_dir + '/train', sess.graph)
    if evaluate:
        writer_test = tf.summary.FileWriter(output_dir + '/test', sess.graph)
    # ============= Checkpoints =============
    if continue_train:
        print("Before training, Load checkpoint ")
        print("Reading checkpoint...")
        ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name))
            print(ckpt_name)
            print("Successful checkpoint upload")
        else:
            print("Failed checkpoint load")
            sys.exit()
    # ============= Training =============
    train_loss = []
    test_loss = []
    itr_train = 0
    itr_test = 0
    for epoch in range(EPOCHS):
        total_loss = 0.0
        perm = np.arange(data_train.shape[0])
        np.random.shuffle(perm)
        data_train = data_train[perm]
        num_batch = int(data_train.shape[0] / BATCH_SIZE)
        for i in range(0, num_batch):
            start = i * BATCH_SIZE
            ns = data_train[start:start + BATCH_SIZE]
            xs, ys = my_data_loader.load_images_and_labels(
                ns,
                image_dir=config['image_dir'],
                n_class=N_CLASSES,
                file_names_dict=file_names_dict,
                num_channel=channels,
                do_center_crop=True)
            [_, _loss, summary_str] = sess.run([train_step, loss, sum_train],
                                               feed_dict={
                                                   x_: xs,
                                                   isTrain: True,
                                                   y_: ys
                                               })
            writer.add_summary(summary_str, itr_train)
            itr_train += 1
            total_loss += _loss
        total_loss /= i
        print("Epoch: " + str(epoch) + " loss: " + str(total_loss) + '\n')
        train_loss.append(total_loss)

        if evaluate:
            total_loss = 0.0
            perm = np.arange(data_test.shape[0])
            np.random.shuffle(perm)
            data_test = data_test[perm]
            num_batch = int(data_test.shape[0] / BATCH_SIZE)
            for i in range(0, num_batch):
                start = i * BATCH_SIZE
                ns = data_test[start:start + BATCH_SIZE]
                xs, ys = my_data_loader.load_images_and_labels(
                    ns,
                    image_dir=config['image_dir'],
                    n_class=N_CLASSES,
                    file_names_dict=file_names_dict,
                    num_channel=channels,
                    do_center_crop=True)
                [_loss, summary_str] = sess.run([loss, sum_train],
                                                feed_dict={
                                                    x_: xs,
                                                    isTrain: False,
                                                    y_: ys
                                                })
                writer_test.add_summary(summary_str, itr_test)
                itr_test += 1
                total_loss += _loss
            total_loss /= i
            print("Epoch: " + str(epoch) + " Test loss: " + str(total_loss) +
                  '\n')
            test_loss.append(total_loss)
            np.save(os.path.join(output_dir, 'logs', 'test_loss.npy'),
                    np.asarray(test_loss))

        checkpoint_name = os.path.join(output_dir,
                                       'cp1_epoch' + str(epoch) + '.ckpt')
        save_path = saver.save(sess, checkpoint_name)
        np.save(os.path.join(output_dir, 'logs', 'train_loss.npy'),
                np.asarray(train_loss))
def test(config,
         dbg_img_label_dict=None,
         dbg_mode=False,
         export_output=True,
         dbg_size=10,
         dbg_img_indices=[],
         calc_stability=True):

    # ============= Experiment Folder=============
    assets_dir = os.path.join(config['log_dir'], config['name'])
    log_dir = os.path.join(assets_dir, 'log')
    ckpt_dir = os.path.join(assets_dir, 'ckpt_dir')
    sample_dir = os.path.join(assets_dir, 'sample')

    # Whether this is for saving the results for substitutability metric or the regular testing process.
    # If only for substitutability, we skip saving large arrays and additional multiple random outputs to avoid OOM
    calc_substitutability = config['calc_substitutability']

    if calc_substitutability:
        substitutability_attr = config['substitutability_attr']

        test_dir = os.path.join(assets_dir, 'test', 'substitutability_input')
        substitutability_exported_img_label_dict = os.path.join(
            test_dir, '{}_dims_{}_clss_{}.txt'.format(substitutability_attr,
                                                      config['k_dim'],
                                                      config['num_bins']))
        substitutability_label_scaler = config['num_bins'] - 1
        exported_dict = {}

        substitutability_classifier_config = config[
            'substitutability_classifier_config']
        _cls_config = yaml.load(open(config['classifier_config']))
        substitutability_img_subset = _cls_config['train']
        substitutability_img_label_dict = _cls_config['image_label_dict']
        _edited_cls_config = deepcopy(_cls_config)
        _edited_cls_config['image_dir'] = os.path.join(test_dir, 'images')
        if not os.path.exists(_edited_cls_config['image_dir']):
            os.makedirs(_edited_cls_config['image_dir'])
        _edited_cls_config[
            'image_label_dict'] = substitutability_exported_img_label_dict
        _edited_cls_config['train'] = os.path.join(test_dir, 'train_ids.npy')
        _edited_cls_config['test'] = ''  # skips evaluating on test
        _edited_cls_config['log_dir'] = test_dir
        _edited_cls_config['ckpt_dir_continue'] = ''
        save_config_dict(_edited_cls_config,
                         substitutability_classifier_config)
    else:
        test_dir = os.path.join(assets_dir, 'test')

    # ============= Experiment Parameters =============

    ckpt_dir_cls = config['cls_experiment']
    if 'evaluation_batch_size' in config.keys():
        BATCH_SIZE = config['evaluation_batch_size']
    else:
        BATCH_SIZE = config['batch_size']
    channels = config['num_channel']
    input_size = config['input_size']
    NUMS_CLASS_cls = config['num_class']
    NUMS_CLASS = config['num_bins']

    metrics_stability_nx = config['metrics_stability_nx']
    metrics_stability_var = config['metrics_stability_var']

    ckpt_dir_continue = ckpt_dir
    if dbg_img_label_dict is not None:
        image_label_dict = dbg_img_label_dict
    elif calc_substitutability:
        image_label_dict = substitutability_img_label_dict
    else:
        image_label_dict = config['image_label_dict']
    # there are k_dim disentangled knobs at indices 0..k_dim-1
    k_dim = config['k_dim']
    disentangle = k_dim > 1

    if dbg_mode:
        num_samples = dbg_size
    else:
        num_samples = config['count_to_save']

    dataset = config['dataset']
    if dataset == 'CelebA':
        my_data_loader = ImageLabelLoader(input_size=128)
        EMBEDDING_SIZE = embedding_size_128()
        pretrained_classifier = celeba_classifier
        Discriminator_Ordinal = Discriminator_Ordinal_128
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_128
    elif dataset == 'shapes':
        if calc_substitutability:
            my_data_loader = ShapesLoader()
        else:
            # my_data_loader = ShapesLoader()
            # for efficiency, let's just load as many samples as we need
            my_data_loader = ShapesLoader(
                dbg_mode=True,
                dbg_size=num_samples,
                dbg_image_label_dict=image_label_dict,
                dbg_img_indices=dbg_img_indices)
            dbg_mode = True
        EMBEDDING_SIZE = embedding_size_64()
        pretrained_classifier = shapes_classifier
        Discriminator_Ordinal = Discriminator_Ordinal_64
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_64
    elif dataset == 'CelebA64' or dataset == 'dermatology':
        my_data_loader = ImageLabelLoader(input_size=64)
        EMBEDDING_SIZE = embedding_size_64()
        pretrained_classifier = celeba_classifier
        Discriminator_Ordinal = Discriminator_Ordinal_64
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_64
    elif dataset == 'synthderm':
        my_data_loader = ImageLabelLoader(input_size=64)
        EMBEDDING_SIZE = embedding_size_64()
        pretrained_classifier = celeba_classifier
        Discriminator_Ordinal = Discriminator_Ordinal_64
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_64

    # ============= Data =============
    try:
        categories, file_names_dict = read_data_file(image_label_dict)
    except:
        print("Problem in reading input data file : ", image_label_dict)
        sys.exit()
    if calc_substitutability:
        data = np.load(substitutability_img_subset)
        num_samples = len(data)
    elif dbg_mode and dataset == 'shapes':
        data = np.array([str(ind) for ind in my_data_loader.tmp_list])
    else:
        if len(dbg_img_indices) > 0:
            data = np.asarray(dbg_img_indices)
        else:
            data = np.asarray(list(file_names_dict.keys()))
    print("The classification categories are: ")
    print(categories)
    print('The size of the training set: ', data.shape[0])

    # ============= placeholder =============
    x_source = tf.placeholder(tf.float32,
                              [None, input_size, input_size, channels],
                              name='x_source')
    y_s = tf.placeholder(tf.int32, [None, NUMS_CLASS], name='y_s')
    y_source = y_s[:, 0]
    train_phase = tf.placeholder(tf.bool, name='train_phase')

    y_t = tf.placeholder(tf.int32, [None, NUMS_CLASS], name='y_t')
    y_target = y_t[:, 0]

    if disentangle:
        y_regularizer = tf.placeholder(tf.int32, [None], name='y_regularizer')
        y_r = tf.placeholder(tf.float32, [None, k_dim], name='y_r')

    generation_dim = k_dim

    # ============= G & D =============
    G = Generator_Encoder_Decoder(
        "generator")  # with conditional BN, SAGAN: SN here as well
    D = Discriminator_Ordinal("discriminator")  # with SN and projection

    real_source_logits = D(x_source, y_s, NUMS_CLASS, "NO_OPS")
    if disentangle:
        fake_target_img, fake_target_img_embedding = G(
            x_source, y_regularizer * NUMS_CLASS + y_target,
            NUMS_CLASS * generation_dim)
        fake_source_img, fake_source_img_embedding = G(
            fake_target_img, y_regularizer * NUMS_CLASS + y_source,
            NUMS_CLASS * generation_dim)
        fake_source_recons_img, x_source_img_embedding = G(
            x_source, y_regularizer * NUMS_CLASS + y_source,
            NUMS_CLASS * generation_dim)
    else:
        fake_target_img, fake_target_img_embedding = G(x_source, y_target,
                                                       NUMS_CLASS)
        fake_source_img, fake_source_img_embedding = G(fake_target_img,
                                                       y_source, NUMS_CLASS)
        fake_source_recons_img, x_source_img_embedding = G(
            x_source, y_source, NUMS_CLASS)
    fake_target_logits = D(fake_target_img, y_t, NUMS_CLASS, None)

    # ============= pre-trained classifier =============
    real_img_cls_logit_pretrained, real_img_cls_prediction = pretrained_classifier(
        x_source, NUMS_CLASS_cls, reuse=False, name='classifier')
    fake_img_cls_logit_pretrained, fake_img_cls_prediction = pretrained_classifier(
        fake_target_img, NUMS_CLASS_cls, reuse=True)
    real_img_recons_cls_logit_pretrained, real_img_recons_cls_prediction = pretrained_classifier(
        fake_source_img, NUMS_CLASS_cls, reuse=True)
    fake_img_target_cls_prediction = tf.cast(
        y_target, tf.float32) * 1.0 / float(NUMS_CLASS - 1)
    # ============= session =============
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    # ============= Checkpoints =============
    print(" [*] Reading checkpoint...")

    ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name))
        print(ckpt_dir_continue, ckpt_name)
        print("Successful checkpoint upload")
    else:
        print("Failed checkpoint load")
        sys.exit()

    # ============= load pre-trained classifier checkpoint =============
    class_vars = [
        var for var in slim.get_variables_to_restore()
        if 'classifier' in var.name
    ]
    name_to_var_map_local = {var.op.name: var for var in class_vars}
    temp_saver = tf.train.Saver(var_list=name_to_var_map_local)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir_cls)
    ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
    temp_saver.restore(sess, os.path.join(ckpt_dir_cls, ckpt_name))
    print("Classifier checkpoint loaded.................")
    print(ckpt_dir_cls, ckpt_name)

    # ============= Testing =============
    def _save_output_array(name, values):
        np.save(os.path.join(test_dir, '{}.npy'.format(name)), values)

    if not calc_substitutability:
        names = np.empty([num_samples], dtype=object)
        real_imgs = np.empty([num_samples, input_size, input_size, channels])
        fake_t_imgs = np.empty([
            num_samples, generation_dim, NUMS_CLASS, input_size, input_size,
            channels
        ])
        fake_t_embeds = np.empty([num_samples, generation_dim, NUMS_CLASS] +
                                 EMBEDDING_SIZE)
        fake_s_imgs = np.empty([
            num_samples, generation_dim, NUMS_CLASS, input_size, input_size,
            channels
        ])
        fake_s_embeds = np.empty([num_samples, generation_dim, NUMS_CLASS] +
                                 EMBEDDING_SIZE)
        fake_s_recon_imgs = np.empty([
            num_samples, generation_dim, NUMS_CLASS, input_size, input_size,
            channels
        ])
        s_embeds = np.empty([num_samples, generation_dim, NUMS_CLASS] +
                            EMBEDDING_SIZE)
        real_ps = np.empty(
            [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls])
        recon_ps = np.empty(
            [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls])
        fake_target_ps = np.empty([num_samples, generation_dim, NUMS_CLASS])
        fake_ps = np.empty(
            [num_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls])

        # For stability metric
        stability_fake_t_imgs = np.empty([
            num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS,
            input_size, input_size, channels
        ])
        stability_fake_s_recon_imgs = np.empty([
            num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS,
            input_size, input_size, channels
        ])
        stability_recon_ps = np.empty([
            num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS,
            NUMS_CLASS_cls
        ])
        stability_fake_ps = np.empty([
            num_samples, metrics_stability_nx, generation_dim, NUMS_CLASS,
            NUMS_CLASS_cls
        ])

        arrs_to_save = [
            'names', 'real_imgs', 'fake_t_imgs', 'fake_t_embeds',
            'fake_s_imgs', 'fake_s_embeds', 'fake_s_recon_imgs', 's_embeds',
            'real_ps', 'recon_ps', 'fake_target_ps', 'fake_ps',
            'stability_fake_t_imgs', 'stability_fake_s_recon_imgs',
            'stability_recon_ps', 'stability_fake_ps'
        ]

    np.random.shuffle(data)

    data = data[0:num_samples]
    for i in range(math.ceil(data.shape[0] / BATCH_SIZE)):
        image_paths = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
        # num_seed_imgs is either BATCH_SIZE
        # or if the number of samples is not divisible by BATCH_SIZE a smaller value
        num_seed_imgs = np.shape(image_paths)[0]
        img, _labels = my_data_loader.load_images_and_labels(
            image_paths,
            config['image_dir'],
            1,
            file_names_dict,
            channels,
            do_center_crop=True)
        labels = np.repeat(_labels, NUMS_CLASS * generation_dim, 0)
        labels = labels.ravel()
        labels = convert_ordinal_to_binary(labels, NUMS_CLASS)
        img_repeat = np.repeat(img, NUMS_CLASS * generation_dim, 0)

        target_labels = np.asarray([
            np.asarray(range(NUMS_CLASS))
            for j in range(num_seed_imgs * generation_dim)
        ])
        target_labels = target_labels.ravel()
        identity_ind = labels == target_labels
        target_labels = convert_ordinal_to_binary(target_labels, NUMS_CLASS)

        if disentangle:
            target_disentangle_ind = np.asarray([
                np.repeat(np.asarray(range(generation_dim)), NUMS_CLASS)
                for j in range(num_seed_imgs)
            ])
            target_disentangle_ind = target_disentangle_ind.ravel()
            target_disentangle_ind_one_hot = np.eye(
                generation_dim)[target_disentangle_ind][:, 0:k_dim]
            target_disentangle_ind_one_hot[identity_ind, :] = 0
            my_feed_dict = {
                y_t: target_labels,
                x_source: img_repeat,
                train_phase: False,
                y_s: labels,
                y_regularizer: target_disentangle_ind,
                y_r: target_disentangle_ind_one_hot
            }
        else:
            my_feed_dict = {
                y_t: target_labels,
                x_source: img_repeat,
                train_phase: False,
                y_s: labels
            }

        fake_t_img, fake_t_embed, fake_s_img, fake_s_embed, fake_s_recon_img, s_embed, real_p, recon_p, fake_target_p, fake_p = sess.run(
            [
                fake_target_img, fake_target_img_embedding, fake_source_img,
                fake_source_img_embedding, fake_source_recons_img,
                x_source_img_embedding, real_img_cls_prediction,
                real_img_recons_cls_prediction, fake_img_target_cls_prediction,
                fake_img_cls_prediction
            ],
            feed_dict=my_feed_dict)

        print('{} / {}'.format(i + 1, math.ceil(data.shape[0] / BATCH_SIZE)))

        _num_cur_samples = len(image_paths)

        if calc_substitutability:
            _ind_generation_dim = np.random.randint(low=0,
                                                    high=generation_dim,
                                                    size=_num_cur_samples)
            reshaped_imgs = np.reshape(
                fake_t_img, (_num_cur_samples, generation_dim, NUMS_CLASS,
                             input_size, input_size, channels))
            sub_exported_dict = save_batch_images(
                reshaped_imgs,
                image_paths,
                _ind_generation_dim,
                _labels,
                substitutability_label_scaler,
                _edited_cls_config['image_dir'],
                has_extension=(dataset != 'shapes'))
            exported_dict.update(sub_exported_dict)
        else:
            start_ind = i * BATCH_SIZE
            end_ind = start_ind + _num_cur_samples
            names[start_ind:end_ind] = np.asarray(image_paths)

            if calc_stability:
                for j in range(metrics_stability_nx):
                    noisy_img = img + np.random.normal(
                        loc=0.0,
                        scale=metrics_stability_var,
                        size=np.shape(img))
                    stability_img_repeat = np.repeat(
                        noisy_img, NUMS_CLASS * generation_dim, 0)
                    stability_feed_dict = my_feed_dict.copy()
                    stability_feed_dict.update(
                        {x_source: stability_img_repeat})
                    _stability_fake_t_img, _stability_fake_s_recon_img, _stability_recon_p, _stability_fake_p = sess.run(
                        [
                            fake_target_img, fake_source_recons_img,
                            real_img_recons_cls_prediction,
                            fake_img_cls_prediction
                        ],
                        feed_dict=stability_feed_dict)

                    stability_fake_t_imgs[start_ind:end_ind, j] = np.reshape(
                        _stability_fake_t_img,
                        (_num_cur_samples, generation_dim, NUMS_CLASS,
                         input_size, input_size, channels))
                    stability_fake_s_recon_imgs[
                        start_ind:end_ind, j] = np.reshape(
                            _stability_fake_s_recon_img,
                            (_num_cur_samples, generation_dim, NUMS_CLASS,
                             input_size, input_size, channels))
                    stability_recon_ps[start_ind:end_ind, j] = np.reshape(
                        _stability_recon_p, (_num_cur_samples, generation_dim,
                                             NUMS_CLASS, NUMS_CLASS_cls))
                    stability_fake_ps[start_ind:end_ind, j] = np.reshape(
                        _stability_fake_p, (_num_cur_samples, generation_dim,
                                            NUMS_CLASS, NUMS_CLASS_cls))

            real_imgs[start_ind:end_ind] = img
            fake_t_imgs[start_ind:end_ind] = np.reshape(
                fake_t_img, (_num_cur_samples, generation_dim, NUMS_CLASS,
                             input_size, input_size, channels))
            fake_s_imgs[start_ind:end_ind] = np.reshape(
                fake_s_img, (_num_cur_samples, generation_dim, NUMS_CLASS,
                             input_size, input_size, channels))
            fake_s_recon_imgs[start_ind:end_ind] = np.reshape(
                fake_s_recon_img,
                (_num_cur_samples, generation_dim, NUMS_CLASS, input_size,
                 input_size, channels))
            real_ps[start_ind:end_ind] = np.reshape(
                real_p,
                (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls))
            recon_ps[start_ind:end_ind] = np.reshape(
                recon_p,
                (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls))
            fake_target_ps[start_ind:end_ind] = np.reshape(
                fake_target_p, (_num_cur_samples, generation_dim, NUMS_CLASS))
            fake_ps[start_ind:end_ind] = np.reshape(
                fake_p,
                (_num_cur_samples, generation_dim, NUMS_CLASS, NUMS_CLASS_cls))

            _RESHAPE_EMBED_SIZE = [
                _num_cur_samples, generation_dim, NUMS_CLASS
            ] + EMBEDDING_SIZE
            fake_t_embeds[start_ind:end_ind] = np.reshape(
                fake_t_embed, _RESHAPE_EMBED_SIZE)
            fake_s_embeds[start_ind:end_ind] = np.reshape(
                fake_s_embed, _RESHAPE_EMBED_SIZE)
            s_embeds[start_ind:end_ind] = np.reshape(s_embed,
                                                     _RESHAPE_EMBED_SIZE)

    output_dict = {}

    if calc_substitutability:
        save_dict(exported_dict, substitutability_exported_img_label_dict,
                  substitutability_attr)
        np.save(_edited_cls_config['train'],
                np.asarray(list(exported_dict.keys())))

        # retrain the classifier with the new generated images
        tf.reset_default_graph()
        train_classif(config['substitutability_classifier_config'])
    else:
        if export_output:
            for arr_name in arrs_to_save:
                _save_output_array(arr_name, eval(arr_name))

        for arr_name in arrs_to_save:
            output_dict.update({arr_name: eval(arr_name)})

    return output_dict