Example #1
0
def main():

    # tf flag
    flags = tf.flags
    flags.DEFINE_string(
        "train_data_txt",
        'F:/data_info/VAE_liver/set_5/TFrecord/fold_1/train.txt',
        "train data txt")
    flags.DEFINE_string(
        "outdir",
        'G:/experiment_result/liver/VAE/set_5/down/64/alpha_0.1/fold_1/beta_1',
        "outdir")
    flags.DEFINE_string("gpu_index", "0", "GPU-index")
    flags.DEFINE_float("beta", 1, "hyperparameter beta")
    flags.DEFINE_integer("batch_size", 12, "batch size")
    flags.DEFINE_integer("num_iteration", 12001, "number of iteration")
    flags.DEFINE_integer("save_loss_step", 150, "step of save loss")
    flags.DEFINE_integer("save_model_step", 150,
                         "step of save model and validation")
    flags.DEFINE_integer("shuffle_buffer_size", 200, "buffer size of shuffle")
    flags.DEFINE_integer("latent_dim", 4, "latent dim")
    flags.DEFINE_list("image_size", [56, 72, 88, 1], "image size")
    FLAGS = flags.FLAGS

    # check folder
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'tensorboard'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'tensorboard'))
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'model'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'model'))

    # read list
    train_data_list = io.load_list(FLAGS.train_data_txt)
    # shuffle list
    random.shuffle(train_data_list)

    # load train data
    train_set = tf.data.Dataset.list_files(train_data_list)
    train_set = train_set.apply(
        tf.contrib.data.parallel_interleave(
            lambda x: tf.data.TFRecordDataset(x, compression_type='GZIP'),
            cycle_length=os.cpu_count()))
    train_set = train_set.map(
        lambda x: utils._parse_function(x, image_size=FLAGS.image_size),
        num_parallel_calls=os.cpu_count())
    # train_set = train_set.cache()
    train_set = train_set.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
    train_set = train_set.repeat()
    train_set = train_set.batch(FLAGS.batch_size)
    train_iter = train_set.make_one_shot_iterator()
    train_data = train_iter.get_next()

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config=utils.config(index=FLAGS.gpu_index)) as sess:
        # # set network

        kwargs = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_resblock_bn,
            'decoder': decoder_resblock_bn,
            'downsampling': down_sampling,
            'upsampling': up_sampling,
            'learning_rate': 1e-4,
            'is_training': True,
            'is_down': False
        }
        VAE = Variational_Autoencoder(**kwargs)

        # print parmeters
        utils.cal_parameter()

        # prepare tensorboard
        writer_train = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'train'), sess.graph)
        writer_rec = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'train_rec'))
        writer_kl = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'train_kl'))

        value_loss = tf.Variable(0.0)
        tf.summary.scalar("loss", value_loss)
        merge_op = tf.summary.merge_all()

        # initialize
        sess.run(init_op)

        # # training
        tbar = tqdm(range(FLAGS.num_iteration), ascii=True)
        epoch_train_loss = []
        epoch_kl_loss = []
        epoch_rec_loss = []
        for i in tbar:
            train_data_batch = sess.run(train_data)

            train_loss, rec_loss, kl_loss = VAE.update(train_data_batch)
            epoch_train_loss.append(train_loss)
            epoch_kl_loss.append(kl_loss)
            epoch_rec_loss.append(rec_loss)

            if i % FLAGS.save_loss_step is 0:
                s = "Loss: {:.4f}, kl_loss: {:.4f}, rec_loss: {:.4f}"\
                    .format(np.mean(epoch_train_loss), np.mean(epoch_kl_loss), np.mean(epoch_rec_loss))
                tbar.set_description(s)

                summary_train_loss = sess.run(merge_op,
                                              {value_loss: train_loss})
                summary_rec_loss = sess.run(merge_op, {value_loss: rec_loss})
                summary_kl_loss = sess.run(merge_op, {value_loss: kl_loss})
                writer_train.add_summary(summary_train_loss, i)
                writer_rec.add_summary(summary_rec_loss, i)
                writer_kl.add_summary(summary_kl_loss, i)

                epoch_train_loss.clear()
                epoch_kl_loss.clear()
                epoch_rec_loss.clear()

            if i % FLAGS.save_model_step is 0:
                # save model
                VAE.save_model(i)
Example #2
0
def main():

    # tf flag
    flags = tf.flags
    flags.DEFINE_string("train_data_txt", "./train.txt", "train data txt")
    flags.DEFINE_string("val_data_txt", "./val.txt", "validation data txt")
    flags.DEFINE_string("outdir", "./output/", "outdir")
    flags.DEFINE_float("beta", 1, "hyperparameter beta")
    flags.DEFINE_integer("num_of_val", 600, "number of validation data")
    flags.DEFINE_integer("batch_size", 30, "batch size")
    flags.DEFINE_integer("num_iteration", 500001, "number of iteration")
    flags.DEFINE_integer("save_loss_step", 200, "step of save loss")
    flags.DEFINE_integer("save_model_step", 500,
                         "step of save model and validation")
    flags.DEFINE_integer("shuffle_buffer_size", 1000, "buffer size of shuffle")
    flags.DEFINE_integer("latent_dim", 6, "latent dim")
    flags.DEFINE_list("image_size", [9 * 9 * 9], "image size")
    flags.DEFINE_string("model", './model/model_{}', "pre training model1")
    flags.DEFINE_string("model2", './model/model_{}', "pre training model2")
    flags.DEFINE_boolean("is_n1_opt", True, "n1_opt")
    FLAGS = flags.FLAGS

    # check folder
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'tensorboard',
                                        'train'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'tensorboard', 'train'))
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'tensorboard', 'val'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'tensorboard', 'val'))
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'tensorboard', 'rec'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'tensorboard', 'rec'))
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'tensorboard', 'kl'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'tensorboard', 'kl'))
    if not (os.path.exists(os.path.join(FLAGS.outdir, 'model'))):
        os.makedirs(os.path.join(FLAGS.outdir, 'model'))

    # read list
    train_data_list = io.load_list(FLAGS.train_data_txt)
    val_data_list = io.load_list(FLAGS.val_data_txt)

    # shuffle list
    random.shuffle(train_data_list)
    # val step
    val_step = FLAGS.num_of_val // FLAGS.batch_size
    if FLAGS.num_of_val % FLAGS.batch_size != 0:
        val_step += 1

    # load train data and validation data
    train_set = tf.data.Dataset.list_files(train_data_list)
    train_set = train_set.apply(
        tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,
                                            cycle_length=6))
    # train_set = tf.data.TFRecordDataset(train_data_list)
    train_set = train_set.map(
        lambda x: _parse_function(x, image_size=FLAGS.image_size),
        num_parallel_calls=os.cpu_count())
    train_set = train_set.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
    train_set = train_set.repeat()
    train_set = train_set.batch(FLAGS.batch_size)
    train_iter = train_set.make_one_shot_iterator()
    train_data = train_iter.get_next()

    val_set = tf.data.Dataset.list_files(val_data_list)
    val_set = val_set.apply(
        tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset,
                                            cycle_length=os.cpu_count()))
    # val_set = tf.data.TFRecordDataset(val_data_list)
    val_set = val_set.map(
        lambda x: _parse_function(x, image_size=FLAGS.image_size),
        num_parallel_calls=os.cpu_count())
    val_set = val_set.repeat()
    val_set = val_set.batch(FLAGS.batch_size)
    val_iter = val_set.make_one_shot_iterator()
    val_data = val_iter.get_next()

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config=utils.config) as sess:
        # with tf.Session() as sess:
        # set network
        kwargs = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_mlp,
            'decoder': decoder_mlp,
            'is_res': False
        }
        VAE = Variational_Autoencoder(**kwargs)

        kwargs_2 = {
            'sess': sess,
            'outdir': FLAGS.outdir,
            'beta': FLAGS.beta,
            'latent_dim': 8,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': encoder_mlp2,
            'decoder': decoder_mlp_tanh,
            'is_res': True,
            'is_constraints': False,
            # 'keep_prob': 0.5
        }

        VAE_2 = Variational_Autoencoder(**kwargs_2)
        # print parmeters
        utils.cal_parameter()

        # prepare tensorboard
        writer_train = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'train'), sess.graph)
        writer_val = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'val'))
        writer_rec = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'rec'))
        writer_kl = tf.summary.FileWriter(
            os.path.join(FLAGS.outdir, 'tensorboard', 'kl'))

        value_loss = tf.Variable(0.0)
        tf.summary.scalar("loss", value_loss)
        merge_op = tf.summary.merge_all()

        # initialize
        sess.run(init_op)

        # use pre trained model
        # ckpt_state = tf.train.get_checkpoint_state(FLAGS.model)
        #
        # if ckpt_state:
        #     restore_model = ckpt_state.model_checkpoint_path
        #     # VAE.restore_model(FLAGS.model+'model_{}'.format(FLAGS.itr))
        VAE.restore_model(FLAGS.model)
        if FLAGS.is_n1_opt == True:
            VAE_2.restore_model(FLAGS.model2)

        # training
        tbar = tqdm(range(FLAGS.num_iteration), ascii=True)
        for i in tbar:
            train_data_batch = sess.run(train_data)
            if FLAGS.is_n1_opt == True:
                VAE.update(train_data_batch)

            output1 = VAE.reconstruction_image(train_data_batch)

            train_loss, rec_loss, kl_loss = VAE_2.update2(
                train_data_batch, output1)

            if i % FLAGS.save_loss_step is 0:
                s = "Loss: {:.4f}, rec_loss: {:.4f}, kl_loss: {:.4f}".format(
                    train_loss, rec_loss, kl_loss)
                tbar.set_description(s)
                summary_train_loss = sess.run(merge_op,
                                              {value_loss: train_loss})
                writer_train.add_summary(summary_train_loss, i)

                summary_rec_loss = sess.run(merge_op, {value_loss: rec_loss})
                summary_kl_loss = sess.run(merge_op, {value_loss: kl_loss})
                writer_rec.add_summary(summary_rec_loss, i)
                writer_kl.add_summary(summary_kl_loss, i)

            if i % FLAGS.save_model_step is 0:
                # save model
                VAE.save_model(i)
                VAE_2.save_model2(i)

                # validation
                val_loss = 0.
                for j in range(val_step):
                    val_data_batch = sess.run(val_data)
                    val_data_batch_output1 = VAE.reconstruction_image(
                        val_data_batch)

                    val_loss += VAE_2.validation2(val_data_batch,
                                                  val_data_batch_output1)
                val_loss /= val_step

                summary_val = sess.run(merge_op, {value_loss: val_loss})
                writer_val.add_summary(summary_val, i)
Example #3
0
def main():
    parser = argparse.ArgumentParser(description='py, train_data_txt, val_data_txt, outdir')

    parser.add_argument('--train_data_txt', '-i1', default='', help='train data txt')

    parser.add_argument('--val_data_txt', '-i2', default='', help='validation data txt')

    parser.add_argument('--outdir', '-i3', default='./beta_0.1', help='outdir')

    args = parser.parse_args()

    # check folder
    if not (os.path.exists(os.path.join(args.outdir, 'tensorboard', 'train'))):
        os.makedirs(os.path.join(args.outdir, 'tensorboard', 'train'))
    if not (os.path.exists(os.path.join(args.outdir, 'tensorboard', 'val'))):
        os.makedirs(os.path.join(args.outdir, 'tensorboard', 'val'))
    if not (os.path.exists(os.path.join(args.outdir, 'tensorboard', 'rec'))):
        os.makedirs(os.path.join(args.outdir, 'tensorboard', 'rec'))
    if not (os.path.exists(os.path.join(args.outdir, 'tensorboard', 'kl'))):
        os.makedirs(os.path.join(args.outdir, 'tensorboard', 'kl'))
    if not (os.path.exists(os.path.join(args.outdir, 'model'))):
        os.makedirs(os.path.join(args.outdir, 'model'))

    # tf flag
    flags = tf.flags
    flags.DEFINE_float("beta", 0.1, "hyperparameter beta")
    flags.DEFINE_integer("num_of_val", 1000, "number of validation data")
    flags.DEFINE_integer("batch_size", 30, "batch size")
    flags.DEFINE_integer("num_iteration", 50001, "number of iteration")
    flags.DEFINE_integer("save_loss_step", 50, "step of save loss")
    flags.DEFINE_integer("save_model_step", 500, "step of save model and validation")
    flags.DEFINE_integer("shuffle_buffer_size", 10000, "buffer size of shuffle")
    flags.DEFINE_integer("latent_dim", 2, "latent dim")
    flags.DEFINE_list("image_size", [512, 512, 1], "image size")
    FLAGS = flags.FLAGS

    # read list
    train_data_list = io.load_list(args.train_data_txt)
    val_data_list = io.load_list(args.val_data_txt)

    # shuffle list
    random.shuffle(train_data_list)
    # val step
    val_step = FLAGS.num_of_val // FLAGS.batch_size
    if FLAGS.num_of_val % FLAGS.batch_size != 0:
        val_step += 1

    # load train data and validation data
    train_set = tf.data.TFRecordDataset(train_data_list)
    train_set = train_set.map(lambda x: _parse_function(x, image_size=FLAGS.image_size),
                              num_parallel_calls=os.cpu_count())
    train_set = train_set.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
    train_set = train_set.repeat()
    train_set = train_set.batch(FLAGS.batch_size)
    train_iter = train_set.make_one_shot_iterator()
    train_data = train_iter.get_next()

    val_set = tf.data.TFRecordDataset(val_data_list)
    val_set = val_set.map(lambda x: _parse_function(x, image_size=FLAGS.image_size),
                          num_parallel_calls=os.cpu_count())
    val_set = val_set.repeat()
    val_set = val_set.batch(FLAGS.batch_size)
    val_iter = val_set.make_one_shot_iterator()
    val_data = val_iter.get_next()

    # initializer
    init_op = tf.group(tf.initializers.global_variables(),
                       tf.initializers.local_variables())

    with tf.Session(config = utils.config) as sess:
        # set network
        kwargs = {
            'sess': sess,
            'outdir': args.outdir,
            'beta': FLAGS.beta,
            'latent_dim': FLAGS.latent_dim,
            'batch_size': FLAGS.batch_size,
            'image_size': FLAGS.image_size,
            'encoder': cnn_encoder,
            'decoder': cnn_decoder
        }
        VAE = Variational_Autoencoder(**kwargs)

        # print parmeters
        utils.cal_parameter()

        # prepare tensorboard
        writer_train = tf.summary.FileWriter(os.path.join(args.outdir, 'tensorboard', 'train'), sess.graph)
        writer_val = tf.summary.FileWriter(os.path.join(args.outdir, 'tensorboard', 'val'))
        writer_rec = tf.summary.FileWriter(os.path.join(args.outdir, 'tensorboard', 'rec'))
        writer_kl = tf.summary.FileWriter(os.path.join(args.outdir, 'tensorboard', 'kl'))

        value_loss = tf.Variable(0.0)
        tf.summary.scalar("loss", value_loss)
        merge_op = tf.summary.merge_all()

        # initialize
        sess.run(init_op)

        # training
        tbar = tqdm(range(FLAGS.num_iteration), ascii=True)
        for i in tbar:
            train_data_batch = sess.run(train_data)
            train_loss, rec_loss, kl_loss = VAE.update(train_data_batch)

            if i % FLAGS.save_loss_step is 0:
                s = "Loss: {:.4f}, rec_loss: {:.4f}, kl_loss: {:.4f}".format(train_loss, rec_loss, kl_loss)
                tbar.set_description(s)
                summary_train_loss = sess.run(merge_op, {value_loss: train_loss})
                writer_train.add_summary(summary_train_loss, i)

                summary_rec_loss = sess.run(merge_op, {value_loss: rec_loss})
                summary_kl_loss = sess.run(merge_op, {value_loss: kl_loss})
                writer_rec.add_summary(summary_rec_loss, i)
                writer_kl.add_summary(summary_kl_loss, i)


            if i % FLAGS.save_model_step is 0:
                # save model
                VAE.save_model(i)

                # validation
                val_loss = 0.
                for j in range(val_step):
                    val_data_batch = sess.run(val_data)
                    val_loss += VAE.validation(val_data_batch)
                val_loss /= val_step

                summary_val = sess.run(merge_op, {value_loss: val_loss})
                writer_val.add_summary(summary_val, i)