Пример #1
0
 def log_prob(self, value):
     log_factorial_n = gammaln(self.total_count + 1)
     log_factorial_k = gammaln(value + 1)
     log_factorial_nmk = gammaln(self.total_count - value + 1)
     return (log_factorial_n - log_factorial_k - log_factorial_nmk +
             betaln(value + self.concentration1,
                    self.total_count - value + self.concentration0) -
             betaln(self.concentration0, self.concentration1))
Пример #2
0
 def log_prob(self, value):
     return (
         -_log_beta_1(self.total_count - value + 1, value)
         + betaln(
             value + self.concentration1,
             self.total_count - value + self.concentration0,
         )
         - betaln(self.concentration0, self.concentration1)
     )
Пример #3
0
def kl_divergence(p, q):
    # From https://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_(entropy)
    a, b = p.concentration1, p.concentration0
    alpha, beta = q.concentration1, q.concentration0
    a_diff = alpha - a
    b_diff = beta - b
    t1 = betaln(alpha, beta) - betaln(a, b)
    t2 = a_diff * digamma(a) + b_diff * digamma(b)
    t3 = (a_diff + b_diff) * digamma(a + b)
    return t1 - t2 + t3
Пример #4
0
def kl_divergence(p, q):
    # From https://arxiv.org/abs/1605.06197 Formula (12)
    a, b = p.concentration1, p.concentration0
    alpha, beta = q.concentration1, q.concentration0
    b_reciprocal = jnp.reciprocal(b)
    a_b = a * b
    t1 = (alpha / a - 1) * (jnp.euler_gamma + digamma(b) + b_reciprocal)
    t2 = jnp.log(a_b) + betaln(alpha, beta) + (b_reciprocal - 1)
    a_ = jnp.expand_dims(a, -1)
    b_ = jnp.expand_dims(b, -1)
    a_b_ = jnp.expand_dims(a_b, -1)
    m = jnp.arange(1, p.KL_KUMARASWAMY_BETA_TAYLOR_ORDER + 1)
    t3 = (beta - 1) * b * (jnp.exp(betaln(m / a_, b_)) / (m + a_b_)).sum(-1)
    return t1 + t2 + t3
Пример #5
0
 def log_prob(self, value):
     post_value = self.concentration + value
     return (
         -betaln(self.concentration, value + 1)
         - jnp.log(post_value)
         + self.concentration * jnp.log(self.rate)
         - post_value * jnp.log1p(self.rate)
     )
Пример #6
0
Файл: beta.py Проект: 0x0is1/jax
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
                                                scale)
    one = lax._const(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
                              xlog1py(lax.sub(b, one), lax.neg(y)))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)