def make_model(proposal_type, model_type, data_dim, mean, global_step): """Create model graph.""" kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step) # Bernoulli VAE proposal gets that data mean because it is proposing images. # Other proposals don't because they are proposing latent states. decoder_hidden_sizes = [ int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",") ] q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")] energy_hidden_sizes = [ int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",") ] if model_type in ["nis", "his", "lars", "conv_nis", "identity"]: proposal_data_dim = data_dim elif model_type in [ "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae", "conv_bernoulli_vae" ]: proposal_data_dim = [FLAGS.latent_dim] if proposal_type == "bernoulli_vae": proposal = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, reparameterize_sample=FLAGS.reparameterize_proposal, temperature=FLAGS.gst_temperature, dtype=tf.float32) if proposal_type == "conv_bernoulli_vae": proposal = vae.ConvBernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=mean, scale_min=FLAGS.scale_min, kl_weight=kl_weight, dtype=tf.float32) elif proposal_type == "gaussian_vae": proposal = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, dtype=tf.float32) elif proposal_type == "nis": proposal = nis.NIS( K=FLAGS.K, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, dtype=tf.float32) elif proposal_type == "rejection_sampling": proposal = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, dtype=tf.float32) elif proposal_type == "gaussian": proposal = base.get_independent_normal(proposal_data_dim) elif proposal_type == "his": proposal = his.FullyConnectedHIS( T=FLAGS.his_T, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) elif proposal_type == "lars": proposal = lars.LARS( K=FLAGS.K, T=FLAGS.lars_T, data_dim=proposal_data_dim, accept_fn_layers=energy_hidden_sizes, proposal=None, data_mean=None, ema_decay=0.99, is_eval=FLAGS.mode == "eval", dtype=tf.float32) if model_type == "bernoulli_vae": model = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, reparameterize_sample=False, dtype=tf.float32) elif model_type == "gaussian_vae": model = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_gaussian_vae": model = vae.ConvGaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_bernoulli_vae": model = vae.ConvBernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "nis": model = nis.NIS( K=FLAGS.K, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, proposal=proposal, reparameterize_proposal_samples=FLAGS.reparameterize_proposal, dtype=tf.float32) elif model_type == "conv_nis": model = nis.ConvNIS( K=FLAGS.K, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, proposal=proposal, reparameterize_proposal_samples=FLAGS.reparameterize_proposal, dtype=tf.float32) elif model_type == "his": model = his.FullyConnectedHIS( proposal=proposal, T=FLAGS.his_T, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) elif model_type == "lars": model = lars.LARS( K=FLAGS.K, T=FLAGS.lars_T, data_dim=data_dim, accept_fn_layers=energy_hidden_sizes, proposal=proposal, data_mean=None if FLAGS.squash else mean, ema_decay=0.99, is_eval=FLAGS.mode == "eval", dtype=tf.float32) elif model_type == "identity": model = proposal # elif model_type == "hisvae": # model = his.HISVAE( # latent_dim=FLAGS.latent_dim, # proposal=proposal, # T=FLAGS.his_T, # data_dim=data_dim, # data_mean=mean, # energy_hidden_sizes=energy_hidden_sizes, # q_hidden_sizes=q_hidden_sizes, # decoder_hidden_sizes=decoder_hidden_sizes, # learn_temps=FLAGS.learn_his_temps, # learn_stepsize=FLAGS.learn_his_stepsize, # init_alpha=FLAGS.his_init_alpha, # init_step_size=FLAGS.his_init_stepsize, # squash=FLAGS.squash, # kl_weight=kl_weight, # dtype=tf.float32) if FLAGS.squash: model = base.SquashedDistribution(distribution=model, data_mean=mean) return model
def main(unused_argv): g = tf.Graph() with g.as_default(): target = dists.get_target_distribution( FLAGS.target, nine_gaussians_variance=FLAGS.nine_gaussians_variance) energy_fn_layers = [ int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",") ] if FLAGS.algo == "lars": print("Running LARS") loss, train_op, global_step = make_lars_graph( target_dist=target, K=FLAGS.K, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, lr=FLAGS.learning_rate, mlp_layers=energy_fn_layers, dtype=tf.float32) else: if FLAGS.algo == "nis": print("Running NIS") model = nis.NIS(K=FLAGS.K, data_dim=2, energy_hidden_sizes=energy_fn_layers) density_image_summary( lambda x: # pylint: disable=g-long-lambda (tf.squeeze(model.energy_fn(x)) + model.proposal.log_prob( x)), FLAGS.density_num_bins, "energy/nis") elif FLAGS.algo == "rejection_sampling": print("Running Rejection Sampling") logit_accept_fn = tf.keras.Sequential([ tf.keras.layers.Dense(layer_size, activation="tanh") for layer_size in energy_fn_layers ] + [tf.keras.layers.Dense(1, activation=None)]) model = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=[2], logit_accept_fn=logit_accept_fn) density_image_summary( lambda x: tf.squeeze( # pylint: disable=g-long-lambda tf.log_sigmoid(model.logit_accept_fn(x)), axis=-1) + model.proposal.log_prob(x), FLAGS.density_num_bins, "energy/trs") elif FLAGS.algo == "his": print("Running HIS") model = his.FullyConnectedHIS( T=FLAGS.his_t, data_dim=2, energy_hidden_sizes=energy_fn_layers, q_hidden_sizes=energy_fn_layers, init_step_size=FLAGS.his_stepsize, learn_stepsize=FLAGS.his_learn_stepsize, init_alpha=FLAGS.his_alpha, learn_temps=FLAGS.his_learn_alpha) density_image_summary( lambda x: -model.hamiltonian_potential(x), FLAGS.density_num_bins, "energy/his") sample_image_summary(model, "density/his", num_samples=100000, num_bins=50) loss, train_op, global_step = make_train_graph( target_dist=target, model=model, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, lr=FLAGS.learning_rate) log_hooks = make_log_hooks(global_step, loss) with tf.train.MonitoredTrainingSession( master="", is_chief=True, hooks=log_hooks, checkpoint_dir=os.path.join(FLAGS.logdir, exp_name()), save_checkpoint_secs=120, save_summaries_steps=FLAGS.summarize_every, log_step_count_steps=FLAGS.summarize_every) as sess: cur_step = -1 while True: if sess.should_stop() or cur_step > FLAGS.max_steps: break # run a step _, cur_step = sess.run([train_op, global_step])
def main(unused_argv): g = tf.Graph() with g.as_default(): energy_fn_layers = [ int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",") ] if FLAGS.algo == "density": target = dists.get_target_distribution(FLAGS.target) plot = make_density_summary(target.log_prob, num_bins=FLAGS.num_bins) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: plot = sess.run(plot) with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"), "w") as out: np.save(out, plot) elif FLAGS.algo == "lars": tf.logging.info("Running LARS") proposal = base.get_independent_normal([2], FLAGS.proposal_variance) model = lars.SimpleLARS( K=FLAGS.K, data_dim=[2], accept_fn_layers=energy_fn_layers, proposal=proposal) plot = make_density_summary( lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x), num_bins=FLAGS.num_bins) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: plot = sess.run(plot) with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"), "w") as out: np.save(out, plot) else: proposal = base.get_independent_normal([2], FLAGS.proposal_variance) if FLAGS.algo == "nis": tf.logging.info("Running NIS") model = nis.NIS( K=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers, proposal=proposal) elif FLAGS.algo == "his": tf.logging.info("Running HIS") model = his.FullyConnectedHIS( T=FLAGS.his_t, data_dim=[2], energy_hidden_sizes=energy_fn_layers, q_hidden_sizes=energy_fn_layers, init_step_size=FLAGS.his_stepsize, learn_stepsize=FLAGS.his_learn_stepsize, init_alpha=FLAGS.his_alpha, learn_temps=FLAGS.his_learn_alpha, proposal=proposal) elif FLAGS.algo == "rejection_sampling": model = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers, proposal=proposal) samples = model.sample(FLAGS.batch_size) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: make_sample_density_summary( sess, samples, max_samples_per_batch=FLAGS.batch_size, num_samples=FLAGS.num_samples, num_bins=FLAGS.num_bins)
def make_model(proposal_type, model_type, data_dim, mean, global_step): """Create model graph.""" kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step) # Bernoulli VAE proposal gets that data mean because it is proposing images. # Other proposals don't because they are proposing latent states. decoder_hidden_sizes = [ int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",") ] q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")] energy_hidden_sizes = [ int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",") ] if model_type in ["nis", "his"]: proposal_data_dim = data_dim elif model_type in [ "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae", "conv_bernoulli_vae" ]: proposal_data_dim = FLAGS.latent_dim if proposal_type == "bernoulli_vae": proposal = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, reparameterize_sample=FLAGS.reparameterize_proposal, temperature=FLAGS.gst_temperature, dtype=tf.float32) elif proposal_type == "gaussian_vae": proposal = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, # data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, dtype=tf.float32) elif proposal_type == "nis": proposal = nis.NIS( K=FLAGS.K, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, dtype=tf.float32) elif proposal_type == "rejection_sampling": logit_accept_fn = functools.partial( base.mlp, layer_sizes=energy_hidden_sizes + [1], final_activation=None, name="rejection_sampling/energy_fn_mlp") proposal = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=[proposal_data_dim], logit_accept_fn=logit_accept_fn, dtype=tf.float32) elif proposal_type == "gaussian": proposal = tfd.MultivariateNormalDiag( loc=tf.zeros([proposal_data_dim], dtype=tf.float32), scale_diag=tf.ones([proposal_data_dim], dtype=tf.float32)) elif proposal_type == "his": proposal = his.FullyConnectedHIS( T=FLAGS.his_T, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) if model_type == "bernoulli_vae": model = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, prior=proposal, kl_weight=kl_weight, reparameterize_sample=False, dtype=tf.float32) elif model_type == "gaussian_vae": model = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, prior=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_gaussian_vae": model = vae.ConvGaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, base_depth=FLAGS.base_depth, scale_min=FLAGS.scale_min, prior=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_bernoulli_vae": model = vae.ConvBernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, scale_min=FLAGS.scale_min, prior=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "nis": model = nis.NIS( K=FLAGS.K, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, proposal=proposal, reparameterize_proposal_samples=FLAGS.reparameterize_proposal, dtype=tf.float32) elif model_type == "his": model = his.FullyConnectedHIS( proposal=proposal, T=FLAGS.his_T, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) # elif model_type == "hisvae": # model = his.HISVAE( # latent_dim=FLAGS.latent_dim, # proposal=proposal, # T=FLAGS.his_T, # data_dim=data_dim, # data_mean=mean, # energy_hidden_sizes=energy_hidden_sizes, # q_hidden_sizes=q_hidden_sizes, # decoder_hidden_sizes=decoder_hidden_sizes, # learn_temps=FLAGS.learn_his_temps, # learn_stepsize=FLAGS.learn_his_stepsize, # init_alpha=FLAGS.his_init_alpha, # init_step_size=FLAGS.his_init_stepsize, # squash=FLAGS.squash, # kl_weight=kl_weight, # dtype=tf.float32) if FLAGS.squash: model = base.SquashedDistribution(distribution=model, data_mean=mean) return model