示例#1
0
    def entropy(self, n, p):
        x = jnp.arange(1, jnp.max(n) + 1)

        term1 = n * jnp.sum(entr(p), axis=-1) - gammaln(n + 1)

        n = n[..., jnp.newaxis]
        new_axes_needed = max(p.ndim, n.ndim) - x.ndim + 1
        x.shape += (1, ) * new_axes_needed

        term2 = jnp.sum(binom.pmf(x, n, p) * gammaln(x + 1),
                        axis=(-1, -1 - new_axes_needed))

        return term1 + term2
示例#2
0
 def _entropy(self, p):
     # TODO: use logits and binary_cross_entropy_with_logits for more stable
     if self.is_logits:
         p = expit(p)
     return entr(p) + entr(1 - p)
示例#3
0
 def _entropy(self, n, p):
     if self.is_logits:
         p = expit(p)
     k = np.arange(n + 1)
     vals = self._pmf(k, n, p)
     return np.sum(entr(vals), axis=0)