예제 #1
0
 def testNoWarn(self, tensor):
     warnings.simplefilter('always')
     with warnings.catch_warnings(record=True) as triggered:
         util.warn_if_parameters_are_not_simple_tensors({'a': tensor})
     self.assertFalse(
         any('Please consult the docstring' in str(warning.message)
             for warning in triggered))
예제 #2
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 num_leapfrog_steps,
                 state_gradients_are_stopped=False,
                 seed=None,
                 store_parameters_in_results=False,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
        for. Total progress per HMC step is roughly proportional to
        `step_size * num_leapfrog_steps`.
      state_gradients_are_stopped: Python `bool` indicating that the proposed
        new state be run through `tf.stop_gradient`. This is particularly useful
        when combining optimization over samples from the HMC chain.
        Default value: `False` (i.e., do not apply `stop_gradient`).
      seed: Python integer to seed the random number generator. Deprecated, pass
        seed to `tfp.mcmc.sample_chain`.
      store_parameters_in_results: If `True`, then `step_size` and
        `num_leapfrog_steps` are written to and read from eponymous fields in
        the kernel results objects returned from `one_step` and
        `bootstrap_results`. This allows wrapper kernels to adjust those
        parameters on the fly.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'hmc_kernel').
    """
        if seed is not None and tf.executing_eagerly():
            # TODO(b/68017812): Re-enable once TFE supports `tf.random.shuffle` seed.
            raise NotImplementedError(
                'Specifying a `seed` when running eagerly is '
                'not currently supported. To run in Eager '
                'mode with a seed, pass the seed to '
                '`tfp.mcmc.sample_chain`.')
        if not store_parameters_in_results:
            mcmc_util.warn_if_parameters_are_not_simple_tensors(
                dict(step_size=step_size,
                     num_leapfrog_steps=num_leapfrog_steps))
        self._seed_stream = SeedStream(seed, salt='uncalibrated_hmc_one_step')
        self._parameters = dict(
            target_log_prob_fn=target_log_prob_fn,
            step_size=step_size,
            num_leapfrog_steps=num_leapfrog_steps,
            state_gradients_are_stopped=state_gradients_are_stopped,
            seed=seed,
            name=name,
            store_parameters_in_results=store_parameters_in_results,
        )
        self._momentum_dtype = None
예제 #3
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 num_leapfrog_steps,
                 state_gradients_are_stopped=False,
                 store_parameters_in_results=False,
                 experimental_shard_axis_names=None,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
        for. Total progress per HMC step is roughly proportional to
        `step_size * num_leapfrog_steps`.
      state_gradients_are_stopped: Python `bool` indicating that the proposed
        new state be run through `tf.stop_gradient`. This is particularly useful
        when combining optimization over samples from the HMC chain.
        Default value: `False` (i.e., do not apply `stop_gradient`).
      store_parameters_in_results: If `True`, then `step_size` and
        `num_leapfrog_steps` are written to and read from eponymous fields in
        the kernel results objects returned from `one_step` and
        `bootstrap_results`. This allows wrapper kernels to adjust those
        parameters on the fly.
      experimental_shard_axis_names: A structure of string names indicating how
        members of the state are sharded.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'hmc_kernel').
    """
        if not store_parameters_in_results:
            mcmc_util.warn_if_parameters_are_not_simple_tensors(
                dict(step_size=step_size,
                     num_leapfrog_steps=num_leapfrog_steps))
        self._parameters = dict(
            target_log_prob_fn=target_log_prob_fn,
            step_size=step_size,
            num_leapfrog_steps=num_leapfrog_steps,
            state_gradients_are_stopped=state_gradients_are_stopped,
            name=name,
            experimental_shard_axis_names=experimental_shard_axis_names,
            store_parameters_in_results=store_parameters_in_results,
        )
        self._momentum_dtype = None