コード例 #1
0
ファイル: fun_mcmc_test.py プロジェクト: ymodak/probability
  def testTransformLogProbFn(self):

    def log_prob_fn(x, y):
      return (tfp.distributions.Normal(self._constant(0.), 1.).log_prob(x) +
              tfp.distributions.Normal(self._constant(1.), 1.).log_prob(y)), ()

    bijectors = [
        tfp.bijectors.AffineScalar(scale=self._constant(2.)),
        tfp.bijectors.AffineScalar(scale=self._constant(3.))
    ]

    (transformed_log_prob_fn,
     transformed_init_state) = fun_mcmc.transform_log_prob_fn(
         log_prob_fn, bijectors,
         [self._constant(2.), self._constant(3.)])

    self.assertIsInstance(transformed_init_state, list)
    self.assertAllClose([1., 1.], transformed_init_state)
    tlp, (orig_space, _) = (
        transformed_log_prob_fn(self._constant(1.), self._constant(1.)))
    lp = log_prob_fn(self._constant(2.), self._constant(3.))[0] + sum(
        b.forward_log_det_jacobian(self._constant(1.), event_ndims=0)
        for b in bijectors)

    self.assertAllClose([2., 3.], orig_space)
    self.assertAllClose(lp, tlp)
コード例 #2
0
  def testTransformLogProbFnKwargs(self):

    def log_prob_fn(x, y):
      return tfd.Normal(0., 1.).log_prob(x) + tfd.Normal(1., 1.).log_prob(y), ()

    bijectors = {
        'x': tfb.AffineScalar(scale=2.),
        'y': tfb.AffineScalar(scale=3.)
    }

    (transformed_log_prob_fn,
     transformed_init_state) = fun_mcmc.transform_log_prob_fn(
         log_prob_fn, bijectors, {
             'x': 2.,
             'y': 3.
         })

    self.assertIsInstance(transformed_init_state, dict)
    self.assertAllClose({'x': 1., 'y': 1.}, transformed_init_state)

    tlp, (orig_space, _) = transformed_log_prob_fn(x=1., y=1.)
    lp = log_prob_fn(
        x=2., y=3.)[0] + sum(
            b.forward_log_det_jacobian(1., event_ndims=0)
            for b in bijectors.values())

    self.assertAllClose({'x': 2., 'y': 3.}, orig_space)
    self.assertAllClose(lp, tlp)
コード例 #3
0
ファイル: fun_mcmc_test.py プロジェクト: ymodak/probability
  def testPreconditionedHMC(self):
    step_size = self._constant(0.2)
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2], dtype=self._dtype)

    base_mean = self._constant([1., 0])
    base_cov = self._constant([[1, 0.5], [0.5, 1]])

    bijector = tfp.bijectors.Softplus()
    base_dist = tfp.distributions.MultivariateNormalFullCovariance(
        loc=base_mean, covariance_matrix=base_cov)
    target_dist = bijector(base_dist)

    def orig_target_log_prob_fn(x):
      return target_dist.log_prob(x), ()

    target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
        orig_target_log_prob_fn, bijector, state)

    # pylint: disable=g-long-lambda
    def kernel(hmc_state, seed):
      if not self._is_on_jax:
        hmc_seed = _test_seed()
      else:
        hmc_seed, seed = util.split_seed(seed, 2)
      hmc_state, _ = fun_mcmc.hamiltonian_monte_carlo(
          hmc_state,
          step_size=step_size,
          num_integrator_steps=num_leapfrog_steps,
          target_log_prob_fn=target_log_prob_fn,
          seed=hmc_seed)
      return (hmc_state, seed), hmc_state.state_extra[0]

    if not self._is_on_jax:
      seed = _test_seed()
    else:
      seed = self._make_seed(_test_seed())

    # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
    # for the jit to do anything.
    _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
        state=(fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
               seed),
        fn=kernel,
        num_steps=num_steps))(state, seed)
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    true_samples = target_dist.sample(4096, seed=self._make_seed(_test_seed()))

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
コード例 #4
0
ファイル: fun_mcmc_test.py プロジェクト: ymodak/probability
    def computation(state, seed):
      bijector = tfp.bijectors.Softplus()
      base_dist = tfp.distributions.MultivariateNormalFullCovariance(
          loc=base_mean, covariance_matrix=base_cov)
      target_dist = bijector(base_dist)

      def orig_target_log_prob_fn(x):
        return target_dist.log_prob(x), ()

      target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
          orig_target_log_prob_fn, bijector, state)

      def kernel(hmc_state, step_size_state, step, seed):
        if not self._is_on_jax:
          hmc_seed = _test_seed()
        else:
          hmc_seed, seed = util.split_seed(seed, 2)
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=hmc_seed)

        rate = fun_mcmc.prefab._polynomial_decay(  # pylint: disable=protected-access
            step=step,
            step_size=self._constant(0.01),
            power=0.5,
            decay_steps=num_adapt_steps,
            final_step_size=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(self._constant(0.), hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1, seed),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

      _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
          state=(fun_mcmc.hamiltonian_monte_carlo_init(state,
                                                       target_log_prob_fn),
                 fun_mcmc.adam_init(tf.math.log(step_size)), 0, seed),
          fn=kernel,
          num_steps=num_adapt_steps + num_steps,
      )
      true_samples = target_dist.sample(
          4096, seed=self._make_seed(_test_seed()))
      return chain, log_accept_ratio_trace, true_samples
コード例 #5
0
  def testPreconditionedHMC(self):
    step_size = 0.2
    num_steps = 2000
    num_leapfrog_steps = 10
    state = tf.ones([16, 2])

    base_mean = [1., 0]
    base_cov = [[1, 0.5], [0.5, 1]]

    bijector = tfb.Softplus()
    base_dist = tfd.MultivariateNormalFullCovariance(
        loc=base_mean, covariance_matrix=base_cov)
    target_dist = bijector(base_dist)

    def orig_target_log_prob_fn(x):
      return target_dist.log_prob(x), ()

    target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
        orig_target_log_prob_fn, bijector, state)

    # pylint: disable=g-long-lambda
    kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo(
        state,
        step_size=step_size,
        num_integrator_steps=num_leapfrog_steps,
        target_log_prob_fn=target_log_prob_fn,
        seed=_test_seed()))

    _, chain = fun_mcmc.trace(
        state=fun_mcmc.hamiltonian_monte_carlo_init(state, target_log_prob_fn),
        fn=kernel,
        num_steps=num_steps,
        trace_fn=lambda state, extra: state.state_extra[0])
    # Discard the warmup samples.
    chain = chain[1000:]

    sample_mean = tf.reduce_mean(chain, axis=[0, 1])
    sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    true_samples = target_dist.sample(4096, seed=_test_seed())

    true_mean = tf.reduce_mean(true_samples, axis=0)
    true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

    self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
    self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
コード例 #6
0
    def computation(state):
      bijector = tfb.Softplus()
      base_dist = tfd.MultivariateNormalFullCovariance(
          loc=base_mean, covariance_matrix=base_cov)
      target_dist = bijector(base_dist)

      def orig_target_log_prob_fn(x):
        return target_dist.log_prob(x), ()

      target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
          orig_target_log_prob_fn, bijector, state)

      def kernel(hmc_state, step_size_state, step):
        hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
            hmc_state,
            step_size=tf.exp(step_size_state.state),
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn)

        rate = tf.compat.v1.train.polynomial_decay(
            0.01,
            global_step=step,
            power=0.5,
            decay_steps=num_adapt_steps,
            end_learning_rate=0.)
        mean_p_accept = tf.reduce_mean(
            tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))

        loss_fn = fun_mcmc.make_surrogate_loss_fn(
            lambda _: (0.9 - mean_p_accept, ()))
        step_size_state, _ = fun_mcmc.adam_step(
            step_size_state, loss_fn, learning_rate=rate)

        return ((hmc_state, step_size_state, step + 1),
                (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

      _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
          state=(fun_mcmc.hamiltonian_monte_carlo_init(state,
                                                       target_log_prob_fn),
                 fun_mcmc.adam_init(tf.math.log(step_size)), 0),
          fn=kernel,
          num_steps=num_adapt_steps + num_steps,
      )
      true_samples = target_dist.sample(4096, seed=_test_seed())
      return chain, log_accept_ratio_trace, true_samples
コード例 #7
0
  def testTransformLogProbFn(self):

    def log_prob_fn(x, y):
      return tfd.Normal(0., 1.).log_prob(x) + tfd.Normal(1., 1.).log_prob(y), ()

    bijectors = [tfb.AffineScalar(scale=2.), tfb.AffineScalar(scale=3.)]

    (transformed_log_prob_fn,
     transformed_init_state) = fun_mcmc.transform_log_prob_fn(
         log_prob_fn, bijectors, [2., 3.])

    self.assertIsInstance(transformed_init_state, list)
    self.assertAllClose([1., 1.], transformed_init_state)
    tlp, (orig_space, _) = transformed_log_prob_fn(1., 1.)
    lp = log_prob_fn(2., 3.)[0] + sum(
        b.forward_log_det_jacobian(1., event_ndims=0) for b in bijectors)

    self.assertAllClose([2., 3.], orig_space)
    self.assertAllClose(lp, tlp)
コード例 #8
0
        def computation(state):
            bijector = tfb.Softplus()
            base_dist = tfd.MultivariateNormalFullCovariance(
                loc=base_mean, covariance_matrix=base_cov)
            target_dist = bijector(base_dist)

            def orig_target_log_prob_fn(x):
                return target_dist.log_prob(x), ()

            target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
                orig_target_log_prob_fn, bijector, state)

            def kernel(hmc_state, step_size, step):
                hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                    hmc_state,
                    step_size=step_size,
                    num_integrator_steps=num_leapfrog_steps,
                    target_log_prob_fn=target_log_prob_fn)

                rate = tf.compat.v1.train.polynomial_decay(
                    0.01,
                    global_step=step,
                    power=0.5,
                    decay_steps=num_adapt_steps,
                    end_learning_rate=0.)
                mean_p_accept = tf.reduce_mean(
                    tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))
                step_size = fun_mcmc.sign_adaptation(step_size,
                                                     output=mean_p_accept,
                                                     set_point=0.9,
                                                     adaptation_rate=rate)

                return (hmc_state, step_size, step + 1), hmc_extra

            _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
                (fun_mcmc.HamiltonianMonteCarloState(state), step_size, 0),
                kernel,
                num_adapt_steps + num_steps,
                trace_fn=lambda state, extra:
                (state[0].state_extra[0], extra.log_accept_ratio))
            true_samples = target_dist.sample(4096,
                                              seed=tfp_test_util.test_seed())
            return chain, log_accept_ratio_trace, true_samples
コード例 #9
0
ファイル: fun_mcmc_test.py プロジェクト: ymodak/probability
  def testTransformLogProbFnKwargs(self):

    def log_prob_fn(x, y):
      return (tfp.distributions.Normal(self._constant(0.), 1.).log_prob(x) +
              tfp.distributions.Normal(self._constant(1.), 1.).log_prob(y)), ()

    bijectors = {
        'x': tfp.bijectors.AffineScalar(scale=self._constant(2.)),
        'y': tfp.bijectors.AffineScalar(scale=self._constant(3.))
    }

    (transformed_log_prob_fn,
     transformed_init_state) = fun_mcmc.transform_log_prob_fn(
         log_prob_fn, bijectors, {
             'x': self._constant(2.),
             'y': self._constant(3.),
         })

    self.assertIsInstance(transformed_init_state, dict)
    self.assertAllClose({
        'x': self._constant(1.),
        'y': self._constant(1.),
    }, transformed_init_state)

    tlp, (orig_space, _) = transformed_log_prob_fn(
        x=self._constant(1.), y=self._constant(1.))
    lp = log_prob_fn(
        x=self._constant(2.), y=self._constant(3.))[0] + sum(
            b.forward_log_det_jacobian(self._constant(1.), event_ndims=0)
            for b in bijectors.values())

    self.assertAllClose({
        'x': self._constant(2.),
        'y': self._constant(3.)
    }, orig_space)
    self.assertAllClose(lp, tlp)