Beispiel #1
 def codomain(self):
     if self.domain is constraints.real:
         return constraints.real
     elif isinstance(self.domain, constraints.greater_than):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.less_than(self(self.domain.lower_bound))
         # we suppose scale > 0 for any tracer
             return constraints.greater_than(self(self.domain.lower_bound))
     elif isinstance(self.domain, constraints.less_than):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.greater_than(self(self.domain.upper_bound))
         # we suppose scale > 0 for any tracer
             return constraints.less_than(self(self.domain.upper_bound))
     elif isinstance(self.domain, constraints.interval):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.interval(
                 self(self.domain.upper_bound), self(self.domain.lower_bound)
             return constraints.interval(
                 self(self.domain.lower_bound), self(self.domain.upper_bound)
         raise NotImplementedError
Beispiel #2
 def _support(self, *args, **kwargs):
     (a, b), loc, scale = self._parse_args(*args, **kwargs)
     # TODO: make constraints.less_than and support a == -np.inf
     if b == np.inf:
         return constraints.greater_than((a - loc) * scale)
         return constraints.interval((a - loc) * scale, (b - loc) * scale)
Beispiel #3
 def __init__(self, base_gamma, low, validate_args=None):
     assert isinstance(base_gamma, Gamma)
     batch_shape = lax.broadcast_shapes(base_gamma.batch_shape,
     self.base_gamma = tree_map(
         lambda p: promote_shapes(p, shape=batch_shape)[0], base_gamma)
     (self.low, ) = promote_shapes(low, shape=batch_shape)
     self._support = constraints.greater_than(low)
     super().__init__(batch_shape, validate_args=validate_args)
Beispiel #4
 def codomain(self):
     if self.domain is constraints.real:
         return constraints.positive
     elif isinstance(self.domain, constraints.greater_than):
         return constraints.greater_than(self.__call__(self.domain.lower_bound))
     elif isinstance(self.domain, constraints.interval):
         return constraints.interval(self.__call__(self.domain.lower_bound),
         raise NotImplementedError
Beispiel #5
 def __init__(self, base_dist, low=0.0, validate_args=None):
     assert isinstance(base_dist, self.supported_types)
     assert ( is constraints.real
     ), "The base distribution should be univariate and have real support."
     batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low))
     self.base_dist = tree_map(
         lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
     (self.low,) = promote_shapes(low, shape=batch_shape)
     self._support = constraints.greater_than(low)
     super().__init__(batch_shape, validate_args=validate_args)
Beispiel #6
class pareto_gen(jax_continuous):
    arg_constraints = {'b': constraints.positive}
    _support_mask = constraints.greater_than(1)

    def _rvs(self, b):
        return random.pareto(self._random_state, b, shape=self._size)

    def _cdf(self, x, b):
        return 1 - x ** (-b)

    def _ppf(self, q, b):
        return np.pow(1 - q, -1.0 / b)

    def _sf(self, x, b):
        return x ** (-b)

    def _stats(self, b, moments='mv'):
        mu, mu2, g1, g2 = None, None, None, None
        if 'm' in moments:
            mask = b > 1
            bt = np.extract(mask, b)
            mu = np.where(mask, bt / (bt - 1.0), np.inf)
        if 'v' in moments:
            mask = b > 2
            bt = np.extract(mask, b)
            mu2 = np.where(mask, bt / (bt - 2.0) / (bt - 1.0) ** 2, np.inf)
        if 's' in moments:
            mask = b > 3
            bt = np.extract(mask, b)
            vals = 2 * (bt + 1.0) * np.sqrt(bt - 2.0) / ((bt - 3.0) * np.sqrt(bt))
            g1 = np.where(mask, vals, np.nan)
        if 'k' in moments:
            mask = b > 4
            bt = np.extract(mask, b)
            vals = (6.0 * np.polyval([1.0, 1.0, -6, -2], bt)
                    / np.polyval([1.0, -7.0, 12.0, 0.0], bt))
            g2 = np.where(mask, vals, np.nan)
        return mu, mu2, g1, g2

    def _entropy(self, c):
        return 1 + 1.0 / c - np.log(c)
Beispiel #7
def _get_codomain(bijector):
    if bijector.__class__.__name__ == "Sigmoid":
        return constraints.interval(bijector.low, bijector.high)
    elif bijector.__class__.__name__ == "Identity":
        return constraints.real
    elif bijector.__class__.__name__ in ["Exp", "SoftPlus"]:
        return constraints.positive
    elif bijector.__class__.__name__ == "GeneralizedPareto":
        loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
        if not_jax_tracer(concentration) and jnp.all(concentration < 0):
            return constraints.interval(loc, loc + scale / jnp.abs(concentration))
        # XXX: here we suppose concentration > 0
        # which is not true in general, but should cover enough usage cases
            return constraints.greater_than(loc)
    elif bijector.__class__.__name__ == "SoftmaxCentered":
        return constraints.simplex
    elif bijector.__class__.__name__ == "Chain":
        return _get_codomain(bijector.bijectors[-1])
        return constraints.real
Beispiel #8
 def support(self):
     return constraints.greater_than(self.scale)
Beispiel #9

    'constraint, x, expected',
        (constraints.boolean, np.array([True, False]), np.array([True, True])),
        (constraints.boolean, np.array([1, 1]), np.array([True, True])),
        (constraints.boolean, np.array([-1, 1]), np.array([False, True])),
         np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]
                   ]), np.array([True, False])),  # NB: not lower_triangular
         np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
         np.array([False, False
                   ])),  # NB: not positive_diagonal & not unit_norm_row
        (constraints.greater_than(1), 3, True),
        (constraints.greater_than(1), np.array(
            [-1, 1, 5]), np.array([False, False, True])),
        (constraints.integer_interval(-3, 5), 0, True),
        (constraints.integer_interval(-3, 5), np.array([-5, -3, 0, 1.1, 5, 7]),
         np.array([False, True, True, False, True, False])),
        (constraints.interval(-3, 5), 0, True),
        (constraints.interval(-3, 5), np.array(
            [-5, -3, 0, 5, 7]), np.array([False, False, True, False, False])),
        (constraints.lower_cholesky, np.array([[1., 0.], [-2., 0.1]]), True),
         np.array([[[1., 0.], [-2., -0.1]], [[1., 0.1], [2., 0.2]]
                   ]), np.array([False, False])),
        (constraints.nonnegative_integer, 3, True),
        (constraints.nonnegative_integer, np.array(
            [-1., 0., 5.]), np.array([False, True, True])),
Beispiel #10
 def model():
     a = numpyro.param("a", a_init, constraint=constraints.greater_than(a_minval))
     b = numpyro.param("b", b_init, constraint=constraints.positive)
     numpyro.sample("x", dist.Normal(a, b))
Beispiel #11
 def model():
     a = numpyro.param('a',
     b = numpyro.param('b', b_init, constraint=constraints.positive)
     numpyro.sample('x', dist.Normal(a, b), obs=obs)

@pytest.mark.parametrize('constraint, x, expected', [
    (constraints.boolean, np.array([True, False]), np.array([True, True])),
    (constraints.boolean, np.array([1, 1]), np.array([True, True])),
    (constraints.boolean, np.array([-1, 1]), np.array([False, True])),
    (constraints.corr_cholesky, np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]),
     np.array([True, False])),  # NB: not lower_triangular
    (constraints.corr_cholesky, np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
     np.array([False, False])),  # NB: not positive_diagonal & not unit_norm_row
    (constraints.corr_matrix, np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]),
     np.array([True, False])),  # NB: not lower_triangular
    (constraints.corr_matrix, np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
     np.array([False, False])),  # NB: not unit diagonal
    (constraints.greater_than(1), 3, True),
    (constraints.greater_than(1), np.array([-1, 1, 5]), np.array([False, False, True])),
    (constraints.integer_interval(-3, 5), 0, True),
    (constraints.integer_interval(-3, 5), np.array([-5, -3, 0, 1.1, 5, 7]),
     np.array([False, True, True, False, True, False])),
    (constraints.interval(-3, 5), 0, True),
    (constraints.interval(-3, 5), np.array([-5, -3, 0, 5, 7]),
     np.array([False, False, True, False, False])),
    (constraints.lower_cholesky, np.array([[1., 0.], [-2., 0.1]]), True),
    (constraints.lower_cholesky, np.array([[[1., 0.], [-2., -0.1]], [[1., 0.1], [2., 0.2]]]),
     np.array([False, False])),
    (constraints.nonnegative_integer, 3, True),
    (constraints.nonnegative_integer, np.array([-1., 0., 5.]), np.array([False, True, True])),
    (constraints.positive, 3, True),
    (constraints.positive, np.array([-1, 0, 5]), np.array([False, False, True])),
    (constraints.positive_definite, np.array([[1., 0.3], [0.3, 1.]]), True),
Beispiel #13
 def tree_unflatten(cls, aux_data, params):
     d = cls(*params)
     if aux_data is not None:
         d._support = constraints.greater_than(aux_data)
     return d