예제 #1
0
 def codomain(self):
     if self.domain is constraints.real:
         return constraints.real
     elif isinstance(self.domain, constraints.greater_than):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.less_than(self(self.domain.lower_bound))
         # we suppose scale > 0 for any tracer
         else:
             return constraints.greater_than(self(self.domain.lower_bound))
     elif isinstance(self.domain, constraints.less_than):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.greater_than(self(self.domain.upper_bound))
         # we suppose scale > 0 for any tracer
         else:
             return constraints.less_than(self(self.domain.upper_bound))
     elif isinstance(self.domain, constraints.interval):
         if not_jax_tracer(self.scale) and np.all(np.less(self.scale, 0)):
             return constraints.interval(
                 self(self.domain.upper_bound), self(self.domain.lower_bound)
             )
         else:
             return constraints.interval(
                 self(self.domain.lower_bound), self(self.domain.upper_bound)
             )
     else:
         raise NotImplementedError
예제 #2
0
 def __init__(self, base_dist, high=0.0, validate_args=None):
     assert isinstance(base_dist, self.supported_types)
     assert (
         base_dist.support is constraints.real
     ), "The base distribution should be univariate and have real support."
     batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(high))
     self.base_dist = tree_map(
         lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
     )
     (self.high,) = promote_shapes(high, shape=batch_shape)
     self._support = constraints.less_than(high)
     super().__init__(batch_shape, validate_args=validate_args)