def codomain(self): output_event_dim = _get_compose_transform_output_event_dim(self.parts) last_output_event_dim = self.parts[-1].codomain.event_dim assert output_event_dim >= last_output_event_dim if output_event_dim == last_output_event_dim: return self.parts[-1].codomain else: return constraints.independent(self.parts[-1].codomain, output_event_dim - last_output_event_dim)
def support(self): codomain = self.transforms[-1].codomain codomain_event_dim = codomain.event_dim assert self.event_dim >= codomain_event_dim if self.event_dim == codomain_event_dim: return codomain else: return independent(codomain, self.event_dim - codomain_event_dim)
def domain(self): input_event_dim = _get_compose_transform_input_event_dim(self.parts) first_input_event_dim = self.parts[0].domain.event_dim assert input_event_dim >= first_input_event_dim if input_event_dim == first_input_event_dim: return self.parts[0].domain else: return constraints.independent(self.parts[0].domain, input_event_dim - first_input_event_dim)
def codomain(self): return constraints.independent(self.base_transform.codomain, self.reinterpreted_batch_ndims)
class DirichletMultinomial(Distribution): r""" Compound distribution comprising of a dirichlet-multinomial pair. The probability of classes (``probs`` for the :class:`~numpyro.distributions.Multinomial` distribution) is unknown and randomly drawn from a :class:`~numpyro.distributions.Dirichlet` distribution prior to a certain number of Categorical trials given by ``total_count``. :param numpy.ndarray concentration: concentration parameter (alpha) for the Dirichlet distribution. :param numpy.ndarray total_count: number of Categorical trials. """ arg_constraints = { 'concentration': constraints.independent(constraints.positive, 1), 'total_count': constraints.nonnegative_integer } is_discrete = True def __init__(self, concentration, total_count=1, validate_args=None): if jnp.ndim(concentration) < 1: raise ValueError( "`concentration` parameter must be at least one-dimensional.") batch_shape = lax.broadcast_shapes( jnp.shape(concentration)[:-1], jnp.shape(total_count)) concentration_shape = batch_shape + jnp.shape(concentration)[-1:] self.concentration, = promote_shapes(concentration, shape=concentration_shape) self.total_count, = promote_shapes(total_count, shape=batch_shape) concentration = jnp.broadcast_to(self.concentration, concentration_shape) self._dirichlet = Dirichlet(concentration) super().__init__(self._dirichlet.batch_shape, self._dirichlet.event_shape, validate_args=validate_args) def sample(self, key, sample_shape=()): assert is_prng_key(key) key_dirichlet, key_multinom = random.split(key) probs = self._dirichlet.sample(key_dirichlet, sample_shape) return MultinomialProbs(total_count=self.total_count, probs=probs).sample(key_multinom) @validate_sample def log_prob(self, value): alpha = self.concentration return (_log_beta_1(alpha.sum(-1), value.sum(-1)) - _log_beta_1(alpha, value).sum(-1)) @property def mean(self): return self._dirichlet.mean * jnp.expand_dims(self.total_count, -1) @property def variance(self): n = jnp.expand_dims(self.total_count, -1) alpha = self.concentration alpha_sum = self.concentration.sum(-1, keepdims=True) alpha_ratio = alpha / alpha_sum return n * alpha_ratio * (1 - alpha_ratio) * (n + alpha_sum) / ( 1 + alpha_sum) @constraints.dependent_property(is_discrete=True, event_dim=1) def support(self): return constraints.multinomial(self.total_count) @staticmethod def infer_shapes(concentration, total_count=()): batch_shape = lax.broadcast_shapes(concentration[:-1], total_count) event_shape = concentration[-1:] return batch_shape, event_shape
class SineBivariateVonMises(Distribution): r"""Unimodal distribution of two dependent angles on the 2-torus (:math:`S^1 \otimes S^1`) given by .. math:: C^{-1}\exp(\kappa_1\cos(x_1-\mu_1) + \kappa_2\cos(x_2 -\mu_2) + \rho\sin(x_1 - \mu_1)\sin(x_2 - \mu_2)) and .. math:: C = (2\pi)^2 \sum_{i=0} {2i \choose i} \left(\frac{\rho^2}{4\kappa_1\kappa_2}\right)^i I_i(\kappa_1)I_i(\kappa_2), where :math:`I_i(\cdot)` is the modified bessel function of first kind, mu's are the locations of the distribution, kappa's are the concentration and rho gives the correlation between angles :math:`x_1` and :math:`x_2`. This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains. To infer parameters, use :class:`~numpyro.infer.hmc.NUTS` or :class:`~numpyro.infer.hmc.HMC` with priors that avoid parameterizations where the distribution becomes bimodal; see note below. .. note:: Sample efficiency drops as .. math:: \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1 because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation` parameter with a skew away from one (e.g., Beta(1,3)). The `weighted_correlation` should be in [0,1]. .. note:: The correlation and weighted_correlation params are mutually exclusive. .. note:: In the context of :class:`~numpyro.infer.svi.SVI`, this distribution can be used as a likelihood but not for latent variables. ** References: ** 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) :param np.ndarray phi_loc: location of first angle :param np.ndarray psi_loc: location of second angle :param np.ndarray phi_concentration: concentration of first angle :param np.ndarray psi_concentration: concentration of second angle :param np.ndarray correlation: correlation between the two angles :param np.ndarray weighted_correlation: set correlation to weigthed_corr * sqrt(phi_conc*psi_conc) to avoid bimodality (see note). The `weighted_correlation` should be in [0,1]. """ arg_constraints = { "phi_loc": constraints.circular, "psi_loc": constraints.circular, "phi_concentration": constraints.positive, "psi_concentration": constraints.positive, "correlation": constraints.real, } support = constraints.independent(constraints.circular, 1) max_sample_iter = 1000 def __init__( self, phi_loc, psi_loc, phi_concentration, psi_concentration, correlation=None, weighted_correlation=None, validate_args=None, ): assert (correlation is None) != (weighted_correlation is None) if weighted_correlation is not None: correlation = (weighted_correlation * jnp.sqrt(phi_concentration * psi_concentration) + 1e-8) ( self.phi_loc, self.psi_loc, self.phi_concentration, self.psi_concentration, self.correlation, ) = promote_shapes(phi_loc, psi_loc, phi_concentration, psi_concentration, correlation) batch_shape = lax.broadcast_shapes( jnp.shape(phi_loc), jnp.shape(psi_loc), jnp.shape(phi_concentration), jnp.shape(psi_concentration), jnp.shape(correlation), ) super().__init__(batch_shape, (2, ), validate_args) @lazy_property def norm_const(self): corr = jnp.reshape(self.correlation, (1, -1)) + 1e-8 conc = jnp.stack((self.phi_concentration, self.psi_concentration), axis=-1).reshape(-1, 2) m = jnp.arange(50).reshape(-1, 1) num = special.gammaln(2 * m + 1.0) den = special.gammaln(m + 1.0) lbinoms = num - 2 * den fs = (lbinoms.reshape(-1, 1) + 2 * m * jnp.log(corr) - m * jnp.log(4 * jnp.prod(conc, axis=-1))) fs += log_I1(49, conc, terms=51).sum(-1) mfs = fs.max() norm_const = 2 * jnp.log(jnp.array(2 * pi)) + mfs + logsumexp( fs - mfs, 0) return norm_const.reshape(jnp.shape(self.phi_loc)) @validate_sample def log_prob(self, value): indv = self.phi_concentration * jnp.cos( value[..., 0] - self.phi_loc) + self.psi_concentration * jnp.cos( value[..., 1] - self.psi_loc) corr = (self.correlation * jnp.sin(value[..., 0] - self.phi_loc) * jnp.sin(value[..., 1] - self.psi_loc)) return indv + corr - self.norm_const def sample(self, key, sample_shape=()): """ ** References: ** 1. A New Unified Approach for the Simulation of a Wide Class of Directional Distributions John T. Kent, Asaad M. Ganeiber & Kanti V. Mardia (2018) """ assert is_prng_key(key) phi_key, psi_key = random.split(key) corr = self.correlation conc = jnp.stack((self.phi_concentration, self.psi_concentration)) eig = 0.5 * (conc[0] - corr**2 / conc[1]) eig = jnp.stack((jnp.zeros_like(eig), eig)) eigmin = jnp.where(eig[1] < 0, eig[1], jnp.zeros_like(eig[1], dtype=eig.dtype)) eig = eig - eigmin b0 = self._bfind(eig) total = _numel(sample_shape) phi_den = log_I1(0, conc[1]).squeeze(0) batch_size = _numel(self.batch_shape) phi_shape = (total, 2, batch_size) phi_state = SineBivariateVonMises._phi_marginal( phi_shape, phi_key, jnp.reshape(conc, (2, batch_size)), jnp.reshape(corr, (batch_size, )), jnp.reshape(eig, (2, batch_size)), jnp.reshape(b0, (batch_size, )), jnp.reshape(eigmin, (batch_size, )), jnp.reshape(phi_den, (batch_size, )), ) phi = jnp.arctan2(phi_state.phi[:, 1:], phi_state.phi[:, :1]) alpha = jnp.sqrt(conc[1]**2 + (corr * jnp.sin(phi))**2) beta = jnp.arctan(corr / conc[1] * jnp.sin(phi)) psi = VonMises(beta, alpha).sample(psi_key) phi_psi = jnp.concatenate( ( (phi + self.phi_loc + pi) % (2 * pi) - pi, (psi + self.psi_loc + pi) % (2 * pi) - pi, ), axis=1, ) phi_psi = jnp.transpose(phi_psi, (0, 2, 1)) return phi_psi.reshape(*sample_shape, *self.batch_shape, *self.event_shape) @staticmethod def _phi_marginal(shape, rng_key, conc, corr, eig, b0, eigmin, phi_den): conc = jnp.broadcast_to(conc, shape) eig = jnp.broadcast_to(eig, shape) b0 = jnp.broadcast_to(b0, shape) eigmin = jnp.broadcast_to(eigmin, shape) phi_den = jnp.broadcast_to(phi_den, shape) def update_fn(curr): i, done, phi, key = curr phi_key, key = random.split(key) accept_key, acg_key, phi_key = random.split(phi_key, 3) x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) x /= jnp.linalg.norm( x, axis=1, keepdims=True) # Angular Central Gaussian distribution lf = (conc[:, :1] * (x[:, :1] - 1) + eigmin + log_I1(0, jnp.sqrt(conc[:, 1:]**2 + (corr * x[:, 1:])**2)).squeeze(0) - phi_den) assert lf.shape == shape lg_inv = (1.0 - b0 / 2 + jnp.log(b0 / 2 + (eig * x**2).sum(1, keepdims=True))) assert lg_inv.shape == lf.shape accepted = random.uniform(accept_key, shape) < jnp.exp(lf + lg_inv) phi = jnp.where(accepted, x, phi) return PhiMarginalState(i + 1, done | accepted, phi, key) def cond_fn(curr): return jnp.bitwise_and( curr.i < SineBivariateVonMises.max_sample_iter, jnp.logical_not(jnp.all(curr.done)), ) phi_state = while_loop( cond_fn, update_fn, PhiMarginalState( i=jnp.array(0), done=jnp.zeros(shape, dtype=bool), phi=jnp.empty(shape, dtype=float), key=rng_key, ), ) return PhiMarginalState(phi_state.i, phi_state.done, phi_state.phi, phi_state.key) @property def mean(self): """Computes circular mean of distribution. Note: same as location when mapped to support [-pi, pi]""" mean = (jnp.stack((self.phi_loc, self.psi_loc), axis=-1) + jnp.pi) % (2.0 * jnp.pi) - jnp.pi return jnp.broadcast_to(mean, (*self.batch_shape, 2)) def _bfind(self, eig): b = eig.shape[0] / 2 * jnp.ones(self.batch_shape, dtype=eig.dtype) g1 = jnp.sum(1 / (b + 2 * eig)**2, axis=0) g2 = jnp.sum(-2 / (b + 2 * eig)**3, axis=0) return jnp.where(jnp.linalg.norm(eig, axis=0) != 0, b - g1 / g2, b)
class SineSkewed(Distribution): r"""Sine-skewing [1] is a procedure for producing a distribution that breaks pointwise symmetry on a torus distribution. The new distribution is called the Sine Skewed X distribution, where X is the name of the (symmetric) base distribution. Torus distributions are distributions with support on products of circles (i.e., :math:`\otimes S^1` where :math:`S^1 = [-pi,pi)`). So, a 0-torus is a point, the 1-torus is a circle, and the 2-torus is commonly associated with the donut shape. The sine skewed X distribution is parameterized by a weight parameter for each dimension of the event of X. For example with a von Mises distribution over a circle (1-torus), the sine skewed von Mises distribution has one skew parameter. The skewness parameters can be inferred using :class:`~numpyro.infer.HMC` or :class:`~numpyro.infer.NUTS`. For example, the following will produce a prior over skewness for the 2-torus,:: @numpyro.handlers.reparam(config={'phi_loc': CircularReparam(), 'psi_loc': CircularReparam()}) def model(obs): # Sine priors phi_loc = numpyro.sample('phi_loc', VonMises(pi, 2.)) psi_loc = numpyro.sample('psi_loc', VonMises(-pi / 2, 2.)) phi_conc = numpyro.sample('phi_conc', Beta(1., 1.)) psi_conc = numpyro.sample('psi_conc', Beta(1., 1.)) corr_scale = numpyro.sample('corr_scale', Beta(2., 5.)) # Skewing prior ball_trans = L1BallTransform() skewness = numpyro.sample('skew_phi', Normal(0, 0.5).expand((2,))) skewness = ball_trans(skewness) # constraint sum |skewness_i| <= 1 with numpyro.plate('obs_plate'): sine = SineBivariateVonMises(phi_loc=phi_loc, psi_loc=psi_loc, phi_concentration=70 * phi_conc, psi_concentration=70 * psi_conc, weighted_correlation=corr_scale) return numpyro.sample('phi_psi', SineSkewed(sine, skewness), obs=obs) To ensure the skewing does not alter the normalization constant of the (sine bivariate von Mises) base distribution the skewness parameters are constraint. The constraint requires the sum of the absolute values of skewness to be less than or equal to one. We can use the :class:`~numpyro.distriubtions.transforms.L1BallTransform` to achieve this. In the context of :class:`~pyro.infer.SVI`, this distribution can freely be used as a likelihood, but use as latent variables it will lead to slow inference for 2 and higher dim toruses. This is because the base_dist cannot be reparameterized. .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be `(d,)`. .. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event must be less than or equal to one. See eq. 2.1 in [1]. ** References: ** 1. Sine-skewed toroidal distributions and their application in protein bioinformatics Ameijeiras-Alonso, J., Ley, C. (2019) :param numpyro.distributions.Distribution base_dist: base density on a d-dimensional torus. Supported base distributions include: 1D :class:`~numpyro.distributions.VonMises`, :class:`~numnumpyro.distributions.SineBivariateVonMises`, 1D :class:`~numpyro.distributions.ProjectedNormal`, and :class:`~numpyro.distributions.Uniform` (-pi, pi). :param jax.numpy.array skewness: skewness of the distribution. """ arg_constraints = {"skewness": constraints.l1_ball} support = constraints.independent(constraints.circular, 1) def __init__(self, base_dist: Distribution, skewness, validate_args=None): assert ( base_dist.event_shape == skewness.shape[-1:] ), "Sine Skewing is only valid with a skewness parameter for each dimension of `base_dist.event_shape`." batch_shape = jnp.broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1]) event_shape = skewness.shape[-1:] self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape) self.base_dist = base_dist.expand(batch_shape) super().__init__(batch_shape, event_shape, validate_args=validate_args) def __repr__(self): args_string = ", ".join([ "{}: {}".format( p, getattr(self, p) if getattr(self, p).numel() == 1 else getattr( self, p).size(), ) for p in self.arg_constraints.keys() ]) return (self.__class__.__name__ + "(" + f"base_density: {str(self.base_dist)}, " + args_string + ")") def sample(self, key, sample_shape=()): base_key, skew_key = random.split(key) bd = self.base_dist ys = bd.sample(base_key, sample_shape) u = random.uniform(skew_key, sample_shape + self.batch_shape) # Section 2.3 step 3 in [1] mask = u <= 0.5 + 0.5 * (self.skewness * jnp.sin( (ys - bd.mean) % (2 * jnp.pi))).sum(-1) mask = mask[..., None] samples = (jnp.where(mask, ys, -ys + 2 * bd.mean) + jnp.pi) % (2 * jnp.pi) - jnp.pi return samples def log_prob(self, value): if self._validate_args: self._validate_sample(value) if self.base_dist._validate_args: self.base_dist._validate_sample(value) # Eq. 2.1 in [1] skew_prob = jnp.log1p((self.skewness * jnp.sin( (value - self.base_dist.mean) % (2 * jnp.pi))).sum(-1)) return self.base_dist.log_prob(value) + skew_prob @property def mean(self): """Mean of the base distribution""" return self.base_dist.mean
def support(self): return independent(real, self.event_dim)
def support(self): return independent(self.base_dist.support, self.reinterpreted_batch_ndims)
def __init__(self, support, batch_shape, event_shape, validate_args=None): self.support = independent(support, len(event_shape) - support.event_dim) super().__init__(batch_shape, event_shape, validate_args=validate_args)
class SineSkewed(Distribution): """The Sine Skewed distribution [1] is a distribution for breaking pointwise-symmetry on a base-distribution over the d-dimensional torus defined as ⨂^d S^1 where S^1 is the circle. So for example the 0-torus is a point, the 1-torus is a circle and the 2-tours is commonly associated with the donut shape (some may object to this simile). The skewness parameter can be inferred using :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`. For example, the following will produce a uniform prior over skewness for the 2-torus,:: def model(...): ... skew_phi = pyro.sample(f'skew_phi', Uniform(-1., 1.)) psi_bound = 1 - skewness_phi.abs() skew_psi = pyro.sample(f'skew_psi', Uniform(-1, 1.)) skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=0) ... In the context of :class:`~pyro.infer.SVI`, this distribution can be freely used as a likelihood, but use as a latent variables will lead to slow inference for 2 and higher order toruses. This is because the base_dist cannot be reparameterized. .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,). .. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event must be less than or equal to one. See eq. 2.1 in [1]. ** References: ** 1. Sine-skewed toroidal distributions and their application in protein bioinformatics Ameijeiras-Alonso, J., Ley, C. (2019) :param base_dist: base density on a d-dimensional torus. :param skewness: skewness of the distribution. """ arg_constraints = { "skewness": constraints.independent(constraints.interval(-1.0, 1.0), 1) } support = constraints.independent(constraints.real, 1) def __init__(self, base_dist: Distribution, skewness, validate_args=None): batch_shape = jnp.broadcast_shapes(base_dist.batch_shape, skewness.shape[:-1]) event_shape = skewness.shape[-1:] self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape) self.base_dist = base_dist.expand(batch_shape) super().__init__(batch_shape, event_shape, validate_args=validate_args) if self._validate_args and base_dist.mean.device != skewness.device: raise ValueError( f"base_density: {base_dist.__class__.__name__} and SineSkewed " f"must be on same device.") def __repr__(self): args_string = ", ".join([ "{}: {}".format( p, getattr(self, p) if getattr(self, p).numel() == 1 else getattr( self, p).size(), ) for p in self.arg_constraints.keys() ]) return (self.__class__.__name__ + "(" + f"base_density: {str(self.base_dist)}, " + args_string + ")") def sample(self, key, sample_shape=()): base_key, skew_key = random.split(key) bd = self.base_dist ys = bd.sample(base_key, sample_shape) u = random.uniform(skew_key, sample_shape + self.batch_shape) # Section 2.3 step 3 in [1] mask = u <= 0.5 + 0.5 * (self.skewness * jnp.sin( (ys - bd.mean) % (2 * pi))).sum(-1) mask = mask[..., None] samples = (jnp.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi return samples def log_prob(self, value): if self._validate_args: self._validate_sample(value) # Eq. 2.1 in [1] skew_prob = jnp.log(1 + (self.skewness * jnp.sin((value - self.base_dist.mean) % (2 * pi))).sum(-1)) return self.base_dist.log_prob(value) + skew_prob