Esempio n. 1
0
 def _logpmf(self, x, n, p):
     x, n, p = _promote_dtypes(x, n, p)
     combiln = gammaln(n + 1) - (gammaln(x + 1) + gammaln(n - x + 1))
     if self.is_logits:
         # TODO: move this implementation to PyTorch if it does not get non-continuous problem
         # In PyTorch, k * logit - n * log1p(e^logit) get overflow when logit is a large
         # positive number. In that case, we can reformulate into
         # k * logit - n * log1p(e^logit) = k * logit - n * (log1p(e^-logit) + logit)
         #                                = k * logit - n * logit - n * log1p(e^-logit)
         # More context: https://github.com/pytorch/pytorch/pull/15962/
         return combiln + x * p - (n * jnp.clip(p, 0) + xlog1py(n, jnp.exp(-jnp.abs(p))))
     else:
         return combiln + xlogy(x, p) + xlog1py(n - x, -p)
Esempio n. 2
0
 def log_prob(self, value):
     log_factorial_n = gammaln(self.total_count + 1)
     log_factorial_k = gammaln(value + 1)
     log_factorial_nmk = gammaln(self.total_count - value + 1)
     return (log_factorial_n - log_factorial_k - log_factorial_nmk +
             xlogy(value, self.probs) +
             xlog1py(self.total_count - value, -self.probs))
Esempio n. 3
0
def binomial_lpmf(k, n, p):
    # Credit to https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/discrete.py
    log_factorial_n = gammaln(n + 1)
    log_factorial_k = gammaln(k + 1)
    log_factorial_nmk = gammaln(n - k + 1)
    return (log_factorial_n - log_factorial_k - log_factorial_nmk +
            xlogy(k, p) + xlog1py(n - k, -p))
Esempio n. 4
0
 def logpdf(self, k):
     k = jnp.floor(k)
     unnormalized = xlogy(k, self.p) + xlog1py(self.n - k, -self.p)
     binomialcoeffln = gammaln(self.n + 1) - (
         gammaln(k + 1) + gammaln(self.n - k + 1)
     )
     return unnormalized + binomialcoeffln
Esempio n. 5
0
 def log_prob(self, value):
     log_factorial_n = gammaln(self.total_count + 1)
     log_factorial_k = gammaln(value + 1)
     log_factorial_nmk = gammaln(self.total_count - value + 1)
     normalize_term = (self.total_count * np.clip(self.logits, 0) + xlog1py(
         self.total_count, np.exp(-np.abs(self.logits))) - log_factorial_n)
     return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
Esempio n. 6
0
def logpmf(k, p, loc=0):
    k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc)
    zero = lax._const(k, 0)
    one = lax._const(k, 1)
    x = lax.sub(k, loc)
    log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p)
    return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
Esempio n. 7
0
def logpmf(k, p, loc=0):
  k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc)
  zero = lax._const(k, 0)
  one = lax._const(k, 1)
  x = lax.sub(k, loc)
  log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p)
  return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)),
                  -jnp.inf, log_probs)
Esempio n. 8
0
File: beta.py Progetto: 0x0is1/jax
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
                                                scale)
    one = lax._const(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
                              xlog1py(lax.sub(b, one), lax.neg(y)))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
Esempio n. 9
0
 def testXlog1pyShouldReturnZero(self):
     self.assertAllClose(lsp_special.xlog1py(0., -1.),
                         0.,
                         check_dtypes=False)
Esempio n. 10
0
 def logpdf(self, x):
     return xlogy(x, self.p) + xlog1py(1 - x, -self.p)
Esempio n. 11
0
 def log_prob(self, value):
     return xlogy(value, self.probs) + xlog1py(1 - value, -self.probs)
Esempio n. 12
0
 def log_prob(self, value):
     ps_clamped = clamp_probs(self.probs)
     return xlogy(value, ps_clamped) + xlog1py(1 - value, -ps_clamped)
Esempio n. 13
0
 def _logpmf(self, x, p):
     if self.is_logits:
         return -binary_cross_entropy_with_logits(p, x)
     else:
         # TODO: consider always clamp and convert probs to logits
         return xlogy(x, p) + xlog1py(1 - x, -p)
Esempio n. 14
0
 def logpdf(self, x):
     """ (TODO): Check that x belongs to support, return -infty otherwise
     """
     return xlogy(x, self.p) + xlog1py(1 - x, -self.p)