コード例 #1
0
 def codomain(self):
     if self.domain is constraints.real:
         return constraints.real
     elif isinstance(self.domain, constraints.greater_than):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.less_than(self(self.domain.lower_bound))
         # we suppose scale > 0 for any tracer
         else:
             return constraints.greater_than(self(self.domain.lower_bound))
     elif isinstance(self.domain, constraints.less_than):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.greater_than(self(self.domain.upper_bound))
         # we suppose scale > 0 for any tracer
         else:
             return constraints.less_than(self(self.domain.upper_bound))
     elif isinstance(self.domain, constraints.interval):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.interval(
                 self(self.domain.upper_bound), self(self.domain.lower_bound)
             )
         else:
             return constraints.interval(
                 self(self.domain.lower_bound), self(self.domain.upper_bound)
             )
     else:
         raise NotImplementedError
コード例 #2
0
 def _support(self, *args, **kwargs):
     (a, b), loc, scale = self._parse_args(*args, **kwargs)
     # TODO: make constraints.less_than and support a == -np.inf
     if b == np.inf:
         return constraints.greater_than((a - loc) * scale)
     else:
         return constraints.interval((a - loc) * scale, (b - loc) * scale)
コード例 #3
0
def test_elbo_dynamic_support():
    x_prior = dist.Uniform(0, 5)
    x_unconstrained = 2.

    def model():
        numpyro.sample('x', x_prior)

    class _AutoGuide(AutoDiagonalNormal):
        def __call__(self, *args, **kwargs):
            return substitute(
                super(_AutoGuide, self).__call__,
                {'_auto_latent': x_unconstrained})(*args, **kwargs)

    adam = optim.Adam(0.01)
    guide = _AutoGuide(model)
    svi = SVI(model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(random.PRNGKey(0))
    actual_loss = svi.evaluate(svi_state)
    assert np.isfinite(actual_loss)

    guide_log_prob = dist.Normal(
        guide._init_latent, guide._init_scale).log_prob(x_unconstrained).sum()
    transfrom = transforms.biject_to(constraints.interval(0, 5))
    x = transfrom(x_unconstrained)
    logdet = transfrom.log_abs_det_jacobian(x_unconstrained, x)
    model_log_prob = x_prior.log_prob(x) + logdet
    expected_loss = guide_log_prob - model_log_prob
    assert_allclose(actual_loss, expected_loss, rtol=1e-6)
コード例 #4
0
ファイル: test_infer_util.py プロジェクト: ucals/numpyro
 def model():
     population = jnp.array([1000., 2000., 3000.])
     with numpyro.plate("region", 3):
         d = dist.ImproperUniform(support=constraints.interval(0, population),
                                  batch_shape=(3,),
                                  event_shape=event_shape)
         incidence = numpyro.sample("incidence", d)
         assert d.log_prob(incidence).shape == (3,)
コード例 #5
0
 def __init__(self, base_gamma, high, validate_args=None):
     assert isinstance(base_gamma, Gamma)
     batch_shape = lax.broadcast_shapes(base_gamma.batch_shape,
                                        jnp.shape(high))
     self.base_gamma = tree_map(
         lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma)
     (self.high, ) = promote_shapes(high, shape=batch_shape)
     self._support = constraints.interval(0.0, high)
     super().__init__(batch_shape, validate_args=validate_args)
コード例 #6
0
ファイル: transforms.py プロジェクト: hzyjerry/numpyro
 def codomain(self):
     if self.domain is constraints.real:
         return constraints.positive
     elif isinstance(self.domain, constraints.greater_than):
         return constraints.greater_than(self.__call__(self.domain.lower_bound))
     elif isinstance(self.domain, constraints.interval):
         return constraints.interval(self.__call__(self.domain.lower_bound),
                                     self.__call__(self.domain.upper_bound))
     else:
         raise NotImplementedError
コード例 #7
0
def test_iaf():
    # test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample("coefs",
                               dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        offset = numpyro.sample("offset", dist.Uniform(-1, 1))
        logits = offset + jnp.sum(coefs * data, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init, data, labels)
    params = svi.get_params(svi_state)

    x = random.normal(random.PRNGKey(0), (dim + 1, ))
    rng_key = random.PRNGKey(1)
    actual_sample = guide.sample_posterior(rng_key, params)
    actual_output = guide._unpack_latent(guide.get_transform(params)(x))

    flows = []
    for i in range(guide.num_flows):
        if i > 0:
            flows.append(transforms.PermuteTransform(
                jnp.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(
            dim + 1,
            [dim + 1, dim + 1],
            permutation=jnp.arange(dim + 1),
            skip_connections=guide._skip_connections,
            nonlinearity=guide._nonlinearity,
        )
        arn = partial(arn_apply, params["auto_arn__{}$params".format(i)])
        flows.append(InverseAutoregressiveTransform(arn))
    flows.append(guide._unpack_latent)

    transform = transforms.ComposeTransform(flows)
    _, rng_key_sample = random.split(rng_key)
    expected_sample = transform(
        dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample))
    expected_output = transform(x)
    assert_allclose(actual_sample["coefs"], expected_sample["coefs"])
    assert_allclose(
        actual_sample["offset"],
        transforms.biject_to(constraints.interval(-1, 1))(
            expected_sample["offset"]),
    )
    check_eq(actual_output, expected_output)
コード例 #8
0
ファイル: directional.py プロジェクト: fehiepsi/numpyro
class VonMises(Distribution):
    arg_constraints = {
        "loc": constraints.real,
        "concentration": constraints.positive
    }
    reparametrized_params = ["loc"]
    support = constraints.interval(-math.pi, math.pi)

    def __init__(self, loc, concentration, validate_args=None):
        """von Mises distribution for sampling directions.

        :param loc: center of distribution
        :param concentration: concentration of distribution
        """
        self.loc, self.concentration = promote_shapes(loc, concentration)

        batch_shape = lax.broadcast_shapes(jnp.shape(concentration),
                                           jnp.shape(loc))

        super(VonMises, self).__init__(batch_shape=batch_shape,
                                       validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        """Generate sample from von Mises distribution

        :param key: random number generator key
        :param sample_shape: shape of samples
        :return: samples from von Mises
        """
        assert is_prng_key(key)
        samples = von_mises_centered(key, self.concentration,
                                     sample_shape + self.shape())
        samples = samples + self.loc  # VM(0, concentration) -> VM(loc,concentration)
        samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi

        return samples

    @validate_sample
    def log_prob(self, value):
        return -(jnp.log(2 * jnp.pi) + jnp.log(i0e(
            self.concentration))) + self.concentration * (jnp.cos(
                (value - self.loc) % (2 * jnp.pi)) - 1)

    @property
    def mean(self):
        """Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi]"""
        return jnp.broadcast_to((self.loc + jnp.pi) % (2.0 * jnp.pi) - jnp.pi,
                                self.batch_shape)

    @property
    def variance(self):
        """Computes circular variance of distribution"""
        return jnp.broadcast_to(
            1.0 - i1e(self.concentration) / i0e(self.concentration),
            self.batch_shape)
コード例 #9
0
ファイル: distributions.py プロジェクト: akihironitta/numpyro
def _get_codomain(bijector):
    if bijector.__class__.__name__ == "Sigmoid":
        return constraints.interval(bijector.low, bijector.high)
    elif bijector.__class__.__name__ == "Identity":
        return constraints.real
    elif bijector.__class__.__name__ in ["Exp", "SoftPlus"]:
        return constraints.positive
    elif bijector.__class__.__name__ == "GeneralizedPareto":
        loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
        if not_jax_tracer(concentration) and jnp.all(concentration < 0):
            return constraints.interval(loc, loc + scale / jnp.abs(concentration))
        # XXX: here we suppose concentration > 0
        # which is not true in general, but should cover enough usage cases
        else:
            return constraints.greater_than(loc)
    elif bijector.__class__.__name__ == "SoftmaxCentered":
        return constraints.simplex
    elif bijector.__class__.__name__ == "Chain":
        return _get_codomain(bijector.bijectors[-1])
    else:
        return constraints.real
コード例 #10
0
def test_iaf():
    # test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample('coefs', dist.Normal(np.zeros(dim),
                                                    np.ones(dim)))
        offset = numpyro.sample('offset', dist.Uniform(-1, 1))
        logits = offset + np.sum(coefs * data, axis=-1)
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_init = random.PRNGKey(1)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, elbo, adam)
    svi_state = svi.init(rng_init,
                         model_args=(data, labels),
                         guide_args=(data, labels))
    params = svi.get_params(svi_state)

    x = random.normal(random.PRNGKey(0), (dim + 1, ))
    rng = random.PRNGKey(1)
    actual_sample = guide.sample_posterior(rng, params)
    actual_output = guide.get_transform(params)(x)

    flows = []
    for i in range(guide.num_flows):
        if i > 0:
            flows.append(constraints.PermuteTransform(
                np.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(
            dim + 1, [dim + 1, dim + 1],
            permutation=np.arange(dim + 1),
            skip_connections=guide._skip_connections,
            nonlinearity=guide._nonlinearity)
        arn = partial(arn_apply, params['auto_arn__{}$params'.format(i)])
        flows.append(InverseAutoregressiveTransform(arn))

    transform = constraints.ComposeTransform(flows)
    rng_seed, rng_sample = random.split(rng)
    expected_sample = guide.unpack_latent(
        transform(dist.Normal(np.zeros(dim + 1), 1).sample(rng_sample)))
    expected_output = transform(x)
    assert_allclose(actual_sample['coefs'], expected_sample['coefs'])
    assert_allclose(
        actual_sample['offset'],
        constraints.biject_to(constraints.interval(-1, 1))(
            expected_sample['offset']))
    assert_allclose(actual_output, expected_output)
コード例 #11
0
 def __init__(self, base_dist, low=0.0, high=1.0, validate_args=None):
     assert isinstance(base_dist, self.supported_types)
     assert (
         base_dist.support is constraints.real
     ), "The base distribution should be univariate and have real support."
     batch_shape = lax.broadcast_shapes(base_dist.batch_shape,
                                        jnp.shape(low), jnp.shape(high))
     self.base_dist = tree_map(
         lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist)
     (self.low, ) = promote_shapes(low, shape=batch_shape)
     (self.high, ) = promote_shapes(high, shape=batch_shape)
     self._support = constraints.interval(low, high)
     super().__init__(batch_shape, validate_args=validate_args)
コード例 #12
0
ファイル: truncated.py プロジェクト: pyro-ppl/numpyro
class TruncatedPolyaGamma(Distribution):
    truncation_point = 2.5
    num_log_prob_terms = 7
    num_gamma_variates = 8
    assert num_log_prob_terms % 2 == 1

    arg_constraints = {}
    support = constraints.interval(0.0, truncation_point)

    def __init__(self, batch_shape=(), validate_args=None):
        super(TruncatedPolyaGamma, self).__init__(
            batch_shape, validate_args=validate_args
        )

    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        denom = jnp.square(jnp.arange(0.5, self.num_gamma_variates))
        x = random.gamma(
            key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,))
        )
        x = jnp.sum(x / denom, axis=-1)
        return jnp.clip(x * (0.5 / jnp.pi ** 2), a_max=self.truncation_point)

    @validate_sample
    def log_prob(self, value):
        value = value[..., None]
        all_indices = jnp.arange(0, self.num_log_prob_terms)
        two_n_plus_one = 2.0 * all_indices + 1.0
        log_terms = (
            jnp.log(two_n_plus_one)
            - 1.5 * jnp.log(value)
            - 0.125 * jnp.square(two_n_plus_one) / value
        )
        even_terms = jnp.take(log_terms, all_indices[::2], axis=-1)
        odd_terms = jnp.take(log_terms, all_indices[1::2], axis=-1)
        sum_even = jnp.exp(logsumexp(even_terms, axis=-1))
        sum_odd = jnp.exp(logsumexp(odd_terms, axis=-1))
        return jnp.log(sum_even - sum_odd) - 0.5 * jnp.log(2.0 * jnp.pi)

    def tree_flatten(self):
        return (), self.batch_shape

    @classmethod
    def tree_unflatten(cls, aux_data, params):
        return cls(batch_shape=aux_data)
コード例 #13
0
 def guide():
     alpha = param('alpha', 0.5, constraint=constraints.unit_interval)
     param('loc', 0, constraint=constraints.interval(0, alpha))
コード例 #14
0
 def model(data):
     # NB: model's constraints will play no effect
     loc = param('loc', 0., constraint=constraints.interval(0, 0.5))
     sample('obs', dist.Normal(loc, 0.1), obs=data)
コード例 #15
0
 def guide():
     c = numpyro.param('c',
                       c_init,
                       constraint=constraints.interval(c_minval, c_maxval))
     d = numpyro.param('d', d_init, constraint=constraints.unit_interval)
     numpyro.sample('y', dist.Normal(c, d), obs=obs)
コード例 #16
0
 def support(self):
     return constraints.interval(self.low, self.high)
コード例 #17
0
ファイル: test_distributions.py プロジェクト: hdocmsu/numpyro
 (constraints.boolean, np.array([1, 1]), np.array([True, True])),
 (constraints.boolean, np.array([-1, 1]), np.array([False, True])),
 (constraints.corr_cholesky,
  np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]
            ]), np.array([True, False])),  # NB: not lower_triangular
 (constraints.corr_cholesky,
  np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
  np.array([False, False
            ])),  # NB: not positive_diagonal & not unit_norm_row
 (constraints.greater_than(1), 3, True),
 (constraints.greater_than(1), np.array(
     [-1, 1, 5]), np.array([False, False, True])),
 (constraints.integer_interval(-3, 5), 0, True),
 (constraints.integer_interval(-3, 5), np.array([-5, -3, 0, 1.1, 5, 7]),
  np.array([False, True, True, False, True, False])),
 (constraints.interval(-3, 5), 0, True),
 (constraints.interval(-3, 5), np.array(
     [-5, -3, 0, 5, 7]), np.array([False, False, True, False, False])),
 (constraints.lower_cholesky, np.array([[1., 0.], [-2., 0.1]]), True),
 (constraints.lower_cholesky,
  np.array([[[1., 0.], [-2., -0.1]], [[1., 0.1], [2., 0.2]]
            ]), np.array([False, False])),
 (constraints.nonnegative_integer, 3, True),
 (constraints.nonnegative_integer, np.array(
     [-1., 0., 5.]), np.array([False, True, True])),
 (constraints.positive, 3, True),
 (constraints.positive, np.array([-1, 0, 5
                                  ]), np.array([False, False, True])),
 (constraints.positive_definite, np.array([[1., 0.3], [0.3, 1.]
                                           ]), True),
 (constraints.positive_definite,
コード例 #18
0
 (constraints.boolean, np.array([1, 1]), np.array([True, True])),
 (constraints.boolean, np.array([-1, 1]), np.array([False, True])),
 (constraints.corr_cholesky, np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]),
  np.array([True, False])),  # NB: not lower_triangular
 (constraints.corr_cholesky, np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
  np.array([False, False])),  # NB: not positive_diagonal & not unit_norm_row
 (constraints.corr_matrix, np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]),
  np.array([True, False])),  # NB: not lower_triangular
 (constraints.corr_matrix, np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
  np.array([False, False])),  # NB: not unit diagonal
 (constraints.greater_than(1), 3, True),
 (constraints.greater_than(1), np.array([-1, 1, 5]), np.array([False, False, True])),
 (constraints.integer_interval(-3, 5), 0, True),
 (constraints.integer_interval(-3, 5), np.array([-5, -3, 0, 1.1, 5, 7]),
  np.array([False, True, True, False, True, False])),
 (constraints.interval(-3, 5), 0, True),
 (constraints.interval(-3, 5), np.array([-5, -3, 0, 5, 7]),
  np.array([False, False, True, False, False])),
 (constraints.lower_cholesky, np.array([[1., 0.], [-2., 0.1]]), True),
 (constraints.lower_cholesky, np.array([[[1., 0.], [-2., -0.1]], [[1., 0.1], [2., 0.2]]]),
  np.array([False, False])),
 (constraints.nonnegative_integer, 3, True),
 (constraints.nonnegative_integer, np.array([-1., 0., 5.]), np.array([False, True, True])),
 (constraints.positive, 3, True),
 (constraints.positive, np.array([-1, 0, 5]), np.array([False, False, True])),
 (constraints.positive_definite, np.array([[1., 0.3], [0.3, 1.]]), True),
 (constraints.positive_definite, np.array([[[2., 0.4], [0.3, 2.]], [[1., 0.1], [0.1, 0.]]]),
  np.array([False, False])),
 (constraints.positive_integer, 3, True),
 (constraints.positive_integer, np.array([-1., 0., 5.]), np.array([False, False, True])),
 (constraints.real, -1, True),
コード例 #19
0
 def model(data):
     alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
     loc = numpyro.param('loc',
                         0.,
                         constraint=constraints.interval(0., alpha))
     numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)
コード例 #20
0
ファイル: directional.py プロジェクト: ahmadsalim/numpyro
class SineSkewed(Distribution):
    """The Sine Skewed distribution [1] is a distribution for breaking pointwise-symmetry on a base-distribution over
    the d-dimensional torus defined as ⨂^d S^1 where S^1 is the circle. So for example the 0-torus is a point, the
    1-torus is a circle and the 2-tours is commonly associated with the donut shape (some may object to this simile).

    The skewness parameter can be inferred using :class:`~pyro.infer.HMC` or :class:`~pyro.infer.NUTS`.
    For example, the following will produce a uniform prior over skewness for the 2-torus,::

        def model(...):
            ...
            skew_phi = pyro.sample(f'skew_phi', Uniform(-1., 1.))
            psi_bound = 1 - skewness_phi.abs()
            skew_psi = pyro.sample(f'skew_psi', Uniform(-1, 1.))
            skewness = torch.stack((skew_phi, psi_bound * skew_psi), dim=0)
            ...

    In the context of :class:`~pyro.infer.SVI`, this distribution can be freely used as a likelihood, but use as a
    latent variables will lead to slow inference for 2 and higher order toruses. This is because the base_dist
    cannot be reparameterized.

    .. note:: An event in the base distribution must be on a d-torus, so the event_shape must be (d,).

    .. note:: For the skewness parameter, it must hold that the sum of the absolute value of its weights for an event
        must be less than or equal to one. See eq. 2.1 in [1].

    ** References: **
      1. Sine-skewed toroidal distributions and their application in protein bioinformatics
         Ameijeiras-Alonso, J., Ley, C. (2019)

    :param base_dist: base density on a d-dimensional torus.
    :param skewness: skewness of the distribution.
    """

    arg_constraints = {
        "skewness": constraints.independent(constraints.interval(-1.0, 1.0), 1)
    }

    support = constraints.independent(constraints.real, 1)

    def __init__(self, base_dist: Distribution, skewness, validate_args=None):
        batch_shape = jnp.broadcast_shapes(base_dist.batch_shape,
                                           skewness.shape[:-1])
        event_shape = skewness.shape[-1:]
        self.skewness = jnp.broadcast_to(skewness, batch_shape + event_shape)
        self.base_dist = base_dist.expand(batch_shape)
        super().__init__(batch_shape, event_shape, validate_args=validate_args)

        if self._validate_args and base_dist.mean.device != skewness.device:
            raise ValueError(
                f"base_density: {base_dist.__class__.__name__} and SineSkewed "
                f"must be on same device.")

    def __repr__(self):
        args_string = ", ".join([
            "{}: {}".format(
                p,
                getattr(self, p) if getattr(self, p).numel() == 1 else getattr(
                    self, p).size(),
            ) for p in self.arg_constraints.keys()
        ])
        return (self.__class__.__name__ + "(" +
                f"base_density: {str(self.base_dist)}, " + args_string + ")")

    def sample(self, key, sample_shape=()):
        base_key, skew_key = random.split(key)
        bd = self.base_dist
        ys = bd.sample(base_key, sample_shape)
        u = random.uniform(skew_key, sample_shape + self.batch_shape)

        # Section 2.3 step 3 in [1]
        mask = u <= 0.5 + 0.5 * (self.skewness * jnp.sin(
            (ys - bd.mean) % (2 * pi))).sum(-1)
        mask = mask[..., None]
        samples = (jnp.where(mask, ys, -ys + 2 * bd.mean) + pi) % (2 * pi) - pi
        return samples

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)

        # Eq. 2.1 in [1]
        skew_prob = jnp.log(1 + (self.skewness *
                                 jnp.sin((value - self.base_dist.mean) %
                                         (2 * pi))).sum(-1))
        return self.base_dist.log_prob(value) + skew_prob