Пример #1
0
def load_image_data(data, n_xl, n_channels, output_batch_size):
    if data == 'mnist':
        # Load MNIST
        data_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                 'data', 'mnist.pkl.gz')
        x_train, t_train, x_valid, t_valid, _, _ = \
            dataset.load_mnist_realval(data_path)
        x_train = np.vstack([x_train, x_valid]).astype('float32')
        x_train = np.reshape(x_train, [-1, n_xl, n_xl, n_channels])

        x_train2 = x_train[:output_batch_size]
        t_train2 = t_train[:output_batch_size]
        t_train2 = np.nonzero(t_train2)[1]
        order = np.argsort(t_train2)
        sorted_x_train = x_train2[order]
    elif data == 'svhn':
        # Load SVHN data
        print('Reading svhn...')
        time_read = -time.time()
        print('Train')
        x_train = np.load('data/svhn_train1_x.npy')
        y_train = np.load('data/svhn_train1_y.npy')
        print('Test')
        x_test = np.load('data/svhn_test_x.npy')
        y_test = np.load('data/svhn_test_y.npy')
        time_read += time.time()
        print('Finished in {:.4f} seconds'.format(time_read))

        x_train2 = x_train[:output_batch_size]
        y_train2 = y_train[:output_batch_size]
        order = np.argsort(y_train2)
        sorted_x_train = x_train2[order]
    elif data == 'lfw':
        # Load LFW data
        print('Reading lfw...')
        time_read = -time.time()
        x_train = np.load('data/lfw.npy').astype(np.float32)
        print(x_train.shape)
        x_train = np.reshape(x_train, [-1, n_xl, n_xl, n_channels])
        time_read += time.time()
        print('Finished in {:.4f} seconds'.format(time_read))

        sorted_x_train = x_train[:output_batch_size]
    else:
        x_train, t_train, x_test, t_test = \
            dataset.load_cifar10('data/cifar10/cifar-10-python.tar.gz', normalize=True, one_hot=True)
        x = np.vstack((x_train, x_test))
        t = np.vstack((t_train, t_test))

        x2 = x[:output_batch_size]
        t2 = np.argmax(t[:output_batch_size], 1)
        order = np.argsort(t2)

        x_train = x
        sorted_x_train = x2[order]

    return x_train, sorted_x_train
Пример #2
0
import tensorflow as tf
from tensorflow.contrib import layers
from six.moves import range
import numpy as np
import zhusuan as zs

import conf
import dataset

if __name__ == "__main__":
    tf.set_random_seed(1237)

    # Load MNIST
    data_path = os.path.join(conf.data_dir, 'mnist.pkl.gz')
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    x_train = np.vstack([x_train, x_valid]).astype('float32')
    np.random.seed(1234)
    x_test = np.random.binomial(1, x_test, size=x_test.shape).astype('float32')
    n_x = x_train.shape[1]

    # Define model parameters
    n_z = 40

    # Define training/evaluation parameters
    lb_samples = 10
    ll_samples = 1000
    epochs = 3000
    batch_size = 100
    iters = x_train.shape[0] // batch_size
    learning_rate = 0.001
Пример #3
0
def run_experiment(args):
    import os
    # set environment variables for tensorflow
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    import inspect
    import shutil
    import numpy as np
    import tensorflow as tf

    from collections import OrderedDict
    import matplotlib.pyplot as plt
    plt.switch_backend('Agg')

    import utils
    import paramgraphics
    import nn
    from tensorflow.contrib.framework.python.ops import arg_scope
    # import tensorflow.contrib.layers as layers

    # ----------------------------------------------------------------
    # Arguments and Settings
    args.message = 'LBT-GAN-smnist_' + args.message
    np.random.seed(12345)
    tf.set_random_seed(args.seed)

    # copy file for reproducibility
    logger, dirname = utils.setup_logging(args)
    script_fn = inspect.getfile(inspect.currentframe())
    script_src = os.path.abspath(script_fn)
    script_dst = os.path.abspath(os.path.join(dirname, script_fn))
    shutil.copyfile(script_src, script_dst)
    logger.info("script copied from %s to %s" % (script_src, script_dst))

    # print arguments
    for k, v in sorted(vars(args).items()):
        logger.info("  %20s: %s" % (k, v))

    # get arguments
    batch_size = args.batch_size
    batch_size_est = args.batch_size_est
    gen_lr = args.gen_lr
    dis_lr = args.dis_lr
    est_lr = args.est_lr
    lambda_gan = args.lambda_gan
    beta1 = 0.5
    epsilon = 1e-8
    max_iter = args.max_iter
    viz_every = args.viz_every
    z_dim, vae_z_dim = utils.get_ints(args.z_dims)
    unrolling_steps = args.unrolling_steps
    assert unrolling_steps > 0
    n_viz = args.n_viz

    # ----------------------------------------------------------------
    # Dataset
    from dataset import load_mnist_realval, DataSet
    train_x, _, test_x, _ = load_mnist_realval(validation=False, asimage=True)
    train_x = np.concatenate([train_x, test_x], 0)

    ids = np.random.randint(0, train_x.shape[0], size=(128000, 3))
    X_training = np.zeros(shape=(ids.shape[0], 28, 28, ids.shape[1]))
    for i in range(ids.shape[0]):
        for j in range(ids.shape[1]):
            X_training[i, :, :, j] = train_x[ids[i, j], :, :, 0]

    smnist = DataSet(X_training, None)

    # data_channel = 3
    x_dim = 784 * 3
    dim_input = (28, 28)
    feature_dim = 16

    # ----------------------------------------------------------------
    # Model setup
    logger.info("Setting up model ...")

    def discriminator(x, Reuse=tf.AUTO_REUSE, is_training=True):
        def leaky_relu(x, alpha=0.2):
            return tf.maximum(alpha * x, x)

        D_feature_dim = int(feature_dim * args.d_ratio)
        with tf.variable_scope("discriminator", reuse=Reuse):

            def bn_layer(x):
                if args.d_bn is True:
                    # print("Use bn in D")
                    return tf.layers.batch_normalization(x,
                                                         training=is_training)
                else:
                    # print("No BN in D")
                    return x

            x = tf.reshape(x, [batch_size, 28, 28, 3])
            conv1 = tf.layers.conv2d(x,
                                     D_feature_dim,
                                     4,
                                     2,
                                     use_bias=True,
                                     padding='same')
            conv1 = leaky_relu(conv1)

            conv2 = tf.layers.conv2d(conv1,
                                     2 * D_feature_dim,
                                     4,
                                     2,
                                     use_bias=False,
                                     padding='same')
            conv2 = bn_layer(conv2)
            conv2 = leaky_relu(conv2)
            conv2 = tf.layers.flatten(conv2)

            fc1 = tf.layers.dense(conv2, 1024, use_bias=False)
            fc1 = bn_layer(fc1)
            fc1 = leaky_relu(fc1)

            fc2 = tf.layers.dense(fc1, 1)
            return fc2

    def generator(z, Reuse=tf.AUTO_REUSE, flatten=True, is_training=True):
        if args.g_nonlin == 'relu':
            # print("Use Relu in G")
            nonlin = tf.nn.relu
        else:
            # print("Use tanh in G")
            nonlin = tf.nn.tanh
        # nonlin = tf.nn.relu if args.g_nonlin == 'relu' else tf.nn.tanh

        # norm_prms = {'is_training': is_training, 'decay': 0.9, 'scale': False}
        with tf.variable_scope("generator", reuse=Reuse):

            # lx = layers.fully_connected(z, 1024)
            lx = tf.layers.dense(z, 1024, use_bias=False)
            lx = tf.layers.batch_normalization(lx, training=is_training)
            lx = nonlin(lx)

            lx = tf.layers.dense(lx, feature_dim * 2 * 7 * 7, use_bias=False)
            lx = tf.layers.batch_normalization(lx, training=is_training)
            lx = nonlin(lx)
            lx = tf.reshape(lx, [-1, 7, 7, feature_dim * 2])

            lx = tf.layers.conv2d_transpose(lx,
                                            feature_dim,
                                            5,
                                            2,
                                            use_bias=False,
                                            padding='same')
            lx = tf.layers.batch_normalization(lx, training=is_training)
            lx = nonlin(lx)

            lx = tf.layers.conv2d_transpose(lx, 3, 5, 2, padding='same')
            lx = tf.nn.sigmoid(lx)

            if flatten is True:
                lx = tf.layers.flatten(lx)
            return lx

    nonlin = tf.nn.relu

    def compute_est_samples(z, params=None, reuse=tf.AUTO_REUSE):
        with tf.variable_scope("estimator"):
            with arg_scope([nn.dense], params=params):
                with tf.variable_scope("decoder", reuse=reuse):
                    h_dec_1 = nn.dense(z,
                                       vae_z_dim,
                                       200 * 2,
                                       "dense1",
                                       nonlinearity=nonlin)
                    h_dec_2 = nn.dense(h_dec_1,
                                       200 * 2,
                                       500 * 2,
                                       "dense2",
                                       nonlinearity=nonlin)
                    x_mean = nn.dense(h_dec_2,
                                      500 * 2,
                                      x_dim,
                                      "dense3",
                                      nonlinearity=None)
                    return x_mean

    def compute_est_ll(x, params=None, reuse=tf.AUTO_REUSE):
        with tf.variable_scope("estimator"):
            with arg_scope([nn.dense], params=params):
                with tf.variable_scope("encoder", reuse=reuse):
                    h_enc_1 = nn.dense(x,
                                       x_dim,
                                       500 * 2,
                                       "dense1",
                                       nonlinearity=nonlin)
                    # h_enc_1 = nn.batch_norm(h_enc_1, "bn1", 129, 2)
                    h_enc_2 = nn.dense(h_enc_1,
                                       500 * 2,
                                       200 * 2,
                                       "dense2",
                                       nonlinearity=nonlin)
                    # h_enc_2 = nn.batch_norm(h_enc_2, "bn2", 128, 2)
                    z_mean = nn.dense(h_enc_2,
                                      200 * 2,
                                      vae_z_dim,
                                      "dense3",
                                      nonlinearity=None)
                    z_logvar = nn.dense(h_enc_2,
                                        200 * 2,
                                        vae_z_dim,
                                        "dense4",
                                        nonlinearity=None)
                epsilon = tf.random_normal(tf.shape(z_mean), dtype=tf.float32)
                z = z_mean + tf.exp(0.5 * z_logvar) * epsilon

                with tf.variable_scope("decoder", reuse=reuse):
                    h_dec_1 = nn.dense(z,
                                       vae_z_dim,
                                       200 * 2,
                                       "dense1",
                                       nonlinearity=nonlin)
                    # h_dec_1 = nn.batch_norm(h_dec_1, "bn1", 127, 2)
                    h_dec_2 = nn.dense(h_dec_1,
                                       200 * 2,
                                       500 * 2,
                                       "dense2",
                                       nonlinearity=nonlin)
                    # h_dec_2 = nn.batch_norm(h_dec_2, "bn2", 128, 2)
                    x_mean = nn.dense(h_dec_2,
                                      500 * 2,
                                      x_dim,
                                      "dense3",
                                      nonlinearity=None)

        elbo = tf.reduce_mean(
            tf.reduce_sum(-tf.nn.sigmoid_cross_entropy_with_logits(
                logits=x_mean, labels=x),
                          axis=1) -
            tf.reduce_sum(-0.5 * (1 + z_logvar - tf.square(z_mean) -
                                  tf.exp(z_logvar)),
                          axis=1))
        return elbo, tf.nn.sigmoid(x_mean)

    def compute_est_updated_with_SGD(x, lr=0.001, params=None):
        elbo, _ = compute_est_ll(x, params=params)
        grads = tf.gradients(elbo, params.values())
        new_params = params.copy()
        for key, g in zip(params, grads):
            new_params[key] += lr * g
        return elbo, new_params

    def compute_est_updated_with_Adam(x,
                                      lr=0.001,
                                      beta_1=0.9,
                                      beta_2=0.999,
                                      epsilon=1e-7,
                                      decay=0.,
                                      params=None,
                                      adam_params=None):
        elbo, _ = compute_est_ll(x, params=params)
        grads = tf.gradients(elbo, params.values())
        new_params = params.copy()
        new_adam_params = adam_params.copy()
        new_adam_params['iterations'] += 1
        lr = lr * \
            (1. / (1. + decay *
                   tf.cast(adam_params['iterations'], tf.float32)))
        t = tf.cast(new_adam_params['iterations'], tf.float32)
        lr_t = lr * (tf.sqrt(1. - tf.pow(beta_2, t)) /
                     (1. - tf.pow(beta_1, t)))
        for key, g in zip(params, grads):
            new_adam_params['m_' + key] = (
                beta_1 * adam_params['m_' + key]) + (1. - beta_1) * g
            new_adam_params['v_' + key] = tf.stop_gradient(
                (beta_2 * adam_params['v_' + key]) +
                (1. - beta_2) * tf.square(g))
            new_params[key] = params[key] + lr_t * new_adam_params[
                'm_' + key] / tf.sqrt(new_adam_params['v_' + key] + epsilon)
        return elbo, new_params, new_adam_params

    lr = tf.placeholder(tf.float32)
    data = tf.placeholder(tf.float32, shape=(batch_size, x_dim))

    # Construct generator and estimator nets
    est_params_dict = OrderedDict()
    _, _ = compute_est_ll(data, params=est_params_dict)
    gen_noise = tf.random_normal((batch_size_est, z_dim), dtype=tf.float32)
    samples_gen = generator(gen_noise)
    vae_noise = tf.random_normal((batch_size_est, vae_z_dim), dtype=tf.float32)
    samples_est = tf.nn.sigmoid(
        compute_est_samples(z=vae_noise, params=est_params_dict))
    # for key in est_params_dict:
    #    print(key, est_params_dict[key])

    adam_params_dict = OrderedDict()
    with tf.variable_scope("adam"):
        adam_params_dict['iterations'] = tf.Variable(0,
                                                     dtype=tf.int64,
                                                     name='iterations')
        for key in est_params_dict:
            adam_params_dict['m_' + key] = tf.Variable(tf.zeros_like(
                est_params_dict[key]),
                                                       name='m_' + key)
            adam_params_dict['v_' + key] = tf.Variable(tf.zeros_like(
                est_params_dict[key]),
                                                       name='v_' + key)

    gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator")
    est_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "estimator")
    adam_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "adam")

    # unrolling estimator updates
    cur_params = est_params_dict
    cur_adam_params = adam_params_dict
    elbo_genx_at_steps = []
    for _ in range(unrolling_steps):
        samples_gen = generator(
            tf.random_normal((batch_size_est, z_dim), dtype=tf.float32))
        elbo_genx_step, cur_params, cur_adam_params = compute_est_updated_with_Adam(
            samples_gen,
            lr=lr,
            beta_1=beta1,
            epsilon=epsilon,
            params=cur_params,
            adam_params=cur_adam_params)
        elbo_genx_at_steps.append(elbo_genx_step)

    # estimator update
    updates = []
    for key in est_params_dict:
        updates.append(tf.assign(est_params_dict[key], cur_params[key]))
    for key in adam_params_dict:
        updates.append(tf.assign(adam_params_dict[key], cur_adam_params[key]))
    e_train_op = tf.group(*updates, name="e_train_op")

    # Optimize the generator on the unrolled ELBO loss
    unrolled_elbo_data, _ = compute_est_ll(data, params=cur_params)
    # unrolled_elbo_samp, _ = compute_est_ll(
    #     tf.stop_gradient(samples_gen), params=cur_params)

    # GAN-loss for discriminator and generator
    samples_gen_gan = generator(
        tf.random_normal((batch_size_est, z_dim), dtype=tf.float32))
    fake_D_output = discriminator(samples_gen_gan)
    real_D_output = discriminator(data)
    # print(fake_D_output, real_D_output)
    ganloss_g = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(fake_D_output), logits=fake_D_output))
    ganloss_D_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.zeros_like(fake_D_output), logits=fake_D_output))
    ganloss_D_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(real_D_output), logits=real_D_output))

    use_e_sym = tf.placeholder(tf.float32, shape=(), name="use_E")
    if args.lbt:
        logger.info("Using lbt")
        object_g = lambda_gan * ganloss_g - use_e_sym * unrolled_elbo_data
    else:
        logger.info("Using GAN")
        object_g = lambda_gan * ganloss_g  # - use_e_sym * unrolled_elbo_data

    # object_g = -1 * unrolled_elbo_data
    object_d = ganloss_D_fake + ganloss_D_real
    dis_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 "discriminator")

    g_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, "generator")
    g_train_opt = tf.train.AdamOptimizer(learning_rate=gen_lr,
                                         beta1=beta1,
                                         epsilon=epsilon)
    # g_train_opt = tf.train.RMSPropOptimizer(learning_rate=gen_lr, epsilon=epsilon)
    g_grads = g_train_opt.compute_gradients(object_g, var_list=gen_vars)
    # g_grads_clipped = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in g_grads]
    g_grads_, g_vars_ = zip(*g_grads)
    g_grads_clipped_, g_grads_norm_ = tf.clip_by_global_norm(g_grads_, 5.)
    g_grads_clipped = zip(g_grads_clipped_, g_vars_)
    if args.clip_grad:
        logger.info("Clipping gradients of generator parameters.")
        with tf.control_dependencies(g_update_ops):
            g_train_op = g_train_opt.apply_gradients(g_grads_clipped)
    else:
        with tf.control_dependencies(g_update_ops):
            g_train_op = g_train_opt.apply_gradients(g_grads)
        # g_train_op = g_train_opt.apply_gradients(g_grads)

    d_train_opt = tf.train.AdamOptimizer(learning_rate=dis_lr,
                                         beta1=beta1,
                                         epsilon=epsilon)
    d_update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS, "discriminator")
    with tf.control_dependencies(d_update_op):
        d_train_op = d_train_opt.minimize(object_d, var_list=dis_vars)

    # ----------------------------------------------------------------
    # Training
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(max_to_keep=None)
    if args.model_path:
        saver.restore(sess, args.model_path)

    # # print variables
    # logger.info("Generator parameters:")
    # for p in gen_vars:
    #     logger.debug("%s: %s" % (p.name, sess.run(tf.shape(p))))
    # logger.info("Estimator parameters:")
    # for p in est_vars:
    #     logger.debug("%s: %s" % (p.name, sess.run(tf.shape(p))))
    # logger.info("Adam parameters:")
    # for p in adam_vars:
    #     logger.debug("%s: %s" % (p.name, sess.run(tf.shape(p))))

    elbo_vals = []
    ganloss_vals = []
    tgan_g, tgan_d_fake, tgan_d_real = 0., 0., 0.
    elbo_genx_val, elbo_data_val, gradients_nrom = -np.inf, -np.inf, 0
    use_e_flag = 0.

    for i in range(max_iter + 1):

        # train estimator and generator
        # x_mini_batch0 = mnist.train.next_batch(batch_size)[0].reshape(
        #     [batch_size, 28, 28, 1])
        # x_mini_batch1 = mnist.train.next_batch(batch_size)[0].reshape(
        #     [batch_size, 28, 28, 1])
        # x_mini_batch2 = mnist.train.next_batch(batch_size)[0].reshape(
        #     [batch_size, 28, 28, 1])
        # x_mini_batch = np.concatenate(
        #     [x_mini_batch0, x_mini_batch1, x_mini_batch2],
        #     axis=-1).reshape([batch_size, 28 * 28 * 3])
        x_mini_batch = smnist.next_batch(batch_size)[0].reshape(
            [batch_size, 28 * 28 * 3])

        if i > 3000:
            use_e_flag = 1.
            for _ in range(args.n_est):
                elbo_genx_val, _ = sess.run(
                    [elbo_genx_at_steps[-1], e_train_op],
                    feed_dict={lr: 3. * est_lr})

        for _ in range(args.n_dis):
            _, tgan_g, tgan_d_real, tgan_d_fake = sess.run(
                [d_train_op, ganloss_g, ganloss_D_real, ganloss_D_fake],
                feed_dict={data: x_mini_batch})

        elbo_data_val, gradients_nrom, _ = sess.run(
            [unrolled_elbo_data, g_grads_norm_, g_train_op],
            feed_dict={
                data: x_mini_batch,
                lr: est_lr,
                use_e_sym: use_e_flag
            })
        elbo_vals.append([elbo_genx_val, elbo_data_val])
        ganloss_vals.append([tgan_g, tgan_d_real, tgan_d_fake])

        # visualization
        if i % viz_every == 0:
            np_samples_gen, np_samples_est, np_data = sess.run(
                [samples_gen, samples_est, data],
                feed_dict={data: x_mini_batch})
            np_samples_est = np_samples_est.reshape([-1, 28, 28, 3]).transpose(
                [0, 3, 1, 2]).reshape([-1, 28 * 28 * 3])
            np_samples_gen = np_samples_gen.reshape([-1, 28, 28, 3]).transpose(
                [0, 3, 1, 2]).reshape([-1, 28 * 28 * 3])
            np_data = np_data.reshape([-1, 28, 28, 3]).transpose(
                [0, 3, 1, 2]).reshape([-1, 28 * 28 * 3])

            paramgraphics.mat_to_img(np_samples_gen[:n_viz],
                                     dim_input,
                                     colorImg=True,
                                     save_path=os.path.join(
                                         dirname,
                                         'sample_' + str(i) + '_gen.png'))
            paramgraphics.mat_to_img(np_data[:n_viz],
                                     dim_input,
                                     colorImg=True,
                                     save_path=os.path.join(
                                         dirname,
                                         'sample_' + str(i) + '_dat.png'))
            paramgraphics.mat_to_img(np_samples_est[:n_viz],
                                     dim_input,
                                     colorImg=True,
                                     save_path=os.path.join(
                                         dirname,
                                         'sample_' + str(i) + '_est.png'))

            fig = plt.figure(figsize=(6, 4))
            plt.plot(elbo_vals,
                     '.',
                     markersize=2,
                     markeredgecolor='none',
                     linestyle='none',
                     alpha=min(1.0, 0.01 * max_iter / (i + 1)))
            plt.ylim((-200.0, 0.0))
            legend = plt.legend(('elbo_genx', 'elbo_data'), markerscale=6)
            for lh in legend.legendHandles:
                lh._legmarker.set_alpha(1.)
            plt.grid(True)
            plt.tight_layout()
            plt.savefig(os.path.join(dirname, 'curve.png'),
                        bbox_inches='tight')
            plt.close(fig)

        # training log
        if i % viz_every == 0:
            elbo_genx_ma_val, elbo_data_ma_val = np.mean(elbo_vals[-200:],
                                                         axis=0)
            logger.info(
                "Iter %d: gradients norm = %.4f. samples LL = %.4f, data LL = %.4f."
                % (i, gradients_nrom, elbo_genx_ma_val, elbo_data_ma_val))
            logger.info(
                "Iter %d: gan_g = %.4f. gan_d_real = %.4f, gan_d_fake = %.4f."
                % (i, tgan_g, tgan_d_real, tgan_d_fake))

        if i % args.model_every == 0:
            saver.save(sess, os.path.join(dirname, 'model_' + str(i)))
Пример #4
0
def print_param(param_name):
    val = sess.run(param_name)
    if 'log' in param_name:
        val = np.exp(val)
    val_normalized = val / np.sum(val)
    print('{}: {}'.format(param_name, val_normalized))


if __name__ == "__main__":
    tf.set_random_seed(666)
    np.random.seed(666)

    # Load data from MNIST
    data_dir = './data'
    data_path = os.path.join(data_dir, 'mnist.pkl.gz')
    x_train, t_train, x_val, t_val, x_test, t_test = dataset.load_mnist_realval(
        data_path)
    x_train = np.vstack([x_train, x_val]).astype('float32')
    n_x = x_train.shape[1]  # 784=28*28

    # Define model parameters
    n_h = 40  # D
    n_z = 10  # K

    # Define training/evaluation parameters
    lb_samples = 10
    epoches = 100
    batch_size = 100
    iters = x_train.shape[0] // batch_size
    learning_rate = 0.001
    save_freq = 20
    ckpt_path = "./ckpt/10x10_2"
Пример #5
0
def main():
    # Load MNIST
    data_path = os.path.join(conf.data_dir, "mnist.pkl.gz")
    x_train, t_train, x_valid, t_valid, x_test, t_test = \
        dataset.load_mnist_realval(data_path)
    x_train = np.vstack([x_train, x_valid])
    x_test = np.random.binomial(1, x_test, size=x_test.shape)
    x_dim = x_train.shape[1]

    # Define model parameters
    z_dim = 40

    # Build the computation graph
    n_particles = tf.placeholder(tf.int32, shape=[], name="n_particles")
    x_input = tf.placeholder(tf.float32, shape=[None, x_dim], name="x")
    x = tf.cast(tf.less(tf.random_uniform(tf.shape(x_input)), x_input),
                tf.int32)
    n = tf.placeholder(tf.int32, shape=[], name="n")

    model = build_gen(x_dim, z_dim, n, n_particles)
    variational = build_q_net(x, z_dim, n_particles)

    lower_bound = zs.variational.elbo(model, {"x": x},
                                      variational=variational,
                                      axis=0)
    cost = tf.reduce_mean(lower_bound.sgvb())
    lower_bound = tf.reduce_mean(lower_bound)

    # # Importance sampling estimates of marginal log likelihood
    is_log_likelihood = tf.reduce_mean(
        zs.is_loglikelihood(model, {"x": x}, proposal=variational, axis=0))

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    infer_op = optimizer.minimize(cost)

    # Random generation
    x_gen = tf.reshape(model.observe()["x_mean"], [-1, 28, 28, 1])

    # Define training/evaluation parameters
    epochs = 3000
    batch_size = 128
    iters = x_train.shape[0] // batch_size
    save_freq = 10
    test_freq = 10
    test_batch_size = 400
    test_iters = x_test.shape[0] // test_batch_size
    result_path = "results/vae"
    if not os.path.exists(result_path):
        os.makedirs(result_path)

    # Run the inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(1, epochs + 1):
            time_epoch = -time.time()
            np.random.shuffle(x_train)
            lbs = []
            for t in range(iters):
                x_batch = x_train[t * batch_size:(t + 1) * batch_size]
                _, lb = sess.run([infer_op, lower_bound],
                                 feed_dict={
                                     x_input: x_batch,
                                     n_particles: 1,
                                     n: batch_size
                                 })
                lbs.append(lb)
            time_epoch += time.time()
            print("Epoch {} ({:.1f}s): Lower bound = {}".format(
                epoch, time_epoch, np.mean(lbs)))

            if epoch % test_freq == 0:
                time_test = -time.time()
                test_lbs, test_lls = [], []
                for t in range(test_iters):
                    test_x_batch = x_test[t * test_batch_size:(t + 1) *
                                          test_batch_size]
                    test_lb = sess.run(lower_bound,
                                       feed_dict={
                                           x: test_x_batch,
                                           n_particles: 1,
                                           n: test_batch_size
                                       })
                    test_ll = sess.run(is_log_likelihood,
                                       feed_dict={
                                           x: test_x_batch,
                                           n_particles: 1000,
                                           n: test_batch_size
                                       })
                    test_lbs.append(test_lb)
                    test_lls.append(test_ll)
                time_test += time.time()
                print(">>> TEST ({:.1f}s)".format(time_test))
                print(">> Test lower bound = {}".format(np.mean(test_lbs)))
                print('>> Test log likelihood (IS) = {}'.format(
                    np.mean(test_lls)))

            if epoch % save_freq == 0:
                images = sess.run(x_gen, feed_dict={n: 100, n_particles: 1})
                name = os.path.join(result_path,
                                    "vae.epoch.{}.png".format(epoch))
                save_image_collections(images, name)
Пример #6
0
def train_vae(args):
    # Load MNIST
    data_path = os.path.join(args.data_dir, "mnist.pkl.gz")
    x_train, y_train, x_valid, y_valid, x_test, y_test = dataset.load_mnist_realval(data_path)
    x_train = np.random.binomial(1, x_train, size=x_train.shape)
    x_dim = x_train.shape[1]
    y_dim = y_train.shape[1]

    # Define model parameters
    z_dim = args.z_dim

    # Build the computation graph
    x = tf.placeholder(tf.float32, shape=[None, x_dim], name="x")
    y = tf.placeholder(tf.float32, shape=[None, y_dim], name="y")
    n = tf.placeholder(tf.int32, shape=[], name="n")

    # Get the models
    model = build_gen(y, x_dim, z_dim, n)
    variational = build_q_net(x, y, z_dim)

    # Calculate ELBO
    lower_bound = zs.variational.elbo(model, {"x": x }, variational=variational)
    cost = tf.reduce_mean(lower_bound.sgvb())
    lower_bound = tf.reduce_mean(lower_bound)

    optimizer = tf.train.AdamOptimizer(learning_rate=args.lr)
    infer_op = optimizer.minimize(cost)

    # Random generation
    x_gen = tf.reshape(model.observe()["x_mean"], [-1, 28, 28, 1])

    # Compute class labels
    labels = []
    for c in range(10):
        l = np.zeros((100, 10))
        l[:,c] = 1
        labels.append(l)

    epochs = args.epochs
    batch_size = args.batch_size
    iters = x_train.shape[0] // batch_size

    saver = tf.train.Saver(max_to_keep=10)
    save_model_freq = min(100, args.epochs)

    # Run the Inference
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        ckpt_file = tf.train.latest_checkpoint(args.checkpoints_path)
        begin_epoch = 1

        if(ckpt_file is not None):
            print('Restoring model from {}...'.format(ckpt_file))
            begin_epoch = int(ckpt_file.split('.')[-2]) + 1
            saver.restore(sess, ckpt_file)

        for epoch in range(1, epochs+1):
            time_epoch = -time.time()
            lbs = []
            for t in range(iters):
                x_batch = x_train[t*batch_size:(t+1)*batch_size]
                y_batch = y_train[t*batch_size:(t+1)*batch_size]

                _, lb = sess.run(
                    [infer_op, lower_bound],
                    feed_dict={
                        x: x_batch,
                        y: y_batch,
                        n: batch_size
                    }
                )
                lbs.append(lb)
            
            time_epoch += time.time()
            print("Epoch {} ({:.1f}s): Lower bound = {}".format(epoch, time_epoch, np.mean(lbs)))

            if(epoch % args.save_model_freq == 0):
                save_path = os.path.join(args.checkpoints_path, "vae.epoch.{}.ckpt".format(epoch))
                if not os.path.exists(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path))
                saver.save(sess, save_path)
            
            if epoch % args.save_img_freq == 0:
                for c in range(10):
                    images = sess.run(x_gen, feed_dict={y: labels[c], n: 100 })
                    name = os.path.join(args.results_path, str(epoch).zfill(3), "{}.png".format(c))
                    utils.save_image_collections(images, name)