class BetaBinomial(Distribution): r""" Compound distribution comprising of a beta-binomial pair. The probability of success (``probs`` for the :class:`~numpyro.distributions.Binomial` distribution) is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution prior to a certain number of Bernoulli trials given by ``total_count``. :param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the Beta distribution. :param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the Beta distribution. :param numpy.ndarray total_count: number of Bernoulli trials. """ arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive, 'total_count': constraints.nonnegative_integer} has_enumerate_support = True is_discrete = True enumerate_support = BinomialProbs.enumerate_support def __init__(self, concentration1, concentration0, total_count=1, validate_args=None): self.concentration1, self.concentration0, self.total_count = promote_shapes( concentration1, concentration0, total_count ) batch_shape = lax.broadcast_shapes(jnp.shape(concentration1), jnp.shape(concentration0), jnp.shape(total_count)) concentration1 = jnp.broadcast_to(concentration1, batch_shape) concentration0 = jnp.broadcast_to(concentration0, batch_shape) self._beta = Beta(concentration1, concentration0) super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args) def sample(self, key, sample_shape=()): assert is_prng_key(key) key_beta, key_binom = random.split(key) probs = self._beta.sample(key_beta, sample_shape) return BinomialProbs(total_count=self.total_count, probs=probs).sample(key_binom) @validate_sample def log_prob(self, value): return (-_log_beta_1(self.total_count - value + 1, value) + betaln(value + self.concentration1, self.total_count - value + self.concentration0) - betaln(self.concentration0, self.concentration1)) @property def mean(self): return self._beta.mean * self.total_count @property def variance(self): return self._beta.variance * self.total_count * (self.concentration0 + self.concentration1 + self.total_count) @property def support(self): return constraints.integer_interval(0, self.total_count)
class BetaBinomial(Distribution): r""" Compound distribution comprising of a beta-binomial pair. The probability of success (``probs`` for the :class:`~numpyro.distributions.Binomial` distribution) is unknown and randomly drawn from a :class:`~numpyro.distributions.Beta` distribution prior to a certain number of Bernoulli trials given by ``total_count``. :param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the Beta distribution. :param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the Beta distribution. :param numpy.ndarray total_count: number of Bernoulli trials. """ arg_constraints = { 'concentration1': constraints.positive, 'concentration0': constraints.positive, 'total_count': constraints.nonnegative_integer } has_enumerate_support = True is_discrete = True def __init__(self, concentration1, concentration0, total_count=1, validate_args=None): batch_shape = lax.broadcast_shapes(np.shape(concentration1), np.shape(concentration0), np.shape(total_count)) self.concentration1 = np.broadcast_to(concentration1, batch_shape) self.concentration0 = np.broadcast_to(concentration0, batch_shape) self.total_count, = promote_shapes(total_count, shape=batch_shape) self._beta = Beta(self.concentration1, self.concentration0) super(BetaBinomial, self).__init__(batch_shape, validate_args=validate_args) def sample(self, key, sample_shape=()): key_beta, key_binom = random.split(key) probs = self._beta.sample(key_beta, sample_shape) return Binomial(self.total_count, probs).sample(key_binom) @validate_sample def log_prob(self, value): log_factorial_n = gammaln(self.total_count + 1) log_factorial_k = gammaln(value + 1) log_factorial_nmk = gammaln(self.total_count - value + 1) return (log_factorial_n - log_factorial_k - log_factorial_nmk + betaln(value + self.concentration1, self.total_count - value + self.concentration0) - betaln(self.concentration0, self.concentration1)) @property def mean(self): return self._beta.mean * self.total_count @property def variance(self): return self._beta.variance * self.total_count * ( self.concentration0 + self.concentration1 + self.total_count) @property def support(self): return constraints.integer_interval(0, self.total_count) def enumerate_support(self, expand=True): total_count = np.amax(self.total_count) if not_jax_tracer(total_count): # NB: the error can't be raised if inhomogeneous issue happens when tracing if np.amin(self.total_count) != total_count: raise NotImplementedError( "Inhomogeneous total count not supported" " by `enumerate_support`.") values = np.arange(total_count + 1).reshape((-1, ) + (1, ) * len(self.batch_shape)) if expand: values = np.broadcast_to(values, values.shape[:1] + self.batch_shape) return values