def log_prob(self, value): log_factorial_n = gammaln(self.total_count + 1) log_factorial_k = gammaln(value + 1) log_factorial_nmk = gammaln(self.total_count - value + 1) return (log_factorial_n - log_factorial_k - log_factorial_nmk + betaln(value + self.concentration1, self.total_count - value + self.concentration0) - betaln(self.concentration0, self.concentration1))
def log_prob(self, value): return ( -_log_beta_1(self.total_count - value + 1, value) + betaln( value + self.concentration1, self.total_count - value + self.concentration0, ) - betaln(self.concentration0, self.concentration1) )
def kl_divergence(p, q): # From https://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_(entropy) a, b = p.concentration1, p.concentration0 alpha, beta = q.concentration1, q.concentration0 a_diff = alpha - a b_diff = beta - b t1 = betaln(alpha, beta) - betaln(a, b) t2 = a_diff * digamma(a) + b_diff * digamma(b) t3 = (a_diff + b_diff) * digamma(a + b) return t1 - t2 + t3
def kl_divergence(p, q): # From https://arxiv.org/abs/1605.06197 Formula (12) a, b = p.concentration1, p.concentration0 alpha, beta = q.concentration1, q.concentration0 b_reciprocal = jnp.reciprocal(b) a_b = a * b t1 = (alpha / a - 1) * (jnp.euler_gamma + digamma(b) + b_reciprocal) t2 = jnp.log(a_b) + betaln(alpha, beta) + (b_reciprocal - 1) a_ = jnp.expand_dims(a, -1) b_ = jnp.expand_dims(b, -1) a_b_ = jnp.expand_dims(a_b, -1) m = jnp.arange(1, p.KL_KUMARASWAMY_BETA_TAYLOR_ORDER + 1) t3 = (beta - 1) * b * (jnp.exp(betaln(m / a_, b_)) / (m + a_b_)).sum(-1) return t1 + t2 + t3
def log_prob(self, value): post_value = self.concentration + value return ( -betaln(self.concentration, value + 1) - jnp.log(post_value) + self.concentration * jnp.log(self.rate) - post_value * jnp.log1p(self.rate) )
def logpdf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale) one = lax._const(x, 1) shape_term = lax.neg(betaln(a, b)) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.add(xlogy(lax.sub(a, one), y), xlog1py(lax.sub(b, one), lax.neg(y))) log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)), -inf, log_probs)