Exemplo n.º 1
0
    def testBasicHMC(self):
        step_size = 0.2
        num_steps = 2000
        num_leapfrog_steps = 10
        state = tf.ones([16, 2])

        base_mean = tf.constant([2., 3.])
        base_scale = tf.constant([2., 0.5])

        def target_log_prob_fn(x):
            return -tf.reduce_sum(
                0.5 * tf.square((x - base_mean) / base_scale), -1), ()

        def kernel(hmc_state, seed):
            if backend.get_backend() == backend.TENSORFLOW:
                hmc_seed = tfp_test_util.test_seed()
            else:
                hmc_seed, seed = util.split_seed(seed, 2)
            hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                hmc_state,
                step_size=step_size,
                num_integrator_steps=num_leapfrog_steps,
                target_log_prob_fn=target_log_prob_fn,
                seed=hmc_seed)
            return (hmc_state, seed), hmc_extra

        if backend.get_backend() == backend.TENSORFLOW:
            seed = tfp_test_util.test_seed()
        else:
            seed = self._make_seed(tfp_test_util.test_seed())

        # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
        # for the jit to do anything.
        _, chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
            state=(fun_mcmc.HamiltonianMonteCarloState(state), seed),
            fn=kernel,
            num_steps=num_steps,
            trace_fn=lambda state, extra: state[0].state))(state, seed)
        # Discard the warmup samples.
        chain = chain[1000:]

        sample_mean = tf.reduce_mean(chain, axis=[0, 1])
        sample_var = tf.math.reduce_variance(chain, axis=[0, 1])

        true_samples = util.random_normal(shape=[4096, 2],
                                          dtype=tf.float32,
                                          seed=seed) * base_scale + base_mean

        true_mean = tf.reduce_mean(true_samples, axis=0)
        true_var = tf.math.reduce_variance(true_samples, axis=0)

        self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
        self.assertAllClose(true_var, sample_var, rtol=0.1, atol=0.1)
Exemplo n.º 2
0
        def trace():
            kernel = lambda state: fun_mcmc.hamiltonian_monte_carlo(
                state,
                step_size=0.1,
                num_integrator_steps=3,
                target_log_prob_fn=target_log_prob_fn,
                seed=tfp_test_util.test_seed())

            fun_mcmc.trace(state=fun_mcmc.HamiltonianMonteCarloState(
                tf.zeros([1])),
                           fn=kernel,
                           num_steps=4,
                           trace_fn=lambda *args: ())
Exemplo n.º 3
0
    def testPreconditionedHMC(self):
        step_size = 0.2
        num_steps = 2000
        num_leapfrog_steps = 10
        state = tf.ones([16, 2])

        base_mean = [1., 0]
        base_cov = [[1, 0.5], [0.5, 1]]

        bijector = tfb.Softplus()
        base_dist = tfd.MultivariateNormalFullCovariance(
            loc=base_mean, covariance_matrix=base_cov)
        target_dist = bijector(base_dist)

        def orig_target_log_prob_fn(x):
            return target_dist.log_prob(x), ()

        target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
            orig_target_log_prob_fn, bijector, state)

        # pylint: disable=g-long-lambda
        kernel = tf.function(lambda state: fun_mcmc.hamiltonian_monte_carlo(
            state,
            step_size=step_size,
            num_integrator_steps=num_leapfrog_steps,
            target_log_prob_fn=target_log_prob_fn,
            seed=tfp_test_util.test_seed()))

        _, chain = fun_mcmc.trace(
            state=fun_mcmc.HamiltonianMonteCarloState(state),
            fn=kernel,
            num_steps=num_steps,
            trace_fn=lambda state, extra: state.state_extra[0])
        # Discard the warmup samples.
        chain = chain[1000:]

        sample_mean = tf.reduce_mean(chain, axis=[0, 1])
        sample_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        true_samples = target_dist.sample(4096, seed=tfp_test_util.test_seed())

        true_mean = tf.reduce_mean(true_samples, axis=0)
        true_cov = tfp.stats.covariance(chain, sample_axis=[0, 1])

        self.assertAllClose(true_mean, sample_mean, rtol=0.1, atol=0.1)
        self.assertAllClose(true_cov, sample_cov, rtol=0.1, atol=0.1)
Exemplo n.º 4
0
        def computation(state):
            bijector = tfb.Softplus()
            base_dist = tfd.MultivariateNormalFullCovariance(
                loc=base_mean, covariance_matrix=base_cov)
            target_dist = bijector(base_dist)

            def orig_target_log_prob_fn(x):
                return target_dist.log_prob(x), ()

            target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
                orig_target_log_prob_fn, bijector, state)

            def kernel(hmc_state, step_size_state, step):
                hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                    hmc_state,
                    step_size=tf.exp(step_size_state.state),
                    num_integrator_steps=num_leapfrog_steps,
                    target_log_prob_fn=target_log_prob_fn)

                rate = tf.compat.v1.train.polynomial_decay(
                    0.01,
                    global_step=step,
                    power=0.5,
                    decay_steps=num_adapt_steps,
                    end_learning_rate=0.)
                mean_p_accept = tf.reduce_mean(
                    tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))

                loss_fn = fun_mcmc.make_surrogate_loss_fn(
                    lambda _: (0.9 - mean_p_accept, ()))
                step_size_state, _ = fun_mcmc.adam_step(step_size_state,
                                                        loss_fn,
                                                        learning_rate=rate)

                return ((hmc_state, step_size_state, step + 1),
                        (hmc_state.state_extra[0], hmc_extra.log_accept_ratio))

            _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
                state=(fun_mcmc.HamiltonianMonteCarloState(state),
                       fun_mcmc.adam_init(tf.math.log(step_size)), 0),
                fn=kernel,
                num_steps=num_adapt_steps + num_steps,
            )
            true_samples = target_dist.sample(4096, seed=_test_seed())
            return chain, log_accept_ratio_trace, true_samples
Exemplo n.º 5
0
        def computation(state):
            bijector = tfb.Softplus()
            base_dist = tfd.MultivariateNormalFullCovariance(
                loc=base_mean, covariance_matrix=base_cov)
            target_dist = bijector(base_dist)

            def orig_target_log_prob_fn(x):
                return target_dist.log_prob(x), ()

            target_log_prob_fn, state = fun_mcmc.transform_log_prob_fn(
                orig_target_log_prob_fn, bijector, state)

            def kernel(hmc_state, step_size, step):
                hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                    hmc_state,
                    step_size=step_size,
                    num_integrator_steps=num_leapfrog_steps,
                    target_log_prob_fn=target_log_prob_fn)

                rate = tf.compat.v1.train.polynomial_decay(
                    0.01,
                    global_step=step,
                    power=0.5,
                    decay_steps=num_adapt_steps,
                    end_learning_rate=0.)
                mean_p_accept = tf.reduce_mean(
                    tf.exp(tf.minimum(0., hmc_extra.log_accept_ratio)))
                step_size = fun_mcmc.sign_adaptation(step_size,
                                                     output=mean_p_accept,
                                                     set_point=0.9,
                                                     adaptation_rate=rate)

                return (hmc_state, step_size, step + 1), hmc_extra

            _, (chain, log_accept_ratio_trace) = fun_mcmc.trace(
                (fun_mcmc.HamiltonianMonteCarloState(state), step_size, 0),
                kernel,
                num_adapt_steps + num_steps,
                trace_fn=lambda state, extra:
                (state[0].state_extra[0], extra.log_accept_ratio))
            true_samples = target_dist.sample(4096,
                                              seed=tfp_test_util.test_seed())
            return chain, log_accept_ratio_trace, true_samples
Exemplo n.º 6
0
    def testRunningApproximateAutoCovariance(self, state_shape, event_ndims,
                                             aggregation):
        # We'll use HMC as the source of our chain.
        # While HMC is being sampled, we also compute the running autocovariance.
        step_size = 0.2
        num_steps = 1000
        num_leapfrog_steps = 10
        max_lags = 300

        state = tf.constant(np.zeros(state_shape).astype(np.float32))

        def target_log_prob_fn(x):
            lp = -0.5 * tf.square(x)
            if event_ndims is None:
                return lp, ()
            else:
                return tf.reduce_sum(lp, -1), ()

        def kernel(hmc_state, raac_state, seed):
            if backend.get_backend() == backend.TENSORFLOW:
                hmc_seed = _test_seed()
            else:
                hmc_seed, seed = util.split_seed(seed, 2)
            hmc_state, hmc_extra = fun_mcmc.hamiltonian_monte_carlo(
                hmc_state,
                step_size=step_size,
                num_integrator_steps=num_leapfrog_steps,
                target_log_prob_fn=target_log_prob_fn,
                seed=hmc_seed)
            raac_state, _ = fun_mcmc.running_approximate_auto_covariance_step(
                raac_state, hmc_state.state, axis=aggregation)
            return (hmc_state, raac_state, seed), hmc_extra

        if backend.get_backend() == backend.TENSORFLOW:
            seed = _test_seed()
        else:
            seed = self._make_seed(_test_seed())

        # Subtle: Unlike TF, JAX needs a data dependency from the inputs to outputs
        # for the jit to do anything.
        (_, raac_state,
         _), chain = tf.function(lambda state, seed: fun_mcmc.trace(  # pylint: disable=g-long-lambda
             state=(
                 fun_mcmc.HamiltonianMonteCarloState(state),
                 fun_mcmc.running_approximate_auto_covariance_init(
                     max_lags=max_lags,
                     state_shape=state_shape,
                     dtype=state.dtype,
                     axis=aggregation),
                 seed,
             ),
             fn=kernel,
             num_steps=num_steps,
             trace_fn=lambda state, extra: state[0].state))(state, seed)

        true_aggregation = (0, ) + (() if aggregation is None else tuple(
            [a + 1 for a in util.flatten_tree(aggregation)]))
        true_variance = np.array(
            tf.math.reduce_variance(np.array(chain), true_aggregation))
        true_autocov = np.array(
            tfp.stats.auto_correlation(np.array(chain),
                                       axis=0,
                                       max_lags=max_lags))
        if aggregation is not None:
            true_autocov = tf.reduce_mean(
                true_autocov, [a + 1 for a in util.flatten_tree(aggregation)])

        self.assertAllClose(true_variance, raac_state.auto_covariance[0], 1e-5)
        self.assertAllClose(true_autocov,
                            raac_state.auto_covariance /
                            raac_state.auto_covariance[0],
                            atol=0.1)