Пример #1
0
def train_prior(config,
                RANDOM_SEED,
                MODEL,
                TRAIN_NUM,
                BATCH_SIZE,
                LEARNING_RATE,
                DECAY_VAL,
                DECAY_STEPS,
                DECAY_STAIRCASE,
                GRAD_CLIP,
                K,
                D,
                BETA,
                NUM_LAYERS,
                NUM_FEATURE_MAPS,
                SUMMARY_PERIOD,
                SAVE_PERIOD,
                **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    LOG_DIR = os.path.join(os.path.dirname(MODEL),'pixelcnn')
    # >>>>>>> DATASET
    class Latents():
        def __init__(self,path,validation_size=5000):
            from tensorflow.contrib.learn.python.learn.datasets.mnist import DataSet
            from tensorflow.contrib.learn.python.learn.datasets import base

            data = np.load(path)
            train = DataSet(data['ks'][validation_size:], data['ys'][validation_size:],reshape=False,dtype=np.uint8,one_hot=False) #dtype won't bother even in the case when latent is int32 type.
            validation = DataSet(data['ks'][:validation_size], data['ys'][:validation_size],reshape=False,dtype=np.uint8,one_hot=False)
            #test = DataSet(data['test_x'],np.argmax(data['test_y'],axis=1),reshape=False,dtype=np.float32,one_hot=False)
            self.size = data['ks'].shape[1]
            self.data = base.Datasets(train=train, validation=validation, test=None)
    latent = Latents(os.path.join(os.path.dirname(MODEL),'ks_ys.npz'))
    # <<<<<<<

    # >>>>>>> MODEL for Generate Images
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        _not_used = tf.placeholder(tf.float32,[None,24,24,1])
        tau_notused = 0.5
        vq_net = GumbelVAE(tau_notused,None,None,BETA,_not_used,K,D,_mnist_arch,params,'decode')
    # <<<<<<<

    # >>>>>> MODEL for Training Prior
    with tf.variable_scope('pixelcnn'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr',learning_rate)

        #net = PixelCNN(learning_rate,global_step,grad_clip,latent_data.size,vq_net.embeds,K,D,10,num_layers,num_feature_maps)
        net = PixelCNN(learning_rate,global_step,GRAD_CLIP,
                       latent.size,vq_net.embeds,K,D,
                       10,NUM_LAYERS,NUM_FEATURE_MAPS)
    # <<<<<<
    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss',net.loss)
        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor(config.as_matrix()), collections=[])

        sample_images = tf.placeholder(tf.float32,[None,24,24,1])
        sample_summary_op = tf.summary.image('samples',sample_images,max_outputs=20)

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    #sess.graph.finalize()
    sess.run(init_op)
    vq_net.load(sess,MODEL)

    summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    for step in tqdm(xrange(TRAIN_NUM),dynamic_ncols=True):
        batch_xs, batch_ys = latent.data.train.next_batch(BATCH_SIZE)
        it,loss,_ = sess.run([global_step,net.loss,net.train_op],feed_dict={net.X:batch_xs,net.h:batch_ys})

        if( it % SAVE_PERIOD == 0 ):
            net.save(sess,LOG_DIR,step=it)

        if( it % SUMMARY_PERIOD == 0 ):
            tqdm.write('[%5d] Loss: %1.3f'%(it,loss))
            summary = sess.run(summary_op,feed_dict={net.X:batch_xs,net.h:batch_ys})
            summary_writer.add_summary(summary,it)

        if( it % (SUMMARY_PERIOD * 2) == 0 ):
            sampled_zs,log_probs = net.sample_from_prior(sess,np.arange(10),2)
            sampled_ims = sess.run(vq_net.gen,feed_dict={vq_net.latent:sampled_zs})
            summary_writer.add_summary(
                sess.run(sample_summary_op,feed_dict={sample_images:sampled_ims}),it)

    net.save(sess,LOG_DIR)
Пример #2
0
def train_prior(config, RANDOM_SEED, MODEL, TRAIN_NUM, BATCH_SIZE,
                LEARNING_RATE, DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE,
                GRAD_CLIP, K, D, BETA, NUM_LAYERS, NUM_FEATURE_MAPS,
                SUMMARY_PERIOD, SAVE_PERIOD, **kwargs):
    np.random.seed(RANDOM_SEED)
    tf.set_random_seed(RANDOM_SEED)
    LOG_DIR = os.path.join(os.path.dirname(MODEL), 'pixelcnn')
    # >>>>>>> DATASET
    train_dataset = imagenet.get_split('train', 'datasets/ILSVRC2012')
    ims, labels = _build_batch(train_dataset, BATCH_SIZE, 4)
    # <<<<<<<

    # >>>>>>> MODEL for Generate Images
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        vq_net = VQVAE(None, None, BETA, ims, K, D, _imagenet_arch, params,
                       False)
    # <<<<<<<

    # >>>>>> MODEL for Training Prior
    with tf.variable_scope('pixelcnn'):
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.exponential_decay(LEARNING_RATE,
                                                   global_step,
                                                   DECAY_STEPS,
                                                   DECAY_VAL,
                                                   staircase=DECAY_STAIRCASE)
        tf.summary.scalar('lr', learning_rate)

        net = PixelCNN(learning_rate, global_step, GRAD_CLIP,
                       vq_net.k.get_shape()[1], vq_net.embeds, K, D, 1000,
                       NUM_LAYERS, NUM_FEATURE_MAPS)
    # <<<<<<
    with tf.variable_scope('misc'):
        # Summary Operations
        tf.summary.scalar('loss', net.loss)
        summary_op = tf.summary.merge_all()

        # Initialize op
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config_summary = tf.summary.text('TrainConfig',
                                         tf.convert_to_tensor(
                                             config.as_matrix()),
                                         collections=[])

        sample_images = tf.placeholder(tf.float32, [None, 128, 128, 3])
        sample_summary_op = tf.summary.image('samples',
                                             sample_images,
                                             max_outputs=20)

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    vq_net.load(sess, MODEL)

    summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph)
    summary_writer.add_summary(config_summary.eval(session=sess))

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    try:
        for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True):
            batch_xs, batch_ys = sess.run([vq_net.k, labels])
            it, loss, _ = sess.run([global_step, net.loss, net.train_op],
                                   feed_dict={
                                       net.X: batch_xs,
                                       net.h: batch_ys
                                   })

            if (it % SAVE_PERIOD == 0):
                net.save(sess, LOG_DIR, step=it)
                sampled_zs, log_probs = net.sample_from_prior(
                    sess, np.random.randint(0, 1000, size=(10, )), 2)
                sampled_ims = sess.run(vq_net.gen,
                                       feed_dict={vq_net.latent: sampled_zs})
                summary_writer.add_summary(
                    sess.run(sample_summary_op,
                             feed_dict={sample_images: sampled_ims}), it)

            if (it % SUMMARY_PERIOD == 0):
                tqdm.write('[%5d] Loss: %1.3f' % (it, loss))
                summary = sess.run(summary_op,
                                   feed_dict={
                                       net.X: batch_xs,
                                       net.h: batch_ys
                                   })
                summary_writer.add_summary(summary, it)

    except Exception as e:
        coord.request_stop(e)
    finally:
        net.save(sess, LOG_DIR)

        coord.request_stop()
        coord.join(threads)