class RejectionStandardGamma(Rejector):
    """
    Naive Marsaglia & Tsang rejection sampler for standard Gamma distibution.
    This assumes `concentration >= 1` and does not boost `concentration` or augment shape.
    """
    def __init__(self, concentration):
        if concentration.data.min() < 1:
            raise NotImplementedError('concentration < 1 is not supported')
        self.concentration = concentration
        self._standard_gamma = Gamma(
            concentration,
            concentration.new([1.]).squeeze().expand_as(concentration))
        # The following are Marsaglia & Tsang's variable names.
        self._d = self.concentration - 1.0 / 3.0
        self._c = 1.0 / torch.sqrt(9.0 * self._d)
        # Compute log scale using Gamma.log_prob().
        x = self._d.detach()  # just an arbitrary x.
        log_scale = self.propose_log_prob(x) + self.log_prob_accept(
            x) - self.log_prob(x)
        super().__init__(self.propose,
                         self.log_prob_accept,
                         log_scale,
                         batch_shape=concentration.shape,
                         event_shape=())

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(RejectionStandardGamma, _instance)
        batch_shape = torch.Size(batch_shape)
        new.concentration = self.concentration.expand(batch_shape)
        new._standard_gamma = self._standard_gamma.expand(batch_shape)
        new._d = self._d.expand(batch_shape)
        new._c = self._c.expand(batch_shape)
        # Compute log scale using Gamma.log_prob().
        x = new._d.detach()  # just an arbitrary x.
        log_scale = new.propose_log_prob(x) + new.log_prob_accept(
            x) - new.log_prob(x)
        super(RejectionStandardGamma, new).__init__(new.propose,
                                                    new.log_prob_accept,
                                                    log_scale,
                                                    batch_shape=batch_shape,
                                                    event_shape=())
        new._validate_args = self._validate_args
        return new

    @weakmethod
    def propose(self, sample_shape=torch.Size()):
        # Marsaglia & Tsang's x == Naesseth's epsilon`
        x = torch.randn(sample_shape + self.concentration.shape,
                        dtype=self.concentration.dtype,
                        device=self.concentration.device)
        y = 1.0 + self._c * x
        v = y * y * y
        return (self._d * v).clamp_(1e-30, 1e30)

    def propose_log_prob(self, value):
        v = value / self._d
        result = -self._d.log()
        y = v.pow(1 / 3)
        result -= torch.log(3 * y**2)
        x = (y - 1) / self._c
        result -= self._c.log()
        result += Normal(torch.zeros_like(self.concentration),
                         torch.ones_like(self.concentration)).log_prob(x)
        return result

    @weakmethod
    def log_prob_accept(self, value):
        v = value / self._d
        y = torch.pow(v, 1.0 / 3.0)
        x = (y - 1.0) / self._c
        log_prob_accept = 0.5 * x * x + self._d * (1.0 - v + torch.log(v))
        log_prob_accept[y <= 0] = -float('inf')
        return log_prob_accept

    def log_prob(self, x):
        return self._standard_gamma.log_prob(x)
Exemple #2
0
class GammaPoisson(TorchDistribution):
    r"""
    Compound distribution comprising of a gamma-poisson pair, also referred to as
    a gamma-poisson mixture. The ``rate`` parameter for the
    :class:`~pyro.distributions.Poisson` distribution is unknown and randomly
    drawn from a :class:`~pyro.distributions.Gamma` distribution.

    .. note:: This can be treated as an alternate parametrization of the
        :class:`~pyro.distributions.NegativeBinomial` (``total_count``, ``probs``)
        distribution, with `concentration = total_count` and `rate = (1 - probs) / probs`.

    :param float or torch.Tensor concentration: shape parameter (alpha) of the Gamma
        distribution.
    :param float or torch.Tensor rate: rate parameter (beta) for the Gamma
        distribution.
    """

    arg_constraints = {
        'concentration': constraints.positive,
        'rate': constraints.positive
    }
    support = Poisson.support

    def __init__(self, concentration, rate, validate_args=None):
        concentration, rate = broadcast_all(concentration, rate)
        self._gamma = Gamma(concentration, rate)
        super(GammaPoisson, self).__init__(self._gamma._batch_shape,
                                           validate_args=validate_args)

    @property
    def concentration(self):
        return self._gamma.concentration

    @property
    def rate(self):
        return self._gamma.rate

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(GammaPoisson, _instance)
        batch_shape = torch.Size(batch_shape)
        new._gamma = self._gamma.expand(batch_shape)
        super(GammaPoisson, new).__init__(batch_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new

    def sample(self, sample_shape=()):
        rate = self._gamma.sample(sample_shape)
        return Poisson(rate).sample()

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        post_value = self.concentration + value
        return -_log_beta(self.concentration, value + 1) - post_value.log() + \
            self.concentration * self.rate.log() - post_value * (1 + self.rate).log()

    @property
    def mean(self):
        return self.concentration / self.rate

    @property
    def variance(self):
        return self.concentration / self.rate.pow(2) * (1 + self.rate)