def test_atypical_nesting(self): results = FakeAtypicalNestingResults( unique_atypical_nesting=0, atypical_inner_results=FakeResults(1)) self.assertTrue(unnest.has_nested(results, 'unique_atypical_nesting')) self.assertTrue(unnest.has_nested(results, 'unique_core')) self.assertFalse(hasattr(results, 'unique_core')) self.assertFalse(unnest.has_nested(results, 'foo')) self.assertEqual( results.unique_atypical_nesting, unnest.get_innermost(results, 'unique_atypical_nesting')) self.assertEqual( results.unique_atypical_nesting, unnest.get_outermost(results, 'unique_atypical_nesting')) self.assertEqual(results.atypical_inner_results.unique_core, unnest.get_innermost(results, 'unique_core')) self.assertEqual(results.atypical_inner_results.unique_core, unnest.get_outermost(results, 'unique_core')) self.assertRaises(AttributeError, lambda: unnest.get_innermost(results, 'foo')) self.assertRaises(AttributeError, lambda: unnest.get_outermost(results, 'foo')) self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON), SINGLETON) self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON), SINGLETON)
def hmc_like_proposed_velocity_getter_fn(kernel_results): """Getter for `proposed_velocity` so it can be inspected.""" # TODO(b/169898033): This only works with the standard kinetic energy term. proposed_velocity = unnest.get_innermost(kernel_results, 'final_momentum') proposed_state = unnest.get_innermost(kernel_results, 'proposed_state') # proposed_velocity has the wrong structure when state is a scalar. return tf.nest.pack_sequence_as(proposed_state, tf.nest.flatten(proposed_velocity))
def _get_field(kernel_results, field_name): try: return unnest.get_innermost(kernel_results, field_name) except AttributeError: msg = _kernel_result_not_implemented_message_template.format( kernel_results, field_name) raise REMCFieldNotFoundError(msg)
def test_flat(self): results = FakeResults(0) self.assertTrue(unnest.has_nested(results, 'unique_core')) self.assertFalse(unnest.has_nested(results, 'foo')) self.assertEqual(results.unique_core, unnest.get_innermost(results, 'unique_core')) self.assertEqual(results.unique_core, unnest.get_outermost(results, 'unique_core')) self.assertRaises(AttributeError, lambda: unnest.get_innermost(results, 'foo')) self.assertRaises(AttributeError, lambda: unnest.get_outermost(results, 'foo')) self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON), SINGLETON) self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON), SINGLETON)
def rwm_log_accept_prob_getter_fn(kernel_results): """Getter for `log_accept_prob` member of `MetropolisHastings` `TransitionKernel` so that it can be inspected.""" log_accept_ratio = unnest.get_innermost(kernel_results, "log_accept_ratio") safe_accept_ratio = tf.where( tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=log_accept_ratio.dtype), ) return tf.minimum(safe_accept_ratio, 0.0)
def test_deeply_nested(self): results = _build_deeply_nested(0, 1, 2, 3, 4) self.assertTrue(unnest.has_nested(results, 'unique_nesting')) self.assertTrue(unnest.has_nested(results, 'unique_atypical_nesting')) self.assertTrue(unnest.has_nested(results, 'unique_core')) self.assertFalse(hasattr(self, 'unique_core')) self.assertFalse(unnest.has_nested(results, 'foo')) self.assertEqual(unnest.get_innermost(results, 'unique_nesting'), 3) self.assertEqual(unnest.get_outermost(results, 'unique_nesting'), 0) self.assertEqual( unnest.get_innermost(results, 'unique_atypical_nesting'), 2) self.assertEqual( unnest.get_outermost(results, 'unique_atypical_nesting'), 1) self.assertEqual(unnest.get_innermost(results, 'unique_core'), 4) self.assertEqual(unnest.get_outermost(results, 'unique_core'), 4) self.assertRaises(AttributeError, lambda: unnest.get_innermost(results, 'foo')) self.assertRaises(AttributeError, lambda: unnest.get_outermost(results, 'foo')) self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON), SINGLETON) self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON), SINGLETON)
def rwm_extra_getter_fn(kernel_results): """Getter for `extra` member of `MetropolisHastings` `TransitionKernel` so that it can be inspected.""" return unnest.get_innermost(kernel_results, "extra")
def get_target_log_prob(results): """Fetches a target log prob from a results structure""" return unnest.get_innermost(results, "target_log_prob")
def hmc_like_log_accept_prob_getter_fn(kernel_results): log_accept_ratio = unnest.get_innermost(kernel_results, 'log_accept_ratio') safe_accept_ratio = tf.where( tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=log_accept_ratio.dtype)) return tf.minimum(safe_accept_ratio, 0.)
def hmc_like_step_size_getter_fn(kernel_results): """Getter for `step_size` so it can be inspected.""" return unnest.get_innermost(kernel_results, 'step_size')
def hmc_like_num_leapfrog_steps_getter_fn(kernel_results): """Getter for `num_leapfrog_steps` so it can be inspected.""" return unnest.get_innermost(kernel_results, 'num_leapfrog_steps')
def hmc_like_proposed_state_getter_fn(kernel_results): """Getter for `proposed_state` so it can be inspected.""" return unnest.get_innermost(kernel_results, 'proposed_state')