def testMetropolisHastingsStep(self): seed = self._make_seed(tfp_test_util.test_seed()) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=0., proposed_state=1., energy_change=-np.inf, seed=seed) self.assertAllEqual(1., accepted) self.assertAllEqual(True, mh_extra.is_accepted) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=0., proposed_state=1., energy_change=np.inf, seed=seed) self.assertAllEqual(0., accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=0., proposed_state=1., energy_change=np.nan, seed=seed) self.assertAllEqual(0., accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=None, proposed_state=1., energy_change=np.nan, seed=seed) self.assertAllEqual(1., accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=None, proposed_state=1., log_uniform=-10., energy_change=-np.log(0.5), seed=seed) self.assertAllEqual(1., accepted) self.assertAllEqual(True, mh_extra.is_accepted) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=None, proposed_state=1., log_uniform=0., energy_change=-np.log(0.5), seed=seed) self.assertAllEqual(1., accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, _ = fun_mcmc.metropolis_hastings_step( current_state=tf.zeros(1000), proposed_state=tf.ones(1000), energy_change=-tf.math.log(0.5 * tf.ones(1000)), seed=seed) self.assertAllClose(0.5, tf.reduce_mean(accepted), rtol=0.1)
def testMetropolisHastingsStepStructure(self): struct_type = collections.namedtuple('Struct', 'a, b') current = struct_type([1, 2], (3, [4, [0, 0]])) proposed = struct_type([5, 6], (7, [8, [0, 0]])) accepted, is_accepted, _ = fun_mcmc.metropolis_hastings_step( current_state=current, proposed_state=proposed, energy_change=-np.inf) self.assertAllEqual(True, is_accepted) self.assertAllEqual(tf.nest.flatten(proposed), tf.nest.flatten(accepted))
def testMetropolisHastingsStep(self): seed = self._make_seed(_test_seed()) zero = self._constant(0.) one = self._constant(1.) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=zero, proposed_state=one, energy_change=-np.inf, seed=seed) self.assertAllEqual(one, accepted) self.assertAllEqual(True, mh_extra.is_accepted) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=zero, proposed_state=one, energy_change=np.inf, seed=seed) self.assertAllEqual(zero, accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=zero, proposed_state=one, energy_change=np.nan, seed=seed) self.assertAllEqual(zero, accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=zero, proposed_state=one, energy_change=np.nan, seed=seed) self.assertAllEqual(zero, accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=zero, proposed_state=one, log_uniform=-one, energy_change=self._constant(-np.log(0.5)), seed=seed) self.assertAllEqual(one, accepted) self.assertAllEqual(True, mh_extra.is_accepted) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=zero, proposed_state=one, log_uniform=zero, energy_change=self._constant(-np.log(0.5)), seed=seed) self.assertAllEqual(zero, accepted) self.assertAllEqual(False, mh_extra.is_accepted) accepted, _ = fun_mcmc.metropolis_hastings_step( current_state=tf.zeros(1000, dtype=self._dtype), proposed_state=tf.ones(1000, dtype=self._dtype), energy_change=-tf.math.log(0.5 * tf.ones(1000, dtype=self._dtype)), seed=seed) self.assertAllClose(0.5, tf.reduce_mean(accepted), rtol=0.1)
def testMetropolisHastingsStepStructure(self): struct_type = collections.namedtuple('Struct', 'a, b') current = struct_type([1, 2], (3, [4, [0, 0]])) proposed = struct_type([5, 6], (7, [8, [0, 0]])) accepted, mh_extra = fun_mcmc.metropolis_hastings_step( current_state=current, proposed_state=proposed, energy_change=-np.inf, seed=self._make_seed(_test_seed())) self.assertAllEqual(True, mh_extra.is_accepted) self.assertAllEqual( util.flatten_tree(proposed), util.flatten_tree(accepted))