Esempio n. 1
0
  def testMclachlanIntegratorStepReversible(self):

    def target_log_prob_fn(q):
      return -q**2, []

    def kinetic_energy_fn(p):
      return p**2., []

    seed = self._make_seed(_test_seed())

    state = 1.
    _, _, state_grads = fun_mcmc.call_potential_fn_with_grads(
        target_log_prob_fn,
        state,
    )

    state_fwd, _ = _fwd_mclachlan_optimal_4th_order_step(
        integrator_step_state=fun_mcmc.IntegratorStepState(
            state=state,
            state_grads=state_grads,
            momentum=util.random_normal([], tf.float32, seed)),
        step_size=0.1,
        target_log_prob_fn=target_log_prob_fn,
        kinetic_energy_fn=kinetic_energy_fn)

    state_rev, _ = _rev_mclachlan_optimal_4th_order_step(
        integrator_step_state=state_fwd._replace(momentum=-state_fwd.momentum),
        step_size=0.1,
        target_log_prob_fn=target_log_prob_fn,
        kinetic_energy_fn=kinetic_energy_fn)

    self.assertAllClose(state, state_rev.state, atol=1e-6)
Esempio n. 2
0
  def testIntegratorStep(self, method, num_tlp_calls, num_tlp_calls_jax=None):

    tlp_call_counter = [0]

    def target_log_prob_fn(q):
      tlp_call_counter[0] += 1
      return -q**2, 1.

    def kinetic_energy_fn(p):
      return tf.abs(p)**3., 2.

    state = 1.
    _, _, state_grads = fun_mcmc.call_potential_fn_with_grads(
        target_log_prob_fn,
        state,
    )

    state, extras = method(
        integrator_step_state=fun_mcmc.IntegratorStepState(
            state=state, state_grads=state_grads, momentum=2.),
        step_size=0.1,
        target_log_prob_fn=target_log_prob_fn,
        kinetic_energy_fn=kinetic_energy_fn)

    if num_tlp_calls_jax is not None and self._is_on_jax:
      num_tlp_calls = num_tlp_calls_jax
    self.assertEqual(num_tlp_calls, tlp_call_counter[0])
    self.assertEqual(1., extras.state_extra)
    self.assertEqual(2., extras.kinetic_energy_extra)

    initial_hamiltonian = -target_log_prob_fn(1.)[0] + kinetic_energy_fn(2.)[0]
    fin_hamiltonian = -target_log_prob_fn(state.state)[0] + kinetic_energy_fn(
        state.momentum)[0]

    self.assertAllClose(fin_hamiltonian, initial_hamiltonian, atol=0.2)
Esempio n. 3
0
  def testIntegratorStepReversible(self, method):

    def target_log_prob_fn(q):
      return -q**2, []

    def kinetic_energy_fn(p):
      return p**2., []

    seed = self._make_seed(_test_seed())

    state = self._constant(1.)
    _, _, state_grads = fun_mcmc.call_potential_fn_with_grads(
        target_log_prob_fn,
        state,
    )

    state_fwd, _ = method(
        integrator_step_state=fun_mcmc.IntegratorStepState(
            state=state,
            state_grads=state_grads,
            momentum=util.random_normal([], self._dtype, seed)),
        step_size=self._constant(0.1),
        target_log_prob_fn=target_log_prob_fn,
        kinetic_energy_fn=kinetic_energy_fn)

    state_rev, _ = method(
        integrator_step_state=state_fwd._replace(momentum=-state_fwd.momentum),
        step_size=self._constant(0.1),
        target_log_prob_fn=target_log_prob_fn,
        kinetic_energy_fn=kinetic_energy_fn)

    self.assertAllClose(state, state_rev.state, atol=1e-6)
Esempio n. 4
0
  def testSurrogateLossFnDecorator(self):
    @fun_mcmc.make_surrogate_loss_fn(loss_value=1.)
    def loss_fn(_):
      return 3., 2.

    ret, extra, grads = fun_mcmc.call_potential_fn_with_grads(loss_fn, 0.)
    self.assertAllClose(1., ret)
    self.assertAllClose(2., extra)
    self.assertAllClose(3., grads)
Esempio n. 5
0
    def testRaggedIntegrator(self):
        def target_log_prob_fn(q):
            return -q**2, q

        def kinetic_energy_fn(p):
            return tf.abs(p)**3., p

        integrator_fn = lambda state, num_steps: fun_mcmc.hamiltonian_integrator(  # pylint: disable=g-long-lambda
            state,
            num_steps=num_steps,
            integrator_step_fn=lambda state: fun_mcmc.leapfrog_step(  # pylint: disable=g-long-lambda
                state,
                step_size=0.1,
                target_log_prob_fn=target_log_prob_fn,
                kinetic_energy_fn=kinetic_energy_fn),
            kinetic_energy_fn=kinetic_energy_fn,
            integrator_trace_fn=lambda state, extra: (state, extra))

        state = tf.zeros([2])
        momentum = tf.ones([2])
        target_log_prob, _, state_grads = fun_mcmc.call_potential_fn_with_grads(
            target_log_prob_fn, state)

        start_state = fun_mcmc.IntegratorState(
            target_log_prob=target_log_prob,
            momentum=momentum,
            state=state,
            state_grads=state_grads,
            state_extra=state,
        )

        state_1 = integrator_fn(start_state, 1)
        state_2 = integrator_fn(start_state, 2)
        state_1_2 = integrator_fn(start_state, [1, 2])

        # Make sure integrators actually integrated to different points.
        self.assertFalse(np.all(state_1[0].state == state_2[0].state))

        # Ragged integration should be consistent with the non-ragged equivalent.
        def get_batch(state, idx):
            # For the integrator trace, we'll grab the final value.
            return util.map_tree(
                lambda x: x[idx] if len(x.shape) == 1 else x[-1, idx], state)

        self.assertAllClose(get_batch(state_1, 0), get_batch(state_1_2, 0))
        self.assertAllClose(get_batch(state_2, 0), get_batch(state_1_2, 1))

        # Ragged traces should be equal up to the number of steps for the batch
        # element.
        def get_slice(state, num, idx):
            return util.map_tree(lambda x: x[:num, idx],
                                 state[1].integrator_trace)

        self.assertAllClose(get_slice(state_1, 1, 0),
                            get_slice(state_1_2, 1, 0))
        self.assertAllClose(get_slice(state_2, 2, 0),
                            get_slice(state_1_2, 2, 1))
Esempio n. 6
0
  def testSurrogateLossFn(self, state):
    def grad_fn(*args, **kwargs):
      # This is uglier than user code due to the parameterized test...
      new_state = util.unflatten_tree(state, util.flatten_tree((args, kwargs)))
      return util.map_tree(lambda x: x + 1., new_state), new_state
    loss_fn = fun_mcmc.make_surrogate_loss_fn(grad_fn)

    # Mutate the state to make sure we didn't capture anything.
    state = util.map_tree(lambda x: x + 1., state)
    ret, extra, grads = fun_mcmc.call_potential_fn_with_grads(loss_fn, state)
    # The default is 0.
    self.assertAllClose(0., ret)
    # The gradients of the surrogate loss are state + 1.
    self.assertAllClose(util.map_tree(lambda x: x + 1., state), grads)
    self.assertAllClose(state, extra)