def _batch_shape_tensor(self, loc=None, scale=None, concentration=None): return functools.reduce( prefer_static.broadcast_shape, (prefer_static.shape(self.loc if loc is None else loc), prefer_static.shape(self.scale if scale is None else scale), prefer_static.shape(self.concentration if concentration is None else concentration)))
def _batch_shape_tensor(self, logits_or_probs=None, total_count=None): if logits_or_probs is None: logits_or_probs = self._logits if self._probs is None else self._logits total_count = self._total_count if total_count is None else total_count return prefer_static.broadcast_shape( prefer_static.shape(logits_or_probs), prefer_static.shape(total_count))
def _swap_m_with_i(vecs, m, i): """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.) Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped per-vector indices `i`, this function swaps elements `m` and `i` in each vector. For the use-case below, these are permutation vectors. Args: vecs: Vectors on which we perform the swap, int64 `Tensor`. m: Scalar int64 `Tensor`, the index into which the `i`th element is going. i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into which the `m`th element is going. Returns: vecs: The updated vectors. """ vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs') m = tf.convert_to_tensor(m, dtype=tf.int64, name='m') i = tf.convert_to_tensor(i, dtype=tf.int64, name='i') trailing_elts = tf.broadcast_to( tf.range(m + 1, prefer_static.shape(vecs, out_type=tf.int64)[-1]), prefer_static.shape(vecs[..., m + 1:])) trailing_elts = tf.where(tf.equal(trailing_elts, i), tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:]) # TODO(bjp): Could we use tensor_scatter_nd_update? vecs_shape = vecs.shape vecs = tf.concat([ vecs[..., :m], tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1), trailing_elts ], axis=-1) tensorshape_util.set_shape(vecs, vecs_shape) return vecs
def _batch_shape_tensor(self, temperature=None, logits=None): param = logits if param is None: param = self._logits if self._logits is not None else self._probs if temperature is None: temperature = self.temperature return prefer_static.broadcast_shape(prefer_static.shape(temperature), prefer_static.shape(param)[:-1])
def cholesky_concat(chol, cols, name=None): """Concatenates `chol @ chol.T` with additional rows and columns. This operation is conceptually identical to: ```python def cholesky_concat_slow(chol, cols): # cols shaped (n + m) x m = z x m mat = tf.matmul(chol, chol, adjoint_b=True) # batch of n x n # Concat columns. mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1) # n x z # Concat rows. mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2) # z x z return tf.linalg.cholesky(mat) ``` but whereas `cholesky_concat_slow` would cost `O(z**3)` work, `cholesky_concat` only costs `O(z**2 + m**3)` work. The resulting (implicit) matrix must be symmetric and positive definite. Thus, the bottom right `m x m` must be self-adjoint, and we do not require a separate `rows` argument (which can be inferred from `conj(cols.T)`). Args: chol: Cholesky decomposition of `mat = chol @ chol.T`. cols: The new columns whose first `n` rows we would like concatenated to the right of `mat = chol @ chol.T`, and whose conjugate transpose we would like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor` with final dims `(n+m, m)`. The first `n` rows are the top right rectangle (their conjugate transpose forms the bottom left), and the bottom `m x m` is self-adjoint. name: Optional name for this op. Returns: chol_concat: The Cholesky decomposition of: ``` [ [ mat cols[:n, :] ] [ conj(cols.T) ] ] ``` """ with tf.name_scope(name or 'cholesky_extend'): dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32) chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype) cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype) n = prefer_static.shape(chol)[-1] mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :] solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast( chol, mat_nm) lower_right_mm = tf.linalg.cholesky( mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True)) lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm)) out_batch = prefer_static.shape(solved_nm)[:-2] chol = tf.broadcast_to( chol, tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0)) top_right_zeros_nm = tf.zeros_like(solved_nm) return tf.concat([ tf.concat([chol, top_right_zeros_nm], axis=-1), tf.concat([lower_left_mn, lower_right_mm], axis=-1) ], axis=-2)
def _reshape_part(part, event_shape): part = tf.cast(part, self.dtype) static_rank = tf.get_static_value(ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ ps.shape(part)[:ps.size(ps.shape(part)) - ps.size(event_shape)], [-1] ], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32))
def _batch_shape_tensor(self, distributions=None): if distributions is None: distributions = self.poisson_and_mixture_distributions() dist, mixture_dist = distributions return tf.broadcast_dynamic_shape( dist.batch_shape_tensor(), prefer_static.shape(mixture_dist.logits))[:-1]
def _cdf(self, x): df = tf.convert_to_tensor(self.df) # Take Abs(scale) to make subsequent where work correctly. y = (x - self.loc) / tf.abs(self.scale) x_t = df / (y**2. + df) neg_cdf = 0.5 * tf.math.betainc( 0.5 * tf.broadcast_to(df, prefer_static.shape(x_t)), 0.5, x_t) return tf.where(y < 0., neg_cdf, 1. - neg_cdf)
def _inverse(self, y): n = prefer_static.shape(y)[-1] batch_shape = prefer_static.shape(y)[:-2] # Extract the reciprocal of the row norms from the diagonal. diag = tf.linalg.diag_part(y)[..., tf.newaxis] # Set the diagonal to 0s. y = tf.linalg.set_diag( y, tf.zeros(tf.concat([batch_shape, [n]], axis=-1), dtype=y.dtype)) # Multiply with the norm (or divide by its reciprocal) to recover the # unconstrained reals in the (strictly) lower triangular part. x = y / diag # Remove the first row and last column before inverting the FillTriangular # transformation. return fill_triangular.FillTriangular().inverse(x[..., 1:, :-1])
def _forward(self, x): x = tf.convert_to_tensor(x, name='x') batch_shape = prefer_static.shape(x)[:-1] # Pad zeros on the top row and right column. y = fill_triangular.FillTriangular().forward(x) rank = prefer_static.rank(y) paddings = tf.concat([ tf.zeros(shape=(rank - 2, 2), dtype=tf.int32), tf.constant([[1, 0], [0, 1]], dtype=tf.int32) ], axis=0) y = tf.pad(y, paddings) # Set diagonal to 1s. n = prefer_static.shape(y)[-1] diag = tf.ones(tf.concat([batch_shape, [n]], axis=-1), dtype=x.dtype) y = tf.linalg.set_diag(y, diag) # Normalize each row to have Euclidean (L2) norm 1. y /= tf.norm(y, axis=-1)[..., tf.newaxis] return y
def _forward_log_det_jacobian(self, x): # This code is similar to tf.math.log_softmax but different because we have # an implicit zero column to handle. I.e., instead of: # reduce_sum(logits - reduce_sum(exp(logits), dim)) # we must do: # log_normalization = 1 + reduce_sum(exp(logits)) # -log_normalization + reduce_sum(logits - log_normalization) n = prefer_static.shape(x)[-1] log_normalization = tf.math.softplus( tf.reduce_logsumexp(x, axis=-1, keepdims=True)) return tf.squeeze( (-log_normalization + tf.reduce_sum(x - log_normalization, axis=-1, keepdims=True)), axis=-1) + 0.5 * tf.math.log(tf.cast(n + 1, dtype=x.dtype))
def _sample_n(self, n, seed=None): logits = self._logits_parameter_no_checks() sample_shape = prefer_static.concat( [[n], prefer_static.shape(logits)], 0) event_size = self._event_size(logits) if tensorshape_util.rank(logits.shape) == 2: logits_2d = logits else: logits_2d = tf.reshape(logits, [-1, event_size]) samples = tf.random.categorical(logits_2d, n, seed=seed) samples = tf.transpose(a=samples) samples = tf.one_hot(samples, event_size, dtype=self.dtype) ret = tf.reshape(samples, sample_shape) return ret
def _assert_compatible_shape(self, index, sample_shape, samples): requested_shape, _ = self._expand_sample_shape_to_vector( tf.convert_to_tensor(sample_shape, dtype=tf.int32), name='requested_shape') actual_shape = prefer_static.shape(samples) actual_rank = prefer_static.rank_from_shape(actual_shape) requested_rank = prefer_static.rank_from_shape(requested_shape) # We test for two properties we expect of yielded distributions: # (1) The rank of the tensor of generated samples must be at least # as large as the rank requested. # (2) The requested shape must be a prefix of the shape of the # generated tensor of samples. # We attempt to perform test (1) statically first. # We don't need to do this explicitly for test (2) because # `assert_equal` evaluates statically if it can. static_actual_rank = tf.get_static_value(actual_rank) static_requested_rank = tf.get_static_value(requested_rank) assertion_message = ('Samples yielded by distribution #{} are not ' 'consistent with `sample_shape` passed to ' '`JointDistributionCoroutine` ' 'distribution.'.format(index)) # TODO Remove this static check (b/138738650) if (static_actual_rank is not None and static_requested_rank is not None): # We're able to statically check the rank if static_actual_rank < static_requested_rank: raise ValueError(assertion_message) else: control_dependencies = [] else: # We're not able to statically check the rank control_dependencies = [ assert_util.assert_greater_equal( actual_rank, requested_rank, message=assertion_message) ] with tf.control_dependencies(control_dependencies): trimmed_actual_shape = actual_shape[:requested_rank] control_dependencies = [ assert_util.assert_equal( requested_shape, trimmed_actual_shape, message=assertion_message) ] return control_dependencies
def _inverse(self, y): ndims = prefer_static.rank(y) shifted_y = tf.pad( tf.slice( y, tf.zeros(ndims, dtype=tf.int32), prefer_static.shape(y) - tf.one_hot(ndims + self.axis, ndims, dtype=tf.int32) ), # Remove the last entry of y in the chosen dimension. paddings=tf.one_hot( tf.one_hot(ndims + self.axis, ndims, on_value=0, off_value=-1), 2, dtype=tf.int32 ) # Insert zeros at the beginning of the chosen dimension. ) return y - shifted_y
def _inverse_log_det_jacobian(self, y): # The inverse log det jacobian (ILDJ) of the entire mapping is the sum of # the ILDJs of each row's mapping. # # To compute the ILDJ for each row's mapping, consider the forward mapping # `f_k` restricted to the `k`th (1-indexed) row. It maps unconstrained reals # in `R^{k-1}` to unit vectors in `R^k`. `f_k : R^{k-1} -> R^k` is given by: # # f(x_1, x_2, ... x_{k-1}) = (x_1/s, x_2/s, ..., x_{k-1}/s, 1/s) # # where `s = norm(x_1, x_2, ..., x_{k-1}, 1)`. # # The change in infinitesimal `k-1`-dimensional volume (or surface area) is # given by sqrt(|det J^T J|); where J is the `k x (k-1)` Jacobian matrix. # # Claim: sqrt(|det(J^T J)|) = s^{-k}. # # Proof: We compute the entries of the Jacobian matrix J: # # J_{i, j} = -x_j / s^3 if i == k # J_{i, j} = (s^2 - x_i^2) / s^3 if i == j and i < k # J_{i, j} = -(x_i * x_j) / s^3 if i != j and i < k # # By spherical symmetry, the volume element depends only on `s`; w.l.o.g. # we can assume that `x_1 = r` and `x_2, ..., x_n = 0`; where # `r^2 + 1 = s^2`. # # We can write `J^T = [A|B]` where `A` is a diagonal matrix of rank `k-1` # with diagonal `(1/s^3, 1/s, 1/s, ..., 1/s)`; and `B` is a column vector # of size `k-1`, with entries (-r/s^3, 0, 0, ..., 0). Hence, # # det(J^T J) = det(diag((r^2 + 1) / s^6, 1/s^2, ..., s^2)) # = s^{-2k}. # # Or, sqrt(|det(J^T J)|) = s^{-k}. # # Hence, the forward log det jacobian (FLDJ) for the `k`th row is given by # `-k * log(s)`. The ILDJ is equal to negative FLDJ at the pre-image, or, # `k * log(s)`; where `s` is the reciprocal of the `k`th diagonal entry. # n = prefer_static.shape(y)[-1] return -tf.reduce_sum(tf.range(1, n + 1, dtype=y.dtype) * tf.math.log(tf.linalg.diag_part(y)), axis=-1)
def maybe_check_wont_broadcast(flat_xs, validate_args): """Verifies that `parts` don't broadcast.""" flat_xs = tuple(flat_xs) # So we can receive generators. if not validate_args: # Note: we don't try static validation because it is theoretically # possible that a user wants to take advantage of broadcasting. # Only when `validate_args` is `True` do we enforce the validation. return flat_xs msg = 'Broadcasting probably indicates an error in model specification.' s = tuple(prefer_static.shape(x) for x in flat_xs) if all(prefer_static.is_numpy(s_) for s_ in s): if not all(np.all(a == b) for a, b in zip(s[1:], s[:-1])): raise ValueError(msg) return flat_xs assertions = [ assert_util.assert_equal(a, b, message=msg) for a, b in zip(s[1:], s[:-1]) ] with tf.control_dependencies(assertions): return tuple(tf.identity(x) for x in flat_xs)
def _inverse_log_det_jacobian(self, y): # Let B be the forward map defined by the bijector. Consider the map # F : R^n -> R^n where the image of B in R^{n+1} is restricted to the first # n coordinates. # # Claim: det{ dF(X)/dX } = prod(Y) where Y = B(X). # Proof: WLOG, in vector notation: # X = log(Y[:-1]) - log(Y[-1]) # where, # Y[-1] = 1 - sum(Y[:-1]). # We have: # det{dF} = 1 / det{ dX/dF(X} } (1) # = 1 / det{ diag(1 / Y[:-1]) + 1 / Y[-1] } # = 1 / det{ inv{ diag(Y[:-1]) - Y[:-1]' Y[:-1] } } # = det{ diag(Y[:-1]) - Y[:-1]' Y[:-1] } # = (1 + Y[:-1]' inv{diag(Y[:-1])} Y[:-1]) det{diag(Y[:-1])} (2) # = Y[-1] prod(Y[:-1]) # = prod(Y) # # Let P be the image of R^n under F. Define the lift G, from P to R^{n+1}, # which appends the last coordinate, Y[-1] := 1 - \sum_k Y_k. G is linear, # so its Jacobian is constant. # # The differential of G, DG, is eye(n) with a row of -1s appended to the # bottom. To compute the Jacobian sqrt{det{(DG)^T(DG)}}, one can see that # (DG)^T(DG) = A + eye(n), where A is the n x n matrix of 1s. This has # eigenvalues (n + 1, 1,...,1), so the determinant is (n + 1). Hence, the # Jacobian of G is sqrt{n + 1} everywhere. # # Putting it all together, the forward bijective map B can be written as # B(X) = G(F(X)) and has Jacobian sqrt{n + 1} * prod(F(X)). # # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector # docstring "Tip". # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma n_plus_one = prefer_static.shape(y)[-1] return -tf.reduce_sum(tf.math.log(y), axis=-1) - 0.5 * tf.math.log( tf.cast(n_plus_one, dtype=y.dtype))
def _event_shape_tensor(self, logits=None): param = logits if param is None: param = self._logits if self._logits is not None else self._probs return prefer_static.shape(param)[-1:]
def _batch_shape_tensor(self, concentration=None, rate=None): return prefer_static.broadcast_shape( prefer_static.shape( self.concentration if concentration is None else concentration), prefer_static.shape(self.rate if rate is None else rate))
def _batch_shape_tensor(self, concentration1=None, concentration0=None): return prefer_static.broadcast_shape( prefer_static.shape(self.concentration1 if concentration1 is None else concentration1), prefer_static.shape(self.concentration0 if concentration0 is None else concentration0))
def _event_shape_tensor(self): param = self._logits if self._logits is not None else self._probs # NOTE: If the last dimension of `param.shape` is statically-known, but # the `param.shape` is not statically-known, then we will *not* return a # statically-known event size here. This could be fixed. return prefer_static.shape(param)[-1:]
def _batch_shape_tensor(self): param = self._logits if self._logits is not None else self._probs return prefer_static.shape(param)[:-1]
def _batch_shape_tensor(self, loc=None, scale=None): return prefer_static.broadcast_shape( prefer_static.shape(self.loc if loc is None else loc), prefer_static.shape(self.scale if scale is None else scale))
def _batch_shape_tensor(self): return prefer_static.shape(self.concentration)
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None): """Computes the (partial) pivoted cholesky decomposition of `matrix`. The pivoted Cholesky is a low rank approximation of the Cholesky decomposition of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The currently-worst-approximated diagonal element is selected as the pivot at each iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn, N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`. Note that, unlike the Cholesky decomposition, `lr` is not triangular even in a rectangular-matrix sense. However, under a permutation it could be made triangular (it has one more zero in each column as you move to the right). Such a matrix can be useful as a preconditioner for conjugate gradient optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be cheaply done via the Woodbury matrix identity, as implemented by `tf.linalg.LinearOperatorLowRankUpdate`. Args: matrix: Floating point `Tensor` batch of symmetric, positive definite matrices. max_rank: Scalar `int` `Tensor`, the rank at which to truncate the approximation. diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the errors of all diagonal elements of `lr @ lr.T` are each lower than `element * diag_rtol`, iteration is permitted to terminate early. name: Optional name for the op. Returns: lr: Low rank pivoted Cholesky approximation of `matrix`. #### References [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the pivoted Cholesky decomposition. _Applied numerical mathematics_, 62(4):428-440, 2012. [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points. _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114 """ with tf.name_scope(name or 'pivoted_cholesky'): dtype = dtype_util.common_dtype([matrix, diag_rtol], dtype_hint=tf.float32) matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype) if tensorshape_util.rank(matrix.shape) is None: raise NotImplementedError( 'Rank of `matrix` must be known statically') max_rank = tf.convert_to_tensor(max_rank, name='max_rank', dtype=tf.int64) max_rank = tf.minimum( max_rank, prefer_static.shape(matrix, out_type=tf.int64)[-1]) diag_rtol = tf.convert_to_tensor(diag_rtol, dtype=dtype, name='diag_rtol') matrix_diag = tf.linalg.diag_part(matrix) # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs. orig_error = tf.reduce_max(matrix_diag, axis=-1) def cond(m, pchol, perm, matrix_diag): """Condition for `tf.while_loop` continuation.""" del pchol del perm error = tf.linalg.norm(matrix_diag, ord=1, axis=-1) max_err = tf.reduce_max(error / orig_error) return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol)) batch_dims = tensorshape_util.rank(matrix.shape) - 2 def batch_gather(params, indices, axis=-1): return tf.gather(params, indices, axis=axis, batch_dims=batch_dims) def body(m, pchol, perm, matrix_diag): """Body of a single `tf.while_loop` iteration.""" # Here is roughly a numpy, non-batched version of what's going to happen. # (See also Algorithm 1 of Harbrecht et al.) # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m # 2: maxval = matrix_diag[perm][maxi] # 3: perm[m], perm[maxi] = perm[maxi], perm[m] # 4: row = matrix[perm[m]][perm[m + 1:]] # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2) # 6: pivot = np.sqrt(maxval); row /= pivot # 7: row = np.concatenate([[[pivot]], row], -1) # 8: matrix_diag[perm[m:]] -= row**2 # 9: pchol[m, perm[m:]] = row # Find the maximal position of the (remaining) permuted diagonal. # Steps 1, 2 above. permuted_diag = batch_gather(matrix_diag, perm[..., m:]) maxi = tf.argmax(permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis] maxval = batch_gather(permuted_diag, maxi) maxi = maxi + m maxval = maxval[..., 0] # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above. perm = _swap_m_with_i(perm, m, maxi) # Step 4. row = batch_gather(matrix, perm[..., m:m + 1], axis=-2) row = batch_gather(row, perm[..., m + 1:]) # Step 5. prev_rows = pchol[..., :m, :] prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:]) prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1]) row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col, axis=-2)[..., tf.newaxis, :] # Step 6. pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis] # Step 7. row = tf.concat([pivot, row / pivot], axis=-1) # TODO(b/130899118): Pad grad fails with int64 paddings. # Step 8. paddings = tf.concat([ tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32), [[tf.cast(m, tf.int32), 0]] ], axis=0) diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :] reverse_perm = _invert_permutation(perm) matrix_diag -= batch_gather(diag_update, reverse_perm) # Step 9. row = tf.pad(row, paddings=paddings) # TODO(bjp): Defer the reverse permutation all-at-once at the end? row = batch_gather(row, reverse_perm) pchol_shape = pchol.shape pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]], axis=-2) tensorshape_util.set_shape(pchol, pchol_shape) return m + 1, pchol, perm, matrix_diag m = np.int64(0) pchol = tf.zeros_like(matrix[..., :max_rank, :]) matrix_shape = prefer_static.shape(matrix, out_type=tf.int64) perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]), matrix_shape[:-1]) _, pchol, _, _ = tf.while_loop(cond=cond, body=body, loop_vars=(m, pchol, perm, matrix_diag)) pchol = tf.linalg.matrix_transpose(pchol) tensorshape_util.set_shape( pchol, tensorshape_util.concatenate(matrix_diag.shape, [None])) return pchol
def _invert_permutation(perm): # TODO(b/130217510): Remove this function. return tf.cast( tf.math.top_k(perm, k=prefer_static.shape(perm)[-1], sorted=True).indices[..., ::-1], perm.dtype)