예제 #1
0
    def __init__(self,
                 T,
                 data_dim,
                 energy_fn,
                 q_fn,
                 proposal=None,
                 init_alpha=1.,
                 init_step_size=0.01,
                 learn_temps=False,
                 learn_stepsize=False,
                 dtype=tf.float32,
                 name="his"):
        self.timesteps = T
        self.data_dim = data_dim
        self.energy_fn = energy_fn
        self.q = q_fn

        init_alpha = -np.log(1. / init_alpha - 1. + 1e-4)
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            self.raw_alphas = []
            for t in range(T):
                self.raw_alphas.append(
                    tf.get_variable(
                        name="raw_alpha_%d" % t,
                        shape=[],
                        dtype=dtype,
                        initializer=tf.constant_initializer(init_alpha),
                        trainable=learn_temps))
            self.log_alphas = [
                -tf.nn.softplus(-raw_alpha) for raw_alpha in self.raw_alphas
            ]
            self.log_alphas = [-tf.reduce_sum(self.log_alphas)
                               ] + self.log_alphas
            init_step_size = np.log(np.exp(init_step_size) - 1.)
            self.raw_step_size = tf.get_variable(
                name="raw_step_size",
                shape=data_dim,
                dtype=tf.float32,
                initializer=tf.constant_initializer(init_step_size),
                trainable=learn_stepsize)
            self.step_size = tf.math.softplus(self.raw_step_size)
            tf.summary.scalar("his_step_size", tf.reduce_mean(self.step_size))
            _ = [
                tf.summary.scalar("his_alpha/alpha_%d" % t,
                                  tf.exp(self.log_alphas[t]))
                for t in range(len(self.log_alphas))
            ]

        if proposal is None:
            self.proposal = base.get_independent_normal(data_dim)
        else:
            self.proposal = proposal
        self.momentum_proposal = base.get_independent_normal(data_dim)
예제 #2
0
def make_lars_graph(
        target_dist,  # pylint: disable=invalid-name
        K,
        batch_size,
        eval_batch_size,
        lr,
        mlp_layers,
        dtype=tf.float32):
    """Construct the training graph for LARS."""
    proposal = base.get_independent_normal([2], FLAGS.proposal_variance)
    model = lars.LARS(K=K,
                      T=K,
                      data_dim=[2],
                      accept_fn_layers=mlp_layers,
                      proposal=proposal,
                      dtype=dtype)

    train_data = target_dist.sample(batch_size)
    log_p = model.log_prob(train_data)
    test_data = target_dist.sample(eval_batch_size)
    eval_log_p = model.log_prob(test_data)

    global_step = tf.train.get_or_create_global_step()
    opt = tf.train.AdamOptimizer(lr)
    grads = opt.compute_gradients(-tf.reduce_mean(log_p))
    apply_grads_op = opt.apply_gradients(grads, global_step=global_step)
    with tf.control_dependencies([apply_grads_op]):
        train_op = model.post_train_op()

    density_image_summary(
        lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x),
        FLAGS.density_num_bins, "energy/lars")
    tf.summary.scalar("elbo", tf.reduce_mean(log_p))
    tf.summary.scalar("eval_elbo", tf.reduce_mean(eval_log_p))
    return -tf.reduce_mean(log_p), train_op, global_step
예제 #3
0
 def __init__(self,
              latent_dim,
              data_dim,
              proposal=None,
              data_mean=None,
              kl_weight=1.,
              dtype=tf.float32):
     """Create HVAE."""
     self.latent_dim = latent_dim
     self.data_dim = data_dim
     if data_mean is not None:
         self.data_mean = data_mean
     else:
         self.data_mean = 0.
     self.kl_weight = kl_weight
     self.dtype = dtype
     if proposal is None:
         self.proposal = base.get_independent_normal([latent_dim])
     else:
         self.proposal = proposal
     self._build()
    def __init__(
            self,  # pylint: disable=invalid-name
            T,
            data_dim,
            logit_accept_fn,
            proposal=None,
            dtype=tf.float32,
            name="rejection_sampling"):
        """Creates a Rejection Sampling model.

    Args:
      T: The maximum number of proposals to sample in the rejection sampler.
      data_dim: The dimension of the data. Should be a list.
      logit_accept_fn: Accept function, takes [batch_size] + data_dim to [0, 1].
      proposal: A distribution over the data space of this model. Must support
        sample() and log_prob() although log_prob only needs to return a lower
        bound on the true log probability. If not supplied, then defaults to
        Gaussian.
      dtype: Type of data.
      name: Name to use in scopes.
    """
        self.T = T  # pylint: disable=invalid-name
        self.data_dim = data_dim
        self.logit_accept_fn = logit_accept_fn
        if proposal is None:
            self.proposal = base.get_independent_normal(data_dim)
        else:
            self.proposal = proposal
        self.name = name
        self.dtype = dtype
        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            self.logit_Z = tf.get_variable(  # pylint: disable=invalid-name
                name="logit_Z",
                shape=[],
                dtype=dtype,
                initializer=tf.constant_initializer(0.),
                trainable=True)

        tf.summary.scalar("Expected_trials",
                          tf.exp(-tf.log_sigmoid(self.logit_Z)))
예제 #5
0
    def __init__(self,
                 latent_dim,
                 data_dim,
                 decoder,
                 q,
                 proposal=None,
                 data_mean=None,
                 kl_weight=1.,
                 dtype=tf.float32):
        """Creates a VAE.

    Args:
      latent_dim: The size of the latent variable of the VAE.
      data_dim: The size of the input data.
      decoder: A callable that accepts a batch of latent samples and returns a
        distribution over the data space of the VAE. The distribution should
        support sample() and log_prob().
      q: A callable that accepts a batch of data samples and returns a
        distribution over the latent space of the VAE. The distribution should
        support sample() and log_prob().
      proposal: A distribution over the latent space of the VAE. The object must
        support sample() and log_prob(). If not provided, defaults to Gaussian.
      data_mean: Mean of the data used to center the input.
      kl_weight: Weighting on the KL regularizer.
      dtype: Type of the tensors.
    """
        self.data_dim = data_dim
        if data_mean is not None:
            self.data_mean = data_mean
        else:
            self.data_mean = tf.zeros((), dtype=dtype)
        self.decoder = decoder
        self.q = q
        self.kl_weight = kl_weight

        self.dtype = dtype
        if proposal is None:
            self.proposal = base.get_independent_normal([latent_dim])
        else:
            self.proposal = proposal
예제 #6
0
    def __init__(self,
                 K,
                 T,
                 data_dim,
                 accept_fn_layers,
                 proposal=None,
                 data_mean=None,
                 ema_decay=0.99,
                 dtype=tf.float32,
                 is_eval=False):
        self.k = K
        self.T = T  # pylint: disable=invalid-name
        self.data_dim = data_dim
        self.ema_decay = ema_decay
        self.dtype = dtype
        if data_mean is not None:
            self.data_mean = data_mean
        else:
            self.data_mean = tf.zeros((), dtype=dtype)
        self.accept_fn = functools.partial(
            base.mlp,
            layer_sizes=accept_fn_layers + [1],
            final_activation=tf.math.log_sigmoid,
            name="a")
        if proposal is None:
            self.proposal = base.get_independent_normal(data_dim)
        else:
            self.proposal = proposal
        self.is_eval = is_eval
        if is_eval:
            self.Z_estimate = tf.placeholder(tf.float32, shape=[])  # pylint: disable=invalid-name

        with tf.variable_scope("LARS_Z_ema", reuse=tf.AUTO_REUSE):
            self.Z_ema = tf.get_variable(  # pylint: disable=invalid-name
                name="LARS_Z_ema",
                shape=[],
                dtype=dtype,
                initializer=tf.constant_initializer(0.5),
                trainable=False)
예제 #7
0
    def __init__(
            self,  # pylint: disable=invalid-name
            K,
            data_dim,
            energy_fn,
            proposal=None,
            data_mean=None,
            reparameterize_proposal_samples=True,
            dtype=tf.float32):
        """Creates a NIS model.

    Args:
      K: The number of proposal samples to take.
      data_dim: The dimension of the data.
      energy_fn: Energy function.
      proposal: A distribution over the data space of this model. Must support
        sample() and log_prob() although log_prob only needs to return a lower
        bound on the true log probability. If not supplied, then defaults to
        Gaussian.
      data_mean: Mean of the data used to center the input.
      reparameterize_proposal_samples: Whether to allow gradients to pass
        through the proposal samples.
      dtype: Type of the tensors.
    """
        self.K = K  # pylint: disable=invalid-name
        self.data_dim = data_dim  # self.data_dim is always a list
        self.reparameterize_proposal_samples = reparameterize_proposal_samples
        if data_mean is not None:
            self.data_mean = data_mean
        else:
            self.data_mean = tf.zeros((), dtype=dtype)
        self.energy_fn = energy_fn
        if proposal is None:
            self.proposal = base.get_independent_normal(self.data_dim)
        else:
            self.proposal = proposal
예제 #8
0
def make_model(proposal_type, model_type, data_dim, mean, global_step):
  """Create model graph."""
  kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step)
  # Bernoulli VAE proposal gets that data mean because it is proposing images.
  # Other proposals don't because they are proposing latent states.
  decoder_hidden_sizes = [
      int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",")
  ]
  q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")]
  energy_hidden_sizes = [
      int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",")
  ]
  if model_type in ["nis", "his", "lars", "conv_nis", "identity"]:
    proposal_data_dim = data_dim
  elif model_type in [
      "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae",
      "conv_bernoulli_vae"
  ]:
    proposal_data_dim = [FLAGS.latent_dim]

  if proposal_type == "bernoulli_vae":
    proposal = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        reparameterize_sample=FLAGS.reparameterize_proposal,
        temperature=FLAGS.gst_temperature,
        dtype=tf.float32)
  if proposal_type == "conv_bernoulli_vae":
    proposal = vae.ConvBernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=mean,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif proposal_type == "gaussian_vae":
    proposal = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif proposal_type == "nis":
    proposal = nis.NIS(
        K=FLAGS.K,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        dtype=tf.float32)
  elif proposal_type == "rejection_sampling":
    proposal = rejection_sampling.RejectionSampling(
        T=FLAGS.K,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        dtype=tf.float32)
  elif proposal_type == "gaussian":
    proposal = base.get_independent_normal(proposal_data_dim)
  elif proposal_type == "his":
    proposal = his.FullyConnectedHIS(
        T=FLAGS.his_T,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)
  elif proposal_type == "lars":
    proposal = lars.LARS(
        K=FLAGS.K,
        T=FLAGS.lars_T,
        data_dim=proposal_data_dim,
        accept_fn_layers=energy_hidden_sizes,
        proposal=None,
        data_mean=None,
        ema_decay=0.99,
        is_eval=FLAGS.mode == "eval",
        dtype=tf.float32)

  if model_type == "bernoulli_vae":
    model = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        reparameterize_sample=False,
        dtype=tf.float32)
  elif model_type == "gaussian_vae":
    model = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_gaussian_vae":
    model = vae.ConvGaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_bernoulli_vae":
    model = vae.ConvBernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "nis":
    model = nis.NIS(
        K=FLAGS.K,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        proposal=proposal,
        reparameterize_proposal_samples=FLAGS.reparameterize_proposal,
        dtype=tf.float32)
  elif model_type == "conv_nis":
    model = nis.ConvNIS(
        K=FLAGS.K,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        proposal=proposal,
        reparameterize_proposal_samples=FLAGS.reparameterize_proposal,
        dtype=tf.float32)
  elif model_type == "his":
    model = his.FullyConnectedHIS(
        proposal=proposal,
        T=FLAGS.his_T,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)
  elif model_type == "lars":
    model = lars.LARS(
        K=FLAGS.K,
        T=FLAGS.lars_T,
        data_dim=data_dim,
        accept_fn_layers=energy_hidden_sizes,
        proposal=proposal,
        data_mean=None if FLAGS.squash else mean,
        ema_decay=0.99,
        is_eval=FLAGS.mode == "eval",
        dtype=tf.float32)
  elif model_type == "identity":
    model = proposal

#  elif model_type == "hisvae":
#    model = his.HISVAE(
#        latent_dim=FLAGS.latent_dim,
#        proposal=proposal,
#        T=FLAGS.his_T,
#        data_dim=data_dim,
#        data_mean=mean,
#        energy_hidden_sizes=energy_hidden_sizes,
#        q_hidden_sizes=q_hidden_sizes,
#        decoder_hidden_sizes=decoder_hidden_sizes,
#        learn_temps=FLAGS.learn_his_temps,
#        learn_stepsize=FLAGS.learn_his_stepsize,
#        init_alpha=FLAGS.his_init_alpha,
#        init_step_size=FLAGS.his_init_stepsize,
#        squash=FLAGS.squash,
#        kl_weight=kl_weight,
#        dtype=tf.float32)

  if FLAGS.squash:
    model = base.SquashedDistribution(distribution=model, data_mean=mean)

  return model
예제 #9
0
def main(unused_argv):
  g = tf.Graph()
  with g.as_default():
    energy_fn_layers = [
        int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",")
    ]
    if FLAGS.algo == "density":
      target = dists.get_target_distribution(FLAGS.target)
      plot = make_density_summary(target.log_prob, num_bins=FLAGS.num_bins)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        plot = sess.run(plot)
        with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"),
                               "w") as out:
          np.save(out, plot)
    elif FLAGS.algo == "lars":
      tf.logging.info("Running LARS")
      proposal = base.get_independent_normal([2], FLAGS.proposal_variance)
      model = lars.SimpleLARS(
          K=FLAGS.K, data_dim=[2], accept_fn_layers=energy_fn_layers,
          proposal=proposal)
      plot = make_density_summary(
          lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x),
          num_bins=FLAGS.num_bins)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        plot = sess.run(plot)
        with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"),
                               "w") as out:
          np.save(out, plot)
    else:
      proposal = base.get_independent_normal([2], FLAGS.proposal_variance)
      if FLAGS.algo == "nis":
        tf.logging.info("Running NIS")
        model = nis.NIS(
            K=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers,
            proposal=proposal)
      elif FLAGS.algo == "his":
        tf.logging.info("Running HIS")
        model = his.FullyConnectedHIS(
            T=FLAGS.his_t,
            data_dim=[2],
            energy_hidden_sizes=energy_fn_layers,
            q_hidden_sizes=energy_fn_layers,
            init_step_size=FLAGS.his_stepsize,
            learn_stepsize=FLAGS.his_learn_stepsize,
            init_alpha=FLAGS.his_alpha,
            learn_temps=FLAGS.his_learn_alpha,
            proposal=proposal)
      elif FLAGS.algo == "rejection_sampling":
        model = rejection_sampling.RejectionSampling(
            T=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers,
            proposal=proposal)
      samples = model.sample(FLAGS.batch_size)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        make_sample_density_summary(
            sess,
            samples,
            max_samples_per_batch=FLAGS.batch_size,
            num_samples=FLAGS.num_samples,
            num_bins=FLAGS.num_bins)
예제 #10
0
def main(unused_argv):
    g = tf.Graph()
    with g.as_default():
        target = dists.get_target_distribution(
            FLAGS.target,
            nine_gaussians_variance=FLAGS.nine_gaussians_variance)
        energy_fn_layers = [
            int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",")
        ]
        if FLAGS.algo == "lars":
            print("Running LARS")
            loss, train_op, global_step = make_lars_graph(
                target_dist=target,
                K=FLAGS.K,
                batch_size=FLAGS.batch_size,
                eval_batch_size=FLAGS.eval_batch_size,
                lr=FLAGS.learning_rate,
                mlp_layers=energy_fn_layers,
                dtype=tf.float32)
        else:
            proposal = base.get_independent_normal([2],
                                                   FLAGS.proposal_variance)
            if FLAGS.algo == "nis":
                print("Running NIS")
                model = nis.NIS(K=FLAGS.K,
                                data_dim=[2],
                                energy_hidden_sizes=energy_fn_layers,
                                proposal=proposal)
                density_image_summary(
                    lambda x:  # pylint: disable=g-long-lambda
                    (tf.squeeze(model.energy_fn(x)) + model.proposal.log_prob(
                        x)),
                    FLAGS.density_num_bins,
                    "energy/nis")
            elif FLAGS.algo == "rejection_sampling":
                print("Running Rejection Sampling")
                model = rejection_sampling.RejectionSampling(
                    T=FLAGS.K,
                    data_dim=[2],
                    energy_hidden_sizes=energy_fn_layers,
                    proposal=proposal)
                density_image_summary(
                    lambda x: tf.squeeze(  # pylint: disable=g-long-lambda
                        tf.log_sigmoid(model.logit_accept_fn(x)),
                        axis=-1) + model.proposal.log_prob(x),
                    FLAGS.density_num_bins,
                    "energy/trs")
            elif FLAGS.algo == "his":
                print("Running HIS")
                model = his.FullyConnectedHIS(
                    T=FLAGS.his_t,
                    data_dim=[2],
                    energy_hidden_sizes=energy_fn_layers,
                    q_hidden_sizes=energy_fn_layers,
                    init_step_size=FLAGS.his_stepsize,
                    learn_stepsize=FLAGS.his_learn_stepsize,
                    init_alpha=FLAGS.his_alpha,
                    learn_temps=FLAGS.his_learn_alpha,
                    proposal=proposal)
                density_image_summary(
                    lambda x: -model.hamiltonian_potential(x),
                    FLAGS.density_num_bins, "energy/his")
                sample_image_summary(model,
                                     "density/his",
                                     num_samples=100000,
                                     num_bins=50)

            loss, train_op, global_step = make_train_graph(
                target_dist=target,
                model=model,
                batch_size=FLAGS.batch_size,
                eval_batch_size=FLAGS.eval_batch_size,
                lr=FLAGS.learning_rate)

        log_hooks = make_log_hooks(global_step, loss)
        with tf.train.MonitoredTrainingSession(
                master="",
                is_chief=True,
                hooks=log_hooks,
                checkpoint_dir=os.path.join(FLAGS.logdir, exp_name()),
                save_checkpoint_secs=120,
                save_summaries_steps=FLAGS.summarize_every,
                log_step_count_steps=FLAGS.summarize_every) as sess:
            cur_step = -1
            while True:
                if sess.should_stop() or cur_step > FLAGS.max_steps:
                    break
                # run a step
                _, cur_step = sess.run([train_op, global_step])