Ejemplo n.º 1
0
 def tree_unflatten(cls, aux_data, params):
     if len(aux_data) == 2:
         base_flatten, low, high = params
         base_cls, base_aux = aux_data
     else:
         base_flatten = params
         base_cls, base_aux, low, high = aux_data
     base_gamma = Gamma.tree_unflatten(base_aux, base_flatten)
     return cls(base_gamma, low=low, high=high)
Ejemplo n.º 2
0
class GammaPoisson(Distribution):
    r"""
    Compound distribution comprising of a gamma-poisson pair, also referred to as
    a gamma-poisson mixture. The ``rate`` parameter for the
    :class:`~numpyro.distributions.Poisson` distribution is unknown and randomly
    drawn from a :class:`~numpyro.distributions.Gamma` distribution.

    :param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution.
    :param numpy.ndarray rate: rate parameter (beta) for the Gamma distribution.
    """
    arg_constraints = {
        "concentration": constraints.positive,
        "rate": constraints.positive,
    }
    support = constraints.nonnegative_integer

    def __init__(self, concentration, rate=1.0, validate_args=None):
        self.concentration, self.rate = promote_shapes(concentration, rate)
        self._gamma = Gamma(concentration, rate)
        super(GammaPoisson, self).__init__(
            self._gamma.batch_shape, validate_args=validate_args
        )

    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        key_gamma, key_poisson = random.split(key)
        rate = self._gamma.sample(key_gamma, sample_shape)
        return Poisson(rate).sample(key_poisson)

    @validate_sample
    def log_prob(self, value):
        post_value = self.concentration + value
        return (
            -betaln(self.concentration, value + 1)
            - jnp.log(post_value)
            + self.concentration * jnp.log(self.rate)
            - post_value * jnp.log1p(self.rate)
        )

    @property
    def mean(self):
        return self.concentration / self.rate

    @property
    def variance(self):
        return self.concentration / jnp.square(self.rate) * (1 + self.rate)

    def cdf(self, value):
        bt = betainc(self.concentration, value + 1.0, self.rate / (self.rate + 1.0))
        return bt
Ejemplo n.º 3
0
 def __init__(self, concentration, rate=1., validate_args=None):
     self._gamma = Gamma(concentration, rate)
     self.concentration = self._gamma.concentration
     self.rate = self._gamma.rate
     super(GammaPoisson, self).__init__(self._gamma.batch_shape,
                                        validate_args=validate_args)