def log_prob(self, value): log_factorial_n = gammaln(self.total_count + 1) log_factorial_k = gammaln(value + 1) log_factorial_nmk = gammaln(self.total_count - value + 1) return (log_factorial_n - log_factorial_k - log_factorial_nmk + xlogy(value, self.probs) + xlog1py(self.total_count - value, -self.probs))
def logpmf(self, x, n, p): x, n, p = _promote_dtypes(x, n, p) if self.is_logits: return gammaln(n + 1) + np.sum(x * p - gammaln(x + 1), axis=-1) - n * logsumexp(p, axis=-1) else: return gammaln(n + 1) + np.sum(xlogy(x, p) - gammaln(x + 1), axis=-1)
def log_prob(self, value): if self._validate_args: self._validate_sample(value) dtype = get_dtypes(self.probs)[0] value = lax.convert_element_type(value, dtype) total_count = lax.convert_element_type(self.total_count, dtype) return gammaln(total_count + 1) + np.sum( xlogy(value, self.probs) - gammaln(value + 1), axis=-1)
def log_prob(self, value): if self._validate_args: self._validate_sample(value) log_factorial_n = gammaln(self.total_count + 1) log_factorial_k = gammaln(value + 1) log_factorial_nmk = gammaln(self.total_count - value + 1) return (log_factorial_n - log_factorial_k - log_factorial_nmk + xlogy(value, self.probs) + xlog1py(self.total_count - value, -self.probs))
def log_prob(self, value): if self._validate_args: self._validate_sample(value) dtype = get_dtypes(self.probs)[0] value = lax.convert_element_type(value, dtype) total_count = lax.convert_element_type(self.total_count, dtype) log_factorial_n = gammaln(total_count + 1) log_factorial_k = gammaln(value + 1) log_factorial_nmk = gammaln(total_count - value + 1) return (log_factorial_n - log_factorial_k - log_factorial_nmk + xlogy(value, self.probs) + xlog1py(total_count - value, -self.probs))
def _logpmf(self, x, n, p): x, n, p = _promote_dtypes(x, n, p) combiln = gammaln(n + 1) - (gammaln(x + 1) + gammaln(n - x + 1)) if self.is_logits: # TODO: move this implementation to PyTorch if it does not get non-continuous problem # In PyTorch, k * logit - n * log1p(e^logit) get overflow when logit is a large # positive number. In that case, we can reformulate into # k * logit - n * log1p(e^logit) = k * logit - n * (log1p(e^-logit) + logit) # = k * logit - n * logit - n * log1p(e^-logit) # More context: https://github.com/pytorch/pytorch/pull/15962/ return combiln + x * p - (n * np.clip(p, 0) + xlog1py(n, np.exp(-np.abs(p)))) else: return combiln + xlogy(x, p) + xlog1py(n - x, -p)
def logpdf(self, x, alpha): lnB = _lnB(alpha) return -lnB + np.sum(xlogy(alpha - 1, x), axis=-1)
def log_prob(self, value): if self._validate_args: self._validate_sample(value) return xlogy(value, self.probs) + xlog1py(1 - value, -self.probs)
def test_xlogy_jac(x, y, grad1, grad2): assert_allclose(grad(lambda x, y: np.sum(xlogy(x, y)))(x, y), grad1) assert_allclose(grad(lambda x, y: np.sum(xlogy(x, y)), 1)(x, y), grad2)
def _logpmf(self, x, p): if self.is_logits: return -binary_cross_entropy_with_logits(p, x) else: # TODO: consider always clamp and convert probs to logits return xlogy(x, p) + xlog1py(1 - x, -p)
def _logpmf(self, x, mu): x, mu = _promote_dtypes(x, mu) Pk = xlogy(x, mu) - gammaln(x + 1) - mu return Pk
def log_prob(self, value): if self._validate_args: self._validate_sample(value) return gammaln(self.total_count + 1) \ + np.sum(xlogy(value, self.probs) - gammaln(value + 1), axis=-1)
def _logpmf(self, x, n, p): k = np.floor(x) n, p = _promote_dtypes(n, p) combiln = (gammaln(n + 1) - (gammaln(k + 1) + gammaln(n - k + 1))) return combiln + xlogy(k, p) + xlog1py(n - k, -p)
def _logpmf(self, x, n, p): n, p, x = _promote_dtypes(n, p, x) return gammaln(n + 1) + np.sum(xlogy(x, p) - gammaln(x + 1), axis=-1)
def log_prob(self, value): return xlogy(value, self.probs) + xlog1py(1 - value, -self.probs)