def __init__(self, target_log_prob_fn, new_state_fn=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. new_state_fn: Python callable which takes a list of state parts and a seed; returns a same-type `list` of `Tensor`s, each being a perturbation of the input state parts. The perturbation distribution is assumed to be a symmetric distribution centered at the input state part. Default value: `None` which is mapped to `tfp.mcmc.random_walk_normal_fn()`. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'rwm_kernel'). Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `scale` or a list with same length as `current_state`. """ if new_state_fn is None: new_state_fn = random_walk_normal_fn() self._impl = metropolis_hastings.MetropolisHastings( inner_kernel=UncalibratedRandomWalk( target_log_prob_fn=target_log_prob_fn, new_state_fn=new_state_fn, name=name))
def __init__(self, target_log_prob_fn, step_size, num_leapfrog_steps, momentum_distribution=None, state_gradients_are_stopped=False, step_size_update_fn=None, store_parameters_in_results=False, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. momentum_distribution: A `tfp.distributions.Distribution` instance to draw momentum from. Defaults to isotropic normal distributions. state_gradients_are_stopped: Python `bool` indicating that the proposed new state be run through `tf.stop_gradient`. This is particularly useful when combining optimization over samples from the HMC chain. Default value: `False` (i.e., do not apply `stop_gradient`). step_size_update_fn: Python `callable` taking current `step_size` (typically a `tf.Variable`) and `kernel_results` (typically `collections.namedtuple`) and returns updated step_size (`Tensor`s). Default value: `None` (i.e., do not update `step_size` automatically). store_parameters_in_results: If `True`, then `step_size` and `num_leapfrog_steps` are written to and read from eponymous fields in the kernel results objects returned from `one_step` and `bootstrap_results`. This allows wrapper kernels to adjust those parameters on the fly. This is incompatible with `step_size_update_fn`, which must be set to `None`. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_kernel'). """ if step_size_update_fn and store_parameters_in_results: raise ValueError('It is invalid to simultaneously specify ' '`step_size_update_fn` and set ' '`store_parameters_in_results` to `True`.') self._impl = metropolis_hastings.MetropolisHastings( inner_kernel=UncalibratedPreconditionedHamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=state_gradients_are_stopped, momentum_distribution=momentum_distribution, name=name or 'hmc_kernel', store_parameters_in_results=store_parameters_in_results)) self._parameters = self._impl.inner_kernel.parameters.copy() self._parameters.pop('seed', None) # TODO(b/159636942): Remove this line. self._parameters['step_size_update_fn'] = step_size_update_fn
def __init__(self, target_log_prob_fn, step_size, volatility_fn=None, parallel_iterations=10, experimental_shard_axis_names=None, name=None): """Initializes MALA transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. volatility_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns volatility value at `current_state`. Should return a `Tensor` or Python `list` of `Tensor`s that must broadcast with the shape of `current_state` Defaults to the identity function. parallel_iterations: the number of coordinates for which the gradients of the volatility matrix `volatility_fn` can be computed in parallel. Default value: `None` (i.e., use system default). experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mala_kernel'). Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. TypeError: if `volatility_fn` is not callable. """ impl = metropolis_hastings.MetropolisHastings( inner_kernel=UncalibratedLangevin( target_log_prob_fn=target_log_prob_fn, step_size=step_size, volatility_fn=volatility_fn, parallel_iterations=parallel_iterations, name=name)).experimental_with_shard_axes( experimental_shard_axis_names) self._impl = impl parameters = impl.inner_kernel.parameters.copy() # Remove `compute_acceptance` parameter as this is not a MALA kernel # `__init__` parameter. del parameters['compute_acceptance'] self._parameters = parameters
def __init__(self, target_log_prob_fn, step_size, num_leapfrog_steps, state_gradients_are_stopped=False, step_size_update_fn=None, seed=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. state_gradients_are_stopped: Python `bool` indicating that the proposed new state be run through `tf.stop_gradient`. This is particularly useful when combining optimization over samples from the HMC chain. Default value: `False` (i.e., do not apply `stop_gradient`). step_size_update_fn: Python `callable` taking current `step_size` (typically a `tf.Variable`) and `kernel_results` (typically `collections.namedtuple`) and returns updated step_size (`Tensor`s). Default value: `None` (i.e., do not update `step_size` automatically). seed: Python integer to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_kernel'). Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. """ impl = metropolis_hastings.MetropolisHastings( inner_kernel=UncalibratedHamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=state_gradients_are_stopped, seed=seed, name='hmc_kernel' if name is None else name), seed=seed) parameters = impl.inner_kernel.parameters.copy() parameters['step_size_update_fn'] = step_size_update_fn self._impl = impl self._parameters = parameters
def __init__(self, target_log_prob_fn, step_size, num_leapfrog_steps, seed=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. seed: Python integer to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_kernel'). Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. """ self._target_log_prob_fn = target_log_prob_fn self._step_size = step_size self._num_leapfrog_steps = num_leapfrog_steps self._seed = seed self._name = name self._hmc_impl = metropolis_hastings.MetropolisHastings( inner_kernel=UncalibratedHamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, seed=seed, name=name), seed=seed)
def __init__(self, target_log_prob_fn, step_size, volatility_fn=None, seed=None, name=None): """Initializes MALA transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. volatility_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns volatility value at `current_state`. Should return a `Tensor` or Python `list` of `Tensor`s that must broadcast with the shape of `current_state` Defaults to the identity function. seed: Python integer to seed the random number generator. Default value: `None` (i.e., no seed). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mala_kernel'). Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. TypeError: if `volatility_fn` is not callable. """ self._impl = metropolis_hastings.MetropolisHastings( inner_kernel=UncalibratedLangevin( target_log_prob_fn=target_log_prob_fn, step_size=step_size, volatility_fn=volatility_fn, seed=seed, name=name), seed=seed)
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope( mcmc_util.make_name(self.name, "AdaptiveRandomWalkMetropolisHastings", "one_step")): with tf.name_scope("initialize"): if mcmc_util.is_list_like(current_state): current_state_parts = list(current_state) else: current_state_parts = [current_state] current_state_parts = [ tf.convert_to_tensor(s, name="current_state") for s in current_state_parts ] # Note 'covariance_scaling' and 'accum_covar' are updated every step but # 'covariance' is not updated until 'num_steps' >= 'covariance_burnin'. num_steps = self.extra_getter_fn(previous_kernel_results).num_steps # for parallel processing efficiency use gather() rather than cond()? previous_is_adaptive = self.extra_getter_fn( previous_kernel_results).is_adaptive current_covariance_scaling = tf.gather( tf.stack( [ self.extra_getter_fn( previous_kernel_results).covariance_scaling, self.update_covariance_scaling(previous_kernel_results, num_steps), ], axis=-1, ), previous_is_adaptive, batch_dims=1, axis=1, ) previous_accum_covar = self.extra_getter_fn( previous_kernel_results).running_covariance current_accum_covar = self.running_covar.update( state=previous_accum_covar, new_sample=current_state_parts) previous_covariance = self.extra_getter_fn( previous_kernel_results).covariance current_covariance = tf.gather( [ previous_covariance, self.running_covar.finalize(current_accum_covar, ddof=1), ], tf.cast( num_steps >= self.covariance_burnin, dtype=tf.dtypes.int32, ), ) current_scaled_covariance = tf.squeeze( tf.expand_dims(current_covariance_scaling, axis=1) * tf.stack([current_covariance]), axis=0, ) current_scaled_covariance = tf.unstack(current_scaled_covariance) if mcmc_util.is_list_like(current_scaled_covariance): current_scaled_covariance_parts = list( current_scaled_covariance) else: current_scaled_covariance_parts = [current_scaled_covariance] current_scaled_covariance_parts = [ tf.convert_to_tensor(s, name="current_scaled_covariance") for s in current_scaled_covariance_parts ] current_is_adaptive = self.u.sample(seed=self.seed) self._impl = metropolis_hastings.MetropolisHastings( inner_kernel=random_walk_metropolis.UncalibratedRandomWalk( target_log_prob_fn=self.target_log_prob_fn, new_state_fn=random_walk_mvnorm_fn( covariance=current_scaled_covariance_parts, pu=self.pu, fixed_variance=self.fixed_variance, is_adaptive=current_is_adaptive, name=self.name, ), name=self.name, ), name=self.name, ) new_state, new_inner_results = self._impl.one_step( current_state, previous_kernel_results) new_inner_results = self.extra_setter_fn( new_inner_results, num_steps + 1, tf.squeeze(current_covariance_scaling, axis=1), current_covariance, current_accum_covar, current_is_adaptive, ) return [new_state, new_inner_results]
def __init__( self, target_log_prob_fn, initial_state, initial_covariance=None, initial_covariance_scaling=2.38**2, covariance_scaling_reducer=0.7, covariance_scaling_limiter=0.01, covariance_burnin=100, target_accept_ratio=0.234, pu=0.95, fixed_variance=0.01, extra_getter_fn=rwm_extra_getter_fn, extra_setter_fn=rwm_extra_setter_fn, log_accept_prob_getter_fn=rwm_log_accept_prob_getter_fn, seed=None, name=None, ): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` and returns its (possibly unnormalized) log-density under the target distribution. initial_state: Python `list` of `Tensor`s representing the initial state of each parameter. initial_covariance: Python `list` of `Tensor`s representing the initial covariance of the proposal. The `initial_covariance` and `initial_state` should have identical `dtype`s and batch dimensions. If `initial_covariance` is `None` then it initialized to a Python `list` of `Tensor`s where each tensor is the identity matrix multiplied by 0.001; the `list` structure will be identical to `initial_state`. The covariance matrix is tuned during the evolution of the MCMC chain. Default value: `None`. initial_covariance_scaling: Python floating point number representing a the initial value of the `covariance_scaling`. The value of `covariance_scaling` is tuned during the evolution of the MCMC chain. Let d represent the number of parameters e.g. as given by the `initial_state`. The ratio given by the `covariance_scaling` divided by d is used to multiply the running covariance. The covariance scaling factor multiplied by the covariance matrix is used in the proposal at each step. Default value: 2.38**2. covariance_scaling_reducer: Python floating point number, bounded over the range (0.5,1.0], representing the constant factor used during the adaptation of the `covariance_scaling`. Default value: 0.7. covariance_scaling_limiter: Python floating point number, bounded between 0.0 and 1.0, which places a limit on the maximum amount the `covariance_scaling` value can be purturbed at each interaction of the MCMC chain. Default value: 0.01. covariance_burnin: Python integer number of steps to take before starting to compute the running covariance. Default value: 100. target_accept_ratio: Python floating point number, bounded between 0.0 and 1.0, representing the target acceptance probability of the Metropolis–Hastings algorithm. Default value: 0.234. pu: Python floating point number, bounded between 0.0 and 1.0, representing the bounded convergence parameter. See `random_walk_mvnorm_fn()` for further details. Default value: 0.95. fixed_variance: Python floating point number representing the variance of the fixed proposal distribution. See `random_walk_mvnorm_fn` for further details. Default value: 0.01. extra_getter_fn: A callable with the signature `(kernel_results) -> extra` where `kernel_results` are the results of the `inner_kernel`, and `extra` is a nested collection of `Tensor`s. extra_setter_fn: A callable with the signature `(kernel_results, args) -> new_kernel_results` where `kernel_results` are the results of the `inner_kernel`, `args` are a nested collection of `Tensor`s with the same structure as returned by the `extra_getter_fn`, and `new_kernel_results` are a copy of `kernel_results` with `args` in the `extra` field set. log_accept_prob_getter_fn: A callable with the signature `(kernel_results) -> log_accept_prob` where `kernel_results` are the results of the `inner_kernel`, and `log_accept_prob` is either a a scalar, or has shape [num_chains]. seed: Python integer to seed the random number generator. Default value: `None`. name: Python `str` name prefixed to Ops created by this function. Default value: `None`. Returns: next_state: Tensor or list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if `initial_covariance_scaling` is less than or equal to 0.0. ValueError: if `covariance_scaling_reducer` is less than or equal to 0.5 or greater than 1.0. ValueError: if `covariance_scaling_limiter` is less than 0.0 or greater than 1.0. ValueError: if `covariance_burnin` is less than 0. ValueError: if `target_accept_ratio` is less than 0.0 or greater than 1.0. ValueError: if `pu` is less than 0.0 or greater than 1.0. ValueError: if `fixed_variance` is less than 0.0. """ with tf.name_scope( mcmc_util.make_name(name, "AdaptiveRandomWalkMetropolisHastings", "__init__")) as name: if initial_covariance_scaling <= 0.0: raise ValueError( "`{}` must be a `float` greater than 0.0".format( "initial_covariance_scaling")) if covariance_scaling_reducer <= 0.5 or covariance_scaling_reducer > 1.0: raise ValueError( "`{}` must be a `float` greater than 0.5 and less than or equal to 1.0." .format("covariance_scaling_reducer")) if covariance_scaling_limiter < 0.0 or covariance_scaling_limiter > 1.0: raise ValueError( "`{}` must be a `float` between 0.0 and 1.0.".format( "covariance_scaling_limiter")) if covariance_burnin < 0: raise ValueError( "`{}` must be a `integer` greater or equal to 0.".format( "covariance_burnin")) if target_accept_ratio <= 0.0 or target_accept_ratio > 1.0: raise ValueError( "`{}` must be a `float` between 0.0 and 1.0.".format( "target_accept_ratio")) if pu < 0.0 or pu > 1.0: raise ValueError( "`{}` must be a `float` between 0.0 and 1.0.".format("pu")) if fixed_variance < 0.0: raise ValueError( "`{}` must be a `float` greater than 0.0.".format( "fixed_variance")) if mcmc_util.is_list_like(initial_state): initial_state_parts = list(initial_state) else: initial_state_parts = [initial_state] initial_state_parts = [ tf.convert_to_tensor(s, name="initial_state") for s in initial_state_parts ] shape = tf.stack(initial_state_parts).shape dtype = dtype_util.base_dtype(tf.stack(initial_state_parts).dtype) if initial_covariance is None: initial_covariance = 0.001 * tf.eye( num_rows=shape[-1], dtype=dtype, batch_shape=[shape[0]]) else: initial_covariance = tf.stack(initial_covariance) if mcmc_util.is_list_like(initial_covariance): initial_covariance_parts = list(initial_covariance) else: initial_covariance_parts = [initial_covariance] initial_covariance_parts = [ tf.convert_to_tensor(s, name="initial_covariance") for s in initial_covariance_parts ] self._running_covar = stats.RunningCovariance(shape=(1, shape[-1]), dtype=dtype, event_ndims=1) self._accum_covar = self._running_covar.initialize() probs = tf.expand_dims(tf.ones([shape[0]], dtype=dtype) * pu, axis=1) self._u = Bernoulli(probs=probs, dtype=tf.dtypes.int32) self._initial_u = tf.zeros_like(self._u.sample(seed=seed), dtype=tf.dtypes.int32) name = mcmc_util.make_name(name, "AdaptiveRandomWalkMetropolisHastings", "") seed_stream = SeedStream(seed, salt="AdaptiveRandomWalkMetropolisHastings") self._parameters = dict( target_log_prob_fn=target_log_prob_fn, initial_state=initial_state, initial_covariance=initial_covariance, initial_covariance_scaling=initial_covariance_scaling, covariance_scaling_reducer=covariance_scaling_reducer, covariance_scaling_limiter=covariance_scaling_limiter, covariance_burnin=covariance_burnin, target_accept_ratio=target_accept_ratio, pu=pu, fixed_variance=fixed_variance, extra_getter_fn=extra_getter_fn, extra_setter_fn=extra_setter_fn, log_accept_prob_getter_fn=log_accept_prob_getter_fn, seed=seed, name=name, ) self._impl = metropolis_hastings.MetropolisHastings( inner_kernel=random_walk_metropolis.UncalibratedRandomWalk( target_log_prob_fn=target_log_prob_fn, new_state_fn=random_walk_mvnorm_fn( covariance=initial_covariance_parts, pu=pu, fixed_variance=fixed_variance, is_adaptive=self._initial_u, name=name, ), name=name, ), name=name, )