def _start_trajectory_batched(self, state, target_log_prob, seed): """Computations needed to start a trajectory.""" with tf.name_scope('start_trajectory_batched'): seeds = samplers.split_seed(seed, n=len(state) + 1) momentum_seeds = distribute_lib.fold_in_axis_index( seeds[:-1], self.experimental_shard_axis_names) momentum = [ samplers.normal( # pylint: disable=g-complex-comprehension shape=ps.shape(x), dtype=x.dtype, seed=momentum_seeds[i]) for (i, x) in enumerate(state) ] init_energy = compute_hamiltonian( target_log_prob, momentum, shard_axis_names=self.experimental_shard_axis_names) if MULTINOMIAL_SAMPLE: return momentum, init_energy, None # Draw a slice variable u ~ Uniform(0, p(initial state, initial # momentum)) and compute log u. For numerical stability, we perform this # in log space where log u = log (u' * p(...)) = log u' + log # p(...) and u' ~ Uniform(0, 1). log_slice_sample = tf.math.log1p(-samplers.uniform( shape=ps.shape(init_energy), dtype=init_energy.dtype, seed=seeds[len(state)])) return momentum, init_energy, log_slice_sample
def _choose_random_direction(current_state_parts, batch_rank, seed=None, experimental_shard_axis_names=None): """Chooses a random direction in the event space.""" seeds = list(samplers.split_seed(seed, n=len(current_state_parts))) seeds = distribute_lib.fold_in_axis_index( seeds, experimental_shard_axis_names) # Sample random directions across each of the input components. def _sample_direction_part(state_part, part_seed): state_part_shape = ps.shape(state_part) batch_shape = state_part_shape[:batch_rank] dimension = ps.reduce_prod(state_part_shape[batch_rank:]) return ps.reshape( random_ops.spherical_uniform( shape=batch_shape, dimension=dimension, dtype=state_part.dtype, seed=part_seed), state_part_shape) return [_sample_direction_part(state_part, seed) for state_part, seed in zip(current_state_parts, seeds)]
def _fn(state_parts, seed, experimental_shard_axis_names=None): """Adds a normal perturbation to the input state. Args: state_parts: A list of `Tensor`s of any shape and real dtype representing the state parts of the `current_state` of the Markov chain. seed: `int` or None. The random seed for this `Op`. If `None`, no seed is applied. experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. Returns: perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same shape and type as the `state_parts`. Raises: ValueError: if `scale` does not broadcast with `state_parts`. """ with tf.name_scope(name or 'random_walk_normal_fn'): scales = scale if mcmc_util.is_list_like(scale) else [scale] if len(scales) == 1: scales *= len(state_parts) if len(state_parts) != len(scales): raise ValueError('`scale` must broadcast with `state_parts`.') part_seeds = samplers.split_seed(seed, n=len(state_parts)) part_seeds = distribute_lib.fold_in_axis_index( part_seeds, experimental_shard_axis_names) next_state_parts = [ samplers.normal( # pylint: disable=g-complex-comprehension mean=state_part, stddev=scale_part, shape=ps.shape(state_part), dtype=dtype_util.base_dtype(state_part.dtype), seed=seed_part) for scale_part, state_part, seed_part in zip( scales, state_parts, part_seeds) ] return next_state_parts
def _fn(state_parts, seed, experimental_shard_axis_names=None): """Adds a uniform perturbation to the input state. Args: state_parts: A list of `Tensor`s of any shape and real dtype representing the state parts of the `current_state` of the Markov chain. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. experimental_shard_axis_names: A structure of string names indicating how members of the state are sharded. Returns: perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same shape and type as the `state_parts`. Raises: ValueError: if `scale` does not broadcast with `state_parts`. """ with tf.name_scope(name or 'random_walk_uniform_fn'): scales = scale if mcmc_util.is_list_like(scale) else [scale] if len(scales) == 1: scales *= len(state_parts) if len(state_parts) != len(scales): raise ValueError('`scale` must broadcast with `state_parts`.') part_seeds = list(samplers.split_seed(seed, n=len(state_parts))) part_seeds = distribute_lib.fold_in_axis_index( part_seeds, experimental_shard_axis_names) next_state_parts = [ samplers.uniform( # pylint: disable=g-complex-comprehension minval=state_part - scale_part, maxval=state_part + scale_part, shape=tf.shape(state_part), dtype=dtype_util.base_dtype(state_part.dtype), seed=seed_part) for scale_part, state_part, seed_part in zip( scales, state_parts, part_seeds) ] return next_state_parts
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'mala', 'one_step')): with tf.name_scope('initialize'): # Prepare input arguments to be passed to `_euler_method`. [ current_state_parts, step_size_parts, current_target_log_prob, _, # grads_target_log_prob current_volatility_parts, _, # grads_volatility current_drift_parts, ] = _prepare_args( self.target_log_prob_fn, self.volatility_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, previous_kernel_results.volatility, previous_kernel_results.grads_volatility, previous_kernel_results.diffusion_drift, self.parallel_iterations) seed = samplers.sanitize_seed(seed) # Retain for diagnostics. seeds = list( samplers.split_seed(seed, n=len(current_state_parts), salt='langevin.one_step')) seeds = distribute_lib.fold_in_axis_index( seeds, self.experimental_shard_axis_names) random_draw_parts = [] for state_part, part_seed in zip(current_state_parts, seeds): random_draw_parts.append( samplers.normal(shape=ps.shape(state_part), dtype=dtype_util.base_dtype( state_part.dtype), seed=part_seed)) # Number of independent chains run by the algorithm. independent_chain_ndims = ps.rank(current_target_log_prob) # Generate the next state of the algorithm using Euler-Maruyama method. next_state_parts = _euler_method(random_draw_parts, current_state_parts, current_drift_parts, step_size_parts, current_volatility_parts) # Compute helper `UncalibratedLangevinKernelResults` to be processed by # `_compute_log_acceptance_correction` and in the next iteration of # `one_step` function. [ _, # state_parts _, # step_sizes next_target_log_prob, next_grads_target_log_prob, next_volatility_parts, next_grads_volatility, next_drift_parts, ] = _prepare_args(self.target_log_prob_fn, self.volatility_fn, next_state_parts, step_size_parts, parallel_iterations=self.parallel_iterations) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] # Decide whether to compute the acceptance ratio log_acceptance_correction_compute = _compute_log_acceptance_correction( current_state_parts, next_state_parts, current_volatility_parts, next_volatility_parts, current_drift_parts, next_drift_parts, step_size_parts, independent_chain_ndims, experimental_shard_axis_names=self. experimental_shard_axis_names) log_acceptance_correction_skip = tf.zeros_like( next_target_log_prob) log_acceptance_correction = tf.cond( pred=self.compute_acceptance, true_fn=lambda: log_acceptance_correction_compute, false_fn=lambda: log_acceptance_correction_skip) return [ maybe_flatten(next_state_parts), UncalibratedLangevinKernelResults( log_acceptance_correction=log_acceptance_correction, target_log_prob=next_target_log_prob, grads_target_log_prob=next_grads_target_log_prob, volatility=maybe_flatten(next_volatility_parts), grads_volatility=next_grads_volatility, diffusion_drift=next_drift_parts, seed=seed, ), ]
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) seed = samplers.sanitize_seed(seed) # Retain for diagnostics. seeds = samplers.split_seed(seed, n=len(current_state_parts)) seeds = distribute_lib.fold_in_axis_index( seeds, self.experimental_shard_axis_names) current_momentum_parts = [] for part_seed, x in zip(seeds, current_state_parts): current_momentum_parts.append( samplers.normal(shape=ps.shape(x), dtype=self._momentum_dtype or dtype_util.base_dtype(x.dtype), seed=part_seed)) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator(current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] independent_chain_ndims = ps.rank(current_target_log_prob) new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims, shard_axis_names=self.experimental_shard_axis_names), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, initial_momentum=current_momentum_parts, final_momentum=next_momentum_parts, seed=seed, ) return maybe_flatten(next_state_parts), new_kernel_results
def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_sample') seed = distribute_lib.fold_in_axis_index( seed, self.experimental_shard_axis_names) return self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
def one_step(self, current_state, previous_kernel_results, seed=None): seed = samplers.sanitize_seed(seed, salt='sharded_kernel') seed = distribute_lib.fold_in_axis_index(seed, self.chain_axis_names) return self.inner_kernel.one_step( current_state, previous_kernel_results, seed=seed)