def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') if self._state_includes_replicas: it_n_replica = inverse_temperatures.shape[0] state_n_replica = init_state[0].shape[0] if ((it_n_replica is not None) and (state_n_replica is not None) and (it_n_replica != state_n_replica)): raise ValueError( 'Number of replicas implied by initial state ({}) must equal ' 'number of replicas implied by inverse_temperatures ({}), but ' 'did not'.format(it_n_replica, state_n_replica)) # We will now replicate each of a possible batch of initial stats, one for # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy] # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means # concatenation and T=shape(inverse_temperature). num_replica = ps.size0(inverse_temperatures) replica_shape = tf.convert_to_tensor([num_replica]) if self._state_includes_replicas: replica_states = init_state else: replica_states = [ tf.broadcast_to( # pylint: disable=g-complex-comprehension x, ps.concat([replica_shape, ps.shape(x)], axis=0), name='replica_states') for x in init_state ] target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( self.target_log_prob_fn, inverse_temperatures) # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.make_kernel_fn`. # In other words: # - We try `make_kernel_fn` without a seed first; this is the future. The # kernel will receive a seed later, as part of `one_step`. # - If the user code doesn't like that (Python complains about a missing # required argument), we fall back to the previous behavior and warn. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The second (`seed`) argument to `ReplicaExchangeMC`s ' '`make_kernel_fn` is deprecated. `TransitionKernel` instances now ' 'receive seeds via `bootstrap_results` and `one_step`. This ' 'fallback may become an error 2020-09-20.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, self._seed_stream()) replica_results = inner_kernel.bootstrap_results(replica_states) pre_swap_replica_target_log_prob = _get_field( replica_results, 'target_log_prob') replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Pretend we did a "null swap", which will always be accepted. swaps = mcmc_util.left_justified_broadcast_to( tf.range(num_replica), replica_and_batch_shape) # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape. is_swap_accepted = distribution_util.rotate_transpose(tf.eye( num_replica, batch_shape=batch_shape, dtype=tf.bool), shift=2) post_swap_replica_results = _make_post_swap_replica_results( replica_results, inverse_temperatures, inverse_temperatures, is_swap_accepted[0], lambda x: x, ) return ReplicaExchangeMCKernelResults( post_swap_replica_states=replica_states, pre_swap_replica_results=replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_accepted, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_accepted), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), inverse_temperatures=self.inverse_temperatures, swaps=swaps, step_count=tf.zeros(shape=(), dtype=tf.int32), seed=samplers.zeros_seed(), )
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] dummy_momentum = [tf.ones_like(state) for state in init_state] def _init(shape_and_dtype): """Allocate TensorArray for storing state and momentum.""" return [ # pylint: disable=g-complex-comprehension ps.zeros( ps.concat([[max(self._write_instruction) + 1], s], axis=0), dtype=d) for (s, d) in shape_and_dtype ] get_shapes_and_dtypes = lambda x: [(ps.shape(x_), x_.dtype) # pylint: disable=g-long-lambda for x_ in x] momentum_state_memory = MomentumStateSwap( momentum_swap=_init(get_shapes_and_dtypes(dummy_momentum)), state_swap=_init(get_shapes_and_dtypes(init_state))) [ _, _, current_target_log_prob, current_grads_log_prob, ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum, init_state) # Padding the step_size so it is compatable with the states step_size = self.step_size if len(step_size) == 1: step_size = step_size * len(init_state) if len(step_size) != len(init_state): raise ValueError('Expected either one step size or {} (size of ' '`init_state`), but found {}'.format( len(init_state), len(step_size))) step_size = tf.nest.map_structure( lambda x: tf.convert_to_tensor( # pylint: disable=g-long-lambda x, dtype=current_target_log_prob.dtype, name='step_size'), step_size) return NUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, momentum_state_memory=momentum_state_memory, step_size=step_size, log_accept_ratio=tf.zeros_like(current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like(current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian(current_target_log_prob, dummy_momentum), # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( self.target_log_prob_fn, inverse_temperatures) # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.make_kernel_fn`, and no seed # expected by `kernel.one_step`. # In other words: # - We try `make_kernel_fn` without a seed first; this is the future. The # kernel will receive a seed later, as part of `one_step`. # - If the user code doesn't like that (Python complains about a missing # required argument), we warn and fall back to the previous behavior. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is ' 'deprecated. `TransitionKernel` instances now receive seeds via ' '`one_step`.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, self._seed_stream()) # Now that we've constructed the TransitionKernel instance: # - If we were given a seed, we sanitize it to stateless and pass along # to `kernel.one_step`. If it doesn't like that, we crash and propagate # the error. Rationale: The contract is stateless sampling given # seed, and doing otherwise would not meet it. # - If not given a seed, we don't pass one along. This avoids breaking # underlying kernels lacking a `seed` arg on `one_step`. # TODO(b/159636942): Clean up after 2020-09-20. if seed is not None: seed = samplers.sanitize_seed(seed) inner_seed, swap_seed, logu_seed = samplers.split_seed( seed, n=3, salt='remc_one_step') inner_kwargs = dict(seed=inner_seed) else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) inner_kwargs = {} swap_seed, logu_seed = samplers.split_seed(self._seed_stream()) [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results, **inner_kwargs) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob) num_replica = ps.size0(inverse_temperatures) inverse_temperatures = mcmc_util.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. try: swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed, step_count=previous_kernel_results.step_count), dtype=tf.int32) except TypeError as e: if 'step_count' not in str(e): raise warnings.warn( 'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept ' 'the `step_count` argument. Falling back to omitting the ' 'argument. This fallback will be removed after 24-Oct-2020.' ) swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed), dtype=tf.int32) null_swaps = mcmc_util.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs. E.g., for replica k, at point x_k, this is # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k. untempered_pre_swap_replica_target_log_prob = ( pre_swap_replica_target_log_prob / inverse_temperatures) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. energy_diff = (untempered_pre_swap_replica_target_log_prob - mcmc_util.index_remapping_gather( untempered_pre_swap_replica_target_log_prob, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * mcmc_util.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce Log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=logu_seed)) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = mcmc_util.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, mcmc_util.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) if self._state_includes_replicas: post_swap_states = post_swap_replica_states else: post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _make_post_swap_replica_results( pre_swap_replica_results, inverse_temperatures, swapped_inverse_temperatures, is_swap_accepted_mask, _swap_tensor) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, step_count=previous_kernel_results.step_count + 1, seed=samplers.zeros_seed() if seed is None else seed, ) return states, post_swap_kernel_results
def _fixed_sample(d): return d.sample(seed=samplers.zeros_seed())
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope(mcmc_util.make_name( self.name, 'remc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') if self._state_includes_replicas: it_n_replica = inverse_temperatures.shape[0] state_n_replica = init_state[0].shape[0] if ((it_n_replica is not None) and (state_n_replica is not None) and (it_n_replica != state_n_replica)): raise ValueError( 'Number of replicas implied by initial state ({}) must equal ' 'number of replicas implied by inverse_temperatures ({}), but ' 'did not'.format(it_n_replica, state_n_replica)) # We will now replicate each of a possible batch of initial stats, one for # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy] # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means # concatenation and T=shape(inverse_temperature). num_replica = ps.size0(inverse_temperatures) replica_shape = ps.convert_to_shape_tensor([num_replica]) if self._state_includes_replicas: replica_states = init_state else: replica_states = [ tf.broadcast_to( # pylint: disable=g-complex-comprehension x, ps.concat([replica_shape, ps.shape(x)], axis=0), name='replica_states') for x in init_state ] target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( target_log_prob_fn=self.target_log_prob_fn, inverse_temperatures=inverse_temperatures, untempered_log_prob_fn=self.untempered_log_prob_fn, tempered_log_prob_fn=self.tempered_log_prob_fn, ) # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise raise TypeError( '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a second ' '(`seed`) argument. `TransitionKernel` instances now receive seeds ' 'via `one_step`.') replica_results = inner_kernel.bootstrap_results(replica_states) pre_swap_replica_target_log_prob = _get_field( replica_results, 'target_log_prob') replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] inverse_temperatures = bu.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Pretend we did a "null swap", which will always be accepted. swaps = bu.left_justified_broadcast_to( tf.range(num_replica), replica_and_batch_shape) # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape. is_swap_accepted = distribution_util.rotate_transpose( tf.eye(num_replica, batch_shape=batch_shape, dtype=tf.bool), shift=2) return ReplicaExchangeMCKernelResults( post_swap_replica_states=replica_states, pre_swap_replica_results=replica_results, post_swap_replica_results=_set_swapped_fields_to_nan(replica_results), is_swap_proposed=is_swap_accepted, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_accepted), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), inverse_temperatures=self.inverse_temperatures, swaps=swaps, step_count=tf.zeros(shape=(), dtype=tf.int32), seed=samplers.zeros_seed(), )
def bootstrap_results(self, init_state): """Creates initial `previous_kernel_results` using a supplied `state`.""" with tf.name_scope(self.name + '.bootstrap_results'): if not tf.nest.is_nested(init_state): init_state = [init_state] # Padding the step_size so it is compatable with the states step_size = self.step_size if len(step_size) == 1: step_size = step_size * len(init_state) if len(step_size) != len(init_state): raise ValueError( 'Expected either one step size or {} (size of ' '`init_state`), but found {}'.format( len(init_state), len(step_size))) state_parts, _ = mcmc_util.prepare_state_parts( init_state, name='current_state') current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads( self.target_log_prob_fn, state_parts) momentum_distribution = self.momentum_distribution if momentum_distribution is None: momentum_distribution = pu.make_momentum_distribution( state_parts, ps.shape(current_target_log_prob)) momentum_distribution = pu.maybe_make_list_and_batch_broadcast( momentum_distribution, ps.shape(current_target_log_prob)) momentum_parts = momentum_distribution.sample() def _init(shape_and_dtype): """Allocate TensorArray for storing state and velocity.""" return [ # pylint: disable=g-complex-comprehension ps.zeros(ps.concat([[max(self._write_instruction) + 1], s], axis=0), dtype=d) for (s, d) in shape_and_dtype ] get_shapes_and_dtypes = lambda x: [ (ps.shape(x_), x_.dtype) # pylint: disable=g-long-lambda for x_ in x ] velocity_state_memory = VelocityStateSwap( velocity_swap=_init(get_shapes_and_dtypes(momentum_parts)), state_swap=_init(get_shapes_and_dtypes(init_state))) return PreconditionedNUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, velocity_state_memory=velocity_state_memory, step_size=step_size, log_accept_ratio=tf.zeros_like(current_target_log_prob, name='log_accept_ratio'), leapfrogs_taken=tf.zeros_like(current_target_log_prob, dtype=TREE_COUNT_DTYPE, name='leapfrogs_taken'), is_accepted=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='is_accepted'), reach_max_depth=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='reach_max_depth'), has_divergence=tf.zeros_like(current_target_log_prob, dtype=tf.bool, name='has_divergence'), energy=compute_hamiltonian(current_target_log_prob, momentum_parts, momentum_distribution), momentum_distribution=momentum_distribution, # Allow room for one_step's seed. seed=samplers.zeros_seed(), )
def dummy_seed(): """Returns a fixed constant seed, for cases needing samples without a seed.""" # TODO(b/147874898): After 20 Dec 2020, drop the 42 and inline the zeros_seed. return samplers.zeros_seed() if JAX_MODE else 42
from tensorflow_probability.python.internal import callable_util from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers # pylint: disable=g-long-lambda,protected-access preconditioning_bijector_fns = { deterministic.Deterministic: (lambda d: d.experimental_default_event_space_bijector()), independent.Independent: lambda d: make_distribution_bijector(d.distribution), markov_chain.MarkovChain: lambda d: markov_chain._MarkovChainBijector( chain=d, transition_bijector=make_distribution_bijector( d.transition_fn( 0, d.initial_state_prior.sample(seed=samplers.zeros_seed()))), bijector_fn=make_distribution_bijector), normal.Normal: lambda d: tfb.Shift(d.loc)(tfb.Scale(d.scale)), sample.Sample: lambda d: sample._DefaultSampleBijector( distribution=d.distribution, sample_shape=d.sample_shape, sum_fn=d._sum_fn(), bijector=make_distribution_bijector(d.distribution)), uniform.Uniform: lambda d: (tfb.Shift(d.low)(tfb.Scale(d.high - d.low)(tfb.NormalCDF()))) } # pylint: enable=g-long-lambda,protected-access
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'rwm', 'one_step')): with tf.name_scope('initialize'): if mcmc_util.is_list_like(current_state): current_state_parts = list(current_state) else: current_state_parts = [current_state] current_state_parts = [ tf.convert_to_tensor(s, name='current_state') for s in current_state_parts ] # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.new_state_fn`. # In other words: # - If we were given a seed, we sanitize it to stateless, and # if the `new_state_fn` doesn't like that, we crash and propagate # the error. Rationale: The contract is stateless sampling given # seed, and doing otherwise would not meet it. # - If we were not given a seed, we try `new_state_fn` with a stateless # seed. Rationale: This is the future. # - If it fails with a seed incompatibility problem (as best we can # detect from here), we issue a warning and try it again with a # stateful-style seed. Rationale: User code that didn't set seeds # shouldn't suddenly break. # TODO(b/159636942): Clean up after 2020-09-20. if seed is not None: force_stateless = True seed = samplers.sanitize_seed(seed) else: force_stateless = False if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) stateful_seed = self._seed_stream() seed = samplers.sanitize_seed(stateful_seed) try: next_state_parts = self.new_state_fn(current_state_parts, seed) # pylint: disable=not-callable except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)) or force_stateless: raise msg = ( 'Falling back to `int` seed for `new_state_fn` {}. Please update ' 'to use `tf.random.stateless_*` RNGs. ' 'This fallback may be removed after 10-Sep-2020. ({})') warnings.warn(msg.format(self.new_state_fn, str(e))) seed = None next_state_parts = self.new_state_fn( # pylint: disable=not-callable current_state_parts, stateful_seed) # Compute `target_log_prob` so its available to MetropolisHastings. next_target_log_prob = self.target_log_prob_fn(*next_state_parts) # pylint: disable=not-callable def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedRandomWalkResults( log_acceptance_correction=tf.zeros_like( next_target_log_prob), target_log_prob=next_target_log_prob, seed=samplers.zeros_seed() if seed is None else seed, ), ]
def __init__(self, parameter_prior, parameterized_initial_state_prior_fn, parameterized_transition_fn, parameterized_observation_fn, parameterized_initial_state_proposal_fn=None, parameterized_proposal_fn=None, parameter_constraining_bijector=None, name=None): """Builds an iterated filter for parameter estimation in sequential models. Iterated filtering is a parameter estimation method in which parameters are included in an augmented state space, with dynamics that introduce parameter perturbations, and a filtering algorithm such as particle filtering is run several times with perturbations of decreasing size. This class implements the IF2 algorithm of [Ionides et al., 2015][1], for which, under appropriate conditions (including a uniform prior) the final parameter distribution approaches a point mass at the maximum likelihood estimate. If a non-uniform prior is provided, the final parameter distribution will (under appropriate conditions) approach a point mass at the maximum a posteriori (MAP) value. This class augments the state space of a sequential model to include parameter perturbations, and provides utilities to run particle filtering on that augmented model. Alternately, the augmented components may be passed directly into a filtering algorithm of the user's choice. Args: parameter_prior: prior `tfd.Distribution` over parameters (may be a joint distribution). parameterized_initial_state_prior_fn: `callable` with signature `initial_state_prior = parameterized_initial_state_prior_fn(parameters)` where `parameters` has the form of a sample from `parameter_prior`, and `initial_state_prior` is a distribution over the initial state. parameterized_transition_fn: `callable` with signature `next_state_dist = parameterized_transition_fn( step, state, parameters, **kwargs)`. parameterized_observation_fn: `callable` with signature `observation_dist = parameterized_observation_fn( step, state, parameters, **kwargs)`. parameterized_initial_state_proposal_fn: optional `callable` with signature `initial_state_proposal = parameterized_initial_state_proposal_fn(parameters)` where `parameters` has the form of a sample from `parameter_prior`, and `initial_state_proposal` is a distribution over the initial state. parameterized_proposal_fn: optional `callable` with signature `next_state_dist = parameterized_transition_fn( step, state, parameters, **kwargs)`. Default value: `None`. parameter_constraining_bijector: optional `tfb.Bijector` instance such that `parameter_constraining_bijector.forward(x)` returns valid parameters for any real-valued `x` of the same structure and shape as `parameters`. If `None`, the default bijector of the provided `parameter_prior` will be used. Default value: `None`. name: `str` name for ops constructed by this object. Default value: `iterated_filter`. #### Example We'll walk through applying iterated filtering to a toy Susceptible-Infected-Recovered (SIR) model, a [compartmental model]( https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#The_SIR_model) of infectious disease. Note that the model we use here is extremely simplified and is intended as a pedagogical example; it should not be interpreted to describe disease spread in the real world. We begin by specifying a prior distribution over the parameters to be inferred, thus defining the structure of the parameter space and the support of the parameters (which will imply a default constraining bijector). Here we'll use uniform priors over ranges that we expect to contain the parameters: ```python parameter_prior = tfd.JointDistributionNamed({ 'infection_rate': tfd.Uniform(low=0., high=3.), 'recovery_rate': tfd.Uniform(low=0., high=3.), }) ``` The model specification itself is identical to that used by `tfp.experimental.mcmc.infer_trajectories`, except that each component accepts an additional `parameters` keyword argument. We start by specifying a parameterized prior on initial states. In this case, our state includes the current number of susceptible and infected individuals (the third compartment, recovered individuals, is implicitly defined to include the remaining population). We'll also include, as auxiliary variables, the daily counts of new infections and new recoveries; these will help ensure that people shift consistently across compartments. ```python population_size = 1000 initial_state_prior_fn = lambda parameters: tfd.JointDistributionNamed({ 'new_infections': tfd.Poisson(parameters['infection_rate']), 'new_recoveries': tfd.Deterministic( tf.broadcast_to(0., tf.shape(parameters['recovery_rate']))), 'susceptible': (lambda new_infections: tfd.Deterministic(population_size - new_infections)), 'infected': (lambda new_infections: tfd.Deterministic(new_infections))}) ``` **Note**: the state prior must have the same batch shape as the passed-in parameters; equivalently, it must sample a full state for each parameter particle. If any part of the state prior does not depend on the parameters, you must manually ensure that it has the appropriate batch shape. For example, in the definition of `new_recoveries` above, applying `broadcast_to` with the shape of a parameter ensures that the batch shape is maintained. Next, we specify a transition model. This takes the state at the previous day, along with parameters, and returns a distribution over the state for the current day. ```python def parameterized_infection_dynamics(_, previous_state, parameters): new_infections = tfd.Poisson( parameters['infection_rate'] * previous_state['infected'] * previous_state['susceptible'] / population_size) new_recoveries = tfd.Poisson( previous_state['infected'] * parameters['recovery_rate']) return tfd.JointDistributionNamed({ 'new_infections': new_infections, 'new_recoveries': new_recoveries, 'susceptible': lambda new_infections: tfd.Deterministic( tf.maximum(0., previous_state['susceptible'] - new_infections)), 'infected': lambda new_infections, new_recoveries: tfd.Deterministic( tf.maximum(0., (previous_state['infected'] + new_infections - new_recoveries)))}) ``` Finally, assume that every day we get to observe noisy counts of new infections and recoveries. ```python def parameterized_infection_observations(_, state, parameters): del parameters # Not used. return tfd.JointDistributionNamed({ 'new_infections': tfd.Poisson(state['new_infections'] + 0.1), 'new_recoveries': tfd.Poisson(state['new_recoveries'] + 0.1)}) ``` Combining these components, an `IteratedFilter` augments the state space to include parameters that may change over time. ```python iterated_filter = tfp.experimental.sequential.IteratedFilter( parameter_prior=parameter_prior, parameterized_initial_state_prior_fn=initial_state_prior_fn, parameterized_transition_fn=parameterized_infection_dynamics, parameterized_observation_fn=parameterized_infection_observations) ``` We may then run the filter to estimate parameters from a series of observations: ```python # Simulated with `infection_rate=1.2` and `recovery_rate=0.1`. observed_values = { 'new_infections': tf.convert_to_tensor([ 2., 7., 14., 24., 45., 93., 160., 228., 252., 158., 17., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]), 'new_recoveries': tf.convert_to_tensor([ 0., 0., 3., 4., 3., 8., 12., 31., 49., 73., 85., 65., 71., 58., 42., 65., 36., 31., 32., 27., 31., 20., 19., 19., 14., 27.]) } parameter_particles = iterated_filter.estimate_parameters( observations=observed_values, num_iterations=20, num_particles=4096, initial_perturbation_scale=1.0, cooling_schedule=( tfp.experimental.sequential.geometric_cooling_schedule( 0.001, k=20)), seed=test_util.test_seed()) print('Mean of parameter particles from final iteration: {}'.format( tf.nest.map_structure(lambda x: tf.reduce_mean(x[-1], axis=0), parameter_particles))) print('Standard deviation of parameter particles from ' 'final iteration: {}'.format( tf.nest.map_structure(lambda x: tf.math.reduce_std(x[-1], axis=0), parameter_particles))) ``` For more control, we could alternately choose to run filtering iterations on the augmented model manually, using the filter of our choice. For example, manually invoking `infer_trajectories` would allow us to inspect the parameter and state values at all timesteps, and their corresponding log-probabilities: ```python trajectories, lps = tfp.experimental.mcmc.infer_trajectories( observations=observations, initial_state_prior=iterated_filter.joint_initial_state_prior, transition_fn=functools.partial( iterated_filter.joint_transition_fn, perturbation_scale=perturbation_scale), observation_fn=iterated_filter.joint_observation_fn, proposal_fn=iterated_filter.joint_proposal_fn, initial_state_proposal=iterated_filter.joint_initial_state_proposal( initial_unconstrained_parameters), num_particles=4096) ``` #### References: [1] Edward L. Ionides, Dao Nguyen, Yves Atchade, Stilian Stoev, and Aaron A. King. Inference for dynamic and latent variable models via iterated, perturbed Bayes maps. _Proceedings of the National Academy of Sciences_ 112, no. 3: 719-724, 2015. https://www.pnas.org/content/pnas/112/3/719.full.pdf """ name = name or 'IteratedFilter' with tf.name_scope(name): self._parameter_prior = parameter_prior self._parameterized_initial_state_prior_fn = ( parameterized_initial_state_prior_fn) if parameter_constraining_bijector is None: parameter_constraining_bijector = ( parameter_prior.experimental_default_event_space_bijector( )) self._parameter_constraining_bijector = parameter_constraining_bijector # Augment the prior to include both parameters and states. self._joint_initial_state_prior = joint_prior_on_parameters_and_state( parameter_prior, parameterized_initial_state_prior_fn, parameter_constraining_bijector, prior_is_constrained=True) # Check that prior samples have a consistent number of particles. # TODO(davmre): remove the need for dummy shape dependencies, # and this check, by using `JointDistributionNamedAutoBatched` with # auto-vectorization enabled in `joint_prior_on_parameters_and_state`. num_particles_canary = 13 canary_seed = samplers.zeros_seed() def _get_shape_1(x): if hasattr(x, 'state'): x = x.state return tf.TensorShape(x.shape[1:2]) prior_static_sample_shapes = tf.nest.map_structure( # Sample shape [0, num_particles_canary] particles (size will be zero) # then trim off the leading 0 and (possibly) any event shape. # We expect shape [num_particles_canary] to remain. _get_shape_1, self._joint_initial_state_prior.sample( [0, num_particles_canary], seed=canary_seed)) if not all([ tensorshape_util.is_compatible_with( s[:1], [num_particles_canary]) for s in tf.nest.flatten(prior_static_sample_shapes) ]): raise ValueError( 'The specified prior does not generate consistent ' 'shapes when sampled. Please verify that all parts of ' '`initial_state_prior_fn` have batch shape matching ' 'that of the parameters. This may require creating ' '"dummy" dependencies on parameters; for example: ' '`tf.broadcast_to(value, tf.shape(parameter))`. (in a ' f'test sample with {num_particles_canary} particles, we expected ' 'all) values to have shape compatible with ' f'[{num_particles_canary}, ...]; ' f'saw shapes {prior_static_sample_shapes})') # Augment the transition and observation fns to cover both # parameters and states. self._joint_transition_fn = augment_transition_fn_with_parameters( parameter_prior, parameterized_transition_fn, parameter_constraining_bijector) self._joint_observation_fn = augment_observation_fn_with_parameters( parameterized_observation_fn, parameter_constraining_bijector) # If given a proposal for the initial state, augment it into a joint # proposal over parameters and states. joint_initial_state_proposal = None if parameterized_initial_state_proposal_fn: joint_initial_state_proposal = joint_prior_on_parameters_and_state( parameter_prior, parameterized_initial_state_proposal_fn, parameter_constraining_bijector) else: parameterized_initial_state_proposal_fn = ( parameterized_initial_state_prior_fn) self._joint_initial_state_proposal = joint_initial_state_proposal self._parameterized_initial_state_proposal_fn = ( parameterized_initial_state_proposal_fn) # If given a conditional proposal fn (for non-initial states), augment # it to be joint over states and parameters. self._joint_proposal_fn = None if parameterized_proposal_fn: self._joint_proposal_fn = augment_transition_fn_with_parameters( parameter_prior, parameterized_proposal_fn, parameter_constraining_bijector) self._batch_ndims = tf.nest.map_structure( ps.rank_from_shape, parameter_prior.batch_shape_tensor()) self._name = name
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ # TODO(b/159636942): Clean up after 2020-09-20. if seed is not None: seed = samplers.sanitize_seed(seed) # preserve for kernel results proposal_seed, acceptance_seed = samplers.split_seed(seed) else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) acceptance_seed = samplers.sanitize_seed(self._seed_stream()) with tf.name_scope(mcmc_util.make_name(self.name, 'mh', 'one_step')): # Take one inner step. inner_kwargs = {} if seed is None else dict(seed=proposal_seed) [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results, **inner_kwargs) if (not has_target_log_prob(proposed_results) or not has_target_log_prob( previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [ proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob ] try: if (not mcmc_util.is_list_like( proposed_results.log_acceptance_correction) or proposed_results.log_acceptance_correction): to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn( 'Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum( to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio log_uniform = tf.math.log( samplers.uniform(shape=prefer_static.shape( proposed_results.target_log_prob), dtype=dtype_util.base_dtype( proposed_results.target_log_prob.dtype), seed=acceptance_seed)) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose(is_accepted, proposed_state, current_state, name='choose_next_state') kernel_results = MetropolisHastingsKernelResults( accepted_results=mcmc_util.choose( is_accepted, # We strip seeds when populating `accepted_results` because unlike # other kernel result fields, seeds are not a per-chain value. # Thus it is impossible to choose between a previously accepted # seed value and a proposed seed, since said choice would need to # be made on a per-chain basis. mcmc_util.strip_seeds(proposed_results), previous_kernel_results.accepted_results, name='choose_inner_results'), is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, extra=[], seed=samplers.zeros_seed() if seed is None else seed, ) return next_state, kernel_results
def one_step(self, current_state, previous_kernel_results, seed=None): # TODO(b/159636942): Clean up after 2020-09-20. if seed is not None: start_trajectory_seed, loop_seed = samplers.split_seed( seed, salt='nuts.one_step') else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) start_trajectory_seed, loop_seed = samplers.split_seed( self._seed_stream(), salt='nuts.one_step') with tf.name_scope(self.name + '.one_step'): unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] current_target_log_prob = previous_kernel_results.target_log_prob [init_momentum, init_energy, log_slice_sample ] = self._start_trajectory_batched(current_state, current_target_log_prob, seed=start_trajectory_seed) def _copy(v): return v * ps.ones(ps.pad( [2], paddings=[[0, ps.rank(v)]], constant_values=1), dtype=v.dtype) initial_state = TreeDoublingState( momentum=init_momentum, state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob ) initial_step_state = tf.nest.map_structure(_copy, initial_state) if MULTINOMIAL_SAMPLE: init_weight = tf.zeros_like(init_energy) # log(exp(H0 - H0)) else: init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE) candidate_state = TreeDoublingStateCandidate( state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results. grads_target_log_prob, energy=init_energy, weight=init_weight) initial_step_metastate = TreeDoublingMetaState( candidate_state=candidate_state, is_accepted=tf.zeros_like(init_energy, dtype=tf.bool), momentum_sum=init_momentum, energy_diff_sum=tf.zeros_like(init_energy), leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE), continue_tree=tf.ones_like(init_energy, dtype=tf.bool), not_divergence=tf.ones_like(init_energy, dtype=tf.bool)) # Convert the write/read instruction into TensorArray so that it is # compatible with XLA. write_instruction = tf.TensorArray( TREE_COUNT_DTYPE, size=len(self._write_instruction), clear_after_read=False).unstack(self._write_instruction) read_instruction = tf.TensorArray(tf.int32, size=len(self._read_instruction), clear_after_read=False).unstack( self._read_instruction) current_step_meta_info = OneStepMetaInfo( log_slice_sample=log_slice_sample, init_energy=init_energy, write_instruction=write_instruction, read_instruction=read_instruction) _, _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda (iter_ < self.max_tree_depth) & tf.reduce_any( metastate.continue_tree)), body=lambda iter_, seed, state, metastate: self. _loop_tree_doubling( # pylint: disable=g-long-lambda previous_kernel_results.step_size, previous_kernel_results. momentum_state_memory, current_step_meta_info, iter_, state, metastate, seed), loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'), loop_seed, initial_step_state, initial_step_metastate), parallel_iterations=self.parallel_iterations, ) kernel_results = NUTSKernelResults( target_log_prob=new_step_metastate.candidate_state.target, grads_target_log_prob=( new_step_metastate.candidate_state.target_grad_parts), momentum_state_memory=previous_kernel_results. momentum_state_memory, step_size=previous_kernel_results.step_size, log_accept_ratio=tf.math.log( new_step_metastate.energy_diff_sum / tf.cast(new_step_metastate.leapfrog_count, dtype=new_step_metastate.energy_diff_sum.dtype)), leapfrogs_taken=(new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps), is_accepted=new_step_metastate.is_accepted, reach_max_depth=new_step_metastate.continue_tree, has_divergence=~new_step_metastate.not_divergence, energy=new_step_metastate.candidate_state.energy, seed=samplers.zeros_seed() if seed is None else seed, ) result_state = new_step_metastate.candidate_state.state if unwrap_state_list: result_state = result_state[0] return result_state, kernel_results