Example #1
0
 def test_atypical_nesting(self):
     results = FakeAtypicalNestingResults(
         unique_atypical_nesting=0, atypical_inner_results=FakeResults(1))
     self.assertTrue(unnest.has_nested(results, 'unique_atypical_nesting'))
     self.assertTrue(unnest.has_nested(results, 'unique_core'))
     self.assertFalse(hasattr(results, 'unique_core'))
     self.assertFalse(unnest.has_nested(results, 'foo'))
     self.assertEqual(
         results.unique_atypical_nesting,
         unnest.get_innermost(results, 'unique_atypical_nesting'))
     self.assertEqual(
         results.unique_atypical_nesting,
         unnest.get_outermost(results, 'unique_atypical_nesting'))
     self.assertEqual(results.atypical_inner_results.unique_core,
                      unnest.get_innermost(results, 'unique_core'))
     self.assertEqual(results.atypical_inner_results.unique_core,
                      unnest.get_outermost(results, 'unique_core'))
     self.assertRaises(AttributeError,
                       lambda: unnest.get_innermost(results, 'foo'))
     self.assertRaises(AttributeError,
                       lambda: unnest.get_outermost(results, 'foo'))
     self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON),
                   SINGLETON)
     self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON),
                   SINGLETON)
def hmc_like_proposed_velocity_getter_fn(kernel_results):
  """Getter for `proposed_velocity` so it can be inspected."""
  # TODO(b/169898033): This only works with the standard kinetic energy term.
  proposed_velocity = unnest.get_innermost(kernel_results, 'final_momentum')
  proposed_state = unnest.get_innermost(kernel_results, 'proposed_state')
  # proposed_velocity has the wrong structure when state is a scalar.
  return tf.nest.pack_sequence_as(proposed_state,
                                  tf.nest.flatten(proposed_velocity))
Example #3
0
def _get_field(kernel_results, field_name):
    try:
        return unnest.get_innermost(kernel_results, field_name)
    except AttributeError:
        msg = _kernel_result_not_implemented_message_template.format(
            kernel_results, field_name)
        raise REMCFieldNotFoundError(msg)
Example #4
0
 def test_flat(self):
     results = FakeResults(0)
     self.assertTrue(unnest.has_nested(results, 'unique_core'))
     self.assertFalse(unnest.has_nested(results, 'foo'))
     self.assertEqual(results.unique_core,
                      unnest.get_innermost(results, 'unique_core'))
     self.assertEqual(results.unique_core,
                      unnest.get_outermost(results, 'unique_core'))
     self.assertRaises(AttributeError,
                       lambda: unnest.get_innermost(results, 'foo'))
     self.assertRaises(AttributeError,
                       lambda: unnest.get_outermost(results, 'foo'))
     self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON),
                   SINGLETON)
     self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON),
                   SINGLETON)
Example #5
0
def rwm_log_accept_prob_getter_fn(kernel_results):
    """Getter for `log_accept_prob` member of `MetropolisHastings`
    `TransitionKernel` so that it can be inspected."""
    log_accept_ratio = unnest.get_innermost(kernel_results, "log_accept_ratio")
    safe_accept_ratio = tf.where(
        tf.math.is_finite(log_accept_ratio),
        log_accept_ratio,
        tf.constant(-np.inf, dtype=log_accept_ratio.dtype),
    )
    return tf.minimum(safe_accept_ratio, 0.0)
Example #6
0
 def test_deeply_nested(self):
     results = _build_deeply_nested(0, 1, 2, 3, 4)
     self.assertTrue(unnest.has_nested(results, 'unique_nesting'))
     self.assertTrue(unnest.has_nested(results, 'unique_atypical_nesting'))
     self.assertTrue(unnest.has_nested(results, 'unique_core'))
     self.assertFalse(hasattr(self, 'unique_core'))
     self.assertFalse(unnest.has_nested(results, 'foo'))
     self.assertEqual(unnest.get_innermost(results, 'unique_nesting'), 3)
     self.assertEqual(unnest.get_outermost(results, 'unique_nesting'), 0)
     self.assertEqual(
         unnest.get_innermost(results, 'unique_atypical_nesting'), 2)
     self.assertEqual(
         unnest.get_outermost(results, 'unique_atypical_nesting'), 1)
     self.assertEqual(unnest.get_innermost(results, 'unique_core'), 4)
     self.assertEqual(unnest.get_outermost(results, 'unique_core'), 4)
     self.assertRaises(AttributeError,
                       lambda: unnest.get_innermost(results, 'foo'))
     self.assertRaises(AttributeError,
                       lambda: unnest.get_outermost(results, 'foo'))
     self.assertIs(unnest.get_innermost(results, 'foo', SINGLETON),
                   SINGLETON)
     self.assertIs(unnest.get_outermost(results, 'foo', SINGLETON),
                   SINGLETON)
Example #7
0
def rwm_extra_getter_fn(kernel_results):
    """Getter for `extra` member of `MetropolisHastings` `TransitionKernel`
    so that it can be inspected."""
    return unnest.get_innermost(kernel_results, "extra")
Example #8
0
def get_target_log_prob(results):
    """Fetches a target log prob from a results structure"""
    return unnest.get_innermost(results, "target_log_prob")
Example #9
0
def hmc_like_log_accept_prob_getter_fn(kernel_results):
    log_accept_ratio = unnest.get_innermost(kernel_results, 'log_accept_ratio')
    safe_accept_ratio = tf.where(
        tf.math.is_finite(log_accept_ratio), log_accept_ratio,
        tf.constant(-np.inf, dtype=log_accept_ratio.dtype))
    return tf.minimum(safe_accept_ratio, 0.)
Example #10
0
def hmc_like_step_size_getter_fn(kernel_results):
    """Getter for `step_size` so it can be inspected."""
    return unnest.get_innermost(kernel_results, 'step_size')
def hmc_like_num_leapfrog_steps_getter_fn(kernel_results):
  """Getter for `num_leapfrog_steps` so it can be inspected."""
  return unnest.get_innermost(kernel_results, 'num_leapfrog_steps')
def hmc_like_proposed_state_getter_fn(kernel_results):
  """Getter for `proposed_state` so it can be inspected."""
  return unnest.get_innermost(kernel_results, 'proposed_state')