def p_net(observed=None, n_z=None, is_training=True): logging.info('p_net builder: %r', locals()) net = BayesianNet(observed=observed) # sample z ~ p(z) z = net.add('z', Normal(mean=tf.zeros([1, config.z_dim]), logstd=tf.zeros([1, config.z_dim])), group_ndims=1, n_samples=n_z) # compute the hidden features with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): h_x, s1, s2 = flatten(z, 2) h_x = dense(h_x, 500) h_x = dense(h_x, 500) # sample x ~ p(x|z) x_logits = unflatten(dense(h_x, config.x_dim, name='x_logits'), s1, s2) x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1) return net
def p_net(config, observed=None, n_z=None, is_training=True): net = BayesianNet(observed=observed) # sample z ~ p(z) z = net.add('z', Bernoulli(tf.zeros([1, config.z_dim])), group_ndims=1, n_samples=n_z) # compute the hidden features with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): z = tf.to_float(z) h_z, s1, s2 = flatten(z, 2) h_z = dense(h_z, 500) h_z = dense(h_z, 500) # sample x ~ p(x|z) x_logits = unflatten(dense(h_z, config.x_dim, name='x_logits'), s1, s2) x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1) return net
def p_net(observed=None, n_y=None, n_z=None, tau=None, is_training=True, n_samples=None): if n_samples is not None: warnings.warn('`n_samples` is deprecated, use `n_y` instead.') n_y = n_samples use_concrete = config.use_concrete_distribution and tau is not None logging.info('p_net builder: %r', locals()) net = BayesianNet(observed=observed) # sample y if use_concrete: y = net.add('y', ExpConcrete(tau, tf.zeros([1, config.n_clusters])), n_samples=n_y, is_reparameterized=True) else: y = net.add('y', Categorical(tf.zeros([1, config.n_clusters])), n_samples=n_y) # sample z ~ p(z|y) z = net.add('z', gaussian_mixture_prior(y, config.z_dim, config.n_clusters, use_concrete=use_concrete), group_ndims=1, n_samples=n_z, is_reparameterized=use_concrete) # compute the hidden features for x with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): h_x, s1, s2 = flatten(z, 2) h_x = dense(h_x, 500) h_x = dense(h_x, 500) # sample x ~ p(x|z) x_logits = unflatten(dense(h_x, config.x_dim, name='x_logits'), s1, s2) x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1) return net
def q_net(config, x, observed=None, n_z=None, is_training=True): net = BayesianNet(observed=observed) # compute the hidden features with arg_scope([dense], activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg)): h_x = tf.to_float(x) h_x = dense(h_x, 500) h_x = dense(h_x, 500) # sample z ~ q(z|x) z_logits = dense(h_x, config.z_dim, name='z_logits') z = net.add('z', Bernoulli(logits=z_logits), n_samples=n_z, group_ndims=1) return net
def p_net(config, observed=None, n_z=None, is_training=True, channels_last=False): net = BayesianNet(observed=observed) # sample z ~ p(z) z = net.add('z', Normal(mean=tf.zeros([1, config.z_dim]), logstd=tf.zeros([1, config.z_dim])), group_ndims=1, n_samples=n_z) # compute the hidden features with arg_scope([deconv_resnet_block], shortcut_kernel_size=config.shortcut_kernel_size, activation_fn=tf.nn.leaky_relu, kernel_regularizer=l2_regularizer(config.l2_reg), channels_last=channels_last): h_z, s1, s2 = flatten(z, 2) h_z = tf.reshape(dense(h_z, 64 * 7 * 7), [-1, 7, 7, 64] if channels_last else [-1, 64, 7, 7]) h_z = deconv_resnet_block(h_z, 64) # output: (64, 7, 7) h_z = deconv_resnet_block(h_z, 32, strides=2) # output: (32, 14, 14) h_z = deconv_resnet_block(h_z, 32) # output: (32, 14, 14) h_z = deconv_resnet_block(h_z, 16, strides=2) # output: (16, 28, 28) h_z = conv2d(h_z, 1, (1, 1), padding='same', name='feature_map_to_pixel', channels_last=channels_last) # output: (1, 28, 28) h_z = tf.reshape(h_z, [-1, config.x_dim]) # sample x ~ p(x|z) x_logits = unflatten(h_z, s1, s2) x = net.add('x', Bernoulli(logits=x_logits), group_ndims=1) return net
def main(): # load mnist data (x_train, y_train), (x_test, y_test) = \ load_mnist(shape=[config.x_dim], dtype=np.float32, normalize=True) # input placeholders input_x = tf.placeholder(dtype=tf.int32, shape=(None, ) + x_train.shape[1:], name='input_x') is_training = tf.placeholder(dtype=tf.bool, shape=(), name='is_training') learning_rate = tf.placeholder(shape=(), dtype=tf.float32) learning_rate_var = AnnealingDynamicValue(config.initial_lr, config.lr_anneal_factor) multi_gpu = MultiGPU(disable_prebuild=False) # build the model vae = VAE( p_z=Bernoulli(tf.zeros([1, config.z_dim])), p_x_given_z=Bernoulli, q_z_given_x=Bernoulli, h_for_p_x=functools.partial(h_for_p_x, is_training=is_training), h_for_q_z=functools.partial(h_for_q_z, is_training=is_training), ) grads = [] losses = [] lower_bounds = [] test_nlls = [] batch_size = get_batch_size(input_x) params = None optimizer = tf.train.AdamOptimizer(learning_rate) for dev, pre_build, [dev_input_x ] in multi_gpu.data_parallel(batch_size, [input_x]): with tf.device(dev), multi_gpu.maybe_name_scope(dev): if pre_build: with arg_scope([h_for_q_z, h_for_p_x]): _ = vae.chain(dev_input_x) else: # derive the loss and lower-bound for training train_chain = vae.chain(dev_input_x) dev_baseline = baseline_net(dev_input_x) dev_cost, dev_baseline_cost = \ train_chain.vi.training.reinforce(baseline=dev_baseline) dev_loss = regularization_loss() + \ tf.reduce_mean(dev_cost + dev_baseline_cost) dev_lower_bound = \ tf.reduce_mean(train_chain.vi.lower_bound.elbo()) losses.append(dev_loss) lower_bounds.append(dev_lower_bound) # derive the nll and logits output for testing test_chain = vae.chain(dev_input_x, n_z=config.test_n_z) dev_test_nll = -tf.reduce_mean( test_chain.vi.evaluation.is_loglikelihood()) test_nlls.append(dev_test_nll) # derive the optimizer params = tf.trainable_variables() grads.append( optimizer.compute_gradients(dev_loss, var_list=params)) # merge multi-gpu outputs and operations [loss, lower_bound, test_nll] = \ multi_gpu.average([losses, lower_bounds, test_nlls], batch_size) train_op = multi_gpu.apply_grads(grads=multi_gpu.average_grads(grads), optimizer=optimizer, control_inputs=tf.get_collection( tf.GraphKeys.UPDATE_OPS)) # derive the plotting function work_dev = multi_gpu.work_devices[0] with tf.device(work_dev), tf.name_scope('plot_x'), \ arg_scope([h_for_q_z, h_for_p_x], channels_last=multi_gpu.channels_last(work_dev)): x_plots = tf.reshape( tf.cast(255 * tf.sigmoid(vae.model(n_z=100)['x'].distribution.logits), dtype=tf.uint8), [-1, 28, 28]) def plot_samples(loop): with loop.timeit('plot_time'): session = get_default_session_or_error() images = session.run(x_plots, feed_dict={is_training: False}) save_images_collection(images=images, filename=results.prepare_parent( 'plotting/{}.png'.format(loop.epoch)), grid_size=(10, 10)) # prepare for training and testing data def input_x_sampler(x): sess = get_default_session_or_error() return sess.run([sampled_x], feed_dict={sample_input_x: x}) with tf.device('/device:CPU:0'): sample_input_x = tf.placeholder(dtype=tf.float32, shape=(None, config.x_dim), name='sample_input_x') sampled_x = sample_from_probs(sample_input_x) train_flow = DataFlow.arrays([x_train], config.batch_size, shuffle=True, skip_incomplete=True).map(input_x_sampler) test_flow = DataFlow.arrays([x_test], config.test_batch_size). \ map(input_x_sampler) with create_session().as_default(): # fix the testing flow, reducing the testing time test_flow = test_flow.to_arrays_flow(batch_size=config.test_batch_size) # train the network with TrainLoop(params, max_epoch=config.max_epoch, summary_dir=results.make_dir('train_summary'), summary_graph=tf.get_default_graph(), early_stopping=False) as loop: trainer = Trainer(loop, train_op, [input_x], train_flow, feed_dict={ learning_rate: learning_rate_var, is_training: True }, metrics={'loss': loss}) anneal_after(trainer, learning_rate_var, epochs=config.lr_anneal_epoch_freq, steps=config.lr_anneal_step_freq) evaluator = Evaluator(loop, metrics={ 'test_nll': test_nll, 'test_lb': lower_bound }, inputs=[input_x], data_flow=test_flow, feed_dict={is_training: False}, time_metric_name='test_time') trainer.evaluate_after_epochs(evaluator, freq=10) trainer.evaluate_after_epochs(functools.partial( plot_samples, loop), freq=10) trainer.log_after_epochs(freq=1) trainer.run() # write the final test_nll and test_lb results.commit(evaluator.last_metrics_dict)