示例#1
0
    def testIntegratorStep(self,
                           method,
                           num_tlp_calls,
                           num_tlp_calls_jax=None):

        tlp_call_counter = [0]

        def target_log_prob_fn(q):
            tlp_call_counter[0] += 1
            return -q**2, 1.

        def kinetic_energy_fn(p):
            return tf.abs(p)**3., 2.

        state, extras = method(
            integrator_step_state=fun_mcmc.IntegratorStepState(
                state=1., state_grads=None, momentum=2.),
            step_size=0.1,
            target_log_prob_fn=target_log_prob_fn,
            kinetic_energy_fn=kinetic_energy_fn)

        if num_tlp_calls_jax is not None and backend.get_backend(
        ) == backend.JAX:
            num_tlp_calls = num_tlp_calls_jax
        self.assertEqual(num_tlp_calls, tlp_call_counter[0])
        self.assertEqual(1., extras.state_extra)
        self.assertEqual(2., extras.kinetic_energy_extra)

        initial_hamiltonian = -target_log_prob_fn(1.)[0] + kinetic_energy_fn(
            2.)[0]
        fin_hamiltonian = -target_log_prob_fn(
            state.state)[0] + kinetic_energy_fn(state.momentum)[0]

        self.assertAllClose(fin_hamiltonian, initial_hamiltonian, atol=0.2)
示例#2
0
 def kernel(rwm_state, seed):
     if backend.get_backend() == backend.TENSORFLOW:
         rwm_seed = tfp_test_util.test_seed()
     else:
         rwm_seed, seed = util.split_seed(seed, 2)
     rwm_state, rwm_extra = fun_mcmc.random_walk_metropolis(
         rwm_state,
         target_log_prob_fn=target_log_prob_fn,
         proposal_fn=proposal_fn,
         seed=rwm_seed)
     return (rwm_state, seed), rwm_extra
示例#3
0
 def kernel(hmc_state, seed):
     if backend.get_backend() == backend.TENSORFLOW:
         hmc_seed = tfp_test_util.test_seed()
     else:
         hmc_seed, seed = util.split_seed(seed, 2)
     hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
         hmc_state,
         step_size=step_size,
         num_integrator_steps=num_leapfrog_steps,
         target_log_prob_fn=target_log_prob_fn,
         seed=hmc_seed)
     return (hmc_state, seed), hmc_extra
示例#4
0
    def testBasicHMC(self):
        step_size = 0.2
        num_steps = 2000
        num_leapfrog_steps = 10
        state = tf.ones([16, 2])

        base_mean = tf.constant([2., 3.])
        base_scale = tf.constant([2., 0.5])

        def target_log_prob_fn(x):
            return -tf.reduce_sum(
                0.5 * tf.square((x - base_mean) / base_scale), -1), ()

        def kernel(hmc_state, seed):
            if backend.get_backend() == backend.TENSORFLOW:
                hmc_seed = tfp_test_util.test_seed()
            else:
                hmc_seed, seed = util.split_seed(seed, 2)
            hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                hmc_state,
                step_size=step_size,
                num_integrator_steps=num_leapfrog_steps,
                target_log_prob_fn=target_log_prob_fn,
                seed=hmc_seed)
            return (hmc_state, seed), hmc_extra

        if backend.get_backend() == backend.TENSORFLOW:
            seed = tfp_test_util.test_seed()
        else:
            seed = self._make_seed(tfp_test_util.test_seed())

        # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
        # for the jit to do anything.
        _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
            state=(fun_mcmc.HamiltonianMonteCarloState(state), seed),
            fn=kernel,
            num_steps=num_steps,
            trace_fn=lambda state, extra: state[0].state))(state, seed)
        # Discard the warmup samples.
        chain = chain[1000:]

        sample_mean = tf.reduce_mean(chain, axis=[0, 1])
        sample_var = tf.math.reduce_variance(chain, axis=[0, 1])

        true_samples = util.random_normal(shape=[4096, 2],
                                          dtype=tf.float32,
                                          seed=seed) * base_scale + base_mean

        true_mean = tf.reduce_mean(true_samples, axis=0)
        true_var = tf.math.reduce_variance(true_samples, axis=0)

        self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
        self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
示例#5
0
 def kernel(hmc_state, raac_state, seed):
     if backend.get_backend() == backend.TENSORFLOW:
         hmc_seed = _test_seed()
     else:
         hmc_seed, seed = util.split_seed(seed, 2)
     hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
         hmc_state,
         step_size=step_size,
         num_integrator_steps=num_leapfrog_steps,
         target_log_prob_fn=target_log_prob_fn,
         seed=hmc_seed)
     raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step(
         raac_state, hmc_state.state, axis=aggregation)
     return (hmc_state, raac_state, seed), hmc_extra
示例#6
0
    def testRandomWalkMetropolis(self):
        num_steps = 1000
        state = tf.ones([16], dtype=tf.int32)
        target_logits = tf.constant([1., 2., 3., 4.]) + 2.
        proposal_logits = tf.constant([4., 3., 2., 1.]) + 2.

        def target_log_prob_fn(x):
            return tf.gather(target_logits, x), ()

        def proposal_fn(x, seed):
            current_logits = tf.gather(proposal_logits, x)
            proposal = util.random_categorical(proposal_logits[tf.newaxis],
                                               x.shape[0], seed)[0]
            proposed_logits = tf.gather(proposal_logits, proposal)
            return tf.cast(proposal,
                           x.dtype), ((), proposed_logits - current_logits)

        def kernel(rwm_state, seed):
            if backend.get_backend() == backend.TENSORFLOW:
                rwm_seed = tfp_test_util.test_seed()
            else:
                rwm_seed, seed = util.split_seed(seed, 2)
            rwm_state, rwm_extra = fun_mcmc.random_walk_metropolis(
                rwm_state,
                target_log_prob_fn=target_log_prob_fn,
                proposal_fn=proposal_fn,
                seed=rwm_seed)
            return (rwm_state, seed), rwm_extra

        if backend.get_backend() == backend.TENSORFLOW:
            seed = tfp_test_util.test_seed()
        else:
            seed = self._make_seed(tfp_test_util.test_seed())

        # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
        # for the jit to do anything.
        _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
            state=(fun_mcmc.random_walk_metropolis_init(
                state, target_log_prob_fn), seed),
            fn=kernel,
            num_steps=num_steps,
            trace_fn=lambda state, extra: state[0].state))(state, seed)
        # Discard the warmup samples.
        chain = chain[500:]

        sample_mean = tf.reduce_mean(tf.one_hot(chain, 4), axis=[0, 1])
        self.assertAllClose(tf.nn.softmax(target_logits),
                            sample_mean,
                            atol=0.1)
示例#7
0
def trace(
    state: State,
    fn: TransitionOperator,
    num_steps: IntTensor,
    trace_fn: Callable[[State, TensorNest], TensorNest],
    parallel_iterations: int = 10,
) -> Tuple[State, TensorNest]:
    """`TransitionOperator` that runs `fn` repeatedly and traces its outputs.

  Args:
    state: A nest of `Tensor`s or None.
    fn: A `TransitionOperator`.
    num_steps: Number of steps to run the function for. Must be greater than 1.
    trace_fn: Callable that the unpacked outputs of `fn` and returns a nest of
      `Tensor`s. These will be stacked and returned.
    parallel_iterations: Number of iterations of the while loop to run in
      parallel.

  Returns:
    state: The final state returned by `fn`.
    traces: Stacked outputs of `trace_fn`.
  """
    state = util.map_tree(
        lambda t: (t if t is None else tf.convert_to_tensor(t)), state)

    def wrapper(state):
        state, extra = util.map_tree(tf.convert_to_tensor,
                                     call_transition_operator(fn, state))
        trace_element = util.map_tree(tf.convert_to_tensor,
                                      trace_fn(state, extra))
        return state, trace_element

    # JAX tracing/pre-compilation isn't as stable as TF's, so we won't use it to
    # start.
    if (backend.get_backend() != backend.TENSORFLOW
            or any(e is None for e in util.flatten_tree(state))
            or tf.executing_eagerly()):
        state, first_trace = wrapper(state)
        trace_arrays = util.map_tree(
            lambda v: util.write_dynamic_array(  # pylint: disable=g-long-lambda
                util.make_dynamic_array(
                    v.dtype, size=num_steps, element_shape=v.shape), 0, v),
            first_trace)
        start_idx = 1
    else:
        state_spec = util.map_tree(tf.TensorSpec.from_tensor, state)
        # We need the shapes and dtypes of the outputs of `wrapper` function to
        # create the `TensorArray`s, we can get it by pre-compiling the wrapper
        # function.
        wrapper = tf.function(autograph=False)(wrapper)
        concrete_wrapper = wrapper.get_concrete_function(state_spec)
        _, trace_dtypes = concrete_wrapper.output_dtypes
        _, trace_shapes = concrete_wrapper.output_shapes
        trace_arrays = util.map_tree(
            lambda dtype, shape: tf.TensorArray(  # pylint: disable=g-long-lambda
                dtype,
                size=num_steps,
                element_shape=shape),
            trace_dtypes,
            trace_shapes)
        wrapper = lambda state: concrete_wrapper(*util.flatten_tree(state))
        start_idx = 0

    def body(i, state, trace_arrays):
        state, trace_element = wrapper(state)
        trace_arrays = util.map_tree(
            lambda a, v: util.write_dynamic_array(a, i, v), trace_arrays,
            trace_element)
        return i + 1, state, trace_arrays

    def cond(i, *_):
        return i < num_steps

    _, state, trace_arrays = tf.while_loop(
        cond=cond,
        body=body,
        loop_vars=(start_idx, state, trace_arrays),
        parallel_iterations=parallel_iterations)

    stacked_trace = util.map_tree(util.snapshot_dynamic_array, trace_arrays)

    # TensorFlow often loses the static shape information.
    if backend.get_backend() == backend.TENSORFLOW:
        static_length = tf.get_static_value(num_steps)

        def _merge_static_length(x):
            x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:]))
            return x

        stacked_trace = util.map_tree(_merge_static_length, stacked_trace)

    return state, stacked_trace
示例#8
0
 def _wrapper(self, *args, **kwargs):
     if backend.get_backend() != backend.JAX:
         return fn(self, *args, **kwargs)
示例#9
0
    def testRunningApproximateAutoCovariance(self, state_shape, event_ndims,
                                             aggregation):
        # We'll use HMC as the source of our chain.
        # While HMC is being sampled, we also compute the running autocovariance.
        step_size = 0.2
        num_steps = 1000
        num_leapfrog_steps = 10
        max_lags = 300

        state = tf.constant(np.zeros(state_shape).astype(np.float32))

        def target_log_prob_fn(x):
            lp = -0.5 * tf.square(x)
            if event_ndims is None:
                return lp, ()
            else:
                return tf.reduce_sum(lp, -1), ()

        def kernel(hmc_state, raac_state, seed):
            if backend.get_backend() == backend.TENSORFLOW:
                hmc_seed = _test_seed()
            else:
                hmc_seed, seed = util.split_seed(seed, 2)
            hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                hmc_state,
                step_size=step_size,
                num_integrator_steps=num_leapfrog_steps,
                target_log_prob_fn=target_log_prob_fn,
                seed=hmc_seed)
            raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step(
                raac_state, hmc_state.state, axis=aggregation)
            return (hmc_state, raac_state, seed), hmc_extra

        if backend.get_backend() == backend.TENSORFLOW:
            seed = _test_seed()
        else:
            seed = self._make_seed(_test_seed())

        # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
        # for the jit to do anything.
        (_, raac_state,
         _), chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
             state=(
                 fun_mcmc.HamiltonianMonteCarloState(state),
                 fun_mcmc.running_approximate_auto_covariance_init(
                     max_lags=max_lags,
                     state_shape=state_shape,
                     dtype=state.dtype,
                     axis=aggregation),
                 seed,
             ),
             fn=kernel,
             num_steps=num_steps,
             trace_fn=lambda state, extra: state[0].state))(state, seed)

        true_aggregation = (0, ) + (() if aggregation is None else tuple(
            [a + 1 for a in util.flatten_tree(aggregation)]))
        true_variance = np.array(
            tf.math.reduce_variance(np.array(chain), true_aggregation))
        true_autocov = np.array(
            tfp.stats.auto_correlation(np.array(chain),
                                       axis=0,
                                       max_lags=max_lags))
        if aggregation is not None:
            true_autocov = tf.reduce_mean(
                true_autocov, [a + 1 for a in util.flatten_tree(aggregation)])

        self.assertAllClose(true_variance, raac_state.auto_covariance[0], 1e-5)
        self.assertAllClose(true_autocov,
                            raac_state.auto_covariance /
                            raac_state.auto_covariance[0],
                            atol=0.1)