def testNoWarn(self, tensor): warnings.simplefilter('always') with warnings.catch_warnings(record=True) as triggered: util.warn_if_parameters_are_not_simple_tensors({'a': tensor}) self.assertFalse( any('Please consult the docstring' in str(warning.message) for warning in triggered))
def __init__(self, target_log_prob_fn, step_size, num_leapfrog_steps, state_gradients_are_stopped=False, seed=None, store_parameters_in_results=False, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. state_gradients_are_stopped: Python `bool` indicating that the proposed new state be run through `tf.stop_gradient`. This is particularly useful when combining optimization over samples from the HMC chain. Default value: `False` (i.e., do not apply `stop_gradient`). seed: Python integer to seed the random number generator. Deprecated, pass seed to `tfp.mcmc.sample_chain`. store_parameters_in_results: If `True`, then `step_size` and `num_leapfrog_steps` are written to and read from eponymous fields in the kernel results objects returned from `one_step` and `bootstrap_results`. This allows wrapper kernels to adjust those parameters on the fly. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_kernel'). """ if seed is not None and tf.executing_eagerly(): # TODO(b/68017812): Re-enable once TFE supports `tf.random.shuffle` seed. raise NotImplementedError( 'Specifying a `seed` when running eagerly is ' 'not currently supported. To run in Eager ' 'mode with a seed, pass the seed to ' '`tfp.mcmc.sample_chain`.') if not store_parameters_in_results: mcmc_util.warn_if_parameters_are_not_simple_tensors( dict(step_size=step_size, num_leapfrog_steps=num_leapfrog_steps)) self._seed_stream = SeedStream(seed, salt='uncalibrated_hmc_one_step') self._parameters = dict( target_log_prob_fn=target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=state_gradients_are_stopped, seed=seed, name=name, store_parameters_in_results=store_parameters_in_results, ) self._momentum_dtype = None
def __init__(self, target_log_prob_fn, step_size, num_leapfrog_steps, state_gradients_are_stopped=False, store_parameters_in_results=False, experimental_shard_axis_names=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. state_gradients_are_stopped: Python `bool` indicating that the proposed new state be run through `tf.stop_gradient`. This is particularly useful when combining optimization over samples from the HMC chain. Default value: `False` (i.e., do not apply `stop_gradient`). store_parameters_in_results: If `True`, then `step_size` and `num_leapfrog_steps` are written to and read from eponymous fields in the kernel results objects returned from `one_step` and `bootstrap_results`. This allows wrapper kernels to adjust those parameters on the fly. experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_kernel'). """ if not store_parameters_in_results: mcmc_util.warn_if_parameters_are_not_simple_tensors( dict(step_size=step_size, num_leapfrog_steps=num_leapfrog_steps)) self._parameters = dict( target_log_prob_fn=target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=state_gradients_are_stopped, name=name, experimental_shard_axis_names=experimental_shard_axis_names, store_parameters_in_results=store_parameters_in_results, ) self._momentum_dtype = None