예제 #1
0
 def _support(self, *args, **kwargs):
     args, loc, _ = self._parse_args(*args, **kwargs)
     support_mask = self._support_mask
     if isinstance(support_mask, constraints.integer_interval):
         return constraints.integer_interval(loc + support_mask.lower_bound,
                                             loc + support_mask.upper_bound)
     elif isinstance(support_mask, constraints.integer_greater_than):
         return constraints.integer_greater_than(loc + support_mask.lower_bound)
     else:
         raise NotImplementedError
예제 #2
0
class bernoulli_gen(jax_discrete):
    _support_mask = constraints.integer_interval(0, 1)

    @property
    def arg_constraints(self):
        if self.is_logits:
            return {'p': constraints.real}
        else:
            return {'p': constraints.unit_interval}

    def _rvs(self, p):
        if self.is_logits:
            p = expit(p)
        return random.bernoulli(self._random_state, p, self._size)

    def _logpmf(self, x, p):
        if self.is_logits:
            return -binary_cross_entropy_with_logits(p, x)
        else:
            # TODO: consider always clamp and convert probs to logits
            return xlogy(x, p) + xlog1py(1 - x, -p)

    def _pmf(self, x, p):
        return np.exp(self._logpmf(x, p))

    def _cdf(self, x, p):
        return binom._cdf(x, 1, p)

    def _sf(self, x, p):
        return binom._sf(x, 1, p)

    def _ppf(self, q, p):
        return binom._ppf(q, 1, p)

    def _stats(self, p):
        return binom._stats(1, p)

    def _entropy(self, p):
        # TODO: use logits and binary_cross_entropy_with_logits for more stable
        if self.is_logits:
            p = expit(p)
        return entr(p) + entr(1 - p)
예제 #3
0
 def _support(self, *args, **kwargs):
     (p, ), _, _ = self._parse_args(*args, **kwargs)
     return constraints.integer_interval(0, p.shape[-1] - 1)
예제 #4
0
 def _support(self, *args, **kwargs):
     (n, p), _, _ = self._parse_args(*args, **kwargs)
     return constraints.integer_interval(0, n)
예제 #5
0
 def support(self):
     return constraints.integer_interval(0, np.shape(self.logits)[-1])
예제 #6
0
 def support(self):
     return constraints.integer_interval(0, self.total_count)
예제 #7
0
 'constraint, x, expected',
 [
     (constraints.boolean, np.array([True, False]), np.array([True, True])),
     (constraints.boolean, np.array([1, 1]), np.array([True, True])),
     (constraints.boolean, np.array([-1, 1]), np.array([False, True])),
     (constraints.corr_cholesky,
      np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]
                ]), np.array([True, False])),  # NB: not lower_triangular
     (constraints.corr_cholesky,
      np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
      np.array([False, False
                ])),  # NB: not positive_diagonal & not unit_norm_row
     (constraints.greater_than(1), 3, True),
     (constraints.greater_than(1), np.array(
         [-1, 1, 5]), np.array([False, False, True])),
     (constraints.integer_interval(-3, 5), 0, True),
     (constraints.integer_interval(-3, 5), np.array([-5, -3, 0, 1.1, 5, 7]),
      np.array([False, True, True, False, True, False])),
     (constraints.interval(-3, 5), 0, True),
     (constraints.interval(-3, 5), np.array(
         [-5, -3, 0, 5, 7]), np.array([False, False, True, False, False])),
     (constraints.lower_cholesky, np.array([[1., 0.], [-2., 0.1]]), True),
     (constraints.lower_cholesky,
      np.array([[[1., 0.], [-2., -0.1]], [[1., 0.1], [2., 0.2]]
                ]), np.array([False, False])),
     (constraints.nonnegative_integer, 3, True),
     (constraints.nonnegative_integer, np.array(
         [-1., 0., 5.]), np.array([False, True, True])),
     (constraints.positive, 3, True),
     (constraints.positive, np.array([-1, 0, 5
                                      ]), np.array([False, False, True])),
예제 #8
0
 def support(self):
     return constraints.integer_interval(0, jnp.shape(self.probs)[-1] - 1)
예제 #9
0
 def __init__(self, low=0, high=1, validate_args=None):
     self.low, self.high = promote_shapes(low, high)
     batch_shape = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
     self._support = constraints.integer_interval(low, high)
     super().__init__(batch_shape, validate_args=validate_args)