예제 #1
0
  def testMclachlanIntegratorStepReversible(self):

    def target_log_prob_fn(q):
      return -q**2, []

    def kinetic_energy_fn(p):
      return p**2., []

    seed = self._make_seed(_test_seed())

    state = 1.
    _, _, state_grads = fun_mcmc.call_potential_fn_with_grads(
        target_log_prob_fn,
        state,
    )

    state_fwd, _ = _fwd_mclachlan_optimal_4th_order_step(
        integrator_step_state=fun_mcmc.IntegratorStepState(
            state=state,
            state_grads=state_grads,
            momentum=util.random_normal([], tf.float32, seed)),
        step_size=0.1,
        target_log_prob_fn=target_log_prob_fn,
        kinetic_energy_fn=kinetic_energy_fn)

    state_rev, _ = _rev_mclachlan_optimal_4th_order_step(
        integrator_step_state=state_fwd._replace(momentum=-state_fwd.momentum),
        step_size=0.1,
        target_log_prob_fn=target_log_prob_fn,
        kinetic_energy_fn=kinetic_energy_fn)

    self.assertAllClose(state, state_rev.state, atol=1e-6)
예제 #2
0
  def testIntegratorStep(self, method, num_tlp_calls, num_tlp_calls_jax=None):

    tlp_call_counter = [0]

    def target_log_prob_fn(q):
      tlp_call_counter[0] += 1
      return -q**2, 1.

    def kinetic_energy_fn(p):
      return tf.abs(p)**3., 2.

    state = 1.
    _, _, state_grads = fun_mcmc.call_potential_fn_with_grads(
        target_log_prob_fn,
        state,
    )

    state, extras = method(
        integrator_step_state=fun_mcmc.IntegratorStepState(
            state=state, state_grads=state_grads, momentum=2.),
        step_size=0.1,
        target_log_prob_fn=target_log_prob_fn,
        kinetic_energy_fn=kinetic_energy_fn)

    if num_tlp_calls_jax is not None and self._is_on_jax:
      num_tlp_calls = num_tlp_calls_jax
    self.assertEqual(num_tlp_calls, tlp_call_counter[0])
    self.assertEqual(1., extras.state_extra)
    self.assertEqual(2., extras.kinetic_energy_extra)

    initial_hamiltonian = -target_log_prob_fn(1.)[0] + kinetic_energy_fn(2.)[0]
    fin_hamiltonian = -target_log_prob_fn(state.state)[0] + kinetic_energy_fn(
        state.momentum)[0]

    self.assertAllClose(fin_hamiltonian, initial_hamiltonian, atol=0.2)
예제 #3
0
    def testIntegratorStepReversible(self, method):
        def target_log_prob_fn(q):
            return -q**2, []

        def kinetic_energy_fn(p):
            return p**2., []

        seed = self._make_seed(tfp_test_util.test_seed())

        state_fwd, _ = method(
            integrator_step_state=fun_mcmc.IntegratorStepState(
                state=1.,
                state_grads=None,
                momentum=util.random_normal([], tf.float32, seed)),
            step_size=0.1,
            target_log_prob_fn=target_log_prob_fn,
            kinetic_energy_fn=kinetic_energy_fn)

        state_rev, _ = method(integrator_step_state=state_fwd._replace(
            momentum=-state_fwd.momentum),
                              step_size=0.1,
                              target_log_prob_fn=target_log_prob_fn,
                              kinetic_energy_fn=kinetic_energy_fn)

        self.assertAllClose(1., state_rev.state, atol=1e-6)
예제 #4
0
    def testIntegratorStep(self, method, num_tlp_calls):

        tlp_call_counter = [0]

        def target_log_prob_fn(q):
            tlp_call_counter[0] += 1
            return -q**2, 1.

        def kinetic_energy_fn(p):
            return tf.abs(p)**3., 2.

        state, extras = method(
            integrator_step_state=fun_mcmc.IntegratorStepState(
                state=1., state_grads=None, momentum=2.),
            step_size=0.1,
            target_log_prob_fn=target_log_prob_fn,
            kinetic_energy_fn=kinetic_energy_fn)

        self.assertEqual(num_tlp_calls, tlp_call_counter[0])
        self.assertEqual(1., extras.state_extra)
        self.assertEqual(2., extras.kinetic_energy_extra)

        initial_hamiltonian = -target_log_prob_fn(1.)[0] + kinetic_energy_fn(
            2.)[0]
        fin_hamiltonian = -target_log_prob_fn(
            state.state)[0] + kinetic_energy_fn(state.momentum)[0]

        self.assertAllClose(fin_hamiltonian, initial_hamiltonian, atol=0.2)
예제 #5
0
  def testIntegratorStepReversible(self, method):

    def target_log_prob_fn(q):
      return -q**2, []

    def kinetic_energy_fn(p):
      return p**2., []

    seed = self._make_seed(_test_seed())

    state = self._constant(1.)
    _, _, state_grads = fun_mcmc.call_potential_fn_with_grads(
        target_log_prob_fn,
        state,
    )

    state_fwd, _ = method(
        integrator_step_state=fun_mcmc.IntegratorStepState(
            state=state,
            state_grads=state_grads,
            momentum=util.random_normal([], self._dtype, seed)),
        step_size=self._constant(0.1),
        target_log_prob_fn=target_log_prob_fn,
        kinetic_energy_fn=kinetic_energy_fn)

    state_rev, _ = method(
        integrator_step_state=state_fwd._replace(momentum=-state_fwd.momentum),
        step_size=self._constant(0.1),
        target_log_prob_fn=target_log_prob_fn,
        kinetic_energy_fn=kinetic_energy_fn)

    self.assertAllClose(state, state_rev.state, atol=1e-6)
예제 #6
0
    def testMclachlanIntegratorStepReversible(self):
        def target_log_prob_fn(q):
            return -q**2, []

        def kinetic_energy_fn(p):
            return p**2., []

        state_fwd, _ = _fwd_mclachlan_optimal_4th_order_step(
            integrator_step_state=fun_mcmc.IntegratorStepState(
                state=1., state_grads=None, momentum=tf.random.normal([])),
            step_size=0.1,
            target_log_prob_fn=target_log_prob_fn,
            kinetic_energy_fn=kinetic_energy_fn)

        state_rev, _ = _rev_mclachlan_optimal_4th_order_step(
            integrator_step_state=state_fwd._replace(
                momentum=-state_fwd.momentum),
            step_size=0.1,
            target_log_prob_fn=target_log_prob_fn,
            kinetic_energy_fn=kinetic_energy_fn)

        self.assertAllClose(1., state_rev.state, atol=1e-6)