예제 #1
0
    def _entropy_scalar_with_lax(
            total_count: int, probs: Array,
            log_of_probs: Array) -> Union[jnp.float32, jnp.float64]:
        """Like `_entropy_scalar`, but uses a lax while loop."""

        dtype = probs.dtype
        log_n_factorial = lax.lgamma(jnp.asarray(total_count + 1, dtype=dtype))

        def cond_func(args):
            xi, _ = args
            return jnp.less_equal(xi, total_count)

        def body_func(args):
            xi, accumulated_sum = args
            xi_float = jnp.asarray(xi, dtype=dtype)
            log_xi_factorial = lax.lgamma(xi_float + 1.)
            log_comb_n_xi = (log_n_factorial - log_xi_factorial -
                             lax.lgamma(total_count - xi_float + 1.))
            comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi))
            likelihood1 = math.power_no_nan(probs, xi)
            likelihood2 = math.power_no_nan(1. - probs, total_count - xi)
            likelihood = likelihood1 * likelihood2
            comb_term = comb_n_xi * log_xi_factorial * likelihood  # [K]
            chex.assert_shape(comb_term, (probs.shape[-1], ))
            return xi + 1, accumulated_sum + comb_term

        comb_term = jnp.sum(lax.while_loop(cond_func, body_func,
                                           (0, jnp.zeros_like(probs)))[1],
                            axis=-1)

        n_probs_factor = jnp.sum(total_count *
                                 math.multiply_no_nan(log_of_probs, probs),
                                 axis=-1)

        return -log_n_factorial - n_probs_factor + comb_term
예제 #2
0
 def log_prob(self, value: Array) -> Array:
     """See `Distribution.log_prob`."""
     total_permutations = lax.lgamma(self._total_count + 1.)
     counts_factorial = lax.lgamma(value + 1.)
     redundant_permutations = jnp.sum(counts_factorial, axis=-1)
     log_combinations = total_permutations - redundant_permutations
     return log_combinations + jnp.sum(
         math.multiply_no_nan(self.log_of_probs, value), axis=-1)
예제 #3
0
 def cdf(self, value: Array) -> Array:
   """See `Distribution.cdf`."""
   # For value < 0 the output should be zero because support = {0, ..., K-1}.
   should_be_zero = value < 0
   # Will use value as an index below, so clip it to {0, ..., K-1}.
   value = jnp.clip(value, 0, self.num_categories - 1)
   value_one_hot = jax.nn.one_hot(value, self.num_categories)
   cdf = jnp.sum(math.multiply_no_nan(
       jnp.cumsum(self.probs, axis=-1), value_one_hot), axis=-1)
   return jnp.where(should_be_zero, jnp.array(0.), cdf)
예제 #4
0
 def entropy(self) -> Array:
   """See `Distribution.entropy`."""
   if self._logits is None:
     return -jnp.sum(
         math.multiply_no_nan(jnp.log(self._probs), self._probs), axis=-1)
   # The following result can be derived as follows. Write log(p[i]) as:
   # s[i]-m-lse(s[i]-m) where m=max(s), then you have:
   #   sum_i exp(s[i]-m-lse(s-m)) (s[i] - m - lse(s-m))
   #   = -m - lse(s-m) + sum_i s[i] exp(s[i]-m-lse(s-m))
   #   = -m - lse(s-m) + (1/exp(lse(s-m))) sum_i s[i] exp(s[i]-m)
   #   = -m - lse(s-m) + (1/sumexp(s-m)) sum_i s[i] exp(s[i]-m)
   # Write x[i]=s[i]-m then you have:
   #   = -m - lse(x) + (1/sum_exp(x)) sum_i s[i] exp(x[i])
   # Negating all of this result is the Shannon (discrete) entropy.
   m = jnp.max(self._logits, axis=-1, keepdims=True)
   x = self._logits - m
   sum_exp_x = jnp.sum(jnp.exp(x), axis=-1)
   lse_logits = jnp.squeeze(m, axis=-1) + jnp.log(sum_exp_x)
   return lse_logits - jnp.sum(
       math.multiply_no_nan(self._logits, jnp.exp(x)), axis=-1) / sum_exp_x
예제 #5
0
def _kl_divergence_bernoulli_bernoulli(
    dist1: Union[Bernoulli, tfd.Bernoulli],
    dist2: Union[Bernoulli, tfd.Bernoulli],
    *unused_args,
    **unused_kwargs,
) -> Array:
    """KL divergence `KL(dist1 || dist2)` between two Bernoulli distributions.

  Args:
    dist1: instance of a Bernoulli distribution.
    dist2: instance of a Bernoulli distribution.

  Returns:
    Batchwise `KL(dist1 || dist2)`.
  """
    one_minus_p1, p1, log_one_minus_p1, log_p1 = _probs_and_log_probs(dist1)
    _, _, log_one_minus_p2, log_p2 = _probs_and_log_probs(dist2)
    # KL[a || b] = Pa * Log[Pa / Pb] + (1 - Pa) * Log[(1 - Pa) / (1 - Pb)]
    # Multiply each factor individually to avoid Inf - Inf
    return (math.multiply_no_nan(log_p1, p1) -
            math.multiply_no_nan(log_p2, p1) +
            math.multiply_no_nan(log_one_minus_p1, one_minus_p1) -
            math.multiply_no_nan(log_one_minus_p2, one_minus_p1))
예제 #6
0
    def _entropy_scalar(
            total_count: int, probs: Array,
            log_of_probs: Array) -> Union[jnp.float32, jnp.float64]:
        """Calculates the entropy for a Multinomial with integer `total_count`."""
        # Constant factors in the entropy.
        xi = jnp.arange(total_count + 1, dtype=probs.dtype)
        log_xi_factorial = lax.lgamma(xi + 1)
        log_n_minus_xi_factorial = jnp.flip(log_xi_factorial, axis=-1)
        log_n_factorial = log_xi_factorial[..., -1]
        log_comb_n_xi = (log_n_factorial[..., None] - log_xi_factorial -
                         log_n_minus_xi_factorial)
        comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi))
        chex.assert_shape(comb_n_xi, (total_count + 1, ))

        likelihood1 = math.power_no_nan(probs[..., None], xi)
        likelihood2 = math.power_no_nan(1. - probs[..., None],
                                        total_count - xi)
        chex.assert_shape(likelihood1, (
            probs.shape[-1],
            total_count + 1,
        ))
        chex.assert_shape(likelihood2, (
            probs.shape[-1],
            total_count + 1,
        ))
        likelihood = jnp.sum(likelihood1 * likelihood2, axis=-2)
        chex.assert_shape(likelihood, (total_count + 1, ))
        comb_term = jnp.sum(comb_n_xi * log_xi_factorial * likelihood, axis=-1)
        chex.assert_shape(comb_term, ())

        # Probs factors in the entropy.
        n_probs_factor = jnp.sum(total_count *
                                 math.multiply_no_nan(log_of_probs, probs),
                                 axis=-1)

        return -log_n_factorial - n_probs_factor + comb_term
예제 #7
0
 def log_prob(self, value: Array) -> Array:
   """See `Distribution.log_prob`."""
   value_one_hot = jax.nn.one_hot(value, self.num_categories)
   return jnp.sum(math.multiply_no_nan(self.logits, value_one_hot), axis=-1)
예제 #8
0
 def test_multiply_no_nan(self):
     zero = jnp.zeros(())
     nan = zero / zero
     self.assertTrue(jnp.isnan(math.multiply_no_nan(zero, nan)))
     self.assertFalse(jnp.isnan(math.multiply_no_nan(nan, zero)))
예제 #9
0
 def entropy(self) -> Array:
     """See `Distribution.entropy`."""
     (probs0, probs1, log_probs0, log_probs1) = _probs_and_log_probs(self)
     return -1. * (math.multiply_no_nan(log_probs0, probs0) +
                   math.multiply_no_nan(log_probs1, probs1))
예제 #10
0
 def prob(self, value: Array) -> Array:
     """See `Distribution.prob`."""
     probs1 = self.probs
     probs0 = 1 - probs1
     return (math.multiply_no_nan(probs0, 1 - value) +
             math.multiply_no_nan(probs1, value))
예제 #11
0
 def log_prob(self, value: Array) -> Array:
     """See `Distribution.log_prob`."""
     log_probs0, log_probs1 = self._log_probs_parameter()
     return (math.multiply_no_nan(log_probs0, 1 - value) +
             math.multiply_no_nan(log_probs1, value))