def train_prior(config, RANDOM_SEED, MODEL, TRAIN_NUM, BATCH_SIZE, LEARNING_RATE, DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE, GRAD_CLIP, K, D, BETA, NUM_LAYERS, NUM_FEATURE_MAPS, SUMMARY_PERIOD, SAVE_PERIOD, **kwargs): np.random.seed(RANDOM_SEED) tf.set_random_seed(RANDOM_SEED) LOG_DIR = os.path.join(os.path.dirname(MODEL),'pixelcnn') # >>>>>>> DATASET class Latents(): def __init__(self,path,validation_size=5000): from tensorflow.contrib.learn.python.learn.datasets.mnist import DataSet from tensorflow.contrib.learn.python.learn.datasets import base data = np.load(path) train = DataSet(data['ks'][validation_size:], data['ys'][validation_size:],reshape=False,dtype=np.uint8,one_hot=False) #dtype won't bother even in the case when latent is int32 type. validation = DataSet(data['ks'][:validation_size], data['ys'][:validation_size],reshape=False,dtype=np.uint8,one_hot=False) #test = DataSet(data['test_x'],np.argmax(data['test_y'],axis=1),reshape=False,dtype=np.float32,one_hot=False) self.size = data['ks'].shape[1] self.data = base.Datasets(train=train, validation=validation, test=None) latent = Latents(os.path.join(os.path.dirname(MODEL),'ks_ys.npz')) # <<<<<<< # >>>>>>> MODEL for Generate Images with tf.variable_scope('net'): with tf.variable_scope('params') as params: pass _not_used = tf.placeholder(tf.float32,[None,24,24,1]) tau_notused = 0.5 vq_net = GumbelVAE(tau_notused,None,None,BETA,_not_used,K,D,_mnist_arch,params,'decode') # <<<<<<< # >>>>>> MODEL for Training Prior with tf.variable_scope('pixelcnn'): global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE) tf.summary.scalar('lr',learning_rate) #net = PixelCNN(learning_rate,global_step,grad_clip,latent_data.size,vq_net.embeds,K,D,10,num_layers,num_feature_maps) net = PixelCNN(learning_rate,global_step,GRAD_CLIP, latent.size,vq_net.embeds,K,D, 10,NUM_LAYERS,NUM_FEATURE_MAPS) # <<<<<< with tf.variable_scope('misc'): # Summary Operations tf.summary.scalar('loss',net.loss) summary_op = tf.summary.merge_all() # Initialize op init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor(config.as_matrix()), collections=[]) sample_images = tf.placeholder(tf.float32,[None,24,24,1]) sample_summary_op = tf.summary.image('samples',sample_images,max_outputs=20) # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) #sess.graph.finalize() sess.run(init_op) vq_net.load(sess,MODEL) summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph) summary_writer.add_summary(config_summary.eval(session=sess)) for step in tqdm(xrange(TRAIN_NUM),dynamic_ncols=True): batch_xs, batch_ys = latent.data.train.next_batch(BATCH_SIZE) it,loss,_ = sess.run([global_step,net.loss,net.train_op],feed_dict={net.X:batch_xs,net.h:batch_ys}) if( it % SAVE_PERIOD == 0 ): net.save(sess,LOG_DIR,step=it) if( it % SUMMARY_PERIOD == 0 ): tqdm.write('[%5d] Loss: %1.3f'%(it,loss)) summary = sess.run(summary_op,feed_dict={net.X:batch_xs,net.h:batch_ys}) summary_writer.add_summary(summary,it) if( it % (SUMMARY_PERIOD * 2) == 0 ): sampled_zs,log_probs = net.sample_from_prior(sess,np.arange(10),2) sampled_ims = sess.run(vq_net.gen,feed_dict={vq_net.latent:sampled_zs}) summary_writer.add_summary( sess.run(sample_summary_op,feed_dict={sample_images:sampled_ims}),it) net.save(sess,LOG_DIR)
def train_prior(config, RANDOM_SEED, MODEL, TRAIN_NUM, BATCH_SIZE, LEARNING_RATE, DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE, GRAD_CLIP, K, D, BETA, NUM_LAYERS, NUM_FEATURE_MAPS, SUMMARY_PERIOD, SAVE_PERIOD, **kwargs): np.random.seed(RANDOM_SEED) tf.set_random_seed(RANDOM_SEED) LOG_DIR = os.path.join(os.path.dirname(MODEL), 'pixelcnn') # >>>>>>> DATASET train_dataset = imagenet.get_split('train', 'datasets/ILSVRC2012') ims, labels = _build_batch(train_dataset, BATCH_SIZE, 4) # <<<<<<< # >>>>>>> MODEL for Generate Images with tf.variable_scope('net'): with tf.variable_scope('params') as params: pass vq_net = VQVAE(None, None, BETA, ims, K, D, _imagenet_arch, params, False) # <<<<<<< # >>>>>> MODEL for Training Prior with tf.variable_scope('pixelcnn'): global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE) tf.summary.scalar('lr', learning_rate) net = PixelCNN(learning_rate, global_step, GRAD_CLIP, vq_net.k.get_shape()[1], vq_net.embeds, K, D, 1000, NUM_LAYERS, NUM_FEATURE_MAPS) # <<<<<< with tf.variable_scope('misc'): # Summary Operations tf.summary.scalar('loss', net.loss) summary_op = tf.summary.merge_all() # Initialize op init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor( config.as_matrix()), collections=[]) sample_images = tf.placeholder(tf.float32, [None, 128, 128, 3]) sample_summary_op = tf.summary.image('samples', sample_images, max_outputs=20) # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.graph.finalize() sess.run(init_op) vq_net.load(sess, MODEL) summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph) summary_writer.add_summary(config_summary.eval(session=sess)) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord, sess=sess) try: for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True): batch_xs, batch_ys = sess.run([vq_net.k, labels]) it, loss, _ = sess.run([global_step, net.loss, net.train_op], feed_dict={ net.X: batch_xs, net.h: batch_ys }) if (it % SAVE_PERIOD == 0): net.save(sess, LOG_DIR, step=it) sampled_zs, log_probs = net.sample_from_prior( sess, np.random.randint(0, 1000, size=(10, )), 2) sampled_ims = sess.run(vq_net.gen, feed_dict={vq_net.latent: sampled_zs}) summary_writer.add_summary( sess.run(sample_summary_op, feed_dict={sample_images: sampled_ims}), it) if (it % SUMMARY_PERIOD == 0): tqdm.write('[%5d] Loss: %1.3f' % (it, loss)) summary = sess.run(summary_op, feed_dict={ net.X: batch_xs, net.h: batch_ys }) summary_writer.add_summary(summary, it) except Exception as e: coord.request_stop(e) finally: net.save(sess, LOG_DIR) coord.request_stop() coord.join(threads)