예제 #1
0
 def log_prob(self, value):
     log_factorial_n = gammaln(self.total_count + 1)
     log_factorial_k = gammaln(value + 1)
     log_factorial_nmk = gammaln(self.total_count - value + 1)
     return (log_factorial_n - log_factorial_k - log_factorial_nmk +
             xlogy(value, self.probs) +
             xlog1py(self.total_count - value, -self.probs))
예제 #2
0
 def logpmf(self, x, n, p):
     x, n, p = _promote_dtypes(x, n, p)
     if self.is_logits:
         return gammaln(n + 1) + np.sum(x * p - gammaln(x + 1),
                                        axis=-1) - n * logsumexp(p, axis=-1)
     else:
         return gammaln(n + 1) + np.sum(xlogy(x, p) - gammaln(x + 1),
                                        axis=-1)
예제 #3
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.probs)[0]
     value = lax.convert_element_type(value, dtype)
     total_count = lax.convert_element_type(self.total_count, dtype)
     return gammaln(total_count + 1) + np.sum(
         xlogy(value, self.probs) - gammaln(value + 1), axis=-1)
예제 #4
0
파일: discrete.py 프로젝트: hdocmsu/numpyro
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     log_factorial_n = gammaln(self.total_count + 1)
     log_factorial_k = gammaln(value + 1)
     log_factorial_nmk = gammaln(self.total_count - value + 1)
     return (log_factorial_n - log_factorial_k - log_factorial_nmk +
             xlogy(value, self.probs) + xlog1py(self.total_count - value, -self.probs))
예제 #5
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     dtype = get_dtypes(self.probs)[0]
     value = lax.convert_element_type(value, dtype)
     total_count = lax.convert_element_type(self.total_count, dtype)
     log_factorial_n = gammaln(total_count + 1)
     log_factorial_k = gammaln(value + 1)
     log_factorial_nmk = gammaln(total_count - value + 1)
     return (log_factorial_n - log_factorial_k - log_factorial_nmk +
             xlogy(value, self.probs) +
             xlog1py(total_count - value, -self.probs))
예제 #6
0
 def _logpmf(self, x, n, p):
     x, n, p = _promote_dtypes(x, n, p)
     combiln = gammaln(n + 1) - (gammaln(x + 1) + gammaln(n - x + 1))
     if self.is_logits:
         # TODO: move this implementation to PyTorch if it does not get non-continuous problem
         # In PyTorch, k * logit - n * log1p(e^logit) get overflow when logit is a large
         # positive number. In that case, we can reformulate into
         # k * logit - n * log1p(e^logit) = k * logit - n * (log1p(e^-logit) + logit)
         #                                = k * logit - n * logit - n * log1p(e^-logit)
         # More context: https://github.com/pytorch/pytorch/pull/15962/
         return combiln + x * p - (n * np.clip(p, 0) +
                                   xlog1py(n, np.exp(-np.abs(p))))
     else:
         return combiln + xlogy(x, p) + xlog1py(n - x, -p)
예제 #7
0
 def logpdf(self, x, alpha):
     lnB = _lnB(alpha)
     return -lnB + np.sum(xlogy(alpha - 1, x), axis=-1)
예제 #8
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     return xlogy(value, self.probs) + xlog1py(1 - value, -self.probs)
def test_xlogy_jac(x, y, grad1, grad2):
    assert_allclose(grad(lambda x, y: np.sum(xlogy(x, y)))(x, y), grad1)
    assert_allclose(grad(lambda x, y: np.sum(xlogy(x, y)), 1)(x, y), grad2)
예제 #10
0
 def _logpmf(self, x, p):
     if self.is_logits:
         return -binary_cross_entropy_with_logits(p, x)
     else:
         # TODO: consider always clamp and convert probs to logits
         return xlogy(x, p) + xlog1py(1 - x, -p)
예제 #11
0
 def _logpmf(self, x, mu):
     x, mu = _promote_dtypes(x, mu)
     Pk = xlogy(x, mu) - gammaln(x + 1) - mu
     return Pk
예제 #12
0
파일: discrete.py 프로젝트: hdocmsu/numpyro
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     return gammaln(self.total_count + 1) \
         + np.sum(xlogy(value, self.probs) - gammaln(value + 1), axis=-1)
예제 #13
0
 def _logpmf(self, x, n, p):
     k = np.floor(x)
     n, p = _promote_dtypes(n, p)
     combiln = (gammaln(n + 1) - (gammaln(k + 1) + gammaln(n - k + 1)))
     return combiln + xlogy(k, p) + xlog1py(n - k, -p)
예제 #14
0
 def _logpmf(self, x, n, p):
     n, p, x = _promote_dtypes(n, p, x)
     return gammaln(n + 1) + np.sum(xlogy(x, p) - gammaln(x + 1), axis=-1)
예제 #15
0
 def log_prob(self, value):
     return xlogy(value, self.probs) + xlog1py(1 - value, -self.probs)