def test_seed_stream(salt='Salt of the Earth', hardcoded_seed=None): """Returns a command-line-controllable SeedStream PRNG for unit tests. When seeding unit-test PRNGs, we want: - The seed to be fixed to an arbitrary value most of the time, so the test doesn't flake even if its failure probability is noticeable. - To switch to different seeds per run when using --runs_per_test to measure the test's failure probability. - To set the seed to a specific value when reproducing a low-probability event (e.g., debugging a crash that only some seeds trigger). To those ends, this function returns a `SeedStream` seeded with `test_seed` (which see). The latter respects the command line flags `--fixed_seed=<seed>` and `--vary-seed` (Boolean, default False). `--vary_seed` uses system entropy to produce unpredictable seeds. `--fixed_seed` takes precedence over `--vary_seed` when both are present. Note that TensorFlow graph mode operations tend to read seed state from two sources: a "graph-level seed" and an "op-level seed". test_util.TestCase will set the former to a fixed value per test, but in general it may be necessary to explicitly set both to ensure reproducibility. Args: salt: Optional string wherewith to salt the returned SeedStream. Setting this guarantees independent random numbers across tests. hardcoded_seed: Optional Python value. The seed to use if both the `--vary_seed` and `--fixed_seed` flags are unset. This should usually be unnecessary, since a test should pass with any seed. Returns: strm: A SeedStream instance seeded with 17, unless otherwise specified by arguments or command line flags. """ return SeedStream(test_seed(hardcoded_seed), salt=salt)
def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None): # This function additionally depends on: # self._dist_fn_wrapped # self._dist_fn_args # self._always_use_specified_sample_shape seed = SeedStream(seed, salt='JointDistributionSequential') ds = [] xs = [None] * len(self._dist_fn_wrapped) if value is None else list( value) if len(xs) != len(self._dist_fn_wrapped): raise ValueError('Number of `xs`s must match number of ' 'distributions.') for i, (dist_fn, args) in enumerate( zip(self._dist_fn_wrapped, self._dist_fn_args)): ds.append(dist_fn(*xs[:i])) # Chain rule of probability. if xs[i] is None: # TODO(b/129364796): We should ignore args prefixed with `_`; this # would mean we more often identify when to use `sample_shape=()` # rather than `sample_shape=sample_shape`. xs[i] = ds[-1].sample( () if args and not self._always_use_specified_sample_shape else sample_shape, seed=seed()) else: xs[i] = nest.map_structure_up_to( ds[-1].dtype, lambda x, dtype: tf.convert_to_tensor(x, dtype_hint=dtype), xs[i], ds[-1].dtype) seed( ) # Ensure reproducibility even when xs are (partially) set. # Note: we could also resolve distributions up to the first non-`None` in # `self._model_flatten(value)`, however we omit this feature for simplicity, # speed, and because it has not yet been requested. return ds, xs
def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None): """Executes `model`, creating both samples and distributions.""" ds = [] values_out = [] seed = SeedStream('JointDistributionCoroutine', seed) gen = self._model() index = 0 d = next(gen) if not isinstance(d, self.Root): raise ValueError('First distribution yielded by coroutine must ' 'be wrapped in `Root`.') try: while True: actual_distribution = d.distribution if isinstance(d, self.Root) else d ds.append(actual_distribution) if (value is not None and len(value) > index and value[index] is not None): seed() next_value = value[index] else: next_value = actual_distribution.sample( sample_shape=sample_shape if isinstance(d, self.Root) else (), seed=seed()) if self._validate_args: with tf.control_dependencies( self._assert_compatible_shape( index, sample_shape, next_value)): values_out.append(tf.nest.map_structure(tf.identity, next_value)) else: values_out.append(next_value) index += 1 d = gen.send(next_value) except StopIteration: pass return ds, values_out
def _sample_n(self, n, seed=None): seed = SeedStream(seed, 'dirichlet_multinomial') concentration = tf.convert_to_tensor(self._concentration) total_count = tf.convert_to_tensor(self._total_count) n_draws = tf.cast(total_count, dtype=tf.int32) k = self._event_shape_tensor(concentration)[0] alpha = tf.math.multiply( tf.ones_like(total_count[..., tf.newaxis]), concentration, name='alpha') unnormalized_logits = tf.math.log( tf.random.gamma( shape=[n], alpha=alpha, dtype=self.dtype, seed=seed())) x = multinomial.draw_sample( 1, k, unnormalized_logits, n_draws, self.dtype, seed()) final_shape = tf.concat( [[n], self._batch_shape_tensor(concentration, total_count), [k]], 0) return tf.reshape(x, final_shape)
def make_transform_hmc_kernel_fn( target_log_prob_fn, init_state, scalings, seed=None): """Generate a transform hmc kernel.""" with tf.name_scope('make_transformed_hmc_kernel_fn'): seed = SeedStream(seed, salt='make_transformed_hmc_kernel_fn') # TransformedTransitionKernel doesn't modify the input step size, thus we # need to pass the appropriate step size that are already in unconstrained # space state_std = [ tf.math.reduce_std(bij.inverse(x), axis=0, keepdims=True) for x, bij in zip(init_state, unconstraining_bijectors) ] step_size = compute_hmc_step_size(scalings, state_std, num_leapfrog_steps) return transformed_kernel.TransformedTransitionKernel( hmc.HamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, num_leapfrog_steps=num_leapfrog_steps, step_size=step_size, seed=seed), unconstraining_bijectors)
def _sample_3d(self, n, mean_direction, concentration, seed=None): """Specialized inversion sampler for 3D.""" seed = SeedStream(seed, salt='von_mises_fisher_3d') u_shape = tf.concat([[n], self._batch_shape_tensor( mean_direction=mean_direction, concentration=concentration)], axis=0) z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype) # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could # be bisected for bounded sampling runtime (i.e. not rejection sampling). # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/ # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa # We must protect against both kappa and z being zero. safe_conc = tf.where(concentration > 0, concentration, tf.ones_like(concentration)) safe_z = tf.where(z > 0, z, tf.ones_like(z)) safe_u = 1 + tf.reduce_logsumexp( [tf.math.log(safe_z), tf.math.log1p(-safe_z) - 2 * safe_conc], axis=0) / safe_conc # Limit of the above expression as kappa->0 is 2*z-1 u = tf.where(concentration > 0., safe_u, 2 * z - 1) # Limit of the expression as z->0 is -1. u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u) if not self._allow_nan_stats: u = tf.debugging.check_numerics(u, 'u in _sample_3d') return u[..., tf.newaxis]
def _fn(state_parts, seed): """Adds a normal perturbation to the input state. Args: state_parts: A list of `Tensor`s of any shape and real dtype representing the state parts of the `current_state` of the Markov chain. seed: `int` or None. The random seed for this `Op`. If `None`, no seed is applied. Default value: `None`. Returns: perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same shape and type as the `state_parts`. Raises: ValueError: if `scale` does not broadcast with `state_parts`. """ with tf.compat.v1.name_scope(name, 'random_walk_normal_fn', values=[state_parts, scale, seed]): scales = scale if mcmc_util.is_list_like(scale) else [scale] if len(scales) == 1: scales *= len(state_parts) if len(state_parts) != len(scales): raise ValueError('`scale` must broadcast with `state_parts`.') seed_stream = SeedStream(seed, salt='RandomWalkNormalFn') next_state_parts = [ tf.random.normal(mean=state_part, stddev=scale_part, shape=tf.shape(input=state_part), dtype=state_part.dtype.base_dtype, seed=seed_stream()) for scale_part, state_part in zip(scales, state_parts) ] return next_state_parts
def _sample_n(self, n, seed=None): # The sampling method comes from the fact that if: # X ~ Normal(0, 1) # Z ~ Chi2(df) # Y = |X| / sqrt(Z / df) # then: # Y ~ HalfStudentT(df). df = tf.convert_to_tensor(self.df) loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) batch_shape = self._batch_shape_tensor(df=df, loc=loc, scale=scale) shape = tf.concat([[n], batch_shape], 0) seed = SeedStream(seed, "half_student_t") abs_normal_sample = tf.math.abs( tf.random.normal(shape, dtype=self.dtype, seed=seed())) df = df * tf.ones(batch_shape, dtype=self.dtype) gamma_sample = tf.random.gamma([n], 0.5 * df, beta=0.5, dtype=self.dtype, seed=seed()) samples = abs_normal_sample * tf.math.rsqrt(gamma_sample / df) return samples * scale + loc # Abs(scale) not wanted.
def __init__(self, inner_kernel, seed=None, name=None): """Instantiates this object. Args: inner_kernel: `TransitionKernel`-like object which has `collections.namedtuple` `kernel_results` and which contains a `target_log_prob` member and optionally a `log_acceptance_correction` member. seed: Python integer to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "mh_kernel"). Returns: metropolis_hastings_kernel: Instance of `TransitionKernel` which wraps the input transition kernel with the Metropolis-Hastings algorithm. """ if inner_kernel.is_calibrated: warnings.warn('Supplied `TransitionKernel` is already calibrated. ' 'Composing `MetropolisHastings` `TransitionKernel` ' 'may not be required.') self._seed_stream = SeedStream(seed, 'metropolis_hastings_one_step') self._parameters = dict(inner_kernel=inner_kernel, seed=seed, name=name)
def _sample_n(self, n, seed=None): # Generate samples using: # mu + sigma* sgn(U-0.5)* sqrt(X^2 + Y^2 + Z^2) U~Unif; X,Y,Z ~N(0,1) seed = SeedStream(seed, salt='DoublesidedMaxwell') loc = tf.convert_to_tensor(self.loc) scale = tf.convert_to_tensor(self.scale) shape = prefer_static.pad(self._batch_shape_tensor(loc=loc, scale=scale), paddings=[[1, 0]], constant_values=n) # Generate one-sided Maxwell variables by using 3 Gaussian variates norm_rvs = tf.random.normal(shape=prefer_static.pad(shape, paddings=[[0, 1]], constant_values=3), dtype=self.dtype, seed=seed()) maxwell_rvs = tf.norm(norm_rvs, axis=-1) # Generate random signs for the symmetric variates. random_sign = tfp_math.random_rademacher(shape, seed=seed()) sampled = random_sign * maxwell_rvs * scale + loc return sampled
def _sample_n(self, n, seed=None): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) stream = SeedStream(seed, salt='triangular') shape = tf.concat([[n], self._batch_shape_tensor( low=low, high=high, peak=peak)], axis=0) samples = tf.random.uniform(shape=shape, dtype=self.dtype, seed=stream()) # We use Inverse CDF sampling here. Because the CDF is a quadratic function, # we must use sqrts here. interval_length = high - low return tf.where( # Note the CDF on the left side of the peak is # (x - low) ** 2 / ((high - low) * (peak - low)). # If we plug in peak for x, we get that the CDF at the peak # is (peak - low) / (high - low). Because of this we decide # which part of the piecewise CDF we should use based on the cdf samples # we drew. samples < (peak - low) / interval_length, # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)). low + tf.sqrt(samples * interval_length * (peak - low)), # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak)) high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
def __init__( self, posterior, prior, penalty_weight=None, posterior_penalty_fn=kl_divergence_monte_carlo, posterior_value_fn=tfd.Distribution.sample, seed=None, dtype=tf.float32, name=None): """Base class for variational layers. # mean ==> penalty_weight = 1 / train_size # sum ==> penalty_weight = batch_size / train_size Args: posterior: ... prior: ... penalty_weight: ... posterior_penalty_fn: ... posterior_value_fn: ... seed: ... dtype: ... name: Python `str` prepeneded to ops created by this object. Default value: `None` (i.e., `type(self).__name__`). """ super(VariationalLayer, self).__init__(name=name) self._posterior = posterior self._prior = prior self._penalty_weight = penalty_weight self._posterior_penalty_fn = posterior_penalty_fn self._posterior_value_fn = posterior_value_fn self._seed = SeedStream(seed, salt=self.name) self._dtype = dtype tf.nest.assert_same_structure(prior.dtype, posterior.dtype, check_types=False)
def sample_sequential_monte_carlo( prior_log_prob_fn, likelihood_log_prob_fn, current_state, max_num_steps=25, max_stage=100, make_kernel_fn=make_rwmh_kernel_fn, tuning_fn=simple_heuristic_tuning, make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn, ess_threshold_ratio=0.5, parallel_iterations=10, seed=None, name=None): """Runs Sequential Monte Carlo to sample from the posterior distribution. This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo) to sample from a series of distributions that slowly interpolates between an initial 'prior' distribution: `exp(prior_log_prob_fn(x))` and the target 'posterior' distribution: `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`, by mutating a collection of MC samples (i.e., particles). The approach is also known as Particle Filter in some literature. The current implemenetation is largely based on Del Moral et al [1], which adapts the tempering sequence adaptively (base on the effective sample size) and the scaling of the mutation kernel (base on the sample covariance of the particles) at each stage. Args: prior_log_prob_fn: Python callable that returns the log density of the prior distribution. likelihood_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the likelihood distribution. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. max_num_steps: The maximum number of kernel transition steps in one mutation of the MC samples. Note that the actual number of steps in one mutation is tuned during sampling and likely lower than the max_num_step. max_stage: Integer number of the stage for increasing the temperature from 0 to 1. make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like object. Must take one argument representing the `TransitionKernel`'s `target_log_prob_fn`. The `target_log_prob_fn` argument represents the `TransitionKernel`'s target log distribution. Note: `sample_sequential_monte_carlo` creates a new `target_log_prob_fn` which is an interpolation between the supplied `target_log_prob_fn` and `proposal_log_prob_fn`; it is this interpolated function which is used as an argument to `make_kernel_fn`. tuning_fn: Python `callable` which takes the number of steps, the log scaling, and the log acceptance ratio from the last mutation and output the number of steps and log scaling for the next mutation. make_tempered_target_log_prob_fn: Python `callable` that takes the `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures` and creates a `target_log_prob_fn` `callable` that pass to `make_kernel_fn`. ess_threshold_ratio: Target ratio for effective sample size. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: Python integer or TFP seedstream to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'sample_sequential_monte_carlo'). Returns: n_stage: Number of the mutation stage SMC ran. final_state: `Tensor` or Python `list` of `Tensor`s representing the final state(s) of the Markov chain(s). The output are the posterior samples. final_kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. #### References [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential Monte Carlo method for approximate Bayesian computation. _Statistics and Computing_, 22.5(1009-1020), 2012. """ with tf.name_scope(name or 'sample_sequential_monte_carlo'): seed_stream = SeedStream(seed, salt='smc_seed') unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] current_state = [ tf.convert_to_tensor(s, dtype_hint=tf.float32) for s in current_state ] # Initial preprocessing at Stage 0 likelihood_log_prob = likelihood_log_prob_fn(*current_state) likelihood_rank = ps.rank(likelihood_log_prob) dimension = ps.reduce_sum([ ps.reduce_prod(ps.shape(x)[likelihood_rank:]) for x in current_state ]) # We infer the particle shapes from the resulting likelihood: # [num_particles, b1, ..., bN] particle_shape = ps.shape(likelihood_log_prob) num_particles, batch_shape = particle_shape[0], particle_shape[1:] effective_sample_size_threshold = tf.cast( num_particles * ess_threshold_ratio, tf.int32) # TODO(b/152412213): Revisit this default parameter. # Default to the optimal scaling of a random walk kernel for a d-dimensional # normal distributed targets: 2.38 ** 2 / d. # For more detail see: # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of # random walk Metropolis algorithms. _The annals of applied probability_. # 1997;7(1):110-20. scale_start = (tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) / tf.constant(dimension, dtype=likelihood_log_prob.dtype)) inverse_temperature = tf.zeros(batch_shape, dtype=likelihood_log_prob.dtype) scalings = ps.ones_like(likelihood_log_prob) * ps.minimum( scale_start, 1.) kernel = make_kernel_fn(make_tempered_target_log_prob_fn( prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature), current_state, scalings, seed=seed_stream) pkr = kernel.bootstrap_results(current_state) _, kernel_target_log_prob = gather_mh_like_result(pkr) particle_info = ParticleInfo( log_accept_prob=ps.zeros_like(likelihood_log_prob), log_scalings=tf.math.log(scalings), tempered_log_prob=kernel_target_log_prob, likelihood_log_prob=likelihood_log_prob, ) current_pkr = SMCResults( num_steps=tf.convert_to_tensor(max_num_steps, dtype=tf.int32, name='num_steps'), inverse_temperature=inverse_temperature, log_marginal_likelihood=tf.zeros_like(inverse_temperature), particle_info=particle_info) def update_weights_temperature(inverse_temperature, likelihood_log_prob): """Calculate the next inverse temperature and update weights.""" likelihood_diff = likelihood_log_prob - tf.reduce_max( likelihood_log_prob, axis=0) def _body_fn(new_beta, upper_beta, lower_beta, eff_size, log_weights): """One iteration of the temperature and weight update.""" new_beta = (lower_beta + upper_beta) / 2.0 log_weights = (new_beta - inverse_temperature) * likelihood_diff log_weights_norm = tf.math.log_softmax(log_weights, axis=0) eff_size = tf.cast( tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm, axis=0)), tf.int32) upper_beta = tf.where( eff_size < effective_sample_size_threshold, new_beta, upper_beta) lower_beta = tf.where( eff_size < effective_sample_size_threshold, lower_beta, new_beta) return new_beta, upper_beta, lower_beta, eff_size, log_weights def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_): # pylint: disable=unused-argument # TODO(junpenglao): revisit threshold below to be dtype specific. threshold = 1e-6 return (tf.math.reduce_any(upper_beta - lower_beta > threshold) & tf.math.reduce_any( eff_size != effective_sample_size_threshold)) (new_beta, upper_beta, lower_beta, eff_size, log_weights) = tf.while_loop( # pylint: disable=unused-variable cond=_cond_fn, body=_body_fn, loop_vars=(tf.zeros_like(inverse_temperature), tf.fill(ps.shape(inverse_temperature), tf.constant(2, inverse_temperature.dtype)), inverse_temperature, tf.zeros_like(inverse_temperature, dtype=tf.int32), tf.zeros_like(likelihood_diff)), parallel_iterations=parallel_iterations) log_weights = tf.where(new_beta < 1., log_weights, (1. - inverse_temperature) * likelihood_diff) marginal_loglike_ = reduce_logmeanexp( (new_beta - inverse_temperature) * likelihood_log_prob, axis=0) new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.) return marginal_loglike_, new_inverse_temperature, log_weights def mutate(current_state, log_scalings, num_steps, inverse_temperature): """Mutate the state using a Transition kernel.""" with tf.name_scope('mutate_states'): scalings = tf.exp(log_scalings) kernel = make_kernel_fn(make_tempered_target_log_prob_fn( prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature), current_state, scalings, seed=seed_stream) pkr = kernel.bootstrap_results(current_state) kernel_log_accept_ratio, _ = gather_mh_like_result(pkr) def mutate_onestep(i, state, pkr, log_accept_prob_sum): next_state, next_kernel_results = kernel.one_step( state, pkr) kernel_log_accept_ratio, _ = gather_mh_like_result(pkr) log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.) log_accept_prob_sum = log_add_exp(log_accept_prob_sum, log_accept_prob) return i + 1, next_state, next_kernel_results, log_accept_prob_sum ( _, next_state, next_kernel_results, log_accept_prob_sum ) = tf.while_loop( cond=lambda i, *args: i < num_steps, body=mutate_onestep, loop_vars=( tf.zeros([], dtype=tf.int32), current_state, pkr, # we accumulate the acceptance probability in log space. tf.fill( ps.shape(kernel_log_accept_ratio), tf.constant(-np.inf, kernel_log_accept_ratio.dtype))), parallel_iterations=parallel_iterations) _, kernel_target_log_prob = gather_mh_like_result( next_kernel_results) avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log( tf.cast(num_steps + 1, log_accept_prob_sum.dtype)) return (next_state, avg_log_accept_prob_per_particle, kernel_target_log_prob) # One SMC steps. def smc_body_fn(stage, state, smc_kernel_result): """Run one stage of SMC with constant temperature.""" (new_marginal, new_inv_temperature, log_weights) = update_weights_temperature( smc_kernel_result.inverse_temperature, smc_kernel_result.particle_info.likelihood_log_prob) # TODO(b/152412213) Use a tf.scan to better collect debug info. if PRINT_DEBUG: tf.print( 'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:', smc_kernel_result.num_steps, 'accept:', tf.exp( reduce_logmeanexp( smc_kernel_result.particle_info.log_accept_prob, axis=0)), 'scaling:', tf.exp( reduce_logmeanexp( smc_kernel_result.particle_info.log_scalings, axis=0))) (resampled_state, resampled_particle_info), _ = resample_particle_and_info( (state, smc_kernel_result.particle_info), log_weights, seed=seed_stream) next_num_steps, next_log_scalings = tuning_fn( smc_kernel_result.num_steps, resampled_particle_info.log_scalings, resampled_particle_info.log_accept_prob) # Skip tuning at stage 0. next_num_steps = tf.where(stage == 0, smc_kernel_result.num_steps, next_num_steps) next_log_scalings = tf.where(stage == 0, resampled_particle_info.log_scalings, next_log_scalings) next_num_steps = tf.clip_by_value(next_num_steps, 2, max_num_steps) next_state, log_accept_prob, tempered_log_prob = mutate( resampled_state, next_log_scalings, next_num_steps, new_inv_temperature) next_pkr = SMCResults( num_steps=next_num_steps, inverse_temperature=new_inv_temperature, log_marginal_likelihood=( new_marginal + smc_kernel_result.log_marginal_likelihood), particle_info=ParticleInfo( log_accept_prob=log_accept_prob, log_scalings=next_log_scalings, tempered_log_prob=tempered_log_prob, likelihood_log_prob=likelihood_log_prob_fn(*next_state), )) return stage + 1, next_state, next_pkr (n_stage, final_state, final_kernel_results) = tf.while_loop( cond=lambda i, state, pkr: ( # pylint: disable=g-long-lambda (i < max_stage) & tf.reduce_any(pkr.inverse_temperature < 1.)), body=smc_body_fn, loop_vars=(tf.zeros([], dtype=tf.int32), current_state, current_pkr), parallel_iterations=parallel_iterations) if unwrap_state_list: final_state = final_state[0] return n_stage, final_state, final_kernel_results
def __init__(self, target_log_prob_fn, step_size, volatility_fn=None, parallel_iterations=10, compute_acceptance=True, seed=None, name=None): """Initializes Langevin diffusion transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. volatility_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns volatility value at `current_state`. Should return a `Tensor` or Python `list` of `Tensor`s that must broadcast with the shape of `current_state` Defaults to the identity function. parallel_iterations: the number of coordinates for which the gradients of the volatility matrix `volatility_fn` can be computed in parallel. compute_acceptance: Python 'bool' indicating whether to compute the Metropolis log-acceptance ratio used to construct `MetropolisAdjustedLangevinAlgorithm` kernel. seed: Python integer to seed the random number generator. Default value: `None` (i.e., no seed). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mala_kernel'). Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. TypeError: if `volatility_fn` is not callable. """ self._seed_stream = SeedStream(seed, salt='UncalibratedLangevin') # Default value of `volatility_fn` is the identity function. if volatility_fn is None: volatility_fn = lambda *args: 1. if not callable(volatility_fn): raise TypeError('`volatility_fn` must be callable (saw: {})'.format( type(volatility_fn))) self._parameters = dict( target_log_prob_fn=target_log_prob_fn, step_size=step_size, volatility_fn=volatility_fn, compute_acceptance=tf.convert_to_tensor(value=compute_acceptance), seed=seed, parallel_iterations=parallel_iterations, name=name)
def _sample_n(self, n, seed=None): stream = SeedStream(seed, salt="VectorDiffeomixture") x = self.distribution.sample(sample_shape=concat_vectors( [n], self.batch_shape_tensor(), self.event_shape_tensor()), seed=stream()) # shape: [n, B, e] x = [aff.forward(x) for aff in self.endpoint_affine] # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod(self.batch_shape_tensor()) mix_batch_size = tensorshape_util.num_elements( self.mixture_distribution.batch_shape) if mix_batch_size is None: mix_batch_size = tf.reduce_prod( self.mixture_distribution.batch_shape_tensor()) ids = self.mixture_distribution.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]), [batch_size // mix_batch_size])), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `components * quadrature_size` for `batch_size` number of times. stride = tensorshape_util.num_elements( tensorshape_util.with_rank(self.grid.shape[-2:], rank=2)) if stride is None: stride = tf.reduce_prod(tf.shape(self.grid)[-2:]) offset = tf.range(start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype) weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset) # At this point, weight flattened all batch dims into one. # We also need to append a singleton to broadcast with event dims. if tensorshape_util.is_fully_defined(self.batch_shape): new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1] else: new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0) weight = tf.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a # policy we're putting this exception wherever we exploit the bimixture # assumption. raise NotImplementedError( "Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(x))) # Alternatively: # x = weight * x[0] + (1. - weight) * x[1] x = weight * (x[0] - x[1]) + x[1] return x
def __init__(self, target_log_prob_fn, step_size, volatility_fn=None, seed=None, parallel_iterations=10, name=None): """Initializes MALA transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. volatility_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns volatility value at `current_state`. Should return a `Tensor` or Python `list` of `Tensor`s that must broadcast with the shape of `current_state` Defaults to the identity function. seed: Python integer to seed the random number generator. Deprecated, pass seed to `tfp.mcmc.sample_chain`. parallel_iterations: the number of coordinates for which the gradients of the volatility matrix `volatility_fn` can be computed in parallel. Default value: `None` (i.e., no seed). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'mala_kernel'). Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `step_size` or a list with same length as `current_state`. TypeError: if `volatility_fn` is not callable. """ seed_stream = SeedStream(seed, salt='langevin') mh_kwargs = {} if seed is None else dict(seed=seed_stream()) uncal_kwargs = {} if seed is None else dict(seed=seed_stream()) impl = metropolis_hastings.MetropolisHastings( inner_kernel=UncalibratedLangevin( target_log_prob_fn=target_log_prob_fn, step_size=step_size, volatility_fn=volatility_fn, parallel_iterations=parallel_iterations, name=name, **uncal_kwargs), **mh_kwargs) self._impl = impl parameters = impl.inner_kernel.parameters.copy() # Remove `compute_acceptance` parameter as this is not a MALA kernel # `__init__` parameter. del parameters['compute_acceptance'] self._parameters = parameters
def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None): # This function additionally depends on: # self._dist_fn_wrapped # self._dist_fn_args # self._always_use_specified_sample_shape num_dists = len(self._dist_fn_wrapped) if seed is not None and samplers.is_stateful_seed(seed): seed_stream = SeedStream(seed, salt='JointDistributionSequential') else: seed_stream = None if seed is not None: seeds = samplers.split_seed(seed, n=num_dists, salt='JointDistributionSequential') else: seeds = [None] * num_dists ds = [] xs = [None] * num_dists if value is None else list(value) if len(xs) != num_dists: raise ValueError('Number of `xs`s must match number of ' 'distributions.') for i, (dist_fn, args) in enumerate(zip(self._dist_fn_wrapped, self._dist_fn_args)): ds.append(dist_fn(*xs[:i])) # Chain rule of probability. # Ensure reproducibility even when xs are (partially) set. stateful_seed = None if seed_stream is None else seed_stream() if xs[i] is None: # TODO(b/129364796): We should ignore args prefixed with `_`; this # would mean we more often identify when to use `sample_shape=()` # rather than `sample_shape=sample_shape`. try: # TODO(b/147874898): Eliminate the stateful fallback 20 Dec 2020. xs[i] = ds[-1].sample( () if args and not self._always_use_specified_sample_shape else sample_shape, seed=seeds[i]) except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)) or stateful_seed is None: raise if not getattr(self, '_resolving_names', False): # avoid recursion self._resolving_names = True resolved_names = self._flat_resolve_names() self._resolving_names = False msg = ( 'Falling back to stateful sampling for distribution #{i} ' '(0-based) of type `{dist_cls}` with component name ' '"{component_name}" and `dist.name` "{dist_name}". Please ' 'update to use `tf.random.stateless_*` RNGs. This fallback may ' 'be removed after 20-Dec-2020. ({exc})') warnings.warn(msg.format( i=i, dist_name=ds[-1].name, component_name=resolved_names[i], dist_cls=type(ds[-1]), exc=str(e))) xs[i] = ds[-1].sample( () if args and not self._always_use_specified_sample_shape else sample_shape, seed=stateful_seed) else: # This signature does not allow kwarg names. Applies # `convert_to_tensor` on the next value. xs[i] = nest.map_structure_up_to( ds[-1].dtype, # shallow_tree lambda x, dtype: tf.convert_to_tensor(x, dtype_hint=dtype), # func xs[i], # x ds[-1].dtype) # dtype # Note: we could also resolve distributions up to the first non-`None` in # `self._model_flatten(value)`, however we omit this feature for simplicity, # speed, and because it has not yet been requested. return ds, xs
def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None): """Executes `model`, creating both samples and distributions.""" ds = [] values_out = [] if samplers.is_stateful_seed(seed): seed_stream = SeedStream(seed, salt='JointDistributionCoroutine') if not self._stateful_to_stateless: seed = None else: seed_stream = None # We got a stateless seed for seed=. # TODO(b/166658748): Make _stateful_to_stateless always True (eliminate it). if self._stateful_to_stateless and (seed is not None or not JAX_MODE): seed = samplers.sanitize_seed(seed, salt='JointDistributionCoroutine') gen = self._model_coroutine() index = 0 d = next(gen) if self._require_root and not isinstance(d, self.Root): raise ValueError('First distribution yielded by coroutine must ' 'be wrapped in `Root`.') try: while True: actual_distribution = d.distribution if isinstance( d, self.Root) else d ds.append(actual_distribution) # Ensure reproducibility even when xs are (partially) set. Always split. stateful_sample_seed = None if seed_stream is None else seed_stream( ) if seed is None: stateless_sample_seed = None else: stateless_sample_seed, seed = samplers.split_seed(seed) if (value is not None and len(value) > index and value[index] is not None): def convert_tree_to_tensor(x, dtype_hint): return tf.convert_to_tensor(x, dtype_hint=dtype_hint) # This signature does not allow kwarg names. Applies # `convert_to_tensor` on the next value. next_value = nest.map_structure_up_to( ds[-1].dtype, # shallow_tree convert_tree_to_tensor, # func value[index], # x ds[-1].dtype) # dtype_hint else: try: next_value = actual_distribution.sample( sample_shape=sample_shape if isinstance( d, self.Root) else (), seed=(stateful_sample_seed if stateless_sample_seed is None else stateless_sample_seed)) except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)) or ( stateful_sample_seed is None): raise msg = ( 'Falling back to stateful sampling for distribution #{index} ' '(0-based) of type `{dist_cls}` with component name ' '{component_name} and `dist.name` "{dist_name}". Please ' 'update to use `tf.random.stateless_*` RNGs. This fallback may ' 'be removed after 20-Dec-2020. ({exc})') component_name = (joint_distribution_lib. get_explicit_name_for_component( ds[-1])) if component_name is None: component_name = '[None specified]' else: component_name = '"{}"'.format(component_name) warnings.warn( msg.format(index=index, component_name=component_name, dist_name=ds[-1].name, dist_cls=type(ds[-1]), exc=str(e))) next_value = actual_distribution.sample( sample_shape=sample_shape if isinstance( d, self.Root) else (), seed=stateful_sample_seed) if self._validate_args: with tf.control_dependencies( self._assert_compatible_shape( index, sample_shape, next_value)): values_out.append( tf.nest.map_structure(tf.identity, next_value)) else: values_out.append(next_value) index += 1 d = gen.send(next_value) except StopIteration: pass return ds, values_out
def random_von_mises(shape, concentration, dtype=tf.float32, seed=None): """Samples from the standardized von Mises distribution. The distribution is vonMises(loc=0, concentration=concentration), so the mean is zero. The location can then be changed by adding it to the samples. The sampling algorithm is rejection sampling with wrapped Cauchy proposal [1]. The samples are pathwise differentiable using the approach of [2]. Arguments: shape: The output sample shape. concentration: The concentration parameter of the von Mises distribution. dtype: The data type of concentration and the outputs. seed: (optional) The random seed. Returns: Differentiable samples of standardized von Mises. References: [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986; Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf + corrections http://www.nrbook.com/devroye/Devroye_files/errors.pdf [2] Michael Figurnov, Shakir Mohamed, Andriy Mnih. "Implicit Reparameterization Gradients", 2018. """ seed = SeedStream(seed, salt='von_mises') concentration = tf.convert_to_tensor(concentration, dtype=dtype, name='concentration') @tf.custom_gradient def rejection_sample_with_gradient(concentration): """Performs rejection sampling for standardized von Mises. A nested function is required because @tf.custom_gradient does not handle non-tensor inputs such as dtype. Instead, they are captured by the outer scope. Arguments: concentration: The concentration parameter of the distribution. Returns: Differentiable samples of standardized von Mises. """ r = 1. + tf.sqrt(1. + 4. * concentration**2) rho = (r - tf.sqrt(2. * r)) / (2. * concentration) s_exact = (1. + rho**2) / (2. * rho) # For low concentration, s becomes numerically unstable. # To fix that, we use an approximation. Here is the derivation. # First-order Taylor expansion at conc = 0 gives # sqrt(1 + 4 concentration^2) ~= 1 + (2 concentration)^2 / 2. # Therefore, r ~= 2 + 2 concentration. By plugging this into rho, we have # rho ~= conc + 1 / conc - sqrt(1 + 1 / concentration^2). # Let's expand the last term at concentration=0 up to the linear term: # sqrt(1 + 1 / concentration^2) ~= 1 / concentration + concentration / 2 # Thus, rho ~= concentration / 2. Finally, # s = 1 / (2 rho) + rho / 2 ~= 1 / concentration + concentration / 4. # Since concentration is small, we drop the second term and simply use # s ~= 1 / concentration. s_approximate = 1. / concentration # To compute the cutoff, we compute s_exact using mpmath with 30 decimal # digits precision and compare that to the s_exact and s_approximate # computed with dtype. Then, the cutoff is the largest concentration for # which abs(s_exact - s_exact_mpmath) > abs(s_approximate - s_exact_mpmath). s_concentration_cutoff_dict = { tf.float16: 1.8e-1, tf.float32: 2e-2, tf.float64: 1.2e-4, } s_concentration_cutoff = s_concentration_cutoff_dict[dtype] s = tf.where(concentration > s_concentration_cutoff, s_exact, s_approximate) def loop_body(done, u, w): """Resample the non-accepted points.""" # We resample u each time completely. Only its sign is used outside the # loop, which is random. u = tf.random.uniform(shape, minval=-1., maxval=1., dtype=dtype, seed=seed()) z = tf.cos(np.pi * u) # Update the non-accepted points. w = tf.where(done, w, (1. + s * z) / (s + z)) y = concentration * (s - w) v = tf.random.uniform(shape, minval=0., maxval=1., dtype=dtype, seed=seed()) accept = (y * (2. - y) >= v) | (tf.math.log(y / v) + 1. >= y) return done | accept, u, w _, u, w = tf.while_loop( cond=lambda done, *_: ~tf.reduce_all(done), body=loop_body, loop_vars=( tf.zeros(shape, dtype=tf.bool, name='done'), tf.zeros(shape, dtype=dtype, name='u'), tf.zeros(shape, dtype=dtype, name='w'), ), # The expected number of iterations depends on concentration. # It monotonically increases from one iteration for concentration = 0 to # sqrt(2 pi / e) ~= 1.52 iterations for concentration = +inf [1]. # We use a limit of 100 iterations to avoid infinite loops # for very large / nan concentration. maximum_iterations=100, parallel_iterations=1 if seed.original_seed is None else 10, ) x = tf.sign(u) * tf.math.acos(w) def grad(dy): """The gradient of the von Mises samples w.r.t. concentration.""" broadcast_concentration = tf.broadcast_to(concentration, prefer_static.shape(x)) _, dcdf_dconcentration = value_and_gradient( lambda conc: von_mises_cdf(x, conc), broadcast_concentration) inv_prob = tf.exp(-broadcast_concentration * (tf.cos(x) - 1.)) * ( (2. * np.pi) * tf.math.bessel_i0e(broadcast_concentration)) # Compute the implicit reparameterization gradient [2], # dz/dconc = -(dF(z; conc) / dconc) / p(z; conc) ret = dy * (-inv_prob * dcdf_dconcentration) # Sum over the sample dimensions. Assume that they are always the first # ones. num_sample_dimensions = (tf.rank(broadcast_concentration) - tf.rank(concentration)) return tf.reduce_sum(ret, axis=tf.range(num_sample_dimensions)) return x, grad return rejection_sample_with_gradient(concentration)
def batched_las_vegas_algorithm( batched_las_vegas_trial_fn, seed=None, name=None): """Batched Las Vegas Algorithm. This utility encapsulates the notion of a 'batched las_vegas_algorithm' (BLVA): a batch of independent (but not necessarily identical) randomized computations, each of which will eventually terminate after an unknown number of trials [(Babai, 1979)][1]. The number of trials will in general vary across batch points. The computation is parameterized by a callable representing a single trial for the entire batch. The utility runs the callable repeatedly, keeping track of which batch points have succeeded, until all have succeeded. Because we keep running the callable repeatedly until we've generated at least one good value for every batch point, we may generate multiple good values for many batch point. In this case, the particular good batch point returned is deliberately left unspecified. Args: batched_las_vegas_trial_fn: A callable that takes a Python integer PRNG seed and returns two values. (1) A structure of Tensors containing the results of the computation, all with a shape broadcastable with (2) a boolean mask representing whether each batch point succeeded. seed: Python integer or `tfp.util.SeedStream` instance, for seeding PRNG. name: A name to prepend to created ops. Default value: `'batched_las_vegas_algorithm'`. Returns: results, num_iters: A structure of Tensors representing the results of a successful computation for each batch point, and a scalar int32 tensor, the number of calls to `randomized_computation`. #### References [1]: Laszlo Babai. Monte-Carlo algorithms in graph isomorphism testing. Universite de Montreal, D.M.S. No. 79-10. """ with tf.name_scope(name or 'batched_las_vegas_algorithm'): seed_stream = SeedStream(seed, 'batched_las_vegas_algorithm') values, good_values_mask = batched_las_vegas_trial_fn(seed_stream()) num_iters = tf.constant(1) def cond(unused_values, good_values_mask, unused_num_iters): return tf.math.logical_not(tf.reduce_all(good_values_mask)) def body(values, good_values_mask, num_iters): """Batched Las Vegas Algorithm body.""" new_values, new_good_values_mask = batched_las_vegas_trial_fn( seed_stream()) values = tf.nest.map_structure( lambda new, old: tf.where(new_good_values_mask, new, old), *(new_values, values)) good_values_mask = tf.logical_or(good_values_mask, new_good_values_mask) return values, good_values_mask, num_iters+1 (values, _, num_iters) = tf.while_loop( cond, body, (values, good_values_mask, num_iters), parallel_iterations=1 if seed is not None else 10) return values, num_iters
def _sample_n(self, n, seed): components_seed, mix_seed = samplers.split_seed( seed, salt='MixtureSameFamily') try: seed_stream = SeedStream(seed, salt='MixtureSameFamily') except TypeError as e: # Can happen for Tensor seeds. seed_stream = None seed_stream_err = e try: x = self.components_distribution.sample( # [n, B, k, E] n, seed=components_seed) if seed_stream is not None: seed_stream() # Advance even if unused. except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `components_distribution` ' '{} of type `{}`. Please update to use `tf.random.stateless_*` ' 'RNGs. This fallback may be removed after 20-Aug-2020. {}') warnings.warn( msg.format(self.components_distribution.name, type(self.components_distribution), str(e))) x = self.components_distribution.sample( # [n, B, k, E] n, seed=seed_stream()) event_shape = None event_ndims = tensorshape_util.rank(self.event_shape) if event_ndims is None: event_shape = self.components_distribution.event_shape_tensor() event_ndims = ps.rank_from_shape(event_shape) event_ndims_static = tf.get_static_value(event_ndims) num_components = None if event_ndims_static is not None: num_components = tf.compat.dimension_value( x.shape[-1 - event_ndims_static]) # We could also check if num_components can be computed statically from # self.mixture_distribution's logits or probs. if num_components is None: num_components = tf.shape(x)[-1 - event_ndims] # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). npdt = dtype_util.as_numpy_dtype(x.dtype) try: mix_sample = self.mixture_distribution.sample( n, seed=mix_seed) # [n, B] or [n] except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `mixture_distribution` ' '{} of type `{}`. Please update to use `tf.random.stateless_*` ' 'RNGs. This fallback may be removed after 20-Aug-2020. ({})') warnings.warn( msg.format(self.mixture_distribution.name, type(self.mixture_distribution), str(e))) mix_sample = self.mixture_distribution.sample( n, seed=seed_stream()) # [n, B] or [n] mask = tf.one_hot( indices=mix_sample, # [n, B] or [n] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k] or [n, k] # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] . batch_ndims = ps.rank(x) - event_ndims - 1 mask_batch_ndims = ps.rank(mask) - 1 pad_ndims = batch_ndims - mask_batch_ndims mask_shape = ps.shape(mask) mask = tf.reshape(mask, shape=ps.concat([ mask_shape[:-1], ps.ones([pad_ndims], dtype=tf.int32), mask_shape[-1:], ps.ones([event_ndims], dtype=tf.int32), ], axis=0)) if x.dtype in [ tf.bfloat16, tf.float16, tf.float32, tf.float64, tf.complex64, tf.complex128 ]: masked = tf.math.multiply_no_nan(x, mask) else: masked = x * mask ret = tf.reduce_sum(masked, axis=-1 - event_ndims) # [n, B, E] if self._reparameterize: if event_shape is None: event_shape = self.components_distribution.event_shape_tensor() ret = self._reparameterize_sample(ret, event_shape=event_shape) return ret
def __init__(self, target_log_prob_fn, step_size, max_tree_depth=10, unrolled_leapfrog_steps=1, num_trajectories_per_step=1, use_auto_batching=True, stackless=False, backend=None, seed=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. Due to limitations of the underlying auto-batching system, target_log_prob_fn may be invoked with junk data at some batch indexes, which it must process without crashing. (The results at those indexes are ignored). step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The maximum number of leapfrog steps is bounded by `2**max_tree_depth-1` i.e. the number of nodes in a binary tree `max_tree_depth` nodes deep. The default setting of 10 takes up to 1023 leapfrog steps. unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1. This parameter can be useful for amortizing the auto-batching control flow overhead. num_trajectories_per_step: Python `int` giving the number of NUTS trajectories to run as "one" step. Setting this higher than 1 may be favorable for performance by giving the autobatching system the opportunity to batch gradients across consecutive trajectories. The intermediate samples are thinned: only the last sample from the run (in each batch member) is returned. use_auto_batching: Boolean. If `False`, do not invoke the auto-batching system; operate on batch size 1 only. stackless: Boolean. If `True`, invoke the stackless version of the auto-batching system. Only works in Eager mode. backend: Auto-batching backend object. Falls back to a default TensorFlowBackend(). seed: Python integer to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'nuts_kernel'). """ self._parameters = dict(locals()) del self._parameters["self"] self.target_log_prob_fn = target_log_prob_fn self.step_size = step_size if max_tree_depth < 1: raise ValueError( "max_tree_depth must be >= 1 but was {}".format(max_tree_depth)) self.max_tree_depth = max_tree_depth self.unrolled_leapfrog_steps = unrolled_leapfrog_steps self.num_trajectories_per_step = num_trajectories_per_step self.use_auto_batching = use_auto_batching self.stackless = stackless self.backend = backend self._seed_stream = SeedStream(seed, "nuts_one_step") self.name = "nuts_kernel" if name is None else name # TODO(b/125544625): Identify why we need `use_gradient_tape=True`, i.e., # what's different between `tape.gradient` and `tf.gradient`. value_and_gradients_fn = lambda *args: tfp_math.value_and_gradient( # pylint: disable=g-long-lambda self.target_log_prob_fn, args, use_gradient_tape=True) self.value_and_gradients_fn = _embed_no_none_gradient_check( value_and_gradients_fn) max_tree_edges = max_tree_depth - 1 self.many_steps, self.autobatch_context = _make_evolve_trajectory( self.value_and_gradients_fn, max_tree_edges, unrolled_leapfrog_steps, self._seed_stream) self._block_code_cache = {}
def _sample_n(self, n, seed=None): power = tf.convert_to_tensor(self.power) shape = tf.concat([[n], tf.shape(power)], axis=0) has_seed = seed is not None seed = SeedStream(seed, salt='zipf') minval_u = self._hat_integral(0.5, power=power) + 1. maxval_u = self._hat_integral(tf.int64.max - 0.5, power=power) def loop_body(should_continue, k): """Resample the non-accepted points.""" # The range of U is chosen so that the resulting sample K lies in # [0, tf.int64.max). The final sample, if accepted, is K + 1. u = tf.random.uniform( shape, minval=minval_u, maxval=maxval_u, dtype=power.dtype, seed=seed()) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u, power=power) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp( self._log_prob(k + 1, power=power))) return [should_continue & (~accept), k] should_continue, samples = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=power.dtype), # k ], parallel_iterations=1 if has_seed else 10, maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and dtype_util.is_integer(self.dtype): samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: npdt = dtype_util.as_numpy_dtype(self.dtype) v = npdt(dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan) samples = tf.where(should_continue, v, samples) return samples
def _sample_n(self, n, seed=None): seed = SeedStream(seed, salt='vom_mises_fisher') # The sampling strategy relies on the fact that vMF variates are symmetric # about the mean direction. Accordingly, if we have a sampling strategy for # the away-from-mean angle, then we can uniformly sample the remaining # dimensions on the S^{dim-2} sphere for , and rotate these samples from a # (1, 0, 0, ..., 0)-mode distribution into the target orientation. # # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a # von-Mises distributed `x` value in [-1, 1], then uniformly select what # amounts to a "up" or "down" additional degree of freedom after unit # normalizing, followed by a final rotation to the desired mean direction # from a basis of (1, 0). # # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the # unit sphere over which the distribution is uniform, in particular the # circle where x = \hat{x} intersects the unit sphere. We pick a point on # that circle, then rotate to the desired mean direction from a basis of # (1, 0, 0). event_dim = ( tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor()[0]) sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) dim = tf.cast(event_dim - 1, self.dtype) if event_dim == 3: samples_dim0 = self._sample_3d(n, seed=seed) else: # Wood'94 provides a rejection algorithm to sample the x coordinate. # Wood'94 definition of b: # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim # https://stats.stackexchange.com/questions/156729 suggests: b = dim / (2 * self.concentration + tf.sqrt(4 * self.concentration**2 + dim**2)) # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE # https://github.com/nicola-decao/s-vae-tf/ x = (1 - b) / (1 + b) c = self.concentration * x + dim * tf.math.log1p(-x**2) beta = beta_lib.Beta(dim / 2, dim / 2) def cond_fn(w, should_continue): del w return tf.reduce_any(should_continue) def body_fn(w, should_continue): z = beta.sample(sample_shape=sample_batch_shape, seed=seed()) # set_shape needed here because of b/139013403 z.set_shape(w.shape) w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w) w = tf.debugging.check_numerics(w, 'w') unif = tf.random.uniform( sample_batch_shape, seed=seed(), dtype=self.dtype) # set_shape needed here because of b/139013403 unif.set_shape(w.shape) should_continue = tf.logical_and( should_continue, self.concentration * w + dim * tf.math.log1p(-x * w) - c < tf.math.log(unif)) return w, should_continue w = tf.zeros(sample_batch_shape, dtype=self.dtype) should_continue = tf.ones(sample_batch_shape, dtype=tf.bool) samples_dim0 = tf.while_loop( cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0] samples_dim0 = samples_dim0[..., tf.newaxis] if not self._allow_nan_stats: # Verify samples are w/in -1, 1, with useful error output tensors (top # value rather than all values). with tf.control_dependencies([ assert_util.assert_less_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(1.01), data=[tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]]), assert_util.assert_greater_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(-1.01), data=[-tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]]) ]): samples_dim0 = tf.identity(samples_dim0) samples_otherdims_shape = tf.concat([sample_batch_shape, [event_dim - 1]], axis=0) unit_otherdims = tf.math.l2_normalize( tf.random.normal( samples_otherdims_shape, seed=seed(), dtype=self.dtype), axis=-1) samples = tf.concat([ samples_dim0, # we must avoid sqrt(1 - (>1)**2) tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims ], axis=-1) samples = tf.math.l2_normalize(samples, axis=-1) if not self._allow_nan_stats: samples = tf.debugging.check_numerics(samples, 'samples') # Runtime assert that samples are unit length. if not self._allow_nan_stats: worst, idx = tf.math.top_k( tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1])) with tf.control_dependencies([ assert_util.assert_near( dtype_util.as_numpy_dtype(self.dtype)(0), worst, data=[ worst, idx, tf.gather(tf.reshape(samples, [-1, event_dim]), idx) ], atol=1e-4, summarize=100) ]): samples = tf.identity(samples) # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0). # Now, we move the mode to `self.mean_direction` using a rotation matrix. if not self._allow_nan_stats: # Assert that the basis vector rotates to the mean direction, as expected. basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0), self.dtype) with tf.control_dependencies([ assert_util.assert_less( tf.linalg.norm( self._rotate(basis) - self.mean_direction, axis=-1), dtype_util.as_numpy_dtype(self.dtype)(1e-5)) ]): return self._rotate(samples) return self._rotate(samples)
def _sample_n(self, n, seed=None): if self._use_static_graph: # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code # path. samples = [] cat_samples = self.cat.sample(n, seed=seed) stream = SeedStream(seed, salt='Mixture') for c in range(self.num_components): samples.append(self.components[c].sample(n, seed=stream())) stack_axis = -1 - tensorshape_util.rank(self._static_event_shape) x = tf.stack(samples, axis=stack_axis) # [n, B, k, E] npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=npdt(1), off_value=npdt(0)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, tensorshape_util.rank( self._static_event_shape)) # [n, B, k, [1]*e] return tf.reduce_sum(x * mask, axis=stack_axis) # [n, B, E] n = tf.convert_to_tensor(n, name='n') static_n = tf.get_static_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.shape if tensorshape_util.is_fully_defined(static_samples_shape): samples_shape = tensorshape_util.as_list(static_samples_shape) samples_size = tensorshape_util.num_elements(static_samples_shape) else: samples_shape = tf.shape(cat_samples) samples_size = tf.size(cat_samples) static_batch_shape = self.batch_shape if tensorshape_util.is_fully_defined(static_batch_shape): batch_shape = tensorshape_util.as_list(static_batch_shape) batch_size = tensorshape_util.num_elements(static_batch_shape) else: batch_shape = tf.shape(cat_samples)[1:] batch_size = tf.reduce_prod(batch_shape) static_event_shape = self.event_shape if tensorshape_util.is_fully_defined(static_event_shape): event_shape = np.array( tensorshape_util.as_list(static_event_shape), dtype=np.int32) else: event_shape = None # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = tf.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = tf.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] stream = SeedStream(seed, salt='Mixture') for c in range(self.num_components): n_class = tf.size(partitioned_samples_indices[c]) samples_class_c = self.components[c].sample(n_class, seed=stream()) if event_shape is None: batch_ndims = prefer_static.rank_from_shape(batch_shape) event_shape = tf.shape(samples_class_c)[1 + batch_ndims:] # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * tf.range(n_class) + partitioned_batch_indices[c]) samples_class_c = tf.reshape( samples_class_c, tf.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = tf.gather(samples_class_c, lookup_partitioned_batch_indices, name='samples_class_c_gather') samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = tf.reshape(lhs_flat_ret, tf.concat([samples_shape, event_shape], 0)) tensorshape_util.set_shape( ret, tensorshape_util.concatenate(static_samples_shape, self.event_shape)) return ret
def __init__(self, output_shape=(32, 32, 3), num_glow_blocks=3, num_steps_per_block=32, coupling_bijector_fn=None, exit_bijector_fn=None, grab_after_block=None, use_actnorm=True, seed=None, validate_args=False, name='glow'): """Creates the Glow bijector. Args: output_shape: A list of integers, specifying the event shape of the output, of the bijectors forward pass (the image). Specified as [H, W, C]. Default Value: (32, 32, 3) num_glow_blocks: An integer, specifying how many downsampling levels to include in the model. This must divide equally into both H and W, otherwise the bijector would not be invertible. Default Value: 3 num_steps_per_block: An integer specifying how many Affine Coupling and 1x1 convolution layers to include at each level of the spatial hierarchy. Default Value: 32 (i.e. the value used in the original glow paper). coupling_bijector_fn: A function which takes the argument `input_shape` and returns a callable neural network (e.g. a keras.Sequential). The network should either return a tensor with the same event shape as `input_shape` (this will employ additive coupling), a tensor with the same height and width as `input_shape` but twice the number of channels (this will employ affine coupling), or a bijector which takes in a tensor with event shape `input_shape`, and returns a tensor with shape `input_shape`. exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is a function which takes the argument `input_shape` and `output_chan` and returns a callable neural network. The neural network it returns should take a tensor of shape `input_shape` as the input, and return one of three options: A tensor with `output_chan` channels, a tensor with `2 * output_chan` channels, or a bijector. Additional details can be found in the documentation for ExitBijector. grab_after_block: A tuple of floats, specifying what fraction of the remaining channels to remove following each glow block. Glow will take the integer floor of this number multiplied by the remaining number of channels. The default is half at each spatial hierarchy. Default value: None (this will take out half of the channels after each block. use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent initialization is used to initialize this layer. Default value: `False` seed: A seed to control randomness in the 1x1 convolution initialization. Default value: `None` (i.e., non-reproducible sampling). validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` name: Python `str`, name given to ops managed by this object. Default value: `'glow'`. """ # Make sure that the input shape is fully defined. if not tensorshape_util.is_fully_defined(output_shape): raise ValueError('Shape must be fully defined.') if tensorshape_util.rank(output_shape) != 3: raise ValueError('Shape ndims must be 3 for images. Your shape is' '{}'.format(tensorshape_util.rank(output_shape))) num_glow_blocks_ = tf.get_static_value(num_glow_blocks) if (num_glow_blocks_ is None or int(num_glow_blocks_) != num_glow_blocks_ or num_glow_blocks_ < 1): raise ValueError('Argument `num_glow_blocks` must be a statically known' 'positive `int` (saw: {}).'.format(num_glow_blocks)) num_glow_blocks = int(num_glow_blocks_) output_shape = tensorshape_util.as_list(output_shape) h, w, c = output_shape n = num_glow_blocks nsteps = num_steps_per_block # Default Glow: Half of the channels are split off after each block, # and after the final block, no channels are split off. if grab_after_block is None: grab_after_block = tuple([0.5] * (n - 1) + [0.]) # Thing we know must be true: h and w are evenly divisible by 2, n times. # Otherwise, the squeeze bijector will not work. if w % 2**n != 0: raise ValueError('Width must be divisible by 2 at least n times.' 'Saw: {} % {} != 0'.format(w, 2**n)) if h % 2**n != 0: raise ValueError('Height should be divisible by 2 at least n times.') if h // 2**n < 1: raise ValueError('num_glow_blocks ({0}) is too large. The image height ' '({1}) must be divisible by 2 no more than {2} ' 'times.'.format(num_glow_blocks, h, int(np.log(h) / np.log(2.)))) if w // 2**n < 1: raise ValueError('num_glow_blocks ({0}) is too large. The image width ' '({1}) must be divisible by 2 no more than {2} ' 'times.'.format(num_glow_blocks, w, int(np.log(h) / np.log(2.)))) # Other things we want to be true: # - The number of times we take must be equal to the number of glow blocks. if len(grab_after_block) != num_glow_blocks: raise ValueError('Length of grab_after_block ({0}) must match the number' 'of blocks ({1}).'.format(len(grab_after_block), num_glow_blocks)) self._blockwise_splits = self._get_blockwise_splits(output_shape, grab_after_block[::-1]) # Now check on the values of blockwise splits if any([bs[0] < 1 for bs in self._blockwise_splits]): first_offender = [bs[0] for bs in self._blockwise_splits].index(True) raise ValueError('At at least one exit, you are taking out all of your ' 'channels, and therefore have no inputs to later blocks.' ' Try setting grab_after_block to a lower value at index' '{}.'.format(first_offender)) if any(np.isclose(gab, 0) for gab in grab_after_block): # Special case: if specifically exiting no channels, then the exit is # just an identity bijector. pass elif any([bs[1] < 1 for bs in self._blockwise_splits]): first_offender = [bs[1] for bs in self._blockwise_splits].index(True) raise ValueError('At least one of your layers has < 1 output channels. ' 'This means you set grab_at_block too small. ' 'Try setting grab_after_block to a larger value at index' '{}.'.format(first_offender)) # Lets start to build our bijector. We assume that the distribution is 1 # dimensional. First, lets reshape it to an image. glow_chain = [ reshape.Reshape( event_shape_out=[h // 2**n, w // 2**n, c * 4**n], event_shape_in=[h * w * c]) ] seedstream = SeedStream(seed=seed, salt='random_beta') for i in range(n): # This is the shape of the current tensor current_shape = (h // 2**n * 2**i, w // 2**n * 2**i, c * 4**(i + 1)) # This is the shape of the input to both the glow block and exit bijector. this_nchan = sum(self._blockwise_splits[i][0:2]) this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan) glow_chain.append(invert.Invert(ExitBijector(current_shape, self._blockwise_splits[i], exit_bijector_fn))) glow_block = GlowBlock(input_shape=this_input_shape, num_steps=nsteps, coupling_bijector_fn=coupling_bijector_fn, use_actnorm=use_actnorm, seedstream=seedstream) if self._blockwise_splits[i][2] == 0: # All channels are passed to the RealNVP glow_chain.append(glow_block) else: # Some channels are passed around the block. # This is done with the Blockwise bijector. glow_chain.append( blockwise.Blockwise( [glow_block, identity.Identity()], [sum(self._blockwise_splits[i][0:2]), self._blockwise_splits[i][2]])) # Finally, lets expand the channels into spatial features. glow_chain.append( Expand(input_shape=[ h // 2**n * 2**i, w // 2**n * 2**i, c * 4**n // 4**i, ])) glow_chain = glow_chain[::-1] # To finish off, we initialize the bijector with the chain we've built # This way, the rest of the model attributes are taken care of for us. super(Glow, self).__init__( bijectors=glow_chain, validate_args=validate_args, name=name)
def _sample_n(self, n, seed=None): with tf.control_dependencies(self._runtime_assertions): strm = SeedStream(seed, salt="HiddenMarkovModel") num_states = self._num_states batch_shape = self.batch_shape_tensor() batch_size = tf.reduce_prod(batch_shape) # The batch sizes of the underlying initial distributions and # transition distributions might not match the batch size of # the HMM distribution. # As a result we need to ask for more samples from the # underlying distributions and then reshape the results into # the correct batch size for the HMM. init_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._initial_distribution.batch_shape_tensor())) init_state = self._initial_distribution.sample(n * init_repeat, seed=strm()) init_state = tf.reshape(init_state, [n, batch_size]) # init_state :: n batch_size transition_repeat = ( tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod( self._transition_distribution.batch_shape_tensor()[:-1])) def generate_step(state, _): """Take a single step in Markov chain.""" gen = self._transition_distribution.sample(n * transition_repeat, seed=strm()) # gen :: (n * transition_repeat) transition_batch new_states = tf.reshape(gen, [n, batch_size, num_states]) # new_states :: n batch_size num_states old_states_one_hot = tf.one_hot(state, num_states, dtype=tf.int32) # old_states :: n batch_size num_states return tf.reduce_sum(old_states_one_hot * new_states, axis=-1) def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state, parallel_iterations=1) else: # Invoke default parallel_iterations behavior hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0) hidden_states = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps, lambda: init_state[tf.newaxis, ...]) hidden_one_hot = tf.one_hot( hidden_states, num_states, dtype=self._observation_distribution.dtype) # hidden_one_hot :: num_steps n batch_size num_states # The observation distribution batch size might not match # the required batch size so as with the initial and # transition distributions we generate more samples and # reshape. observation_repeat = (batch_size // tf.reduce_prod( self._observation_distribution.batch_shape_tensor()[:-1])) possible_observations = self._observation_distribution.sample( [self._num_steps, observation_repeat * n], seed=strm()) inner_shape = self._observation_distribution.event_shape # possible_observations :: num_steps (observation_repeat * n) # observation_batch[:-1] num_states inner_shape possible_observations = tf.reshape( possible_observations, tf.concat([[self._num_steps, n], batch_shape, [num_states], inner_shape], axis=0)) # possible_observations :: steps n batch_size num_states inner_shape hidden_one_hot = tf.reshape( hidden_one_hot, tf.concat([[self._num_steps, n], batch_shape, [num_states], tf.ones_like(inner_shape)], axis=0)) # hidden_one_hot :: steps n batch_size num_states "inner_shape" observations = tf.reduce_sum(hidden_one_hot * possible_observations, axis=-1 - tf.size(inner_shape)) # observations :: steps n batch_size inner_shape observations = distribution_util.move_dimension( observations, 0, 1 + tf.size(batch_shape)) # returned :: n batch_shape steps inner_shape return observations
def __init__( self, target_log_prob_fn, initial_state, initial_covariance=None, initial_covariance_scaling=2.38**2, covariance_scaling_reducer=0.7, covariance_scaling_limiter=0.01, covariance_burnin=100, target_accept_ratio=0.234, pu=0.95, fixed_variance=0.01, extra_getter_fn=rwm_extra_getter_fn, extra_setter_fn=rwm_extra_setter_fn, log_accept_prob_getter_fn=rwm_log_accept_prob_getter_fn, seed=None, name=None, ): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` and returns its (possibly unnormalized) log-density under the target distribution. initial_state: Python `list` of `Tensor`s representing the initial state of each parameter. initial_covariance: Python `list` of `Tensor`s representing the initial covariance of the proposal. The `initial_covariance` and `initial_state` should have identical `dtype`s and batch dimensions. If `initial_covariance` is `None` then it initialized to a Python `list` of `Tensor`s where each tensor is the identity matrix multiplied by 0.001; the `list` structure will be identical to `initial_state`. The covariance matrix is tuned during the evolution of the MCMC chain. Default value: `None`. initial_covariance_scaling: Python floating point number representing a the initial value of the `covariance_scaling`. The value of `covariance_scaling` is tuned during the evolution of the MCMC chain. Let d represent the number of parameters e.g. as given by the `initial_state`. The ratio given by the `covariance_scaling` divided by d is used to multiply the running covariance. The covariance scaling factor multiplied by the covariance matrix is used in the proposal at each step. Default value: 2.38**2. covariance_scaling_reducer: Python floating point number, bounded over the range (0.5,1.0], representing the constant factor used during the adaptation of the `covariance_scaling`. Default value: 0.7. covariance_scaling_limiter: Python floating point number, bounded between 0.0 and 1.0, which places a limit on the maximum amount the `covariance_scaling` value can be purturbed at each interaction of the MCMC chain. Default value: 0.01. covariance_burnin: Python integer number of steps to take before starting to compute the running covariance. Default value: 100. target_accept_ratio: Python floating point number, bounded between 0.0 and 1.0, representing the target acceptance probability of the Metropolis–Hastings algorithm. Default value: 0.234. pu: Python floating point number, bounded between 0.0 and 1.0, representing the bounded convergence parameter. See `random_walk_mvnorm_fn()` for further details. Default value: 0.95. fixed_variance: Python floating point number representing the variance of the fixed proposal distribution. See `random_walk_mvnorm_fn` for further details. Default value: 0.01. extra_getter_fn: A callable with the signature `(kernel_results) -> extra` where `kernel_results` are the results of the `inner_kernel`, and `extra` is a nested collection of `Tensor`s. extra_setter_fn: A callable with the signature `(kernel_results, args) -> new_kernel_results` where `kernel_results` are the results of the `inner_kernel`, `args` are a nested collection of `Tensor`s with the same structure as returned by the `extra_getter_fn`, and `new_kernel_results` are a copy of `kernel_results` with `args` in the `extra` field set. log_accept_prob_getter_fn: A callable with the signature `(kernel_results) -> log_accept_prob` where `kernel_results` are the results of the `inner_kernel`, and `log_accept_prob` is either a a scalar, or has shape [num_chains]. seed: Python integer to seed the random number generator. Default value: `None`. name: Python `str` name prefixed to Ops created by this function. Default value: `None`. Returns: next_state: Tensor or list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if `initial_covariance_scaling` is less than or equal to 0.0. ValueError: if `covariance_scaling_reducer` is less than or equal to 0.5 or greater than 1.0. ValueError: if `covariance_scaling_limiter` is less than 0.0 or greater than 1.0. ValueError: if `covariance_burnin` is less than 0. ValueError: if `target_accept_ratio` is less than 0.0 or greater than 1.0. ValueError: if `pu` is less than 0.0 or greater than 1.0. ValueError: if `fixed_variance` is less than 0.0. """ with tf.name_scope( mcmc_util.make_name(name, "AdaptiveRandomWalkMetropolisHastings", "__init__")) as name: if initial_covariance_scaling <= 0.0: raise ValueError( "`{}` must be a `float` greater than 0.0".format( "initial_covariance_scaling")) if covariance_scaling_reducer <= 0.5 or covariance_scaling_reducer > 1.0: raise ValueError( "`{}` must be a `float` greater than 0.5 and less than or equal to 1.0." .format("covariance_scaling_reducer")) if covariance_scaling_limiter < 0.0 or covariance_scaling_limiter > 1.0: raise ValueError( "`{}` must be a `float` between 0.0 and 1.0.".format( "covariance_scaling_limiter")) if covariance_burnin < 0: raise ValueError( "`{}` must be a `integer` greater or equal to 0.".format( "covariance_burnin")) if target_accept_ratio <= 0.0 or target_accept_ratio > 1.0: raise ValueError( "`{}` must be a `float` between 0.0 and 1.0.".format( "target_accept_ratio")) if pu < 0.0 or pu > 1.0: raise ValueError( "`{}` must be a `float` between 0.0 and 1.0.".format("pu")) if fixed_variance < 0.0: raise ValueError( "`{}` must be a `float` greater than 0.0.".format( "fixed_variance")) if mcmc_util.is_list_like(initial_state): initial_state_parts = list(initial_state) else: initial_state_parts = [initial_state] initial_state_parts = [ tf.convert_to_tensor(s, name="initial_state") for s in initial_state_parts ] shape = tf.stack(initial_state_parts).shape dtype = dtype_util.base_dtype(tf.stack(initial_state_parts).dtype) if initial_covariance is None: initial_covariance = 0.001 * tf.eye( num_rows=shape[-1], dtype=dtype, batch_shape=[shape[0]]) else: initial_covariance = tf.stack(initial_covariance) if mcmc_util.is_list_like(initial_covariance): initial_covariance_parts = list(initial_covariance) else: initial_covariance_parts = [initial_covariance] initial_covariance_parts = [ tf.convert_to_tensor(s, name="initial_covariance") for s in initial_covariance_parts ] self._running_covar = stats.RunningCovariance(shape=(1, shape[-1]), dtype=dtype, event_ndims=1) self._accum_covar = self._running_covar.initialize() probs = tf.expand_dims(tf.ones([shape[0]], dtype=dtype) * pu, axis=1) self._u = Bernoulli(probs=probs, dtype=tf.dtypes.int32) self._initial_u = tf.zeros_like(self._u.sample(seed=seed), dtype=tf.dtypes.int32) name = mcmc_util.make_name(name, "AdaptiveRandomWalkMetropolisHastings", "") seed_stream = SeedStream(seed, salt="AdaptiveRandomWalkMetropolisHastings") self._parameters = dict( target_log_prob_fn=target_log_prob_fn, initial_state=initial_state, initial_covariance=initial_covariance, initial_covariance_scaling=initial_covariance_scaling, covariance_scaling_reducer=covariance_scaling_reducer, covariance_scaling_limiter=covariance_scaling_limiter, covariance_burnin=covariance_burnin, target_accept_ratio=target_accept_ratio, pu=pu, fixed_variance=fixed_variance, extra_getter_fn=extra_getter_fn, extra_setter_fn=extra_setter_fn, log_accept_prob_getter_fn=log_accept_prob_getter_fn, seed=seed, name=name, ) self._impl = metropolis_hastings.MetropolisHastings( inner_kernel=random_walk_metropolis.UncalibratedRandomWalk( target_log_prob_fn=target_log_prob_fn, new_state_fn=random_walk_mvnorm_fn( covariance=initial_covariance_parts, pu=pu, fixed_variance=fixed_variance, is_adaptive=self._initial_u, name=name, ), name=name, ), name=name, )
def __init__(self, target_log_prob_fn, step_size, max_tree_depth=10, max_energy_diff=1000., unrolled_leapfrog_steps=1, parallel_iterations=10, seed=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e. the number of nodes in a binary tree `max_tree_depth` nodes deep. The default setting of 10 takes up to 1024 leapfrog steps. max_energy_diff: Scaler threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000. unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. Note that if you set the seed to have deterministic output you should also set `parallel_iterations` to 1. seed: Python integer to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'nuts_kernel'). """ with tf.name_scope(name or 'NoUTurnSampler') as name: # Process `max_tree_depth` argument. max_tree_depth = tf.get_static_value(max_tree_depth) if max_tree_depth is None or max_tree_depth < 1: raise ValueError( 'max_tree_depth must be known statically and >= 1 but was ' '{}'.format(max_tree_depth)) self._max_tree_depth = max_tree_depth # Compute parameters derived from `max_tree_depth`. instruction_array = build_tree_uturn_instruction(max_tree_depth, init_memory=-1) [write_instruction_numpy, read_instruction_numpy ] = generate_efficient_write_read_instruction(instruction_array) # TensorArray version of the read/write instruction need to be created # within the function call to be compatible with XLA. Here we store the # numpy version of the instruction and convert it to TensorArray later. self._write_instruction = write_instruction_numpy self._read_instruction = read_instruction_numpy # Process all other arguments. self._target_log_prob_fn = target_log_prob_fn if not tf.nest.is_nested(step_size): step_size = [step_size] step_size = [ tf.convert_to_tensor(s, dtype_hint=tf.float32) for s in step_size ] self._step_size = step_size self._parameters = dict( target_log_prob_fn=target_log_prob_fn, step_size=step_size, max_tree_depth=max_tree_depth, max_energy_diff=max_energy_diff, unrolled_leapfrog_steps=unrolled_leapfrog_steps, parallel_iterations=parallel_iterations, seed=seed, name=name, ) self._parallel_iterations = parallel_iterations self._seed_stream = SeedStream(seed, salt='nuts_one_step') self._unrolled_leapfrog_steps = unrolled_leapfrog_steps self._name = name self._max_energy_diff = max_energy_diff
def sample_halton_sequence(dim, num_results=None, sequence_indices=None, dtype=tf.float32, randomized=True, seed=None, name=None): r"""Returns a sample from the `dim` dimensional Halton sequence. Warning: The sequence elements take values only between 0 and 1. Care must be taken to appropriately transform the domain of a function if it differs from the unit cube before evaluating integrals using Halton samples. It is also important to remember that quasi-random numbers without randomization are not a replacement for pseudo-random numbers in every context. Quasi random numbers are completely deterministic and typically have significant negative autocorrelation unless randomization is used. Computes the members of the low discrepancy Halton sequence in dimension `dim`. The `dim`-dimensional sequence takes values in the unit hypercube in `dim` dimensions. Currently, only dimensions up to 1000 are supported. The prime base for the k-th axes is the k-th prime starting from 2. For example, if `dim` = 3, then the bases will be [2, 3, 5] respectively and the first element of the non-randomized sequence will be: [0.5, 0.333, 0.2]. For a more complete description of the Halton sequences see [here](https://en.wikipedia.org/wiki/Halton_sequence). For low discrepancy sequences and their applications see [here](https://en.wikipedia.org/wiki/Low-discrepancy_sequence). If `randomized` is true, this function produces a scrambled version of the Halton sequence introduced by [Owen (2017)][1]. For the advantages of randomization of low discrepancy sequences see [here]( https://en.wikipedia.org/wiki/Quasi-Monte_Carlo_method#Randomization_of_quasi-Monte_Carlo). The number of samples produced is controlled by the `num_results` and `sequence_indices` parameters. The user must supply either `num_results` or `sequence_indices` but not both. The former is the number of samples to produce starting from the first element. If `sequence_indices` is given instead, the specified elements of the sequence are generated. For example, sequence_indices=tf.range(10) is equivalent to specifying n=10. #### Examples ```python import tensorflow as tf import tensorflow_probability as tfp # Produce the first 1000 members of the Halton sequence in 3 dimensions. num_results = 1000 dim = 3 sample = tfp.mcmc.sample_halton_sequence( dim, num_results=num_results, seed=127) # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional # hypercube. powers = tf.range(1.0, limit=dim + 1) integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1)) true_value = 1.0 / tf.reduce_prod(powers + 1.0) with tf.Session() as session: values = session.run((integral, true_value)) # Produces a relative absolute error of 1.7%. print ("Estimated: %f, True Value: %f" % values) # Now skip the first 1000 samples and recompute the integral with the next # thousand samples. The sequence_indices argument can be used to do this. sequence_indices = tf.range(start=1000, limit=1000 + num_results, dtype=tf.int32) sample_leaped = tfp.mcmc.sample_halton_sequence( dim, sequence_indices=sequence_indices, seed=111217) integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers, axis=-1)) with tf.Session() as session: values = session.run((integral_leaped, true_value)) # Now produces a relative absolute error of 0.05%. print ("Leaped Estimated: %f, True Value: %f" % values) ``` Args: dim: Positive Python `int` representing each sample's `event_size.` Must not be greater than 1000. num_results: (Optional) Positive scalar `Tensor` of dtype int32. The number of samples to generate. Either this parameter or sequence_indices must be specified but not both. If this parameter is None, then the behaviour is determined by the `sequence_indices`. Default value: `None`. sequence_indices: (Optional) `Tensor` of dtype int32 and rank 1. The elements of the sequence to compute specified by their position in the sequence. The entries index into the Halton sequence starting with 0 and hence, must be whole numbers. For example, sequence_indices=[0, 5, 6] will produce the first, sixth and seventh elements of the sequence. If this parameter is None, then the `num_results` parameter must be specified which gives the number of desired samples starting from the first sample. Default value: `None`. dtype: (Optional) The dtype of the sample. One of: `float16`, `float32` or `float64`. Default value: `tf.float32`. randomized: (Optional) bool indicating whether to produce a randomized Halton sequence. If True, applies the randomization described in [Owen (2017)][1]. Default value: `True`. seed: (Optional) Python integer to seed the random number generator. Only used if `randomized` is True. If not supplied and `randomized` is True, no seed is set. Default value: `None`. name: (Optional) Python `str` describing ops managed by this function. If not supplied the name of this function is used. Default value: "sample_halton_sequence". Returns: halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype and `shape` `[num_results, dim]` if `num_results` was specified or shape `[s, dim]` where s is the size of `sequence_indices` if `sequence_indices` were specified. Raises: ValueError: if both `sequence_indices` and `num_results` were specified or if dimension `dim` is less than 1 or greater than 1000. #### References [1]: Art B. Owen. A randomized Halton algorithm in R. _arXiv preprint arXiv:1706.02808_, 2017. https://arxiv.org/abs/1706.02808 """ if dim < 1 or dim > _MAX_DIMENSION: raise ValueError( 'Dimension must be between 1 and {}. Supplied {}'.format( _MAX_DIMENSION, dim)) if (num_results is None) == (sequence_indices is None): raise ValueError('Either `num_results` or `sequence_indices` must be' ' specified but not both.') if not dtype.is_floating: raise ValueError('dtype must be of `float`-type') with tf.name_scope(name or 'sample'): # Here and in the following, the shape layout is as follows: # [sample dimension, event dimension, coefficient dimension]. # The coefficient dimension is an intermediate axes which will hold the # weights of the starting integer when expressed in the (prime) base for # an event dimension. if num_results is not None: num_results = tf.convert_to_tensor(num_results) if sequence_indices is not None: sequence_indices = tf.convert_to_tensor(sequence_indices) indices = _get_indices(num_results, sequence_indices, dtype) radixes = tf.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1]) max_sizes_by_axes = _base_expansion_size(tf.reduce_max(indices), radixes) max_size = tf.reduce_max(max_sizes_by_axes) # The powers of the radixes that we will need. Note that there is a bit # of an excess here. Suppose we need the place value coefficients of 7 # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits # for base 3. However, we can only create rectangular tensors so we # store both expansions in a [2, 3] tensor. This leads to the problem that # we might end up attempting to raise large numbers to large powers. For # example, base 2 expansion of 1024 has 10 digits. If we were in 10 # dimensions, then the 10th prime (29) we will end up computing 29^10 even # though we don't need it. We avoid this by setting the exponents for each # axes to 0 beyond the maximum value needed for that dimension. exponents_by_axes = tf.tile([tf.range(max_size)], [dim, 1]) # The mask is true for those coefficients that are irrelevant. weight_mask = exponents_by_axes < max_sizes_by_axes capped_exponents = tf.where(weight_mask, exponents_by_axes, tf.constant(0, exponents_by_axes.dtype)) weights = radixes**capped_exponents # The following computes the base b expansion of the indices. Suppose, # x = a0 + a1*b + a2*b^2 + ... Then, performing a floor div of x with # the vector (1, b, b^2, b^3, ...) will produce # (a0 + s1 * b, a1 + s2 * b, ...) where s_i are coefficients we don't care # about. Noting that all a_i < b by definition of place value expansion, # we see that taking the elements mod b of the above vector produces the # place value expansion coefficients. coeffs = tf.math.floordiv(indices, weights) coeffs *= tf.cast(weight_mask, dtype) coeffs %= radixes if not randomized: coeffs /= radixes return tf.reduce_sum(coeffs / weights, axis=-1) stream = SeedStream(seed, salt='MCMCSampleHaltonSequence') coeffs = _randomize(coeffs, radixes, seed=stream()) # Remove the contribution from randomizing the trailing zero for the # axes where max_size_by_axes < max_size. This will be accounted # for separately below (using zero_correction). coeffs *= tf.cast(weight_mask, dtype) coeffs /= radixes base_values = tf.reduce_sum(coeffs / weights, axis=-1) # The randomization used in Owen (2017) does not leave 0 invariant. While # we have accounted for the randomization of the first `max_size_by_axes` # coefficients, we still need to correct for the trailing zeros. Luckily, # this is equivalent to adding a uniform random value scaled so the first # `max_size_by_axes` coefficients are zero. The following statements perform # this correction. zero_correction = tf.random.uniform([dim, 1], seed=stream(), dtype=dtype) zero_correction /= radixes**max_sizes_by_axes return base_values + tf.reshape(zero_correction, [-1])