Esempio n. 1
0
    def __init__(self, concentration, total_count=1, validate_args=None):
        if jnp.ndim(concentration) < 1:
            raise ValueError("`concentration` parameter must be at least one-dimensional.")

        batch_shape = lax.broadcast_shapes(jnp.shape(concentration)[:-1], jnp.shape(total_count))
        concentration_shape = batch_shape + jnp.shape(concentration)[-1:]
        self.concentration, = promote_shapes(concentration, shape=concentration_shape)
        self.total_count, = promote_shapes(total_count, shape=batch_shape)
        concentration = jnp.broadcast_to(self.concentration, concentration_shape)
        self._dirichlet = Dirichlet(concentration)
        super().__init__(
            self._dirichlet.batch_shape, self._dirichlet.event_shape, validate_args=validate_args)
 def sample(self, key, sample_shape=()):
     ps = Dirichlet(self.weights).sample(key, sample_shape=sample_shape)
     zs = np.expand_dims(Categorical(ps).sample(key), axis=-1)
     locs = np.broadcast_to(self.locs, sample_shape + self.batch_shape + self.event_shape + self.mixture_shape)
     scales = np.broadcast_to(self.scales, sample_shape + self.batch_shape + self.event_shape + self.mixture_shape)
     mlocs = np.squeeze(np.take_along_axis(locs, zs, axis=-1), axis=-1)
     mscales = np.squeeze(np.take_along_axis(scales, zs, axis=-1), axis=-1)
     return Normal(mlocs, mscales).sample(key)
Esempio n. 3
0
class DirichletMultinomial(Distribution):
    r"""
    Compound distribution comprising of a dirichlet-multinomial pair. The probability of
    classes (``probs`` for the :class:`~numpyro.distributions.Multinomial` distribution)
    is unknown and randomly drawn from a :class:`~numpyro.distributions.Dirichlet`
    distribution prior to a certain number of Categorical trials given by
    ``total_count``.

    :param numpy.ndarray concentration: concentration parameter (alpha) for the
        Dirichlet distribution.
    :param numpy.ndarray total_count: number of Categorical trials.
    """
    arg_constraints = {'concentration': constraints.positive,
                       'total_count': constraints.nonnegative_integer}
    is_discrete = True

    def __init__(self, concentration, total_count=1, validate_args=None):
        if jnp.ndim(concentration) < 1:
            raise ValueError("`concentration` parameter must be at least one-dimensional.")

        batch_shape = lax.broadcast_shapes(jnp.shape(concentration)[:-1], jnp.shape(total_count))
        concentration_shape = batch_shape + jnp.shape(concentration)[-1:]
        self.concentration, = promote_shapes(concentration, shape=concentration_shape)
        self.total_count, = promote_shapes(total_count, shape=batch_shape)
        concentration = jnp.broadcast_to(self.concentration, concentration_shape)
        self._dirichlet = Dirichlet(concentration)
        super().__init__(
            self._dirichlet.batch_shape, self._dirichlet.event_shape, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        key_dirichlet, key_multinom = random.split(key)
        probs = self._dirichlet.sample(key_dirichlet, sample_shape)
        return MultinomialProbs(total_count=self.total_count, probs=probs).sample(key_multinom)

    @validate_sample
    def log_prob(self, value):
        alpha = self.concentration
        return (_log_beta_1(alpha.sum(-1), value.sum(-1)) -
                _log_beta_1(alpha, value).sum(-1))

    @property
    def mean(self):
        return self._dirichlet.mean * jnp.expand_dims(self.total_count, -1)

    @property
    def variance(self):
        n = jnp.expand_dims(self.total_count, -1)
        alpha = self.concentration
        alpha_sum = self.concentration.sum(-1, keepdims=True)
        alpha_ratio = alpha / alpha_sum
        return n * alpha_ratio * (1 - alpha_ratio) * (n + alpha_sum) / (1 + alpha_sum)

    @property
    def support(self):
        return constraints.multinomial(self.total_count)