Exemple #1
0
def main(config, RANDOM_SEED, LOG_DIR, TRAIN_NUM, BATCH_SIZE, LEARNING_RATE,
         DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE, BETA, K, D, SAVE_PERIOD,
         SUMMARY_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)

    # >>>>>>> DATASET
    image, _ = get_image()
    images = tf.train.shuffle_batch([image],
                                    batch_size=BATCH_SIZE,
                                    num_threads=4,
                                    capacity=BATCH_SIZE * 10,
                                    min_after_dequeue=BATCH_SIZE * 2)
    valid_image, _ = get_image(False)
    valid_images = tf.train.shuffle_batch([valid_image],
                                          batch_size=BATCH_SIZE,
                                          num_threads=1,
                                          capacity=BATCH_SIZE * 10,
                                          min_after_dequeue=BATCH_SIZE * 2)
    # <<<<<<<

    # >>>>>>> MODEL
    with tf.variable_scope('train'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        with tf.variable_scope('params') as params:
            pass
        net = VQVAE(learning_rate, global_step, BETA, images, K, D,
                    _cifar10_arch, params, True)

    with tf.variable_scope('valid'):
        params.reuse_variables()
        valid_net = VQVAE(None, None, BETA, valid_images, K, D, _cifar10_arch,
                          params, False)

    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        tf.summary.scalar('recon', net.recon)
        tf.summary.scalar('vq', net.vq)
        tf.summary.scalar('commit', BETA * net.commit)
        tf.summary.scalar('nll', tf.reduce_mean(net.nll))
        tf.summary.image('origin', images, max_outputs=4)
        tf.summary.image('recon', net.p_x_z, max_outputs=4)
        # TODO: logliklihood

        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        extended_summary_op = tf.summary.merge([
            tf.summary.scalar('valid_loss', valid_net.loss),
            tf.summary.scalar('valid_recon', valid_net.recon),
            tf.summary.scalar('valid_vq', valid_net.vq),
            tf.summary.scalar('valid_commit', BETA * valid_net.commit),
            tf.summary.scalar('valid_nll', tf.reduce_mean(valid_net.nll)),
            tf.summary.image('valid_origin', valid_images, max_outputs=4),
            tf.summary.image('valid_recon', valid_net.p_x_z, max_outputs=4),
        ])

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    try:
        # Start Queueing
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord, sess=sess)
        for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
            it, loss, _ = sess.run([global_step, net.loss, net.train_op])

            if (it % SAVE_PERIOD == 0):
                net.save(sess, LOG_DIR, step=it)

            if (it % SUMMARY_PERIOD == 0):
                tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
                summary = sess.run(summary_op)
                summary_writer.add_summary(summary, it)

            if (it % (SUMMARY_PERIOD * 2) == 0):  #Extended Summary
                summary = sess.run(extended_summary_op)
                summary_writer.add_summary(summary, it)

    except Exception as e:
        coord.request_stop(e)
    finally:
        net.save(sess, LOG_DIR)

        coord.request_stop()
        coord.join(threads)
Exemple #2
0
def main(config, RANDOM_SEED, LOG_DIR, TRAIN_NUM, BATCH_SIZE, LEARNING_RATE,
         DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE, BETA, K, D, SAVE_PERIOD,
         SUMMARY_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)

    # >>>>>>> DATASET
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("datasets/mnist", one_hot=False)
    # <<<<<<<

    # >>>>>>> MODEL
    x = tf.placeholder(tf.float32, [None, 784])
    resized = tf.image.resize_images(tf.reshape(x, [-1, 28, 28, 1]), (24, 24),
                                     method=tf.image.ResizeMethod.BILINEAR)

    with tf.variable_scope('train'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        with tf.variable_scope('params') as params:
            pass
        net = VQVAE(learning_rate, global_step, BETA, resized, K, D,
                    _mnist_arch, params, True)

    with tf.variable_scope('valid'):
        params.reuse_variables()
        valid_net = VQVAE(None, None, BETA, resized, K, D, _mnist_arch, params,
                          False)

    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        tf.summary.scalar('recon', net.recon)
        tf.summary.scalar('vq', net.vq)
        tf.summary.scalar('commit', BETA * net.commit)
        tf.summary.image('origin', resized, max_outputs=4)
        tf.summary.image('recon', net.p_x_z, max_outputs=4)
        # TODO: logliklihood

        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        extended_summary_op = tf.summary.merge([
            tf.summary.scalar('valid_loss', valid_net.loss),
            tf.summary.scalar('valid_recon', valid_net.recon),
            tf.summary.scalar('valid_vq', valid_net.vq),
            tf.summary.scalar('valid_commit', BETA * valid_net.commit),
            tf.summary.image('valid_recon', valid_net.p_x_z, max_outputs=10),
        ])

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
        batch_xs, _ = mnist.train.next_batch(BATCH_SIZE)
        it, loss, _ = sess.run([global_step, net.loss, net.train_op],
                               feed_dict={x: batch_xs})

        if (it % SAVE_PERIOD == 0):
            net.save(sess, LOG_DIR, step=it)

        if (it % SUMMARY_PERIOD == 0):
            tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
            summary = sess.run(summary_op, feed_dict={x: batch_xs})
            summary_writer.add_summary(summary, it)

        if (it % (SUMMARY_PERIOD * 2) == 0):  #Extended Summary
            batch_xs, _ = mnist.test.next_batch(BATCH_SIZE)
            summary = sess.run(extended_summary_op, feed_dict={x: batch_xs})
            summary_writer.add_summary(summary, it)

    net.save(sess, LOG_DIR)