def _entropy(self): if self._logits is None: # If we only have probs, there's not much we can do to ensure numerical # precision. probs = tf.convert_to_tensor(self._probs) return -tf.reduce_sum( tf.math.multiply_no_nan(tf.math.log(probs), probs), axis=-1) # The following result can be derived as follows. Write log(p[i]) as: # s[i]-m-lse(s[i]-m) where m=max(s), then you have: # sum_i exp(s[i]-m-lse(s-m)) (s[i] - m - lse(s-m)) # = -m - lse(s-m) + sum_i s[i] exp(s[i]-m-lse(s-m)) # = -m - lse(s-m) + (1/exp(lse(s-m))) sum_i s[i] exp(s[i]-m) # = -m - lse(s-m) + (1/sumexp(s-m)) sum_i s[i] exp(s[i]-m) # Write x[i]=s[i]-m then you have: # = -m - lse(x) + (1/sum_exp(x)) sum_i s[i] exp(x[i]) # Negating all of this result is the Shanon (discrete) entropy. logits = tf.convert_to_tensor(self._logits) m = tf.reduce_max(logits, axis=-1, keepdims=True) x = logits - m lse_logits = m[..., 0] + tf.reduce_logsumexp(x, axis=-1) sum_exp_x = tf.reduce_sum(tf.math.exp(x), axis=-1) return lse_logits - tf.reduce_sum(tf.math.multiply_no_nan( logits, tf.math.exp(x)), axis=-1) / sum_exp_x
def _inverse_log_det_jacobian(self, y, use_saved_statistics=False): if not self.batchnorm.built: # Create variables. self.batchnorm.build(y.shape) event_dims = self.batchnorm.axis reduction_axes = [i for i in range(len(y.shape)) if i not in event_dims] # At training-time, ildj is computed from the mean and log-variance across # the current minibatch. # We use multiplication instead of tf.where() to get easier broadcasting. log_variance = tf.math.log( tf.where( tf.logical_or(use_saved_statistics, tf.logical_not(self._training)), self.batchnorm.moving_variance, tf.nn.moments(x=y, axes=reduction_axes, keepdims=True)[1]) + self.batchnorm.epsilon) # TODO(b/137216713): determine whether it's unsafe for the reduce_sums below # to happen across all axes. # `gamma` and `log Var(y)` reductions over event_dims. # Log(total change in area from gamma term). log_total_gamma = tf.reduce_sum(tf.math.log(self.batchnorm.gamma)) # Log(total change in area from log-variance term). log_total_variance = tf.reduce_sum(log_variance) # The ildj is scalar, as it does not depend on the values of x and are # constant across minibatch elements. return log_total_gamma - 0.5 * log_total_variance
def _entropy(self): concentration = tf.convert_to_tensor(self.concentration) k = tf.cast(tf.shape(concentration)[-1], self.dtype) total_concentration = tf.reduce_sum(concentration, axis=-1) return (tf.math.lbeta(concentration) + ((total_concentration - k) * tf.math.digamma(total_concentration)) - tf.reduce_sum((concentration - 1.) * tf.math.digamma(concentration), axis=-1))
def _variance(self): with tf.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) probs = distribution_utils.pad_mixture_dimensions( self.mixture_distribution.probs_parameter(), self, self.mixture_distribution, self._event_ndims) # [B, k, [1]*e] mean_cond_var = tf.reduce_sum( probs * self.components_distribution.variance(), axis=-1 - self._event_ndims) # [B, E] var_cond_mean = tf.reduce_sum(probs * tf.math.squared_difference( self.components_distribution.mean(), self._pad_sample_dims(self._mean())), axis=-1 - self._event_ndims) # [B, E] return mean_cond_var + var_cond_mean # [B, E]
def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html # The gradient of KL[p,q] is not defined when p==q. The culprit is # tf.norm, i.e., we cannot use the commented out code. # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1])) return tf.reduce_sum(tf.square(x), axis=[-2, -1])
def _log_prob(self, x): logits = self._logits_parameter_no_checks() event_size = self._event_size(logits) x = tf.cast(x, logits.dtype) x = self._maybe_assert_valid_sample(x, dtype=logits.dtype) # broadcast logits or x if need be. if (not tensorshape_util.is_fully_defined(x.shape) or not tensorshape_util.is_fully_defined(logits.shape) or x.shape != logits.shape): broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(logits), tf.shape(x)) logits = tf.broadcast_to(logits, broadcast_shape) x = tf.broadcast_to(x, broadcast_shape) logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1)) logits_2d = tf.reshape(logits, [-1, event_size]) x_2d = tf.reshape(x, [-1, event_size]) ret = -tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(x_2d), logits=logits_2d) # Reshape back to user-supplied batch and sample dims prior to 2D reshape. ret = tf.reshape(ret, logits_shape) return ret
def _get_entropy(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count prob = tf.cast(count / num_samples, dtype=self.dtype) entropy = tf.reduce_sum(-prob * tf.math.log(prob)) return entropy
def _log_prob(self, x, **kwargs): batch_ndims = prefer_static.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape) event_ndims = prefer_static.rank_from_shape( self.distribution.event_shape_tensor, self.distribution.event_shape) ndims = prefer_static.rank(x) # (1) Expand x's dims. d = ndims - batch_ndims - extra_sample_ndims - event_ndims x = tf.reshape(x, shape=tf.pad( tf.shape(x), paddings=[[prefer_static.maximum(0, -d), 0]], constant_values=1)) sample_ndims = prefer_static.maximum(0, d) # (2) Transpose x's dims. sample_dims = prefer_static.range(0, sample_ndims) batch_dims = prefer_static.range(sample_ndims, sample_ndims + batch_ndims) extra_sample_dims = prefer_static.range( sample_ndims + batch_ndims, sample_ndims + batch_ndims + extra_sample_ndims) event_dims = prefer_static.range( sample_ndims + batch_ndims + extra_sample_ndims, ndims) perm = prefer_static.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) x = tf.transpose(a=x, perm=perm) # (3) Compute x's log_prob. lp = self.distribution.log_prob(x, **kwargs) # (4) Make the final reduction in x. axis = prefer_static.range(sample_ndims, sample_ndims + extra_sample_ndims) return tf.reduce_sum(lp, axis=axis)
def _variance(self): concentration = tf.convert_to_tensor(self.concentration) total_concentration = tf.reduce_sum(concentration, axis=-1, keepdims=True) mean = concentration / total_concentration scale = tf.math.rsqrt(1. + total_concentration) x = scale * mean return x * (scale - x)
def matrix_rank(a, tol=None, validate_args=False, name=None): """Compute the matrix rank; the number of non-zero SVD singular values. Arguments: a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be pseudo-inverted. tol: Threshold below which the singular value is counted as 'zero'. Default value: `None` (i.e., `eps * max(rows, cols) * max(singular_val)`). validate_args: When `True`, additional assertions might be embedded in the graph. Default value: `False` (i.e., no graph assertions are added). name: Python `str` prefixed to ops created by this function. Default value: 'matrix_rank'. Returns: matrix_rank: (Batch of) `int32` scalars representing the number of non-zero singular values. """ with tf.name_scope(name or 'matrix_rank'): a = tf.convert_to_tensor(a, dtype_hint=tf.float32, name='a') assertions = _maybe_validate_matrix(a, validate_args) if assertions: with tf.control_dependencies(assertions): a = tf.identity(a) s = tf.linalg.svd(a, compute_uv=False) if tol is None: if tensorshape_util.is_fully_defined(a.shape[-2:]): m = np.max(a.shape[-2:].as_list()) else: m = tf.reduce_max(tf.shape(a)[-2:]) eps = np.finfo(dtype_util.as_numpy_dtype(a.dtype)).eps tol = (eps * tf.cast(m, a.dtype) * tf.reduce_max(s, axis=-1, keepdims=True)) return tf.reduce_sum(tf.cast(s > tol, tf.int32), axis=-1)
def log_combinations(n, counts, name='log_combinations'): """Multinomial coefficient. Given `n` and `counts`, where `counts` has last dimension `k`, we compute the multinomial coefficient as: ```n! / sum_i n_i!``` where `i` runs over all `k` classes. Args: n: Floating-point `Tensor` broadcastable with `counts`. This represents `n` outcomes. counts: Floating-point `Tensor` broadcastable with `n`. This represents counts in `k` classes, where `k` is the last dimension of the tensor. name: A name for this operation (optional). Returns: log_combinations: `Tensor` representing the multinomial coefficient between `n` and `counts`. """ # First a bit about the number of ways counts could have come in: # E.g. if counts = [1, 2], then this is 3 choose 2. # In general, this is (sum counts)! / sum(counts!) # The sum should be along the last dimension of counts. This is the # 'distribution' dimension. Here n a priori represents the sum of counts. with tf.name_scope(name): n = tf.convert_to_tensor(n, name='n') counts = tf.convert_to_tensor(counts, name='counts') total_permutations = tf.math.lgamma(n + 1) counts_factorial = tf.math.lgamma(counts + 1) redundant_permutations = tf.reduce_sum(counts_factorial, axis=-1) return total_permutations - redundant_permutations
def _log_prob(self, counts): with tf.control_dependencies(self._maybe_assert_valid_sample(counts)): log_p = (tf.math.log(self._probs) if self._logits is None else tf.math.log_softmax(self._logits)) k = tf.convert_to_tensor(self.total_count) return (tf.reduce_sum(counts * log_p, axis=-1) + # log_unnorm_prob tfp_math.log_combinations(k, counts)) # -log_normalization
def _mean(self): with tf.control_dependencies(self._runtime_assertions): probs = distribution_utils.pad_mixture_dimensions( self.mixture_distribution.probs_parameter(), self, self.mixture_distribution, self._event_ndims) # [B, k, [1]*e] return tf.reduce_sum(probs * self.components_distribution.mean(), axis=-1 - self._event_ndims) # [B, E]
def backward_step(most_likely_successor, most_likely_given_successor): return tf.reduce_sum( input_tensor=(most_likely_given_successor * tf.one_hot(most_likely_successor, self._num_states, dtype=tf.int64)), axis=-1)
def _rotate(self, samples): """Applies a Householder rotation to `samples`.""" event_dim = ( tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor()[0]) basis = tf.concat([[1.], tf.zeros([event_dim - 1], dtype=self.dtype)], axis=0), u = tf.math.l2_normalize(basis - self.mean_direction, axis=-1) return samples - 2 * tf.reduce_sum(samples * u, axis=-1, keepdims=True) * u
def _forward_log_det_jacobian(self, x): # For a discussion of this (non-obvious) result, see Note 7.2.2 (and the # sections leading up to it, for context) in # http://neutrino.aquaphoenix.com/ReactionDiffusion/SERC5chap7.pdf with tf.control_dependencies(self._assertions(x)): matrix_dim = tf.cast( tf.shape(x)[-1], dtype_util.base_dtype(x.dtype)) return -(matrix_dim + 1) * tf.reduce_sum( tf.math.log(tf.abs(tf.linalg.diag_part(x))), axis=-1)
def _maybe_assert_valid_sample(self, x, dtype): if not self.validate_args: return x one = tf.ones([], dtype=dtype) return distribution_util.with_dependencies([ assert_util.assert_non_negative(x), assert_util.assert_less_equal(x, one), assert_util.assert_near(one, tf.reduce_sum(x, axis=[-1])), ], x)
def body(m, pchol, perm, matrix_diag): """Body of a single `tf.while_loop` iteration.""" # Here is roughly a numpy, non-batched version of what's going to happen. # (See also Algorithm 1 of Harbrecht et al.) # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m # 2: maxval = matrix_diag[perm][maxi] # 3: perm[m], perm[maxi] = perm[maxi], perm[m] # 4: row = matrix[perm[m]][perm[m + 1:]] # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2) # 6: pivot = np.sqrt(maxval); row /= pivot # 7: row = np.concatenate([[[pivot]], row], -1) # 8: matrix_diag[perm[m:]] -= row**2 # 9: pchol[m, perm[m:]] = row # Find the maximal position of the (remaining) permuted diagonal. # Steps 1, 2 above. permuted_diag = batch_gather(matrix_diag, perm[..., m:]) maxi = tf.argmax(permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis] maxval = batch_gather(permuted_diag, maxi) maxi = maxi + m maxval = maxval[..., 0] # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above. perm = _swap_m_with_i(perm, m, maxi) # Step 4. row = batch_gather(matrix, perm[..., m:m + 1], axis=-2) row = batch_gather(row, perm[..., m + 1:]) # Step 5. prev_rows = pchol[..., :m, :] prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:]) prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1]) row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col, axis=-2)[..., tf.newaxis, :] # Step 6. pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis] # Step 7. row = tf.concat([pivot, row / pivot], axis=-1) # TODO(b/130899118): Pad grad fails with int64 paddings. # Step 8. paddings = tf.concat([ tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32), [[tf.cast(m, tf.int32), 0]] ], axis=0) diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :] reverse_perm = _invert_permutation(perm) matrix_diag -= batch_gather(diag_update, reverse_perm) # Step 9. row = tf.pad(row, paddings=paddings) # TODO(bjp): Defer the reverse permutation all-at-once at the end? row = batch_gather(row, reverse_perm) pchol_shape = pchol.shape pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]], axis=-2) tensorshape_util.set_shape(pchol, pchol_shape) return m + 1, pchol, perm, matrix_diag
def grad(dy): """Computes a derivative for the min and max parameters. This function implements the derivative wrt the truncation bounds, which get blocked by the sampler. We use a custom expression for numerical stability instead of automatic differentiation on CDF for implicit gradients. Args: dy: output gradients Returns: The standard normal samples and the gradients wrt the upper bound and lower bound. """ # std_samples has an extra dimension (the sample dimension), expand # lower and upper so they broadcast along this dimension. # See note above regarding parameterized_truncated_normal, the sample # dimension is the final dimension. lower_broadcast = lower[..., tf.newaxis] upper_broadcast = upper[..., tf.newaxis] cdf_samples = ((special_math.ndtr(std_samples) - special_math.ndtr(lower_broadcast)) / (special_math.ndtr(upper_broadcast) - special_math.ndtr(lower_broadcast))) # tiny, eps are tolerance parameters to ensure we stay away from giving # a zero arg to the log CDF expression. tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps) du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) + tf.math.log(cdf_samples)) dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) + tf.math.log1p(-cdf_samples)) # Reduce the gradient across the samples grad_u = tf.reduce_sum(dy * du, axis=-1) grad_l = tf.reduce_sum(dy * dl, axis=-1) return [grad_l, grad_u]
def _covariance(self): concentration = tf.convert_to_tensor(self.concentration) total_concentration = tf.reduce_sum(concentration, axis=-1, keepdims=True) mean = concentration / total_concentration scale = tf.math.rsqrt(1. + total_concentration) x = scale * mean variance = x * (scale - x) return tf.linalg.set_diag( tf.matmul(-x[..., tf.newaxis], x[..., tf.newaxis, :]), variance)
def _inverse_log_det_jacobian(self, y): # The Jacobian of the inverse mapping is lower # triangular, with the diagonal elements being: # J[i,i] = 1 if i=1, and # exp(y_i) if 1<i<=K # which gives the absolute Jacobian determinant: # |det(Jac)| = prod_{i=1}^{K} exp(y[i]). # (1) - Stan Modeling Language User's Guide and Reference Manual # Version 2.17.0 session 35.2 return tf.reduce_sum(y[..., 1:], axis=-1)
def _maybe_assert_valid_sample(self, x): """Checks the validity of a sample.""" if not self.validate_args: return [] return [ assert_util.assert_positive(x, message='samples must be positive'), assert_util.assert_near( tf.ones([], dtype=self.dtype), tf.reduce_sum(x, axis=-1), message='sample last-dimension must sum to `1`'), ]
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return [] assertions = distribution_util.assert_nonnegative_integer_form(counts) assertions.append( assert_util.assert_equal( self.total_count, tf.reduce_sum(counts, axis=-1), message='counts must sum to `self.total_count`')) return assertions
def _variance(self): probs = self._categorical.probs_parameter() outcomes = tf.broadcast_to( self.outcomes, shape=dist_util.prefer_static_shape(probs)) if dtype_util.is_integer(outcomes.dtype): if self._validate_args: outcomes = dist_util.embed_check_integer_casting_closed( outcomes, target_dtype=probs.dtype) outcomes = tf.cast(outcomes, dtype=probs.dtype) square_d = tf.math.squared_difference( outcomes, self._mean(probs)[..., tf.newaxis]) return tf.reduce_sum(probs * square_d, axis=-1)
def _maybe_assert_valid_sample(self, counts): """Check counts for proper shape, values, then return tensor version.""" if not self.validate_args: return counts counts = distribution_util.embed_check_nonnegative_integer_form(counts) return distribution_util.with_dependencies([ assert_util.assert_equal( self.total_count, tf.reduce_sum(counts, axis=-1), message='counts last-dimension must sum to `self.total_count`' ), ], counts)
def _sample_one_batch_member(args): logits, num_cat_samples = args[0], args[1] # [K], [] # x has shape [1, num_cat_samples = num_samples * num_trials] x = tf.random.categorical(logits[tf.newaxis, ...], num_cat_samples, seed=seed) x = tf.reshape(x, shape=[num_samples, -1]) # [num_samples, num_trials] x = tf.one_hot( x, depth=num_classes) # [num_samples, num_trials, num_classes] x = tf.reduce_sum(x, axis=-2) # [num_samples, num_classes] return tf.cast(x, dtype=dtype)
def _event_shape_tensor(self): event_sizes = tf.nest.map_structure(tensorshape_util.num_elements, self._distribution.event_shape) if any(s is None for s in tf.nest.flatten(event_sizes)): event_sizes = tf.nest.map_structure( lambda static_size, shape_tensor: # pylint: disable=g-long-lambda (tf.reduce_prod(shape_tensor) if static_size is None else static_size), event_sizes, self._distribution.event_shape_tensor()) return tf.reduce_sum(tf.nest.flatten(event_sizes))[tf.newaxis]
def _prob(self, event): samples = tf.convert_to_tensor(self._samples) num_samples = self._compute_num_samples(samples) event = tf.convert_to_tensor(event, name='event', dtype=self.dtype) event, samples = _broadcast_event_and_samples( event, samples, event_ndims=self._event_ndims) prob = tf.reduce_sum(tf.cast(tf.reduce_all( tf.equal(samples, event), axis=tf.range(-self._event_ndims, 0)), dtype=tf.int32), axis=-1) / num_samples if dtype_util.is_floating(self.dtype): prob = tf.cast(prob, self.dtype) return prob
def _covariance(self): static_event_ndims = tensorshape_util.rank(self.event_shape) if static_event_ndims is not None and static_event_ndims != 1: # Covariance is defined only for vector distributions. raise NotImplementedError("covariance is not implemented") with tf.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) probs = distribution_utils.pad_mixture_dimensions( distribution_utils.pad_mixture_dimensions( self.mixture_distribution.probs_parameter(), self, self.mixture_distribution, self._event_ndims), self, self.mixture_distribution, self._event_ndims) # [B, k, 1, 1] mean_cond_var = tf.reduce_sum( probs * self.components_distribution.covariance(), axis=-3) # [B, e, e] var_cond_mean = tf.reduce_sum( probs * _outer_squared_difference(self.components_distribution.mean(), self._pad_sample_dims(self._mean())), axis=-3) # [B, e, e] return mean_cond_var + var_cond_mean # [B, e, e]
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs, logits): """Return assertions for `Categorical`-type distributions.""" assertions = [] # In init, we can always build shape and dtype checks because # we assume shape doesn't change for Variable backed args. if is_init: x, name = (probs, 'probs') if logits is None else (logits, 'logits') if not dtype_util.is_floating(x.dtype): raise TypeError( 'Argument `{}` must having floating type.'.format(name)) msg = 'Argument `{}` must have rank at least 1.'.format(name) ndims = tensorshape_util.rank(x.shape) if ndims is not None: if ndims < 1: raise ValueError(msg) elif validate_args: x = tf.convert_to_tensor(x) probs = x if logits is None else None # Retain tensor conversion. logits = x if probs is None else None assertions.append( assert_util.assert_rank_at_least(x, 1, message=msg)) if not validate_args: assert not assertions # Should never happen. return [] if logits is not None: if is_init != tensor_util.is_ref(logits): logits = tf.convert_to_tensor(logits) assertions.extend( distribution_util.assert_categorical_event_shape(logits)) if probs is not None: if is_init != tensor_util.is_ref(probs): probs = tf.convert_to_tensor(probs) assertions.extend([ assert_util.assert_non_negative(probs), assert_util.assert_near( tf.reduce_sum(probs, axis=-1), np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)), message='Argument `probs` must sum to 1.') ]) assertions.extend( distribution_util.assert_categorical_event_shape(probs)) return assertions