def _logpmf(self, x, n, p): x, n, p = _promote_dtypes(x, n, p) combiln = gammaln(n + 1) - (gammaln(x + 1) + gammaln(n - x + 1)) if self.is_logits: # TODO: move this implementation to PyTorch if it does not get non-continuous problem # In PyTorch, k * logit - n * log1p(e^logit) get overflow when logit is a large # positive number. In that case, we can reformulate into # k * logit - n * log1p(e^logit) = k * logit - n * (log1p(e^-logit) + logit) # = k * logit - n * logit - n * log1p(e^-logit) # More context: https://github.com/pytorch/pytorch/pull/15962/ return combiln + x * p - (n * jnp.clip(p, 0) + xlog1py(n, jnp.exp(-jnp.abs(p)))) else: return combiln + xlogy(x, p) + xlog1py(n - x, -p)
def log_prob(self, value): log_factorial_n = gammaln(self.total_count + 1) log_factorial_k = gammaln(value + 1) log_factorial_nmk = gammaln(self.total_count - value + 1) return (log_factorial_n - log_factorial_k - log_factorial_nmk + xlogy(value, self.probs) + xlog1py(self.total_count - value, -self.probs))
def binomial_lpmf(k, n, p): # Credit to https://github.com/pyro-ppl/numpyro/blob/master/numpyro/distributions/discrete.py log_factorial_n = gammaln(n + 1) log_factorial_k = gammaln(k + 1) log_factorial_nmk = gammaln(n - k + 1) return (log_factorial_n - log_factorial_k - log_factorial_nmk + xlogy(k, p) + xlog1py(n - k, -p))
def logpdf(self, k): k = jnp.floor(k) unnormalized = xlogy(k, self.p) + xlog1py(self.n - k, -self.p) binomialcoeffln = gammaln(self.n + 1) - ( gammaln(k + 1) + gammaln(self.n - k + 1) ) return unnormalized + binomialcoeffln
def log_prob(self, value): log_factorial_n = gammaln(self.total_count + 1) log_factorial_k = gammaln(value + 1) log_factorial_nmk = gammaln(self.total_count - value + 1) normalize_term = (self.total_count * np.clip(self.logits, 0) + xlog1py( self.total_count, np.exp(-np.abs(self.logits))) - log_factorial_n) return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
def logpmf(k, p, loc=0): k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc) zero = lax._const(k, 0) one = lax._const(k, 1) x = lax.sub(k, loc) log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p) return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
def logpmf(k, p, loc=0): k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc) zero = lax._const(k, 0) one = lax._const(k, 1) x = lax.sub(k, loc) log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p) return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)), -jnp.inf, log_probs)
def logpdf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale) one = lax._const(x, 1) shape_term = lax.neg(betaln(a, b)) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.add(xlogy(lax.sub(a, one), y), xlog1py(lax.sub(b, one), lax.neg(y))) log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)), -inf, log_probs)
def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False)
def logpdf(self, x): return xlogy(x, self.p) + xlog1py(1 - x, -self.p)
def log_prob(self, value): return xlogy(value, self.probs) + xlog1py(1 - value, -self.probs)
def log_prob(self, value): ps_clamped = clamp_probs(self.probs) return xlogy(value, ps_clamped) + xlog1py(1 - value, -ps_clamped)
def _logpmf(self, x, p): if self.is_logits: return -binary_cross_entropy_with_logits(p, x) else: # TODO: consider always clamp and convert probs to logits return xlogy(x, p) + xlog1py(1 - x, -p)
def logpdf(self, x): """ (TODO): Check that x belongs to support, return -infty otherwise """ return xlogy(x, self.p) + xlog1py(1 - x, -self.p)