Beispiel #1
0
def main(args):
    if args.checkpoint_path is not None:
        model = VQVAE.load_from_checkpoint(str(args.checkpoint_path))
    else:
        model = VQVAE()

    datamodule = CTDataModule(
        path=args.dataset_path,
        batch_size=1,
        train_frac=1,
        num_workers=5,
        rescale_input=(256,256,128)
    )
    datamodule.setup()
    dataloader = datamodule.train_dataloader()

    db = lmdb.open(
        get_output_abspath(args.checkpoint_path, args.output_path, args.output_name),
        map_size=int(1e12),
        max_dbs=model.n_bottleneck_blocks
    )

    sub_dbs = [db.open_db(str(i).encode()) for i in range(model.n_bottleneck_blocks)]
    with db.begin(write=True) as txn:
        # Write root db metadata
        txn.put(b"num_dbs", str(model.n_bottleneck_blocks).encode())
        txn.put(b"length",  str(len(dataloader)).encode())
        txn.put(b"num_embeddings", pickle.dumps(np.asarray(model.num_embeddings)))

        for i, sample_encodings in tqdm(enumerate(extract_samples(model, dataloader)), total=len(dataloader)):
            for sub_db, encoding in zip(sub_dbs, sample_encodings):
                txn.put(str(i).encode(), pickle.dumps(encoding.cpu().numpy()), db=sub_db)

    db.close()
Beispiel #2
0
def main(args):
    #
    save_dir = os.path.join(args.save_dir, args.model_type)
    img_dir = os.path.join(args.img_dir, args.model_type)
    log_dir = os.path.join(args.log_dir, args.model_type)
    train_dir = args.train_dir

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    mnist = utils.read_data_sets(args.train_dir)
    summary_writer = tf.summary.FileWriter(log_dir)
    config_proto = utils.get_config_proto()

    sess = tf.Session(config=config_proto)
    model = VQVAE(args, sess, name="vqvae")

    total_batch = mnist.train.num_examples // args.batch_size

    for epoch in range(1, args.nb_epoch + 1):
        print "Epoch %d start with learning rate %f" % (
            epoch, model.learning_rate.eval(sess))
        print "- " * 50
        epoch_start_time = time.time()
        step_start_time = epoch_start_time
        for i in range(1, total_batch + 1):
            global_step = sess.run(model.global_step)
            x_batch, y_batch = mnist.train.next_batch(args.batch_size)

            _, loss, rec_loss, vq, commit, global_step, summaries = model.train(
                x_batch)
            summary_writer.add_summary(summaries, global_step)

            if i % args.print_step == 0:
                print "epoch %d, step %d, loss %f, rec_loss %f, vq_loss %f, commit_loss %f, time %.2fs" \
                    % (epoch, global_step, loss, rec_loss, vq, commit, time.time()-step_start_time)
                step_start_time = time.time()

        if epoch % 50 == 0:
            print "- " * 5

        if args.anneal and epoch >= args.anneal_start:
            sess.run(model.lr_decay_op)

        if epoch % args.save_epoch == 0:
            x_batch, y_batch = mnist.test.next_batch(100)
            x_recon = model.reconstruct(x_batch)
            utils.save_images(x_batch.reshape(-1, 28, 28, 1), [10, 10],
                              os.path.join(img_dir, "rawImage%s.jpg" % epoch))
            utils.save_images(
                x_recon, [10, 10],
                os.path.join(img_dir, "reconstruct%s.jpg" % epoch))

    model.saver.save(sess, os.path.join(save_dir, "model.ckpt"))
    print "Model stored...."
Beispiel #3
0
def test(MODEL, BETA, K, D, **kwargs):
    # >>>>>>> DATASET
    image, _ = get_image(num_epochs=1)
    images = tf.train.batch([image],
                            batch_size=100,
                            num_threads=1,
                            capacity=100,
                            allow_smaller_final_batch=True)
    valid_image, _ = get_image(False, num_epochs=1)
    valid_images = tf.train.batch([valid_image],
                                  batch_size=100,
                                  num_threads=1,
                                  capacity=100,
                                  allow_smaller_final_batch=True)
    # <<<<<<<

    # >>>>>>> MODEL
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        x = tf.placeholder(tf.float32, [None, 32, 32, 3])
        net = VQVAE(None, None, BETA, x, K, D, _cifar10_arch, params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    try:
        nlls = []
        while not coord.should_stop():
            nlls.append(
                sess.run(net.nll, feed_dict={x: sess.run(valid_images)}))
            print('.', end='', flush=True)
    except tf.errors.OutOfRangeError:
        nlls = np.concatenate(nlls, axis=0)
        print(nlls.shape)
        print('NLL for test set: %f bits/dims' % (np.mean(nlls)))

    try:
        nlls = []
        while not coord.should_stop():
            nlls.append(sess.run(net.nll, feed_dict={x: sess.run(images)}))
            print('.', end='', flush=True)
    except tf.errors.OutOfRangeError:
        nlls = np.concatenate(nlls, axis=0)
        print(nlls.shape)
        print('NLL for training set: %f bits/dims' % (np.mean(nlls)))

    coord.request_stop()
    coord.join(threads)
Beispiel #4
0
def extract_z(MODEL, BATCH_SIZE, BETA, K, D, **kwargs):
    # >>>>>>> DATASET
    image, label = get_image(num_epochs=1)
    images, labels = tf.train.batch([image, label],
                                    batch_size=BATCH_SIZE,
                                    num_threads=1,
                                    capacity=BATCH_SIZE,
                                    allow_smaller_final_batch=True)
    # <<<<<<<

    # >>>>>>> MODEL
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        x_ph = tf.placeholder(tf.float32, [None, 32, 32, 3])
        net = VQVAE(None, None, BETA, x_ph, K, D, _cifar10_arch, params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    try:
        ks = []
        ys = []
        while not coord.should_stop():
            x, y = sess.run([images, labels])
            k = sess.run(net.k, feed_dict={x_ph: x})
            ks.append(k)
            ys.append(y)
            print('.', end='', flush=True)
    except tf.errors.OutOfRangeError:
        print('Extracting Finished')

    ks = np.concatenate(ks, axis=0)
    ys = np.concatenate(ys, axis=0)
    np.savez(os.path.join(os.path.dirname(MODEL), 'ks_ys.npz'), ks=ks, ys=ys)

    coord.request_stop()
    coord.join(threads)
Beispiel #5
0
def main(args: Namespace):
    pl.seed_everything(42)

    min_val, max_val, scale_val = -1500, 3000, 1000

    print("- Loading dataloader")
    datamodule = CTDataModule(path=args.dataset_path,
                              train_frac=1,
                              batch_size=1,
                              num_workers=0,
                              rescale_input=args.rescale_input)
    datamodule.setup()
    train_loader = datamodule.train_dataloader()

    print("- Loading single CT sample")
    single_sample, _ = next(iter(train_loader))
    single_sample = single_sample.cuda()

    print("- Loading model weights")
    model = VQVAE.load_from_checkpoint(str(args.ckpt_path)).cuda()

    print("- Performing forward pass")
    with torch.no_grad(), torch.cuda.amp.autocast():
        res, *_ = model(single_sample)
        res = torch.nn.functional.elu(res)

    res = res.squeeze().detach().cpu().numpy()
    res = res * scale_val - scale_val
    res = np.rint(res).astype(np.int)

    print("- Writing to nrrd")
    nrrd.write(str(args.out_path), res, header={'spacings': (0.976, 0.976, 3)})

    print("- Done")
Beispiel #6
0
def parse_arguments():
    parser = ArgumentParser()

    parser = pl.Trainer.add_argparse_args(parser)
    parser = VQVAE.add_model_specific_args(parser)

    parser.add_argument('--rescale-input', type=int, nargs='+')
    parser.add_argument("--batch-size", type=int)
    parser.add_argument("dataset_path", type=Path)

    parser.set_defaults(
        gpus="-1",
        accelerator='ddp',
        benchmark=True,
        num_sanity_val_steps=0,
        precision=16,
        log_every_n_steps=50,
        val_check_interval=0.5,
        flush_logs_every_n_steps=100,
        weights_summary='full',
        max_epochs=int(1e5),
    )

    args = parser.parse_args()

    return args
Beispiel #7
0
def extract_k_rec_from_vqvae(dt_key):
    MODEL, K, D = ('models/imagenet/last.ckpt', 512, 128)
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        net = VQVAE(None, None, 0.25, x, K, D, _imagenet_arch, params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    # dataset = Stimuli_Dataset(
    #     "/data1/home/guangjie/Data/purdue/exprimentData/Stimuli/Stimuli_{}_frame_{}.hdf5".format(
    #         'train' if dt_key == 'st' else 'test', frame_idx), dt_key, transpose=False)
    dtKey = 'stimTrn' if dt_key == 'st' else 'stimVal'
    dataset = vim1_blur_stimuli_dataset(
        "/data1/home/guangjie/Data/vim-1/Stimuli.hdf5", dtKey, 3)
    dataloader = DataLoader(dataset,
                            batch_size=10,
                            shuffle=False,
                            num_workers=1)
    os.makedirs(
        '/data1/home/guangjie/Data/vim1/exprimentData/extract_from_vqvae',
        exist_ok=True)
    with h5py.File(
            "/data1/home/guangjie/Data/vim1/exprimentData/extract_from_vqvae/rec_from_vqvae_{}.hdf5"
            .format(dt_key), 'w') as recf:
        rec_dataset = recf.create_dataset('rec',
                                          shape=(len(dataset), 128, 128, 3))
        with h5py.File(
                "/data1/home/guangjie/Data/vim1/exprimentData/extract_from_vqvae/k_from_vqvae_{}.hdf5"
                .format(dt_key), 'w') as kf:
            k_dataset = kf.create_dataset('k', shape=(len(dataset), 32, 32))
            begin_idx = 0
            for step, data in enumerate(dataloader):
                k, rec = sess.run((net.k, net.p_x_z), feed_dict={x: data})
                end_idx = begin_idx + len(rec)
                rec_dataset[begin_idx:end_idx] = rec
                k_dataset[begin_idx:end_idx] = k
                begin_idx = end_idx
                print(step)
Beispiel #8
0
def rec_a_frame_img_from_ze(latentPath, savePath):
    os.makedirs(savePath[:savePath.rfind('/')], exist_ok=True)
    # os.makedirs(latentRootDir, exist_ok=True)
    # save_dir = os.path.join(saveRootDir, "subject{}/{}/frame_{}".format(
    #     subject, dt_key, frame_idx))
    # os.makedirs(save_dir, exist_ok=True)
    MODEL, K, D = ('models/imagenet/last.ckpt', 512, 128)
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        net = VQVAE(None, None, 0.25, x, K, D, _imagenet_arch, params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    dataset = vqvae_ze_dataset(latentPath)
    # dataset = vqvae_ze_dataset(
    #     os.path.join(latentRootDir, "subject_{}/{}/frame_{}/subject_{}_frame_{}_ze_{}_all_wd003.hdf5".format(
    #         subject, dt_key, frame_idx, subject, frame_idx, postfix)))
    # dataset = vqvae_zq_dataset("/data1/home/guangjie/Data/vim-2-gallant/myOrig/zq_from_vqvae_sv.hdf5")
    dataloader = DataLoader(dataset,
                            batch_size=10,
                            shuffle=False,
                            num_workers=0)

    with h5py.File(savePath, 'w') as sf:
        rec_dataset = sf.create_dataset('rec',
                                        shape=(len(dataset), 128, 128, 3),
                                        dtype=np.uint8,
                                        chunks=True)
        begin_idx = 0
        for step, data in enumerate(dataloader):
            rec = sess.run(net.p_x_z,
                           feed_dict={net.z_e: data
                                      })  # todo z_e z_q 直接喂给zq的话在验证集效果更差。。。
            rec = (rec * 255.0).astype(np.uint8)
            end_idx = begin_idx + len(rec)
            rec_dataset[begin_idx:end_idx] = rec
            begin_idx = end_idx
            print(step)
Beispiel #9
0
def extract_z(MODEL, BATCH_SIZE, BETA, K, D, **kwargs):
    # >>>>>>> DATASET
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("datasets/mnist", one_hot=False)
    # <<<<<<<

    # >>>>>>> MODEL
    x = tf.placeholder(tf.float32, [None, 784])
    resized = tf.image.resize_images(tf.reshape(x, [-1, 28, 28, 1]), (24, 24),
                                     method=tf.image.ResizeMethod.BILINEAR)

    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        net = VQVAE(None, None, BETA, resized, K, D, _mnist_arch, params,
                    False)

    # Initialize op
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    xs, ys = mnist.train.images, mnist.train.labels
    ks = []
    for i in tqdm(range(0, len(xs), BATCH_SIZE)):
        batch = xs[i:i + BATCH_SIZE]

        k = sess.run(net.k, feed_dict={x: batch})
        ks.append(k)
    ks = np.concatenate(ks, axis=0)

    np.savez(os.path.join(os.path.dirname(MODEL), 'ks_ys.npz'), ks=ks, ys=ys)
Beispiel #10
0
def extract_zq_from_vqvae(dt_key):
    MODEL, K, D = ('models/imagenet/last.ckpt', 512, 128)
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        x = tf.placeholder(tf.float32, [None, 128, 128, 3])
        net = VQVAE(None, None, 0.25, x, K, D, _imagenet_arch, params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    dtKey = 'stimTrn' if dt_key == 'st' else 'stimVal'
    dataset = vim1_blur_stimuli_dataset(
        "/data1/home/guangjie/Data/vim-1/Stimuli.mat", dtKey, 3)
    # dataset = Stimuli_Dataset("/data1/home/guangjie/Data/vim-2-gallant/orig/Stimuli.mat", dt_key)
    dataloader = DataLoader(dataset,
                            batch_size=10,
                            shuffle=False,
                            num_workers=1)

    with h5py.File(
            "/data1/home/guangjie/Data/vim1/exprimentData/extract_from_vqvae/ze_from_vqvae_{}.hdf5"
            .format(dt_key), 'w') as sf:
        ze_dataset = sf.create_dataset('latent',
                                       shape=(len(dataset), 32, 32, 128))
        begin_idx = 0
        for step, data in enumerate(dataloader):
            ze = sess.run(net.z_e, feed_dict={x: data})
            end_idx = begin_idx + len(ze)
            ze_dataset[begin_idx:end_idx] = ze
            begin_idx = end_idx
            print(step)
Beispiel #11
0
def main(args):
    torch.cuda.empty_cache()

    pl.trainer.seed_everything(seed=42)

    datamodule = CTDataModule(path=args.dataset_path,
                              batch_size=args.batch_size,
                              num_workers=5,
                              rescale_input=args.rescale_input)

    model = VQVAE(args)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        save_top_k=1, save_last=True, monitor='val_recon_loss_mean')

    trainer = pl.Trainer.from_argparse_args(args,
                                            callbacks=[checkpoint_callback])
    trainer.fit(model, datamodule=datamodule)
Beispiel #12
0
def main():
    pl.seed_everything(1337)

    parser = ArgumentParser()
    # model related
    parser.add_argument("--vq_flavor",
                        type=str,
                        default='vqvae',
                        choices=['vqvae', 'gumbel'])
    # data related
    parser.add_argument("--data_dir",
                        type=str,
                        default='/apcv/users/akarpathy/cifar10')
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--num_workers", type=int, default=8)
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    data = CIFAR10Data(args)
    model = VQVAE(args)

    checkpoint_callback = ModelCheckpoint(monitor='val_recon_error',
                                          mode='min')

    class DecayTemperature(pl.Callback):
        def on_train_epoch_start(self, trainer, pl_module):
            e = trainer.current_epoch
            e0, e1 = 0, 30
            t0, t1 = 1.0, 0.1
            alpha = max(0, min(1, (e - e0) / (e1 - e0)))
            t = alpha * t1 + (
                1 - alpha) * t0  # probably should be exponential instead
            print("epoch %d setting temperature of model's quantizer to %f" %
                  (e, t))
            pl_module.quantizer.temperature = t

    trainer = pl.Trainer.from_argparse_args(
        args, callbacks=[checkpoint_callback,
                         DecayTemperature()])

    trainer.fit(model, data)
Beispiel #13
0
def main(args: Namespace):

    min_val, max_val, scale_val = -1500, 3000, 1000

    print("- Loading model weights")
    model = VQVAE.load_from_checkpoint(str(args.ckpt_path)).cuda()

    db = torch.load(args.db_path)

    for embedding_0_key, embedding_0 in db[0].items():
        embedding_1_key = embedding_0['condition']
        embedding_1 = db[1][embedding_1_key]

        # issue where the pixelcnn samples 0's
        success = 'failure' if torch.all(
            embedding_0['data'][-1] == 0) else 'success'

        embeddings = [
            quantizer.embed_code(embedding['data'].cuda().unsqueeze(
                dim=0)).permute(0, 4, 1, 2, 3)
            for embedding, quantizer in zip((
                embedding_0, embedding_1), model.encoder.quantize)
        ]

        print("- Performing forward pass")
        with torch.cuda.amp.autocast():
            res = model.decode(embeddings)
            res = torch.nn.functional.elu(res)

        res = res.squeeze().detach().cpu().numpy()
        res = res * scale_val - scale_val
        res = np.rint(res).astype(np.int)

        print("- Writing to nrrd")
        nrrd.write(
            str(args.out_path) +
            f'_{success}_{str(embedding_1_key)}_{str(embedding_0_key)}.nrrd',
            res,
            header={'spacings': (0.976, 0.976, 3)})

        print("- Done")
def main(args: Namespace):
    # Same seed as used in train.py, so that train/val splits are also the same
    pl.trainer.seed_everything(seed=42)

    print("- Loading datamodule")
    datamodule = CTDataModule(path=args.dataset_path,
                              batch_size=5,
                              num_workers=5)  # mypy: ignore
    datamodule.setup()

    train_dl = datamodule.train_dataloader()
    val_dl = datamodule.val_dataloader()

    print("- Loading model weights")
    model = VQVAE.load_from_checkpoint(str(args.ckpt_path)).cuda()

    data_min, data_max = -0.24, 4
    data_range = data_max - data_min

    train_ssim = SSIM3DSlices(data_range=data_range)
    val_ssim = SSIM3DSlices(data_range=data_range)

    def batch_ssim(batch, ssim_f):
        batch = batch.cuda()
        out, *_ = model(batch)
        out = F.elu(out)
        return val_ssim(out.float(), batch)

    with torch.no_grad(), torch.cuda.amp.autocast():
        val_ssims = torch.Tensor(
            [batch_ssim(batch, ssim_f=val_ssim) for batch, _ in tqdm(val_dl)])
        breakpoint()
        train_ssims = torch.Tensor([
            batch_ssim(batch, ssim_f=train_ssim) for batch, _ in tqdm(train_dl)
        ])

    # breakpoint for manual decision what to do with train_ssims/val_ssims
    # TODO: find some better solution to described above
    breakpoint()
Beispiel #15
0
def train_prior(config, RANDOM_SEED, MODEL, TRAIN_NUM, BATCH_SIZE,
                LEARNING_RATE, DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE,
                GRAD_CLIP, K, D, BETA, NUM_LAYERS, NUM_FEATURE_MAPS,
                SUMMARY_PERIOD, SAVE_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    LOG_DIR = os.path.join(os.path.dirname(MODEL), 'pixelcnn')
    # >>>>>>> DATASET
    train_dataset = imagenet.get_split('train', 'datasets/ILSVRC2012')
    ims, labels = _build_batch(train_dataset, BATCH_SIZE, 4)
    # <<<<<<<

    # >>>>>>> MODEL for Generate Images
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        vq_net = VQVAE(None, None, BETA, ims, K, D, _imagenet_arch, params,
                       False)
    # <<<<<<<

    # >>>>>> MODEL for Training Prior
    with tf.variable_scope('pixelcnn'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        net = PixelCNN(learning_rate, global_step, GRAD_CLIP,
                       vq_net.k.get_shape()[1], vq_net.embeds, K, D, 1000,
                       NUM_LAYERS, NUM_FEATURE_MAPS)
    # <<<<<<
    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        sample_images = tf.placeholder(tf.float32, [None, 128, 128, 3])
        sample_summary_op = tf.summary.image('samples',
                                             sample_images,
                                             max_outputs=20)

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    vq_net.load(sess, MODEL)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    try:
        for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
            batch_xs, batch_ys = sess.run([vq_net.k, labels])
            it, loss, _ = sess.run([global_step, net.loss, net.train_op],
                                   feed_dict={
                                       net.X: batch_xs,
                                       net.h: batch_ys
                                   })

            if (it % SAVE_PERIOD == 0):
                net.save(sess, LOG_DIR, step=it)
                sampled_zs, log_probs = net.sample_from_prior(
                    sess, np.random.randint(0, 1000, size=(10, )), 2)
                sampled_ims = sess.run(vq_net.gen,
                                       feed_dict={vq_net.latent: sampled_zs})
                summary_writer.add_summary(
                    sess.run(sample_summary_op,
                             feed_dict={sample_images: sampled_ims}), it)

            if (it % SUMMARY_PERIOD == 0):
                tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
                summary = sess.run(summary_op,
                                   feed_dict={
                                       net.X: batch_xs,
                                       net.h: batch_ys
                                   })
                summary_writer.add_summary(summary, it)

    except Exception as e:
        coord.request_stop(e)
    finally:
        net.save(sess, LOG_DIR)

        coord.request_stop()
        coord.join(threads)
Beispiel #16
0
def train_vqvae(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = VQVAE(args.channels, args.latent_dim, args.num_embeddings,
                  args.embedding_dim)
    model.to(device)

    model_name = "{}_C_{}_N_{}_M_{}_D_{}".format(args.model, args.channels,
                                                 args.latent_dim,
                                                 args.num_embeddings,
                                                 args.embedding_dim)

    checkpoint_dir = Path(model_name)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    writer = SummaryWriter(log_dir=Path("runs") / model_name)

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    if args.resume is not None:
        print("Resume checkpoint from: {}:".format(args.resume))
        checkpoint = torch.load(args.resume,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        global_step = checkpoint["step"]
    else:
        global_step = 0

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(shift)])
    training_dataset = datasets.CIFAR10("./CIFAR10",
                                        train=True,
                                        download=True,
                                        transform=transform)

    test_dataset = datasets.CIFAR10("./CIFAR10",
                                    train=False,
                                    download=True,
                                    transform=transform)

    training_dataloader = DataLoader(training_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     pin_memory=True)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=64,
                                 shuffle=True,
                                 drop_last=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    num_epochs = args.num_training_steps // len(training_dataloader) + 1
    start_epoch = global_step // len(training_dataloader) + 1

    N = 3 * 32 * 32
    KL = args.latent_dim * 8 * 8 * np.log(args.num_embeddings)

    for epoch in range(start_epoch, num_epochs + 1):
        model.train()
        average_logp = average_vq_loss = average_elbo = average_bpd = average_perplexity = 0
        for i, (images, _) in enumerate(tqdm(training_dataloader), 1):
            images = images.to(device)

            dist, vq_loss, perplexity = model(images)
            targets = (images + 0.5) * 255
            targets = targets.long()
            logp = dist.log_prob(targets).sum((1, 2, 3)).mean()
            loss = -logp / N + vq_loss
            elbo = (KL - logp) / N
            bpd = elbo / np.log(2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            global_step += 1

            if global_step % 25000 == 0:
                save_checkpoint(model, optimizer, global_step, checkpoint_dir)

            average_logp += (logp.item() - average_logp) / i
            average_vq_loss += (vq_loss.item() - average_vq_loss) / i
            average_elbo += (elbo.item() - average_elbo) / i
            average_bpd += (bpd.item() - average_bpd) / i
            average_perplexity += (perplexity.item() - average_perplexity) / i

        writer.add_scalar("logp/train", average_logp, epoch)
        writer.add_scalar("kl/train", KL, epoch)
        writer.add_scalar("vqloss/train", average_vq_loss, epoch)
        writer.add_scalar("elbo/train", average_elbo, epoch)
        writer.add_scalar("bpd/train", average_bpd, epoch)
        writer.add_scalar("perplexity/train", average_perplexity, epoch)

        model.eval()
        average_logp = average_vq_loss = average_elbo = average_bpd = average_perplexity = 0
        for i, (images, _) in enumerate(test_dataloader, 1):
            images = images.to(device)

            with torch.no_grad():
                dist, vq_loss, perplexity = model(images)

            targets = (images + 0.5) * 255
            targets = targets.long()
            logp = dist.log_prob(targets).sum((1, 2, 3)).mean()
            elbo = (KL - logp) / N
            bpd = elbo / np.log(2)

            average_logp += (logp.item() - average_logp) / i
            average_vq_loss += (vq_loss.item() - average_vq_loss) / i
            average_elbo += (elbo.item() - average_elbo) / i
            average_bpd += (bpd.item() - average_bpd) / i
            average_perplexity += (perplexity.item() - average_perplexity) / i

        writer.add_scalar("logp/test", average_logp, epoch)
        writer.add_scalar("kl/test", KL, epoch)
        writer.add_scalar("vqloss/test", average_vq_loss, epoch)
        writer.add_scalar("elbo/test", average_elbo, epoch)
        writer.add_scalar("bpd/test", average_bpd, epoch)
        writer.add_scalar("perplexity/test", average_perplexity, epoch)

        samples = torch.argmax(dist.logits, dim=-1)
        grid = utils.make_grid(samples.float() / 255)
        writer.add_image("reconstructions", grid, epoch)

        print(
            "epoch:{}, logp:{:.3E}, vq loss:{:.3E}, elbo:{:.3f}, bpd:{:.3f}, perplexity:{:.3f}"
            .format(epoch, average_logp, average_vq_loss, average_elbo,
                    average_bpd, average_perplexity))
Beispiel #17
0
def main(args):
    #
    save_dir = os.path.join(args.save_dir, args.model_type)
    img_dir = os.path.join(args.img_dir, args.model_type)
    log_dir = os.path.join(args.log_dir, args.model_type)
    train_dir = args.train_dir

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    summary_writer = tf.summary.FileWriter(log_dir)
    config_proto = utils.get_config_proto()

    sess = tf.Session(config=config_proto)
    model = VQVAE(args, sess, name="vqvae")

    img_paths = glob.glob('data/img_align_celeba/*.jpg')
    train_paths, test_paths = train_test_split(img_paths,
                                               test_size=0.1,
                                               random_state=args.random_seed)
    celeba = utils.DiskImageData(sess,
                                 train_paths,
                                 args.batch_size,
                                 shape=[218, 178, 3])
    total_batch = celeba.num_examples // args.batch_size

    for epoch in range(1, args.nb_epoch + 1):
        print "Epoch %d start with learning rate %f" % (
            epoch, model.learning_rate.eval(sess))
        print "- " * 50
        epoch_start_time = time.time()
        step_start_time = epoch_start_time
        for i in range(1, total_batch + 1):
            global_step = sess.run(model.global_step)
            x_batch = celeba.next_batch()

            _, loss, rec_loss, vq, commit, global_step, summaries = model.train(
                x_batch)
            summary_writer.add_summary(summaries, global_step)

            if i % args.print_step == 0:
                print "epoch %d, step %d, loss %f, rec_loss %f, vq_loss %f, commit_loss %f, time %.2fs" \
                    % (epoch, global_step, loss, rec_loss, vq, commit, time.time()-step_start_time)
                step_start_time = time.time()

        if args.anneal and epoch >= args.anneal_start:
            sess.run(model.lr_decay_op)

        if epoch % args.save_epoch == 0:
            x_batch = celeba.next_batch()
            x_recon = model.reconstruct(x_batch)
            utils.save_images(x_batch, [10, 10],
                              os.path.join(img_dir, "rawImage%s.jpg" % epoch))
            utils.save_images(
                x_recon, [10, 10],
                os.path.join(img_dir, "reconstruct%s.jpg" % epoch))

    model.saver.save(sess, os.path.join(save_dir, "model.ckpt"))
    print "Model stored...."
Beispiel #18
0
# with h5py.File("/data1/home/guangjie/Data/vim-2-gallant/orig/Stimuli.mat", 'r') as f:
# st0 = f['st'][:500:50, :, :, :].transpose((0, 3, 2, 1)) / 255.0  # shape = (108000,3,128,128)
# st1 = f['st'][500:1000:50, :, :, :].transpose((0, 3, 2, 1)) / 255.0
# st0 = st0[np.newaxis, :]
# st = f['st'][0]
# extend_st = st[np.newaxis, :]

# slices = tf.data.Dataset.from_tensor_slices(extend_st)
# next_item = slices.make_one_shot_iterator().get_next() #todo

with tf.variable_scope('net'):
    with tf.variable_scope('params') as params:
        pass
    x = tf.placeholder(tf.float32, [None, 128, 128, 3])
    net = VQVAE(None, None, 0.25, x, K, D, _imagenet_arch, params, False)

init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
sess.graph.finalize()
sess.run(init_op)
net.load(sess, MODEL)


def draw(images):
    from matplotlib import pyplot as plt
    fig = plt.figure(figsize=(20, 20))
    for n, image in enumerate(images):
Beispiel #19
0
from torch import nn, optim
import torch
from tqdm import tqdm

from model import VQVAE

args = config.get_args()
transform = config.get_transform()

dataset = datasets.ImageFolder(args.path, transform=transform)
loader = DataLoader(dataset,
                    batch_size=args.batch,
                    shuffle=True,
                    num_workers=0)

model = VQVAE()
model = model.cuda()

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

from torch.autograd import Variable

for epoch in range(args.epoch):

    loader = tqdm(loader)

    for i, (img, _) in enumerate(loader):
        img = img.cuda()

        #generate the attention regions for the images
Beispiel #20
0
def main(config, RANDOM_SEED, LOG_DIR, TRAIN_NUM, BATCH_SIZE, LEARNING_RATE,
         DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE, BETA, K, D, SAVE_PERIOD,
         SUMMARY_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)

    # >>>>>>> DATASET
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("datasets/mnist", one_hot=False)
    # <<<<<<<

    # >>>>>>> MODEL
    x = tf.placeholder(tf.float32, [None, 784])
    resized = tf.image.resize_images(tf.reshape(x, [-1, 28, 28, 1]), (24, 24),
                                     method=tf.image.ResizeMethod.BILINEAR)

    with tf.variable_scope('train'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        with tf.variable_scope('params') as params:
            pass
        net = VQVAE(learning_rate, global_step, BETA, resized, K, D,
                    _mnist_arch, params, True)

    with tf.variable_scope('valid'):
        params.reuse_variables()
        valid_net = VQVAE(None, None, BETA, resized, K, D, _mnist_arch, params,
                          False)

    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        tf.summary.scalar('recon', net.recon)
        tf.summary.scalar('vq', net.vq)
        tf.summary.scalar('commit', BETA * net.commit)
        tf.summary.image('origin', resized, max_outputs=4)
        tf.summary.image('recon', net.p_x_z, max_outputs=4)
        # TODO: logliklihood

        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        extended_summary_op = tf.summary.merge([
            tf.summary.scalar('valid_loss', valid_net.loss),
            tf.summary.scalar('valid_recon', valid_net.recon),
            tf.summary.scalar('valid_vq', valid_net.vq),
            tf.summary.scalar('valid_commit', BETA * valid_net.commit),
            tf.summary.image('valid_recon', valid_net.p_x_z, max_outputs=10),
        ])

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
        batch_xs, _ = mnist.train.next_batch(BATCH_SIZE)
        it, loss, _ = sess.run([global_step, net.loss, net.train_op],
                               feed_dict={x: batch_xs})

        if (it % SAVE_PERIOD == 0):
            net.save(sess, LOG_DIR, step=it)

        if (it % SUMMARY_PERIOD == 0):
            tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
            summary = sess.run(summary_op, feed_dict={x: batch_xs})
            summary_writer.add_summary(summary, it)

        if (it % (SUMMARY_PERIOD * 2) == 0):  #Extended Summary
            batch_xs, _ = mnist.test.next_batch(BATCH_SIZE)
            summary = sess.run(extended_summary_op, feed_dict={x: batch_xs})
            summary_writer.add_summary(summary, it)

    net.save(sess, LOG_DIR)
Beispiel #21
0
decoder = tfk.Sequential([
    tfkl.Flatten(),  # remove extra dim
    tfkl.Dense(units=7 * 7 * 32, activation='relu'),
    tfkl.Reshape(target_shape=(7, 7, 32)),
    tfkl.Conv2DTranspose(filters=16,
                         kernel_size=3,
                         strides=2,
                         padding='same',
                         activation='relu'),
    tfkl.Conv2DTranspose(filters=1, kernel_size=3, strides=2,
                         padding='same'),  # no activation
])

# Define model
model = VQVAE(encoder, decoder, codebook_size=32)
model.compile(optimizer='adam', loss='mse')

# Callbacks
time = datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join('.', 'logs', 'vqvae', time)
tensorboard_clbk = tfk.callbacks.TensorBoard(log_dir=log_dir)
plot_clbk = PlotReconstructionCallback(logdir=log_dir, test_ds=test_ds, nex=4)
callbacks = [tensorboard_clbk, plot_clbk]

# Fit
model.fit(train_ds,
          validation_data=test_ds,
          epochs=EPOCHS,
          callbacks=callbacks)
Beispiel #22
0
def train_prior(config, RANDOM_SEED, MODEL, TRAIN_NUM, BATCH_SIZE,
                LEARNING_RATE, DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE,
                GRAD_CLIP, K, D, BETA, NUM_LAYERS, NUM_FEATURE_MAPS,
                SUMMARY_PERIOD, SAVE_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    LOG_DIR = os.path.join(os.path.dirname(MODEL), 'pixelcnn_6')

    # >>>>>>> DATASET
    class Latents():
        def __init__(self, path, validation_size=1):
            from tensorflow.contrib.learn.python.learn.datasets.mnist import DataSet
            from tensorflow.contrib.learn.python.learn.datasets import base

            data = np.load(path)
            train = DataSet(
                data['ks'][validation_size:],
                data['ys'][validation_size:],
                reshape=False,
                dtype=np.uint8,
                one_hot=False
            )  #dtype won't bother even in the case when latent is int32 type.
            validation = DataSet(data['ks'][:validation_size],
                                 data['ys'][:validation_size],
                                 reshape=False,
                                 dtype=np.uint8,
                                 one_hot=False)
            #test = DataSet(data['test_x'],np.argmax(data['test_y'],axis=1),reshape=False,dtype=np.float32,one_hot=False)
            self.size = data['ks'].shape[1]
            self.data = base.Datasets(train=train,
                                      validation=validation,
                                      test=None)

    latent = Latents(os.path.join(os.path.dirname(MODEL), 'ks_ys.npz'))
    # <<<<<<<

    # >>>>>>> MODEL for Generate Images
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        _not_used = tf.placeholder(tf.float32, [None, 32, 32, 3])
        vq_net = VQVAE(None, None, BETA, _not_used, K, D, _cifar10_arch,
                       params, False)
    # <<<<<<<

    # >>>>>> MODEL for Training Prior
    with tf.variable_scope('pixelcnn'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        net = PixelCNN(learning_rate, global_step, GRAD_CLIP, latent.size,
                       vq_net.embeds, K, D, 10, NUM_LAYERS, NUM_FEATURE_MAPS)
    # <<<<<<
    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        sample_images = tf.placeholder(tf.float32, [None, 32, 32, 3])
        sample_summary_op = tf.summary.image('samples',
                                             sample_images,
                                             max_outputs=20)

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    vq_net.load(sess, MODEL)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
        batch_xs, batch_ys = latent.data.train.next_batch(BATCH_SIZE)
        it, loss, _ = sess.run([global_step, net.loss, net.train_op],
                               feed_dict={
                                   net.X: batch_xs,
                                   net.h: batch_ys
                               })

        if (it % SAVE_PERIOD == 0):
            net.save(sess, LOG_DIR, step=it)

        if (it % SUMMARY_PERIOD == 0):
            tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
            summary = sess.run(summary_op,
                               feed_dict={
                                   net.X: batch_xs,
                                   net.h: batch_ys
                               })
            summary_writer.add_summary(summary, it)

        if (it % (SUMMARY_PERIOD * 2) == 0):
            sampled_zs, log_probs = net.sample_from_prior(
                sess, np.arange(10), 2)
            sampled_ims = sess.run(vq_net.gen,
                                   feed_dict={vq_net.latent: sampled_zs})
            summary_writer.add_summary(
                sess.run(sample_summary_op,
                         feed_dict={sample_images: sampled_ims}), it)

    net.save(sess, LOG_DIR)
Beispiel #23
0
    raise NotImplementedError("encoder %s not implemented" % args.encoder)
decoder = WavenetDecoder(parameters['wavenet_parameters'])
model_args = {
    'x': tf.constant(wav),
    'speaker': tf.constant(speaker, dtype=np.float32),
    'encoder': encoder,
    'decoder': decoder,
    'k': parameters['k'],
    'beta': parameters['beta'],
    'verbose': parameters['verbose'],
    'use_vq': parameters['use_vq'],
    'speaker_embedding': parameters['speaker_embedding'],
    'num_speakers': num_speakers
}

model = VQVAE(model_args)
model.build_generator()
wavenet = model.decoder.wavenet

variables = model.ema.variables_to_restore()
saver = tf.train.Saver(variables)
saver.restore(sess, args.restore_path)

encoding = sess.run(model.encoding)

save_path = args.restore_path.split('/weights')[0]

if parameters['use_vq']:
    embedding = sess.run(model.embedding)
    np.save(save_path + '/embedding_%d.npy' % gs, embedding)
if parameters['speaker_embedding'] > 0:
Beispiel #24
0
model_args = {
    'x': dataset.x,
    'speaker': dataset.y,
    'encoder': encoder,
    'decoder': decoder,
    'k': parameters['k'],
    'beta': parameters['beta'],
    'verbose': parameters['verbose'],
    'use_vq': parameters['use_vq'],
    'speaker_embedding': parameters['speaker_embedding'],
    'num_speakers': dataset.num_speakers
}

schedule = {int(k): v for k, v in parameters['learning_rate_schedule'].items()}

model = VQVAE(model_args)
model.build(learning_rate_schedule=schedule)

sess = tf.Session()
saver = tf.train.Saver()

if args.restore_path is not None:
    saver.restore(sess, args.restore_path)
else:
    sess.run(tf.global_variables_initializer())

gs = sess.run(model.global_step)
lr = sess.run(model.lr)
print('[restore] last global step: %d, learning rate: %.5f' % (gs, lr))

save_path = args.save_path
Beispiel #25
0
def main(config, RANDOM_SEED, LOG_DIR, TRAIN_NUM, BATCH_SIZE, LEARNING_RATE,
         DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE, BETA, K, D, SAVE_PERIOD,
         SUMMARY_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)

    # >>>>>>> DATASET
    image, _ = get_image()
    images = tf.train.shuffle_batch([image],
                                    batch_size=BATCH_SIZE,
                                    num_threads=4,
                                    capacity=BATCH_SIZE * 10,
                                    min_after_dequeue=BATCH_SIZE * 2)
    valid_image, _ = get_image(False)
    valid_images = tf.train.shuffle_batch([valid_image],
                                          batch_size=BATCH_SIZE,
                                          num_threads=1,
                                          capacity=BATCH_SIZE * 10,
                                          min_after_dequeue=BATCH_SIZE * 2)
    # <<<<<<<

    # >>>>>>> MODEL
    with tf.variable_scope('train'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        with tf.variable_scope('params') as params:
            pass
        net = VQVAE(learning_rate, global_step, BETA, images, K, D,
                    _cifar10_arch, params, True)

    with tf.variable_scope('valid'):
        params.reuse_variables()
        valid_net = VQVAE(None, None, BETA, valid_images, K, D, _cifar10_arch,
                          params, False)

    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        tf.summary.scalar('recon', net.recon)
        tf.summary.scalar('vq', net.vq)
        tf.summary.scalar('commit', BETA * net.commit)
        tf.summary.scalar('nll', tf.reduce_mean(net.nll))
        tf.summary.image('origin', images, max_outputs=4)
        tf.summary.image('recon', net.p_x_z, max_outputs=4)
        # TODO: logliklihood

        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        extended_summary_op = tf.summary.merge([
            tf.summary.scalar('valid_loss', valid_net.loss),
            tf.summary.scalar('valid_recon', valid_net.recon),
            tf.summary.scalar('valid_vq', valid_net.vq),
            tf.summary.scalar('valid_commit', BETA * valid_net.commit),
            tf.summary.scalar('valid_nll', tf.reduce_mean(valid_net.nll)),
            tf.summary.image('valid_origin', valid_images, max_outputs=4),
            tf.summary.image('valid_recon', valid_net.p_x_z, max_outputs=4),
        ])

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    try:
        # Start Queueing
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord, sess=sess)
        for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
            it, loss, _ = sess.run([global_step, net.loss, net.train_op])

            if (it % SAVE_PERIOD == 0):
                net.save(sess, LOG_DIR, step=it)

            if (it % SUMMARY_PERIOD == 0):
                tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
                summary = sess.run(summary_op)
                summary_writer.add_summary(summary, it)

            if (it % (SUMMARY_PERIOD * 2) == 0):  #Extended Summary
                summary = sess.run(extended_summary_op)
                summary_writer.add_summary(summary, it)

    except Exception as e:
        coord.request_stop(e)
    finally:
        net.save(sess, LOG_DIR)

        coord.request_stop()
        coord.join(threads)
Beispiel #26
0
if args.cuda:
    torch.cuda.manual_seed(args.seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(datasets.MNIST(
    '../data', train=True, download=True, transform=transforms.ToTensor()),
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.MNIST(
    '../data', train=False, transform=transforms.ToTensor()),
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          **kwargs)

model = VQVAE(args.input_dim, args.emb_dim, args.emb_num, args.batch_size)
if args.cuda:
    model.cuda()

optimizer = optim.Adam(model.parameters(), lr=1e-3)


def train(epoch):
    """run one epoch of model to train with data loader"""
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data).view(-1, 784)
        if args.cuda:
            data = data.cuda()
        # run forward