def _entropy_scalar_with_lax( total_count: int, probs: Array, log_of_probs: Array) -> Union[jnp.float32, jnp.float64]: """Like `_entropy_scalar`, but uses a lax while loop.""" dtype = probs.dtype log_n_factorial = lax.lgamma(jnp.asarray(total_count + 1, dtype=dtype)) def cond_func(args): xi, _ = args return jnp.less_equal(xi, total_count) def body_func(args): xi, accumulated_sum = args xi_float = jnp.asarray(xi, dtype=dtype) log_xi_factorial = lax.lgamma(xi_float + 1.) log_comb_n_xi = (log_n_factorial - log_xi_factorial - lax.lgamma(total_count - xi_float + 1.)) comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi)) likelihood1 = math.power_no_nan(probs, xi) likelihood2 = math.power_no_nan(1. - probs, total_count - xi) likelihood = likelihood1 * likelihood2 comb_term = comb_n_xi * log_xi_factorial * likelihood # [K] chex.assert_shape(comb_term, (probs.shape[-1], )) return xi + 1, accumulated_sum + comb_term comb_term = jnp.sum(lax.while_loop(cond_func, body_func, (0, jnp.zeros_like(probs)))[1], axis=-1) n_probs_factor = jnp.sum(total_count * math.multiply_no_nan(log_of_probs, probs), axis=-1) return -log_n_factorial - n_probs_factor + comb_term
def log_prob(self, value: Array) -> Array: """See `Distribution.log_prob`.""" total_permutations = lax.lgamma(self._total_count + 1.) counts_factorial = lax.lgamma(value + 1.) redundant_permutations = jnp.sum(counts_factorial, axis=-1) log_combinations = total_permutations - redundant_permutations return log_combinations + jnp.sum( math.multiply_no_nan(self.log_of_probs, value), axis=-1)
def cdf(self, value: Array) -> Array: """See `Distribution.cdf`.""" # For value < 0 the output should be zero because support = {0, ..., K-1}. should_be_zero = value < 0 # Will use value as an index below, so clip it to {0, ..., K-1}. value = jnp.clip(value, 0, self.num_categories - 1) value_one_hot = jax.nn.one_hot(value, self.num_categories) cdf = jnp.sum(math.multiply_no_nan( jnp.cumsum(self.probs, axis=-1), value_one_hot), axis=-1) return jnp.where(should_be_zero, jnp.array(0.), cdf)
def entropy(self) -> Array: """See `Distribution.entropy`.""" if self._logits is None: return -jnp.sum( math.multiply_no_nan(jnp.log(self._probs), self._probs), axis=-1) # The following result can be derived as follows. Write log(p[i]) as: # s[i]-m-lse(s[i]-m) where m=max(s), then you have: # sum_i exp(s[i]-m-lse(s-m)) (s[i] - m - lse(s-m)) # = -m - lse(s-m) + sum_i s[i] exp(s[i]-m-lse(s-m)) # = -m - lse(s-m) + (1/exp(lse(s-m))) sum_i s[i] exp(s[i]-m) # = -m - lse(s-m) + (1/sumexp(s-m)) sum_i s[i] exp(s[i]-m) # Write x[i]=s[i]-m then you have: # = -m - lse(x) + (1/sum_exp(x)) sum_i s[i] exp(x[i]) # Negating all of this result is the Shannon (discrete) entropy. m = jnp.max(self._logits, axis=-1, keepdims=True) x = self._logits - m sum_exp_x = jnp.sum(jnp.exp(x), axis=-1) lse_logits = jnp.squeeze(m, axis=-1) + jnp.log(sum_exp_x) return lse_logits - jnp.sum( math.multiply_no_nan(self._logits, jnp.exp(x)), axis=-1) / sum_exp_x
def _kl_divergence_bernoulli_bernoulli( dist1: Union[Bernoulli, tfd.Bernoulli], dist2: Union[Bernoulli, tfd.Bernoulli], *unused_args, **unused_kwargs, ) -> Array: """KL divergence `KL(dist1 || dist2)` between two Bernoulli distributions. Args: dist1: instance of a Bernoulli distribution. dist2: instance of a Bernoulli distribution. Returns: Batchwise `KL(dist1 || dist2)`. """ one_minus_p1, p1, log_one_minus_p1, log_p1 = _probs_and_log_probs(dist1) _, _, log_one_minus_p2, log_p2 = _probs_and_log_probs(dist2) # KL[a || b] = Pa * Log[Pa / Pb] + (1 - Pa) * Log[(1 - Pa) / (1 - Pb)] # Multiply each factor individually to avoid Inf - Inf return (math.multiply_no_nan(log_p1, p1) - math.multiply_no_nan(log_p2, p1) + math.multiply_no_nan(log_one_minus_p1, one_minus_p1) - math.multiply_no_nan(log_one_minus_p2, one_minus_p1))
def _entropy_scalar( total_count: int, probs: Array, log_of_probs: Array) -> Union[jnp.float32, jnp.float64]: """Calculates the entropy for a Multinomial with integer `total_count`.""" # Constant factors in the entropy. xi = jnp.arange(total_count + 1, dtype=probs.dtype) log_xi_factorial = lax.lgamma(xi + 1) log_n_minus_xi_factorial = jnp.flip(log_xi_factorial, axis=-1) log_n_factorial = log_xi_factorial[..., -1] log_comb_n_xi = (log_n_factorial[..., None] - log_xi_factorial - log_n_minus_xi_factorial) comb_n_xi = jnp.round(jnp.exp(log_comb_n_xi)) chex.assert_shape(comb_n_xi, (total_count + 1, )) likelihood1 = math.power_no_nan(probs[..., None], xi) likelihood2 = math.power_no_nan(1. - probs[..., None], total_count - xi) chex.assert_shape(likelihood1, ( probs.shape[-1], total_count + 1, )) chex.assert_shape(likelihood2, ( probs.shape[-1], total_count + 1, )) likelihood = jnp.sum(likelihood1 * likelihood2, axis=-2) chex.assert_shape(likelihood, (total_count + 1, )) comb_term = jnp.sum(comb_n_xi * log_xi_factorial * likelihood, axis=-1) chex.assert_shape(comb_term, ()) # Probs factors in the entropy. n_probs_factor = jnp.sum(total_count * math.multiply_no_nan(log_of_probs, probs), axis=-1) return -log_n_factorial - n_probs_factor + comb_term
def log_prob(self, value: Array) -> Array: """See `Distribution.log_prob`.""" value_one_hot = jax.nn.one_hot(value, self.num_categories) return jnp.sum(math.multiply_no_nan(self.logits, value_one_hot), axis=-1)
def test_multiply_no_nan(self): zero = jnp.zeros(()) nan = zero / zero self.assertTrue(jnp.isnan(math.multiply_no_nan(zero, nan))) self.assertFalse(jnp.isnan(math.multiply_no_nan(nan, zero)))
def entropy(self) -> Array: """See `Distribution.entropy`.""" (probs0, probs1, log_probs0, log_probs1) = _probs_and_log_probs(self) return -1. * (math.multiply_no_nan(log_probs0, probs0) + math.multiply_no_nan(log_probs1, probs1))
def prob(self, value: Array) -> Array: """See `Distribution.prob`.""" probs1 = self.probs probs0 = 1 - probs1 return (math.multiply_no_nan(probs0, 1 - value) + math.multiply_no_nan(probs1, value))
def log_prob(self, value: Array) -> Array: """See `Distribution.log_prob`.""" log_probs0, log_probs1 = self._log_probs_parameter() return (math.multiply_no_nan(log_probs0, 1 - value) + math.multiply_no_nan(log_probs1, value))