def test_function(params):
    DCG, disc = params
    Discriminator = DCG.DCGAND_1
    BATCH_SIZE = FLAGS.batch_size
    with tf.Graph().as_default() as graph:
        train_data_list = helpers.get_dataset_files()
        real_data = input_pipeline(train_data_list, batch_size=BATCH_SIZE)
        # Normalize -1 to 1
        real_data = 2 * ((tf.cast(real_data, tf.float32) / 255.) - .5)
        disc_real, _ = Discriminator(real_data)
        disc_vars = lib.params_with_name("Discriminator")
        disc_saver = tf.train.Saver(disc_vars)
        ckpt_disc = tf.train.get_checkpoint_state("./saved_models/" + disc +
                                                  "/")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print('Queue runners started.')
            real_im = sess.run([real_data])[0][0][0][0:5]
            print("Real Image range sample: ", real_im)
            try:
                if ckpt_disc and ckpt_disc.model_checkpoint_path:
                    print("Restoring discriminator...", disc)
                    disc_saver.restore(sess, ckpt_disc.model_checkpoint_path)
                    pred_arr = np.empty([no_samples, BATCH_SIZE])
                    for i in tqdm(range(no_samples + 1)):
                        predictions = sess.run([disc_real])
                        pred_arr[i - 1, :] = predictions[0]
                    return pred_arr
                else:
                    print("Failed to load Discriminator")
            except KeyboardInterrupt as e:
                print("Manual interrupt occurred.")
            finally:
                coord.request_stop()
                coord.join(threads)
                print('Finished inference.')
def begin_training(params):
    """
    Takes model name, Generator and Discriminator architectures as input,
    builds the rest of the graph.

    """
    model_name, Generator, Discriminator, epochs, restore = params
    fid_stats_file = "./tmp/"
    inception_path = "./tmp/"
    TRAIN_FOR_N_EPOCHS = epochs
    MODEL_NAME = model_name + "_" + FLAGS.dataset
    SUMMARY_DIR = 'summary/' + MODEL_NAME + "/"
    SAVE_DIR = "./saved_models/" + MODEL_NAME + "/"
    OUTPUT_DIR = './outputs/' + MODEL_NAME + "/"
    helpers.refresh_dirs(SUMMARY_DIR, OUTPUT_DIR, SAVE_DIR, restore)
    with tf.Graph().as_default():
        with tf.variable_scope('input'):
            all_real_data_conv = input_pipeline(
                train_data_list, batch_size=BATCH_SIZE)
            # Split data over multiple GPUs:
            split_real_data_conv = tf.split(all_real_data_conv, len(DEVICES))
        global_step = tf.train.get_or_create_global_step()

        gen_cost, disc_cost, pre_real, pre_fake, gradient_penalty, real_data, fake_data, disc_fake, disc_real = split_and_setup_costs(
            Generator, Discriminator, split_real_data_conv)

        gen_train_op, disc_train_op, gen_learning_rate = setup_train_ops(
            gen_cost, disc_cost, global_step)

        performance_merged, distances_merged = add_summaries(gen_cost, disc_cost, fake_data, real_data,
                                                             gen_learning_rate, gradient_penalty, pre_real, pre_fake)

        saver = tf.train.Saver(max_to_keep=1)
        all_fixed_noise_samples = helpers.prepare_noise_samples(
            DEVICES, Generator)

        fid_stats_file += FLAGS.dataset + "_stats.npz"
        assert tf.gfile.Exists(
            fid_stats_file), "Can't find training set statistics for FID (%s)" % fid_stats_file
        f = np.load(fid_stats_file)
        mu_fid, sigma_fid = f['mu'][:], f['sigma'][:]
        f.close()
        inception_path = fid.check_or_download_inception(inception_path)
        fid.create_inception_graph(inception_path)

        # Create session
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        if FLAGS.use_XLA:
            config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
        with tf.Session(config=config) as sess:
            # Restore variables if required
            ckpt = tf.train.get_checkpoint_state(SAVE_DIR)
            if restore and ckpt and ckpt.model_checkpoint_path:
                print("Restoring variables...")
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Variables restored from:\n', ckpt.model_checkpoint_path)
            else:
                # Initialise all the variables
                print("Initialising variables")
                sess.run(tf.local_variables_initializer())
                sess.run(tf.global_variables_initializer())
                print('Variables initialised.')
            # Start input enqueue threads
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            print('Queue runners started.')
            real_im = sess.run([all_real_data_conv])[0][0][0][0:5]
            print("Real Image range sample: ", real_im)

            summary_writer = tf.summary.FileWriter(SUMMARY_DIR, sess.graph)
            helpers.sample_dataset(sess, all_real_data_conv, OUTPUT_DIR)
            # Training loop
            try:
                ep_start = (global_step.eval(sess)) // EPOCH
                for epoch in tqdm(range(ep_start, TRAIN_FOR_N_EPOCHS), desc="Epochs passed"):
                    step = (global_step.eval(sess)) % EPOCH
                    for _ in tqdm(range(step, EPOCH), desc="Current epoch %i" % epoch, mininterval=0.5):
                        # train gen
                        _, step = sess.run([gen_train_op, global_step])
                        # Train discriminator
                        if (MODE == 'dcgan') or (MODE == 'lsgan'):
                            disc_iters = 1
                        else:
                            disc_iters = CRITIC_ITERS
                        for _ in range(disc_iters):
                            _disc_cost, _ = sess.run(
                                [disc_cost, disc_train_op])
                        if step % (128) == 0:
                            _, _, _, performance_summary, distances_summary = sess.run(
                                [gen_train_op, disc_cost, disc_train_op, performance_merged, distances_merged])
                            summary_writer.add_summary(
                                performance_summary, step)
                            summary_writer.add_summary(
                                distances_summary, step)

                        if step % (512) == 0:
                            saver.save(sess, SAVE_DIR, global_step=step)
                            helpers.generate_image(step, sess, OUTPUT_DIR,
                                                   all_fixed_noise_samples, Generator, summary_writer)
                            fid_score, IS_mean, IS_std, kid_score = fake_batch_stats(
                                sess, fake_data)
                            pre_real_out, pre_fake_out, fake_out, real_out = sess.run(
                                [pre_real, pre_fake, disc_fake, disc_real])
                            scalar_avg_fake = np.mean(fake_out)
                            scalar_sdev_fake = np.std(fake_out)
                            scalar_avg_real = np.mean(real_out)
                            scalar_sdev_real = np.std(real_out)

                            frechet_dist = frechet_distance(
                                pre_real_out, pre_fake_out)
                            kid_score = np.mean(kid_score)
                            inception_summary = tf.Summary()
                            inception_summary.value.add(
                                tag="distances/FD", simple_value=frechet_dist)
                            inception_summary.value.add(
                                tag="distances/FID", simple_value=fid_score)
                            inception_summary.value.add(
                                tag="distances/IS_mean", simple_value=IS_mean)
                            inception_summary.value.add(
                                tag="distances/IS_std", simple_value=IS_std)
                            inception_summary.value.add(
                                tag="distances/KID", simple_value=kid_score)
                            inception_summary.value.add(
                                tag="distances/scalar_mean_fake", simple_value=scalar_avg_fake)
                            inception_summary.value.add(
                                tag="distances/scalar_sdev_fake", simple_value=scalar_sdev_fake)
                            inception_summary.value.add(
                                tag="distances/scalar_mean_real", simple_value=scalar_avg_real)
                            inception_summary.value.add(
                                tag="distances/scalar_sdev_real", simple_value=scalar_sdev_real)
                            summary_writer.add_summary(inception_summary, step)
            except KeyboardInterrupt as e:
                print("Manual interrupt occurred.")
            except Exception as e:
                print(e)
            finally:
                coord.request_stop()
                coord.join(threads)
                print('Finished training.')
                saver.save(sess, SAVE_DIR, global_step=step)
                print("Model " + MODEL_NAME +
                      " saved in file: {} at step {}".format(SAVE_DIR, step))
# if you have downloaded and extracted
#   http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
# set this path to the directory where the extracted files are, otherwise
# just set it to None and the script will later download the files for you
inception_path = './tmp'
print("check for inception model..", end=" ", flush=True)
inception_path = fid.check_or_download_inception(
    inception_path)  # download inception if necessary
print("ok")

print("create inception graph..", end=" ", flush=True)
# load the graph into the current TF graph
fid.create_inception_graph(inception_path)
print("ok")

images = tf.squeeze(((input_pipeline(train_data_list, batch_size=BATCH_SIZE))))
# Split data over multiple GPUs:
print("calculte FID stats..", end=" ", flush=True)
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    try:
        sess.run(tf.global_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        no_images = (50000 // batch_size) * batch_size
        temp_image_list = np.empty([no_images, FLAGS.height, FLAGS.height, 3])
        # Build a massive array of dataset images
        for i in range(no_images // batch_size):
            np_images = np.squeeze(np.asarray(sess.run([images])))
            temp_image_list[i * batch_size:(i + 1) *
def test_function(params):
    DCG, gen, disc = params
    Generator = DCG.DCGANG_1
    Discriminator = DCG.DCGAND_1
    BATCH_SIZE = FLAGS.batch_size
    with tf.Graph().as_default() as graph:
        fake_data = Generator(BATCH_SIZE)
        disc_fake, _ = Discriminator(fake_data)

        train_data_list = helpers.get_dataset_files()
        real_data = input_pipeline(train_data_list, batch_size=BATCH_SIZE)
        # Normalize -1 to 1
        real_data = 2 * ((tf.cast(real_data, tf.float32) / 255.) - .5)
        disc_real, _ = Discriminator(real_data)
        gen_vars = lib.params_with_name('Generator')
        gen_saver = tf.train.Saver(gen_vars)
        disc_vars = lib.params_with_name("Discriminator")
        disc_saver = tf.train.Saver(disc_vars)
        ckpt_gen = tf.train.get_checkpoint_state(
            "./saved_models/" + gen + "/")
        ckpt_disc = tf.train.get_checkpoint_state(
            "./saved_models/" + disc + "/")
        config = tf.ConfigProto(allow_soft_placement=True,
                                )
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            try:
                sess.run(tf.global_variables_initializer())
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)
                if ckpt_gen and ckpt_gen.model_checkpoint_path:
                    gen_saver.restore(sess, ckpt_gen.model_checkpoint_path)
                else:
                    print("Failed to load Generator")
                if ckpt_disc and ckpt_disc.model_checkpoint_path:
                    disc_saver.restore(
                        sess, ckpt_disc.model_checkpoint_path)
                    pred_arr = np.empty(
                        [3, no_samples, BATCH_SIZE], dtype=np.float32)
                    for i in tqdm(range(no_samples)):
                        rl, fk = sess.run([real_data, fake_data])
                        fk = fk.reshape([256, 64, 64, 3])
                        half_real = np.zeros(
                            [BATCH_SIZE, 64, 64, 3], dtype=np.float32)
                        half_real[:BATCH_SIZE // 2] = rl[BATCH_SIZE // 2:]
                        half_real[BATCH_SIZE // 2:] = fk[BATCH_SIZE // 2:]
                        predictions = sess.run([disc_fake])[0]
                        pred_arr[0, i, :] = predictions
                        predictions = sess.run([disc_real])[0]
                        pred_arr[1, i, :] = predictions
                        predictions = sess.run(
                            [Discriminator(half_real)])[0][0]
                        pred_arr[2, i, :] = predictions
                    fake_mean = np.mean(pred_arr[0])
                    real_mean = np.mean(pred_arr[1])
                    half_mean = np.mean(pred_arr[2])
                    fake_std = np.std(pred_arr[0])
                    real_std = np.std(pred_arr[1])
                    half_std = np.std(pred_arr[2])
                    coord.request_stop()
                    coord.join(threads)
                    return real_mean, fake_mean, half_mean, fake_std, real_std, half_std
                else:
                    print("Failed to load Discriminator")
            except Exception as e:
                print(e)
            finally:
                coord.request_stop()
                coord.join(threads)
def explain(params=None):
    DCG, disc, images_to_explain, d_index, normalize_by_mean = params
    Discriminator = DCG.DCGAND_1
    BATCH_SIZE = FLAGS.batch_size
    with tf.Graph().as_default() as graph:
        train_data_list = helpers.get_dataset_files()
        real_data = input_pipeline(train_data_list, batch_size=BATCH_SIZE)
        # Normalize -1 to 1
        real_data = 2 * ((tf.cast(real_data, tf.float32) / 255.) - .5)
        disc_real, _ = Discriminator(real_data)
        disc_vars = lib.params_with_name("Discriminator")
        disc_saver = tf.train.Saver(disc_vars)
        ckpt_disc = tf.train.get_checkpoint_state(
            "./saved_models/" + disc + "/")
        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            # print('Queue runners started.')
            if ckpt_disc and ckpt_disc.model_checkpoint_path:
                # print("Restoring discriminator...", disc)
                disc_saver.restore(
                    sess, ckpt_disc.model_checkpoint_path)

                def disc_prediction(image):
                    # make fake batch:
                    # Transform to -1 to 1:
                    if np.max(image) > 1.1 or np.min(image) > 0.0:
                        image = (image.astype(np.float32) * 2.0 / 255.0) - 1.0
                    if len(image.shape) == 4:
                        no_ims = image.shape[0]
                    else:
                        no_ims = 1
                    images_batch = np.zeros(
                        [256, 64, 64, 3]).astype(np.float32)
                    images_batch[0:no_ims] = image
                    prediction, _ = sess.run([Discriminator(images_batch)])[0]
                    # Need to map input from [-inf, + inf] to [-1, +1]
                    pred_array = np.zeros((no_ims, 2))
                    # Normalize predictions to see what happens:
                    # prediction = (prediction-np.mean(prediction))/np.std(prediction)
                    for i, x in enumerate(prediction[:no_ims]):
                        if normalize_by_mean:
                            bias = marginalized_means[d_index]
                            pred_array[i, 1] = expit(x-bias)
                            pred_array[i, 0] = 1 - pred_array[i, 1]
                        else:
                            pred_array[i, 1] = expit(x)
                            pred_array[i, 0] = 1 - pred_array[i, 1]
                        # 1 == REAL; 0 == FAKE
                    return pred_array
                explanations = []
                explainer = lime_image.LimeImageExplainer(verbose=False)
                segmenter = SegmentationAlgorithm(
                    'slic', n_segments=100, compactness=1, sigma=1)
                try:
                    if not len(images_to_explain):
                        images_to_explain = sess.run(real_data)[:no_samples]
                        images_to_explain = (images_to_explain + 1.0) * 255.0 / 2.0
                        images_to_explain = images_to_explain.astype(np.uint8)
                        images_to_explain = np.reshape(
                        images_to_explain, [no_samples, 64, 64, 3])
                    for image_to_explain in tqdm(images_to_explain):
                        explanation = explainer.explain_instance(image_to_explain,
                                                                 classifier_fn=disc_prediction, batch_size=256,
                                                                 top_labels=2, hide_color=None, num_samples=no_perturbed_images,
                                                                 segmentation_fn=segmenter)
                        explanations.append(explanation)
                except KeyboardInterrupt as e:
                    print("Manual interrupt occurred.")
                finally:
                    coord.request_stop()
                    coord.join(threads)
                make_figures(images_to_explain, explanations,
                             DCG.get_G_dim(), DCG.get_D_dim(), normalize_by_mean)
                return images_to_explain
            else:
                print("Failed to load Discriminator", disc)