def main(): # tf flag flags = tf.flags flags.DEFINE_string("train_data_txt", "E:/git/beta-VAE/input/CT/shift/train.txt", "train data txt") flags.DEFINE_string("ground_truth_txt", "E:/git/beta-VAE/input/CT/shift/test.txt", "i1") flags.DEFINE_string( "model1", 'D:/vae_result/n1/z6/beta_1/model/model_{}'.format(997500), "i2") flags.DEFINE_string( "model2", 'D:/vae_result/n1+n2/all/sig/beta_1/model/model_{}'.format(197500), "i3") flags.DEFINE_string("outdir", "D:/vae_result/n1+n2/all/sig/beta_1/spe/", "i4") flags.DEFINE_float("beta", 1, "hyperparameter beta") flags.DEFINE_integer("num_of_generate", 5000, "number of generate data") flags.DEFINE_integer("num_of_test", 600, "number of test data") flags.DEFINE_integer("num_of_train", 1804, "number of train data") flags.DEFINE_integer("batch_size", 1, "batch size") flags.DEFINE_integer("latent_dim", 6, "latent dim") flags.DEFINE_list("image_size", [9 * 9 * 9], "image size") flags.DEFINE_boolean("const_bool", False, "if there is sigmoid in front of last output") FLAGS = flags.FLAGS # check folder if not (os.path.exists(FLAGS.outdir)): os.makedirs(FLAGS.outdir + 'spe1/') os.makedirs(FLAGS.outdir + 'spe2/') os.makedirs(FLAGS.outdir + 'spe_all/') # read list test_data_list = io.load_list(FLAGS.ground_truth_txt) train_data_list = io.load_list(FLAGS.train_data_txt) # test step test_step = FLAGS.num_of_generate // FLAGS.batch_size if FLAGS.num_of_generate % FLAGS.batch_size != 0: test_step += 1 # load train data train_set = tf.data.TFRecordDataset(train_data_list) train_set = train_set.map( lambda x: _parse_function(x, image_size=FLAGS.image_size), num_parallel_calls=os.cpu_count()) train_set = train_set.batch(FLAGS.batch_size) train_iter = train_set.make_one_shot_iterator() train_data = train_iter.get_next() # load test data test_set = tf.data.TFRecordDataset(test_data_list) test_set = test_set.map( lambda x: _parse_function(x, image_size=FLAGS.image_size), num_parallel_calls=os.cpu_count()) test_set = test_set.batch(FLAGS.batch_size) test_iter = test_set.make_one_shot_iterator() test_data = test_iter.get_next() # initializer init_op = tf.group(tf.initializers.global_variables(), tf.initializers.local_variables()) with tf.Session(config=utils.config) as sess: # set network # set network kwargs = { 'sess': sess, 'outdir': FLAGS.outdir, 'beta': FLAGS.beta, 'latent_dim': FLAGS.latent_dim, 'batch_size': FLAGS.batch_size, 'image_size': FLAGS.image_size, 'encoder': encoder_mlp, 'decoder': decoder_mlp, 'is_res': False } VAE = Variational_Autoencoder(**kwargs) kwargs_2 = { 'sess': sess, 'outdir': FLAGS.outdir, 'beta': FLAGS.beta, 'latent_dim': 8, 'batch_size': FLAGS.batch_size, 'image_size': FLAGS.image_size, 'encoder': encoder_mlp2, 'decoder': decoder_mlp_tanh, 'is_res': False, 'is_constraints': FLAGS.const_bool } VAE_2 = Variational_Autoencoder(**kwargs_2) sess.run(init_op) # testing VAE.restore_model(FLAGS.model1) VAE_2.restore_model(FLAGS.model2) tbar = tqdm(range(FLAGS.num_of_generate), ascii=True) specificity = [] spe_mean = [] generate_data = [] generate_data2 = [] ori = [] latent_space = [] latent_space2 = [] patch_side = 9 for i in range(FLAGS.num_of_train): train_data_batch = sess.run(train_data) z = VAE.plot_latent(train_data_batch) z2 = VAE_2.plot_latent(train_data_batch) z = z.flatten() z2 = z2.flatten() latent_space.append(z) latent_space2.append(z2) mu = np.mean(latent_space, axis=0) var = np.var(latent_space, axis=0) mu2 = np.mean(latent_space2, axis=0) var2 = np.var(latent_space2, axis=0) for i in range(FLAGS.num_of_test): test_data_batch = sess.run(test_data) ori_single = test_data_batch ori_single = ori_single[0, :] ori.append(ori_single) file_spe1 = open(FLAGS.outdir + 'spe1/list.txt', 'w') file_spe2 = open(FLAGS.outdir + 'spe2/list.txt', 'w') file_spe_all = open(FLAGS.outdir + 'spe_all/list.txt', 'w') for j in tbar: sample_z = np.random.normal(mu, var, (1, FLAGS.latent_dim)) sample_z2 = np.random.normal(mu2, var2, (1, 8)) generate_data_single = VAE.generate_sample(sample_z) if FLAGS.const_bool is False: generate_data_single2 = VAE_2.generate_sample(sample_z2) generate_data_single = generate_data_single[0, :] generate_data_single2 = generate_data_single2[0, :] generate_data.append(generate_data_single) generate_data2.append(generate_data_single2) gen = np.reshape(generate_data_single, [patch_side, patch_side, patch_side]) gen2 = np.reshape(generate_data_single2, [patch_side, patch_side, patch_side]) generate_data_single_all = generate_data_single + generate_data_single2 gen_all = gen + gen2 if FLAGS.const_bool is True: generate_data_single_all = VAE_2.generate_sample2( sample_z2, generate_data_single) generate_data_single = generate_data_single[0, :] generate_data_single_all = generate_data_single_all[0, :] generate_data.append(generate_data_single) generate_data2.append(generate_data_single_all) gen = np.reshape(generate_data_single, [patch_side, patch_side, patch_side]) gen_all = np.reshape(generate_data_single_all, [patch_side, patch_side, patch_side]) generate_data_single2 = generate_data_single_all - generate_data_single gen2 = gen_all - gen # EUDT gen_image = sitk.GetImageFromArray(gen) gen_image.SetSpacing([0.885, 0.885, 1]) gen_image.SetOrigin([0, 0, 0]) gen2_image = sitk.GetImageFromArray(gen2) gen2_image.SetSpacing([0.885, 0.885, 1]) gen2_image.SetOrigin([0, 0, 0]) gen_all_image = sitk.GetImageFromArray(gen_all) gen_all_image.SetSpacing([0.885, 0.885, 1]) gen_all_image.SetOrigin([0, 0, 0]) # calculation case_min_specificity = 1.0 for image_index in range(FLAGS.num_of_test): specificity_tmp = utils.L1norm(ori[image_index], generate_data_single_all) if specificity_tmp < case_min_specificity: case_min_specificity = specificity_tmp specificity.append([case_min_specificity]) spe = np.mean(specificity) spe_mean.append(spe) io.write_mhd_and_raw( gen_image, '{}.mhd'.format( os.path.join(FLAGS.outdir, 'spe1', 'spe1_{}'.format(j + 1)))) io.write_mhd_and_raw( gen2_image, '{}.mhd'.format( os.path.join(FLAGS.outdir, 'spe2', 'spe2_{}'.format(j + 1)))) io.write_mhd_and_raw( gen_all_image, '{}.mhd'.format( os.path.join(FLAGS.outdir, 'spe_all', 'spe_all_{}'.format(j + 1)))) file_spe1.write('{}.mhd'.format( os.path.join(FLAGS.outdir, 'spe1', 'spe1_{}'.format(j + 1))) + "\n") file_spe2.write('{}.mhd'.format( os.path.join(FLAGS.outdir, 'spe2', 'spe2_{}'.format(j + 1))) + "\n") file_spe_all.write('{}.mhd'.format( os.path.join(FLAGS.outdir, 'spe_all', 'spe_all_{}'.format( j + 1))) + "\n") file_spe1.close() file_spe2.close() file_spe_all.close() print('specificity = %f' % np.mean(specificity)) np.savetxt(os.path.join(FLAGS.outdir, 'specificity.csv'), specificity, delimiter=",") # spe graph plt.plot(spe_mean) plt.grid() # plt.show() plt.savefig(FLAGS.outdir + "Specificity.png")
def main(): # tf flag flags = tf.flags flags.DEFINE_string("test_data_txt", 'F:/data_info/VAE_liver/set_5/TFrecord/fold_1/train.txt', "test data txt") flags.DEFINE_string("dir", 'G:/experiment_result/liver/VAE/set_5/down/64/alpha_0.1/fold_1/VAE/axis_4/beta_6', "input dir") flags.DEFINE_integer("model_index", 3450 ,"index of model") flags.DEFINE_string("gpu_index", "0", "GPU-index") flags.DEFINE_float("beta", 1.0, "hyperparameter beta") flags.DEFINE_integer("num_of_test", 4681, "number of test data") flags.DEFINE_integer("batch_size", 1, "batch size") flags.DEFINE_integer("latent_dim", 4, "latent dim") flags.DEFINE_list("image_size", [56, 72, 88, 1], "image size") FLAGS = flags.FLAGS # check folder if not (os.path.exists(FLAGS.dir)): os.makedirs(FLAGS.dir) # read list test_data_list = io.load_list(FLAGS.test_data_txt) # test step test_step = FLAGS.num_of_test // FLAGS.batch_size if FLAGS.num_of_test % FLAGS.batch_size != 0: test_step += 1 # load test data test_set = tf.data.TFRecordDataset(test_data_list, compression_type = 'GZIP') test_set = test_set.map(lambda x: utils._parse_function(x, image_size=FLAGS.image_size), num_parallel_calls=os.cpu_count()) test_set = test_set.batch(FLAGS.batch_size) test_iter = test_set.make_one_shot_iterator() test_data = test_iter.get_next() # initializer init_op = tf.group(tf.initializers.global_variables(), tf.initializers.local_variables()) with tf.Session(config = utils.config(index=FLAGS.gpu_index)) as sess: # set network kwargs = { 'sess': sess, 'outdir': FLAGS.dir, 'beta': FLAGS.beta, 'latent_dim': FLAGS.latent_dim, 'batch_size': FLAGS.batch_size, 'image_size': FLAGS.image_size, 'encoder': encoder_resblock_bn, 'decoder': decoder_resblock_bn, 'downsampling': down_sampling, 'upsampling': up_sampling, 'is_training': False, 'is_down': False } VAE = Variational_Autoencoder(**kwargs) sess.run(init_op) # testing VAE.restore_model(os.path.join(FLAGS.dir,'model','model_{}'.format(FLAGS.model_index))) tbar = tqdm(range(test_step), ascii=True) latent_space = [] for k in tbar: test_data_batch = sess.run(test_data) ori_single = test_data_batch z = VAE.plot_latent(ori_single) z = z.flatten() if FLAGS.latent_dim == 1: z = [z[0], 0] latent_space.append(z) latent_space = np.asarray(latent_space) plt.figure(figsize=(8, 6)) fig = plt.scatter(latent_space[:, 0], latent_space[:, 1], alpha=0.2) plt.title('latent distribution') plt.xlabel('dim_1') plt.ylabel('dim_2') plt.savefig(os.path.join(FLAGS.dir, 'latent_distribution_{}.PNG'.format(FLAGS.model_index))) # filename = open(os.path.join(FLAGS.outdir, 'latent_distribution.pickle'), 'wb') # pickle.dump(fig, filename) # plt.show() latent_space = np.asarray(latent_space) mean = np.average(latent_space, axis=0) var = np.var(latent_space, axis=0, ddof=1) print(mean) print(var) print(np.cov(latent_space.transpose())) print('skew, kurtosis') print(skew(latent_space, axis=0)) print(kurtosis(latent_space, axis=0)) # output mean and var np.savetxt(os.path.join(FLAGS.dir, 'mean_{}.txt'.format(FLAGS.model_index)), mean) np.savetxt(os.path.join(FLAGS.dir, 'var_{}.txt'.format(FLAGS.model_index)), var)
def main(): # tf flag flags = tf.flags # flags.DEFINE_string("test_data_txt", "./input/CT/patch/test.txt", "i1") flags.DEFINE_string("test_data_txt", "./input/axis2/noise/test.txt", "i1") # flags.DEFINE_string("model", './output/CT/patch/model2/z24/alpha_1e-5/beta_0.1/fine/model/model_{}'.format(244000), "i2") # flags.DEFINE_string("outdir", "./output/CT/patch/model2/z24/alpha_1e-5/beta_0.1/fine/latent/", "i3") flags.DEFINE_string( "model", './output/axis2/noise/model2/z24/alpha_1e-5/model/model_{}'.format( 9072000), "i2") flags.DEFINE_string("outdir", "./output/axis2/noise/model2/z24/alpha_1e-5/latent/", "i3") flags.DEFINE_float("beta", 1, "hyperparameter beta") # flags.DEFINE_integer("num_of_test", 607, "number of test data") flags.DEFINE_integer("num_of_test", 3000, "number of test data") flags.DEFINE_integer("batch_size", 1, "batch size") flags.DEFINE_integer("latent_dim", 24, "latent dim") flags.DEFINE_list("image_size", [9 * 9 * 9], "image size") FLAGS = flags.FLAGS # check folder if not (os.path.exists(FLAGS.outdir)): os.makedirs(FLAGS.outdir + 'morphing/') # read list test_data_list = io.load_list(FLAGS.test_data_txt) # test step test_step = FLAGS.num_of_test // FLAGS.batch_size if FLAGS.num_of_test % FLAGS.batch_size != 0: test_step += 1 # load test data test_set = tf.data.TFRecordDataset(test_data_list) test_set = test_set.map( lambda x: _parse_function(x, image_size=FLAGS.image_size), num_parallel_calls=os.cpu_count()) test_set = test_set.batch(FLAGS.batch_size) test_iter = test_set.make_one_shot_iterator() test_data = test_iter.get_next() # initializer init_op = tf.group(tf.initializers.global_variables(), tf.initializers.local_variables()) with tf.Session(config=utils.config) as sess: # set network kwargs = { 'sess': sess, 'outdir': FLAGS.outdir, 'beta': FLAGS.beta, 'latent_dim': FLAGS.latent_dim, 'batch_size': FLAGS.batch_size, 'image_size': FLAGS.image_size, 'encoder': encoder_mlp, 'decoder': decoder_mlp } VAE = Variational_Autoencoder(**kwargs) sess.run(init_op) patch_side = 9 patch_center = int(patch_side / 2) # testing VAE.restore_model(FLAGS.model) tbar = tqdm(range(test_step), ascii=True) preds = [] ori = [] latent_space = [] for k in tbar: test_data_batch = sess.run(test_data) ori_single = test_data_batch z = VAE.plot_latent(ori_single) z = z.flatten() latent_space.append(z) latent_space = np.asarray(latent_space) # print("latent_space =",latent_space.shape) # print(latent_space[0]) # print(latent_space[1]) # print(latent_space[2]) # print(latent_space[3]) # print(latent_space[4]) mu = np.mean(latent_space, axis=0) var = np.var(latent_space, axis=0) sigma = np.sqrt(var) plt.figure(figsize=(8, 6)) fig = plt.scatter(latent_space[:, 0], latent_space[:, 1]) plt.xlabel('dim_1') plt.ylabel('dim_2') plt.title('latent distribution') plt.savefig(FLAGS.outdir + "latent_space.png") if FLAGS.latent_dim == 3: if not (os.path.exists(FLAGS.outdir + "3D/")): os.makedirs(FLAGS.outdir + "3D/") utils.matplotlib_plt(latent_space, FLAGS.outdir) # check folder # fig = plt.figure() # ax = fig.add_subplot(111, projection="3d") # ax.scatter(latent_space[:, 0], latent_space[:, 1], latent_space[:, 2], marker="x") # ax.scatter(latent_space[:5, 0], latent_space[:5, 1], latent_space[:5, 2], marker="o", color='orange') plt.figure(figsize=(8, 6)) plt.scatter(latent_space[:, 0], latent_space[:, 1]) plt.scatter(latent_space[:5, 0], latent_space[:5, 1], color='orange') plt.title('latent distribution') plt.xlabel('dim_1') plt.ylabel('dim_2') plt.savefig(FLAGS.outdir + "back_projection.png") # plt.show() #### display a 2D manifold of digits plt.figure() n = 13 digit_size = patch_side figure1 = np.zeros((digit_size * n, digit_size * n)) figure2 = np.zeros((digit_size * n, digit_size * n)) figure3 = np.zeros((digit_size * n, digit_size * n)) # linearly spaced coordinates corresponding to the 2D plot # of digit classes in the latent space grid_x = np.linspace(-3 * sigma[0], 3 * sigma[0], n) grid_y = np.linspace(-3 * sigma[1], 3 * sigma[1], n)[::-1] for i, yi in enumerate(grid_y): for j, xi in enumerate(grid_x): z_sample = [] if FLAGS.latent_dim == 2: z_sample = np.array([[xi, yi]]) if FLAGS.latent_dim == 3: z_sample = np.array([[xi, yi, 0]]) if FLAGS.latent_dim == 4: z_sample = np.array([[xi, yi, 0, 0]]) if FLAGS.latent_dim == 6: z_sample = np.array([[xi, yi, 0, 0, 0, 0]]) if FLAGS.latent_dim == 8: z_sample = np.array([[xi, yi, 0, 0, 0, 0, 0, 0]]) if FLAGS.latent_dim == 24: z_sample = np.array([[ xi, yi, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]]) if FLAGS.latent_dim == 25: z_sample = np.array([[ xi, yi, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]]) x_decoded = VAE.generate_sample(z_sample) generate_data = x_decoded[0].reshape(digit_size, digit_size, digit_size) digit_axial = generate_data[patch_center, :, :] digit_coronal = generate_data[:, patch_center, :] digit_sagital = generate_data[:, :, patch_center] digit1 = np.reshape(digit_axial, [patch_side, patch_side]) digit2 = np.reshape(digit_coronal, [patch_side, patch_side]) digit3 = np.reshape(digit_sagital, [patch_side, patch_side]) fig2 = plt.imshow(digit_axial, cmap='Greys_r', vmin=0, vmax=1, interpolation='none') plt.savefig(FLAGS.outdir + 'morphing/' + str(i) + '@' + str(j) + 'fig.png') figure1[i * digit_size:(i + 1) * digit_size, j * digit_size:(j + 1) * digit_size] = digit1 figure2[i * digit_size:(i + 1) * digit_size, j * digit_size:(j + 1) * digit_size] = digit2 figure3[i * digit_size:(i + 1) * digit_size, j * digit_size:(j + 1) * digit_size] = digit3 # set graph start_range = digit_size // 2 end_range = n * digit_size + start_range + 1 pixel_range = np.arange(start_range, end_range, digit_size) sample_range_x = np.round(grid_x, 1) sample_range_y = np.round(grid_y, 1) # axial plt.figure(figsize=(10, 10)) plt.xticks(pixel_range, sample_range_x) plt.yticks(pixel_range, sample_range_y) plt.xlabel("z[0]") plt.ylabel("z[1]") plt.imshow(figure1, cmap='Greys_r', vmin=0, vmax=1, interpolation='none') plt.savefig(FLAGS.outdir + "digit_axial.png") # plt.show() # coronal plt.figure(figsize=(10, 10)) plt.xticks(pixel_range, sample_range_x) plt.yticks(pixel_range, sample_range_y) plt.xlabel("z[0]") plt.ylabel("z[1]") plt.imshow(figure2, cmap='Greys_r', vmin=0, vmax=1, interpolation='none') plt.savefig(FLAGS.outdir + "digit_coronal.png") # plt.show() # sagital plt.figure(figsize=(10, 10)) plt.xticks(pixel_range, sample_range_x) plt.yticks(pixel_range, sample_range_y) plt.xlabel("z[0]") plt.ylabel("z[1]") plt.imshow(figure3, cmap='Greys_r', vmin=0, vmax=1, interpolation='none') plt.savefig(FLAGS.outdir + "digit_sagital.png")
def main(): parser = argparse.ArgumentParser( description='py, test_data_txt, model, outdir') parser.add_argument('--test_data_txt', '-i1', default='') parser.add_argument('--model', '-i2', default='./model_{}'.format(50000)) parser.add_argument('--outdir', '-i3', default='') args = parser.parse_args() # check folder if not (os.path.exists(args.outdir)): os.makedirs(args.outdir) # tf flag flags = tf.flags flags.DEFINE_float("beta", 0.1, "hyperparameter beta") flags.DEFINE_integer("num_of_test", 100, "number of test data") flags.DEFINE_integer("batch_size", 1, "batch size") flags.DEFINE_integer("latent_dim", 2, "latent dim") flags.DEFINE_list("image_size", [512, 512, 1], "image size") FLAGS = flags.FLAGS # read list test_data_list = io.load_list(args.test_data_txt) # test step test_step = FLAGS.num_of_test // FLAGS.batch_size if FLAGS.num_of_test % FLAGS.batch_size != 0: test_step += 1 # load test data test_set = tf.data.TFRecordDataset(test_data_list) test_set = test_set.map( lambda x: _parse_function(x, image_size=FLAGS.image_size), num_parallel_calls=os.cpu_count()) test_set = test_set.batch(FLAGS.batch_size) test_iter = test_set.make_one_shot_iterator() test_data = test_iter.get_next() # initializer init_op = tf.group(tf.initializers.global_variables(), tf.initializers.local_variables()) with tf.Session(config=utils.config) as sess: # set network kwargs = { 'sess': sess, 'outdir': args.outdir, 'beta': FLAGS.beta, 'latent_dim': FLAGS.latent_dim, 'batch_size': FLAGS.batch_size, 'image_size': FLAGS.image_size, 'encoder': cnn_encoder, 'decoder': cnn_decoder } VAE = Variational_Autoencoder(**kwargs) sess.run(init_op) # testing VAE.restore_model(args.model) tbar = tqdm(range(test_step), ascii=True) preds = [] ori = [] latent_space = [] for k in tbar: test_data_batch = sess.run(test_data) ori_single = test_data_batch z = VAE.plot_latent(ori_single) z = z.flatten() latent_space.append(z) latent_space = np.asarray(latent_space) plt.figure(figsize=(8, 6)) fig = plt.scatter(latent_space[:, 0], latent_space[:, 1]) plt.title('latent distribution') plt.xlabel('dim_1') plt.ylabel('dim_2') plt.show()