def test_flat_replace(self): results = FakeResults(0) self.assertEqual( unnest.replace_innermost(results, unique_core=1).unique_core, 1) self.assertEqual( unnest.replace_innermost(results, return_unused=True, unique_core=1, foo=2), (FakeResults(1), { 'foo': 2 })) self.assertEqual( unnest.replace_outermost(results, unique_core=2).unique_core, 2) self.assertEqual( unnest.replace_outermost(results, return_unused=True, unique_core=2, foo=3), (FakeResults(2), { 'foo': 3 })) self.assertRaises( ValueError, lambda: unnest.replace_innermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1)) self.assertRaises( ValueError, lambda: unnest.replace_outermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1))
def test_atypical_nesting_replace(self): def build(a, b): return FakeAtypicalNestingResults( unique_atypical_nesting=a, atypical_inner_results=FakeResults(b)) results = build(0, 1) self.assertEqual( unnest.replace_innermost(results, unique_atypical_nesting=2), build(2, 1)) self.assertEqual(unnest.replace_innermost(results, unique_core=2), build(0, 2)) self.assertEqual( unnest.replace_innermost(results, unique_atypical_nesting=2, unique_core=3), build(2, 3)) self.assertEqual( unnest.replace_innermost(results, return_unused=True, unique_atypical_nesting=2, unique_core=3, foo=4), (build(2, 3), { 'foo': 4 })) self.assertEqual( unnest.replace_outermost(results, unique_atypical_nesting=2), build(2, 1)) self.assertEqual(unnest.replace_outermost(results, unique_core=2), build(0, 2)) self.assertEqual( unnest.replace_outermost(results, unique_atypical_nesting=2, unique_core=3), build(2, 3)) self.assertEqual( unnest.replace_outermost(results, return_unused=True, unique_atypical_nesting=2, unique_core=3, foo=4), (build(2, 3), { 'foo': 4 })) self.assertRaises( ValueError, lambda: unnest.replace_innermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1)) self.assertRaises( ValueError, lambda: unnest.replace_outermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1))
def _update_field(kernel_results, field_name, value): try: return unnest.replace_innermost(kernel_results, **{field_name: value}) except ValueError: msg = _kernel_result_not_implemented_message_template.format( kernel_results, field_name) raise REMCFieldNotFoundError(msg)
def test_deeply_nested_replace(self): results = _build_deeply_nested(0, 1, 2, 3, 4) self.assertTrue(unnest.replace_innermost(results, unique_nesting=5), _build_deeply_nested(0, 1, 2, 5, 4)) self.assertTrue(unnest.replace_outermost(results, unique_nesting=5), _build_deeply_nested(5, 1, 2, 3, 4)) self.assertTrue( unnest.replace_innermost(results, unique_atypical_nesting=5), _build_deeply_nested(0, 1, 5, 3, 4)) self.assertTrue( unnest.replace_outermost(results, unique_atypical_nesting=5), _build_deeply_nested(0, 5, 2, 3, 4)) self.assertTrue(unnest.replace_innermost(results, unique_core=5), _build_deeply_nested(0, 1, 2, 3, 5)) self.assertTrue(unnest.replace_outermost(results, unique_core=5), _build_deeply_nested(0, 1, 2, 3, 5)) self.assertTrue( unnest.replace_innermost(results, unique_nesting=5, unique_atypical_nesting=6, unique_core=7), _build_deeply_nested(0, 1, 6, 5, 7)) self.assertTrue( unnest.replace_innermost(results, return_unused=True, unique_nesting=5, unique_atypical_nesting=6, unique_core=7, foo=8), (_build_deeply_nested(0, 1, 6, 5, 7), { 'foo': 8 })) self.assertTrue( unnest.replace_outermost(results, unique_nesting=5, unique_atypical_nesting=6, unique_core=7), _build_deeply_nested(5, 6, 2, 3, 7)) self.assertTrue( unnest.replace_outermost(results, return_unused=True, unique_nesting=5, unique_atypical_nesting=6, unique_core=7, foo=8), (_build_deeply_nested(5, 6, 2, 3, 7), { 'foo': 8 })) self.assertRaises( ValueError, lambda: unnest.replace_innermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1)) self.assertRaises( ValueError, lambda: unnest.replace_outermost( # pylint: disable=g-long-lambda results, unique_core=1, foo=1))
def rwm_extra_setter_fn( kernel_results, num_steps, covariance_scaling, covariance, running_covariance, is_adaptive, ): """Setter for `extra` member of `MetropolisHastings` `TransitionKernel` so that it can be adapted.""" return unnest.replace_innermost( kernel_results, extra=AdaptiveRWMResults( num_steps=num_steps, covariance_scaling=covariance_scaling, covariance=covariance, running_covariance=running_covariance, is_adaptive=is_adaptive, ), )
def update_target_log_prob(results, target_log_prob): """Puts a target log prob into a results structure""" return unnest.replace_innermost(results, target_log_prob=target_log_prob)
def hmc_like_step_size_setter_fn(kernel_results, new_step_size): """Setter for `step_size` so it can be adapted.""" return unnest.replace_innermost(kernel_results, step_size=new_step_size)
def hmc_like_num_leapfrog_steps_setter_fn(kernel_results, new_num_leapfrog_steps): """Setter for `num_leapfrog_steps` so it can be adapted.""" return unnest.replace_innermost( kernel_results, num_leapfrog_steps=new_num_leapfrog_steps)