Exemple #1
0
def main():

    # tf flag
    flags = tf.flags
    flags.DEFINE_string("train_data_txt",
                        "E:/git/beta-VAE/input/CT/shift/train.txt",
                        "train data txt")
    flags.DEFINE_string("ground_truth_txt",
                        "E:/git/beta-VAE/input/CT/shift/test.txt", "i1")
    flags.DEFINE_string(
        "model1", 'D:/vae_result/n1/z6/beta_1/model/model_{}'.format(997500),
        "i2")
    flags.DEFINE_string(
        "model2",
        'D:/vae_result/n1+n2/all/sig/beta_1/model/model_{}'.format(197500),
        "i3")
    flags.DEFINE_string("outdir", "D:/vae_result/n1+n2/all/sig/beta_1/spe/",
                        "i4")
    flags.DEFINE_float("beta", 1, "hyperparameter beta")
    flags.DEFINE_integer("num_of_generate", 5000, "number of generate data")
    flags.DEFINE_integer("num_of_test", 600, "number of test data")
    flags.DEFINE_integer("num_of_train", 1804, "number of train data")
    flags.DEFINE_integer("batch_size", 1, "batch size")
    flags.DEFINE_integer("latent_dim", 6, "latent dim")
    flags.DEFINE_list("image_size", [9 * 9 * 9], "image size")
    flags.DEFINE_boolean("const_bool", False,
                         "if there is sigmoid in front of last output")
    FLAGS = flags.FLAGS

    # check folder
    if not (os.path.exists(FLAGS.outdir)):
        os.makedirs(FLAGS.outdir + 'spe1/')
        os.makedirs(FLAGS.outdir + 'spe2/')
        os.makedirs(FLAGS.outdir + 'spe_all/')

    # read list
    test_data_list = io.load_list(FLAGS.ground_truth_txt)
    train_data_list = io.load_list(FLAGS.train_data_txt)

    # test step
    test_step = FLAGS.num_of_generate // FLAGS.batch_size
    if FLAGS.num_of_generate % FLAGS.batch_size != 0:
        test_step += 1

    # load train data
    train_set = tf.data.TFRecordDataset(train_data_list)
    train_set = train_set.map(
        lambda x: _parse_function(x, image_size=FLAGS.image_size),
        num_parallel_calls=os.cpu_count())
    train_set = train_set.batch(FLAGS.batch_size)
    train_iter = train_set.make_one_shot_iterator()
    train_data = train_iter.get_next()

    # load test data
    test_set = tf.data.TFRecordDataset(test_data_list)
    test_set = test_set.map(
        lambda x: _parse_function(x, image_size=FLAGS.image_size),
        num_parallel_calls=os.cpu_count())
    test_set = test_set.batch(FLAGS.batch_size)
    test_iter = test_set.make_one_shot_iterator()
    test_data = test_iter.get_next()

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config=utils.config) as sess:

        # set network
        # set network
        kwargs = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_mlp,
            'decoder': decoder_mlp,
            'is_res': False
        }
        VAE = Variational_Autoencoder(**kwargs)
        kwargs_2 = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': 8,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_mlp2,
            'decoder': decoder_mlp_tanh,
            'is_res': False,
            'is_constraints': FLAGS.const_bool
        }
        VAE_2 = Variational_Autoencoder(**kwargs_2)

        sess.run(init_op)

        # testing
        VAE.restore_model(FLAGS.model1)
        VAE_2.restore_model(FLAGS.model2)

        tbar = tqdm(range(FLAGS.num_of_generate), ascii=True)
        specificity = []
        spe_mean = []
        generate_data = []
        generate_data2 = []
        ori = []
        latent_space = []
        latent_space2 = []

        patch_side = 9

        for i in range(FLAGS.num_of_train):
            train_data_batch = sess.run(train_data)
            z = VAE.plot_latent(train_data_batch)
            z2 = VAE_2.plot_latent(train_data_batch)
            z = z.flatten()
            z2 = z2.flatten()
            latent_space.append(z)
            latent_space2.append(z2)

        mu = np.mean(latent_space, axis=0)
        var = np.var(latent_space, axis=0)
        mu2 = np.mean(latent_space2, axis=0)
        var2 = np.var(latent_space2, axis=0)

        for i in range(FLAGS.num_of_test):
            test_data_batch = sess.run(test_data)
            ori_single = test_data_batch
            ori_single = ori_single[0, :]
            ori.append(ori_single)

        file_spe1 = open(FLAGS.outdir + 'spe1/list.txt', 'w')
        file_spe2 = open(FLAGS.outdir + 'spe2/list.txt', 'w')
        file_spe_all = open(FLAGS.outdir + 'spe_all/list.txt', 'w')

        for j in tbar:
            sample_z = np.random.normal(mu, var, (1, FLAGS.latent_dim))
            sample_z2 = np.random.normal(mu2, var2, (1, 8))
            generate_data_single = VAE.generate_sample(sample_z)
            if FLAGS.const_bool is False:
                generate_data_single2 = VAE_2.generate_sample(sample_z2)
                generate_data_single = generate_data_single[0, :]
                generate_data_single2 = generate_data_single2[0, :]
                generate_data.append(generate_data_single)
                generate_data2.append(generate_data_single2)
                gen = np.reshape(generate_data_single,
                                 [patch_side, patch_side, patch_side])
                gen2 = np.reshape(generate_data_single2,
                                  [patch_side, patch_side, patch_side])
                generate_data_single_all = generate_data_single + generate_data_single2
                gen_all = gen + gen2

            if FLAGS.const_bool is True:
                generate_data_single_all = VAE_2.generate_sample2(
                    sample_z2, generate_data_single)
                generate_data_single = generate_data_single[0, :]
                generate_data_single_all = generate_data_single_all[0, :]
                generate_data.append(generate_data_single)
                generate_data2.append(generate_data_single_all)
                gen = np.reshape(generate_data_single,
                                 [patch_side, patch_side, patch_side])
                gen_all = np.reshape(generate_data_single_all,
                                     [patch_side, patch_side, patch_side])
                generate_data_single2 = generate_data_single_all - generate_data_single
                gen2 = gen_all - gen

            # EUDT
            gen_image = sitk.GetImageFromArray(gen)
            gen_image.SetSpacing([0.885, 0.885, 1])
            gen_image.SetOrigin([0, 0, 0])

            gen2_image = sitk.GetImageFromArray(gen2)
            gen2_image.SetSpacing([0.885, 0.885, 1])
            gen2_image.SetOrigin([0, 0, 0])

            gen_all_image = sitk.GetImageFromArray(gen_all)
            gen_all_image.SetSpacing([0.885, 0.885, 1])
            gen_all_image.SetOrigin([0, 0, 0])

            # calculation
            case_min_specificity = 1.0
            for image_index in range(FLAGS.num_of_test):
                specificity_tmp = utils.L1norm(ori[image_index],
                                               generate_data_single_all)
                if specificity_tmp < case_min_specificity:
                    case_min_specificity = specificity_tmp

            specificity.append([case_min_specificity])
            spe = np.mean(specificity)
            spe_mean.append(spe)

            io.write_mhd_and_raw(
                gen_image, '{}.mhd'.format(
                    os.path.join(FLAGS.outdir, 'spe1',
                                 'spe1_{}'.format(j + 1))))
            io.write_mhd_and_raw(
                gen2_image, '{}.mhd'.format(
                    os.path.join(FLAGS.outdir, 'spe2',
                                 'spe2_{}'.format(j + 1))))
            io.write_mhd_and_raw(
                gen_all_image, '{}.mhd'.format(
                    os.path.join(FLAGS.outdir, 'spe_all',
                                 'spe_all_{}'.format(j + 1))))
            file_spe1.write('{}.mhd'.format(
                os.path.join(FLAGS.outdir, 'spe1', 'spe1_{}'.format(j + 1))) +
                            "\n")
            file_spe2.write('{}.mhd'.format(
                os.path.join(FLAGS.outdir, 'spe2', 'spe2_{}'.format(j + 1))) +
                            "\n")
            file_spe_all.write('{}.mhd'.format(
                os.path.join(FLAGS.outdir, 'spe_all', 'spe_all_{}'.format(
                    j + 1))) + "\n")

    file_spe1.close()
    file_spe2.close()
    file_spe_all.close()

    print('specificity = %f' % np.mean(specificity))
    np.savetxt(os.path.join(FLAGS.outdir, 'specificity.csv'),
               specificity,
               delimiter=",")

    # spe graph
    plt.plot(spe_mean)
    plt.grid()
    # plt.show()
    plt.savefig(FLAGS.outdir + "Specificity.png")
Exemple #2
0
def main():
    parser = argparse.ArgumentParser(
        description='py, test_data_txt, ground_truth_txt, outdir')

    parser.add_argument('--ground_truth_txt', '-i1', default='')

    parser.add_argument('--model', '-i2', default='./model_{}'.format(50000))

    parser.add_argument('--outdir', '-i3', default='')

    args = parser.parse_args()

    # check folder
    if not (os.path.exists(args.outdir)):
        os.makedirs(args.outdir)

    # tf flag
    flags = tf.flags
    flags.DEFINE_float("beta", 0.1, "hyperparameter beta")
    flags.DEFINE_integer("num_of_generate", 1000, "number of generate data")
    flags.DEFINE_integer("batch_size", 1, "batch size")
    flags.DEFINE_integer("latent_dim", 2, "latent dim")
    flags.DEFINE_list("image_size", [512, 512, 1], "image size")
    FLAGS = flags.FLAGS

    # load ground truth
    ground_truth = io.load_matrix_data(args.ground_truth_txt, 'int32')
    print(ground_truth.shape)

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config=utils.config) as sess:

        # set network
        kwargs = {
            'sess': sess,
            'outdir': args.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': cnn_encoder,
            'decoder': cnn_decoder
        }
        VAE = Variational_Autoencoder(**kwargs)

        sess.run(init_op)

        # testing
        VAE.restore_model(args.model)
        tbar = tqdm(range(FLAGS.num_of_generate), ascii=True)
        specificity = []
        for i in tbar:
            sample_z = np.random.normal(0, 1.0, (1, FLAGS.latent_dim))
            generate_data = VAE.generate_sample(sample_z)
            generate_data = generate_data[0, :, :, 0]

            # EUDT
            eudt_image = sitk.GetImageFromArray(generate_data)
            eudt_image.SetSpacing([1, 1])
            eudt_image.SetOrigin([0, 0])

            # label
            label = np.where(generate_data > 0, 0, 1)
            label_image = sitk.GetImageFromArray(label)
            label_image.SetSpacing([1, 1])
            label_image.SetOrigin([0, 0])

            # calculate ji
            case_max_ji = 0.
            for image_index in range(ground_truth.shape[0]):
                ji = utils.jaccard(label, ground_truth[image_index])
                if ji > case_max_ji:
                    case_max_ji = ji
            specificity.append([case_max_ji])

            # output image
            io.write_mhd_and_raw(
                eudt_image,
                '{}.mhd'.format(os.path.join(args.outdir, 'EUDT', str(i + 1))))
            io.write_mhd_and_raw(
                label_image,
                '{}.mhd'.format(os.path.join(args.outdir, 'label',
                                             str(i + 1))))

    print('specificity = %f' % np.mean(specificity))

    # output csv file
    with open(os.path.join(args.outdir, 'specificity.csv'), 'w',
              newline='') as file:
        writer = csv.writer(file)
        writer.writerows(specificity)
        writer.writerow(['specificity:', np.mean(specificity)])
Exemple #3
0
def main():

    # tf flag
    flags = tf.flags
    # flags.DEFINE_string("test_data_txt", "./input/CT/patch/test.txt", "i1")
    flags.DEFINE_string("test_data_txt", "./input/axis2/noise/test.txt", "i1")
    # flags.DEFINE_string("model", './output/CT/patch/model2/z24/alpha_1e-5/beta_0.1/fine/model/model_{}'.format(244000), "i2")
    # flags.DEFINE_string("outdir", "./output/CT/patch/model2/z24/alpha_1e-5/beta_0.1/fine/latent/", "i3")
    flags.DEFINE_string(
        "model",
        './output/axis2/noise/model2/z24/alpha_1e-5/model/model_{}'.format(
            9072000), "i2")
    flags.DEFINE_string("outdir",
                        "./output/axis2/noise/model2/z24/alpha_1e-5/latent/",
                        "i3")
    flags.DEFINE_float("beta", 1, "hyperparameter beta")
    # flags.DEFINE_integer("num_of_test", 607, "number of test data")
    flags.DEFINE_integer("num_of_test", 3000, "number of test data")
    flags.DEFINE_integer("batch_size", 1, "batch size")
    flags.DEFINE_integer("latent_dim", 24, "latent dim")
    flags.DEFINE_list("image_size", [9 * 9 * 9], "image size")
    FLAGS = flags.FLAGS

    # check folder
    if not (os.path.exists(FLAGS.outdir)):
        os.makedirs(FLAGS.outdir + 'morphing/')

    # read list
    test_data_list = io.load_list(FLAGS.test_data_txt)

    # test step
    test_step = FLAGS.num_of_test // FLAGS.batch_size
    if FLAGS.num_of_test % FLAGS.batch_size != 0:
        test_step += 1

    # load test data
    test_set = tf.data.TFRecordDataset(test_data_list)
    test_set = test_set.map(
        lambda x: _parse_function(x, image_size=FLAGS.image_size),
        num_parallel_calls=os.cpu_count())
    test_set = test_set.batch(FLAGS.batch_size)
    test_iter = test_set.make_one_shot_iterator()
    test_data = test_iter.get_next()

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config=utils.config) as sess:

        # set network
        kwargs = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_mlp,
            'decoder': decoder_mlp
        }
        VAE = Variational_Autoencoder(**kwargs)

        sess.run(init_op)

        patch_side = 9
        patch_center = int(patch_side / 2)

        # testing
        VAE.restore_model(FLAGS.model)
        tbar = tqdm(range(test_step), ascii=True)
        preds = []
        ori = []
        latent_space = []
        for k in tbar:
            test_data_batch = sess.run(test_data)
            ori_single = test_data_batch
            z = VAE.plot_latent(ori_single)
            z = z.flatten()
            latent_space.append(z)

        latent_space = np.asarray(latent_space)
        # print("latent_space =",latent_space.shape)
        # print(latent_space[0])
        # print(latent_space[1])
        # print(latent_space[2])
        # print(latent_space[3])
        # print(latent_space[4])
        mu = np.mean(latent_space, axis=0)
        var = np.var(latent_space, axis=0)
        sigma = np.sqrt(var)

        plt.figure(figsize=(8, 6))
        fig = plt.scatter(latent_space[:, 0], latent_space[:, 1])
        plt.xlabel('dim_1')
        plt.ylabel('dim_2')
        plt.title('latent distribution')
        plt.savefig(FLAGS.outdir + "latent_space.png")

        if FLAGS.latent_dim == 3:
            if not (os.path.exists(FLAGS.outdir + "3D/")):
                os.makedirs(FLAGS.outdir + "3D/")
            utils.matplotlib_plt(latent_space, FLAGS.outdir)
            # check folder

            # fig = plt.figure()
            # ax = fig.add_subplot(111, projection="3d")
            # ax.scatter(latent_space[:, 0], latent_space[:, 1], latent_space[:, 2], marker="x")
            # ax.scatter(latent_space[:5, 0], latent_space[:5, 1], latent_space[:5, 2], marker="o", color='orange')

        plt.figure(figsize=(8, 6))
        plt.scatter(latent_space[:, 0], latent_space[:, 1])
        plt.scatter(latent_space[:5, 0], latent_space[:5, 1], color='orange')
        plt.title('latent distribution')
        plt.xlabel('dim_1')
        plt.ylabel('dim_2')
        plt.savefig(FLAGS.outdir + "back_projection.png")
        # plt.show()

        #### display a 2D manifold of digits
        plt.figure()
        n = 13
        digit_size = patch_side
        figure1 = np.zeros((digit_size * n, digit_size * n))
        figure2 = np.zeros((digit_size * n, digit_size * n))
        figure3 = np.zeros((digit_size * n, digit_size * n))
        # linearly spaced coordinates corresponding to the 2D plot
        # of digit classes in the latent space
        grid_x = np.linspace(-3 * sigma[0], 3 * sigma[0], n)
        grid_y = np.linspace(-3 * sigma[1], 3 * sigma[1], n)[::-1]

        for i, yi in enumerate(grid_y):
            for j, xi in enumerate(grid_x):
                z_sample = []

                if FLAGS.latent_dim == 2:
                    z_sample = np.array([[xi, yi]])
                if FLAGS.latent_dim == 3:
                    z_sample = np.array([[xi, yi, 0]])
                if FLAGS.latent_dim == 4:
                    z_sample = np.array([[xi, yi, 0, 0]])
                if FLAGS.latent_dim == 6:
                    z_sample = np.array([[xi, yi, 0, 0, 0, 0]])
                if FLAGS.latent_dim == 8:
                    z_sample = np.array([[xi, yi, 0, 0, 0, 0, 0, 0]])
                if FLAGS.latent_dim == 24:
                    z_sample = np.array([[
                        xi, yi, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                        0, 0, 0, 0, 0, 0
                    ]])
                if FLAGS.latent_dim == 25:
                    z_sample = np.array([[
                        xi, yi, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
                        0, 0, 0, 0, 0, 0, 0
                    ]])

                x_decoded = VAE.generate_sample(z_sample)
                generate_data = x_decoded[0].reshape(digit_size, digit_size,
                                                     digit_size)
                digit_axial = generate_data[patch_center, :, :]
                digit_coronal = generate_data[:, patch_center, :]
                digit_sagital = generate_data[:, :, patch_center]
                digit1 = np.reshape(digit_axial, [patch_side, patch_side])
                digit2 = np.reshape(digit_coronal, [patch_side, patch_side])
                digit3 = np.reshape(digit_sagital, [patch_side, patch_side])
                fig2 = plt.imshow(digit_axial,
                                  cmap='Greys_r',
                                  vmin=0,
                                  vmax=1,
                                  interpolation='none')
                plt.savefig(FLAGS.outdir + 'morphing/' + str(i) + '@' +
                            str(j) + 'fig.png')
                figure1[i * digit_size:(i + 1) * digit_size,
                        j * digit_size:(j + 1) * digit_size] = digit1
                figure2[i * digit_size:(i + 1) * digit_size,
                        j * digit_size:(j + 1) * digit_size] = digit2
                figure3[i * digit_size:(i + 1) * digit_size,
                        j * digit_size:(j + 1) * digit_size] = digit3

        # set graph
        start_range = digit_size // 2
        end_range = n * digit_size + start_range + 1
        pixel_range = np.arange(start_range, end_range, digit_size)
        sample_range_x = np.round(grid_x, 1)
        sample_range_y = np.round(grid_y, 1)

        # axial
        plt.figure(figsize=(10, 10))
        plt.xticks(pixel_range, sample_range_x)
        plt.yticks(pixel_range, sample_range_y)
        plt.xlabel("z[0]")
        plt.ylabel("z[1]")
        plt.imshow(figure1,
                   cmap='Greys_r',
                   vmin=0,
                   vmax=1,
                   interpolation='none')
        plt.savefig(FLAGS.outdir + "digit_axial.png")
        # plt.show()

        # coronal
        plt.figure(figsize=(10, 10))
        plt.xticks(pixel_range, sample_range_x)
        plt.yticks(pixel_range, sample_range_y)
        plt.xlabel("z[0]")
        plt.ylabel("z[1]")
        plt.imshow(figure2,
                   cmap='Greys_r',
                   vmin=0,
                   vmax=1,
                   interpolation='none')
        plt.savefig(FLAGS.outdir + "digit_coronal.png")
        # plt.show()

        # sagital
        plt.figure(figsize=(10, 10))
        plt.xticks(pixel_range, sample_range_x)
        plt.yticks(pixel_range, sample_range_y)
        plt.xlabel("z[0]")
        plt.ylabel("z[1]")
        plt.imshow(figure3,
                   cmap='Greys_r',
                   vmin=0,
                   vmax=1,
                   interpolation='none')
        plt.savefig(FLAGS.outdir + "digit_sagital.png")
Exemple #4
0
def main():
    # tf flag
    flags = tf.flags
    flags.DEFINE_string(
        "model",
        'G:/experiment_result/liver/VAE/set_4/down_64/RBF/alpha_0.1/4/beta_10/model/model_{}'
        .format(1350), "model")
    flags.DEFINE_string(
        "outdir",
        'G:/experiment_result/liver/VAE/set_4/down_64/RBF/alpha_0.1/4/beta_10/random',
        "outdir")
    flags.DEFINE_string("gpu_index", "0", "GPU-index")
    flags.DEFINE_float("beta", 1.0, "hyperparameter beta")
    flags.DEFINE_integer("batch_size", 1, "batch size")
    flags.DEFINE_integer("latent_dim", 2, "latent dim")
    flags.DEFINE_list("image_size", [56, 72, 88, 1], "image size")
    FLAGS = flags.FLAGS

    # check folder
    if not (os.path.exists(FLAGS.outdir)):
        os.makedirs(FLAGS.outdir)

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config=utils.config(index=FLAGS.gpu_index)) as sess:

        # set network
        kwargs = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_resblock_bn,
            'decoder': decoder_resblock_bn,
            'downsampling': down_sampling,
            'upsampling': up_sampling,
            'is_training': False,
            'is_down': False
        }
        VAE = Variational_Autoencoder(**kwargs)

        sess.run(init_op)

        # testing
        VAE.restore_model(FLAGS.model)

        # 2 dim vis
        for j in range(-2, 3):
            for i in range(-2, 3):
                mean = [0.37555057, 0.8882291]
                var = [32.121346, 24.540127]

                sample_z = [[i, j]]
                sample_z = np.asarray(mean) + np.sqrt(
                    np.asarray(var)) * sample_z
                generate_data = VAE.generate_sample(sample_z)
                generate_data = generate_data[0, :, :, :, 0]

                # EUDT
                generate_data = generate_data.astype(np.float32)
                eudt_image = sitk.GetImageFromArray(generate_data)
                eudt_image.SetSpacing([1, 1, 1])
                eudt_image.SetOrigin([0, 0, 0])

                # label
                label = np.where(generate_data > 0.5, 0, 1)
                label = label.astype(np.int16)
                label_image = sitk.GetImageFromArray(label)
                label_image.SetSpacing([1, 1, 1])
                label_image.SetOrigin([0, 0, 0])

                io.write_mhd_and_raw(
                    label_image, '{}.mhd'.format(
                        os.path.join(FLAGS.outdir, '2_dim',
                                     str(i) + '_' + str(j))))
Exemple #5
0
def main():
    # tf flag
    flags = tf.flags
    flags.DEFINE_string("ground_truth_txt", 'F:/data_info/VAE_liver/set_5/PCA/alpha_0.1/fold_1/test_label.txt', "ground truth txt")
    flags.DEFINE_string("indir", 'G:/experiment_result/liver/VAE/set_5/down/64/alpha_0.1/fold_1/VAE/axis_4/beta_6', "input dir")
    flags.DEFINE_string("outdir", 'G:/experiment_result/liver/VAE/set_5/down/64/alpha_0.1/fold_1/VAE/axis_4/beta_6/random', "outdir")
    flags.DEFINE_integer("model_index", 3450 ,"index of model")
    flags.DEFINE_string("gpu_index", "0", "GPU-index")
    flags.DEFINE_float("beta", 1.0, "hyperparameter beta")
    flags.DEFINE_integer("num_of_generate", 1000, "number of generate data")
    flags.DEFINE_integer("batch_size", 1, "batch size")
    flags.DEFINE_integer("latent_dim", 4, "latent dim")
    flags.DEFINE_list("image_size", [56, 72, 88, 1], "image size")
    FLAGS = flags.FLAGS

    np.random.seed(1)

    # check folder
    if not (os.path.exists(FLAGS.outdir)):
        os.makedirs(FLAGS.outdir)

    # load ground truth
    ground_truth = io.load_matrix_data(FLAGS.ground_truth_txt, 'int32')
    print(ground_truth.shape)

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config = utils.config(index=FLAGS.gpu_index)) as sess:

        # set network
        kwargs = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_resblock_bn,
            'decoder': decoder_resblock_bn,
            'downsampling': down_sampling,
            'upsampling': up_sampling,
            'is_training': False,
            'is_down': False
        }
        VAE = Variational_Autoencoder(**kwargs)

        sess.run(init_op)

        # testing
        VAE.restore_model(os.path.join(FLAGS.indir,'model','model_{}'.format(FLAGS.model_index)))
        mean = np.loadtxt(os.path.join(FLAGS.indir, 'mean_{}.txt'.format(FLAGS.model_index)))
        var = np.loadtxt(os.path.join(FLAGS.indir, 'var_{}.txt'.format(FLAGS.model_index)))
        specificity = []

        tbar = tqdm(range(FLAGS.num_of_generate), ascii=True)
        for i in tbar:
            sample_z = np.random.normal(0, 1.0, (1, FLAGS.latent_dim))
            sample_z = np.asarray(mean) + np.sqrt(np.asarray(var)) * sample_z
            generate_data = VAE.generate_sample(sample_z)
            generate_data = generate_data[0, :, :, :, 0]

            # EUDT
            eudt_image = sitk.GetImageFromArray(generate_data)
            eudt_image.SetSpacing([1, 1, 1])
            eudt_image.SetOrigin([0, 0, 0])

            # label
            label = np.where(generate_data > 0.5, 0, 1)
            label = label.astype(np.int8)
            label_image = sitk.GetImageFromArray(label)
            label_image.SetSpacing([1, 1, 1])
            label_image.SetOrigin([0, 0, 0])

            # # calculate ji
            case_max_ji = 0.
            for image_index in range(ground_truth.shape[0]):
                ji = utils.jaccard(label, ground_truth[image_index])
                if ji > case_max_ji:
                    case_max_ji = ji
            specificity.append([case_max_ji])

            # # output image
            # io.write_mhd_and_raw(eudt_image, '{}.mhd'.format(os.path.join(FLAGS.outdir, 'EUDT', str(i+1))))
            # io.write_mhd_and_raw(label_image, '{}.mhd'.format(os.path.join(FLAGS.outdir, 'label', str(i + 1))))

    print('specificity = %f' % np.mean(specificity))

    # # output csv file
    with open(os.path.join(FLAGS.outdir, 'specificity_{}.csv'.format(FLAGS.model_index)), 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerows(specificity)
        writer.writerow(['specificity:', np.mean(specificity)])