def _log_prob(self, x): logits = self._logits_parameter_no_checks() event_size = self._event_size(logits) x = tf.cast(x, logits.dtype) x = self._maybe_assert_valid_sample(x, dtype=logits.dtype) # broadcast logits or x if need be. if (not tensorshape_util.is_fully_defined(x.shape) or not tensorshape_util.is_fully_defined(logits.shape) or x.shape != logits.shape): broadcast_shape = tf.broadcast_dynamic_shape( tf.shape(logits), tf.shape(x)) logits = tf.broadcast_to(logits, broadcast_shape) x = tf.broadcast_to(x, broadcast_shape) logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1)) logits_2d = tf.reshape(logits, [-1, event_size]) x_2d = tf.reshape(x, [-1, event_size]) ret = -tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(x_2d), logits=logits_2d) # Reshape back to user-supplied batch and sample dims prior to 2D reshape. ret = tf.reshape(ret, logits_shape) return ret
def _cdf(self, x): with tf.control_dependencies(self._maybe_assert_valid_sample(x)): concentration1 = tf.convert_to_tensor(self.concentration1) concentration0 = tf.convert_to_tensor(self.concentration0) shape = self._batch_shape_tensor(concentration1, concentration0) concentration1 = tf.broadcast_to(concentration1, shape) concentration0 = tf.broadcast_to(concentration0, shape) return tf.math.betainc(concentration1, concentration0, x)
def _cdf(self, x): x = self._maybe_assert_valid_sample(x) logits = self._logits_parameter_no_checks() total_count = tf.convert_to_tensor(self.total_count) shape = self._batch_shape_tensor(logits_or_probs=logits, total_count=total_count) return tf.math.betainc(tf.broadcast_to(total_count, shape), tf.broadcast_to(1. + x, shape), tf.broadcast_to(tf.sigmoid(-logits), shape))
def _variance(self): with tf.control_dependencies(self._runtime_assertions): probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = self._observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ self.batch_shape_tensor(), [self._num_states], self._observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( self._observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(self.batch_shape_tensor()) flat_probs_shape = [self._num_steps, batch_size, self._num_states] flat_means_shape = [ batch_size, 1, self._num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size 1 num_states observation_event_size flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means) # flat_mean :: batch_size num_steps 1 observation_event_size variances = self._observation_distribution.variance() variances = tf.broadcast_to(variances, means_shape) # variances :: batch_shape num_states observation_event_shape flat_variances = tf.reshape(variances, flat_means_shape) # flat_variances :: batch_size 1 num_states observation_event_size # For a mixture of n distributions with mixture probabilities # p[i], and where the individual distributions have means and # variances given by mean[i] and var[i], the variance of # the mixture is given by: # # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2) flat_variance = tf.einsum("ijk,jikl->jil", flat_probs, (flat_means - flat_mean)**2 + flat_variances) # flat_variance :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat([ self.batch_shape_tensor(), [self._num_steps], observation_event_shape ], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_variance, unflat_mean_shape)
def _log_prob(self, value): with tf.control_dependencies(self._runtime_assertions): # The argument `value` is a tensor of sequences of observations. # `observation_batch_shape` is the shape of that tensor with the # sequence part removed. # `observation_batch_shape` is then broadcast to the full batch shape # to give the `batch_shape` that defines the shape of the result. observation_tensor_shape = tf.shape(value) observation_batch_shape = observation_tensor_shape[:-1 - self. _underlying_event_rank] # value :: observation_batch_shape num_steps observation_event_shape batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) log_init = tf.broadcast_to( self._log_init, tf.concat([batch_shape, [self._num_states]], axis=0)) # log_init :: batch_shape num_states log_transition = self._log_trans # `observation_event_shape` is the shape of each sequence of observations # emitted by the model. observation_event_shape = observation_tensor_shape[ -1 - self._underlying_event_rank:] working_obs = tf.broadcast_to( value, tf.concat([batch_shape, observation_event_shape], axis=0)) # working_obs :: batch_shape observation_event_shape r = self._underlying_event_rank # Move index into sequence of observations to front so we can apply # tf.foldl working_obs = distribution_util.move_dimension( working_obs, -1 - r, 0)[..., tf.newaxis] # working_obs :: num_steps batch_shape underlying_event_shape observation_probs = ( self._observation_distribution.log_prob(working_obs)) def forward_step(log_prev_step, log_prob_observation): return _log_vector_matrix( log_prev_step, log_transition) + log_prob_observation fwd_prob = tf.foldl(forward_step, observation_probs, initializer=log_init) # fwd_prob :: batch_shape num_states log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1) # log_prob :: batch_shape return log_prob
def _swap_m_with_i(vecs, m, i): """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.) Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped per-vector indices `i`, this function swaps elements `m` and `i` in each vector. For the use-case below, these are permutation vectors. Args: vecs: Vectors on which we perform the swap, int64 `Tensor`. m: Scalar int64 `Tensor`, the index into which the `i`th element is going. i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into which the `m`th element is going. Returns: vecs: The updated vectors. """ vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs') m = tf.convert_to_tensor(m, dtype=tf.int64, name='m') i = tf.convert_to_tensor(i, dtype=tf.int64, name='i') trailing_elts = tf.broadcast_to( tf.range(m + 1, prefer_static.shape(vecs, out_type=tf.int64)[-1]), prefer_static.shape(vecs[..., m + 1:])) trailing_elts = tf.where(tf.equal(trailing_elts, i), tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:]) # TODO(bjp): Could we use tensor_scatter_nd_update? vecs_shape = vecs.shape vecs = tf.concat([ vecs[..., :m], tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1), trailing_elts ], axis=-1) tensorshape_util.set_shape(vecs, vecs_shape) return vecs
def _marginal_hidden_probs(self): """Compute marginal pdf for each individual observable.""" initial_log_probs = tf.broadcast_to( self._log_init, tf.concat([self.batch_shape_tensor(), [self._num_states]], axis=0)) # initial_log_probs :: batch_shape num_states def _scan_multiple_steps(): """Perform `scan` operation when `num_steps` > 1.""" transition_log_probs = self._log_trans def forward_step(log_probs, _): return _log_vector_matrix(log_probs, transition_log_probs) dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) forward_log_probs = tf.scan(forward_step, dummy_index, initializer=initial_log_probs, name="forward_log_probs") return tf.concat([[initial_log_probs], forward_log_probs], axis=0) forward_log_probs = prefer_static.cond( self._num_steps > 1, _scan_multiple_steps, lambda: initial_log_probs[tf.newaxis, ...]) return tf.exp(forward_log_probs)
def _mean(self, **kwargs): if not self.bijector.is_constant_jacobian: raise NotImplementedError("mean is not implemented for non-affine " "bijectors") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) x = self.distribution.mean(**distribution_kwargs) if self._is_maybe_batch_override or self._is_maybe_event_override: # A batch (respectively event) shape override is only allowed if the batch # (event) shape of the base distribution is [], so concatenating all the # shapes does the right thing. new_shape = prefer_static.concat([ prefer_static.ones_like(self._override_batch_shape), self.distribution.batch_shape_tensor(), prefer_static.ones_like(self._override_event_shape), self.distribution.event_shape_tensor(), ], 0) x = tf.reshape(x, new_shape) new_shape = prefer_static.concat( [self.batch_shape_tensor(), self.event_shape_tensor()], 0) x = tf.broadcast_to(x, new_shape) y = self.bijector.forward(x, **bijector_kwargs) sample_shape = tf.convert_to_tensor([], dtype=tf.int32, name="sample_shape") y = self._set_sample_static_shape(y, sample_shape) return y
def _sample_n(self, n, seed=None): seed = SeedStream(seed, "beta") concentration1 = tf.convert_to_tensor(self.concentration1) concentration0 = tf.convert_to_tensor(self.concentration0) shape = self._batch_shape_tensor(concentration1, concentration0) expanded_concentration1 = tf.broadcast_to(concentration1, shape) expanded_concentration0 = tf.broadcast_to(concentration0, shape) gamma1_sample = tf.random.gamma(shape=[n], alpha=expanded_concentration1, dtype=self.dtype, seed=seed()) gamma2_sample = tf.random.gamma(shape=[n], alpha=expanded_concentration0, dtype=self.dtype, seed=seed()) beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample) return beta_sample
def _cdf(self, x): df = tf.convert_to_tensor(self.df) # Take Abs(scale) to make subsequent where work correctly. y = (x - self.loc) / tf.abs(self.scale) x_t = df / (y**2. + df) neg_cdf = 0.5 * tf.math.betainc( 0.5 * tf.broadcast_to(df, prefer_static.shape(x_t)), 0.5, x_t) return tf.where(y < 0., neg_cdf, 1. - neg_cdf)
def cholesky_concat(chol, cols, name=None): """Concatenates `chol @ chol.T` with additional rows and columns. This operation is conceptually identical to: ```python def cholesky_concat_slow(chol, cols): # cols shaped (n + m) x m = z x m mat = tf.matmul(chol, chol, adjoint_b=True) # batch of n x n # Concat columns. mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1) # n x z # Concat rows. mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2) # z x z return tf.linalg.cholesky(mat) ``` but whereas `cholesky_concat_slow` would cost `O(z**3)` work, `cholesky_concat` only costs `O(z**2 + m**3)` work. The resulting (implicit) matrix must be symmetric and positive definite. Thus, the bottom right `m x m` must be self-adjoint, and we do not require a separate `rows` argument (which can be inferred from `conj(cols.T)`). Args: chol: Cholesky decomposition of `mat = chol @ chol.T`. cols: The new columns whose first `n` rows we would like concatenated to the right of `mat = chol @ chol.T`, and whose conjugate transpose we would like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor` with final dims `(n+m, m)`. The first `n` rows are the top right rectangle (their conjugate transpose forms the bottom left), and the bottom `m x m` is self-adjoint. name: Optional name for this op. Returns: chol_concat: The Cholesky decomposition of: ``` [ [ mat cols[:n, :] ] [ conj(cols.T) ] ] ``` """ with tf.name_scope(name or 'cholesky_extend'): dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32) chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype) cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype) n = prefer_static.shape(chol)[-1] mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :] solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast( chol, mat_nm) lower_right_mm = tf.linalg.cholesky( mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True)) lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm)) out_batch = prefer_static.shape(solved_nm)[:-2] chol = tf.broadcast_to( chol, tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0)) top_right_zeros_nm = tf.zeros_like(solved_nm) return tf.concat([ tf.concat([chol, top_right_zeros_nm], axis=-1), tf.concat([lower_left_mn, lower_right_mm], axis=-1) ], axis=-2)
def _sample_n(self, n, seed=None): # Like with the univariate Student's t, sampling can be implemented as a # ratio of samples from a multivariate gaussian with the appropriate # covariance matrix and a sample from the chi-squared distribution. seed = SeedStream(seed, salt='multivariate t') loc = tf.broadcast_to(self.loc, self._sample_shape()) mvn = mvn_linear_operator.MultivariateNormalLinearOperator( loc=tf.zeros_like(loc), scale=self.scale) normal_samp = mvn.sample(n, seed=seed()) df = tf.broadcast_to(self.df, self.batch_shape_tensor()) chi2 = chi2_lib.Chi2(df=df) chi2_samp = chi2.sample(n, seed=seed()) return ( self._loc + normal_samp * tf.math.rsqrt(chi2_samp / self._df)[..., tf.newaxis])
def _sample_n(self, n, seed=None): del seed # unused loc = tf.convert_to_tensor(self.loc) return tf.broadcast_to( loc, tf.concat([[n], self._batch_shape_tensor(loc=loc), self._event_shape_tensor(loc=loc)], axis=0))
def _variance(self): probs = self._categorical.probs_parameter() outcomes = tf.broadcast_to( self.outcomes, shape=dist_util.prefer_static_shape(probs)) if dtype_util.is_integer(outcomes.dtype): if self._validate_args: outcomes = dist_util.embed_check_integer_casting_closed( outcomes, target_dtype=probs.dtype) outcomes = tf.cast(outcomes, dtype=probs.dtype) square_d = tf.math.squared_difference( outcomes, self._mean(probs)[..., tf.newaxis]) return tf.reduce_sum(probs * square_d, axis=-1)
def _covariance(self): if distribution_util.is_diagonal_scale(self.scale): mvn_cov = tf.linalg.diag(tf.square(self.scale.diag_part())) else: mvn_cov = self.scale.matmul(self.scale.to_dense(), adjoint_arg=True) cov_shape = tf.concat( [self._sample_shape(), self._event_shape_tensor()], -1) mvn_cov = tf.broadcast_to(mvn_cov, cov_shape) return self._std_var_helper(mvn_cov, 'covariance', 2, lambda x: x)
def _variance(self): if distribution_util.is_diagonal_scale(self.scale): mvn_var = tf.square(self.scale.diag_part()) elif (isinstance(self.scale, tf.linalg.LinearOperatorLowRankUpdate) and self.scale.is_self_adjoint): mvn_var = tf.linalg.diag_part( self.scale.matmul(self.scale.to_dense())) else: mvn_var = tf.linalg.diag_part( self.scale.matmul(self.scale.to_dense(), adjoint_arg=True)) mvn_var = tf.broadcast_to(mvn_var, self._sample_shape()) return self._std_var_helper(mvn_var, 'variance', 1, lambda x: x)
def _mean(self): mean = tf.broadcast_to(self.loc, self._sample_shape()) if self.allow_nan_stats: return tf.where(self.df[..., tf.newaxis] > 1., mean, dtype_util.as_numpy_dtype(self.dtype)(np.nan)) else: with tf.control_dependencies([ assert_util.assert_less( tf.cast(1., self.dtype), self.df, message='Mean not defined for components of df <= 1.'), ]): return tf.identity(mean)
def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): mvn_std = tf.abs(self.scale.diag_part()) elif (isinstance(self.scale, tf.linalg.LinearOperatorLowRankUpdate) and self.scale.is_self_adjoint): mvn_std = tf.sqrt( tf.linalg.diag_part(self.scale.matmul(self.scale.to_dense()))) else: mvn_std = tf.sqrt( tf.linalg.diag_part( self.scale.matmul(self.scale.to_dense(), adjoint_arg=True))) mvn_std = tf.broadcast_to(mvn_std, self._sample_shape()) return self._std_var_helper(mvn_std, 'standard deviation', 1, tf.sqrt)
def _entropy(self): df = tf.broadcast_to(self.df, self.batch_shape_tensor()) num_dims = tf.cast(self.event_shape_tensor()[0], self.dtype) def _lbeta(concentration0, concentration1): return (tf.math.lgamma(concentration1) + tf.math.lgamma(concentration0) - tf.math.lgamma(concentration0 + concentration1)) shape_factor = self._scale.log_abs_determinant() beta_factor = num_dims / 2. * ( tf.math.log(df) + np.log(np.pi)) - tf.math.lgamma( num_dims / 2.) + _lbeta(num_dims / 2., df / 2.) digamma_factor = (num_dims + df) / 2. * (tf.math.digamma( (num_dims + df) / 2.) - tf.math.digamma(df / 2.)) return shape_factor + beta_factor + digamma_factor
def _fn(self, **kwargs): """Implements summary statistic, eg, mean, stddev, mode.""" x = getattr(self.distribution, attr)(**kwargs) shape = prefer_static.concat([ self.distribution.batch_shape_tensor(), prefer_static.ones(prefer_static.rank_from_shape( self.sample_shape), dtype=self.sample_shape.dtype), self.distribution.event_shape_tensor(), ], axis=0) x = tf.reshape(x, shape=shape) shape = prefer_static.concat([ self.distribution.batch_shape_tensor(), self.sample_shape, self.distribution.event_shape_tensor(), ], axis=0) return tf.broadcast_to(x, shape)
def _mean(self): with tf.control_dependencies(self._runtime_assertions): probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = self._observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ self.batch_shape_tensor(), [self._num_states], self._observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( self._observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(self.batch_shape_tensor()) flat_probs_shape = [self._num_steps, batch_size, self._num_states] flat_means_shape = [ batch_size, self._num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size num_states observation_event_size flat_mean = tf.einsum("ijk,jkl->jil", flat_probs, flat_means) # flat_mean :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat([ self.batch_shape_tensor(), [self._num_steps], observation_event_shape ], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_mean, unflat_mean_shape)
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_reconstruct'). Returns: x: The original input to `tf.linalg.lu`, i.e., `x` as in, `lu_reconstruct(*tf.linalg.lu(x))`. #### Examples ```python import numpy as np from tensorflow_probability.python.internal.backend import jax as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x)) tf.assert_near(x, x_reconstructed) # ==> True ``` """ with tf.name_scope(name or 'lu_reconstruct'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) shape = tf.shape(lower_upper) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(shape[:-1], dtype=lower_upper.dtype)) upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1) x = tf.matmul(lower, upper) if (tensorshape_util.rank(lower_upper.shape) is None or tensorshape_util.rank(lower_upper.shape) != 2): # We either don't know the batch rank or there are >0 batch dims. batch_size = tf.reduce_prod(shape[:-2]) d = shape[-1] x = tf.reshape(x, [batch_size, d, d]) perm = tf.reshape(perm, [batch_size, d]) perm = tf.map_fn(tf.math.invert_permutation, perm) batch_indices = tf.broadcast_to( tf.range(batch_size)[:, tf.newaxis], [batch_size, d]) x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1)) x = tf.reshape(x, shape) else: x = tf.gather(x, tf.math.invert_permutation(perm)) x.set_shape(lower_upper.shape) return x
def _mode(self): loc = tf.convert_to_tensor(self.loc) return tf.broadcast_to(loc, self._batch_shape_tensor(loc=loc))
def _stddev(self): scale = tf.convert_to_tensor(self.scale) return tf.broadcast_to(scale * np.pi / np.sqrt(3), self._batch_shape_tensor(scale=scale))
def _entropy(self): scale = tf.convert_to_tensor(self.scale) return tf.broadcast_to(np.log(2.) + 1 + tf.math.log(scale), self._batch_shape_tensor(scale=scale))
def _mode(self): return tf.broadcast_to(self.loc, self._sample_shape())
def _stddev(self): scale = tf.convert_to_tensor(self.scale) return tf.broadcast_to(np.sqrt(2.) * scale, self._batch_shape_tensor(scale=scale))
def _mode(self): scale = tf.convert_to_tensor(self.scale) return tf.broadcast_to(scale, self._batch_shape_tensor(scale=scale))
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None): """Computes the (partial) pivoted cholesky decomposition of `matrix`. The pivoted Cholesky is a low rank approximation of the Cholesky decomposition of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The currently-worst-approximated diagonal element is selected as the pivot at each iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn, N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`. Note that, unlike the Cholesky decomposition, `lr` is not triangular even in a rectangular-matrix sense. However, under a permutation it could be made triangular (it has one more zero in each column as you move to the right). Such a matrix can be useful as a preconditioner for conjugate gradient optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be cheaply done via the Woodbury matrix identity, as implemented by `tf.linalg.LinearOperatorLowRankUpdate`. Args: matrix: Floating point `Tensor` batch of symmetric, positive definite matrices. max_rank: Scalar `int` `Tensor`, the rank at which to truncate the approximation. diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the errors of all diagonal elements of `lr @ lr.T` are each lower than `element * diag_rtol`, iteration is permitted to terminate early. name: Optional name for the op. Returns: lr: Low rank pivoted Cholesky approximation of `matrix`. #### References [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the pivoted Cholesky decomposition. _Applied numerical mathematics_, 62(4):428-440, 2012. [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points. _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114 """ with tf.name_scope(name or 'pivoted_cholesky'): dtype = dtype_util.common_dtype([matrix, diag_rtol], dtype_hint=tf.float32) matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype) if tensorshape_util.rank(matrix.shape) is None: raise NotImplementedError( 'Rank of `matrix` must be known statically') max_rank = tf.convert_to_tensor(max_rank, name='max_rank', dtype=tf.int64) max_rank = tf.minimum( max_rank, prefer_static.shape(matrix, out_type=tf.int64)[-1]) diag_rtol = tf.convert_to_tensor(diag_rtol, dtype=dtype, name='diag_rtol') matrix_diag = tf.linalg.diag_part(matrix) # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs. orig_error = tf.reduce_max(matrix_diag, axis=-1) def cond(m, pchol, perm, matrix_diag): """Condition for `tf.while_loop` continuation.""" del pchol del perm error = tf.linalg.norm(matrix_diag, ord=1, axis=-1) max_err = tf.reduce_max(error / orig_error) return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol)) batch_dims = tensorshape_util.rank(matrix.shape) - 2 def batch_gather(params, indices, axis=-1): return tf.gather(params, indices, axis=axis, batch_dims=batch_dims) def body(m, pchol, perm, matrix_diag): """Body of a single `tf.while_loop` iteration.""" # Here is roughly a numpy, non-batched version of what's going to happen. # (See also Algorithm 1 of Harbrecht et al.) # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m # 2: maxval = matrix_diag[perm][maxi] # 3: perm[m], perm[maxi] = perm[maxi], perm[m] # 4: row = matrix[perm[m]][perm[m + 1:]] # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2) # 6: pivot = np.sqrt(maxval); row /= pivot # 7: row = np.concatenate([[[pivot]], row], -1) # 8: matrix_diag[perm[m:]] -= row**2 # 9: pchol[m, perm[m:]] = row # Find the maximal position of the (remaining) permuted diagonal. # Steps 1, 2 above. permuted_diag = batch_gather(matrix_diag, perm[..., m:]) maxi = tf.argmax(permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis] maxval = batch_gather(permuted_diag, maxi) maxi = maxi + m maxval = maxval[..., 0] # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above. perm = _swap_m_with_i(perm, m, maxi) # Step 4. row = batch_gather(matrix, perm[..., m:m + 1], axis=-2) row = batch_gather(row, perm[..., m + 1:]) # Step 5. prev_rows = pchol[..., :m, :] prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:]) prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1]) row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col, axis=-2)[..., tf.newaxis, :] # Step 6. pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis] # Step 7. row = tf.concat([pivot, row / pivot], axis=-1) # TODO(b/130899118): Pad grad fails with int64 paddings. # Step 8. paddings = tf.concat([ tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32), [[tf.cast(m, tf.int32), 0]] ], axis=0) diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :] reverse_perm = _invert_permutation(perm) matrix_diag -= batch_gather(diag_update, reverse_perm) # Step 9. row = tf.pad(row, paddings=paddings) # TODO(bjp): Defer the reverse permutation all-at-once at the end? row = batch_gather(row, reverse_perm) pchol_shape = pchol.shape pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]], axis=-2) tensorshape_util.set_shape(pchol, pchol_shape) return m + 1, pchol, perm, matrix_diag m = np.int64(0) pchol = tf.zeros_like(matrix[..., :max_rank, :]) matrix_shape = prefer_static.shape(matrix, out_type=tf.int64) perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]), matrix_shape[:-1]) _, pchol, _, _ = tf.while_loop(cond=cond, body=body, loop_vars=(m, pchol, perm, matrix_diag)) pchol = tf.linalg.matrix_transpose(pchol) tensorshape_util.set_shape( pchol, tensorshape_util.concatenate(matrix_diag.shape, [None])) return pchol
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): """Solves systems of linear eqns `A X = RHS`, given LU factorizations. Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. rhs: Matrix-shaped float `Tensor` representing targets for which to solve; `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_solve'). Returns: x: The `X` in `A @ X = RHS`. #### Examples ```python import numpy as np from tensorflow_probability.python.internal.backend import jax as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax x = [[[1., 2], [3, 4]], [[7, 8], [3, 4]]] inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with tf.name_scope(name or 'lu_solve'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') rhs = tf.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) rhs = tf.identity(rhs) if (tensorshape_util.rank(rhs.shape) == 2 and tensorshape_util.rank(perm.shape) == 1): # Both rhs and perm have scalar batch_shape. permuted_rhs = tf.gather(rhs, perm, axis=-2) else: # Either rhs or perm have non-scalar batch_shape or we can't determine # this information statically. rhs_shape = tf.shape(rhs) broadcast_batch_shape = tf.broadcast_dynamic_shape( rhs_shape[:-2], tf.shape(perm)[:-1]) d, m = rhs_shape[-2], rhs_shape[-1] rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]], axis=0) # Tile out rhs. broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape) broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m]) # Tile out perm and add batch indices. broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1]) broadcast_perm = tf.reshape(broadcast_perm, [-1, d]) broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape) broadcast_batch_indices = tf.broadcast_to( tf.range(broadcast_batch_size)[:, tf.newaxis], [broadcast_batch_size, d]) broadcast_perm = tf.stack( [broadcast_batch_indices, broadcast_perm], axis=-1) permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm) permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) return linear_operator_util.matrix_triangular_solve_with_broadcast( lower_upper, # Only upper is accessed. linear_operator_util.matrix_triangular_solve_with_broadcast( lower, permuted_rhs), lower=False)