Ejemplo n.º 1
0
def _get_field(kernel_results, field_name):
    try:
        return mcmc_util.get_field(kernel_results, field_name)
    except TypeError:
        msg = _kernel_result_not_implemented_message_template.format(
            kernel_results, field_name)
        raise REMCFieldNotFoundError(msg)
Ejemplo n.º 2
0
 def testValidKernelResults(self, kernel_results):
   updated_kernel_results = util.update_field(
       kernel_results, 'some_field', 'moose')
   self.assertEqual(
       util.get_field(
           updated_kernel_results, 'some_field'), 'moose')
   with self.assertRaisesRegexp(TypeError, 'set some_other_field'):
     util.update_field(kernel_results, 'some_other_field', 'antelope')
Ejemplo n.º 3
0
 def trace_fn(state, results):  # pylint: disable=unused-argument
     return [
         mcmc_util.get_field(results.post_swap_replica_results,
                             'log_accept_ratio'),
         results.post_swap_replica_states
     ]
Ejemplo n.º 4
0
    def testNormal(self,
                   tfp_transition_kernel,
                   inverse_temperatures,
                   state_includes_replicas,
                   store_parameters_in_results,
                   asserts,
                   prob_swap=1.0,
                   dtype=np.float32):
        """Sampling from standard normal with REMC."""

        target = tfd.Normal(dtype(0.), dtype(1.))
        inverse_temperatures = dtype(inverse_temperatures)
        num_replica = len(inverse_temperatures)

        step_size = 0.51234 / np.sqrt(inverse_temperatures)
        num_leapfrog_steps = 3

        def make_kernel_fn(target_log_prob_fn):
            return tfp_transition_kernel(
                target_log_prob_fn=target_log_prob_fn,
                step_size=step_size,
                store_parameters_in_results=store_parameters_in_results,
                num_leapfrog_steps=num_leapfrog_steps)

        remc = tfp.mcmc.ReplicaExchangeMC(
            target_log_prob_fn=target.log_prob,
            inverse_temperatures=inverse_temperatures,
            state_includes_replicas=state_includes_replicas,
            make_kernel_fn=make_kernel_fn,
            swap_proposal_fn=tfp.mcmc.default_swap_proposal_fn(prob_swap))

        num_results = 17
        if asserts:
            num_results = 2000
            remc.one_step = tf.function(remc.one_step, autograph=False)

        if state_includes_replicas:
            current_state = target.sample(num_replica, seed=_set_seed())
        else:
            current_state = target.sample(seed=_set_seed())

        states, kernel_results = tfp.mcmc.sample_chain(
            num_results=num_results,
            current_state=current_state,
            kernel=remc,
            num_burnin_steps=50,
            trace_fn=lambda _, results: results,
            seed=_set_seed())

        if state_includes_replicas:
            self.assertAllEqual((num_results, num_replica), states.shape)
        else:
            self.assertAllEqual((num_results, ), states.shape)

        states_, kr_, replica_ess_ = self.evaluate([
            states,
            kernel_results,
            # Get the first (and only) state part for all replicas.
            effective_sample_size(kernel_results.post_swap_replica_states[0]),
        ])

        logging.vlog(
            2, '---- execution:{}  mean:{}  stddev:{}'.format(
                'eager' if tf.executing_eagerly() else 'graph', states_.mean(),
                states_.std()))

        # Some shortened names.
        replica_log_accept_ratio = mcmc_util.get_field(
            kr_.post_swap_replica_results, 'log_accept_ratio')
        replica_states_ = kr_.post_swap_replica_states[0]  # Get rid of "parts"

        # Target state is at index 0.
        if state_includes_replicas:
            self.assertAllClose(states_, replica_states_)
        else:
            self.assertAllClose(states_, replica_states_[:, 0])

        # Check that *each* replica has correct marginal.
        def _check_sample_stats(replica_idx):
            x = replica_states_[:, replica_idx]
            ess = replica_ess_[replica_idx]

            err_msg = 'replica_idx={}'.format(replica_idx)

            mean_atol = 6 * 1.0 / np.sqrt(ess)
            self.assertAllClose(x.mean(), 0.0, atol=mean_atol, msg=err_msg)

            # For a tempered Normal, Variance = T.
            expected_var = 1 / inverse_temperatures[replica_idx]
            var_atol = 6 * expected_var * np.sqrt(2) / np.sqrt(ess)
            self.assertAllClose(np.var(x),
                                expected_var,
                                atol=var_atol,
                                msg=err_msg)

        if not asserts:
            return

        for replica_idx in range(num_replica):
            _check_sample_stats(replica_idx)

        # Test log_accept_ratio and replica_log_accept_ratio.
        self.assertAllEqual((num_results, num_replica),
                            replica_log_accept_ratio.shape)
        replica_mean_accept_ratio = np.mean(np.exp(
            np.minimum(0, replica_log_accept_ratio)),
                                            axis=0)
        for accept_ratio in replica_mean_accept_ratio:
            # Every single replica should have a decent P[Accept]
            self.assertBetween(accept_ratio, 0.2, 0.99)

        # Check swap probabilities for adjacent swaps.
        self.assertAllEqual((num_results, num_replica - 1),
                            kr_.is_swap_accepted_adjacent.shape)
        conditional_swap_prob = (
            np.sum(kr_.is_swap_accepted_adjacent, axis=0) /
            np.sum(kr_.is_swap_proposed_adjacent, axis=0))
        if num_replica > 1 and prob_swap > 0:
            # If temperatures are reasonable, this should be the case.
            # Ideally conditional_swap_prob is near 30%, but we're not tuning here
            self.assertGreater(np.min(conditional_swap_prob), 0.01)
            self.assertLess(np.max(conditional_swap_prob), 0.99)

        # Check swap probabilities for all swaps.
        def _check_swap_matrix(matrix):
            self.assertAllEqual((num_results, num_replica, num_replica),
                                matrix.shape)
            # Matrix is stochastic (since you either get swapped with another
            # replica, or yourself), and symmetric, since we do once-only swaps.
            self.assertAllEqual(np.ones((num_results, num_replica)),
                                matrix.sum(axis=-1))
            self.assertAllEqual(matrix, np.transpose(matrix, (0, 2, 1)))
            # By default, all swaps are between adjacent replicas.
            for i in range(num_replica):
                for j in range(i + 2, num_replica):
                    self.assertEqual(0.0, np.max(np.abs(matrix[..., i, j])))

        _check_swap_matrix(kr_.is_swap_proposed)
        _check_swap_matrix(kr_.is_swap_accepted)

        # Check inverse_temperatures never change.
        self.assertAllEqual(
            np.repeat([inverse_temperatures], axis=0, repeats=num_results),
            kr_.inverse_temperatures)

        if store_parameters_in_results:
            # Check that store_parameters_in_results=True worked for HMC.
            if not isinstance(
                    kr_.post_swap_replica_results,
                    tfp.mcmc.simple_step_size_adaptation.
                    SimpleStepSizeAdaptationResults):
                self.assertAllEqual(
                    np.repeat([step_size], axis=0, repeats=num_results),
                    mcmc_util.get_field(kr_.post_swap_replica_results,
                                        'step_size'))

            self.assertAllEqual(
                np.repeat([num_leapfrog_steps], axis=0, repeats=num_results),
                mcmc_util.get_field(kr_.post_swap_replica_results,
                                    'num_leapfrog_steps'))
Ejemplo n.º 5
0
 def testIncompleteKernelResults(self):
   kernel_results = FakeKernelResults(some_field='zebra')
   with self.assertRaisesRegexp(TypeError, 'extract some_other_field'):
     util.get_field(kernel_results, 'some_other_field')
Ejemplo n.º 6
0
 def testValidKernelResults(self, kernel_results):
   self.assertEqual(util.get_field(kernel_results, 'some_field'), 'yak')
   with self.assertRaisesRegexp(TypeError, 'extract some_other_field'):
     util.get_field(kernel_results, 'some_other_field')