def visualize():
    FLAGS = get_args()
    plot_size = 20

    valid_data = MNISTData('test',
                           data_dir=DATA_PATH,
                           shuffle=True,
                           pf=preprocess_im,
                           batch_dict_name=['im', 'label'])
    valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize)

    with tf.variable_scope('VAE') as scope:
        model = VAE(n_code=FLAGS.ncode, wd=0)
        model.create_train_model()

    with tf.variable_scope('VAE') as scope:
        scope.reuse_variables()
        valid_model = VAE(n_code=FLAGS.ncode, wd=0)
        valid_model.create_generate_model(b_size=400)

    visualizer = Visualizer(model, save_path=SAVE_PATH)
    generator = Generator(generate_model=valid_model, save_path=SAVE_PATH)

    z = distribution.interpolate(plot_size=plot_size)
    z = np.reshape(z, (plot_size * plot_size, 2))

    sessconfig = tf.ConfigProto()
    sessconfig.gpu_options.allow_growth = True
    with tf.Session(config=sessconfig) as sess:
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, '{}vae-epoch-{}'.format(SAVE_PATH, FLAGS.load))
        visualizer.viz_2Dlatent_variable(sess, valid_data)
        generator.generate_samples(sess, plot_size=plot_size, z=z)
Пример #2
0
def loadMNIST(data_dir='', sample_per_class=12):
    def normalize_im(im):
        return np.clip(im / 255.0, 0., 1.)

    train_data = MNISTData('train',
                           n_class=10,
                           data_dir=data_dir,
                           batch_dict_name=['im', 'label'],
                           shuffle=True,
                           pf=normalize_im)
    train_data.setup(epoch_val=0,
                     sample_n_class=10,
                     sample_per_class=sample_per_class)

    valid_data = MNISTData('test',
                           n_class=10,
                           data_dir=data_dir,
                           batch_dict_name=['im', 'label'],
                           shuffle=True,
                           pf=normalize_im)
    valid_data.setup(epoch_val=0,
                     sample_n_class=10,
                     sample_per_class=sample_per_class)

    return train_data, valid_data
def train():
    FLAGS = get_args()
    train_data = MNISTData('train',
                           data_dir=DATA_PATH,
                           shuffle=True,
                           pf=preprocess_im,
                           batch_dict_name=['im', 'label'])
    train_data.setup(epoch_val=0, batch_size=FLAGS.bsize)
    valid_data = MNISTData('test',
                           data_dir=DATA_PATH,
                           shuffle=True,
                           pf=preprocess_im,
                           batch_dict_name=['im', 'label'])
    valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize)

    with tf.variable_scope('VAE') as scope:
        model = VAE(n_code=FLAGS.ncode, wd=0)
        model.create_train_model()

    with tf.variable_scope('VAE') as scope:
        scope.reuse_variables()
        valid_model = VAE(n_code=FLAGS.ncode, wd=0)
        valid_model.create_generate_model(b_size=400)

    trainer = Trainer(model,
                      valid_model,
                      train_data,
                      init_lr=FLAGS.lr,
                      save_path=SAVE_PATH)
    if FLAGS.ncode == 2:
        z = distribution.interpolate(plot_size=20)
        z = np.reshape(z, (400, 2))
        visualizer = Visualizer(model, save_path=SAVE_PATH)
    else:
        z = None
    generator = Generator(generate_model=valid_model, save_path=SAVE_PATH)

    sessconfig = tf.ConfigProto()
    sessconfig.gpu_options.allow_growth = True
    with tf.Session(config=sessconfig) as sess:
        writer = tf.summary.FileWriter(SAVE_PATH)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)

        for epoch_id in range(FLAGS.maxepoch):
            trainer.train_epoch(sess, summary_writer=writer)
            trainer.valid_epoch(sess, summary_writer=writer)
            if epoch_id % 10 == 0:
                saver.save(sess, '{}vae-epoch-{}'.format(SAVE_PATH, epoch_id))
                if FLAGS.ncode == 2:
                    generator.generate_samples(sess,
                                               plot_size=20,
                                               z=z,
                                               file_id=epoch_id)
                    visualizer.viz_2Dlatent_variable(sess,
                                                     valid_data,
                                                     file_id=epoch_id)
def read_valid_data(batch_size):
    """ Function for load validation data """
    data = MNISTData('test',
                     data_dir=DATA_PATH,
                     shuffle=True,
                     pf=preprocess_im,
                     batch_dict_name=['im', 'label'])
    data.setup(epoch_val=0, batch_size=batch_size)
    return data
Пример #5
0
def load_mnist(batch_size,
               data_path,
               shuffle=True,
               n_use_label=None,
               n_use_sample=None,
               rescale_size=None):
    """ Function for load training data 

    If n_use_label or n_use_sample is not None, samples will be
    randomly picked to have a balanced number of examples

    Args:
        batch_size (int): batch size
        n_use_label (int): how many labels are used for training
        n_use_sample (int): how many samples are used for training

    Retuns:
        MNISTData dataflow
    """

    # data_path = '/home/qge2/workspace/data/MNIST_data/'

    def preprocess_im(im):
        """ normalize input image to [-1., 1.] """
        if rescale_size is not None:
            im = np.squeeze(im, axis=-1)
            im = skimage.transform.resize(im, [rescale_size, rescale_size],
                                          mode='constant',
                                          preserve_range=True)
            im = np.expand_dims(im, axis=-1)
        im = im / 255. * 2. - 1.

        return np.clip(im, -1., 1.)

    data = MNISTData('train',
                     data_dir=data_path,
                     shuffle=shuffle,
                     pf=preprocess_im,
                     n_use_label=n_use_label,
                     n_use_sample=n_use_sample,
                     batch_dict_name=['im', 'label'])
    data.setup(epoch_val=0, batch_size=batch_size)
    return data
def visualize():
    """ function for visualize latent space of trained model when ncode = 2 """
    FLAGS = get_args()
    if FLAGS.ncode != 2:
        raise ValueError('Visualization only for ncode = 2!')

    plot_size = 20

    # read validation set
    valid_data = MNISTData('test',
                           data_dir=DATA_PATH,
                           shuffle=True,
                           pf=preprocess_im,
                           batch_dict_name=['im', 'label'])
    valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize)

    # create model for computing the latent z
    model = AAE(n_code=FLAGS.ncode, use_label=FLAGS.label, n_class=10)
    model.create_train_model()

    # create model for sampling images
    valid_model = AAE(n_code=FLAGS.ncode)
    valid_model.create_generate_model(b_size=400)

    # initialize Visualizer and Generator
    visualizer = Visualizer(model, save_path=SAVE_PATH)
    generator = Generator(generate_model=valid_model,
                          save_path=SAVE_PATH,
                          distr_type=FLAGS.dist_type,
                          n_labels=10,
                          use_label=FLAGS.label)

    sessconfig = tf.ConfigProto()
    sessconfig.gpu_options.allow_growth = True
    with tf.Session(config=sessconfig) as sess:
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, '{}aae-epoch-{}'.format(SAVE_PATH, FLAGS.load))
        # visulize the learned latent space
        visualizer.viz_2Dlatent_variable(sess, valid_data)
        # visulize the learned manifold
        generator.generate_samples(sess, plot_size=plot_size, manifold=True)
def read_train_data(batch_size, n_use_label=None, n_use_sample=None):
    """ Function for load training data 

    If n_use_label or n_use_sample is not None, samples will be
    randomly picked to have a balanced number of examples

    Args:
        batch_size (int): batch size
        n_use_label (int): how many labels are used for training
        n_use_sample (int): how many samples are used for training

    Retuns:
        MNISTData

    """
    data = MNISTData('train',
                     data_dir=DATA_PATH,
                     shuffle=True,
                     pf=preprocess_im,
                     n_use_label=n_use_label,
                     n_use_sample=n_use_sample,
                     batch_dict_name=['im', 'label'])
    data.setup(epoch_val=0, batch_size=batch_size)
    return data