def _sample_n(self, n, seed=None): seed1, seed2 = samplers.split_seed(seed, salt='beta') concentration1 = tf.convert_to_tensor(self.concentration1) concentration0 = tf.convert_to_tensor(self.concentration0) shape = self._batch_shape_tensor(concentration1, concentration0) expanded_concentration1 = tf.broadcast_to(concentration1, shape) expanded_concentration0 = tf.broadcast_to(concentration0, shape) gamma1_sample = samplers.gamma( shape=[n], alpha=expanded_concentration1, dtype=self.dtype, seed=seed1) gamma2_sample = samplers.gamma( shape=[n], alpha=expanded_concentration0, dtype=self.dtype, seed=seed2) beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) return beta_sample
def _sample_n(self, n, seed=None): concentration = tf.convert_to_tensor(self.concentration) mixing_concentration = tf.convert_to_tensor(self.mixing_concentration) mixing_rate = tf.convert_to_tensor(self.mixing_rate) seed_rate, seed_samples = samplers.split_seed(seed, salt='gamma_gamma') rate = samplers.gamma( shape=[n], # Be sure to draw enough rates for the fully-broadcasted gamma-gamma. alpha=mixing_concentration + tf.zeros_like(concentration), beta=mixing_rate, dtype=self.dtype, seed=seed_rate) return samplers.gamma(shape=[], alpha=concentration, beta=rate, dtype=self.dtype, seed=seed_samples)
def _sample_n(self, n, seed=None): # Here we use the fact that if: # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs) # then X ~ Poisson(lam) is Negative Binomially distributed. logits = self._logits_parameter_no_checks() gamma_seed, poisson_seed = samplers.split_seed( seed, salt='NegativeBinomial') rate = samplers.gamma( shape=[n], alpha=self.total_count, beta=tf.math.exp(-logits), dtype=self.dtype, seed=gamma_seed) return samplers.poisson( shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
def _sample_n(self, n, seed=None): # Here we use the fact that if: # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs) # then X ~ Poisson(lam) is Negative Binomially distributed. logits = self._logits_parameter_no_checks() gamma_seed, poisson_seed = samplers.split_seed( seed, salt='NegativeBinomial') # TODO(b/152785714): For some reason switching to gamma_lib.random_gamma # makes tests time out. Note: observed similar in jax_transformation_test. rate = samplers.gamma( shape=[n], alpha=self.total_count, beta=tf.math.exp(-logits), dtype=self.dtype, seed=gamma_seed) return samplers.poisson( shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
def sample_n(n, df, loc, scale, batch_shape, dtype, seed): """Draw n samples from a Student T distribution. Note that `scale` can be negative or zero. The sampling method comes from the fact that if: X ~ Normal(0, 1) Z ~ Chi2(df) Y = X / sqrt(Z / df) then: Y ~ StudentT(df) Args: n: int, number of samples df: Floating-point `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. loc: Floating-point `Tensor`; the location(s) of the distribution(s). scale: Floating-point `Tensor`; the scale(s) of the distribution(s). Must contain only positive values. batch_shape: Callable to compute batch shape dtype: Return dtype. seed: Optional seed for random draw. Returns: samples: a `Tensor` with prepended dimensions `n`. """ normal_seed, gamma_seed = samplers.split_seed(seed, salt='student_t') shape = tf.concat([[n], batch_shape], 0) normal_sample = samplers.normal(shape, dtype=dtype, seed=normal_seed) df = df * tf.ones(batch_shape, dtype=dtype) gamma_sample = samplers.gamma([n], 0.5 * df, beta=0.5, dtype=dtype, seed=gamma_seed) samples = normal_sample * tf.math.rsqrt(gamma_sample / df) return samples * scale + loc
def _sample_n(self, n, seed=None): gamma_seed, multinomial_seed = samplers.split_seed( seed, salt='dirichlet_multinomial') concentration = tf.convert_to_tensor(self._concentration) total_count = tf.convert_to_tensor(self._total_count) n_draws = tf.cast(total_count, dtype=tf.int32) k = self._event_shape_tensor(concentration)[0] alpha = tf.math.multiply(tf.ones_like(total_count[..., tf.newaxis]), concentration, name='alpha') unnormalized_logits = tf.math.log( samplers.gamma(shape=[n], alpha=alpha, dtype=self.dtype, seed=gamma_seed)) x = multinomial.draw_sample(1, k, unnormalized_logits, n_draws, self.dtype, multinomial_seed) final_shape = tf.concat( [[n], self._batch_shape_tensor(concentration, total_count), [k]], 0) return tf.reshape(x, final_shape)
def _sample_n(self, n, seed=None): return 1. / samplers.gamma(shape=[n], alpha=self.concentration, beta=self.scale, dtype=self.dtype, seed=seed)
def _sample_n(self, n, seed=None): gamma_sample = samplers.gamma( shape=[n], alpha=self.concentration, dtype=self.dtype, seed=seed) return gamma_sample / tf.reduce_sum(gamma_sample, axis=-1, keepdims=True)
def _sample_n(self, n, seed=None): return samplers.gamma(shape=[n], alpha=0.5 * self.df, beta=tf.convert_to_tensor(0.5, dtype=self.dtype), dtype=self.dtype, seed=seed)
def slice_sampler_one_dim(target_log_prob, x_initial, step_size=0.01, max_doublings=30, seed=None, name=None): """For a given x position in each Markov chain, returns the next x. Applies the one dimensional slice sampling algorithm as defined in Neal (2003) to an input tensor x of shape (num_chains,) where num_chains is the number of simulataneous Markov chains, and returns the next tensor x of shape (num_chains,) when these chains are evolved by the slice sampling algorithm. Args: target_log_prob: Callable accepting a tensor like `x_initial` and returning a tensor containing the log density at that point of the same shape. x_initial: A tensor of any shape. The initial positions of the chains. This function assumes that all the dimensions of `x_initial` are batch dimensions (i.e. the event shape is `[]`). step_size: A tensor of shape and dtype compatible with `x_initial`. The min interval size in the doubling algorithm. max_doublings: Scalar tensor of dtype `tf.int32`. The maximum number of doublings to try to find the slice bounds. seed: (Optional) positive int, or Tensor seed pair. The random seed. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'find_slice_bounds'). Returns: retval: A tensor of the same shape and dtype as `x_initial`. The next state of the Markov chain. next_target_log_prob: The target log density evaluated at `retval`. bounds_satisfied: A tensor of bool dtype and shape batch dimensions. upper_bounds: Tensor of the same shape and dtype as `x_initial`. The upper bounds for the slice found. lower_bounds: Tensor of the same shape and dtype as `x_initial`. The lower bounds for the slice found. """ gamma_seed, bounds_seed, sample_seed = samplers.split_seed( seed, n=3, salt='ssu.slice_sampler_one_dim') with tf.name_scope(name or 'slice_sampler_one_dim'): dtype = dtype_util.common_dtype([x_initial, step_size], dtype_hint=tf.float32) x_initial = tf.convert_to_tensor(x_initial, dtype=dtype) step_size = tf.convert_to_tensor(step_size, dtype=dtype) # Obtain the input dtype of the array. # Select the height of the slice. Tensor of shape x_initial.shape. log_slice_heights = target_log_prob(x_initial) - samplers.gamma( ps.shape(x_initial), alpha=1, dtype=dtype, seed=gamma_seed) # Given the above x and slice heights, compute the bounds of the slice for # each chain. upper_bounds, lower_bounds, bounds_satisfied = slice_bounds_by_doubling( x_initial, target_log_prob, log_slice_heights, max_doublings, step_size, seed=bounds_seed) retval = _sample_with_shrinkage(x_initial, target_log_prob=target_log_prob, log_slice_heights=log_slice_heights, step_size=step_size, lower_bounds=lower_bounds, upper_bounds=upper_bounds, seed=sample_seed) return (retval, target_log_prob(retval), bounds_satisfied, upper_bounds, lower_bounds)
def _sample_n(self, n, seed): df = tf.convert_to_tensor(self.df) batch_shape = self._batch_shape_tensor(df) event_shape = self._event_shape_tensor() batch_ndims = tf.shape(batch_shape)[0] ndims = batch_ndims + 3 # sample_ndims=1, event_ndims=2 shape = tf.concat([[n], batch_shape, event_shape], 0) normal_seed, gamma_seed = samplers.split_seed(seed, salt='Wishart') # Complexity: O(nbk**2) x = samplers.normal(shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=normal_seed) # Complexity: O(nbk) # This parameterization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) expanded_df = df * tf.ones(self._scale.batch_shape_tensor(), dtype=dtype_util.base_dtype(df.dtype)) g = samplers.gamma(shape=[n], alpha=self._multi_gamma_sequence( 0.5 * expanded_df, self._dimension()), beta=0.5, dtype=self.dtype, seed=gamma_seed) # Complexity: O(nbk**2) x = tf.linalg.band_part(x, -1, 0) # Tri-lower. # Complexity: O(nbk) x = tf.linalg.set_diag(x, tf.sqrt(g)) # Make batch-op ready. # Complexity: O(nbk**2) perm = tf.concat([tf.range(1, ndims), [0]], 0) x = tf.transpose(a=x, perm=perm) shape = tf.concat( [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0) x = tf.reshape(x, shape) # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so # this step has complexity O(nbk^3). x = self._scale.matmul(x) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = tf.concat([batch_shape, event_shape, [n]], 0) x = tf.reshape(x, shape) perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0) x = tf.transpose(a=x, perm=perm) if not self.input_output_cholesky: # Complexity: O(nbk**3) x = tf.matmul(x, x, adjoint_b=True) return x