def test_function(params): DCG, disc = params Discriminator = DCG.DCGAND_1 BATCH_SIZE = FLAGS.batch_size with tf.Graph().as_default() as graph: train_data_list = helpers.get_dataset_files() real_data = input_pipeline(train_data_list, batch_size=BATCH_SIZE) # Normalize -1 to 1 real_data = 2 * ((tf.cast(real_data, tf.float32) / 255.) - .5) disc_real, _ = Discriminator(real_data) disc_vars = lib.params_with_name("Discriminator") disc_saver = tf.train.Saver(disc_vars) ckpt_disc = tf.train.get_checkpoint_state("./saved_models/" + disc + "/") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) print('Queue runners started.') real_im = sess.run([real_data])[0][0][0][0:5] print("Real Image range sample: ", real_im) try: if ckpt_disc and ckpt_disc.model_checkpoint_path: print("Restoring discriminator...", disc) disc_saver.restore(sess, ckpt_disc.model_checkpoint_path) pred_arr = np.empty([no_samples, BATCH_SIZE]) for i in tqdm(range(no_samples + 1)): predictions = sess.run([disc_real]) pred_arr[i - 1, :] = predictions[0] return pred_arr else: print("Failed to load Discriminator") except KeyboardInterrupt as e: print("Manual interrupt occurred.") finally: coord.request_stop() coord.join(threads) print('Finished inference.')
def begin_training(params): """ Takes model name, Generator and Discriminator architectures as input, builds the rest of the graph. """ model_name, Generator, Discriminator, epochs, restore = params fid_stats_file = "./tmp/" inception_path = "./tmp/" TRAIN_FOR_N_EPOCHS = epochs MODEL_NAME = model_name + "_" + FLAGS.dataset SUMMARY_DIR = 'summary/' + MODEL_NAME + "/" SAVE_DIR = "./saved_models/" + MODEL_NAME + "/" OUTPUT_DIR = './outputs/' + MODEL_NAME + "/" helpers.refresh_dirs(SUMMARY_DIR, OUTPUT_DIR, SAVE_DIR, restore) with tf.Graph().as_default(): with tf.variable_scope('input'): all_real_data_conv = input_pipeline( train_data_list, batch_size=BATCH_SIZE) # Split data over multiple GPUs: split_real_data_conv = tf.split(all_real_data_conv, len(DEVICES)) global_step = tf.train.get_or_create_global_step() gen_cost, disc_cost, pre_real, pre_fake, gradient_penalty, real_data, fake_data, disc_fake, disc_real = split_and_setup_costs( Generator, Discriminator, split_real_data_conv) gen_train_op, disc_train_op, gen_learning_rate = setup_train_ops( gen_cost, disc_cost, global_step) performance_merged, distances_merged = add_summaries(gen_cost, disc_cost, fake_data, real_data, gen_learning_rate, gradient_penalty, pre_real, pre_fake) saver = tf.train.Saver(max_to_keep=1) all_fixed_noise_samples = helpers.prepare_noise_samples( DEVICES, Generator) fid_stats_file += FLAGS.dataset + "_stats.npz" assert tf.gfile.Exists( fid_stats_file), "Can't find training set statistics for FID (%s)" % fid_stats_file f = np.load(fid_stats_file) mu_fid, sigma_fid = f['mu'][:], f['sigma'][:] f.close() inception_path = fid.check_or_download_inception(inception_path) fid.create_inception_graph(inception_path) # Create session config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True if FLAGS.use_XLA: config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 with tf.Session(config=config) as sess: # Restore variables if required ckpt = tf.train.get_checkpoint_state(SAVE_DIR) if restore and ckpt and ckpt.model_checkpoint_path: print("Restoring variables...") saver.restore(sess, ckpt.model_checkpoint_path) print('Variables restored from:\n', ckpt.model_checkpoint_path) else: # Initialise all the variables print("Initialising variables") sess.run(tf.local_variables_initializer()) sess.run(tf.global_variables_initializer()) print('Variables initialised.') # Start input enqueue threads coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) print('Queue runners started.') real_im = sess.run([all_real_data_conv])[0][0][0][0:5] print("Real Image range sample: ", real_im) summary_writer = tf.summary.FileWriter(SUMMARY_DIR, sess.graph) helpers.sample_dataset(sess, all_real_data_conv, OUTPUT_DIR) # Training loop try: ep_start = (global_step.eval(sess)) // EPOCH for epoch in tqdm(range(ep_start, TRAIN_FOR_N_EPOCHS), desc="Epochs passed"): step = (global_step.eval(sess)) % EPOCH for _ in tqdm(range(step, EPOCH), desc="Current epoch %i" % epoch, mininterval=0.5): # train gen _, step = sess.run([gen_train_op, global_step]) # Train discriminator if (MODE == 'dcgan') or (MODE == 'lsgan'): disc_iters = 1 else: disc_iters = CRITIC_ITERS for _ in range(disc_iters): _disc_cost, _ = sess.run( [disc_cost, disc_train_op]) if step % (128) == 0: _, _, _, performance_summary, distances_summary = sess.run( [gen_train_op, disc_cost, disc_train_op, performance_merged, distances_merged]) summary_writer.add_summary( performance_summary, step) summary_writer.add_summary( distances_summary, step) if step % (512) == 0: saver.save(sess, SAVE_DIR, global_step=step) helpers.generate_image(step, sess, OUTPUT_DIR, all_fixed_noise_samples, Generator, summary_writer) fid_score, IS_mean, IS_std, kid_score = fake_batch_stats( sess, fake_data) pre_real_out, pre_fake_out, fake_out, real_out = sess.run( [pre_real, pre_fake, disc_fake, disc_real]) scalar_avg_fake = np.mean(fake_out) scalar_sdev_fake = np.std(fake_out) scalar_avg_real = np.mean(real_out) scalar_sdev_real = np.std(real_out) frechet_dist = frechet_distance( pre_real_out, pre_fake_out) kid_score = np.mean(kid_score) inception_summary = tf.Summary() inception_summary.value.add( tag="distances/FD", simple_value=frechet_dist) inception_summary.value.add( tag="distances/FID", simple_value=fid_score) inception_summary.value.add( tag="distances/IS_mean", simple_value=IS_mean) inception_summary.value.add( tag="distances/IS_std", simple_value=IS_std) inception_summary.value.add( tag="distances/KID", simple_value=kid_score) inception_summary.value.add( tag="distances/scalar_mean_fake", simple_value=scalar_avg_fake) inception_summary.value.add( tag="distances/scalar_sdev_fake", simple_value=scalar_sdev_fake) inception_summary.value.add( tag="distances/scalar_mean_real", simple_value=scalar_avg_real) inception_summary.value.add( tag="distances/scalar_sdev_real", simple_value=scalar_sdev_real) summary_writer.add_summary(inception_summary, step) except KeyboardInterrupt as e: print("Manual interrupt occurred.") except Exception as e: print(e) finally: coord.request_stop() coord.join(threads) print('Finished training.') saver.save(sess, SAVE_DIR, global_step=step) print("Model " + MODEL_NAME + " saved in file: {} at step {}".format(SAVE_DIR, step))
# if you have downloaded and extracted # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz # set this path to the directory where the extracted files are, otherwise # just set it to None and the script will later download the files for you inception_path = './tmp' print("check for inception model..", end=" ", flush=True) inception_path = fid.check_or_download_inception( inception_path) # download inception if necessary print("ok") print("create inception graph..", end=" ", flush=True) # load the graph into the current TF graph fid.create_inception_graph(inception_path) print("ok") images = tf.squeeze(((input_pipeline(train_data_list, batch_size=BATCH_SIZE)))) # Split data over multiple GPUs: print("calculte FID stats..", end=" ", flush=True) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: try: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) no_images = (50000 // batch_size) * batch_size temp_image_list = np.empty([no_images, FLAGS.height, FLAGS.height, 3]) # Build a massive array of dataset images for i in range(no_images // batch_size): np_images = np.squeeze(np.asarray(sess.run([images]))) temp_image_list[i * batch_size:(i + 1) *
def test_function(params): DCG, gen, disc = params Generator = DCG.DCGANG_1 Discriminator = DCG.DCGAND_1 BATCH_SIZE = FLAGS.batch_size with tf.Graph().as_default() as graph: fake_data = Generator(BATCH_SIZE) disc_fake, _ = Discriminator(fake_data) train_data_list = helpers.get_dataset_files() real_data = input_pipeline(train_data_list, batch_size=BATCH_SIZE) # Normalize -1 to 1 real_data = 2 * ((tf.cast(real_data, tf.float32) / 255.) - .5) disc_real, _ = Discriminator(real_data) gen_vars = lib.params_with_name('Generator') gen_saver = tf.train.Saver(gen_vars) disc_vars = lib.params_with_name("Discriminator") disc_saver = tf.train.Saver(disc_vars) ckpt_gen = tf.train.get_checkpoint_state( "./saved_models/" + gen + "/") ckpt_disc = tf.train.get_checkpoint_state( "./saved_models/" + disc + "/") config = tf.ConfigProto(allow_soft_placement=True, ) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: try: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) if ckpt_gen and ckpt_gen.model_checkpoint_path: gen_saver.restore(sess, ckpt_gen.model_checkpoint_path) else: print("Failed to load Generator") if ckpt_disc and ckpt_disc.model_checkpoint_path: disc_saver.restore( sess, ckpt_disc.model_checkpoint_path) pred_arr = np.empty( [3, no_samples, BATCH_SIZE], dtype=np.float32) for i in tqdm(range(no_samples)): rl, fk = sess.run([real_data, fake_data]) fk = fk.reshape([256, 64, 64, 3]) half_real = np.zeros( [BATCH_SIZE, 64, 64, 3], dtype=np.float32) half_real[:BATCH_SIZE // 2] = rl[BATCH_SIZE // 2:] half_real[BATCH_SIZE // 2:] = fk[BATCH_SIZE // 2:] predictions = sess.run([disc_fake])[0] pred_arr[0, i, :] = predictions predictions = sess.run([disc_real])[0] pred_arr[1, i, :] = predictions predictions = sess.run( [Discriminator(half_real)])[0][0] pred_arr[2, i, :] = predictions fake_mean = np.mean(pred_arr[0]) real_mean = np.mean(pred_arr[1]) half_mean = np.mean(pred_arr[2]) fake_std = np.std(pred_arr[0]) real_std = np.std(pred_arr[1]) half_std = np.std(pred_arr[2]) coord.request_stop() coord.join(threads) return real_mean, fake_mean, half_mean, fake_std, real_std, half_std else: print("Failed to load Discriminator") except Exception as e: print(e) finally: coord.request_stop() coord.join(threads)
def explain(params=None): DCG, disc, images_to_explain, d_index, normalize_by_mean = params Discriminator = DCG.DCGAND_1 BATCH_SIZE = FLAGS.batch_size with tf.Graph().as_default() as graph: train_data_list = helpers.get_dataset_files() real_data = input_pipeline(train_data_list, batch_size=BATCH_SIZE) # Normalize -1 to 1 real_data = 2 * ((tf.cast(real_data, tf.float32) / 255.) - .5) disc_real, _ = Discriminator(real_data) disc_vars = lib.params_with_name("Discriminator") disc_saver = tf.train.Saver(disc_vars) ckpt_disc = tf.train.get_checkpoint_state( "./saved_models/" + disc + "/") config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # print('Queue runners started.') if ckpt_disc and ckpt_disc.model_checkpoint_path: # print("Restoring discriminator...", disc) disc_saver.restore( sess, ckpt_disc.model_checkpoint_path) def disc_prediction(image): # make fake batch: # Transform to -1 to 1: if np.max(image) > 1.1 or np.min(image) > 0.0: image = (image.astype(np.float32) * 2.0 / 255.0) - 1.0 if len(image.shape) == 4: no_ims = image.shape[0] else: no_ims = 1 images_batch = np.zeros( [256, 64, 64, 3]).astype(np.float32) images_batch[0:no_ims] = image prediction, _ = sess.run([Discriminator(images_batch)])[0] # Need to map input from [-inf, + inf] to [-1, +1] pred_array = np.zeros((no_ims, 2)) # Normalize predictions to see what happens: # prediction = (prediction-np.mean(prediction))/np.std(prediction) for i, x in enumerate(prediction[:no_ims]): if normalize_by_mean: bias = marginalized_means[d_index] pred_array[i, 1] = expit(x-bias) pred_array[i, 0] = 1 - pred_array[i, 1] else: pred_array[i, 1] = expit(x) pred_array[i, 0] = 1 - pred_array[i, 1] # 1 == REAL; 0 == FAKE return pred_array explanations = [] explainer = lime_image.LimeImageExplainer(verbose=False) segmenter = SegmentationAlgorithm( 'slic', n_segments=100, compactness=1, sigma=1) try: if not len(images_to_explain): images_to_explain = sess.run(real_data)[:no_samples] images_to_explain = (images_to_explain + 1.0) * 255.0 / 2.0 images_to_explain = images_to_explain.astype(np.uint8) images_to_explain = np.reshape( images_to_explain, [no_samples, 64, 64, 3]) for image_to_explain in tqdm(images_to_explain): explanation = explainer.explain_instance(image_to_explain, classifier_fn=disc_prediction, batch_size=256, top_labels=2, hide_color=None, num_samples=no_perturbed_images, segmentation_fn=segmenter) explanations.append(explanation) except KeyboardInterrupt as e: print("Manual interrupt occurred.") finally: coord.request_stop() coord.join(threads) make_figures(images_to_explain, explanations, DCG.get_G_dim(), DCG.get_D_dim(), normalize_by_mean) return images_to_explain else: print("Failed to load Discriminator", disc)