def make_lars_graph( target_dist, # pylint: disable=invalid-name K, batch_size, eval_batch_size, lr, mlp_layers, dtype=tf.float32): """Construct the training graph for LARS.""" proposal = base.get_independent_normal([2], FLAGS.proposal_variance) model = lars.LARS(K=K, T=K, data_dim=[2], accept_fn_layers=mlp_layers, proposal=proposal, dtype=dtype) train_data = target_dist.sample(batch_size) log_p = model.log_prob(train_data) test_data = target_dist.sample(eval_batch_size) eval_log_p = model.log_prob(test_data) global_step = tf.train.get_or_create_global_step() opt = tf.train.AdamOptimizer(lr) grads = opt.compute_gradients(-tf.reduce_mean(log_p)) apply_grads_op = opt.apply_gradients(grads, global_step=global_step) with tf.control_dependencies([apply_grads_op]): train_op = model.post_train_op() density_image_summary( lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x), FLAGS.density_num_bins, "energy/lars") tf.summary.scalar("elbo", tf.reduce_mean(log_p)) tf.summary.scalar("eval_elbo", tf.reduce_mean(eval_log_p)) return -tf.reduce_mean(log_p), train_op, global_step
def make_model(proposal_type, model_type, data_dim, mean, global_step): """Create model graph.""" kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step) # Bernoulli VAE proposal gets that data mean because it is proposing images. # Other proposals don't because they are proposing latent states. decoder_hidden_sizes = [ int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",") ] q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")] energy_hidden_sizes = [ int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",") ] if model_type in ["nis", "his", "lars", "conv_nis", "identity"]: proposal_data_dim = data_dim elif model_type in [ "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae", "conv_bernoulli_vae" ]: proposal_data_dim = [FLAGS.latent_dim] if proposal_type == "bernoulli_vae": proposal = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, reparameterize_sample=FLAGS.reparameterize_proposal, temperature=FLAGS.gst_temperature, dtype=tf.float32) if proposal_type == "conv_bernoulli_vae": proposal = vae.ConvBernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=mean, scale_min=FLAGS.scale_min, kl_weight=kl_weight, dtype=tf.float32) elif proposal_type == "gaussian_vae": proposal = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, dtype=tf.float32) elif proposal_type == "nis": proposal = nis.NIS( K=FLAGS.K, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, dtype=tf.float32) elif proposal_type == "rejection_sampling": proposal = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, dtype=tf.float32) elif proposal_type == "gaussian": proposal = base.get_independent_normal(proposal_data_dim) elif proposal_type == "his": proposal = his.FullyConnectedHIS( T=FLAGS.his_T, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) elif proposal_type == "lars": proposal = lars.LARS( K=FLAGS.K, T=FLAGS.lars_T, data_dim=proposal_data_dim, accept_fn_layers=energy_hidden_sizes, proposal=None, data_mean=None, ema_decay=0.99, is_eval=FLAGS.mode == "eval", dtype=tf.float32) if model_type == "bernoulli_vae": model = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, reparameterize_sample=False, dtype=tf.float32) elif model_type == "gaussian_vae": model = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_gaussian_vae": model = vae.ConvGaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_bernoulli_vae": model = vae.ConvBernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "nis": model = nis.NIS( K=FLAGS.K, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, proposal=proposal, reparameterize_proposal_samples=FLAGS.reparameterize_proposal, dtype=tf.float32) elif model_type == "conv_nis": model = nis.ConvNIS( K=FLAGS.K, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, proposal=proposal, reparameterize_proposal_samples=FLAGS.reparameterize_proposal, dtype=tf.float32) elif model_type == "his": model = his.FullyConnectedHIS( proposal=proposal, T=FLAGS.his_T, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) elif model_type == "lars": model = lars.LARS( K=FLAGS.K, T=FLAGS.lars_T, data_dim=data_dim, accept_fn_layers=energy_hidden_sizes, proposal=proposal, data_mean=None if FLAGS.squash else mean, ema_decay=0.99, is_eval=FLAGS.mode == "eval", dtype=tf.float32) elif model_type == "identity": model = proposal # elif model_type == "hisvae": # model = his.HISVAE( # latent_dim=FLAGS.latent_dim, # proposal=proposal, # T=FLAGS.his_T, # data_dim=data_dim, # data_mean=mean, # energy_hidden_sizes=energy_hidden_sizes, # q_hidden_sizes=q_hidden_sizes, # decoder_hidden_sizes=decoder_hidden_sizes, # learn_temps=FLAGS.learn_his_temps, # learn_stepsize=FLAGS.learn_his_stepsize, # init_alpha=FLAGS.his_init_alpha, # init_step_size=FLAGS.his_init_stepsize, # squash=FLAGS.squash, # kl_weight=kl_weight, # dtype=tf.float32) if FLAGS.squash: model = base.SquashedDistribution(distribution=model, data_mean=mean) return model