示例#1
0
文件: cifar10.py 项目: yyht/tf-vqvae
def test(MODEL, BETA, K, D, **kwargs):
    # >>>>>>> DATASET
    image, _ = get_image(num_epochs=1)
    images = tf.train.batch([image],
                            batch_size=100,
                            num_threads=1,
                            capacity=100,
                            allow_smaller_final_batch=True)
    valid_image, _ = get_image(False, num_epochs=1)
    valid_images = tf.train.batch([valid_image],
                                  batch_size=100,
                                  num_threads=1,
                                  capacity=100,
                                  allow_smaller_final_batch=True)
    # <<<<<<<

    # >>>>>>> MODEL
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        x = tf.placeholder(tf.float32, [None, 32, 32, 3])
        net = VQVAE(None, None, BETA, x, K, D, _cifar10_arch, params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    try:
        nlls = []
        while not coord.should_stop():
            nlls.append(
                sess.run(net.nll, feed_dict={x: sess.run(valid_images)}))
            print('.', end='', flush=True)
    except tf.errors.OutOfRangeError:
        nlls = np.concatenate(nlls, axis=0)
        print(nlls.shape)
        print('NLL for test set: %f bits/dims' % (np.mean(nlls)))

    try:
        nlls = []
        while not coord.should_stop():
            nlls.append(sess.run(net.nll, feed_dict={x: sess.run(images)}))
            print('.', end='', flush=True)
    except tf.errors.OutOfRangeError:
        nlls = np.concatenate(nlls, axis=0)
        print(nlls.shape)
        print('NLL for training set: %f bits/dims' % (np.mean(nlls)))

    coord.request_stop()
    coord.join(threads)
示例#2
0
文件: cifar10.py 项目: yyht/tf-vqvae
def extract_z(MODEL, BATCH_SIZE, BETA, K, D, **kwargs):
    # >>>>>>> DATASET
    image, label = get_image(num_epochs=1)
    images, labels = tf.train.batch([image, label],
                                    batch_size=BATCH_SIZE,
                                    num_threads=1,
                                    capacity=BATCH_SIZE,
                                    allow_smaller_final_batch=True)
    # <<<<<<<

    # >>>>>>> MODEL
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        x_ph = tf.placeholder(tf.float32, [None, 32, 32, 3])
        net = VQVAE(None, None, BETA, x_ph, K, D, _cifar10_arch, params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    try:
        ks = []
        ys = []
        while not coord.should_stop():
            x, y = sess.run([images, labels])
            k = sess.run(net.k, feed_dict={x_ph: x})
            ks.append(k)
            ys.append(y)
            print('.', end='', flush=True)
    except tf.errors.OutOfRangeError:
        print('Extracting Finished')

    ks = np.concatenate(ks, axis=0)
    ys = np.concatenate(ys, axis=0)
    np.savez(os.path.join(os.path.dirname(MODEL), 'ks_ys.npz'), ks=ks, ys=ys)

    coord.request_stop()
    coord.join(threads)
示例#3
0
def rec_a_frame_img_from_ze(latentPath, savePath):
    os.makedirs(savePath[:savePath.rfind('/')], exist_ok=True)
    # os.makedirs(latentRootDir, exist_ok=True)
    # save_dir = os.path.join(saveRootDir, "subject{}/{}/frame_{}".format(
    #     subject, dt_key, frame_idx))
    # os.makedirs(save_dir, exist_ok=True)
    MODEL, K, D = ('models/imagenet/last.ckpt', 512, 128)
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        net = VQVAE(None, None, 0.25, x, K, D, _imagenet_arch, params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    dataset = vqvae_ze_dataset(latentPath)
    # dataset = vqvae_ze_dataset(
    #     os.path.join(latentRootDir, "subject_{}/{}/frame_{}/subject_{}_frame_{}_ze_{}_all_wd003.hdf5".format(
    #         subject, dt_key, frame_idx, subject, frame_idx, postfix)))
    # dataset = vqvae_zq_dataset("/data1/home/guangjie/Data/vim-2-gallant/myOrig/zq_from_vqvae_sv.hdf5")
    dataloader = DataLoader(dataset,
                            batch_size=10,
                            shuffle=False,
                            num_workers=0)

    with h5py.File(savePath, 'w') as sf:
        rec_dataset = sf.create_dataset('rec',
                                        shape=(len(dataset), 128, 128, 3),
                                        dtype=np.uint8,
                                        chunks=True)
        begin_idx = 0
        for step, data in enumerate(dataloader):
            rec = sess.run(net.p_x_z,
                           feed_dict={net.z_e: data
                                      })  # todo z_e z_q 直接喂给zq的话在验证集效果更差。。。
            rec = (rec * 255.0).astype(np.uint8)
            end_idx = begin_idx + len(rec)
            rec_dataset[begin_idx:end_idx] = rec
            begin_idx = end_idx
            print(step)
示例#4
0
def extract_k_rec_from_vqvae(dt_key):
    MODEL, K, D = ('models/imagenet/last.ckpt', 512, 128)
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        net = VQVAE(None, None, 0.25, x, K, D, _imagenet_arch, params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    # dataset = Stimuli_Dataset(
    #     "/data1/home/guangjie/Data/purdue/exprimentData/Stimuli/Stimuli_{}_frame_{}.hdf5".format(
    #         'train' if dt_key == 'st' else 'test', frame_idx), dt_key, transpose=False)
    dtKey = 'stimTrn' if dt_key == 'st' else 'stimVal'
    dataset = vim1_blur_stimuli_dataset(
        "/data1/home/guangjie/Data/vim-1/Stimuli.hdf5", dtKey, 3)
    dataloader = DataLoader(dataset,
                            batch_size=10,
                            shuffle=False,
                            num_workers=1)
    os.makedirs(
        '/data1/home/guangjie/Data/vim1/exprimentData/extract_from_vqvae',
        exist_ok=True)
    with h5py.File(
            "/data1/home/guangjie/Data/vim1/exprimentData/extract_from_vqvae/rec_from_vqvae_{}.hdf5"
            .format(dt_key), 'w') as recf:
        rec_dataset = recf.create_dataset('rec',
                                          shape=(len(dataset), 128, 128, 3))
        with h5py.File(
                "/data1/home/guangjie/Data/vim1/exprimentData/extract_from_vqvae/k_from_vqvae_{}.hdf5"
                .format(dt_key), 'w') as kf:
            k_dataset = kf.create_dataset('k', shape=(len(dataset), 32, 32))
            begin_idx = 0
            for step, data in enumerate(dataloader):
                k, rec = sess.run((net.k, net.p_x_z), feed_dict={x: data})
                end_idx = begin_idx + len(rec)
                rec_dataset[begin_idx:end_idx] = rec
                k_dataset[begin_idx:end_idx] = k
                begin_idx = end_idx
                print(step)
示例#5
0
def extract_z(MODEL, BATCH_SIZE, BETA, K, D, **kwargs):
    # >>>>>>> DATASET
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("datasets/mnist", one_hot=False)
    # <<<<<<<

    # >>>>>>> MODEL
    x = tf.placeholder(tf.float32, [None, 784])
    resized = tf.image.resize_images(tf.reshape(x, [-1, 28, 28, 1]), (24, 24),
                                     method=tf.image.ResizeMethod.BILINEAR)

    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        net = VQVAE(None, None, BETA, resized, K, D, _mnist_arch, params,
                    False)

    # Initialize op
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    xs, ys = mnist.train.images, mnist.train.labels
    ks = []
    for i in tqdm(range(0, len(xs), BATCH_SIZE)):
        batch = xs[i:i + BATCH_SIZE]

        k = sess.run(net.k, feed_dict={x: batch})
        ks.append(k)
    ks = np.concatenate(ks, axis=0)

    np.savez(os.path.join(os.path.dirname(MODEL), 'ks_ys.npz'), ks=ks, ys=ys)
示例#6
0
def extract_zq_from_vqvae(dt_key):
    MODEL, K, D = ('models/imagenet/last.ckpt', 512, 128)
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        net = VQVAE(None, None, 0.25, x, K, D, _imagenet_arch, params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    dtKey = 'stimTrn' if dt_key == 'st' else 'stimVal'
    dataset = vim1_blur_stimuli_dataset(
        "/data1/home/guangjie/Data/vim-1/Stimuli.mat", dtKey, 3)
    # dataset = Stimuli_Dataset("/data1/home/guangjie/Data/vim-2-gallant/orig/Stimuli.mat", dt_key)
    dataloader = DataLoader(dataset,
                            batch_size=10,
                            shuffle=False,
                            num_workers=1)

    with h5py.File(
            "/data1/home/guangjie/Data/vim1/exprimentData/extract_from_vqvae/ze_from_vqvae_{}.hdf5"
            .format(dt_key), 'w') as sf:
        ze_dataset = sf.create_dataset('latent',
                                       shape=(len(dataset), 32, 32, 128))
        begin_idx = 0
        for step, data in enumerate(dataloader):
            ze = sess.run(net.z_e, feed_dict={x: data})
            end_idx = begin_idx + len(ze)
            ze_dataset[begin_idx:end_idx] = ze
            begin_idx = end_idx
            print(step)
示例#7
0
文件: cifar10.py 项目: yyht/tf-vqvae
def train_prior(config, RANDOM_SEED, MODEL, TRAIN_NUM, BATCH_SIZE,
                LEARNING_RATE, DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE,
                GRAD_CLIP, K, D, BETA, NUM_LAYERS, NUM_FEATURE_MAPS,
                SUMMARY_PERIOD, SAVE_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    LOG_DIR = os.path.join(os.path.dirname(MODEL), 'pixelcnn_6')

    # >>>>>>> DATASET
    class Latents():
        def __init__(self, path, validation_size=1):
            from tensorflow.contrib.learn.python.learn.datasets.mnist import DataSet
            from tensorflow.contrib.learn.python.learn.datasets import base

            data = np.load(path)
            train = DataSet(
                data['ks'][validation_size:],
                data['ys'][validation_size:],
                reshape=False,
                dtype=np.uint8,
                one_hot=False
            )  #dtype won't bother even in the case when latent is int32 type.
            validation = DataSet(data['ks'][:validation_size],
                                 data['ys'][:validation_size],
                                 reshape=False,
                                 dtype=np.uint8,
                                 one_hot=False)
            #test = DataSet(data['test_x'],np.argmax(data['test_y'],axis=1),reshape=False,dtype=np.float32,one_hot=False)
            self.size = data['ks'].shape[1]
            self.data = base.Datasets(train=train,
                                      validation=validation,
                                      test=None)

    latent = Latents(os.path.join(os.path.dirname(MODEL), 'ks_ys.npz'))
    # <<<<<<<

    # >>>>>>> MODEL for Generate Images
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        _not_used = tf.placeholder(tf.float32, [None, 32, 32, 3])
        vq_net = VQVAE(None, None, BETA, _not_used, K, D, _cifar10_arch,
                       params, False)
    # <<<<<<<

    # >>>>>> MODEL for Training Prior
    with tf.variable_scope('pixelcnn'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        net = PixelCNN(learning_rate, global_step, GRAD_CLIP, latent.size,
                       vq_net.embeds, K, D, 10, NUM_LAYERS, NUM_FEATURE_MAPS)
    # <<<<<<
    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        sample_images = tf.placeholder(tf.float32, [None, 32, 32, 3])
        sample_summary_op = tf.summary.image('samples',
                                             sample_images,
                                             max_outputs=20)

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    vq_net.load(sess, MODEL)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
        batch_xs, batch_ys = latent.data.train.next_batch(BATCH_SIZE)
        it, loss, _ = sess.run([global_step, net.loss, net.train_op],
                               feed_dict={
                                   net.X: batch_xs,
                                   net.h: batch_ys
                               })

        if (it % SAVE_PERIOD == 0):
            net.save(sess, LOG_DIR, step=it)

        if (it % SUMMARY_PERIOD == 0):
            tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
            summary = sess.run(summary_op,
                               feed_dict={
                                   net.X: batch_xs,
                                   net.h: batch_ys
                               })
            summary_writer.add_summary(summary, it)

        if (it % (SUMMARY_PERIOD * 2) == 0):
            sampled_zs, log_probs = net.sample_from_prior(
                sess, np.arange(10), 2)
            sampled_ims = sess.run(vq_net.gen,
                                   feed_dict={vq_net.latent: sampled_zs})
            summary_writer.add_summary(
                sess.run(sample_summary_op,
                         feed_dict={sample_images: sampled_ims}), it)

    net.save(sess, LOG_DIR)
示例#8
0
# next_item = slices.make_one_shot_iterator().get_next() #todo

with tf.variable_scope('net'):
    with tf.variable_scope('params') as params:
        pass
    x = tf.placeholder(tf.float32, [None, 128, 128, 3])
    net = VQVAE(None, None, 0.25, x, K, D, _imagenet_arch, params, False)

init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.graph.finalize()
sess.run(init_op)
net.load(sess, MODEL)


def draw(images):
    from matplotlib import pyplot as plt
    fig = plt.figure(figsize=(20, 20))
    for n, image in enumerate(images):
        a = fig.add_subplot(2, 5, n + 1)
        a.imshow((image * 255.0).astype(np.uint8))
        # a.imshow(image)
        a.axis('off')
        a.set_aspect('equal')

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()
    plt.close()
示例#9
0
def train_prior(config, RANDOM_SEED, MODEL, TRAIN_NUM, BATCH_SIZE,
                LEARNING_RATE, DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE,
                GRAD_CLIP, K, D, BETA, NUM_LAYERS, NUM_FEATURE_MAPS,
                SUMMARY_PERIOD, SAVE_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    LOG_DIR = os.path.join(os.path.dirname(MODEL), 'pixelcnn')
    # >>>>>>> DATASET
    train_dataset = imagenet.get_split('train', 'datasets/ILSVRC2012')
    ims, labels = _build_batch(train_dataset, BATCH_SIZE, 4)
    # <<<<<<<

    # >>>>>>> MODEL for Generate Images
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        vq_net = VQVAE(None, None, BETA, ims, K, D, _imagenet_arch, params,
                       False)
    # <<<<<<<

    # >>>>>> MODEL for Training Prior
    with tf.variable_scope('pixelcnn'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        net = PixelCNN(learning_rate, global_step, GRAD_CLIP,
                       vq_net.k.get_shape()[1], vq_net.embeds, K, D, 1000,
                       NUM_LAYERS, NUM_FEATURE_MAPS)
    # <<<<<<
    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        sample_images = tf.placeholder(tf.float32, [None, 128, 128, 3])
        sample_summary_op = tf.summary.image('samples',
                                             sample_images,
                                             max_outputs=20)

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    vq_net.load(sess, MODEL)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    try:
        for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
            batch_xs, batch_ys = sess.run([vq_net.k, labels])
            it, loss, _ = sess.run([global_step, net.loss, net.train_op],
                                   feed_dict={
                                       net.X: batch_xs,
                                       net.h: batch_ys
                                   })

            if (it % SAVE_PERIOD == 0):
                net.save(sess, LOG_DIR, step=it)
                sampled_zs, log_probs = net.sample_from_prior(
                    sess, np.random.randint(0, 1000, size=(10, )), 2)
                sampled_ims = sess.run(vq_net.gen,
                                       feed_dict={vq_net.latent: sampled_zs})
                summary_writer.add_summary(
                    sess.run(sample_summary_op,
                             feed_dict={sample_images: sampled_ims}), it)

            if (it % SUMMARY_PERIOD == 0):
                tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
                summary = sess.run(summary_op,
                                   feed_dict={
                                       net.X: batch_xs,
                                       net.h: batch_ys
                                   })
                summary_writer.add_summary(summary, it)

    except Exception as e:
        coord.request_stop(e)
    finally:
        net.save(sess, LOG_DIR)

        coord.request_stop()
        coord.join(threads)