def _sample_n(self, n, seed=None):
   # Here we use the fact that if:
   # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
   # then X ~ Poisson(lam) is Negative Binomially distributed.
   logits = self._logits_parameter_no_checks()
   gamma_seed, poisson_seed = samplers.split_seed(
       seed, salt='NegativeBinomial')
   rate = samplers.gamma(
       shape=[n],
       alpha=self.total_count,
       beta=tf.math.exp(-logits),
       dtype=self.dtype,
       seed=gamma_seed)
   return samplers.poisson(
       shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
 def _sample_n(self, n, seed=None):
   # Here we use the fact that if:
   # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
   # then X ~ Poisson(lam) is Negative Binomially distributed.
   logits = self._logits_parameter_no_checks()
   gamma_seed, poisson_seed = samplers.split_seed(
       seed, salt='NegativeBinomial')
   # TODO(b/152785714): For some reason switching to gamma_lib.random_gamma
   # makes tests time out. Note: observed similar in jax_transformation_test.
   rate = samplers.gamma(
       shape=[n],
       alpha=self.total_count,
       beta=tf.math.exp(-logits),
       dtype=self.dtype,
       seed=gamma_seed)
   return samplers.poisson(
       shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
Exemplo n.º 3
0
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    distributions = self.poisson_and_mixture_distributions()
    dist, mixture_dist = distributions
    batch_size = tensorshape_util.num_elements(self.batch_shape)
    if batch_size is None:
      batch_size = tf.reduce_prod(
          self._batch_shape_tensor(distributions=distributions))
    # We need to 'sample extra' from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    mixture_seed, poisson_seed = samplers.split_seed(
        seed, salt='PoissonLogNormalQuadratureCompound')
    ids = mixture_dist.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                mixture_dist.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=mixture_seed)
    # We need to flatten batch dims in case mixture_dist has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = tf.range(
        start=0,
        limit=batch_size * self._quadrature_size,
        delta=self._quadrature_size,
        dtype=ids.dtype)
    ids = ids + offset
    rate = tf.gather(tf.reshape(dist.rate_parameter(), shape=[-1]), ids)
    rate = tf.reshape(
        rate, shape=concat_vectors([n], self._batch_shape_tensor(
            distributions=distributions)))
    return samplers.poisson(
        shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
Exemplo n.º 4
0
 def _sample_n(self, n, seed=None):
     lam = self._rate_parameter_no_checks()
     return samplers.poisson(shape=[n],
                             lam=lam,
                             dtype=self.dtype,
                             seed=seed)