def _get_permutations(num_results, dims, seed=None): """Uniform iid sample from the space of permutations. Draws a sample of size `num_results` from the group of permutations of degrees specified by the `dims` tensor. These are packed together into one tensor such that each row is one sample from each of the dimensions in `dims`. For example, if dims = [2,3] and num_results = 2, the result is a tensor of shape [2, 2 + 3] and the first row of the result might look like: [1, 0, 2, 0, 1]. The first two elements are a permutation over 2 elements while the next three are a permutation over 3 elements. Args: num_results: A positive scalar `Tensor` of integral type. The number of draws from the discrete uniform distribution over the permutation groups. dims: A 1D `Tensor` of the same dtype as `num_results`. The degree of the permutation groups from which to sample. seed: (Optional) Python integer to seed the random number generator. Returns: permutations: A `Tensor` of shape `[num_results, sum(dims)]` and the same dtype as `dims`. """ sample_range = tf.range(num_results) stream = SeedStream(seed, salt='MCMCSampleHaltonSequence3') def generate_one(d): seed = stream() fn = lambda _: tf.random.shuffle(tf.range(d), seed=seed) return tf.map_fn(fn, sample_range, parallel_iterations=1 if seed is not None else 10) return tf.concat([generate_one(d) for d in tf.unstack(dims)], axis=-1)
def _start_trajectory_batched(self, state, target_log_prob): """Computations needed to start a trajectory.""" with tf.name_scope('start_trajectory_batched'): seed_stream = SeedStream(self._seed_stream, salt='start_trajectory_batched') momentum = [ tf.random.normal( # pylint: disable=g-complex-comprehension shape=prefer_static.shape(x), dtype=x.dtype, seed=seed_stream()) for x in state ] init_energy = compute_hamiltonian(target_log_prob, momentum) if MULTINOMIAL_SAMPLE: return momentum, init_energy, None # Draw a slice variable u ~ Uniform(0, p(initial state, initial # momentum)) and compute log u. For numerical stability, we perform this # in log space where log u = log (u' * p(...)) = log u' + log # p(...) and u' ~ Uniform(0, 1). log_slice_sample = tf.math.log1p( -tf.random.uniform(shape=prefer_static.shape(init_energy), dtype=init_energy.dtype, seed=seed_stream())) return momentum, init_energy, log_slice_sample
def default_exchange_proposed_fn_(num_replica, seed=None): """Default function for `exchange_proposed_fn` of `kernel`.""" seed_stream = SeedStream(seed, 'default_exchange_proposed_fn') zero_start = tf.random.uniform([], seed=seed_stream()) > 0.5 if num_replica % 2 == 0: def _exchange(): flat_exchange = tf.range(num_replica) if num_replica > 2: start = tf.cast(~zero_start, dtype=tf.int32) end = num_replica - start flat_exchange = flat_exchange[start:end] return tf.reshape(flat_exchange, [tf.size(input=flat_exchange) // 2, 2]) else: def _exchange(): start = tf.cast(zero_start, dtype=tf.int32) end = num_replica - tf.cast(~zero_start, dtype=tf.int32) flat_exchange = tf.range(num_replica)[start:end] return tf.reshape(flat_exchange, [tf.size(input=flat_exchange) // 2, 2]) def _null_exchange(): return tf.reshape(tf.cast([], dtype=tf.int32), shape=[0, 2]) return tf.cond( pred=tf.random.uniform([], seed=seed_stream()) < prob_exchange, true_fn=_exchange, false_fn=_null_exchange)
def __init__(self, target_log_prob_fn, step_size, num_leapfrog_steps, state_gradients_are_stopped=False, step_size_update_fn=None, seed=None, store_parameters_in_results=False, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. state_gradients_are_stopped: Python `bool` indicating that the proposed new state be run through `tf.stop_gradient`. This is particularly useful when combining optimization over samples from the HMC chain. Default value: `False` (i.e., do not apply `stop_gradient`). step_size_update_fn: Python `callable` taking current `step_size` (typically a `tf.Variable`) and `kernel_results` (typically `collections.namedtuple`) and returns updated step_size (`Tensor`s). Default value: `None` (i.e., do not update `step_size` automatically). seed: Python integer to seed the random number generator. store_parameters_in_results: If `True`, then `step_size` and `num_leapfrog_steps` are written to and read from eponymous fields in the kernel results objects returned from `one_step` and `bootstrap_results`. This allows wrapper kernels to adjust those parameters on the fly. This is incompatible with `step_size_update_fn`, which must be set to `None`. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_kernel'). """ if step_size_update_fn and store_parameters_in_results: raise ValueError('It is invalid to simultaneously specify ' '`step_size_update_fn` and set ' '`store_parameters_in_results` to `True`.') self._seed_stream = SeedStream(seed, salt='hmc') self._impl = metropolis_hastings.MetropolisHastings( inner_kernel=UncalibratedHamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=state_gradients_are_stopped, seed=self._seed_stream(), name=name or 'hmc_kernel', store_parameters_in_results=store_parameters_in_results), seed=self._seed_stream()) self._parameters = self._impl.inner_kernel.parameters.copy() self._parameters['step_size_update_fn'] = step_size_update_fn self._parameters['seed'] = seed
def __init__(self, target_log_prob_fn, step_size, num_leapfrog_steps, state_gradients_are_stopped=False, seed=None, store_parameters_in_results=False, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. step_size: `Tensor` or Python `list` of `Tensor`s representing the step size for the leapfrog integrator. Must broadcast with the shape of `current_state`. Larger step sizes lead to faster progress, but too-large step sizes make rejection exponentially more likely. When possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. num_leapfrog_steps: Integer number of steps to run the leapfrog integrator for. Total progress per HMC step is roughly proportional to `step_size * num_leapfrog_steps`. state_gradients_are_stopped: Python `bool` indicating that the proposed new state be run through `tf.stop_gradient`. This is particularly useful when combining optimization over samples from the HMC chain. Default value: `False` (i.e., do not apply `stop_gradient`). seed: Python integer to seed the random number generator. Deprecated, pass seed to `tfp.mcmc.sample_chain`. store_parameters_in_results: If `True`, then `step_size` and `num_leapfrog_steps` are written to and read from eponymous fields in the kernel results objects returned from `one_step` and `bootstrap_results`. This allows wrapper kernels to adjust those parameters on the fly. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'hmc_kernel'). """ if seed is not None and tf.executing_eagerly(): # TODO(b/68017812): Re-enable once TFE supports `tf.random.shuffle` seed. raise NotImplementedError( 'Specifying a `seed` when running eagerly is ' 'not currently supported. To run in Eager ' 'mode with a seed, pass the seed to ' '`tfp.mcmc.sample_chain`.') if not store_parameters_in_results: mcmc_util.warn_if_parameters_are_not_simple_tensors( dict(step_size=step_size, num_leapfrog_steps=num_leapfrog_steps)) self._seed_stream = SeedStream(seed, salt='uncalibrated_hmc_one_step') self._parameters = dict( target_log_prob_fn=target_log_prob_fn, step_size=step_size, num_leapfrog_steps=num_leapfrog_steps, state_gradients_are_stopped=state_gradients_are_stopped, seed=seed, name=name, store_parameters_in_results=store_parameters_in_results, ) self._momentum_dtype = None
def _inner(seed): seed_stream = SeedStream(seed, '_inner') x = tf.random.normal(sample_shape, dtype=internal_dtype, seed=seed_stream()) # This implicitly broadcasts alpha up to sample shape. v = 1 + c * x return (x, v), v > 0.
def randomized_computation(seed): seed_stream = SeedStream(seed, 'batched_rejection_sampler') proposed_samples, proposed_values = proposal(seed_stream()) good_samples_mask = tf.less_equal( proposed_values * tf.random.uniform( proposed_samples.shape, maxval=1., seed=seed_stream()), target(proposed_samples)) return proposed_samples, good_samples_mask
def _sample_n(self, n, seed): df = tf.convert_to_tensor(self.df) batch_shape = self._batch_shape_tensor(df) event_shape = self._event_shape_tensor() batch_ndims = tf.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = tf.concat([[n], batch_shape, event_shape], 0) stream = SeedStream(seed, salt='Wishart') # Complexity: O(nbk**2) x = tf.random.normal( shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=stream()) # Complexity: O(nbk) # This parameterization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) expanded_df = df * tf.ones( self._scale.batch_shape_tensor(), dtype=dtype_util.base_dtype(df.dtype)) g = tf.random.gamma( shape=[n], alpha=self._multi_gamma_sequence(0.5 * expanded_df, self._dimension()), beta=0.5, dtype=self.dtype, seed=stream()) # Complexity: O(nbk**2) x = tf.linalg.band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = tf.linalg.set_diag(x, tf.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk**2) perm = tf.concat([tf.range(1, ndims), [0]], 0) x = tf.transpose(a=x, perm=perm) shape = tf.concat([batch_shape, [event_shape[0]], [event_shape[1] * n]], 0) x = tf.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so # this step has complexity O(nbk^3). x = self._scale.matmul(x) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat([batch_shape, event_shape, [n]], 0) x = tf.reshape(x, shape) perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0) x = tf.transpose(a=x, perm=perm) if not self.input_output_cholesky: # Complexity: O(nbk**3) x = tf.matmul(x, x, adjoint_b=True) return x
def _sample_n(self, n, seed): with tf.compat.v1.control_dependencies(self._runtime_assertions): seed = SeedStream(seed, salt="ZeroInflated") mask = self.inflated_distribution.sample(n, seed()) samples = self.count_distribution.sample(n, seed()) mask, samples = _broadcast_rate(mask, samples) # mask = 1 => new_sample = 0 # mask = 0 => new_sample = sample return samples * tf.cast(1 - mask, samples.dtype)
def randomized_computation(seed): seed_stream = SeedStream(seed, 'batched_rejection_sampler') proposed_samples, proposed_values = proposal_fn(seed_stream()) good_samples_mask = tf.less_equal( proposed_values * tf.random.uniform( prefer_static.shape(proposed_samples), seed=seed_stream(), dtype=dtype), target_fn(proposed_samples)) return proposed_samples, good_samples_mask
def _sample_n(self, n, seed=None): seeds = samplers.split_seed(seed, n=self.num_components + 1, salt='Mixture') try: seed_stream = SeedStream(seed, salt='Mixture') except TypeError as e: # Can happen for Tensor seed. seed_stream = None seed_stream_err = e # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object. samples = [] cat_samples = self.cat.sample(n, seed=seeds[0]) for c in range(self.num_components): try: samples.append(self.components[c].sample(n, seed=seeds[c + 1])) if seed_stream is not None: seed_stream() except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise if seed_stream is None: raise seed_stream_err msg = ( 'Falling back to stateful sampling for `components[{}]` {} of ' 'type `{}`. Please update to use `tf.random.stateless_*` RNGs. ' 'This fallback may be removed after 20-Aug-2020. ({})') warnings.warn( msg.format(c, self.components[c].name, type(self.components[c]), str(e))) samples.append(self.components[c].sample(n, seed=seed_stream())) stack_axis = -1 - tensorshape_util.rank(self._static_event_shape) x = tf.stack(samples, axis=stack_axis) # [n, B, k, E] # TODO(b/170730865): Is all this masking stuff really called for? npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=npdt(1), off_value=npdt(0)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, tensorshape_util.rank( self._static_event_shape)) # [n, B, k, [1]*e] if x.dtype.is_floating: masked = tf.math.multiply_no_nan(x, mask) else: masked = x * mask return tf.reduce_sum(masked, axis=stack_axis) # [n, B, E]
def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None): """Executes `model`, creating both samples and distributions.""" ds = [] values_out = [] seed = SeedStream(seed, salt='JointDistributionCoroutine') gen = self._model_coroutine() index = 0 d = next(gen) if self._require_root and not isinstance(d, self.Root): raise ValueError('First distribution yielded by coroutine must ' 'be wrapped in `Root`.') try: while True: actual_distribution = d.distribution if isinstance( d, self.Root) else d ds.append(actual_distribution) if (value is not None and len(value) > index and value[index] is not None): seed( ) # Ensure reproducibility even when xs are (partially) set. def convert_tree_to_tensor(x, dtype_hint): return tf.convert_to_tensor(x, dtype_hint=dtype_hint) # This signature does not allow kwarg names. Applies # `convert_to_tensor` on the next value. next_value = nest.map_structure_up_to( ds[-1].dtype, # shallow_tree convert_tree_to_tensor, # func value[index], # x ds[-1].dtype) # dtype_hint else: next_value = actual_distribution.sample( sample_shape=sample_shape if isinstance(d, self.Root) else (), seed=seed()) if self._validate_args: with tf.control_dependencies( self._assert_compatible_shape( index, sample_shape, next_value)): values_out.append( tf.nest.map_structure(tf.identity, next_value)) else: values_out.append(next_value) index += 1 d = gen.send(next_value) except StopIteration: pass return ds, values_out
def _sample_n(self, n, seed=None): scale = tf.convert_to_tensor(self.scale) shape = tf.concat([[n], tf.shape(scale)], axis=0) seed = SeedStream(seed, salt='random_horseshoe') local_shrinkage = self._half_cauchy.sample(shape, seed=seed()) shrinkage = scale * local_shrinkage sampled = tf.random.normal(shape=shape, mean=0., stddev=1., dtype=scale.dtype, seed=seed()) return sampled * shrinkage
def _sample_n(self, n, seed): seed = SeedStream(seed, salt='MixtureSameFamily') x = self.components_distribution.sample(n, seed=seed()) # [n, B, k, E] event_shape = None event_ndims = tensorshape_util.rank(self.event_shape) if event_ndims is None: event_shape = self.components_distribution.event_shape_tensor() event_ndims = prefer_static.rank_from_shape(event_shape) event_ndims_static = tf.get_static_value(event_ndims) num_components = None if event_ndims_static is not None: num_components = tf.compat.dimension_value( x.shape[-1 - event_ndims_static]) # We could also check if num_components can be computed statically from # self.mixture_distribution's logits or probs. if num_components is None: num_components = tf.shape(x)[-1 - event_ndims] # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=self.mixture_distribution.sample( n, seed=seed()), # [n, B] or [n] depth=num_components, on_value=npdt(1), off_value=npdt(0)) # [n, B, k] or [n, k] # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] . batch_ndims = prefer_static.rank(x) - event_ndims - 1 mask_batch_ndims = prefer_static.rank(mask) - 1 pad_ndims = batch_ndims - mask_batch_ndims mask_shape = prefer_static.shape(mask) mask = tf.reshape( mask, shape=prefer_static.concat([ mask_shape[:-1], prefer_static.ones([pad_ndims], dtype=tf.int32), mask_shape[-1:], prefer_static.ones([event_ndims], dtype=tf.int32), ], axis=0)) ret = tf.reduce_sum(x * mask, axis=-1 - event_ndims) # [n, B, E] if self._reparameterize: if event_shape is None: event_shape = self.components_distribution.event_shape_tensor() ret = self._reparameterize_sample(ret, event_shape=event_shape) return ret
def __init__(self, target_log_prob_fn, inverse_temperatures, make_kernel_fn, exchange_proposed_fn=default_exchange_proposed_fn(1.), seed=None, name=None): """Instantiates this object. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. inverse_temperatures: `1D` `Tensor of inverse temperatures to perform samplings with each replica. Must have statically known `shape`. `inverse_temperatures[0]` produces the states returned by samplers, and is typically == 1. make_kernel_fn: Python callable which takes target_log_prob_fn and seed args and returns a TransitionKernel instance. exchange_proposed_fn: Python callable which take a number of replicas, and return combinations of replicas for exchange. seed: Python integer to seed the random number generator. Default value: `None` (i.e., no seed). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "remc_kernel"). Raises: ValueError: `inverse_temperatures` doesn't have statically known 1D shape. """ inverse_temperatures = tf.convert_to_tensor( value=inverse_temperatures, name='inverse_temperatures') # Note these are static checks, and don't need to be embedded in the graph. inverse_temperatures.shape.assert_is_fully_defined() inverse_temperatures.shape.assert_has_rank(1) self._seed_stream = SeedStream(seed, salt=name) self._seeded_mcmc = seed is not None self._parameters = dict( target_log_prob_fn=target_log_prob_fn, inverse_temperatures=inverse_temperatures, num_replica=tf.compat.dimension_value(inverse_temperatures.shape[0]), exchange_proposed_fn=exchange_proposed_fn, seed=seed, name=name) self.replica_kernels = [] for i in range(self.num_replica): self.replica_kernels.append( make_kernel_fn( target_log_prob_fn=_replica_log_prob_fn(inverse_temperatures[i], target_log_prob_fn), seed=self._seed_stream()))
def _sample_n(self, n, seed=None): seed = SeedStream(seed, 'beta') concentration1 = tf.convert_to_tensor(self.concentration1) concentration0 = tf.convert_to_tensor(self.concentration0) shape = self._batch_shape_tensor(concentration1, concentration0) expanded_concentration1 = tf.broadcast_to(concentration1, shape) expanded_concentration0 = tf.broadcast_to(concentration0, shape) gamma1_sample = tf.random.gamma( shape=[n], alpha=expanded_concentration1, dtype=self.dtype, seed=seed()) gamma2_sample = tf.random.gamma( shape=[n], alpha=expanded_concentration0, dtype=self.dtype, seed=seed()) beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) return beta_sample
def _sample_n(self, n, seed): seed = SeedStream(seed, salt="ZeroInflated") mask = self.inflated_distribution.sample(n, seed()) samples = self.count_distribution.sample(n, seed()) tf.assert_equal( tf.rank(samples) >= tf.rank(mask), True, message=f"Cannot broadcast zero inflated mask of shape {mask.shape} " f"to sample shape {samples.shape}") samples, mask = _make_broadcastable(samples, mask) # mask = 1 => new_sample = 0 # mask = 0 => new_sample = sample return samples * tf.cast(1 - mask, samples.dtype)
def _sample_n(self, n, seed=None): # Here we use the fact that if: # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs) # then X ~ Poisson(lam) is Negative Binomially distributed. logits = self._logits_parameter_no_checks() stream = SeedStream(seed, salt='NegativeBinomial') rate = tf.random.gamma( shape=[n], alpha=self.total_count, beta=tf.exp(-logits), dtype=self.dtype, seed=stream()) return tf.random.poisson( lam=rate, shape=[], dtype=self.dtype, seed=stream())
def _sample_n(self, n, seed=None): seed_stream = SeedStream(seed, 'beta_binomial') total_count, concentration1, concentration0 = self._params_list_as_tensors( ) batch_shape_tensor = self.batch_shape_tensor() probs = beta.Beta(tf.broadcast_to(concentration1, batch_shape_tensor), concentration0, validate_args=self.validate_args).sample( n, seed=seed_stream()) return binomial.Binomial( total_count, probs=probs, validate_args=self.validate_args).sample(seed=seed_stream())
def resample(log_weights, current_state, particle_info, seed=None): """Resample particles based on importance weights.""" with tf.name_scope('resample_particles'): seed = SeedStream(seed, salt='resample_particles') resampling_indexes = tf.random.categorical( [log_weights], ps.reduce_prod(*ps.shape(log_weights)), seed=seed()) next_state = tf.nest.map_structure( lambda x: tf.reshape(tf.gather(x, resampling_indexes), ps.shape(x)), current_state) next_particle_info = tf.nest.map_structure( lambda x: tf.reshape(tf.gather(x, resampling_indexes), ps.shape(x)), particle_info) return next_state, next_particle_info
def _sample_n(self, n, seed): # only for MixtureSameFamilySampleFix import warnings from tensorflow_probability.python.distributions import independent from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.util.seed_stream import SeedStream from tensorflow_probability.python.util.seed_stream import ( TENSOR_SEED_MSG_PREFIX, ) components_seed, mix_seed = samplers.split_seed( seed, salt="MixtureSameFamily") try: seed_stream = SeedStream(seed, salt="MixtureSameFamily") except TypeError as e: # Can happen for Tensor seeds. seed_stream = None seed_stream_err = e try: mix_sample = self.mixture_distribution.sample( n, seed=mix_seed) # [n, B] or [n] except TypeError as e: if "Expected int for argument" not in str( e) and TENSOR_SEED_MSG_PREFIX not in str(e): raise if seed_stream is None: raise seed_stream_err msg = ( "Falling back to stateful sampling for `mixture_distribution` " "{} of type `{}`. Please update to use `tf.random.stateless_*` " "RNGs. This fallback may be removed after 20-Aug-2020. ({})") warnings.warn( msg.format( self.mixture_distribution.name, type(self.mixture_distribution), str(e), )) mix_sample = self.mixture_distribution.sample( n, seed=seed_stream()) # [n, B] or [n] _seed = int(components_seed[0].numpy()) ret = tf.stack( [ self.components_distribution[i_component.numpy()].sample( seed=_seed + i) for i, i_component in enumerate(mix_sample) ], axis=0, ) return ret
def __init__(self, target_log_prob_fn, new_state_fn=None, seed=None, name=None): if new_state_fn is None: new_state_fn = random_walk_normal_fn() self._target_log_prob_fn = target_log_prob_fn self._seed_stream = SeedStream(seed, salt='RandomWalkMetropolis') self._name = name self._parameters = dict(target_log_prob_fn=target_log_prob_fn, new_state_fn=new_state_fn, seed=seed, name=name)
def _randomize(coeffs, radixes, seed=None): """Applies the Owen (2017) randomization to the coefficients.""" given_dtype = coeffs.dtype coeffs = tf.cast(coeffs, dtype=tf.int32) num_coeffs = tf.shape(coeffs)[-1] radixes = tf.reshape(tf.cast(radixes, dtype=tf.int32), shape=[-1]) stream = SeedStream(seed, salt='MCMCSampleHaltonSequence2') perms = _get_permutations(num_coeffs, radixes, seed=stream()) perms = tf.reshape(perms, shape=[-1]) radix_sum = tf.reduce_sum(radixes) radix_offsets = tf.reshape(tf.cumsum(radixes, exclusive=True), shape=[-1, 1]) offsets = radix_offsets + tf.range(num_coeffs) * radix_sum permuted_coeffs = tf.gather(perms, coeffs + offsets) return tf.cast(permuted_coeffs, dtype=given_dtype)
def __init__(self, target_log_prob_fn, inverse_temperatures, make_kernel_fn, swap_proposal_fn=default_swap_proposal_fn(1.), state_includes_replicas=False, seed=None, validate_args=False, name=None): """Instantiates this object. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. inverse_temperatures: `Tensor` of inverse temperatures to temper each replica. The leftmost dimension is the `num_replica` and the second dimension through the rightmost can provide different temperature to different batch members, doing a left-justified broadcast. make_kernel_fn: Python callable which takes a `target_log_prob_fn` arg and returns a `tfp.mcmc.TransitionKernel` instance. Passing a function taking `(target_log_prob_fn, seed)` deprecated but supported until 2020-09-20. swap_proposal_fn: Python callable which take a number of replicas, and returns `swaps`, a shape `[num_replica] + batch_shape` `Tensor`, where axis 0 indexes a permutation of `{0,..., num_replica-1}`, designating replicas to swap. state_includes_replicas: Boolean indicating whether the leftmost dimension of each state sample should index replicas. If `True`, the leftmost dimension of the `current_state` kwarg to `tfp.mcmc.sample_chain` will be interpreted as indexing replicas. seed: Python integer to seed the random number generator. Deprecated, pass seed to `tfp.mcmc.sample_chain`. Default value: `None` (i.e., no seed). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., "remc_kernel"). Raises: ValueError: `inverse_temperatures` doesn't have statically known 1D shape. """ self._parameters = {k: v for k, v in locals().items() if v is not self} self._state_includes_replicas = state_includes_replicas self._seed_stream = SeedStream(seed, salt='replica_mc')
def __init__(self, target_log_prob_fn, new_state_fn=None, seed=None, name=None): """Initializes this transition kernel. Args: target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. new_state_fn: Python callable which takes a list of state parts and a seed; returns a same-type `list` of `Tensor`s, each being a perturbation of the input state parts. The perturbation distribution is assumed to be a symmetric distribution centered at the input state part. Default value: `None` which is mapped to `tfp.mcmc.random_walk_normal_fn()`. seed: Python integer to seed the random number generator. Deprecated, pass seed to `tfp.mcmc.sample_chain`. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'rwm_kernel'). Returns: next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. Raises: ValueError: if there isn't one `scale` or a list with same length as `current_state`. """ if new_state_fn is None: new_state_fn = random_walk_normal_fn() seed_stream = SeedStream(seed, salt='rwm') mh_kwargs = {} if seed is None else dict(seed=seed_stream()) uncal_kwargs = {} if seed is None else dict(seed=seed_stream()) self._impl = metropolis_hastings.MetropolisHastings( inner_kernel=UncalibratedRandomWalk( target_log_prob_fn=target_log_prob_fn, new_state_fn=new_state_fn, name=name, **uncal_kwargs), **mh_kwargs)
def make_rwmh_kernel_fn(target_log_prob_fn, init_state, scalings, seed=None): """Generate a Random Walk MH kernel.""" with tf.name_scope('make_rwmh_kernel_fn'): seed = SeedStream(seed, salt='make_rwmh_kernel_fn') state_std = [ tf.math.reduce_std(x, axis=0, keepdims=True) for x in init_state ] step_size = [ s * ps.cast( # pylint: disable=g-complex-comprehension mcmc_util.left_justified_expand_dims_like(scalings, s), s.dtype) for s in state_std ] return random_walk_metropolis.RandomWalkMetropolis( target_log_prob_fn, new_state_fn=random_walk_metropolis.random_walk_normal_fn( scale=step_size), seed=seed)
def _sample_n(self, n, seed=None): concentration = tf.convert_to_tensor(self.concentration) mixing_concentration = tf.convert_to_tensor(self.mixing_concentration) mixing_rate = tf.convert_to_tensor(self.mixing_rate) seed = SeedStream(seed, 'gamma_gamma') rate = tf.random.gamma( shape=[n], # Be sure to draw enough rates for the fully-broadcasted gamma-gamma. alpha=mixing_concentration + tf.zeros_like(concentration), beta=mixing_rate, dtype=self.dtype, seed=seed()) return tf.random.gamma(shape=[], alpha=concentration, beta=rate, dtype=self.dtype, seed=seed())
def _sample_n(self, n, seed=None): # Like with the univariate Student's t, sampling can be implemented as a # ratio of samples from a multivariate gaussian with the appropriate # covariance matrix and a sample from the chi-squared distribution. seed = SeedStream(seed, salt='multivariate t') loc = tf.broadcast_to(self.loc, self._sample_shape()) mvn = mvn_linear_operator.MultivariateNormalLinearOperator( loc=tf.zeros_like(loc), scale=self.scale) normal_samp = mvn.sample(n, seed=seed()) df = tf.broadcast_to(self.df, self.batch_shape_tensor()) chi2 = chi2_lib.Chi2(df=df) chi2_samp = chi2.sample(n, seed=seed()) return ( self._loc + normal_samp * tf.math.rsqrt(chi2_samp / self._df)[..., tf.newaxis])
def _sample_n(self, n, seed=None): distribution0 = self._get_distribution0() if self._num_steps is not None: num_steps = tf.convert_to_tensor(self._num_steps) num_steps_static = tf.get_static_value(num_steps) else: num_steps_static = tensorshape_util.num_elements( distribution0.event_shape) if num_steps_static is None: num_steps = tf.reduce_prod(distribution0.event_shape_tensor()) stateless_seed = samplers.sanitize_seed(seed, salt='Autoregressive') stateful_seed = None try: samples = distribution0.sample(n, seed=stateless_seed) is_stateful_sampler = False except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise msg = ( 'Falling back to stateful sampling for `distribution_fn(sample0)` of ' 'type `{}`. Please update to use `tf.random.stateless_*` RNGs. ' 'This fallback may be removed after 20-Aug-2020. ({})') warnings.warn( msg.format(distribution0.name, type(distribution0), str(e))) stateful_seed = SeedStream(seed, salt='Autoregressive')() samples = distribution0.sample(n, seed=stateful_seed) is_stateful_sampler = True seed = stateful_seed if is_stateful_sampler else stateless_seed if num_steps_static is not None: for _ in range(num_steps_static): # pylint: disable=not-callable samples = self.distribution_fn(samples).sample(seed=seed) else: # pylint: disable=not-callable samples = tf.foldl( lambda s, _: self.distribution_fn(s).sample(seed=seed), elems=tf.range(0, num_steps), initializer=samples) return samples
def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None): """Executes `model`, creating both samples and distributions.""" ds = [] values_out = [] seed = SeedStream('JointDistributionCoroutine', seed) gen = self._model() index = 0 d = next(gen) if not isinstance(d, self.Root): raise ValueError('First distribution yielded by coroutine must ' 'be wrapped in `Root`.') try: while True: actual_distribution = d.distribution if isinstance( d, self.Root) else d ds.append(actual_distribution) if (value is not None and len(value) > index and value[index] is not None): seed() next_value = value[index] else: next_value = actual_distribution.sample( sample_shape=sample_shape if isinstance(d, self.Root) else (), seed=seed()) if self._validate_args: with tf.control_dependencies( self._assert_compatible_shape( index, sample_shape, next_value)): values_out.append( tf.nest.map_structure(tf.identity, next_value)) else: values_out.append(next_value) index += 1 d = gen.send(next_value) except StopIteration: pass return ds, values_out