def _sample_n(self, n, seed=None): # Here we use the fact that if: # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs) # then X ~ Poisson(lam) is Negative Binomially distributed. logits = self._logits_parameter_no_checks() gamma_seed, poisson_seed = samplers.split_seed( seed, salt='NegativeBinomial') rate = samplers.gamma( shape=[n], alpha=self.total_count, beta=tf.math.exp(-logits), dtype=self.dtype, seed=gamma_seed) return samplers.poisson( shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
def _sample_n(self, n, seed=None): # Here we use the fact that if: # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs) # then X ~ Poisson(lam) is Negative Binomially distributed. logits = self._logits_parameter_no_checks() gamma_seed, poisson_seed = samplers.split_seed( seed, salt='NegativeBinomial') # TODO(b/152785714): For some reason switching to gamma_lib.random_gamma # makes tests time out. Note: observed similar in jax_transformation_test. rate = samplers.gamma( shape=[n], alpha=self.total_count, beta=tf.math.exp(-logits), dtype=self.dtype, seed=gamma_seed) return samplers.poisson( shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. distributions = self.poisson_and_mixture_distributions() dist, mixture_dist = distributions batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod( self._batch_shape_tensor(distributions=distributions)) # We need to 'sample extra' from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. mixture_seed, poisson_seed = samplers.split_seed( seed, salt='PoissonLogNormalQuadratureCompound') ids = mixture_dist.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( mixture_dist.is_scalar_batch(), [batch_size], np.int32([]))), seed=mixture_seed) # We need to flatten batch dims in case mixture_dist has its own # batch dims. ids = tf.reshape( ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range( start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids = ids + offset rate = tf.gather(tf.reshape(dist.rate_parameter(), shape=[-1]), ids) rate = tf.reshape( rate, shape=concat_vectors([n], self._batch_shape_tensor( distributions=distributions))) return samplers.poisson( shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
def _sample_n(self, n, seed=None): lam = self._rate_parameter_no_checks() return samplers.poisson(shape=[n], lam=lam, dtype=self.dtype, seed=seed)