def _start_trajectory_batched(self, state, target_log_prob): """Computations needed to start a trajectory.""" with tf.name_scope('start_trajectory_batched'): seed_stream = SeedStream(self._seed_stream, salt='start_trajectory_batched') momentum = [ tf.random.normal( # pylint: disable=g-complex-comprehension shape=prefer_static.shape(x), dtype=x.dtype, seed=seed_stream()) for x in state ] init_energy = compute_hamiltonian(target_log_prob, momentum) if MULTINOMIAL_SAMPLE: return momentum, init_energy, None # Draw a slice variable u ~ Uniform(0, p(initial state, initial # momentum)) and compute log u. For numerical stability, we perform this # in log space where log u = log (u' * p(...)) = log u' + log # p(...) and u' ~ Uniform(0, 1). log_slice_sample = tf.math.log1p( -tf.random.uniform(shape=prefer_static.shape(init_energy), dtype=init_energy.dtype, seed=seed_stream())) return momentum, init_energy, log_slice_sample
def default_exchange_proposed_fn_(num_replica, seed=None): """Default function for `exchange_proposed_fn` of `kernel`.""" seed_stream = SeedStream(seed, 'default_exchange_proposed_fn') zero_start = tf.random_uniform([], seed=seed_stream()) > 0.5 if num_replica % 2 == 0: def _exchange(): flat_exchange = tf.range(num_replica) if num_replica > 2: start = tf.to_int32(~zero_start) end = num_replica - start flat_exchange = flat_exchange[start:end] return tf.reshape(flat_exchange, [tf.size(flat_exchange) // 2, 2]) else: def _exchange(): start = tf.to_int32(zero_start) end = num_replica - tf.to_int32(~zero_start) flat_exchange = tf.range(num_replica)[start:end] return tf.reshape(flat_exchange, [tf.size(flat_exchange) // 2, 2]) def _null_exchange(): return tf.reshape(tf.to_int32([]), shape=[0, 2]) return tf.cond( tf.random_uniform([], seed=seed_stream()) < prob_exchange, _exchange, _null_exchange)
def _sample_n(self, n, seed=None): shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) seed = SeedStream(seed, salt="random_horseshoe") local_shrinkage = self._half_cauchy.sample(shape, seed=seed()) shrinkage = self.scale * local_shrinkage sampled = tf.random.normal( shape=shape, mean=0., stddev=1., dtype=self.scale.dtype, seed=seed()) return sampled * shrinkage
def __init__(self, target_log_prob_fn, inverse_temperatures, make_kernel_fn, exchange_proposed_fn=default_exchange_proposed_fn(1.), seed=None, name=None, **kwargs): """Instantiates this object. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. inverse_temperatures: `1D` `Tensor of inverse temperatures to perform samplings with each replica. Must have statically known `shape`. `inverse_temperatures[0]` produces the states returned by samplers, and is typically == 1. make_kernel_fn: Python callable which takes target_log_prob_fn and seed args and returns a TransitionKernel instance. exchange_proposed_fn: Python callable which take a number of replicas, and return combinations of replicas for exchange. seed: Python integer to seed the random number generator. Default value: `None` (i.e., no seed). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "remc_kernel"). **kwargs: Arguments for `make_kernel_fn`. Raises: ValueError: `inverse_temperatures` doesn't have statically known 1D shape. """ inverse_temperatures = tf.convert_to_tensor( inverse_temperatures, name='inverse_temperatures') # Note these are static checks, and don't need to be embedded in the graph. inverse_temperatures.shape.assert_is_fully_defined() inverse_temperatures.shape.assert_has_rank(1) self._seed_stream = SeedStream(seed, salt=name) self._seeded_mcmc = seed is not None self._parameters = dict( target_log_prob_fn=target_log_prob_fn, inverse_temperatures=inverse_temperatures, num_replica=inverse_temperatures.shape[0].value, exchange_proposed_fn=exchange_proposed_fn, seed=seed, name=name) self.replica_kernels = [] for i in range(self.num_replica): self.replica_kernels.append( make_kernel_fn( target_log_prob_fn=_replica_log_prob_fn(inverse_temperatures[i], target_log_prob_fn), seed=self._seed_stream()))
def _sample_n(self, n, seed=None): seed = SeedStream(seed, "gamma_gamma") rate = tf.random_gamma(shape=[n], alpha=self.mixing_concentration, beta=self.mixing_rate, dtype=self.dtype, seed=seed()) return tf.random_gamma(shape=[], alpha=self.concentration, beta=rate, dtype=self.dtype, seed=seed())
def _sample_3d(self, n, seed=None): """Specialized inversion sampler for 3D.""" seed = SeedStream(seed, salt='von_mises_fisher_3d') u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) z = tf.random_uniform(u_shape, seed=seed(), dtype=self.dtype) # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could # be bisected for bounded sampling runtime (i.e. not rejection sampling). # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/ # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa # We must protect against both kappa and z being zero. safe_conc = tf.where(self.concentration > 0, self.concentration, tf.ones_like(self.concentration)) safe_z = tf.where(z > 0, z, tf.ones_like(z)) safe_u = 1 + tf.reduce_logsumexp([ tf.log(safe_z), tf.log1p(-safe_z) - 2 * safe_conc], axis=0) / safe_conc # Limit of the above expression as kappa->0 is 2*z-1 u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u, 2 * z - 1) # Limit of the expression as z->0 is -1. u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u) if not self._allow_nan_stats: u = tf.check_numerics(u, 'u in _sample_3d') return u[..., tf.newaxis]
def random_von_mises(shape, concentration, dtype=tf.float32, seed=None): """Samples from the standardized von Mises distribution. The distribution is vonMises(loc=0, concentration=concentration), so the mean is zero. The location can then be changed by adding it to the samples. The sampling algorithm is rejection sampling with wrapped Cauchy proposal [1]. The samples are pathwise differentiable using the approach of [2]. Arguments: shape: The output sample shape. concentration: The concentration parameter of the von Mises distribution. dtype: The data type of concentration and the outputs. seed: (optional) The random seed. Returns: Differentiable samples of standardized von Mises. References: [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986; Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf + corrections http://www.nrbook.com/devroye/Devroye_files/errors.pdf [2] Michael Figurnov, Shakir Mohamed, Andriy Mnih. "Implicit Reparameterization Gradients", 2018. """ seed = SeedStream(seed, salt="von_mises") concentration = tf.convert_to_tensor(concentration, dtype=dtype, name="concentration") @tf.custom_gradient def rejection_sample_with_gradient(concentration): """Performs rejection sampling for standardized von Mises. A nested function is required because @tf.custom_gradient does not handle non-tensor inputs such as dtype. Instead, they are captured by the outer scope. Arguments: concentration: The concentration parameter of the distribution. Returns: Differentiable samples of standardized von Mises. """ r = 1. + tf.sqrt(1. + 4. * concentration**2) rho = (r - tf.sqrt(2. * r)) / (2. * concentration) s_exact = (1. + rho**2) / (2. * rho) # For low concentration, s becomes numerically unstable. # To fix that, we use an approximation. Here is the derivation. # First-order Taylor expansion at conc = 0 gives # sqrt(1 + 4 concentration^2) ~= 1 + (2 concentration)^2 / 2. # Therefore, r ~= 2 + 2 concentration. By plugging this into rho, we have # rho ~= conc + 1 / conc - sqrt(1 + 1 / concentration^2). # Let's expand the last term at concentration=0 up to the linear term: # sqrt(1 + 1 / concentration^2) ~= 1 / concentration + concentration / 2 # Thus, rho ~= concentration / 2. Finally, # s = 1 / (2 rho) + rho / 2 ~= 1 / concentration + concentration / 4. # Since concentration is small, we drop the second term and simply use # s ~= 1 / concentration. s_approximate = 1. / concentration # To compute the cutoff, we compute s_exact using mpmath with 30 decimal # digits precision and compare that to the s_exact and s_approximate # computed with dtype. Then, the cutoff is the largest concentration for # which abs(s_exact - s_exact_mpmath) > abs(s_approximate - s_exact_mpmath). s_concentration_cutoff_dict = { tf.float16: 1.8e-1, tf.float32: 2e-2, tf.float64: 1.2e-4, } s_concentration_cutoff = s_concentration_cutoff_dict[dtype] s = tf.where(concentration > s_concentration_cutoff, s_exact, s_approximate) def loop_body(should_continue, u, w): """Resample the non-accepted points.""" # We resample u each time completely. Only its sign is used outside the # loop, which is random. u = tf.random_uniform(shape, minval=-1., maxval=1., dtype=dtype, seed=seed()) z = tf.cos(np.pi * u) # Update the non-accepted points. w = tf.where(should_continue, (1. + s * z) / (s + z), w) y = concentration * (s - w) v = tf.random_uniform(shape, minval=0., maxval=1., dtype=dtype, seed=seed()) accept = (y * (2. - y) >= v) | (tf.log(y / v) + 1. >= y) should_continue = should_continue & (~accept) return should_continue, u, w _, u, w = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue ), body=loop_body, loop_vars=( tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=dtype), # u tf.zeros(shape, dtype=dtype)), # w # The expected number of iterations depends on concentration. # It monotonically increases from one iteration for concentration = 0 to # sqrt(2 pi / e) ~= 1.52 iterations for concentration = +inf [1]. # We use a limit of 100 iterations to avoid infinite loops # for very large / nan concentration. maximum_iterations=100, ) x = tf.sign(u) * tf.math.acos(w) def grad(dy): """The gradient of the von Mises samples w.r.t. concentration.""" broadcast_concentration = concentration + tf.zeros_like(x) cdf_func = lambda conc: von_mises_cdf(x, conc) _, dcdf_dconcentration = _compute_value_and_grad( cdf_func, broadcast_concentration) inv_prob = tf.exp(-broadcast_concentration * (tf.cos(x) - 1.)) * ( (2. * np.pi) * tf.math.bessel_i0e(broadcast_concentration)) # Compute the implicit reparameterization gradient [2], # dz/dconc = -(dF(z; conc) / dconc) / p(z; conc) ret = dy * (-inv_prob * dcdf_dconcentration) # Sum over the sample dimensions. Assume that they are always the first # ones. num_sample_dimensions = (tf.rank(broadcast_concentration) - tf.rank(concentration)) return tf.reduce_sum(ret, axis=tf.range(num_sample_dimensions)) return x, grad return rejection_sample_with_gradient(concentration)
def _sample_n(self, n, seed=None): shape = tf.concat([[n], self.batch_shape_tensor()], axis=0) has_seed = seed is not None seed = SeedStream(seed, salt="zipf") minval_u = self._hat_integral(0.5) + 1. maxval_u = self._hat_integral(tf.int64.max - 0.5) def loop_body(should_continue, k): """Resample the non-accepted points.""" # The range of U is chosen so that the resulting sample K lies in # [0, tf.int64.max). The final sample, if accepted, is K + 1. u = tf.random.uniform( shape, minval=minval_u, maxval=maxval_u, dtype=self.power.dtype, seed=seed()) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5) + tf.exp(self._log_prob(k + 1))) return [should_continue & (~accept), k] should_continue, samples = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any( input_tensor=should_continue), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=self.power.dtype), # k ], parallel_iterations=1 if has_seed else 10, maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and dtype_util.is_integer(self.dtype): samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: npdt = dtype_util.as_numpy_dtype(self.dtype) v = npdt(dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan) mask = tf.fill(shape, value=v) samples = tf.where(should_continue, mask, samples) return samples
def __init__(self, target_log_prob_fn, step_size, max_tree_depth=6, max_energy_diff=1000., unrolled_leapfrog_steps=1, seed=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. Currnently only support target_log_prob_fn that takes only 1 arg (ie the state or free parameters of your model), with the the input being a 2d tensor with shape being batch_size * state_part_size. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e. the number of nodes in a binary tree `max_tree_depth` nodes deep. The default setting of 6 takes up to 64 leapfrog steps. max_energy_diff: Scaler threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000. unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1. seed: Python integer to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'nuts_kernel'). """ with tf.name_scope(name or 'NoUTurnSamplerUnrolled') as name: # Process `max_tree_depth` argument. max_tree_depth = tf.get_static_value(max_tree_depth) if max_tree_depth is None or max_tree_depth < 1: raise ValueError( 'max_tree_depth must be known statically and >= 1 but was ' '{}'.format(max_tree_depth)) self._max_tree_depth = max_tree_depth # Compute parameters derived from `max_tree_depth`. instruction_array = build_tree_uturn_instruction(max_tree_depth, init_memory=-1) [write_instruction, read_instruction ] = generate_efficient_write_read_instruction(instruction_array) if USE_RAGGED_TENSOR: self._write_instruction = tf.constant(write_instruction) self._read_instruction = tf.ragged.constant(read_instruction) else: f = lambda int_iter: write_instruction[int_iter] self._write_instruction = { x: functools.partial(f, x) for x in range(len(write_instruction)) } self._read_instruction = read_instruction # Process all other arguments. self._target_log_prob_fn = target_log_prob_fn if not tf.nest.is_nested(step_size): step_size = [step_size] step_size = [ tf.convert_to_tensor(s, dtype_hint=tf.float32) for s in step_size ] self._step_size = step_size self._parameters = dict( target_log_prob_fn=target_log_prob_fn, step_size=step_size, max_tree_depth=max_tree_depth, max_energy_diff=max_energy_diff, unrolled_leapfrog_steps=unrolled_leapfrog_steps, seed=seed, name=name, ) self._seed_stream = SeedStream(seed, salt='nuts_one_step') self._unrolled_leapfrog_steps = unrolled_leapfrog_steps self._name = name self._max_energy_diff = max_energy_diff
def _sample_n(self, n, seed=None): seed = SeedStream(seed, salt='vom_mises_fisher') # The sampling strategy relies on the fact that vMF variates are symmetric # about the mean direction. Accordingly, if we have a sampling strategy for # the away-from-mean angle, then we can uniformly sample the remaining # dimensions on the S^{dim-2} sphere for , and rotate these samples from a # (1, 0, 0, ..., 0)-mode distribution into the target orientation. # # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a # von-Mises distributed `x` value in [-1, 1], then uniformly select what # amounts to a "up" or "down" additional degree of freedom after unit # normalizing, followed by a final rotation to the desired mean direction # from a basis of (1, 0). # # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the # unit sphere over which the distribution is uniform, in particular the # circle where x = \hat{x} intersects the unit sphere. We pick a point on # that circle, then rotate to the desired mean direction from a basis of # (1, 0, 0). event_dim = self.event_shape[0].value or self._event_shape_tensor()[0] sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) dim = tf.cast(event_dim - 1, self.dtype) if event_dim == 3: samples_dim0 = self._sample_3d(n, seed=seed) else: # Wood'94 provides a rejection algorithm to sample the x coordinate. # Wood'94 definition of b: # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim # https://stats.stackexchange.com/questions/156729 suggests: b = dim / (2 * self.concentration + tf.sqrt(4 * self.concentration**2 + dim**2)) # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE # https://github.com/nicola-decao/s-vae-tf/ x = (1 - b) / (1 + b) c = self.concentration * x + dim * tf.log1p(-x**2) beta = tf.distributions.Beta(dim / 2, dim / 2) def cond_fn(w, should_continue): del w return tf.reduce_any(should_continue) def body_fn(w, should_continue): z = beta.sample(sample_shape=sample_batch_shape, seed=seed()) w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w) w = tf.check_numerics(w, 'w') should_continue = tf.logical_and( should_continue, self.concentration * w + dim * tf.log1p(-x * w) - c < tf.log(tf.random_uniform(sample_batch_shape, seed=seed(), dtype=self.dtype))) return w, should_continue w = tf.zeros(sample_batch_shape, dtype=self.dtype) should_continue = tf.ones(sample_batch_shape, dtype=tf.bool) samples_dim0 = tf.while_loop(cond_fn, body_fn, (w, should_continue))[0] samples_dim0 = samples_dim0[..., tf.newaxis] if not self._allow_nan_stats: # Verify samples are w/in -1, 1, with useful error output tensors (top # value rather than all values). with tf.control_dependencies([ tf.assert_less_equal( samples_dim0, self.dtype.as_numpy_dtype(1.01), data=[tf.nn.top_k(tf.reshape(samples_dim0, [-1]))[0]]), tf.assert_greater_equal( samples_dim0, self.dtype.as_numpy_dtype(-1.01), data=[-tf.nn.top_k(tf.reshape(-samples_dim0, [-1]))[0]])]): samples_dim0 = tf.identity(samples_dim0) samples_otherdims_shape = tf.concat([sample_batch_shape, [event_dim - 1]], axis=0) unit_otherdims = tf.nn.l2_normalize( tf.random_normal(samples_otherdims_shape, seed=seed(), dtype=self.dtype), axis=-1) samples = tf.concat([ samples_dim0, # we must avoid sqrt(1 - (>1)**2) tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims ], axis=-1) samples = tf.nn.l2_normalize(samples, axis=-1) if not self._allow_nan_stats: samples = tf.check_numerics(samples, 'samples') # Runtime assert that samples are unit length. if not self._allow_nan_stats: worst, idx = tf.nn.top_k( tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1])) with tf.control_dependencies([ tf.assert_near( self.dtype.as_numpy_dtype(0), worst, data=[worst, idx, tf.gather(tf.reshape(samples, [-1, event_dim]), idx)], atol=1e-4, summarize=100)]): samples = tf.identity(samples) # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0). # Now, we move the mode to `self.mean_direction` using a rotation matrix. if not self._allow_nan_stats: # Assert that the basis vector rotates to the mean direction, as expected. basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0), self.dtype) with tf.control_dependencies([ tf.assert_less( tf.linalg.norm(self._rotate(basis) - self.mean_direction, axis=-1), self.dtype.as_numpy_dtype(1e-5)) ]): return self._rotate(samples) return self._rotate(samples)