def _sample_n(self, n, seed=None): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) seed = samplers.sanitize_seed(seed, salt='triangular') shape = ps.concat([[n], self._batch_shape_tensor( low=low, high=high, peak=peak)], axis=0) samples = samplers.uniform(shape=shape, dtype=self.dtype, seed=seed) # We use Inverse CDF sampling here. Because the CDF is a quadratic function, # we must use sqrts here. interval_length = high - low return tf.where( # Note the CDF on the left side of the peak is # (x - low) ** 2 / ((high - low) * (peak - low)). # If we plug in peak for x, we get that the CDF at the peak # is (peak - low) / (high - low). Because of this we decide # which part of the piecewise CDF we should use based on the cdf samples # we drew. samples < (peak - low) / interval_length, # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)). low + tf.sqrt(samples * interval_length * (peak - low)), # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak)) high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
def _random_gamma_no_gradient( shape, concentration, rate, log_rate, seed, log_space): """Sample a gamma, CPU specialized to stateless_gamma. Args: shape: Sample shape. concentration: Concentration of gamma distribution. rate: Rate parameter of gamma distribution. log_rate: Log-rate parameter of gamma distribution. seed: int or Tensor seed. log_space: If `True`, draw log-of-gamma samples. Returns: samples: Samples from gamma distributions. """ seed = samplers.sanitize_seed(seed) sampler_impl = implementation_selection.implementation_selecting( fn_name='gamma', default_fn=_random_gamma_noncpu, cpu_fn=_random_gamma_cpu) return sampler_impl( shape=shape, concentration=concentration, rate=rate, log_rate=log_rate, seed=seed, log_space=log_space)
def random_von_mises(shape, concentration, dtype=tf.float32, seed=None): """Samples from the standardized von Mises distribution. The distribution is vonMises(loc=0, concentration=concentration), so the mean is zero. The location can then be changed by adding it to the samples. The sampling algorithm is rejection sampling with wrapped Cauchy proposal [1]. The samples are pathwise differentiable using the approach of [2]. Args: shape: The output sample shape. concentration: The concentration parameter of the von Mises distribution. dtype: The data type of concentration and the outputs. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: Differentiable samples of standardized von Mises. References: [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986; Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf + corrections http://www.nrbook.com/devroye/Devroye_files/errors.pdf [2] Michael Figurnov, Shakir Mohamed, Andriy Mnih. "Implicit Reparameterization Gradients", 2018. """ shape = ps.convert_to_shape_tensor(shape, dtype_hint=tf.int32, name='shape') seed = samplers.sanitize_seed(seed, salt='von_mises') concentration = tf.convert_to_tensor(concentration, dtype=dtype, name='concentration') return _von_mises_sample_with_gradient(shape, concentration, seed)
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'rwm', 'one_step')): with tf.name_scope('initialize'): if mcmc_util.is_list_like(current_state): current_state_parts = list(current_state) else: current_state_parts = [current_state] current_state_parts = [ tf.convert_to_tensor(s, name='current_state') for s in current_state_parts ] seed = samplers.sanitize_seed(seed) # Retain for diagnostics. next_state_parts = self.new_state_fn(current_state_parts, seed) # pylint: disable=not-callable # User should be using a new_state_fn that does not alter the state size. # This will fail noisily if that is not the case. for next_part, current_part in zip(next_state_parts, current_state_parts): tensorshape_util.set_shape(next_part, current_part.shape) # Compute `target_log_prob` so its available to MetropolisHastings. next_target_log_prob = self.target_log_prob_fn(*next_state_parts) # pylint: disable=not-callable def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedRandomWalkResults( log_acceptance_correction=tf.zeros_like( next_target_log_prob), target_log_prob=next_target_log_prob, seed=seed, ), ]
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'mala', 'one_step')): with tf.name_scope('initialize'): # Prepare input arguments to be passed to `_euler_method`. [ current_state_parts, step_size_parts, current_target_log_prob, _, # grads_target_log_prob current_volatility_parts, _, # grads_volatility current_drift_parts, ] = _prepare_args( self.target_log_prob_fn, self.volatility_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, previous_kernel_results.volatility, previous_kernel_results.grads_volatility, previous_kernel_results.diffusion_drift, self.parallel_iterations) seed = samplers.sanitize_seed(seed) # Retain for diagnostics. seeds = list( samplers.split_seed(seed, n=len(current_state_parts), salt='langevin.one_step')) seeds = distribute_lib.fold_in_axis_index( seeds, self.experimental_shard_axis_names) random_draw_parts = [] for state_part, part_seed in zip(current_state_parts, seeds): random_draw_parts.append( samplers.normal(shape=ps.shape(state_part), dtype=dtype_util.base_dtype( state_part.dtype), seed=part_seed)) # Number of independent chains run by the algorithm. independent_chain_ndims = ps.rank(current_target_log_prob) # Generate the next state of the algorithm using Euler-Maruyama method. next_state_parts = _euler_method(random_draw_parts, current_state_parts, current_drift_parts, step_size_parts, current_volatility_parts) # Compute helper `UncalibratedLangevinKernelResults` to be processed by # `_compute_log_acceptance_correction` and in the next iteration of # `one_step` function. [ _, # state_parts _, # step_sizes next_target_log_prob, next_grads_target_log_prob, next_volatility_parts, next_grads_volatility, next_drift_parts, ] = _prepare_args(self.target_log_prob_fn, self.volatility_fn, next_state_parts, step_size_parts, parallel_iterations=self.parallel_iterations) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] # Decide whether to compute the acceptance ratio log_acceptance_correction_compute = _compute_log_acceptance_correction( current_state_parts, next_state_parts, current_volatility_parts, next_volatility_parts, current_drift_parts, next_drift_parts, step_size_parts, independent_chain_ndims, experimental_shard_axis_names=self. experimental_shard_axis_names) log_acceptance_correction_skip = tf.zeros_like( next_target_log_prob) log_acceptance_correction = tf.cond( pred=self.compute_acceptance, true_fn=lambda: log_acceptance_correction_compute, false_fn=lambda: log_acceptance_correction_skip) return [ maybe_flatten(next_state_parts), UncalibratedLangevinKernelResults( log_acceptance_correction=log_acceptance_correction, target_log_prob=next_target_log_prob, grads_target_log_prob=next_grads_target_log_prob, volatility=maybe_flatten(next_volatility_parts), grads_volatility=next_grads_volatility, diffusion_drift=next_drift_parts, seed=seed, ), ]
def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_independent_sample') return super(ShardedIndependent, self)._sample_n(n, seed + self.replica_id, **kwargs)
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, momentum_distribution, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, self.momentum_distribution, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) seed = samplers.sanitize_seed(seed) current_momentum_parts = momentum_distribution.sample(seed=seed) momentum_log_prob = getattr(momentum_distribution, '_log_prob_unnormalized', momentum_distribution.log_prob) kinetic_energy_fn = lambda *args: -momentum_log_prob(*args) # Let the integrator handle the case where no momentum distribution # is provided if self.momentum_distribution is None: leapfrog_kinetic_energy_fn = None else: leapfrog_kinetic_energy_fn = kinetic_energy_fn integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator( current_momentum_parts, current_state_parts, target=current_target_log_prob, target_grad_parts=current_target_log_prob_grad_parts, kinetic_energy_fn=leapfrog_kinetic_energy_fn) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( kinetic_energy_fn, current_momentum_parts, next_momentum_parts), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, initial_momentum=current_momentum_parts, final_momentum=next_momentum_parts, seed=seed, ) return maybe_flatten(next_state_parts), new_kernel_results
def one_step(self, current_state, previous_kernel_results, seed=None): seed = samplers.sanitize_seed(seed) # Retain for diagnostics. start_trajectory_seed, loop_seed = samplers.split_seed(seed) with tf.name_scope(self.name + '.one_step'): unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] momentum_distribution = previous_kernel_results.momentum_distribution current_target_log_prob = previous_kernel_results.target_log_prob [init_momentum, init_energy, log_slice_sample] = self._start_trajectory_batched( current_state, current_target_log_prob, momentum_distribution=momentum_distribution, seed=start_trajectory_seed) def _copy(v): return v * ps.ones( ps.pad( [2], paddings=[[0, ps.rank(v)]], constant_values=1), dtype=v.dtype) _, init_velocity = mcmc_util.maybe_call_fn_and_grads( get_kinetic_energy_fn(momentum_distribution), [m + 0 for m in init_momentum]) # Breaks cache. initial_state = TreeDoublingState( momentum=init_momentum, velocity=init_velocity, state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob) initial_step_state = tf.nest.map_structure(_copy, initial_state) if MULTINOMIAL_SAMPLE: init_weight = tf.zeros_like(init_energy) # log(exp(H0 - H0)) else: init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE) candidate_state = TreeDoublingStateCandidate( state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob, energy=init_energy, weight=init_weight) initial_step_metastate = TreeDoublingMetaState( candidate_state=candidate_state, is_accepted=tf.zeros_like(init_energy, dtype=tf.bool), momentum_sum=init_momentum, energy_diff_sum=tf.zeros_like(init_energy), leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE), continue_tree=tf.ones_like(init_energy, dtype=tf.bool), not_divergence=tf.ones_like(init_energy, dtype=tf.bool)) # Convert the write/read instruction into TensorArray so that it is # compatible with XLA. write_instruction = tf.TensorArray( TREE_COUNT_DTYPE, size=len(self._write_instruction), clear_after_read=False).unstack(self._write_instruction) read_instruction = tf.TensorArray( tf.int32, size=len(self._read_instruction), clear_after_read=False).unstack(self._read_instruction) current_step_meta_info = OneStepMetaInfo( log_slice_sample=log_slice_sample, init_energy=init_energy, write_instruction=write_instruction, read_instruction=read_instruction ) velocity_state_memory = VelocityStateSwap( velocity_swap=self.init_velocity_state_memory(init_momentum), state_swap=self.init_velocity_state_memory(current_state)) step_size = _prepare_step_size( previous_kernel_results.step_size, current_target_log_prob.dtype, len(current_state)) _, _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda (iter_ < self.max_tree_depth) & tf.reduce_any(metastate.continue_tree)), body=lambda iter_, seed, state, metastate: self._loop_tree_doubling( # pylint: disable=g-long-lambda step_size, velocity_state_memory, current_step_meta_info, iter_, state, metastate, momentum_distribution, seed), loop_vars=( tf.zeros([], dtype=tf.int32, name='iter'), loop_seed, initial_step_state, initial_step_metastate), parallel_iterations=self.parallel_iterations, ) kernel_results = PreconditionedNUTSKernelResults( target_log_prob=new_step_metastate.candidate_state.target, grads_target_log_prob=( new_step_metastate.candidate_state.target_grad_parts), step_size=previous_kernel_results.step_size, log_accept_ratio=tf.math.log( new_step_metastate.energy_diff_sum / tf.cast(new_step_metastate.leapfrog_count, dtype=new_step_metastate.energy_diff_sum.dtype)), leapfrogs_taken=( new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps ), is_accepted=new_step_metastate.is_accepted, reach_max_depth=new_step_metastate.continue_tree, has_divergence=~new_step_metastate.not_divergence, energy=new_step_metastate.candidate_state.energy, momentum_distribution=momentum_distribution, seed=seed, ) result_state = new_step_metastate.candidate_state.state if unwrap_state_list: result_state = result_state[0] return result_state, kernel_results
def one_step(self, current_state, previous_kernel_results, seed=None): """Runs one iteration of the Elliptical Slice Sampler. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(log_likelihood_fn(*normal_sampler_fn()))`. previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) seed: Optional seed, for reproducible sampling. Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: TypeError: if `not log_likelihood.dtype.is_floating`. """ with tf.name_scope( mcmc_util.make_name(self.name, 'elliptical_slice', 'one_step')): with tf.name_scope('initialize'): [init_state_parts, init_log_likelihood ] = _prepare_args(self.log_likelihood_fn, current_state, previous_kernel_results.log_likelihood) seed = samplers.sanitize_seed( seed) # Unsalted, for kernel results. normal_seed, u_seed, angle_seed, loop_seed = samplers.split_seed( seed, n=4, salt='elliptical_slice_sampler') normal_samples = self.normal_sampler_fn(normal_seed) # pylint: disable=not-callable normal_samples = list(normal_samples) if mcmc_util.is_list_like( normal_samples) else [normal_samples] u = samplers.uniform( shape=tf.shape(init_log_likelihood), seed=u_seed, dtype=init_log_likelihood.dtype.base_dtype, ) threshold = init_log_likelihood + tf.math.log(u) starting_angle = samplers.uniform( shape=tf.shape(init_log_likelihood), minval=0., maxval=2 * np.pi, name='angle', seed=angle_seed, dtype=init_log_likelihood.dtype.base_dtype, ) starting_angle_min = starting_angle - 2 * np.pi starting_angle_max = starting_angle starting_state_parts = _rotate_on_ellipse(init_state_parts, normal_samples, starting_angle) starting_log_likelihood = self.log_likelihood_fn( *starting_state_parts) # pylint: disable=not-callable def chain_not_done(seed, angle, angle_min, angle_max, current_state_parts, current_log_likelihood): del seed, angle, angle_min, angle_max, current_state_parts return tf.reduce_any(current_log_likelihood < threshold) def sample_next_angle(seed, angle, angle_min, angle_max, current_state_parts, current_log_likelihood): """Slice sample a new angle, and rotate init_state by that amount.""" angle_seed, next_seed = samplers.split_seed(seed) chain_not_done = current_log_likelihood < threshold # Box in on angle. Only update angles for which we haven't generated a # point that beats the threshold. angle_min = tf.where((angle < 0) & chain_not_done, angle, angle_min) angle_max = tf.where((angle >= 0) & chain_not_done, angle, angle_max) new_angle = samplers.uniform( shape=tf.shape(current_log_likelihood), minval=angle_min, maxval=angle_max, seed=angle_seed, dtype=angle.dtype.base_dtype) angle = tf.where(chain_not_done, new_angle, angle) next_state_parts = _rotate_on_ellipse(init_state_parts, normal_samples, angle) new_state_parts = [] broadcasted_chain_not_done = _right_pad_with_ones( chain_not_done, tf.rank(next_state_parts[0])) for n_state, c_state in zip(next_state_parts, current_state_parts): new_state_part = tf.where(broadcasted_chain_not_done, n_state, c_state) new_state_parts.append(new_state_part) return ( next_seed, angle, angle_min, angle_max, new_state_parts, self.log_likelihood_fn(*new_state_parts) # pylint: disable=not-callable ) [ _, next_angle, _, _, next_state_parts, next_log_likelihood, ] = tf.while_loop(cond=chain_not_done, body=sample_next_angle, loop_vars=[ loop_seed, starting_angle, starting_angle_min, starting_angle_max, starting_state_parts, starting_log_likelihood ]) return [ next_state_parts if mcmc_util.is_list_like(current_state) else next_state_parts[0], EllipticalSliceSamplerKernelResults( log_likelihood=next_log_likelihood, angle=next_angle, normal_samples=normal_samples, seed=seed, ), ]
def _windowed_adaptive_impl(n_draws, joint_dist, *, kind, n_chains, proposal_kernel_kwargs, num_adaptation_steps, current_state, dual_averaging_kwargs, trace_fn, return_final_kernel_results, discard_tuning, seed, chain_axis_names, **pins): """Runs windowed sampling using either HMC or NUTS as internal sampler.""" if trace_fn is None: trace_fn = lambda *args: () no_trace = True else: no_trace = False if isinstance(n_chains, int): n_chains = [n_chains] if (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): # A Tensor num_draws argument breaks XLA, which requires static TensorArray # trace_fn result allocation sizes. num_adaptation_steps = ps.convert_to_shape_tensor(num_adaptation_steps) if 'num_adaptation_steps' in dual_averaging_kwargs: warnings.warn( 'Dual averaging adaptation will use the value specified in' ' the `num_adaptation_steps` argument for its construction,' ' hence there is no need to specify it in the' ' `dual_averaging_kwargs` argument.') # TODO(b/180011931): if num_adaptation_steps is small, this throws an error. dual_averaging_kwargs['num_adaptation_steps'] = num_adaptation_steps dual_averaging_kwargs.setdefault( 'reduce_fn', functools.partial( generic_math.reduce_log_harmonic_mean_exp, # There is only one log_accept_prob per chain, and we reduce across # all chains, so typically the all_gather will be gathering scalars, # which should be relatively efficient. experimental_allow_all_gather=True)) # By default, reduce over named axes for step size adaptation dual_averaging_kwargs.setdefault('experimental_reduce_chain_axis_names', chain_axis_names) setup_seed, sample_seed = samplers.split_seed(samplers.sanitize_seed(seed), n=2) (target_log_prob_fn, initial_transformed_position, bijector, step_broadcast, batch_shape, shard_axis_names) = _setup_mcmc(joint_dist, n_chains=n_chains, init_position=current_state, seed=setup_seed, **pins) if proposal_kernel_kwargs.get('step_size') is None: if batch_shape.shape != (0, ): # Scalar batch has a 0-vector shape. raise ValueError( 'Batch target density must specify init_step_size. Got ' f'batch shape {batch_shape} from joint {joint_dist}.') init_step_size = _get_step_size(initial_transformed_position, target_log_prob_fn) else: init_step_size = step_broadcast(proposal_kernel_kwargs['step_size']) proposal_kernel_kwargs.update({ 'target_log_prob_fn': target_log_prob_fn, 'step_size': init_step_size, 'momentum_distribution': _init_momentum(initial_transformed_position, batch_shape=ps.concat([n_chains, batch_shape], axis=0), shard_axis_names=shard_axis_names) }) initial_running_variance = [ sample_stats.RunningVariance.from_stats( # pylint: disable=g-complex-comprehension num_samples=tf.zeros([], part.dtype), mean=tf.zeros_like(part), variance=tf.ones_like(part)) for part in initial_transformed_position ] # TODO(phandu): Consider splitting out warmup and post warmup phases # to avoid executing adaptation code during the post warmup phase. ret = _do_sampling( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=n_draws if discard_tuning else n_draws + num_adaptation_steps, num_burnin_steps=num_adaptation_steps if discard_tuning else 0, initial_position=initial_transformed_position, initial_running_variance=initial_running_variance, bijector=bijector, trace_fn=trace_fn, return_final_kernel_results=return_final_kernel_results, chain_axis_names=chain_axis_names, shard_axis_names=shard_axis_names, seed=sample_seed) if return_final_kernel_results: draws, trace, fkr = ret return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(draws), trace=trace, final_kernel_results=fkr) else: draws, trace = ret if no_trace: return bijector.inverse(draws) else: return sample.StatesAndTrace(all_states=bijector.inverse(draws), trace=trace)
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'mala', 'one_step')): with tf.name_scope('initialize'): # Prepare input arguments to be passed to `_euler_method`. [ current_state_parts, step_size_parts, current_target_log_prob, _, # grads_target_log_prob current_volatility_parts, _, # grads_volatility current_drift_parts, ] = _prepare_args( self.target_log_prob_fn, self.volatility_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, previous_kernel_results.volatility, previous_kernel_results.grads_volatility, previous_kernel_results.diffusion_drift, self.parallel_iterations) # TODO(b/159636942): Clean up after 2020-09-20. if seed is not None: seed = samplers.sanitize_seed(seed) else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) seed = samplers.sanitize_seed(self._seed_stream()) seeds = samplers.split_seed( seed, n=len(current_state_parts), salt='langevin.one_step') random_draw_parts = [] for state_part, part_seed in zip(current_state_parts, seeds): random_draw_parts.append( samplers.normal( shape=tf.shape(state_part), dtype=dtype_util.base_dtype(state_part.dtype), seed=part_seed)) # Number of independent chains run by the algorithm. independent_chain_ndims = prefer_static.rank(current_target_log_prob) # Generate the next state of the algorithm using Euler-Maruyama method. next_state_parts = _euler_method(random_draw_parts, current_state_parts, current_drift_parts, step_size_parts, current_volatility_parts) # Compute helper `UncalibratedLangevinKernelResults` to be processed by # `_compute_log_acceptance_correction` and in the next iteration of # `one_step` function. [ _, # state_parts _, # step_sizes next_target_log_prob, next_grads_target_log_prob, next_volatility_parts, next_grads_volatility, next_drift_parts, ] = _prepare_args( self.target_log_prob_fn, self.volatility_fn, next_state_parts, step_size_parts, parallel_iterations=self.parallel_iterations) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] # Decide whether to compute the acceptance ratio log_acceptance_correction_compute = _compute_log_acceptance_correction( current_state_parts, next_state_parts, current_volatility_parts, next_volatility_parts, current_drift_parts, next_drift_parts, step_size_parts, independent_chain_ndims) log_acceptance_correction_skip = tf.zeros_like(next_target_log_prob) log_acceptance_correction = tf.cond( pred=self.compute_acceptance, true_fn=lambda: log_acceptance_correction_compute, false_fn=lambda: log_acceptance_correction_skip) return [ maybe_flatten(next_state_parts), UncalibratedLangevinKernelResults( log_acceptance_correction=log_acceptance_correction, target_log_prob=next_target_log_prob, grads_target_log_prob=next_grads_target_log_prob, volatility=maybe_flatten(next_volatility_parts), grads_volatility=next_grads_volatility, diffusion_drift=next_drift_parts, seed=seed, ), ]
def one_step(self, current_state, previous_kernel_results, seed=None): """Runs one iteration of Slice Sampler. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. previous_kernel_results: `collections.namedtuple` containing `Tensor`s representing values from previous calls to this function (or from the `bootstrap_results` function.) seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) after taking exactly one step. Has same type and shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. TypeError: if `not target_log_prob.dtype.is_floating`. """ seed = samplers.sanitize_seed(seed) # Retain for diagnostics. with tf.name_scope(mcmc_util.make_name(self.name, 'slice', 'one_step')): with tf.name_scope('initialize'): [current_state_parts, step_sizes, current_target_log_prob ] = _prepare_args(self.target_log_prob_fn, current_state, self.step_size, previous_kernel_results.target_log_prob, maybe_expand=True) max_doublings = ps.convert_to_shape_tensor( value=self.max_doublings, dtype=tf.int32, name='max_doublings') independent_chain_ndims = ps.rank(current_target_log_prob) [ next_state_parts, next_target_log_prob, bounds_satisfied, direction, upper_bounds, lower_bounds ] = _sample_next(self.target_log_prob_fn, current_state_parts, step_sizes, max_doublings, current_target_log_prob, independent_chain_ndims, seed=seed, experimental_shard_axis_names=self. experimental_shard_axis_names) def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), SliceSamplerKernelResults( target_log_prob=next_target_log_prob, bounds_satisfied=bounds_satisfied, direction=direction, upper_bounds=upper_bounds, lower_bounds=lower_bounds, seed=seed, ), ]
def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None): """Executes `model`, creating both samples and distributions.""" ds = [] values_out = [] if samplers.is_stateful_seed(seed): seed_stream = SeedStream(seed, salt='JointDistributionCoroutine') if not self._stateful_to_stateless: seed = None else: seed_stream = None # We got a stateless seed for seed=. # TODO(b/166658748): Make _stateful_to_stateless always True (eliminate it). if self._stateful_to_stateless and (seed is not None or not JAX_MODE): seed = samplers.sanitize_seed(seed, salt='JointDistributionCoroutine') gen = self._model_coroutine() index = 0 d = next(gen) if self._require_root and not isinstance(d, self.Root): raise ValueError('First distribution yielded by coroutine must ' 'be wrapped in `Root`.') try: while True: actual_distribution = d.distribution if isinstance( d, self.Root) else d ds.append(actual_distribution) # Ensure reproducibility even when xs are (partially) set. Always split. stateful_sample_seed = None if seed_stream is None else seed_stream( ) if seed is None: stateless_sample_seed = None else: stateless_sample_seed, seed = samplers.split_seed(seed) if (value is not None and len(value) > index and value[index] is not None): def convert_tree_to_tensor(x, dtype_hint): return tf.convert_to_tensor(x, dtype_hint=dtype_hint) # This signature does not allow kwarg names. Applies # `convert_to_tensor` on the next value. next_value = nest.map_structure_up_to( ds[-1].dtype, # shallow_tree convert_tree_to_tensor, # func value[index], # x ds[-1].dtype) # dtype_hint else: try: next_value = actual_distribution.sample( sample_shape=sample_shape if isinstance( d, self.Root) else (), seed=(stateful_sample_seed if stateless_sample_seed is None else stateless_sample_seed)) except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)) or ( stateful_sample_seed is None): raise msg = ( 'Falling back to stateful sampling for distribution #{index} ' '(0-based) of type `{dist_cls}` with component name ' '{component_name} and `dist.name` "{dist_name}". Please ' 'update to use `tf.random.stateless_*` RNGs. This fallback may ' 'be removed after 20-Dec-2020. ({exc})') component_name = (joint_distribution_lib. get_explicit_name_for_component( ds[-1])) if component_name is None: component_name = '[None specified]' else: component_name = '"{}"'.format(component_name) warnings.warn( msg.format(index=index, component_name=component_name, dist_name=ds[-1].name, dist_cls=type(ds[-1]), exc=str(e))) next_value = actual_distribution.sample( sample_shape=sample_shape if isinstance( d, self.Root) else (), seed=stateful_sample_seed) if self._validate_args: with tf.control_dependencies( self._assert_compatible_shape( index, sample_shape, next_value)): values_out.append( tf.nest.map_structure(tf.identity, next_value)) else: values_out.append(next_value) index += 1 d = gen.send(next_value) except StopIteration: pass return ds, values_out
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. Raises: ValueError: if `inner_kernel` results doesn't contain the member "target_log_prob". """ is_seeded = seed is not None seed = samplers.sanitize_seed(seed) # Retain for diagnostics. proposal_seed, acceptance_seed = samplers.split_seed(seed) with tf.name_scope(mcmc_util.make_name(self.name, 'mh', 'one_step')): # Take one inner step. inner_kwargs = dict(seed=proposal_seed) if is_seeded else {} [ proposed_state, proposed_results, ] = self.inner_kernel.one_step( current_state, previous_kernel_results.accepted_results, **inner_kwargs) if mcmc_util.is_list_like(current_state): proposed_state = tf.nest.pack_sequence_as( current_state, proposed_state) if (not has_target_log_prob(proposed_results) or not has_target_log_prob( previous_kernel_results.accepted_results)): raise ValueError('"target_log_prob" must be a member of ' '`inner_kernel` results.') # Compute log(acceptance_ratio). to_sum = [ proposed_results.target_log_prob, -previous_kernel_results.accepted_results.target_log_prob ] try: if (not mcmc_util.is_list_like( proposed_results.log_acceptance_correction) or proposed_results.log_acceptance_correction): to_sum.append(proposed_results.log_acceptance_correction) except AttributeError: warnings.warn( 'Supplied inner `TransitionKernel` does not have a ' '`log_acceptance_correction`. Assuming its value is `0.`') log_accept_ratio = mcmc_util.safe_sum( to_sum, name='compute_log_accept_ratio') # If proposed state reduces likelihood: randomly accept. # If proposed state increases likelihood: always accept. # I.e., u < min(1, accept_ratio), where u ~ Uniform[0,1) # ==> log(u) < log_accept_ratio log_uniform = tf.math.log( samplers.uniform(shape=prefer_static.shape( proposed_results.target_log_prob), dtype=dtype_util.base_dtype( proposed_results.target_log_prob.dtype), seed=acceptance_seed)) is_accepted = log_uniform < log_accept_ratio next_state = mcmc_util.choose(is_accepted, proposed_state, current_state, name='choose_next_state') kernel_results = MetropolisHastingsKernelResults( accepted_results=mcmc_util.choose( is_accepted, # We strip seeds when populating `accepted_results` because unlike # other kernel result fields, seeds are not a per-chain value. # Thus it is impossible to choose between a previously accepted # seed value and a proposed seed, since said choice would need to # be made on a per-chain basis. mcmc_util.strip_seeds(proposed_results), previous_kernel_results.accepted_results, name='choose_inner_results'), is_accepted=is_accepted, log_accept_ratio=log_accept_ratio, proposed_state=proposed_state, proposed_results=proposed_results, extra=[], seed=seed, ) return next_state, kernel_results
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope( mcmc_util.make_name(self.name, 'snaper_hamiltonian_monte_carlo', 'one_step')): inner_results = previous_kernel_results.inner_results batch_shape = ps.shape( unnest.get_innermost(previous_kernel_results, 'target_log_prob')) reduce_axes = ps.range(0, ps.size(batch_shape)) step = inner_results.step state_ema_points = previous_kernel_results.state_ema_points kernel = self._make_kernel( batch_shape=batch_shape, step=step, state_ema_points=state_ema_points, state=current_state, mean=previous_kernel_results.ema_mean, variance=previous_kernel_results.ema_variance, principal_component=previous_kernel_results. ema_principal_component, ) inner_results = unnest.replace_innermost( inner_results, momentum_distribution=( kernel.inner_kernel.parameters['momentum_distribution']), # pylint: disable=protected-access ) seed = samplers.sanitize_seed(seed) state_parts, inner_results = kernel.one_step( tf.nest.flatten(current_state), inner_results, seed=seed, ) state = tf.nest.pack_sequence_as(current_state, state_parts) state_ema_points, ema_mean, ema_variance = self._update_state_ema( reduce_axes=reduce_axes, state=state, step=step, state_ema_points=state_ema_points, ema_mean=previous_kernel_results.ema_mean, ema_variance=previous_kernel_results.ema_variance, ) (principal_component_ema_points, ema_principal_component) = self._update_principal_component_ema( reduce_axes=reduce_axes, state=state, step=step, principal_component_ema_points=( previous_kernel_results.principal_component_ema_points), ema_principal_component=( previous_kernel_results.ema_principal_component), ) kernel_results = previous_kernel_results._replace( inner_results=inner_results, ema_mean=ema_mean, ema_variance=ema_variance, state_ema_points=state_ema_points, ema_principal_component=ema_principal_component, principal_component_ema_points=principal_component_ema_points, seed=seed, ) return state, kernel_results
def test_sanitize_tensor_or_tensorlike(self): seed = test_util.test_seed(sampler_type='stateless') seed1 = samplers.sanitize_seed(seed=self.evaluate(seed)) seed2 = samplers.sanitize_seed(seed) self.assertAllEqual(seed1, seed2)
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'rwm', 'one_step')): with tf.name_scope('initialize'): if mcmc_util.is_list_like(current_state): current_state_parts = list(current_state) else: current_state_parts = [current_state] current_state_parts = [ tf.convert_to_tensor(s, name='current_state') for s in current_state_parts ] # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.new_state_fn`. # In other words: # - If we were given a seed, we sanitize it to stateless, and # if the `new_state_fn` doesn't like that, we crash and propagate # the error. Rationale: The contract is stateless sampling given # seed, and doing otherwise would not meet it. # - If we were not given a seed, we try `new_state_fn` with a stateless # seed. Rationale: This is the future. # - If it fails with a seed incompatibility problem (as best we can # detect from here), we issue a warning and try it again with a # stateful-style seed. Rationale: User code that didn't set seeds # shouldn't suddenly break. # TODO(b/159636942): Clean up after 2020-09-20. if seed is not None: force_stateless = True seed = samplers.sanitize_seed(seed) else: force_stateless = False if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) stateful_seed = self._seed_stream() seed = samplers.sanitize_seed(stateful_seed) try: next_state_parts = self.new_state_fn(current_state_parts, seed) # pylint: disable=not-callable except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)) or force_stateless: raise msg = ( 'Falling back to `int` seed for `new_state_fn` {}. Please update ' 'to use `tf.random.stateless_*` RNGs. ' 'This fallback may be removed after 10-Sep-2020. ({})') warnings.warn(msg.format(self.new_state_fn, str(e))) seed = None next_state_parts = self.new_state_fn( # pylint: disable=not-callable current_state_parts, stateful_seed) # Compute `target_log_prob` so its available to MetropolisHastings. next_target_log_prob = self.target_log_prob_fn(*next_state_parts) # pylint: disable=not-callable def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedRandomWalkResults( log_acceptance_correction=tf.zeros_like( next_target_log_prob), target_log_prob=next_target_log_prob, seed=samplers.zeros_seed() if seed is None else seed, ), ]
def _windowed_adaptive_impl(n_draws, joint_dist, *, kind, n_chains, proposal_kernel_kwargs, num_adaptation_steps, current_state, dual_averaging_kwargs, trace_fn, return_final_kernel_results, discard_tuning, seed, **pins): """Runs windowed sampling using either HMC or NUTS as internal sampler.""" if trace_fn is None: trace_fn = lambda *args: () no_trace = True else: no_trace = False if (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): # A Tensor num_draws argument breaks XLA, which requires static TensorArray # trace_fn result allocation sizes. num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps) setup_seed, init_seed, seed = samplers.split_seed( samplers.sanitize_seed(seed), n=3) (target_log_prob_fn, initial_transformed_position, bijector, step_broadcast, batch_shape) = _setup_mcmc( joint_dist, n_chains=n_chains, init_position=current_state, seed=setup_seed, **pins) if proposal_kernel_kwargs.get('step_size') is None: if batch_shape.shape != (0,): # Scalar batch has a 0-vector shape. raise ValueError('Batch target density must specify init_step_size. Got ' f'batch shape {batch_shape} from joint {joint_dist}.') init_step_size = _get_step_size(initial_transformed_position, target_log_prob_fn) else: init_step_size = step_broadcast(proposal_kernel_kwargs['step_size']) proposal_kernel_kwargs.update({ 'target_log_prob_fn': target_log_prob_fn, 'step_size': init_step_size, 'momentum_distribution': _init_momentum( initial_transformed_position, batch_shape=ps.concat([[n_chains], batch_shape], axis=0))}) first_window_size, slow_window_size, last_window_size = _get_window_sizes( num_adaptation_steps) all_traces = [] # Using tf.function here and on _slow_window_closure caches tracing # of _fast_window and _slow_window, respectively, within a single # call to windowed sampling. Why not annotate _fast_window and # _slow_window directly? Two reasons: # - Caching across calls to windowed sampling is probably futile, # because the trace function and bijector will be different Python # objects, preventing cache hits. # - The cache of a global tf.function sticks around for the lifetime # of the Python process, potentially leaking memory. @tf.function(autograph=False) def _fast_window_closure(proposal_kernel_kwargs, window_size, initial_position, seed): return _fast_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=window_size, initial_position=initial_position, bijector=bijector, trace_fn=trace_fn, seed=seed) draws, trace, step_size, running_variances = _fast_window_closure( proposal_kernel_kwargs=proposal_kernel_kwargs, window_size=first_window_size, initial_position=initial_transformed_position, seed=init_seed) proposal_kernel_kwargs.update({'step_size': step_size}) all_draws = [[d] for d in draws] all_traces.append(trace) *slow_seeds, seed = samplers.split_seed(seed, n=5) @tf.function(autograph=False) def _slow_window_closure(proposal_kernel_kwargs, window_size, initial_position, running_variances, seed): return _slow_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=window_size, initial_position=initial_position, initial_running_variance=running_variances, bijector=bijector, trace_fn=trace_fn, seed=seed) for idx, slow_seed in enumerate(slow_seeds): window_size = slow_window_size * (2**idx) # TODO(b/180011931): if num_adaptation_steps is small, this throws an error. (draws, trace, step_size, running_variances, momentum_distribution ) = _slow_window_closure( proposal_kernel_kwargs=proposal_kernel_kwargs, window_size=window_size, initial_position=[d[-1] for d in draws], running_variances=running_variances, seed=slow_seed) for all_d, d in zip(all_draws, draws): all_d.append(d) all_traces.append(trace) proposal_kernel_kwargs.update( {'step_size': step_size, 'momentum_distribution': momentum_distribution}) fast_seed, sample_seed = samplers.split_seed(seed) draws, trace, step_size, _ = _fast_window_closure( proposal_kernel_kwargs=proposal_kernel_kwargs, window_size=last_window_size, initial_position=[d[-1] for d in draws], seed=fast_seed) proposal_kernel_kwargs.update({'step_size': step_size}) for all_d, d in zip(all_draws, draws): all_d.append(d) all_traces.append(trace) ret = _do_sampling( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, num_draws=n_draws, initial_position=[d[-1] for d in draws], bijector=bijector, trace_fn=trace_fn, return_final_kernel_results=return_final_kernel_results, seed=sample_seed) if discard_tuning: if return_final_kernel_results: draws, trace, fkr = ret return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(draws), trace=trace, final_kernel_results=fkr) else: draws, trace = ret if no_trace: return bijector.inverse(draws) else: return sample.StatesAndTrace(all_states=bijector.inverse(draws), trace=trace) else: if return_final_kernel_results: draws, trace, fkr = ret for all_d, d in zip(all_draws, draws): all_d.append(d) all_traces.append(trace) return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse( [tf.concat(d, axis=0) for d in all_draws]), trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True), final_kernel_results=fkr) else: draws, trace = ret for all_d, d in zip(all_draws, draws): all_d.append(d) all_states = bijector.inverse([tf.concat(d, axis=0) for d in all_draws]) if no_trace: return all_states else: all_traces.append(trace) return sample.StatesAndTrace( all_states=all_states, trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True))
def minimize(loss_fn, num_steps, optimizer, convergence_criterion=None, batch_convergence_reduce_fn=tf.reduce_all, trainable_variables=None, trace_fn=_trace_loss, return_full_length_trace=True, jit_compile=False, seed=None, name='minimize'): """Minimize a loss function using a provided optimizer. Args: loss_fn: Python callable with signature `loss = loss_fn()`, where `loss` is a `Tensor` loss to be minimized. This may optionally take a `seed` keyword argument, used to specify a per-iteration seed for stochastic loss functions (a stateless `Tensor` seed will be passed; see `tfp.random.sanitize_seed`). num_steps: Python `int` maximum number of steps to run the optimizer. optimizer: Optimizer instance to use. This may be a TF1-style `tf.train.Optimizer`, TF2-style `tf.optimizers.Optimizer`, or any Python object that implements `optimizer.apply_gradients(grads_and_vars)`. convergence_criterion: Optional instance of `tfp.optimizer.convergence_criteria.ConvergenceCriterion` representing a criterion for detecting convergence. If `None`, the optimization will run for `num_steps` steps, otherwise, it will run for at *most* `num_steps` steps, as determined by the provided criterion. Default value: `None`. batch_convergence_reduce_fn: Python `callable` of signature `has_converged = batch_convergence_reduce_fn(batch_has_converged)` whose input is a `Tensor` of boolean values of the same shape as the `loss` returned by `loss_fn`, and output is a scalar boolean `Tensor`. This determines the behavior of batched optimization loops when `loss_fn`'s return value is non-scalar. For example, `tf.reduce_all` will stop the optimization once all members of the batch have converged, `tf.reduce_any` once *any* member has converged, `lambda x: tf.reduce_mean(tf.cast(x, tf.float32)) > 0.5` once more than half have converged, etc. Default value: `tf.reduce_all`. trainable_variables: list of `tf.Variable` instances to optimize with respect to. If `None`, defaults to the set of all variables accessed during the execution of `loss_fn()`. Default value: `None`. trace_fn: Python callable with signature `traced_values = trace_fn( traceable_quantities)`, where the argument is an instance of `tfp.math.MinimizeTraceableQuantities` and the returned `traced_values` may be a `Tensor` or nested structure of `Tensor`s. The traced values are stacked across steps and returned. The default `trace_fn` simply returns the loss. In general, trace functions may also examine the gradients, values of parameters, the state propagated by the specified `convergence_criterion`, if any (if no convergence criterion is specified, this will be `None`), as well as any other quantities captured in the closure of `trace_fn`, for example, statistics of a variational distribution. Default value: `lambda traceable_quantities: traceable_quantities.loss`. return_full_length_trace: Python `bool` indicating whether to return a trace of the full length `num_steps`, even if a convergence criterion stopped the optimization early, by tiling the value(s) traced at the final optimization step. This enables use in contexts such as XLA that require shapes to be known statically. Default value: `True`. jit_compile: If True, compiles the minimization loop using XLA. XLA performs compiler optimizations, such as fusion, and attempts to emit more efficient code. This may drastically improve the performance. See the docs for `tf.function`. (In JAX, this will apply `jax.jit`). Default value: `False`. seed: PRNG seed for stochastic losses; see `tfp.random.sanitize_seed.` Default value: `None`. name: Python `str` name prefixed to ops created by this function. Default value: 'minimize'. Returns: trace: `Tensor` or nested structure of `Tensor`s, according to the return type of `trace_fn`. Each `Tensor` has an added leading dimension stacking the trajectory of the traced values over the course of the optimization. The size of this dimension is equal to `num_steps` if a convergence criterion was not specified and/or `return_full_length_trace=True`, and otherwise it is equal equal to the number of optimization steps taken. ### Examples To minimize the scalar function `(x - 5)**2`: ```python x = tf.Variable(0.) loss_fn = lambda: (x - 5.)**2 losses = tfp.math.minimize(loss_fn, num_steps=100, optimizer=tf.optimizers.Adam(learning_rate=0.1)) # In TF2/eager mode, the optimization runs immediately. print("optimized value is {} with loss {}".format(x, losses[-1])) ``` In graph mode (e.g., inside of `tf.function` wrapping), retrieving any Tensor that depends on the minimization op will trigger the optimization: ```python with tf.control_dependencies([losses]): optimized_x = tf.identity(x) # Use a dummy op to attach the dependency. ``` We can attempt to automatically detect convergence and stop the optimization by passing an instance of `tfp.optimize.convergence_criteria.ConvergenceCriterion`. For example, to stop the optimization once a moving average of the per-step decrease in loss drops below `0.01`: ```python losses = tfp.math.minimize( loss_fn, num_steps=1000, optimizer=tf.optimizers.Adam(learning_rate=0.1), convergence_criterion=( tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01))) ``` Here `num_steps=1000` defines an upper bound: the optimization will be stopped after 1000 steps even if no convergence is detected. In some cases, we may want to track additional context inside the optimization. We can do this by defining a custom `trace_fn`. Note that the `trace_fn` is passed the loss and gradients, as well as any auxiliary state maintained by the convergence criterion (if any), for example, moving averages of the loss or gradients, but it may also report the values of trainable parameters or other derived quantities by capturing them in its closure. For example, we can capture `x` and track its value over the optimization: ```python # `x` is the tf.Variable instance defined above. trace_fn = lambda traceable_quantities: { 'loss': traceable_quantities.loss, 'x': x} trace = tfp.math.minimize(loss_fn, num_steps=100, optimizer=tf.optimizers.Adam(0.1), trace_fn=trace_fn) print(trace['loss'].shape, # => [100] trace['x'].shape) # => [100] ``` When optimizing a batch of losses, some batch members will converge before others. The optimization will continue until the condition defined by the `batch_convergence_reduce_fn` becomes `True`. During these additional steps, converged elements will continue to be updated and may become unconverged. The convergence status of batch members can be diagnosed by tracing `has_converged`: ```python batch_size = 10 x = tf.Variable([0.] * batch_size) trace_fn = lambda traceable_quantities: { 'loss': traceable_quantities.loss, 'has_converged': traceable_quantities.has_converged} trace = tfp.math.minimize(loss_fn, num_steps=100, optimizer=tf.optimizers.Adam(0.1),, trace_fn=trace_fn, convergence_criterion=( tfp.optimizers.convergence_criteria.LossNotDecreasing(atol=0.01))) for i in range(batch_size): print('Batch element {} final state is {}converged.' ' It first converged at step {}.'.format( i, '' if has_converged[-1, i] else 'not ', np.argmax(trace.has_converged[:, i]))) ``` """ if jit_compile: # Run the entire minimization inside a jit-compiled function. This is # typically faster than jit-compiling the individual steps. parameters = dict(locals()) parameters['jit_compile'] = False @tf.function(autograph=False, jit_compile=True) def run_jitted_minimize(): return minimize(**parameters) return run_jitted_minimize() def convergence_detected(step, seed, trace_arrays, has_converged=None, convergence_criterion_state=None): del step del seed del trace_arrays del convergence_criterion_state return (has_converged is not None # Convergence criterion in use. and batch_convergence_reduce_fn(has_converged)) # Main optimization routine. with tf.name_scope(name) as name: seed = samplers.sanitize_seed(seed, salt='minimize') # Take an initial training step to obtain the initial loss and values, which # will define the shape(s) of the `TensorArray`(s) that we create to hold # the results, and are used to initialize the convergence criterion. # This will trigger tf.function tracing of `optimizer_step_fn`, which is # then reused inside the training loop (i.e., it is only traced once). optimizer_step_fn = _make_optimizer_step_fn( loss_fn=loss_fn, optimizer=optimizer, trainable_variables=trainable_variables) initial_loss, initial_grads, initial_parameters = optimizer_step_fn( seed=seed) has_converged = None initial_convergence_criterion_state = None if convergence_criterion is not None: has_converged = tf.zeros(tf.shape(initial_loss), dtype=tf.bool) initial_convergence_criterion_state = convergence_criterion.bootstrap( initial_loss, initial_grads, initial_parameters) initial_traced_values = trace_fn( MinimizeTraceableQuantities( loss=initial_loss, gradients=initial_grads, parameters=initial_parameters, step=0, has_converged=has_converged, convergence_criterion_state=initial_convergence_criterion_state )) trace_arrays = _initialize_arrays( initial_values=initial_traced_values, num_steps=num_steps, truncate_at_convergence=(convergence_criterion is not None and not return_full_length_trace)) # Run the optimization loop. with tf.control_dependencies([initial_loss]): potential_loop_vars = (1, seed, trace_arrays, has_converged, initial_convergence_criterion_state) results = tf.while_loop( cond=lambda *args: tf.logical_not(convergence_detected(*args)), # pylint: disable=no-value-for-parameter body=_make_training_loop_body( optimizer_step_fn=optimizer_step_fn, convergence_criterion=convergence_criterion, trace_fn=trace_fn), loop_vars=[x for x in potential_loop_vars if x is not None], parallel_iterations=1, maximum_iterations=num_steps - 1) indices, _, trace_arrays = results[:3] # Guaranteed to be present. if convergence_criterion is not None and return_full_length_trace: # Fill out the trace by tiling the last written values. last_written_idx = tf.reduce_max(indices) - 1 trace_arrays = tf.nest.map_structure( lambda ta: _tile_last_written_value(ta, last_written_idx), trace_arrays) return tf.nest.map_structure(lambda array: array.stack(), trace_arrays)
def _sample_n(self, n, seed=None): loc, scale, low, high = self._loc_scale_low_high() batch_shape = self._batch_shape_tensor( loc=loc, scale=scale, low=low, high=high) sample_and_batch_shape = ps.concat([[n], batch_shape], 0) # TODO(b/162522020): Use this behavior unconditionally. if (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): return tf.random.stateless_parameterized_truncated_normal( shape=sample_and_batch_shape, means=loc, stddevs=scale, minvals=low, maxvals=high, seed=samplers.sanitize_seed(seed)) flat_batch_and_sample_shape = tf.stack([tf.reduce_prod(batch_shape), n]) # In order to be reparameterizable we sample on the truncated_normal of # unit variance and mean and scale (but with the standardized # truncation bounds). @tf.custom_gradient def _std_samples_with_gradients(lower, upper): """Standard truncated Normal with gradient support for low, high.""" # Note: Unlike the convention in TFP, parameterized_truncated_normal # returns a tensor with the final dimension being the sample dimension. std_samples = random_ops.parameterized_truncated_normal( shape=flat_batch_and_sample_shape, means=0.0, stddevs=1.0, minvals=lower, maxvals=upper, dtype=self.dtype, seed=seed) def grad(dy): """Computes a derivative for the min and max parameters. This function implements the derivative wrt the truncation bounds, which get blocked by the sampler. We use a custom expression for numerical stability instead of automatic differentiation on CDF for implicit gradients. Args: dy: output gradients Returns: The standard normal samples and the gradients wrt the upper bound and lower bound. """ # std_samples has an extra dimension (the sample dimension), expand # lower and upper so they broadcast along this dimension. # See note above regarding parameterized_truncated_normal, the sample # dimension is the final dimension. lower_broadcast = lower[..., tf.newaxis] upper_broadcast = upper[..., tf.newaxis] cdf_samples = ((special_math.ndtr(std_samples) - special_math.ndtr(lower_broadcast)) / (special_math.ndtr(upper_broadcast) - special_math.ndtr(lower_broadcast))) # tiny, eps are tolerance parameters to ensure we stay away from giving # a zero arg to the log CDF expression. tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps) du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) + tf.math.log(cdf_samples)) dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) + tf.math.log1p(-cdf_samples)) # Reduce the gradient across the samples grad_u = tf.reduce_sum(dy * du, axis=-1) grad_l = tf.reduce_sum(dy * dl, axis=-1) return [grad_l, grad_u] return std_samples, grad std_low, std_high = self._standardized_low_and_high( low=low, high=high, loc=loc, scale=scale) low_high_shp = tf.broadcast_dynamic_shape( tf.shape(std_low), tf.shape(std_high)) std_low = tf.broadcast_to(std_low, low_high_shp) std_high = tf.broadcast_to(std_high, low_high_shp) std_samples = _std_samples_with_gradients( tf.reshape(std_low, [-1]), tf.reshape(std_high, [-1])) # The returned shape is [flat_batch x n] std_samples = tf.transpose(std_samples, perm=[1, 0]) std_samples = tf.reshape(std_samples, sample_and_batch_shape) return std_samples * scale[tf.newaxis] + loc[tf.newaxis]
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ # The code below propagates one step states of shape # [n_replica] + batch_shape + event_shape. # # The step is done in three parts: # 1) Call one_step to transition states via a tempered version of # self.target_log_prob_fn (see _replica_target_log_prob). # 2) Permute values in states # 3) Update state-dependent values, such as log_probs. # # We chose to swap states, rather than temperatures, because... # (i) If swapping temperatures, you *still* have to swap log_probs to # determine acceptance, as well as states (for kernel results). # So it's just as difficult to swap temperatures. # (ii) If swapping temperatures, you have to take care to swap any user- # supplied temperature related things (like step size). # A-priori, we don't know what else will need to be swapped! # (iii)In both cases, the kernel results need to be updated in a non-trivial # manner....so we either special-case, or use bootstrap. with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.inverse_temperatures, name='inverse_temperatures') target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn( target_log_prob_fn=self.target_log_prob_fn, inverse_temperatures=inverse_temperatures, untempered_log_prob_fn=self.untempered_log_prob_fn, tempered_log_prob_fn=self.tempered_log_prob_fn, ) # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel) except TypeError as e: if 'argument' not in str(e): raise raise TypeError( '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` ' 'argument. `TransitionKernel` instances now receive seeds via ' '`one_step`.') seed = samplers.sanitize_seed(seed) # Retain for diagnostics. inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3) # Step the inner TransitionKernel. [ pre_swap_replica_states, pre_swap_replica_results, ] = inner_kernel.one_step( previous_kernel_results.post_swap_replica_states, previous_kernel_results.post_swap_replica_results, seed=inner_seed) pre_swap_replica_target_log_prob = _get_field( # These are tempered log probs (have been divided by temperature). pre_swap_replica_results, 'target_log_prob') dtype = pre_swap_replica_target_log_prob.dtype replica_and_batch_shape = ps.shape( pre_swap_replica_target_log_prob) batch_shape = replica_and_batch_shape[1:] replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob) num_replica = ps.size0(inverse_temperatures) inverse_temperatures = bu.left_justified_broadcast_to( inverse_temperatures, replica_and_batch_shape) # Now that each replica has done one_step, it is time to consider swaps. # swap.shape = [n_replica], and is a "once only" permutation, meaning it # is achievable by a sequence of pairwise permutations, where each element # is moved at most once. # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and # 1, keeping 2 fixed. This exact same swap is considered for *every* # batch member. Of course some batch members may accept and some reject. try: swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed, step_count=previous_kernel_results.step_count), dtype=tf.int32) except TypeError as e: if 'step_count' not in str(e): raise warnings.warn( 'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept ' 'the `step_count` argument. Falling back to omitting the ' 'argument. This fallback will be removed after 24-Oct-2020.' ) swaps = tf.cast( self.swap_proposal_fn( # pylint: disable=not-callable num_replica, batch_shape=batch_shape, seed=swap_seed), dtype=tf.int32) null_swaps = bu.left_justified_expand_dims_like( tf.range(num_replica, dtype=swaps.dtype), swaps) swaps = _maybe_embed_swaps_validation(swaps, null_swaps, self.validate_args) # Un-temper the log probs for use in the swap acceptance ratio. if self.tempered_log_prob_fn is None: # Efficient way of re-evaluating target_log_prob_fn on the # pre_swap_replica_states. untempered_negative_energy_ignoring_ulp = ( # Since untempered_log_prob_fn is None, we may assume # inverse_temperatures > 0 (else the target is improper). pre_swap_replica_target_log_prob / inverse_temperatures) else: # The untempered_log_prob_fn does not factor into the acceptance ratio. # Proof: Suppose the tempered target is # p_k(x) = f(x)^{beta_k} g(x), # So f(x) is tempered, and g(x) is not. Then, the acceptance ratio for # a 1 <--> 2 swap is... # (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2)) # which depends only on f(x), since terms involving g(x) cancel. untempered_negative_energy_ignoring_ulp = self.tempered_log_prob_fn( *pre_swap_replica_states) # Since `swaps` is its own inverse permutation we automatically know the # swap counterpart: range(num_replica). We use this idea to compute the # acceptance in a vectorized manner at the cost of wasting roughly half # our computation. Although we could use `unique` to solve this problem, # we expect the cost of `unique` to be higher than the dozens of wasted # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile. # Note: diffs would normally be "proposed - current" however energy is # flipped since `energy == -log_prob`. # Note: The untempered_log_prob_fn (if provided) is not included in # untempered_pre_swap_replica_target_log_prob, and hence does not factor # into energy_diff. Why? Because, it cancels out in the acceptance ratio. energy_diff = (untempered_negative_energy_ignoring_ulp - mcmc_util.index_remapping_gather( untempered_negative_energy_ignoring_ulp, swaps, name='gather_swap_tlp')) swapped_inverse_temperatures = mcmc_util.index_remapping_gather( inverse_temperatures, swaps, name='gather_swap_temps') inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures # If i and j are swapping, log_accept_ratio[] i and j are equal. log_accept_ratio = (energy_diff * bu.left_justified_expand_dims_to( inverse_temp_diff, replica_and_batch_rank)) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=replica_and_batch_shape, dtype=dtype, seed=logu_seed)) anchor_swaps = tf.minimum(swaps, null_swaps) log_uniform = mcmc_util.index_remapping_gather( log_uniform, anchor_swaps) is_swap_accepted_mask = tf.less(log_uniform, log_accept_ratio, name='is_swap_accepted_mask') def _swap_tensor(x): return mcmc_util.choose( is_swap_accepted_mask, mcmc_util.index_remapping_gather(x, swaps), x) post_swap_replica_states = [ _swap_tensor(s) for s in pre_swap_replica_states ] expanded_null_swaps = bu.left_justified_broadcast_to( null_swaps, replica_and_batch_shape) is_swap_proposed = _compute_swap_notmatrix( # Broadcast both so they have shape [num_replica] + batch_shape. # This (i) makes them have same shape as is_swap_accepted, and # (ii) keeps shape consistent if someday swaps has a batch shape. expanded_null_swaps, bu.left_justified_broadcast_to(swaps, replica_and_batch_shape)) # To get is_swap_accepted in ordered position, we use # _compute_swap_notmatrix on current and next replica positions. post_swap_replica_position = _swap_tensor(expanded_null_swaps) is_swap_accepted = _compute_swap_notmatrix( post_swap_replica_position, expanded_null_swaps) if self._state_includes_replicas: post_swap_states = post_swap_replica_states else: post_swap_states = [s[0] for s in post_swap_replica_states] post_swap_replica_results = _set_swapped_fields_to_nan( _swap_log_prob_and_maybe_grads(pre_swap_replica_results, post_swap_replica_states, inner_kernel)) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = post_swap_states else: states = post_swap_states[0] post_swap_kernel_results = ReplicaExchangeMCKernelResults( post_swap_replica_states=post_swap_replica_states, pre_swap_replica_results=pre_swap_replica_results, post_swap_replica_results=post_swap_replica_results, is_swap_proposed=is_swap_proposed, is_swap_accepted=is_swap_accepted, is_swap_proposed_adjacent=_sub_diag(is_swap_proposed), is_swap_accepted_adjacent=_sub_diag(is_swap_accepted), # Store the original pkr.inverse_temperatures in case its a # `tf.Variable`. inverse_temperatures=previous_kernel_results. inverse_temperatures, swaps=swaps, step_count=previous_kernel_results.step_count + 1, seed=seed, potential_energy=-untempered_negative_energy_ignoring_ulp, ) return states, post_swap_kernel_results
def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_sample') seed = samplers.fold_in(seed, tf.cast(self.replica_id, tf.int32)) return self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
def sample_chain( num_results, current_state, previous_kernel_results=None, kernel=None, num_burnin_steps=0, num_steps_between_results=0, trace_fn=lambda current_state, kernel_results: kernel_results, return_final_kernel_results=False, parallel_iterations=10, seed=None, name=None, ): """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps. This function samples from an Markov chain at `current_state` and whose stationary distribution is governed by the supplied `TransitionKernel` instance (`kernel`). The `current_state` can be represented as a single `Tensor` or a `list` of `Tensors` which collectively represent the current state. This function can sample from multiple chains, in parallel. Whether or not there are multiple chains is dictated by how the `kernel` treats its inputs. Typically, the shape of the independent chains is shape of the result of the `target_log_prob_fn` used by the `kernel` when applied to the given `current_state`. Since MCMC states are correlated, it is sometimes desirable to produce additional intermediate states, and then discard them, ending up with a set of states with decreased autocorrelation. See [Owen (2017)][1]. Such 'thinning' is made possible by setting `num_steps_between_results > 0`. The chain then takes `num_steps_between_results` extra steps between the steps that make it into the results. The extra steps are never materialized, and thus do not increase memory requirements. In addition to returning the chain state, this function supports tracing of auxiliary variables used by the kernel. The traced values are selected by specifying `trace_fn`. By default, all kernel results are traced but in the future the default will be changed to no results being traced, so plan accordingly. See below for some examples of this feature. Args: num_results: Integer number of Markov chain draws. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. num_burnin_steps: Integer number of chain steps to take before starting to collect results. Default value: 0 (i.e., no burn-in). num_steps_between_results: Integer number of chain steps between collecting a result. Only one out of every `num_steps_between_samples + 1` steps is included in the returned results. The number of returned chain states is still equal to `num_results`. Default value: 0 (i.e., no thinning). trace_fn: A callable that takes in the current chain state and the previous kernel results and return a `Tensor` or a nested collection of `Tensor`s that is then traced along with the chain state. return_final_kernel_results: If `True`, then the final kernel results are returned alongside the chain state and the trace specified by the `trace_fn`. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mcmc_sample_chain'). Returns: checkpointable_states_and_trace: if `return_final_kernel_results` is `True`. The return value is an instance of `CheckpointableStatesAndTrace`. all_states: if `return_final_kernel_results` is `False` and `trace_fn` is `None`. The return value is a `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state` but with a prepended `num_results`-size dimension. states_and_trace: if `return_final_kernel_results` is `False` and `trace_fn` is not `None`. The return value is an instance of `StatesAndTrace`. #### Examples ##### Sample from a diagonal-variance Gaussian. I.e., ```none for i=1..n: x[i] ~ MultivariateNormal(loc=0, scale=diag(true_stddev)) # likelihood ``` ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions dims = 10 true_stddev = tf.sqrt(tf.linspace(1., 3., dims)) likelihood = tfd.MultivariateNormalDiag(loc=0., scale_diag=true_stddev) states = tfp.mcmc.sample_chain( num_results=1000, num_burnin_steps=500, current_state=tf.zeros(dims), kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=likelihood.log_prob, step_size=0.5, num_leapfrog_steps=2), trace_fn=None) sample_mean = tf.reduce_mean(states, axis=0) # ==> approx all zeros sample_stddev = tf.sqrt(tf.reduce_mean( tf.squared_difference(states, sample_mean), axis=0)) # ==> approx equal true_stddev ``` ##### Sampling from factor-analysis posteriors with known factors. I.e., ```none # prior w ~ MultivariateNormal(loc=0, scale=eye(d)) for i=1..n: # likelihood x[i] ~ Normal(loc=w^T F[i], scale=1) ``` where `F` denotes factors. ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions # Specify model. def make_prior(dims): return tfd.MultivariateNormalDiag( loc=tf.zeros(dims)) def make_likelihood(weights, factors): return tfd.MultivariateNormalDiag( loc=tf.matmul(weights, factors, adjoint_b=True)) def joint_log_prob(num_weights, factors, x, w): return (make_prior(num_weights).log_prob(w) + make_likelihood(w, factors).log_prob(x)) def unnormalized_log_posterior(w): # Posterior is proportional to: `p(W, X=x | factors)`. return joint_log_prob(num_weights, factors, x, w) # Setup data. num_weights = 10 # == d num_factors = 40 # == n num_chains = 100 weights = make_prior(num_weights).sample(1) factors = tf.random.normal([num_factors, num_weights]) x = make_likelihood(weights, factors).sample() # Sample from Hamiltonian Monte Carlo Markov Chain. # Get `num_results` samples from `num_chains` independent chains. chains_states, kernels_results = tfp.mcmc.sample_chain( num_results=1000, num_burnin_steps=500, current_state=tf.zeros([num_chains, num_weights], name='init_weights'), kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=unnormalized_log_posterior, step_size=0.1, num_leapfrog_steps=2)) # Compute sample stats. sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) # ==> approx equal to weights sample_var = tf.reduce_mean( tf.squared_difference(chains_states, sample_mean), axis=[0, 1]) # ==> less than 1 ``` ##### Custom tracing functions. ```python import tensorflow as tf import tensorflow_probability as tfp tfd = tfp.distributions likelihood = tfd.Normal(loc=0., scale=1.) def sample_chain(trace_fn): return tfp.mcmc.sample_chain( num_results=1000, num_burnin_steps=500, current_state=0., kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=likelihood.log_prob, step_size=0.5, num_leapfrog_steps=2), trace_fn=trace_fn) def trace_log_accept_ratio(states, previous_kernel_results): return previous_kernel_results.log_accept_ratio def trace_everything(states, previous_kernel_results): return previous_kernel_results _, log_accept_ratio = sample_chain(trace_fn=trace_log_accept_ratio) _, kernel_results = sample_chain(trace_fn=trace_everything) acceptance_prob = tf.math.exp(tf.minimum(log_accept_ratio, 0.)) # Equivalent to, but more efficient than: acceptance_prob = tf.math.exp(tf.minimum( kernel_results.log_accept_ratio, 0.)) ``` #### References [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler. _Technical Report_, 2017. http://statweb.stanford.edu/~owen/reports/bestthinning.pdf """ is_seeded = seed is not None seed = samplers.sanitize_seed(seed, salt='mcmc.sample_chain') if not kernel.is_calibrated: warnings.warn( 'supplied `TransitionKernel` is not calibrated. Markov ' 'chain may not converge to intended target distribution.') with tf.name_scope(name or 'mcmc_sample_chain'): num_results = ps.convert_to_shape_tensor(num_results, dtype=tf.int32, name='num_results') num_burnin_steps = ps.convert_to_shape_tensor(num_burnin_steps, dtype=tf.int32, name='num_burnin_steps') num_steps_between_results = ps.convert_to_shape_tensor( num_steps_between_results, dtype=tf.int32, name='num_steps_between_results') current_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name='current_state'), current_state) if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) if trace_fn is None: # It simplifies the logic to use a dummy function here. trace_fn = lambda *args: () no_trace = True else: no_trace = False if trace_fn is sample_chain.__defaults__[4]: warnings.warn( 'Tracing all kernel results by default is deprecated. Set ' 'the `trace_fn` argument to None (the future default ' 'value) or an explicit callback that traces the values ' 'you are interested in.') def _seeded_one_step(seed, *state_and_results): step_seed, passalong_seed = (samplers.split_seed(seed) if is_seeded else (None, seed)) one_step_kwargs = dict(seed=step_seed) if is_seeded else {} return [passalong_seed] + list( kernel.one_step(*state_and_results, **one_step_kwargs)) def _trace_scan_fn(seed_state_and_results, num_steps): seed, next_state, current_kernel_results = loop_util.smart_for_loop( loop_num_iter=num_steps, body_fn=_seeded_one_step, initial_loop_vars=list(seed_state_and_results), parallel_iterations=parallel_iterations) return seed, next_state, current_kernel_results (_, _, final_kernel_results), (all_states, trace) = loop_util.trace_scan( loop_fn=_trace_scan_fn, initial_state=(seed, current_state, previous_kernel_results), elems=tf.one_hot(indices=0, depth=num_results, on_value=1 + num_burnin_steps, off_value=1 + num_steps_between_results, dtype=tf.int32), # pylint: disable=g-long-lambda trace_fn=lambda seed_state_and_results: (seed_state_and_results[ 1], trace_fn(*seed_state_and_results[1:])), # pylint: enable=g-long-lambda parallel_iterations=parallel_iterations) if return_final_kernel_results: return CheckpointableStatesAndTrace( all_states=all_states, trace=trace, final_kernel_results=final_kernel_results) else: if no_trace: return all_states else: return StatesAndTrace(all_states=all_states, trace=trace)
def step_kernel( num_steps, current_state, previous_kernel_results=None, kernel=None, return_final_kernel_results=False, parallel_iterations=10, seed=None, name=None, ): """Takes `num_steps` repeated `TransitionKernel` steps from `current_state`. This is meant to be a minimal driver for executing `TransitionKernel`s; for something more featureful, see `sample_chain`. Args: num_steps: Integer number of Markov chain steps. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s. Warm-start for the auxiliary state needed by the given `kernel`. If not supplied, `step_kernel` will cold-start with `kernel.bootstrap_results`. kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step of the Markov chain. return_final_kernel_results: If `True`, then the final kernel results are returned alongside the chain state after `num_steps` steps are taken. This can be useful to inspect the final auxiliary state, or for a later warm restart. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mcmc_step_kernel'). Returns: next_state: Markov chain state after `num_step` steps are taken, of identical type as `current_state`. final_kernel_results: kernel results, as supplied by `kernel.one_step` after `num_step` steps are taken. This is only returned if `return_final_kernel_results` is `True`. """ is_seeded = seed is not None seed = samplers.sanitize_seed(seed, salt='experimental.mcmc.step_kernel') if not kernel.is_calibrated: warnings.warn( 'supplied `TransitionKernel` is not calibrated. Markov ' 'chain may not converge to intended target distribution.') with tf.name_scope(name or 'mcmc_step_kernel'): num_steps = tf.convert_to_tensor(num_steps, dtype=tf.int32, name='num_steps') current_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name='current_state'), current_state) if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) def _seeded_one_step(seed, *state_and_results): step_seed, passalong_seed = (samplers.split_seed(seed) if is_seeded else (None, seed)) one_step_kwargs = dict(seed=step_seed) if is_seeded else {} return [passalong_seed] + list( kernel.one_step(*state_and_results, **one_step_kwargs)) _, next_state, final_kernel_results = mcmc_util.smart_for_loop( loop_num_iter=num_steps, body_fn=_seeded_one_step, initial_loop_vars=list( (seed, current_state, previous_kernel_results)), parallel_iterations=parallel_iterations) # return semantics are simple enough to not warrant the use of named tuples # as in `sample_chain` if return_final_kernel_results: return next_state, final_kernel_results else: return next_state
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')): if self._store_parameters_in_results: step_size = previous_kernel_results.step_size num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps else: step_size = self.step_size num_leapfrog_steps = self.num_leapfrog_steps [ current_state_parts, step_sizes, current_target_log_prob, current_target_log_prob_grad_parts, ] = _prepare_args( self.target_log_prob_fn, current_state, step_size, previous_kernel_results.target_log_prob, previous_kernel_results.grads_target_log_prob, maybe_expand=True, state_gradients_are_stopped=self.state_gradients_are_stopped) seed = samplers.sanitize_seed(seed) # Retain for diagnostics. seeds = samplers.split_seed(seed, n=len(current_state_parts)) seeds = distribute_lib.fold_in_axis_index( seeds, self.experimental_shard_axis_names) current_momentum_parts = [] for part_seed, x in zip(seeds, current_state_parts): current_momentum_parts.append( samplers.normal(shape=ps.shape(x), dtype=self._momentum_dtype or dtype_util.base_dtype(x.dtype), seed=part_seed)) integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes, num_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target_log_prob, next_target_log_prob_grad_parts, ] = integrator(current_momentum_parts, current_state_parts, current_target_log_prob, current_target_log_prob_grad_parts) if self.state_gradients_are_stopped: next_state_parts = [ tf.stop_gradient(x) for x in next_state_parts ] def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] independent_chain_ndims = ps.rank(current_target_log_prob) new_kernel_results = previous_kernel_results._replace( log_acceptance_correction=_compute_log_acceptance_correction( current_momentum_parts, next_momentum_parts, independent_chain_ndims, shard_axis_names=self.experimental_shard_axis_names), target_log_prob=next_target_log_prob, grads_target_log_prob=next_target_log_prob_grad_parts, initial_momentum=current_momentum_parts, final_momentum=next_momentum_parts, seed=seed, ) return maybe_flatten(next_state_parts), new_kernel_results
def _windowed_adaptive_impl(n_draws, joint_dist, *, kind, n_chains, proposal_kernel_kwargs, num_adaptation_steps, dual_averaging_kwargs, trace_fn, return_final_kernel_results, discard_tuning, seed, **pins): """Runs windowed sampling using either HMC or NUTS as internal sampler.""" if trace_fn is None: trace_fn = lambda *args: () no_trace = True else: no_trace = False num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps) setup_seed, init_seed, seed = samplers.split_seed( samplers.sanitize_seed(seed), n=3) target_log_prob_fn, initial_transformed_position, bijector = _setup_mcmc( joint_dist, n_chains=n_chains, seed=setup_seed, **pins) first_window_size, slow_window_size, last_window_size = _get_window_sizes( num_adaptation_steps) # If we (over) optimistically assume good scaling, this will be near the # optimal step size, see Langmore, Ian, Michael Dikovsky, Scott Geraedts, # Peter Norgaard, and Rob Von Behren. 2019. “A Condition Number for # Hamiltonian Monte Carlo.” arXiv [stat.CO]. arXiv. # http://arxiv.org/abs/1905.09813. init_step_size = tf.cast( ps.shape(initial_transformed_position)[-1], tf.float32) ** -0.25 all_draws = [] all_traces = [] proposal_kernel_kwargs.update({ 'target_log_prob_fn': target_log_prob_fn, 'step_size': tf.fill([n_chains, 1], init_step_size), 'momentum_distribution': _init_momentum(initial_transformed_position), }) draws, trace, step_size, running_variance = _fast_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=first_window_size, initial_position=initial_transformed_position, bijector=bijector, trace_fn=trace_fn, seed=init_seed) proposal_kernel_kwargs.update({'step_size': step_size}) all_draws.append(draws) all_traces.append(trace) *slow_seeds, seed = samplers.split_seed(seed, n=5) for idx, slow_seed in enumerate(slow_seeds): window_size = slow_window_size * (2**idx) # TODO(b/180011931): if num_adaptation_steps is small, this throws an error. draws, trace, step_size, running_variance, momentum_distribution = _slow_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=window_size, initial_position=draws[-1], initial_running_variance=running_variance, bijector=bijector, trace_fn=trace_fn, seed=slow_seed) all_draws.append(draws) all_traces.append(trace) proposal_kernel_kwargs.update( {'step_size': step_size, 'momentum_distribution': momentum_distribution}) fast_seed, sample_seed = samplers.split_seed(seed) draws, trace, step_size, running_variance = _fast_window( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, dual_averaging_kwargs=dual_averaging_kwargs, num_draws=last_window_size, initial_position=draws[-1], bijector=bijector, trace_fn=trace_fn, seed=fast_seed) proposal_kernel_kwargs.update({'step_size': step_size}) all_draws.append(draws) all_traces.append(trace) ret = _do_sampling( kind=kind, proposal_kernel_kwargs=proposal_kernel_kwargs, num_draws=n_draws, initial_position=draws[-1], bijector=bijector, trace_fn=trace_fn, return_final_kernel_results=return_final_kernel_results, seed=sample_seed) if discard_tuning: if return_final_kernel_results: draws, trace, fkr = ret return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(draws), trace=trace, final_kernel_results=fkr) else: draws, trace = ret if no_trace: return bijector.inverse(draws) else: return sample.StatesAndTrace(all_states=bijector.inverse(draws), trace=trace) else: if return_final_kernel_results: draws, trace, fkr = ret all_draws.append(draws) all_traces.append(trace) return sample.CheckpointableStatesAndTrace( all_states=bijector.inverse(tf.concat(all_draws, axis=0)), trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True), final_kernel_results=fkr) else: draws, trace = ret all_draws.append(draws) all_traces.append(trace) if no_trace: return bijector.inverse(tf.concat(all_draws, axis=0)) else: return sample.StatesAndTrace( all_states=bijector.inverse(tf.concat(all_draws, axis=0)), trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0), *all_traces, expand_composites=True))
def one_step(self, current_state, previous_kernel_results, seed=None): seed = samplers.sanitize_seed(seed) # Retain for diagnostics. start_trajectory_seed, loop_seed = samplers.split_seed(seed) with tf.name_scope(self.name + '.one_step'): state_structure = current_state current_state = tf.nest.flatten(current_state) if (tf.nest.is_nested(state_structure) and (not mcmc_util.is_list_like(state_structure) or len(current_state) != len(state_structure))): # TODO(b/170865194): Support dictionaries and other non-list-like state. raise TypeError( 'NUTS does not currently support nested or ' 'non-list-like state structures (saw: {}).'.format( state_structure)) current_target_log_prob = previous_kernel_results.target_log_prob [init_momentum, init_energy, log_slice_sample ] = self._start_trajectory_batched(current_state, current_target_log_prob, seed=start_trajectory_seed) def _copy(v): return v * ps.ones(ps.pad( [2], paddings=[[0, ps.rank(v)]], constant_values=1), dtype=v.dtype) initial_state = TreeDoublingState( momentum=init_momentum, state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob ) initial_step_state = tf.nest.map_structure(_copy, initial_state) if MULTINOMIAL_SAMPLE: init_weight = tf.zeros_like(init_energy) # log(exp(H0 - H0)) else: init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE) candidate_state = TreeDoublingStateCandidate( state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results. grads_target_log_prob, energy=init_energy, weight=init_weight) initial_step_metastate = TreeDoublingMetaState( candidate_state=candidate_state, is_accepted=tf.zeros_like(init_energy, dtype=tf.bool), momentum_sum=init_momentum, energy_diff_sum=tf.zeros_like(init_energy), leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE), continue_tree=tf.ones_like(init_energy, dtype=tf.bool), not_divergence=tf.ones_like(init_energy, dtype=tf.bool)) # Convert the write/read instruction into TensorArray so that it is # compatible with XLA. write_instruction = tf.TensorArray( TREE_COUNT_DTYPE, size=len(self._write_instruction), clear_after_read=False).unstack(self._write_instruction) read_instruction = tf.TensorArray(tf.int32, size=len(self._read_instruction), clear_after_read=False).unstack( self._read_instruction) current_step_meta_info = OneStepMetaInfo( log_slice_sample=log_slice_sample, init_energy=init_energy, write_instruction=write_instruction, read_instruction=read_instruction) _, _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda (iter_ < self.max_tree_depth) & tf.reduce_any( metastate.continue_tree)), body=lambda iter_, seed, state, metastate: self. _loop_tree_doubling( # pylint: disable=g-long-lambda previous_kernel_results.step_size, previous_kernel_results. momentum_state_memory, current_step_meta_info, iter_, state, metastate, seed), loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'), loop_seed, initial_step_state, initial_step_metastate), parallel_iterations=self.parallel_iterations, ) kernel_results = NUTSKernelResults( target_log_prob=new_step_metastate.candidate_state.target, grads_target_log_prob=( new_step_metastate.candidate_state.target_grad_parts), momentum_state_memory=previous_kernel_results. momentum_state_memory, step_size=previous_kernel_results.step_size, log_accept_ratio=tf.math.log( new_step_metastate.energy_diff_sum / tf.cast(new_step_metastate.leapfrog_count, dtype=new_step_metastate.energy_diff_sum.dtype)), leapfrogs_taken=(new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps), is_accepted=new_step_metastate.is_accepted, reach_max_depth=new_step_metastate.continue_tree, has_divergence=~new_step_metastate.not_divergence, energy=new_step_metastate.candidate_state.energy, seed=seed, ) result_state = tf.nest.pack_sequence_as( state_structure, new_step_metastate.candidate_state.state) return result_state, kernel_results
def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_sample') seed = distribute_lib.fold_in_axis_index( seed, self.experimental_shard_axis_names) return self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
def one_step(self, state, kernel_results, seed=None): """Takes one Sequential Monte Carlo inference step. Args: state: instance of `tfp.experimental.mcmc.WeightedParticles` representing the current particles with (log) weights. The `log_weights` must be a float `Tensor` of shape `[num_particles, b1, ..., bN]`. The `particles` may be any structure of `Tensor`s, each of which must have shape `concat([log_weights.shape, event_shape])` for some `event_shape`, which may vary across components. kernel_results: instance of `tfp.experimental.mcmc.SequentialMonteCarloResults` representing results from a previous step. seed: Optional seed for reproducible sampling. Returns: state: instance of `tfp.experimental.mcmc.WeightedParticles` representing new particles with (log) weights. kernel_results: instance of `tfp.experimental.mcmc.SequentialMonteCarloResults`. """ with tf.name_scope(self.name): with tf.name_scope('one_step'): seed = samplers.sanitize_seed(seed) proposal_seed, resample_seed = samplers.split_seed(seed) state = WeightedParticles(*state) # Canonicalize. num_particles = ps.size0(state.log_weights) # Propose new particles and update weights for this step, unless it's # the initial step, in which case, use the user-provided initial # particles and weights. proposed_state = self.propose_and_update_log_weights_fn( # Propose state[t] from state[t - 1]. ps.maximum(0, kernel_results.steps - 1), state, seed=proposal_seed) is_initial_step = ps.equal(kernel_results.steps, 0) # TODO(davmre): this `where` assumes the state size didn't change. state = tf.nest.map_structure( lambda a, b: tf.where(is_initial_step, a, b), state, proposed_state) normalized_log_weights = tf.nn.log_softmax(state.log_weights, axis=0) # Every entry of `log_weights` differs from `normalized_log_weights` # by the same normalizing constant. We extract that constant by # examining an arbitrary entry. incremental_log_marginal_likelihood = ( state.log_weights[0] - normalized_log_weights[0]) do_resample = self.resample_criterion_fn(state) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to # use the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a # dealbreaker. resampled_particles, resample_indices = weighted_resampling.resample( state.particles, state.log_weights, self.resample_fn, seed=resample_seed) uniform_weights = tf.fill( ps.shape(state.log_weights), value=-tf.math.log( tf.cast(num_particles, state.log_weights.dtype))) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: ps.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (state.particles, _dummy_indices_like(resample_indices), normalized_log_weights)) return ( WeightedParticles(particles=resampled_particles, log_weights=log_weights), SequentialMonteCarloResults( steps=kernel_results.steps + 1, parent_indices=resample_indices, incremental_log_marginal_likelihood=( incremental_log_marginal_likelihood), accumulated_log_marginal_likelihood=( kernel_results.accumulated_log_marginal_likelihood + incremental_log_marginal_likelihood), seed=seed))
def test_sanitize_none(self): seed1 = samplers.sanitize_seed(seed=None) seed2 = samplers.sanitize_seed(seed=None) self.assertNotAllEqual(seed1, seed2)