コード例 #1
0
 def test_flat_replace(self):
     results = FakeResults(0)
     self.assertEqual(
         unnest.replace_innermost(results, unique_core=1).unique_core, 1)
     self.assertEqual(
         unnest.replace_innermost(results,
                                  return_unused=True,
                                  unique_core=1,
                                  foo=2), (FakeResults(1), {
                                      'foo': 2
                                  }))
     self.assertEqual(
         unnest.replace_outermost(results, unique_core=2).unique_core, 2)
     self.assertEqual(
         unnest.replace_outermost(results,
                                  return_unused=True,
                                  unique_core=2,
                                  foo=3), (FakeResults(2), {
                                      'foo': 3
                                  }))
     self.assertRaises(
         ValueError,
         lambda: unnest.replace_innermost(  # pylint: disable=g-long-lambda
             results,
             unique_core=1,
             foo=1))
     self.assertRaises(
         ValueError,
         lambda: unnest.replace_outermost(  # pylint: disable=g-long-lambda
             results,
             unique_core=1,
             foo=1))
コード例 #2
0
    def test_atypical_nesting_replace(self):
        def build(a, b):
            return FakeAtypicalNestingResults(
                unique_atypical_nesting=a,
                atypical_inner_results=FakeResults(b))

        results = build(0, 1)
        self.assertEqual(
            unnest.replace_innermost(results, unique_atypical_nesting=2),
            build(2, 1))
        self.assertEqual(unnest.replace_innermost(results, unique_core=2),
                         build(0, 2))
        self.assertEqual(
            unnest.replace_innermost(results,
                                     unique_atypical_nesting=2,
                                     unique_core=3), build(2, 3))
        self.assertEqual(
            unnest.replace_innermost(results,
                                     return_unused=True,
                                     unique_atypical_nesting=2,
                                     unique_core=3,
                                     foo=4), (build(2, 3), {
                                         'foo': 4
                                     }))
        self.assertEqual(
            unnest.replace_outermost(results, unique_atypical_nesting=2),
            build(2, 1))
        self.assertEqual(unnest.replace_outermost(results, unique_core=2),
                         build(0, 2))
        self.assertEqual(
            unnest.replace_outermost(results,
                                     unique_atypical_nesting=2,
                                     unique_core=3), build(2, 3))
        self.assertEqual(
            unnest.replace_outermost(results,
                                     return_unused=True,
                                     unique_atypical_nesting=2,
                                     unique_core=3,
                                     foo=4), (build(2, 3), {
                                         'foo': 4
                                     }))
        self.assertRaises(
            ValueError,
            lambda: unnest.replace_innermost(  # pylint: disable=g-long-lambda
                results,
                unique_core=1,
                foo=1))
        self.assertRaises(
            ValueError,
            lambda: unnest.replace_outermost(  # pylint: disable=g-long-lambda
                results,
                unique_core=1,
                foo=1))
コード例 #3
0
def _update_field(kernel_results, field_name, value):
    try:
        return unnest.replace_innermost(kernel_results, **{field_name: value})
    except ValueError:
        msg = _kernel_result_not_implemented_message_template.format(
            kernel_results, field_name)
        raise REMCFieldNotFoundError(msg)
コード例 #4
0
 def test_deeply_nested_replace(self):
     results = _build_deeply_nested(0, 1, 2, 3, 4)
     self.assertTrue(unnest.replace_innermost(results, unique_nesting=5),
                     _build_deeply_nested(0, 1, 2, 5, 4))
     self.assertTrue(unnest.replace_outermost(results, unique_nesting=5),
                     _build_deeply_nested(5, 1, 2, 3, 4))
     self.assertTrue(
         unnest.replace_innermost(results, unique_atypical_nesting=5),
         _build_deeply_nested(0, 1, 5, 3, 4))
     self.assertTrue(
         unnest.replace_outermost(results, unique_atypical_nesting=5),
         _build_deeply_nested(0, 5, 2, 3, 4))
     self.assertTrue(unnest.replace_innermost(results, unique_core=5),
                     _build_deeply_nested(0, 1, 2, 3, 5))
     self.assertTrue(unnest.replace_outermost(results, unique_core=5),
                     _build_deeply_nested(0, 1, 2, 3, 5))
     self.assertTrue(
         unnest.replace_innermost(results,
                                  unique_nesting=5,
                                  unique_atypical_nesting=6,
                                  unique_core=7),
         _build_deeply_nested(0, 1, 6, 5, 7))
     self.assertTrue(
         unnest.replace_innermost(results,
                                  return_unused=True,
                                  unique_nesting=5,
                                  unique_atypical_nesting=6,
                                  unique_core=7,
                                  foo=8),
         (_build_deeply_nested(0, 1, 6, 5, 7), {
             'foo': 8
         }))
     self.assertTrue(
         unnest.replace_outermost(results,
                                  unique_nesting=5,
                                  unique_atypical_nesting=6,
                                  unique_core=7),
         _build_deeply_nested(5, 6, 2, 3, 7))
     self.assertTrue(
         unnest.replace_outermost(results,
                                  return_unused=True,
                                  unique_nesting=5,
                                  unique_atypical_nesting=6,
                                  unique_core=7,
                                  foo=8),
         (_build_deeply_nested(5, 6, 2, 3, 7), {
             'foo': 8
         }))
     self.assertRaises(
         ValueError,
         lambda: unnest.replace_innermost(  # pylint: disable=g-long-lambda
             results,
             unique_core=1,
             foo=1))
     self.assertRaises(
         ValueError,
         lambda: unnest.replace_outermost(  # pylint: disable=g-long-lambda
             results,
             unique_core=1,
             foo=1))
コード例 #5
0
def rwm_extra_setter_fn(
    kernel_results,
    num_steps,
    covariance_scaling,
    covariance,
    running_covariance,
    is_adaptive,
):
    """Setter for `extra` member of `MetropolisHastings` `TransitionKernel`
    so that it can be adapted."""
    return unnest.replace_innermost(
        kernel_results,
        extra=AdaptiveRWMResults(
            num_steps=num_steps,
            covariance_scaling=covariance_scaling,
            covariance=covariance,
            running_covariance=running_covariance,
            is_adaptive=is_adaptive,
        ),
    )
コード例 #6
0
ファイル: gibbs_kernel.py プロジェクト: chrism0dwk/gemlib_tfp
def update_target_log_prob(results, target_log_prob):
    """Puts a target log prob into a results structure"""
    return unnest.replace_innermost(results, target_log_prob=target_log_prob)
コード例 #7
0
def hmc_like_step_size_setter_fn(kernel_results, new_step_size):
    """Setter for `step_size` so it can be adapted."""
    return unnest.replace_innermost(kernel_results, step_size=new_step_size)
def hmc_like_num_leapfrog_steps_setter_fn(kernel_results,
                                          new_num_leapfrog_steps):
  """Setter for `num_leapfrog_steps` so it can be adapted."""
  return unnest.replace_innermost(
      kernel_results, num_leapfrog_steps=new_num_leapfrog_steps)