Example #1
0
 def codomain(self):
     output_event_dim = _get_compose_transform_output_event_dim(self.parts)
     last_output_event_dim = self.parts[-1].codomain.event_dim
     assert output_event_dim >= last_output_event_dim
     if output_event_dim == last_output_event_dim:
         return self.parts[-1].codomain
     else:
         return constraints.independent(self.parts[-1].codomain, output_event_dim - last_output_event_dim)
Example #2
0
 def support(self):
     codomain = self.transforms[-1].codomain
     codomain_event_dim = codomain.event_dim
     assert self.event_dim >= codomain_event_dim
     if self.event_dim == codomain_event_dim:
         return codomain
     else:
         return independent(codomain, self.event_dim - codomain_event_dim)
Example #3
0
 def domain(self):
     input_event_dim = _get_compose_transform_input_event_dim(self.parts)
     first_input_event_dim = self.parts[0].domain.event_dim
     assert input_event_dim >= first_input_event_dim
     if input_event_dim == first_input_event_dim:
         return self.parts[0].domain
     else:
         return constraints.independent(self.parts[0].domain, input_event_dim - first_input_event_dim)
Example #4
0
 def codomain(self):
     return constraints.independent(self.base_transform.codomain, self.reinterpreted_batch_ndims)
Example #5
0
class DirichletMultinomial(Distribution):
    r"""
    Compound distribution comprising of a dirichlet-multinomial pair. The probability of
    classes (``probs`` for the :class:`~numpyro.distributions.Multinomial` distribution)
    is unknown and randomly drawn from a :class:`~numpyro.distributions.Dirichlet`
    distribution prior to a certain number of Categorical trials given by
    ``total_count``.

    :param numpy.ndarray concentration: concentration parameter (alpha) for the
        Dirichlet distribution.
    :param numpy.ndarray total_count: number of Categorical trials.
    """
    arg_constraints = {
        'concentration': constraints.independent(constraints.positive, 1),
        'total_count': constraints.nonnegative_integer
    }
    is_discrete = True

    def __init__(self, concentration, total_count=1, validate_args=None):
        if jnp.ndim(concentration) < 1:
            raise ValueError(
                "`concentration` parameter must be at least one-dimensional.")

        batch_shape = lax.broadcast_shapes(
            jnp.shape(concentration)[:-1], jnp.shape(total_count))
        concentration_shape = batch_shape + jnp.shape(concentration)[-1:]
        self.concentration, = promote_shapes(concentration,
                                             shape=concentration_shape)
        self.total_count, = promote_shapes(total_count, shape=batch_shape)
        concentration = jnp.broadcast_to(self.concentration,
                                         concentration_shape)
        self._dirichlet = Dirichlet(concentration)
        super().__init__(self._dirichlet.batch_shape,
                         self._dirichlet.event_shape,
                         validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        key_dirichlet, key_multinom = random.split(key)
        probs = self._dirichlet.sample(key_dirichlet, sample_shape)
        return MultinomialProbs(total_count=self.total_count,
                                probs=probs).sample(key_multinom)

    @validate_sample
    def log_prob(self, value):
        alpha = self.concentration
        return (_log_beta_1(alpha.sum(-1), value.sum(-1)) -
                _log_beta_1(alpha, value).sum(-1))

    @property
    def mean(self):
        return self._dirichlet.mean * jnp.expand_dims(self.total_count, -1)

    @property
    def variance(self):
        n = jnp.expand_dims(self.total_count, -1)
        alpha = self.concentration
        alpha_sum = self.concentration.sum(-1, keepdims=True)
        alpha_ratio = alpha / alpha_sum
        return n * alpha_ratio * (1 - alpha_ratio) * (n + alpha_sum) / (
            1 + alpha_sum)

    @constraints.dependent_property(is_discrete=True, event_dim=1)
    def support(self):
        return constraints.multinomial(self.total_count)

    @staticmethod
    def infer_shapes(concentration, total_count=()):
        batch_shape = lax.broadcast_shapes(concentration[:-1], total_count)
        event_shape = concentration[-1:]
        return batch_shape, event_shape
Example #6
0
class SineBivariateVonMises(Distribution):
    r"""Unimodal distribution of two dependent angles on the 2-torus
    (:math:`S^1 \otimes S^1`) given by

    .. math::
        C^{-1}\exp(\kappa_1\cos(x_1-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2))

    and

    .. math::
        C = (2\pi)^2 \sum_{i=0} {2i \choose i}
        \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2),

    where :math:`I_i(\cdot)` is the modified bessel function of first kind, mu's are the locations of the distribution,
    kappa's are the concentration and rho gives the correlation between angles :math:`x_1` and :math:`x_2`.
    This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains.

    To infer parameters, use :class:`~numpyro.infer.hmc.NUTS` or :class:`~numpyro.infer.hmc.HMC` with priors that
    avoid parameterizations where the distribution becomes bimodal; see note below.

    .. note:: Sample efficiency drops as

        .. math::
            \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1

        because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation`
        parameter with a skew away from one (e.g., Beta(1,3)). The `weighted_correlation` should be in [0,1].

    .. note:: The correlation and weighted_correlation params are mutually exclusive.

    .. note:: In the context of :class:`~numpyro.infer.svi.SVI`, this distribution can be used as a likelihood but not
        for latent variables.

    ** References: **
        1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002)

    :param np.ndarray phi_loc: location of first angle
    :param np.ndarray psi_loc: location of second angle
    :param np.ndarray phi_concentration: concentration of first angle
    :param np.ndarray psi_concentration: concentration of second angle
    :param np.ndarray correlation: correlation between the two angles
    :param np.ndarray weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc)
        to avoid bimodality (see note). The `weighted_correlation` should be in [0,1].
    """

    arg_constraints = {
        "phi_loc": constraints.circular,
        "psi_loc": constraints.circular,
        "phi_concentration": constraints.positive,
        "psi_concentration": constraints.positive,
        "correlation": constraints.real,
    }
    support = constraints.independent(constraints.circular, 1)
    max_sample_iter = 1000

    def __init__(
        self,
        phi_loc,
        psi_loc,
        phi_concentration,
        psi_concentration,
        correlation=None,
        weighted_correlation=None,
        validate_args=None,
    ):
        assert (correlation is None) != (weighted_correlation is None)

        if weighted_correlation is not None:
            correlation = (weighted_correlation *
                           jnp.sqrt(phi_concentration * psi_concentration) +
                           1e-8)

        (
            self.phi_loc,
            self.psi_loc,
            self.phi_concentration,
            self.psi_concentration,
            self.correlation,
        ) = promote_shapes(phi_loc, psi_loc, phi_concentration,
                           psi_concentration, correlation)
        batch_shape = lax.broadcast_shapes(
            jnp.shape(phi_loc),
            jnp.shape(psi_loc),
            jnp.shape(phi_concentration),
            jnp.shape(psi_concentration),
            jnp.shape(correlation),
        )
        super().__init__(batch_shape, (2, ), validate_args)

    @lazy_property
    def norm_const(self):
        corr = jnp.reshape(self.correlation, (1, -1)) + 1e-8
        conc = jnp.stack((self.phi_concentration, self.psi_concentration),
                         axis=-1).reshape(-1, 2)
        m = jnp.arange(50).reshape(-1, 1)
        num = special.gammaln(2 * m + 1.0)
        den = special.gammaln(m + 1.0)
        lbinoms = num - 2 * den

        fs = (lbinoms.reshape(-1, 1) + 2 * m * jnp.log(corr) -
              m * jnp.log(4 * jnp.prod(conc, axis=-1)))
        fs += log_I1(49, conc, terms=51).sum(-1)
        mfs = fs.max()
        norm_const = 2 * jnp.log(jnp.array(2 * pi)) + mfs + logsumexp(
            fs - mfs, 0)
        return norm_const.reshape(jnp.shape(self.phi_loc))

    @validate_sample
    def log_prob(self, value):
        indv = self.phi_concentration * jnp.cos(
            value[..., 0] - self.phi_loc) + self.psi_concentration * jnp.cos(
                value[..., 1] - self.psi_loc)
        corr = (self.correlation * jnp.sin(value[..., 0] - self.phi_loc) *
                jnp.sin(value[..., 1] - self.psi_loc))
        return indv + corr - self.norm_const

    def sample(self, key, sample_shape=()):
        """
        ** References: **
            1. A New Unified Approach for the Simulation of a Wide Class of Directional Distributions
               John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018)
        """
        assert is_prng_key(key)
        phi_key, psi_key = random.split(key)

        corr = self.correlation
        conc = jnp.stack((self.phi_concentration, self.psi_concentration))

        eig = 0.5 * (conc[0] - corr**2 / conc[1])
        eig = jnp.stack((jnp.zeros_like(eig), eig))
        eigmin = jnp.where(eig[1] < 0, eig[1],
                           jnp.zeros_like(eig[1], dtype=eig.dtype))
        eig = eig - eigmin
        b0 = self._bfind(eig)

        total = _numel(sample_shape)
        phi_den = log_I1(0, conc[1]).squeeze(0)
        batch_size = _numel(self.batch_shape)
        phi_shape = (total, 2, batch_size)
        phi_state = SineBivariateVonMises._phi_marginal(
            phi_shape,
            phi_key,
            jnp.reshape(conc, (2, batch_size)),
            jnp.reshape(corr, (batch_size, )),
            jnp.reshape(eig, (2, batch_size)),
            jnp.reshape(b0, (batch_size, )),
            jnp.reshape(eigmin, (batch_size, )),
            jnp.reshape(phi_den, (batch_size, )),
        )

        phi = jnp.arctan2(phi_state.phi[:, 1:], phi_state.phi[:, :1])

        alpha = jnp.sqrt(conc[1]**2 + (corr * jnp.sin(phi))**2)
        beta = jnp.arctan(corr / conc[1] * jnp.sin(phi))

        psi = VonMises(beta, alpha).sample(psi_key)

        phi_psi = jnp.concatenate(
            (
                (phi + self.phi_loc + pi) % (2 * pi) - pi,
                (psi + self.psi_loc + pi) % (2 * pi) - pi,
            ),
            axis=1,
        )
        phi_psi = jnp.transpose(phi_psi, (0, 2, 1))
        return phi_psi.reshape(*sample_shape, *self.batch_shape,
                               *self.event_shape)

    @staticmethod
    def _phi_marginal(shape, rng_key, conc, corr, eig, b0, eigmin, phi_den):
        conc = jnp.broadcast_to(conc, shape)
        eig = jnp.broadcast_to(eig, shape)
        b0 = jnp.broadcast_to(b0, shape)
        eigmin = jnp.broadcast_to(eigmin, shape)
        phi_den = jnp.broadcast_to(phi_den, shape)

        def update_fn(curr):
            i, done, phi, key = curr
            phi_key, key = random.split(key)
            accept_key, acg_key, phi_key = random.split(phi_key, 3)

            x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape)
            x /= jnp.linalg.norm(
                x, axis=1,
                keepdims=True)  # Angular Central Gaussian distribution

            lf = (conc[:, :1] * (x[:, :1] - 1) + eigmin +
                  log_I1(0, jnp.sqrt(conc[:, 1:]**2 +
                                     (corr * x[:, 1:])**2)).squeeze(0) -
                  phi_den)
            assert lf.shape == shape

            lg_inv = (1.0 - b0 / 2 +
                      jnp.log(b0 / 2 + (eig * x**2).sum(1, keepdims=True)))
            assert lg_inv.shape == lf.shape

            accepted = random.uniform(accept_key, shape) < jnp.exp(lf + lg_inv)

            phi = jnp.where(accepted, x, phi)
            return PhiMarginalState(i + 1, done | accepted, phi, key)

        def cond_fn(curr):
            return jnp.bitwise_and(
                curr.i < SineBivariateVonMises.max_sample_iter,
                jnp.logical_not(jnp.all(curr.done)),
            )

        phi_state = while_loop(
            cond_fn,
            update_fn,
            PhiMarginalState(
                i=jnp.array(0),
                done=jnp.zeros(shape, dtype=bool),
                phi=jnp.empty(shape, dtype=float),
                key=rng_key,
            ),
        )
        return PhiMarginalState(phi_state.i, phi_state.done, phi_state.phi,
                                phi_state.key)

    @property
    def mean(self):
        """Computes circular mean of distribution. Note: same as location when mapped to support [-pi, pi]"""
        mean = (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) +
                jnp.pi) % (2.0 * jnp.pi) - jnp.pi
        return jnp.broadcast_to(mean, (*self.batch_shape, 2))

    def _bfind(self, eig):
        b = eig.shape[0] / 2 * jnp.ones(self.batch_shape, dtype=eig.dtype)
        g1 = jnp.sum(1 / (b + 2 * eig)**2, axis=0)
        g2 = jnp.sum(-2 / (b + 2 * eig)**3, axis=0)
        return jnp.where(jnp.linalg.norm(eig, axis=0) != 0, b - g1 / g2, b)
Example #7
0
class SineSkewed(Distribution):
    r"""Sine-skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus
    distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric)
    base distribution. Torus distributions are distributions with support on products of circles
    (i.e., :math:`\otimes S^1` where :math:`S^1 = [-pi,pi)`).
    So, a 0-torus is a point, the 1-torus is a circle,
    and the 2-torus is commonly associated with the donut shape.

    The sine skewed X distribution is parameterized by a weight parameter for each dimension of the event of X.
    For example with a von Mises distribution over a circle (1-torus), the sine skewed von Mises distribution has one
    skew parameter. The skewness parameters can be inferred using :class:`~numpyro.infer.HMC` or
    :class:`~numpyro.infer.NUTS`. For example, the following will produce a prior over
    skewness for the 2-torus,::

        @numpyro.handlers.reparam(config={'phi_loc': CircularReparam(), 'psi_loc': CircularReparam()})
        def model(obs):
            # Sine priors
            phi_loc = numpyro.sample('phi_loc', VonMises(pi, 2.))
            psi_loc = numpyro.sample('psi_loc', VonMises(-pi / 2, 2.))
            phi_conc = numpyro.sample('phi_conc', Beta(1., 1.))
            psi_conc = numpyro.sample('psi_conc', Beta(1., 1.))
            corr_scale = numpyro.sample('corr_scale', Beta(2., 5.))

            # Skewing prior
            ball_trans = L1BallTransform()
            skewness = numpyro.sample('skew_phi', Normal(0, 0.5).expand((2,)))
            skewness = ball_trans(skewness)  # constraint sum |skewness_i| <= 1

            with numpyro.plate('obs_plate'):
                sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc,
                                             phi_concentration=70 * phi_conc,
                                             psi_concentration=70 * psi_conc,
                                             weighted_correlation=corr_scale)
                return numpyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs)

    To ensure the skewing does not alter the normalization constant of the (sine bivariate von Mises) base
    distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of
    skewness to be less than or equal to one. We can use the :class:`~numpyro.distriubtions.transforms.L1BallTransform`
    to achieve this.

    In the context of :class:`~pyro.infer.SVI`, this distribution can freely be used as a likelihood, but use as
    latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist
    cannot be reparameterized.

    .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be `(d,)`.

    .. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event
        must be less than or equal to one. See eq. 2.1 in [1].

    ** References: **
        1. Sine-skewed toroidal distributions and their application in protein bioinformatics
            Ameijeiras-Alonso, J., Ley, C. (2019)

    :param numpyro.distributions.Distribution base_dist: base density on a d-dimensional torus. Supported base
        distributions include: 1D :class:`~numpyro.distributions.VonMises`,
        :class:`~numnumpyro.distributions.SineBivariateVonMises`, 1D :class:`~numpyro.distributions.ProjectedNormal`,
        and :class:`~numpyro.distributions.Uniform` (-pi, pi).
    :param jax.numpy.array skewness: skewness of the distribution.
    """

    arg_constraints = {"skewness": constraints.l1_ball}

    support = constraints.independent(constraints.circular, 1)

    def __init__(self, base_dist: Distribution, skewness, validate_args=None):
        assert (
            base_dist.event_shape == skewness.shape[-1:]
        ), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`."

        batch_shape = jnp.broadcast_shapes(base_dist.batch_shape,
                                           skewness.shape[:-1])
        event_shape = skewness.shape[-1:]
        self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape)
        self.base_dist = base_dist.expand(batch_shape)
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

    def __repr__(self):
        args_string = ", ".join([
            "{}: {}".format(
                p,
                getattr(self, p) if getattr(self, p).numel() == 1 else getattr(
                    self, p).size(),
            ) for p in self.arg_constraints.keys()
        ])
        return (self.__class__.__name__ + "(" +
                f"base_density: {str(self.base_dist)}, " + args_string + ")")

    def sample(self, key, sample_shape=()):
        base_key, skew_key = random.split(key)
        bd = self.base_dist
        ys = bd.sample(base_key, sample_shape)
        u = random.uniform(skew_key, sample_shape + self.batch_shape)

        # Section 2.3 step 3 in [1]
        mask = u <= 0.5 + 0.5 * (self.skewness * jnp.sin(
            (ys - bd.mean) % (2 * jnp.pi))).sum(-1)
        mask = mask[..., None]
        samples = (jnp.where(mask, ys, -ys + 2 * bd.mean) +
                   jnp.pi) % (2 * jnp.pi) - jnp.pi
        return samples

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        if self.base_dist._validate_args:
            self.base_dist._validate_sample(value)

        # Eq. 2.1 in [1]
        skew_prob = jnp.log1p((self.skewness * jnp.sin(
            (value - self.base_dist.mean) % (2 * jnp.pi))).sum(-1))
        return self.base_dist.log_prob(value) + skew_prob

    @property
    def mean(self):
        """Mean of the base distribution"""
        return self.base_dist.mean
Example #8
0
 def support(self):
     return independent(real, self.event_dim)
Example #9
0
 def support(self):
     return independent(self.base_dist.support, self.reinterpreted_batch_ndims)
Example #10
0
 def __init__(self, support, batch_shape, event_shape, validate_args=None):
     self.support = independent(support, len(event_shape) - support.event_dim)
     super().__init__(batch_shape, event_shape, validate_args=validate_args)
Example #11
0
class SineSkewed(Distribution):
    """The Sine Skewed distribution [1] is a distribution for breaking pointwise-symmetry on a base-distribution over
    the d-dimensional torus defined as ⨂^d S^1 where S^1 is the circle. So for example the 0-torus is a point, the
    1-torus is a circle and the 2-tours is commonly associated with the donut shape (some may object to this simile).

    The skewness parameter can be inferred using :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`.
    For example, the following will produce a uniform prior over skewness for the 2-torus,::

        def model(...):
            ...
            skew_phi = pyro.sample(f'skew_phi', Uniform(-1., 1.))
            psi_bound = 1 - skewness_phi.abs()
            skew_psi = pyro.sample(f'skew_psi', Uniform(-1, 1.))
            skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=0)
            ...

    In the context of :class:`~pyro.infer.SVI`, this distribution can be freely used as a likelihood, but use as a
    latent variables will lead to slow inference for 2 and higher order toruses. This is because the base_dist
    cannot be reparameterized.

    .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,).

    .. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event
        must be less than or equal to one. See eq. 2.1 in [1].

    ** References: **
      1. Sine-skewed toroidal distributions and their application in protein bioinformatics
         Ameijeiras-Alonso, J., Ley, C. (2019)

    :param base_dist: base density on a d-dimensional torus.
    :param skewness: skewness of the distribution.
    """

    arg_constraints = {
        "skewness": constraints.independent(constraints.interval(-1.0, 1.0), 1)
    }

    support = constraints.independent(constraints.real, 1)

    def __init__(self, base_dist: Distribution, skewness, validate_args=None):
        batch_shape = jnp.broadcast_shapes(base_dist.batch_shape,
                                           skewness.shape[:-1])
        event_shape = skewness.shape[-1:]
        self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape)
        self.base_dist = base_dist.expand(batch_shape)
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

        if self._validate_args and base_dist.mean.device != skewness.device:
            raise ValueError(
                f"base_density: {base_dist.__class__.__name__} and SineSkewed "
                f"must be on same device.")

    def __repr__(self):
        args_string = ", ".join([
            "{}: {}".format(
                p,
                getattr(self, p) if getattr(self, p).numel() == 1 else getattr(
                    self, p).size(),
            ) for p in self.arg_constraints.keys()
        ])
        return (self.__class__.__name__ + "(" +
                f"base_density: {str(self.base_dist)}, " + args_string + ")")

    def sample(self, key, sample_shape=()):
        base_key, skew_key = random.split(key)
        bd = self.base_dist
        ys = bd.sample(base_key, sample_shape)
        u = random.uniform(skew_key, sample_shape + self.batch_shape)

        # Section 2.3 step 3 in [1]
        mask = u <= 0.5 + 0.5 * (self.skewness * jnp.sin(
            (ys - bd.mean) % (2 * pi))).sum(-1)
        mask = mask[..., None]
        samples = (jnp.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi
        return samples

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        # Eq. 2.1 in [1]
        skew_prob = jnp.log(1 + (self.skewness *
                                 jnp.sin((value - self.base_dist.mean) %
                                         (2 * pi))).sum(-1))
        return self.base_dist.log_prob(value) + skew_prob