コード例 #1
0
ファイル: conjugate.py プロジェクト: xidulu/numpyro
 def __init__(self, concentration1, concentration0, total_count=1, validate_args=None):
     self.concentration1, self.concentration0, self.total_count = promote_shapes(
         concentration1, concentration0, total_count
     )
     batch_shape = lax.broadcast_shapes(jnp.shape(concentration1), jnp.shape(concentration0),
                                        jnp.shape(total_count))
     concentration1 = jnp.broadcast_to(concentration1, batch_shape)
     concentration0 = jnp.broadcast_to(concentration0, batch_shape)
     self._beta = Beta(concentration1, concentration0)
     super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args)
コード例 #2
0
class BetaBinomial(Distribution):
    r"""
    Compound distribution comprising of a beta-binomial pair. The probability of
    success (``probs`` for the :class:`~numpyro.distributions.Binomial` distribution)
    is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution
    prior to a certain number of Bernoulli trials given by ``total_count``.

    :param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
        Beta distribution.
    :param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
        Beta distribution.
    :param numpy.ndarray total_count: number of Bernoulli trials.
    """
    arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive,
                       'total_count': constraints.nonnegative_integer}
    has_enumerate_support = True
    is_discrete = True
    enumerate_support = BinomialProbs.enumerate_support

    def __init__(self, concentration1, concentration0, total_count=1, validate_args=None):
        self.concentration1, self.concentration0, self.total_count = promote_shapes(
            concentration1, concentration0, total_count
        )
        batch_shape = lax.broadcast_shapes(jnp.shape(concentration1), jnp.shape(concentration0),
                                           jnp.shape(total_count))
        concentration1 = jnp.broadcast_to(concentration1, batch_shape)
        concentration0 = jnp.broadcast_to(concentration0, batch_shape)
        self._beta = Beta(concentration1, concentration0)
        super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        key_beta, key_binom = random.split(key)
        probs = self._beta.sample(key_beta, sample_shape)
        return BinomialProbs(total_count=self.total_count, probs=probs).sample(key_binom)

    @validate_sample
    def log_prob(self, value):
        return (-_log_beta_1(self.total_count - value + 1, value) +
                betaln(value + self.concentration1, self.total_count - value + self.concentration0) -
                betaln(self.concentration0, self.concentration1))

    @property
    def mean(self):
        return self._beta.mean * self.total_count

    @property
    def variance(self):
        return self._beta.variance * self.total_count * (self.concentration0 + self.concentration1 + self.total_count)

    @property
    def support(self):
        return constraints.integer_interval(0, self.total_count)
コード例 #3
0
class BetaBinomial(Distribution):
    r"""
    Compound distribution comprising of a beta-binomial pair. The probability of
    success (``probs`` for the :class:`~numpyro.distributions.Binomial` distribution)
    is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution
    prior to a certain number of Bernoulli trials given by ``total_count``.

    :param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
        Beta distribution.
    :param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
        Beta distribution.
    :param numpy.ndarray total_count: number of Bernoulli trials.
    """
    arg_constraints = {
        'concentration1': constraints.positive,
        'concentration0': constraints.positive,
        'total_count': constraints.nonnegative_integer
    }
    has_enumerate_support = True
    is_discrete = True

    def __init__(self,
                 concentration1,
                 concentration0,
                 total_count=1,
                 validate_args=None):
        batch_shape = lax.broadcast_shapes(np.shape(concentration1),
                                           np.shape(concentration0),
                                           np.shape(total_count))
        self.concentration1 = np.broadcast_to(concentration1, batch_shape)
        self.concentration0 = np.broadcast_to(concentration0, batch_shape)
        self.total_count, = promote_shapes(total_count, shape=batch_shape)
        self._beta = Beta(self.concentration1, self.concentration0)
        super(BetaBinomial, self).__init__(batch_shape,
                                           validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        key_beta, key_binom = random.split(key)
        probs = self._beta.sample(key_beta, sample_shape)
        return Binomial(self.total_count, probs).sample(key_binom)

    @validate_sample
    def log_prob(self, value):
        log_factorial_n = gammaln(self.total_count + 1)
        log_factorial_k = gammaln(value + 1)
        log_factorial_nmk = gammaln(self.total_count - value + 1)
        return (log_factorial_n - log_factorial_k - log_factorial_nmk +
                betaln(value + self.concentration1,
                       self.total_count - value + self.concentration0) -
                betaln(self.concentration0, self.concentration1))

    @property
    def mean(self):
        return self._beta.mean * self.total_count

    @property
    def variance(self):
        return self._beta.variance * self.total_count * (
            self.concentration0 + self.concentration1 + self.total_count)

    @property
    def support(self):
        return constraints.integer_interval(0, self.total_count)

    def enumerate_support(self, expand=True):
        total_count = np.amax(self.total_count)
        if not_jax_tracer(total_count):
            # NB: the error can't be raised if inhomogeneous issue happens when tracing
            if np.amin(self.total_count) != total_count:
                raise NotImplementedError(
                    "Inhomogeneous total count not supported"
                    " by `enumerate_support`.")
        values = np.arange(total_count +
                           1).reshape((-1, ) + (1, ) * len(self.batch_shape))
        if expand:
            values = np.broadcast_to(values,
                                     values.shape[:1] + self.batch_shape)
        return values