def entropy(self, n, p): x = jnp.arange(1, jnp.max(n) + 1) term1 = n * jnp.sum(entr(p), axis=-1) - gammaln(n + 1) n = n[..., jnp.newaxis] new_axes_needed = max(p.ndim, n.ndim) - x.ndim + 1 x.shape += (1, ) * new_axes_needed term2 = jnp.sum(binom.pmf(x, n, p) * gammaln(x + 1), axis=(-1, -1 - new_axes_needed)) return term1 + term2
def _entropy(self, p): # TODO: use logits and binary_cross_entropy_with_logits for more stable if self.is_logits: p = expit(p) return entr(p) + entr(1 - p)
def _entropy(self, n, p): if self.is_logits: p = expit(p) k = np.arange(n + 1) vals = self._pmf(k, n, p) return np.sum(entr(vals), axis=0)