def __init__(self, T, data_dim, energy_fn, q_fn, proposal=None, init_alpha=1., init_step_size=0.01, learn_temps=False, learn_stepsize=False, dtype=tf.float32, name="his"): self.timesteps = T self.data_dim = data_dim self.energy_fn = energy_fn self.q = q_fn init_alpha = -np.log(1. / init_alpha - 1. + 1e-4) with tf.variable_scope(name, reuse=tf.AUTO_REUSE): self.raw_alphas = [] for t in range(T): self.raw_alphas.append( tf.get_variable( name="raw_alpha_%d" % t, shape=[], dtype=dtype, initializer=tf.constant_initializer(init_alpha), trainable=learn_temps)) self.log_alphas = [ -tf.nn.softplus(-raw_alpha) for raw_alpha in self.raw_alphas ] self.log_alphas = [-tf.reduce_sum(self.log_alphas) ] + self.log_alphas init_step_size = np.log(np.exp(init_step_size) - 1.) self.raw_step_size = tf.get_variable( name="raw_step_size", shape=data_dim, dtype=tf.float32, initializer=tf.constant_initializer(init_step_size), trainable=learn_stepsize) self.step_size = tf.math.softplus(self.raw_step_size) tf.summary.scalar("his_step_size", tf.reduce_mean(self.step_size)) _ = [ tf.summary.scalar("his_alpha/alpha_%d" % t, tf.exp(self.log_alphas[t])) for t in range(len(self.log_alphas)) ] if proposal is None: self.proposal = base.get_independent_normal(data_dim) else: self.proposal = proposal self.momentum_proposal = base.get_independent_normal(data_dim)
def make_lars_graph( target_dist, # pylint: disable=invalid-name K, batch_size, eval_batch_size, lr, mlp_layers, dtype=tf.float32): """Construct the training graph for LARS.""" proposal = base.get_independent_normal([2], FLAGS.proposal_variance) model = lars.LARS(K=K, T=K, data_dim=[2], accept_fn_layers=mlp_layers, proposal=proposal, dtype=dtype) train_data = target_dist.sample(batch_size) log_p = model.log_prob(train_data) test_data = target_dist.sample(eval_batch_size) eval_log_p = model.log_prob(test_data) global_step = tf.train.get_or_create_global_step() opt = tf.train.AdamOptimizer(lr) grads = opt.compute_gradients(-tf.reduce_mean(log_p)) apply_grads_op = opt.apply_gradients(grads, global_step=global_step) with tf.control_dependencies([apply_grads_op]): train_op = model.post_train_op() density_image_summary( lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x), FLAGS.density_num_bins, "energy/lars") tf.summary.scalar("elbo", tf.reduce_mean(log_p)) tf.summary.scalar("eval_elbo", tf.reduce_mean(eval_log_p)) return -tf.reduce_mean(log_p), train_op, global_step
def __init__(self, latent_dim, data_dim, proposal=None, data_mean=None, kl_weight=1., dtype=tf.float32): """Create HVAE.""" self.latent_dim = latent_dim self.data_dim = data_dim if data_mean is not None: self.data_mean = data_mean else: self.data_mean = 0. self.kl_weight = kl_weight self.dtype = dtype if proposal is None: self.proposal = base.get_independent_normal([latent_dim]) else: self.proposal = proposal self._build()
def __init__( self, # pylint: disable=invalid-name T, data_dim, logit_accept_fn, proposal=None, dtype=tf.float32, name="rejection_sampling"): """Creates a Rejection Sampling model. Args: T: The maximum number of proposals to sample in the rejection sampler. data_dim: The dimension of the data. Should be a list. logit_accept_fn: Accept function, takes [batch_size] + data_dim to [0, 1]. proposal: A distribution over the data space of this model. Must support sample() and log_prob() although log_prob only needs to return a lower bound on the true log probability. If not supplied, then defaults to Gaussian. dtype: Type of data. name: Name to use in scopes. """ self.T = T # pylint: disable=invalid-name self.data_dim = data_dim self.logit_accept_fn = logit_accept_fn if proposal is None: self.proposal = base.get_independent_normal(data_dim) else: self.proposal = proposal self.name = name self.dtype = dtype with tf.variable_scope(name, reuse=tf.AUTO_REUSE): self.logit_Z = tf.get_variable( # pylint: disable=invalid-name name="logit_Z", shape=[], dtype=dtype, initializer=tf.constant_initializer(0.), trainable=True) tf.summary.scalar("Expected_trials", tf.exp(-tf.log_sigmoid(self.logit_Z)))
def __init__(self, latent_dim, data_dim, decoder, q, proposal=None, data_mean=None, kl_weight=1., dtype=tf.float32): """Creates a VAE. Args: latent_dim: The size of the latent variable of the VAE. data_dim: The size of the input data. decoder: A callable that accepts a batch of latent samples and returns a distribution over the data space of the VAE. The distribution should support sample() and log_prob(). q: A callable that accepts a batch of data samples and returns a distribution over the latent space of the VAE. The distribution should support sample() and log_prob(). proposal: A distribution over the latent space of the VAE. The object must support sample() and log_prob(). If not provided, defaults to Gaussian. data_mean: Mean of the data used to center the input. kl_weight: Weighting on the KL regularizer. dtype: Type of the tensors. """ self.data_dim = data_dim if data_mean is not None: self.data_mean = data_mean else: self.data_mean = tf.zeros((), dtype=dtype) self.decoder = decoder self.q = q self.kl_weight = kl_weight self.dtype = dtype if proposal is None: self.proposal = base.get_independent_normal([latent_dim]) else: self.proposal = proposal
def __init__(self, K, T, data_dim, accept_fn_layers, proposal=None, data_mean=None, ema_decay=0.99, dtype=tf.float32, is_eval=False): self.k = K self.T = T # pylint: disable=invalid-name self.data_dim = data_dim self.ema_decay = ema_decay self.dtype = dtype if data_mean is not None: self.data_mean = data_mean else: self.data_mean = tf.zeros((), dtype=dtype) self.accept_fn = functools.partial( base.mlp, layer_sizes=accept_fn_layers + [1], final_activation=tf.math.log_sigmoid, name="a") if proposal is None: self.proposal = base.get_independent_normal(data_dim) else: self.proposal = proposal self.is_eval = is_eval if is_eval: self.Z_estimate = tf.placeholder(tf.float32, shape=[]) # pylint: disable=invalid-name with tf.variable_scope("LARS_Z_ema", reuse=tf.AUTO_REUSE): self.Z_ema = tf.get_variable( # pylint: disable=invalid-name name="LARS_Z_ema", shape=[], dtype=dtype, initializer=tf.constant_initializer(0.5), trainable=False)
def __init__( self, # pylint: disable=invalid-name K, data_dim, energy_fn, proposal=None, data_mean=None, reparameterize_proposal_samples=True, dtype=tf.float32): """Creates a NIS model. Args: K: The number of proposal samples to take. data_dim: The dimension of the data. energy_fn: Energy function. proposal: A distribution over the data space of this model. Must support sample() and log_prob() although log_prob only needs to return a lower bound on the true log probability. If not supplied, then defaults to Gaussian. data_mean: Mean of the data used to center the input. reparameterize_proposal_samples: Whether to allow gradients to pass through the proposal samples. dtype: Type of the tensors. """ self.K = K # pylint: disable=invalid-name self.data_dim = data_dim # self.data_dim is always a list self.reparameterize_proposal_samples = reparameterize_proposal_samples if data_mean is not None: self.data_mean = data_mean else: self.data_mean = tf.zeros((), dtype=dtype) self.energy_fn = energy_fn if proposal is None: self.proposal = base.get_independent_normal(self.data_dim) else: self.proposal = proposal
def make_model(proposal_type, model_type, data_dim, mean, global_step): """Create model graph.""" kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step) # Bernoulli VAE proposal gets that data mean because it is proposing images. # Other proposals don't because they are proposing latent states. decoder_hidden_sizes = [ int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",") ] q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")] energy_hidden_sizes = [ int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",") ] if model_type in ["nis", "his", "lars", "conv_nis", "identity"]: proposal_data_dim = data_dim elif model_type in [ "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae", "conv_bernoulli_vae" ]: proposal_data_dim = [FLAGS.latent_dim] if proposal_type == "bernoulli_vae": proposal = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, reparameterize_sample=FLAGS.reparameterize_proposal, temperature=FLAGS.gst_temperature, dtype=tf.float32) if proposal_type == "conv_bernoulli_vae": proposal = vae.ConvBernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=mean, scale_min=FLAGS.scale_min, kl_weight=kl_weight, dtype=tf.float32) elif proposal_type == "gaussian_vae": proposal = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, dtype=tf.float32) elif proposal_type == "nis": proposal = nis.NIS( K=FLAGS.K, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, dtype=tf.float32) elif proposal_type == "rejection_sampling": proposal = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, dtype=tf.float32) elif proposal_type == "gaussian": proposal = base.get_independent_normal(proposal_data_dim) elif proposal_type == "his": proposal = his.FullyConnectedHIS( T=FLAGS.his_T, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) elif proposal_type == "lars": proposal = lars.LARS( K=FLAGS.K, T=FLAGS.lars_T, data_dim=proposal_data_dim, accept_fn_layers=energy_hidden_sizes, proposal=None, data_mean=None, ema_decay=0.99, is_eval=FLAGS.mode == "eval", dtype=tf.float32) if model_type == "bernoulli_vae": model = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, reparameterize_sample=False, dtype=tf.float32) elif model_type == "gaussian_vae": model = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_gaussian_vae": model = vae.ConvGaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_bernoulli_vae": model = vae.ConvBernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "nis": model = nis.NIS( K=FLAGS.K, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, proposal=proposal, reparameterize_proposal_samples=FLAGS.reparameterize_proposal, dtype=tf.float32) elif model_type == "conv_nis": model = nis.ConvNIS( K=FLAGS.K, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, proposal=proposal, reparameterize_proposal_samples=FLAGS.reparameterize_proposal, dtype=tf.float32) elif model_type == "his": model = his.FullyConnectedHIS( proposal=proposal, T=FLAGS.his_T, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) elif model_type == "lars": model = lars.LARS( K=FLAGS.K, T=FLAGS.lars_T, data_dim=data_dim, accept_fn_layers=energy_hidden_sizes, proposal=proposal, data_mean=None if FLAGS.squash else mean, ema_decay=0.99, is_eval=FLAGS.mode == "eval", dtype=tf.float32) elif model_type == "identity": model = proposal # elif model_type == "hisvae": # model = his.HISVAE( # latent_dim=FLAGS.latent_dim, # proposal=proposal, # T=FLAGS.his_T, # data_dim=data_dim, # data_mean=mean, # energy_hidden_sizes=energy_hidden_sizes, # q_hidden_sizes=q_hidden_sizes, # decoder_hidden_sizes=decoder_hidden_sizes, # learn_temps=FLAGS.learn_his_temps, # learn_stepsize=FLAGS.learn_his_stepsize, # init_alpha=FLAGS.his_init_alpha, # init_step_size=FLAGS.his_init_stepsize, # squash=FLAGS.squash, # kl_weight=kl_weight, # dtype=tf.float32) if FLAGS.squash: model = base.SquashedDistribution(distribution=model, data_mean=mean) return model
def main(unused_argv): g = tf.Graph() with g.as_default(): energy_fn_layers = [ int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",") ] if FLAGS.algo == "density": target = dists.get_target_distribution(FLAGS.target) plot = make_density_summary(target.log_prob, num_bins=FLAGS.num_bins) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: plot = sess.run(plot) with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"), "w") as out: np.save(out, plot) elif FLAGS.algo == "lars": tf.logging.info("Running LARS") proposal = base.get_independent_normal([2], FLAGS.proposal_variance) model = lars.SimpleLARS( K=FLAGS.K, data_dim=[2], accept_fn_layers=energy_fn_layers, proposal=proposal) plot = make_density_summary( lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x), num_bins=FLAGS.num_bins) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: plot = sess.run(plot) with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"), "w") as out: np.save(out, plot) else: proposal = base.get_independent_normal([2], FLAGS.proposal_variance) if FLAGS.algo == "nis": tf.logging.info("Running NIS") model = nis.NIS( K=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers, proposal=proposal) elif FLAGS.algo == "his": tf.logging.info("Running HIS") model = his.FullyConnectedHIS( T=FLAGS.his_t, data_dim=[2], energy_hidden_sizes=energy_fn_layers, q_hidden_sizes=energy_fn_layers, init_step_size=FLAGS.his_stepsize, learn_stepsize=FLAGS.his_learn_stepsize, init_alpha=FLAGS.his_alpha, learn_temps=FLAGS.his_learn_alpha, proposal=proposal) elif FLAGS.algo == "rejection_sampling": model = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers, proposal=proposal) samples = model.sample(FLAGS.batch_size) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: make_sample_density_summary( sess, samples, max_samples_per_batch=FLAGS.batch_size, num_samples=FLAGS.num_samples, num_bins=FLAGS.num_bins)
def main(unused_argv): g = tf.Graph() with g.as_default(): target = dists.get_target_distribution( FLAGS.target, nine_gaussians_variance=FLAGS.nine_gaussians_variance) energy_fn_layers = [ int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",") ] if FLAGS.algo == "lars": print("Running LARS") loss, train_op, global_step = make_lars_graph( target_dist=target, K=FLAGS.K, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, lr=FLAGS.learning_rate, mlp_layers=energy_fn_layers, dtype=tf.float32) else: proposal = base.get_independent_normal([2], FLAGS.proposal_variance) if FLAGS.algo == "nis": print("Running NIS") model = nis.NIS(K=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers, proposal=proposal) density_image_summary( lambda x: # pylint: disable=g-long-lambda (tf.squeeze(model.energy_fn(x)) + model.proposal.log_prob( x)), FLAGS.density_num_bins, "energy/nis") elif FLAGS.algo == "rejection_sampling": print("Running Rejection Sampling") model = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers, proposal=proposal) density_image_summary( lambda x: tf.squeeze( # pylint: disable=g-long-lambda tf.log_sigmoid(model.logit_accept_fn(x)), axis=-1) + model.proposal.log_prob(x), FLAGS.density_num_bins, "energy/trs") elif FLAGS.algo == "his": print("Running HIS") model = his.FullyConnectedHIS( T=FLAGS.his_t, data_dim=[2], energy_hidden_sizes=energy_fn_layers, q_hidden_sizes=energy_fn_layers, init_step_size=FLAGS.his_stepsize, learn_stepsize=FLAGS.his_learn_stepsize, init_alpha=FLAGS.his_alpha, learn_temps=FLAGS.his_learn_alpha, proposal=proposal) density_image_summary( lambda x: -model.hamiltonian_potential(x), FLAGS.density_num_bins, "energy/his") sample_image_summary(model, "density/his", num_samples=100000, num_bins=50) loss, train_op, global_step = make_train_graph( target_dist=target, model=model, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, lr=FLAGS.learning_rate) log_hooks = make_log_hooks(global_step, loss) with tf.train.MonitoredTrainingSession( master="", is_chief=True, hooks=log_hooks, checkpoint_dir=os.path.join(FLAGS.logdir, exp_name()), save_checkpoint_secs=120, save_summaries_steps=FLAGS.summarize_every, log_step_count_steps=FLAGS.summarize_every) as sess: cur_step = -1 while True: if sess.should_stop() or cur_step > FLAGS.max_steps: break # run a step _, cur_step = sess.run([train_op, global_step])