예제 #1
0
def main():
    parser = argparse.ArgumentParser(
        description='py, test_data_txt, model, outdir')

    parser.add_argument('--test_data_txt', '-i1', default='')

    parser.add_argument('--model', '-i2', default='./model_{}'.format(50000))

    parser.add_argument('--outdir', '-i3', default='')

    args = parser.parse_args()

    # check folder
    if not (os.path.exists(args.outdir)):
        os.makedirs(args.outdir)

    # tf flag
    flags = tf.flags
    flags.DEFINE_float("beta", 0.1, "hyperparameter beta")
    flags.DEFINE_integer("num_of_test", 100, "number of test data")
    flags.DEFINE_integer("batch_size", 1, "batch size")
    flags.DEFINE_integer("latent_dim", 2, "latent dim")
    flags.DEFINE_list("image_size", [512, 512, 1], "image size")
    FLAGS = flags.FLAGS

    # read list
    test_data_list = io.load_list(args.test_data_txt)

    # test step
    test_step = FLAGS.num_of_test // FLAGS.batch_size
    if FLAGS.num_of_test % FLAGS.batch_size != 0:
        test_step += 1

    # load test data
    test_set = tf.data.TFRecordDataset(test_data_list)
    test_set = test_set.map(
        lambda x: _parse_function(x, image_size=FLAGS.image_size),
        num_parallel_calls=os.cpu_count())
    test_set = test_set.batch(FLAGS.batch_size)
    test_iter = test_set.make_one_shot_iterator()
    test_data = test_iter.get_next()

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config=utils.config) as sess:

        # set network
        kwargs = {
            'sess': sess,
            'outdir': args.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': cnn_encoder,
            'decoder': cnn_decoder
        }
        VAE = Variational_Autoencoder(**kwargs)

        sess.run(init_op)

        # testing
        VAE.restore_model(args.model)
        tbar = tqdm(range(test_step), ascii=True)
        preds = []
        ori = []
        for k in tbar:
            test_data_batch = sess.run(test_data)
            ori_single = test_data_batch
            preds_single = VAE.reconstruction_image(ori_single)
            preds_single = preds_single[0, :, :, 0]
            ori_single = ori_single[0, :, :, 0]

            preds.append(preds_single)
            ori.append(ori_single)

        # # label
        ji = []
        for j in range(len(preds)):

            # EUDT
            eudt_image = sitk.GetImageFromArray(preds[j])
            eudt_image.SetSpacing([1, 1])
            eudt_image.SetOrigin([0, 0])

            label = np.where(preds[j] > 0, 0, 1)
            label_image = sitk.GetImageFromArray(label)
            label_image.SetSpacing([1, 1])
            label_image.SetOrigin([0, 0])

            ori_label = np.where(ori[j] > 0, 0, 1)
            ori_label_image = sitk.GetImageFromArray(ori_label)
            ori_label_image.SetSpacing([1, 1])
            ori_label_image.SetOrigin([0, 0])

            # # calculate ji
            ji.append(utils.jaccard(label, ori_label))

            # output image
            io.write_mhd_and_raw(
                eudt_image, '{}.mhd'.format(
                    os.path.join(args.outdir, 'EUDT', 'recon_{}'.format(j))))
            io.write_mhd_and_raw(
                label_image, '{}.mhd'.format(
                    os.path.join(args.outdir, 'label', 'recon_{}'.format(j))))

    generalization = np.mean(ji)
    print('generalization = %f' % generalization)

    # output csv file
    with open(os.path.join(args.outdir, 'generalization.csv'), 'w',
              newline='') as file:
        writer = csv.writer(file)
        writer.writerows(ji)
        writer.writerow(['generalization= ', generalization])
예제 #2
0
def main():

    # tf flag
    flags = tf.flags
    flags.DEFINE_string("train_data_txt", "./train.txt", "train data txt")
    flags.DEFINE_string("val_data_txt", "./val.txt", "validation data txt")
    flags.DEFINE_string("outdir", "./output/", "outdir")
    flags.DEFINE_float("beta", 1, "hyperparameter beta")
    flags.DEFINE_integer("num_of_val", 600, "number of validation data")
    flags.DEFINE_integer("batch_size", 30, "batch size")
    flags.DEFINE_integer("num_iteration", 500001, "number of iteration")
    flags.DEFINE_integer("save_loss_step", 200, "step of save loss")
    flags.DEFINE_integer("save_model_step", 500,
                         "step of save model and validation")
    flags.DEFINE_integer("shuffle_buffer_size", 1000, "buffer size of shuffle")
    flags.DEFINE_integer("latent_dim", 6, "latent dim")
    flags.DEFINE_list("image_size", [9 * 9 * 9], "image size")
    flags.DEFINE_string("model", './model/model_{}', "pre training model1")
    flags.DEFINE_string("model2", './model/model_{}', "pre training model2")
    flags.DEFINE_boolean("is_n1_opt", True, "n1_opt")
    FLAGS = flags.FLAGS

    # check folder
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'tensorboard',
                                        'train'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'tensorboard', 'train'))
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'tensorboard', 'val'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'tensorboard', 'val'))
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'tensorboard', 'rec'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'tensorboard', 'rec'))
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'tensorboard', 'kl'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'tensorboard', 'kl'))
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'model'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'model'))

    # read list
    train_data_list = io.load_list(FLAGS.train_data_txt)
    val_data_list = io.load_list(FLAGS.val_data_txt)

    # shuffle list
    random.shuffle(train_data_list)
    # val step
    val_step = FLAGS.num_of_val // FLAGS.batch_size
    if FLAGS.num_of_val % FLAGS.batch_size != 0:
        val_step += 1

    # load train data and validation data
    train_set = tf.data.Dataset.list_files(train_data_list)
    train_set = train_set.apply(
        tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,
                                            cycle_length=6))
    # train_set = tf.data.TFRecordDataset(train_data_list)
    train_set = train_set.map(
        lambda x: _parse_function(x, image_size=FLAGS.image_size),
        num_parallel_calls=os.cpu_count())
    train_set = train_set.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
    train_set = train_set.repeat()
    train_set = train_set.batch(FLAGS.batch_size)
    train_iter = train_set.make_one_shot_iterator()
    train_data = train_iter.get_next()

    val_set = tf.data.Dataset.list_files(val_data_list)
    val_set = val_set.apply(
        tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,
                                            cycle_length=os.cpu_count()))
    # val_set = tf.data.TFRecordDataset(val_data_list)
    val_set = val_set.map(
        lambda x: _parse_function(x, image_size=FLAGS.image_size),
        num_parallel_calls=os.cpu_count())
    val_set = val_set.repeat()
    val_set = val_set.batch(FLAGS.batch_size)
    val_iter = val_set.make_one_shot_iterator()
    val_data = val_iter.get_next()

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config=utils.config) as sess:
        # with tf.Session() as sess:
        # set network
        kwargs = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_mlp,
            'decoder': decoder_mlp,
            'is_res': False
        }
        VAE = Variational_Autoencoder(**kwargs)

        kwargs_2 = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': 8,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_mlp2,
            'decoder': decoder_mlp_tanh,
            'is_res': True,
            'is_constraints': False,
            # 'keep_prob': 0.5
        }

        VAE_2 = Variational_Autoencoder(**kwargs_2)
        # print parmeters
        utils.cal_parameter()

        # prepare tensorboard
        writer_train = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'train'), sess.graph)
        writer_val = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'val'))
        writer_rec = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'rec'))
        writer_kl = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'kl'))

        value_loss = tf.Variable(0.0)
        tf.summary.scalar("loss", value_loss)
        merge_op = tf.summary.merge_all()

        # initialize
        sess.run(init_op)

        # use pre trained model
        # ckpt_state = tf.train.get_checkpoint_state(FLAGS.model)
        #
        # if ckpt_state:
        #     restore_model = ckpt_state.model_checkpoint_path
        #     # VAE.restore_model(FLAGS.model+'model_{}'.format(FLAGS.itr))
        VAE.restore_model(FLAGS.model)
        if FLAGS.is_n1_opt == True:
            VAE_2.restore_model(FLAGS.model2)

        # training
        tbar = tqdm(range(FLAGS.num_iteration), ascii=True)
        for i in tbar:
            train_data_batch = sess.run(train_data)
            if FLAGS.is_n1_opt == True:
                VAE.update(train_data_batch)

            output1 = VAE.reconstruction_image(train_data_batch)

            train_loss, rec_loss, kl_loss = VAE_2.update2(
                train_data_batch, output1)

            if i % FLAGS.save_loss_step is 0:
                s = "Loss: {:.4f}, rec_loss: {:.4f}, kl_loss: {:.4f}".format(
                    train_loss, rec_loss, kl_loss)
                tbar.set_description(s)
                summary_train_loss = sess.run(merge_op,
                                              {value_loss: train_loss})
                writer_train.add_summary(summary_train_loss, i)

                summary_rec_loss = sess.run(merge_op, {value_loss: rec_loss})
                summary_kl_loss = sess.run(merge_op, {value_loss: kl_loss})
                writer_rec.add_summary(summary_rec_loss, i)
                writer_kl.add_summary(summary_kl_loss, i)

            if i % FLAGS.save_model_step is 0:
                # save model
                VAE.save_model(i)
                VAE_2.save_model2(i)

                # validation
                val_loss = 0.
                for j in range(val_step):
                    val_data_batch = sess.run(val_data)
                    val_data_batch_output1 = VAE.reconstruction_image(
                        val_data_batch)

                    val_loss += VAE_2.validation2(val_data_batch,
                                                  val_data_batch_output1)
                val_loss /= val_step

                summary_val = sess.run(merge_op, {value_loss: val_loss})
                writer_val.add_summary(summary_val, i)
예제 #3
0
def main():

    # tf flag
    flags = tf.flags
    flags.DEFINE_string(
        "test_data_txt",
        'F:/data_info/VAE_liver/set_5/TFrecord/fold_1/test.txt',
        "test data txt")
    flags.DEFINE_string(
        "indir",
        'G:/experiment_result/liver/VAE/set_5/down/64/alpha_0.1/fold_1/VAE/axis_5/beta_7',
        "input dir")
    flags.DEFINE_string(
        "outdir",
        'G:/experiment_result/liver/VAE/set_5/down/64/alpha_0.1/fold_1/VAE/axis_5/beta_7/rec',
        "outdir")
    flags.DEFINE_integer("model_index", 3300, "index of model")
    flags.DEFINE_string("gpu_index", "0", "GPU-index")
    flags.DEFINE_float("beta", 1.0, "hyperparameter beta")
    flags.DEFINE_integer("num_of_test", 75, "number of test data")
    flags.DEFINE_integer("batch_size", 1, "batch size")
    flags.DEFINE_integer("latent_dim", 5, "latent dim")
    flags.DEFINE_list("image_size", [56, 72, 88, 1], "image size")
    FLAGS = flags.FLAGS

    # check folder
    if not (os.path.exists(FLAGS.outdir)):
        os.makedirs(FLAGS.outdir)

    # read list
    test_data_list = io.load_list(FLAGS.test_data_txt)

    # test step
    test_step = FLAGS.num_of_test // FLAGS.batch_size
    if FLAGS.num_of_test % FLAGS.batch_size != 0:
        test_step += 1

    # load test data
    test_set = tf.data.TFRecordDataset(test_data_list, compression_type='GZIP')
    test_set = test_set.map(
        lambda x: utils._parse_function(x, image_size=FLAGS.image_size),
        num_parallel_calls=os.cpu_count())
    test_set = test_set.batch(FLAGS.batch_size)
    test_iter = test_set.make_one_shot_iterator()
    test_data = test_iter.get_next()

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config=utils.config(index=FLAGS.gpu_index)) as sess:
        # set network
        kwargs = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_resblock_bn,
            'decoder': decoder_resblock_bn,
            'downsampling': down_sampling,
            'upsampling': up_sampling,
            'is_training': False,
            'is_down': False
        }
        VAE = Variational_Autoencoder(**kwargs)

        sess.run(init_op)

        # testing
        VAE.restore_model(
            os.path.join(FLAGS.indir, 'model',
                         'model_{}'.format(FLAGS.model_index)))
        tbar = tqdm(range(test_step), ascii=True)
        preds = []
        ori = []
        ji = []
        for k in tbar:
            test_data_batch = sess.run(test_data)
            ori_single = test_data_batch
            preds_single = VAE.reconstruction_image(ori_single)
            preds_single = preds_single[0, :, :, :, 0]
            ori_single = ori_single[0, :, :, :, 0]

            preds.append(preds_single)
            ori.append(ori_single)

            # # label
            ji = []
            for j in range(len(preds)):

                # EUDT
                eudt_image = sitk.GetImageFromArray(preds[j])
                eudt_image.SetSpacing([1, 1, 1])
                eudt_image.SetOrigin([0, 0, 0])

                label = np.where(preds[j] > 0.5, 0, 1)
                # label = np.where(preds[j] > 0.5, 1, 0.5)
                label = label.astype(np.int16)
                label_image = sitk.GetImageFromArray(label)
                label_image.SetSpacing([1, 1, 1])
                label_image.SetOrigin([0, 0, 0])

                ori_label = np.where(ori[j] > 0.5, 0, 1)
                ori_label_image = sitk.GetImageFromArray(ori_label)
                ori_label_image.SetSpacing([1, 1, 1])
                ori_label_image.SetOrigin([0, 0, 0])

                # # calculate ji
                ji.append([utils.jaccard(label, ori_label)])

                # output image
                io.write_mhd_and_raw(
                    label_image, '{}.mhd'.format(
                        os.path.join(FLAGS.outdir, 'label',
                                     'recon_{}'.format(j))))

        generalization = np.mean(ji)
        print('generalization = %f' % generalization)

        # # output csv file
        with open(os.path.join(
                FLAGS.outdir,
                'generalization_{}.csv'.format(FLAGS.model_index)),
                  'w',
                  newline='') as file:
            writer = csv.writer(file)
            writer.writerows(ji)
            writer.writerow(['generalization= ', generalization])