Example #1
0
 def codomain(self):
     if self.domain is constraints.real:
         return constraints.real
     elif isinstance(self.domain, constraints.greater_than):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.less_than(self(self.domain.lower_bound))
         # we suppose scale > 0 for any tracer
         else:
             return constraints.greater_than(self(self.domain.lower_bound))
     elif isinstance(self.domain, constraints.less_than):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.greater_than(self(self.domain.upper_bound))
         # we suppose scale > 0 for any tracer
         else:
             return constraints.less_than(self(self.domain.upper_bound))
     elif isinstance(self.domain, constraints.interval):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.interval(
                 self(self.domain.upper_bound), self(self.domain.lower_bound)
             )
         else:
             return constraints.interval(
                 self(self.domain.lower_bound), self(self.domain.upper_bound)
             )
     else:
         raise NotImplementedError
Example #2
0
 def _support(self, *args, **kwargs):
     (a, b), loc, scale = self._parse_args(*args, **kwargs)
     # TODO: make constraints.less_than and support a == -np.inf
     if b == np.inf:
         return constraints.greater_than((a - loc) * scale)
     else:
         return constraints.interval((a - loc) * scale, (b - loc) * scale)
def test_elbo_dynamic_support():
    x_prior = dist.Uniform(0, 5)
    x_unconstrained = 2.

    def model():
        numpyro.sample('x', x_prior)

    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(
                super(_AutoGuide, self).__call__,
                {'_auto_latent': x_unconstrained})(*args, **kwargs)

    adam = optim.Adam(0.01)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)

    guide_log_prob = dist.Normal(
        guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum()
    transfrom = transforms.biject_to(constraints.interval(0, 5))
    x = transfrom(x_unconstrained)
    logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x)
    model_log_prob = x_prior.log_prob(x) + logdet
    expected_loss = guide_log_prob - model_log_prob
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
Example #4
0
 def model():
     population = jnp.array([1000., 2000., 3000.])
     with numpyro.plate("region", 3):
         d = dist.ImproperUniform(support=constraints.interval(0, population),
                                  batch_shape=(3,),
                                  event_shape=event_shape)
         incidence = numpyro.sample("incidence", d)
         assert d.log_prob(incidence).shape == (3,)
Example #5
0
 def __init__(self, base_gamma, high, validate_args=None):
     assert isinstance(base_gamma, Gamma)
     batch_shape = lax.broadcast_shapes(base_gamma.batch_shape,
                                        jnp.shape(high))
     self.base_gamma = tree_map(
         lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma)
     (self.high, ) = promote_shapes(high, shape=batch_shape)
     self._support = constraints.interval(0.0, high)
     super().__init__(batch_shape, validate_args=validate_args)
Example #6
0
 def codomain(self):
     if self.domain is constraints.real:
         return constraints.positive
     elif isinstance(self.domain, constraints.greater_than):
         return constraints.greater_than(self.__call__(self.domain.lower_bound))
     elif isinstance(self.domain, constraints.interval):
         return constraints.interval(self.__call__(self.domain.lower_bound),
                                     self.__call__(self.domain.upper_bound))
     else:
         raise NotImplementedError
Example #7
0
def test_iaf():
    # test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample("coefs",
                               dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        offset = numpyro.sample("offset", dist.Uniform(-1, 1))
        logits = offset + jnp.sum(coefs * data, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init, data, labels)
    params = svi.get_params(svi_state)

    x = random.normal(random.PRNGKey(0), (dim + 1, ))
    rng_key = random.PRNGKey(1)
    actual_sample = guide.sample_posterior(rng_key, params)
    actual_output = guide._unpack_latent(guide.get_transform(params)(x))

    flows = []
    for i in range(guide.num_flows):
        if i > 0:
            flows.append(transforms.PermuteTransform(
                jnp.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(
            dim + 1,
            [dim + 1, dim + 1],
            permutation=jnp.arange(dim + 1),
            skip_connections=guide._skip_connections,
            nonlinearity=guide._nonlinearity,
        )
        arn = partial(arn_apply, params["auto_arn__{}$params".format(i)])
        flows.append(InverseAutoregressiveTransform(arn))
    flows.append(guide._unpack_latent)

    transform = transforms.ComposeTransform(flows)
    _, rng_key_sample = random.split(rng_key)
    expected_sample = transform(
        dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample))
    expected_output = transform(x)
    assert_allclose(actual_sample["coefs"], expected_sample["coefs"])
    assert_allclose(
        actual_sample["offset"],
        transforms.biject_to(constraints.interval(-1, 1))(
            expected_sample["offset"]),
    )
    check_eq(actual_output, expected_output)
Example #8
0
class VonMises(Distribution):
    arg_constraints = {
        "loc": constraints.real,
        "concentration": constraints.positive
    }
    reparametrized_params = ["loc"]
    support = constraints.interval(-math.pi, math.pi)

    def __init__(self, loc, concentration, validate_args=None):
        """von Mises distribution for sampling directions.

        :param loc: center of distribution
        :param concentration: concentration of distribution
        """
        self.loc, self.concentration = promote_shapes(loc, concentration)

        batch_shape = lax.broadcast_shapes(jnp.shape(concentration),
                                           jnp.shape(loc))

        super(VonMises, self).__init__(batch_shape=batch_shape,
                                       validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        """Generate sample from von Mises distribution

        :param key: random number generator key
        :param sample_shape: shape of samples
        :return: samples from von Mises
        """
        assert is_prng_key(key)
        samples = von_mises_centered(key, self.concentration,
                                     sample_shape + self.shape())
        samples = samples + self.loc  # VM(0, concentration) -> VM(loc,concentration)
        samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi

        return samples

    @validate_sample
    def log_prob(self, value):
        return -(jnp.log(2 * jnp.pi) + jnp.log(i0e(
            self.concentration))) + self.concentration * (jnp.cos(
                (value - self.loc) % (2 * jnp.pi)) - 1)

    @property
    def mean(self):
        """Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]"""
        return jnp.broadcast_to((self.loc + jnp.pi) % (2.0 * jnp.pi) - jnp.pi,
                                self.batch_shape)

    @property
    def variance(self):
        """Computes circular variance of distribution"""
        return jnp.broadcast_to(
            1.0 - i1e(self.concentration) / i0e(self.concentration),
            self.batch_shape)
Example #9
0
def _get_codomain(bijector):
    if bijector.__class__.__name__ == "Sigmoid":
        return constraints.interval(bijector.low, bijector.high)
    elif bijector.__class__.__name__ == "Identity":
        return constraints.real
    elif bijector.__class__.__name__ in ["Exp", "SoftPlus"]:
        return constraints.positive
    elif bijector.__class__.__name__ == "GeneralizedPareto":
        loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
        if not_jax_tracer(concentration) and jnp.all(concentration < 0):
            return constraints.interval(loc, loc + scale / jnp.abs(concentration))
        # XXX: here we suppose concentration > 0
        # which is not true in general, but should cover enough usage cases
        else:
            return constraints.greater_than(loc)
    elif bijector.__class__.__name__ == "SoftmaxCentered":
        return constraints.simplex
    elif bijector.__class__.__name__ == "Chain":
        return _get_codomain(bijector.bijectors[-1])
    else:
        return constraints.real
Example #10
0
def test_iaf():
    # test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim),
                                                    np.ones(dim)))
        offset = numpyro.sample('offset', dist.Uniform(-1, 1))
        logits = offset + np.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_init = random.PRNGKey(1)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, elbo, adam)
    svi_state = svi.init(rng_init,
                         model_args=(data, labels),
                         guide_args=(data, labels))
    params = svi.get_params(svi_state)

    x = random.normal(random.PRNGKey(0), (dim + 1, ))
    rng = random.PRNGKey(1)
    actual_sample = guide.sample_posterior(rng, params)
    actual_output = guide.get_transform(params)(x)

    flows = []
    for i in range(guide.num_flows):
        if i > 0:
            flows.append(constraints.PermuteTransform(
                np.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(
            dim + 1, [dim + 1, dim + 1],
            permutation=np.arange(dim + 1),
            skip_connections=guide._skip_connections,
            nonlinearity=guide._nonlinearity)
        arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)])
        flows.append(InverseAutoregressiveTransform(arn))

    transform = constraints.ComposeTransform(flows)
    rng_seed, rng_sample = random.split(rng)
    expected_sample = guide.unpack_latent(
        transform(dist.Normal(np.zeros(dim + 1), 1).sample(rng_sample)))
    expected_output = transform(x)
    assert_allclose(actual_sample['coefs'], expected_sample['coefs'])
    assert_allclose(
        actual_sample['offset'],
        constraints.biject_to(constraints.interval(-1, 1))(
            expected_sample['offset']))
    assert_allclose(actual_output, expected_output)
Example #11
0
 def __init__(self, base_dist, low=0.0, high=1.0, validate_args=None):
     assert isinstance(base_dist, self.supported_types)
     assert (
         base_dist.support is constraints.real
     ), "The base distribution should be univariate and have real support."
     batch_shape = lax.broadcast_shapes(base_dist.batch_shape,
                                        jnp.shape(low), jnp.shape(high))
     self.base_dist = tree_map(
         lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist)
     (self.low, ) = promote_shapes(low, shape=batch_shape)
     (self.high, ) = promote_shapes(high, shape=batch_shape)
     self._support = constraints.interval(low, high)
     super().__init__(batch_shape, validate_args=validate_args)
Example #12
0
class TruncatedPolyaGamma(Distribution):
    truncation_point = 2.5
    num_log_prob_terms = 7
    num_gamma_variates = 8
    assert num_log_prob_terms % 2 == 1

    arg_constraints = {}
    support = constraints.interval(0.0, truncation_point)

    def __init__(self, batch_shape=(), validate_args=None):
        super(TruncatedPolyaGamma, self).__init__(
            batch_shape, validate_args=validate_args
        )

    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates))
        x = random.gamma(
            key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,))
        )
        x = jnp.sum(x / denom, axis=-1)
        return jnp.clip(x * (0.5 / jnp.pi ** 2), a_max=self.truncation_point)

    @validate_sample
    def log_prob(self, value):
        value = value[..., None]
        all_indices = jnp.arange(0, self.num_log_prob_terms)
        two_n_plus_one = 2.0 * all_indices + 1.0
        log_terms = (
            jnp.log(two_n_plus_one)
            - 1.5 * jnp.log(value)
            - 0.125 * jnp.square(two_n_plus_one) / value
        )
        even_terms = jnp.take(log_terms, all_indices[::2], axis=-1)
        odd_terms = jnp.take(log_terms, all_indices[1::2], axis=-1)
        sum_even = jnp.exp(logsumexp(even_terms, axis=-1))
        sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1))
        return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi)

    def tree_flatten(self):
        return (), self.batch_shape

    @classmethod
    def tree_unflatten(cls, aux_data, params):
        return cls(batch_shape=aux_data)
Example #13
0
 def guide():
     alpha = param('alpha', 0.5, constraint=constraints.unit_interval)
     param('loc', 0, constraint=constraints.interval(0, alpha))
Example #14
0
 def model(data):
     # NB: model's constraints will play no effect
     loc = param('loc', 0., constraint=constraints.interval(0, 0.5))
     sample('obs', dist.Normal(loc, 0.1), obs=data)
Example #15
0
 def guide():
     c = numpyro.param('c',
                       c_init,
                       constraint=constraints.interval(c_minval, c_maxval))
     d = numpyro.param('d', d_init, constraint=constraints.unit_interval)
     numpyro.sample('y', dist.Normal(c, d), obs=obs)
Example #16
0
 def support(self):
     return constraints.interval(self.low, self.high)
Example #17
0
 (constraints.boolean, np.array([1, 1]), np.array([True, True])),
 (constraints.boolean, np.array([-1, 1]), np.array([False, True])),
 (constraints.corr_cholesky,
  np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]
            ]), np.array([True, False])),  # NB: not lower_triangular
 (constraints.corr_cholesky,
  np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
  np.array([False, False
            ])),  # NB: not positive_diagonal & not unit_norm_row
 (constraints.greater_than(1), 3, True),
 (constraints.greater_than(1), np.array(
     [-1, 1, 5]), np.array([False, False, True])),
 (constraints.integer_interval(-3, 5), 0, True),
 (constraints.integer_interval(-3, 5), np.array([-5, -3, 0, 1.1, 5, 7]),
  np.array([False, True, True, False, True, False])),
 (constraints.interval(-3, 5), 0, True),
 (constraints.interval(-3, 5), np.array(
     [-5, -3, 0, 5, 7]), np.array([False, False, True, False, False])),
 (constraints.lower_cholesky, np.array([[1., 0.], [-2., 0.1]]), True),
 (constraints.lower_cholesky,
  np.array([[[1., 0.], [-2., -0.1]], [[1., 0.1], [2., 0.2]]
            ]), np.array([False, False])),
 (constraints.nonnegative_integer, 3, True),
 (constraints.nonnegative_integer, np.array(
     [-1., 0., 5.]), np.array([False, True, True])),
 (constraints.positive, 3, True),
 (constraints.positive, np.array([-1, 0, 5
                                  ]), np.array([False, False, True])),
 (constraints.positive_definite, np.array([[1., 0.3], [0.3, 1.]
                                           ]), True),
 (constraints.positive_definite,
 (constraints.boolean, np.array([1, 1]), np.array([True, True])),
 (constraints.boolean, np.array([-1, 1]), np.array([False, True])),
 (constraints.corr_cholesky, np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]),
  np.array([True, False])),  # NB: not lower_triangular
 (constraints.corr_cholesky, np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
  np.array([False, False])),  # NB: not positive_diagonal & not unit_norm_row
 (constraints.corr_matrix, np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]),
  np.array([True, False])),  # NB: not lower_triangular
 (constraints.corr_matrix, np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
  np.array([False, False])),  # NB: not unit diagonal
 (constraints.greater_than(1), 3, True),
 (constraints.greater_than(1), np.array([-1, 1, 5]), np.array([False, False, True])),
 (constraints.integer_interval(-3, 5), 0, True),
 (constraints.integer_interval(-3, 5), np.array([-5, -3, 0, 1.1, 5, 7]),
  np.array([False, True, True, False, True, False])),
 (constraints.interval(-3, 5), 0, True),
 (constraints.interval(-3, 5), np.array([-5, -3, 0, 5, 7]),
  np.array([False, False, True, False, False])),
 (constraints.lower_cholesky, np.array([[1., 0.], [-2., 0.1]]), True),
 (constraints.lower_cholesky, np.array([[[1., 0.], [-2., -0.1]], [[1., 0.1], [2., 0.2]]]),
  np.array([False, False])),
 (constraints.nonnegative_integer, 3, True),
 (constraints.nonnegative_integer, np.array([-1., 0., 5.]), np.array([False, True, True])),
 (constraints.positive, 3, True),
 (constraints.positive, np.array([-1, 0, 5]), np.array([False, False, True])),
 (constraints.positive_definite, np.array([[1., 0.3], [0.3, 1.]]), True),
 (constraints.positive_definite, np.array([[[2., 0.4], [0.3, 2.]], [[1., 0.1], [0.1, 0.]]]),
  np.array([False, False])),
 (constraints.positive_integer, 3, True),
 (constraints.positive_integer, np.array([-1., 0., 5.]), np.array([False, False, True])),
 (constraints.real, -1, True),
Example #19
0
 def model(data):
     alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
     loc = numpyro.param('loc',
                         0.,
                         constraint=constraints.interval(0., alpha))
     numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
Example #20
0
class SineSkewed(Distribution):
    """The Sine Skewed distribution [1] is a distribution for breaking pointwise-symmetry on a base-distribution over
    the d-dimensional torus defined as ⨂^d S^1 where S^1 is the circle. So for example the 0-torus is a point, the
    1-torus is a circle and the 2-tours is commonly associated with the donut shape (some may object to this simile).

    The skewness parameter can be inferred using :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`.
    For example, the following will produce a uniform prior over skewness for the 2-torus,::

        def model(...):
            ...
            skew_phi = pyro.sample(f'skew_phi', Uniform(-1., 1.))
            psi_bound = 1 - skewness_phi.abs()
            skew_psi = pyro.sample(f'skew_psi', Uniform(-1, 1.))
            skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=0)
            ...

    In the context of :class:`~pyro.infer.SVI`, this distribution can be freely used as a likelihood, but use as a
    latent variables will lead to slow inference for 2 and higher order toruses. This is because the base_dist
    cannot be reparameterized.

    .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,).

    .. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event
        must be less than or equal to one. See eq. 2.1 in [1].

    ** References: **
      1. Sine-skewed toroidal distributions and their application in protein bioinformatics
         Ameijeiras-Alonso, J., Ley, C. (2019)

    :param base_dist: base density on a d-dimensional torus.
    :param skewness: skewness of the distribution.
    """

    arg_constraints = {
        "skewness": constraints.independent(constraints.interval(-1.0, 1.0), 1)
    }

    support = constraints.independent(constraints.real, 1)

    def __init__(self, base_dist: Distribution, skewness, validate_args=None):
        batch_shape = jnp.broadcast_shapes(base_dist.batch_shape,
                                           skewness.shape[:-1])
        event_shape = skewness.shape[-1:]
        self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape)
        self.base_dist = base_dist.expand(batch_shape)
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

        if self._validate_args and base_dist.mean.device != skewness.device:
            raise ValueError(
                f"base_density: {base_dist.__class__.__name__} and SineSkewed "
                f"must be on same device.")

    def __repr__(self):
        args_string = ", ".join([
            "{}: {}".format(
                p,
                getattr(self, p) if getattr(self, p).numel() == 1 else getattr(
                    self, p).size(),
            ) for p in self.arg_constraints.keys()
        ])
        return (self.__class__.__name__ + "(" +
                f"base_density: {str(self.base_dist)}, " + args_string + ")")

    def sample(self, key, sample_shape=()):
        base_key, skew_key = random.split(key)
        bd = self.base_dist
        ys = bd.sample(base_key, sample_shape)
        u = random.uniform(skew_key, sample_shape + self.batch_shape)

        # Section 2.3 step 3 in [1]
        mask = u <= 0.5 + 0.5 * (self.skewness * jnp.sin(
            (ys - bd.mean) % (2 * pi))).sum(-1)
        mask = mask[..., None]
        samples = (jnp.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi
        return samples

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        # Eq. 2.1 in [1]
        skew_prob = jnp.log(1 + (self.skewness *
                                 jnp.sin((value - self.base_dist.mean) %
                                         (2 * pi))).sum(-1))
        return self.base_dist.log_prob(value) + skew_prob