Exemplo n.º 1
0
  def testPreconditionedHMC(self):
    step_size = self._constant(0.2)
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2], dtype=self._dtype)

    base_mean = self._constant([1., 0])
    base_cov = self._constant([[1, 0.5], [0.5, 1]])

    bijector = tfp.bijectors.Softplus()
    base_dist = tfp.distributions.MultivariateNormalFullCovariance(
        loc=base_mean, covariance_matrix=base_cov)
    target_dist = bijector(base_dist)

    def orig_target_log_prob_fn(x):
      return target_dist.log_prob(x), ()

    target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
        orig_target_log_prob_fn, bijector, state)

    # pylint: disable=g-long-lambda
    def kernel(hmc_state, seed):
      if not self._is_on_jax:
        hmc_seed = _test_seed()
      else:
        hmc_seed, seed = util.split_seed(seed, 2)
      hmc_state, _ = fun_mcmc.hamiltonian_monte_carlo(
          hmc_state,
          step_size=step_size,
          num_integrator_steps=num_leapfrog_steps,
          target_log_prob_fn=target_log_prob_fn,
          seed=hmc_seed)
      return (hmc_state, seed), hmc_state.state_extra[0]

    if not self._is_on_jax:
      seed = _test_seed()
    else:
      seed = self._make_seed(_test_seed())

    # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
    # for the jit to do anything.
    _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
        state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
               seed),
        fn=kernel,
        num_steps=num_steps))(state, seed)
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    true_samples = target_dist.sample(4096, seed=self._make_seed(_test_seed()))

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
Exemplo n.º 2
0
  def testBasicHMC(self):
    step_size = 0.2
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2])

    base_mean = tf.constant([2., 3.])
    base_scale = tf.constant([2., 0.5])

    def target_log_prob_fn(x):
      return -tf.reduce_sum(0.5 * tf.square(
          (x - base_mean) / base_scale), -1), ()

    def kernel(hmc_state, seed):
      if not self._is_on_jax:
        hmc_seed = _test_seed()
      else:
        hmc_seed, seed = util.split_seed(seed, 2)
      hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
          hmc_state,
          step_size=step_size,
          num_integrator_steps=num_leapfrog_steps,
          target_log_prob_fn=target_log_prob_fn,
          seed=hmc_seed)
      return (hmc_state, seed), hmc_extra

    if not self._is_on_jax:
      seed = _test_seed()
    else:
      seed = self._make_seed(_test_seed())

    # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
    # for the jit to do anything.
    _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
        state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
               seed),
        fn=kernel,
        num_steps=num_steps,
        trace_fn=lambda state, extra: state[0].state))(state, seed)
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_var = tf.math.reduce_variance(chain, axis=[0, 1])

    true_samples = util.random_normal(
        shape=[4096, 2], dtype=tf.float32, seed=seed) * base_scale + base_mean

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_var = tf.math.reduce_variance(true_samples, axis=0)

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
Exemplo n.º 3
0
    def computation(state, seed):
      bijector = tfp.bijectors.Softplus()
      base_dist = tfp.distributions.MultivariateNormalFullCovariance(
          loc=base_mean, covariance_matrix=base_cov)
      target_dist = bijector(base_dist)

      def orig_target_log_prob_fn(x):
        return target_dist.log_prob(x), ()

      target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
          orig_target_log_prob_fn, bijector, state)

      def kernel(hmc_state, step_size_state, step, seed):
        if not self._is_on_jax:
          hmc_seed = _test_seed()
        else:
          hmc_seed, seed = util.split_seed(seed, 2)
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=hmc_seed)

        rate = fun_mcmc.prefab._polynomial_decay(  # pylint: disable=protected-access
            step=step,
            step_size=self._constant(0.01),
            power=0.5,
            decay_steps=num_adapt_steps,
            final_step_size=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1, seed),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

      _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
          state=(fun_mcmc.hamiltonian_monte_carlo_init(state,
                                                       target_log_prob_fn),
                 fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed),
          fn=kernel,
          num_steps=num_adapt_steps + num_steps,
      )
      true_samples = target_dist.sample(
          4096, seed=self._make_seed(_test_seed()))
      return chain, log_accept_ratio_trace, true_samples
Exemplo n.º 4
0
    def trace():
      kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
          state,
          step_size=0.1,
          num_integrator_steps=3,
          target_log_prob_fn=target_log_prob_fn,
          seed=_test_seed())

      fun_mcmc.trace(
          state=fun_mcmc.hamiltonian_monte_carlo_init(
              tf.zeros([1]), target_log_prob_fn),
          fn=kernel,
          num_steps=4,
          trace_fn=lambda *args: ())
Exemplo n.º 5
0
        def trace():
            # pylint: disable=g-long-lambda
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=self._make_seed(tfp_test_util.test_seed()))

            fun_mcmc.trace(state=fun_mcmc.hamiltonian_monte_carlo_init(
                state=tf.zeros([1]), target_log_prob_fn=target_log_prob_fn),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
Exemplo n.º 6
0
  def testPreconditionedHMC(self):
    step_size = 0.2
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2])

    base_mean = [1., 0]
    base_cov = [[1, 0.5], [0.5, 1]]

    bijector = tfb.Softplus()
    base_dist = tfd.MultivariateNormalFullCovariance(
        loc=base_mean, covariance_matrix=base_cov)
    target_dist = bijector(base_dist)

    def orig_target_log_prob_fn(x):
      return target_dist.log_prob(x), ()

    target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
        orig_target_log_prob_fn, bijector, state)

    # pylint: disable=g-long-lambda
    kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo(
        state,
        step_size=step_size,
        num_integrator_steps=num_leapfrog_steps,
        target_log_prob_fn=target_log_prob_fn,
        seed=_test_seed()))

    _, chain = fun_mcmc.trace(
        state=fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
        fn=kernel,
        num_steps=num_steps,
        trace_fn=lambda state, extra: state.state_extra[0])
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    true_samples = target_dist.sample(4096, seed=_test_seed())

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
Exemplo n.º 7
0
    def computation(state):
      bijector = tfb.Softplus()
      base_dist = tfd.MultivariateNormalFullCovariance(
          loc=base_mean, covariance_matrix=base_cov)
      target_dist = bijector(base_dist)

      def orig_target_log_prob_fn(x):
        return target_dist.log_prob(x), ()

      target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
          orig_target_log_prob_fn, bijector, state)

      def kernel(hmc_state, step_size_state, step):
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn)

        rate = tf.compat.v1.train.polynomial_decay(
            0.01,
            global_step=step,
            power=0.5,
            decay_steps=num_adapt_steps,
            end_learning_rate=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

      _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
          state=(fun_mcmc.hamiltonian_monte_carlo_init(state,
                                                       target_log_prob_fn),
                 fun_mcmc.adam_init(tf.math.log(step_size)), 0),
          fn=kernel,
          num_steps=num_adapt_steps + num_steps,
      )
      true_samples = target_dist.sample(4096, seed=_test_seed())
      return chain, log_accept_ratio_trace, true_samples
Exemplo n.º 8
0
  def testRunningApproximateAutoCovariance(self, state_shape, event_ndims,
                                           aggregation):
    # We'll use HMC as the source of our chain.
    # While HMC is being sampled, we also compute the running autocovariance.
    step_size = 0.2
    num_steps = 1000
    num_leapfrog_steps = 10
    max_lags = 300

    state = tf.constant(np.zeros(state_shape).astype(np.float32))

    def target_log_prob_fn(x):
      lp = -0.5 * tf.square(x)
      if event_ndims is None:
        return lp, ()
      else:
        return tf.reduce_sum(lp, -1), ()

    def kernel(hmc_state, raac_state, seed):
      if not self._is_on_jax:
        hmc_seed = _test_seed()
      else:
        hmc_seed, seed = util.split_seed(seed, 2)
      hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
          hmc_state,
          step_size=step_size,
          num_integrator_steps=num_leapfrog_steps,
          target_log_prob_fn=target_log_prob_fn,
          seed=hmc_seed)
      raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step(
          raac_state, hmc_state.state, axis=aggregation)
      return (hmc_state, raac_state, seed), hmc_extra

    if not self._is_on_jax:
      seed = _test_seed()
    else:
      seed = self._make_seed(_test_seed())

    # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
    # for the jit to do anything.
    (_, raac_state, _), chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
        state=(
            fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
            fun_mcmc.running_approximate_auto_covariance_init(
                max_lags=max_lags,
                state_shape=state_shape,
                dtype=state.dtype,
                axis=aggregation),
            seed,
        ),
        fn=kernel,
        num_steps=num_steps,
        trace_fn=lambda state, extra: state[0].state))(state, seed)

    true_aggregation = (0,) + (() if aggregation is None else tuple(
        [a + 1 for a in util.flatten_tree(aggregation)]))
    true_variance = np.array(
        tf.math.reduce_variance(np.array(chain), true_aggregation))
    true_autocov = np.array(
        tfp.stats.auto_correlation(np.array(chain), axis=0, max_lags=max_lags))
    if aggregation is not None:
      true_autocov = tf.reduce_mean(
          true_autocov, [a + 1 for a in util.flatten_tree(aggregation)])

    self.assertAllClose(true_variance, raac_state.auto_covariance[0], 1e-5)
    self.assertAllClose(
        true_autocov,
        raac_state.auto_covariance / raac_state.auto_covariance[0],
        atol=0.1)
Exemplo n.º 9
0
def train_p(q, u, x_pos, step_size, opt_p):
    """Train P using the standard CD objective.

  Args:
    q: `ModelQ`.
    u: A callable representing the energy function.
    x_pos: A batch of positive examples.
    step_size: Step size to use for HMC.
    opt_p: A `tf.optimizer.Optimizer`.

  Returns:
    x_neg_q: Negative samples sampled from `q`.
    x_neg_p: Negative samples used to train `p`, possibly generated via HMC.
    p_accept: Acceptance rate of HMC.
    step_size: The new step size, possibly adapted to adjust the acceptance
      rate.
    pos_e: Mean energy of the positive samples across the batch.
    pos_e: Mean energy of the positive samples across the batch, after the
      parameter update.
    neg_e_q: Mean energy of `x_neg_q` across the batch.
    neg_e_p: Mean energy of `x_neg_p` across the batch.
    neg_e_p_updated: Mean energy of `x_neg_p` across the batch, after the
      parameter update.
  """
    def create_momentum_sample_fn(state):
        sample_fn = lambda seed: tf.random.normal(  # pylint: disable=g-long-lambda
            tf.shape(state),
            stddev=FLAGS.mcmc_momentum_stddev)
        return sample_fn

    _, x_neg_q, _ = q.sample_with_log_prob(FLAGS.batch_size,
                                           temp=FLAGS.q_temperature)
    neg_e_q = tf.reduce_mean(u(x_neg_q))

    def p_log_prob(x):
        return -u(x)

    if FLAGS.use_mcmc:

        def log_prob_non_transformed(x):
            p_log_p = p_log_prob(x)

            return p_log_p, (x, )

        # TODO(siege): Why aren't we actually using NeuTra?
        # def log_prob_transformed(z):
        #   x, logdet = q.reverse(z)
        #   p_log_p = p_log_prob(x)

        #   return p_log_p + logdet, (x,)

        def kernel(hmc_state, step_size, step):
            """HMC kernel."""
            hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                hmc_state,
                step_size=step_size,
                num_integrator_steps=FLAGS.mcmc_leapfrog_steps,
                momentum_sample_fn=create_momentum_sample_fn(hmc_state.state),
                target_log_prob_fn=log_prob_non_transformed)

            mean_p_accept = tf.reduce_mean(
                tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))

            if FLAGS.mcmc_adapt_step_size:
                step_size = fun_mcmc.sign_adaptation(step_size,
                                                     output=mean_p_accept,
                                                     set_point=0.9)

            return (hmc_state, step_size, step + 1), hmc_extra

        hmc_state, is_accepted = fun_mcmc.trace(
            state=(fun_mcmc.hamiltonian_monte_carlo_init(
                x_neg_q, log_prob_non_transformed), step_size, 0),
            fn=kernel,
            num_steps=FLAGS.mcmc_num_steps,
            trace_fn=lambda _, hmc_extra: hmc_extra.is_accepted)

        x_neg_p = hmc_state[0].state_extra[0]
        step_size = hmc_state[1]

        p_accept = tf.reduce_mean(tf.cast(is_accepted, tf.float32))
    else:
        x_neg_p = x_neg_q
        p_accept = 0.0
        step_size = 0.0

    with tf.GradientTape() as tape:
        tape.watch(u.trainable_variables)
        pos_e = tf.reduce_mean(u(x_pos))
        neg_e_p = tf.reduce_mean(u(x_neg_p))
        loss = pos_e - neg_e_p + tf.square(pos_e) * FLAGS.p_center_regularizer

    variables = u.trainable_variables
    grads = tape.gradient(loss, variables)
    grads_and_vars = list(zip(grads, variables))
    opt_p.apply_gradients(grads_and_vars)

    pos_e_updated = tf.reduce_mean(u(x_pos))
    neg_e_p_updated = tf.reduce_mean(u(x_neg_p))

    return (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated,
            neg_e_q, neg_e_p, neg_e_p_updated)