Exemplo n.º 1
0
    def computation(state, seed):
      bijector = tfp.bijectors.Softplus()
      base_dist = tfp.distributions.MultivariateNormalFullCovariance(
          loc=base_mean, covariance_matrix=base_cov)
      target_dist = bijector(base_dist)

      def orig_target_log_prob_fn(x):
        return target_dist.log_prob(x), ()

      target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
          orig_target_log_prob_fn, bijector, state)

      def kernel(hmc_state, step_size_state, step, seed):
        if not self._is_on_jax:
          hmc_seed = _test_seed()
        else:
          hmc_seed, seed = util.split_seed(seed, 2)
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=hmc_seed)

        rate = fun_mcmc.prefab._polynomial_decay(  # pylint: disable=protected-access
            step=step,
            step_size=self._constant(0.01),
            power=0.5,
            decay_steps=num_adapt_steps,
            final_step_size=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1, seed),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

      _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
          state=(fun_mcmc.hamiltonian_monte_carlo_init(state,
                                                       target_log_prob_fn),
                 fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed),
          fn=kernel,
          num_steps=num_adapt_steps + num_steps,
      )
      true_samples = target_dist.sample(
          4096, seed=self._make_seed(_test_seed()))
      return chain, log_accept_ratio_trace, true_samples
Exemplo n.º 2
0
  def testAdam(self):
    def loss_fn(x, y):
      return tf.square(x - 1.) + tf.square(y - 2.), []

    _, [(x, y), loss] = fun_mcmc.trace(
        fun_mcmc.adam_init([tf.zeros([]), tf.zeros([])]),
        lambda adam_state: fun_mcmc.adam_step(  # pylint: disable=g-long-lambda
            adam_state, loss_fn, learning_rate=0.01),
        num_steps=1000,
        trace_fn=lambda state, extra: [state.state, extra.loss])

    self.assertAllClose(1., x[-1], atol=1e-3)
    self.assertAllClose(2., y[-1], atol=1e-3)
    self.assertAllClose(0., loss[-1], atol=1e-3)
Exemplo n.º 3
0
    def computation(state):
      bijector = tfb.Softplus()
      base_dist = tfd.MultivariateNormalFullCovariance(
          loc=base_mean, covariance_matrix=base_cov)
      target_dist = bijector(base_dist)

      def orig_target_log_prob_fn(x):
        return target_dist.log_prob(x), ()

      target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
          orig_target_log_prob_fn, bijector, state)

      def kernel(hmc_state, step_size_state, step):
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn)

        rate = tf.compat.v1.train.polynomial_decay(
            0.01,
            global_step=step,
            power=0.5,
            decay_steps=num_adapt_steps,
            end_learning_rate=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

      _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
          state=(fun_mcmc.hamiltonian_monte_carlo_init(state,
                                                       target_log_prob_fn),
                 fun_mcmc.adam_init(tf.math.log(step_size)), 0),
          fn=kernel,
          num_steps=num_adapt_steps + num_steps,
      )
      true_samples = target_dist.sample(4096, seed=_test_seed())
      return chain, log_accept_ratio_trace, true_samples