def _swap_m_with_i(vecs, m, i): """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.) Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped per-vector indices `i`, this function swaps elements `m` and `i` in each vector. For the use-case below, these are permutation vectors. Args: vecs: Vectors on which we perform the swap, int64 `Tensor`. m: Scalar int64 `Tensor`, the index into which the `i`th element is going. i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into which the `m`th element is going. Returns: vecs: The updated vectors. """ vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs') m = tf.convert_to_tensor(m, dtype=tf.int64, name='m') i = tf.convert_to_tensor(i, dtype=tf.int64, name='i') trailing_elts = tf.broadcast_to( tf.range(m + 1, prefer_static.shape(vecs, out_type=tf.int64)[-1]), prefer_static.shape(vecs[..., m + 1:])) shp = prefer_static.shape(trailing_elts) trailing_elts = tf1.where( tf.equal(trailing_elts, tf.broadcast_to(i, shp)), tf.broadcast_to(tf.gather(vecs, [m], axis=-1), shp), tf.broadcast_to(vecs[..., m + 1:], shp)) # TODO(bjp): Could we use tensor_scatter_nd_update? vecs_shape = vecs.shape vecs = tf.concat([ vecs[..., :m], tf.gather(vecs, i, batch_dims=prefer_static.rank(vecs) - 1), trailing_elts ], axis=-1) tensorshape_util.set_shape(vecs, vecs_shape) return vecs
def loop_body(done, u_in, w, seed): """Resample the non-accepted points.""" # We resample u each time completely. Only its sign is used outside the # loop, which is random. u_seed, v_seed, next_seed = samplers.split_seed(seed, n=3) u = samplers.uniform(shape, minval=-1., maxval=1., dtype=concentration.dtype, seed=u_seed) tensorshape_util.set_shape(u, u_in.shape) z = tf.cos(np.pi * u) # Update the non-accepted points. w = tf.where(done, w, (1. + s * z) / (s + z)) y = concentration * (s - w) v = samplers.uniform(shape, minval=0., maxval=1., dtype=concentration.dtype, seed=v_seed) accept = (y * (2. - y) >= v) | (tf.math.log(y / v) + 1. >= y) return done | accept, u, w, next_seed
def _broadcast_cat_event_and_params(event, params, base_dtype): """Broadcasts the event or distribution parameters.""" if dtype_util.is_integer(event.dtype): pass elif dtype_util.is_floating(event.dtype): # When `validate_args=True` we've already ensured int/float casting # is closed. event = tf.cast(event, dtype=tf.int32) else: raise TypeError('`value` should have integer `dtype` or ' '`self.dtype` ({})'.format(base_dtype)) shape_known_statically = ( tensorshape_util.rank(params.shape) is not None and tensorshape_util.is_fully_defined(params.shape[:-1]) and tensorshape_util.is_fully_defined(event.shape)) if not shape_known_statically or params.shape[:-1] != event.shape: params = params * tf.ones_like(event[..., tf.newaxis], dtype=params.dtype) params_shape = tf.shape(params)[:-1] event = event * tf.ones(params_shape, dtype=event.dtype) if tensorshape_util.rank(params.shape) is not None: tensorshape_util.set_shape(event, params.shape[:-1]) return event, params
def _one_step(target_fn, step_sizes, half_next_momentum_parts, state_parts, target, target_grad_parts): """Body of integrator while loop.""" with tf.name_scope('leapfrog_integrate_one_step'): next_state_parts = [ x + tf.cast(eps, x.dtype) * tf.cast(v, x.dtype) # pylint: disable=g-complex-comprehension for x, eps, v in zip(state_parts, step_sizes, half_next_momentum_parts) ] [next_target, next_target_grad_parts ] = mcmc_util.maybe_call_fn_and_grads(target_fn, next_state_parts) if any(g is None for g in next_target_grad_parts): raise ValueError('Encountered `None` gradient.\n' ' state_parts: {}\n' ' next_state_parts: {}\n' ' next_target_grad_parts: {}'.format( state_parts, next_state_parts, next_target_grad_parts)) tensorshape_util.set_shape(next_target, target.shape) for ng, g in zip(next_target_grad_parts, target_grad_parts): tensorshape_util.set_shape(ng, g.shape) next_half_next_momentum_parts = [ v + tf.cast(eps, v.dtype) * tf.cast(g, v.dtype) # pylint: disable=g-complex-comprehension for v, eps, g in zip(half_next_momentum_parts, step_sizes, next_target_grad_parts) ] return [ next_half_next_momentum_parts, next_state_parts, next_target, next_target_grad_parts, ]
def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims, override_event_shape, override_batch_shape, base_is_scalar_batch, **distribution_kwargs): """Finish computation of prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, override_event_shape, override_batch_shape, base_is_scalar_batch, rotate_right=True) prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: prob = tf.reduce_prod(prob, axis=self._reduce_event_indices( override_event_shape, override_batch_shape, base_is_scalar_batch)) prob = prob * tf.exp(tf.cast(ildj, prob.dtype)) if self._is_maybe_event_override and isinstance(event_ndims, int): tensorshape_util.set_shape( prob, tf.broadcast_static_shape(y.shape[:-event_ndims], self.batch_shape)) return prob
def one_step(self, current_state, previous_kernel_results, seed=None): with tf.name_scope(mcmc_util.make_name(self.name, 'rwm', 'one_step')): with tf.name_scope('initialize'): if mcmc_util.is_list_like(current_state): current_state_parts = list(current_state) else: current_state_parts = [current_state] current_state_parts = [ tf.convert_to_tensor(s, name='current_state') for s in current_state_parts ] seed = samplers.sanitize_seed(seed) # Retain for diagnostics. next_state_parts = self.new_state_fn(current_state_parts, seed) # pylint: disable=not-callable # User should be using a new_state_fn that does not alter the state size. # This will fail noisily if that is not the case. for next_part, current_part in zip(next_state_parts, current_state_parts): tensorshape_util.set_shape(next_part, current_part.shape) # Compute `target_log_prob` so its available to MetropolisHastings. next_target_log_prob = self.target_log_prob_fn(*next_state_parts) # pylint: disable=not-callable def maybe_flatten(x): return x if mcmc_util.is_list_like(current_state) else x[0] return [ maybe_flatten(next_state_parts), UncalibratedRandomWalkResults( log_acceptance_correction=tf.zeros_like( next_target_log_prob), target_log_prob=next_target_log_prob, seed=seed, ), ]
def body(i, vecs, cur_sample, seed): sample_seed, next_seed = samplers.split_seed(seed) # squared norm at each coord across active subspace is_active = (i < sample_size) coord_prob = tf.reduce_sum(tf.square(vecs), axis=-1) coord_logits = tf.where( is_active[..., tf.newaxis], tf.math.log(coord_prob), 0.) idx = categorical.Categorical(logits=coord_logits).sample( seed=sample_seed) new_vecs = tf.where( (tf.range(n) < sample_size[..., tf.newaxis, tf.newaxis] - i - 1) & ~cur_sample[..., tf.newaxis], _orthogonal_complement_e_i( vecs, i=tf.where(is_active, idx, 0), gram_schmidt_iters=max_sample_size - i), 0.) # Since range(n) may have unknown shape in the stmt above, we clarify. tensorshape_util.set_shape(new_vecs, vecs.shape) vecs = tf.where(is_active[..., tf.newaxis, tf.newaxis], new_vecs, vecs) cur_sample = (cur_sample | (tf.equal(tf.range(d), idx[..., tf.newaxis]) & is_active[..., tf.newaxis])) return i + 1, vecs, cur_sample, next_seed
def loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_shape = prefer_static.shape( current_step_meta_info.init_energy) direction = tf.cast(tf.random.uniform(shape=batch_shape, minval=0, maxval=2, dtype=tf.int32, seed=self._seed_stream()), dtype=tf.bool) tree_start_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[1], v[0]), initial_step_state) directions_expanded = [ _rightmost_expand_to_rank(direction, prefer_static.rank(state)) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(d, ss, -ss) for d, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, momentum_subtree_cumsum, leapfrogs_taken ] = self._build_sub_tree( directions_expanded, integrator, current_step_meta_info, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state energy_diff_sum = (energy_diff_tree_sum + initial_step_metastate.energy_diff_sum) if MULTINOMIAL_SAMPLE: tree_weight = tf.where( continue_tree_final, candidate_tree_state.weight, tf.constant(-np.inf, dtype=candidate_tree_state.weight.dtype)) weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: tree_weight = tf.where(continue_tree_final, candidate_tree_state.weight, tf.zeros([], dtype=TREE_COUNT_DTYPE)) weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], energy=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.energy, last_candidate_state.energy), weight=weight_sum) for new_candidate_state_temp, old_candidate_state_temp in zip( new_candidate_state.state, last_candidate_state.state): tensorshape_util.set_shape(new_candidate_state_temp, old_candidate_state_temp.shape) for new_candidate_grad_temp, old_candidate_grad_temp in zip( new_candidate_state.target_grad_parts, last_candidate_state.target_grad_parts): tensorshape_util.set_shape(new_candidate_grad_temp, old_candidate_grad_temp.shape) # Update left right information of the trajectory, and check trajectory # level U turn tree_otherend_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[0], v[1]), initial_step_state) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ tf.stack( [ # pylint: disable=g-complex-comprehension tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), r, l), tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), l, r), ], axis=0) for l, r in zip(tf.nest.flatten(tree_final_states), tf.nest.flatten(tree_otherend_states)) ]) momentum_tree_cumsum = [] for p0, p1 in zip(initial_step_metastate.momentum_sum, momentum_subtree_cumsum): momentum_part_temp = p0 + p1 tensorshape_util.set_shape(momentum_part_temp, p0.shape) momentum_tree_cumsum.append(momentum_part_temp) for new_state_temp, old_state_temp in zip( tf.nest.flatten(new_step_state), tf.nest.flatten(initial_step_state)): tensorshape_util.set_shape(new_state_temp, old_state_temp.shape) if GENERALIZED_UTURN: state_diff = momentum_tree_cumsum else: state_diff = [s[1] - s[0] for s in new_step_state.state] no_u_turns_trajectory = has_not_u_turn( state_diff, [m[0] for m in new_step_state.momentum], [m[1] for m in new_step_state.momentum], log_prob_rank=prefer_static.rank_from_shape(batch_shape)) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, momentum_sum=momentum_tree_cumsum, energy_diff_sum=energy_diff_sum, continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, new_step_state, new_step_metastate
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name): x = tf.convert_to_tensor(x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = ps.rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = distribution_util.rotate_transpose(x, shift) if center: x_rotated = x_rotated - tf.reduce_mean( x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = ps.shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = ps.cast(x_len, np.float64) target_length = ps.pow(np.float64(2.), ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.))) pad_length = ps.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = distribution_util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype_util.is_complex(dtype): if not dtype_util.is_floating(dtype): raise TypeError( 'Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex( x_rotated_pad, dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.signal.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.signal.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not tensorshape_util.is_fully_defined(x_rotated.shape): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name='max_lags') max_lags_ = tf.get_static_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = tensorshape_util.as_list(x_rotated.shape) chopped_shape[-1] = min(x_len, max_lags + 1) tensorshape_util.set_shape(shifted_product_chopped, chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = ps.cast(x_len, dtype_util.real_dtype(dtype)) max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype)) denominator = x_len - ps.range(0., max_lags + 1.) denominator = ps.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return distribution_util.rotate_transpose(shifted_product_rotated, -shift)
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None): """Computes the (partial) pivoted cholesky decomposition of `matrix`. The pivoted Cholesky is a low rank approximation of the Cholesky decomposition of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The currently-worst-approximated diagonal element is selected as the pivot at each iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn, N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`. Note that, unlike the Cholesky decomposition, `lr` is not triangular even in a rectangular-matrix sense. However, under a permutation it could be made triangular (it has one more zero in each column as you move to the right). Such a matrix can be useful as a preconditioner for conjugate gradient optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be cheaply done via the Woodbury matrix identity, as implemented by `tf.linalg.LinearOperatorLowRankUpdate`. Args: matrix: Floating point `Tensor` batch of symmetric, positive definite matrices. max_rank: Scalar `int` `Tensor`, the rank at which to truncate the approximation. diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the errors of all diagonal elements of `lr @ lr.T` are each lower than `element * diag_rtol`, iteration is permitted to terminate early. name: Optional name for the op. Returns: lr: Low rank pivoted Cholesky approximation of `matrix`. #### References [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the pivoted Cholesky decomposition. _Applied numerical mathematics_, 62(4):428-440, 2012. [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points. _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114 """ with tf.name_scope(name or 'pivoted_cholesky'): dtype = dtype_util.common_dtype([matrix, diag_rtol], dtype_hint=tf.float32) matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype) if tensorshape_util.rank(matrix.shape) is None: raise NotImplementedError( 'Rank of `matrix` must be known statically') max_rank = tf.convert_to_tensor(max_rank, name='max_rank', dtype=tf.int64) max_rank = tf.minimum( max_rank, prefer_static.shape(matrix, out_type=tf.int64)[-1]) diag_rtol = tf.convert_to_tensor(diag_rtol, dtype=dtype, name='diag_rtol') matrix_diag = tf.linalg.diag_part(matrix) # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs. orig_error = tf.reduce_max(matrix_diag, axis=-1) def cond(m, pchol, perm, matrix_diag): """Condition for `tf.while_loop` continuation.""" del pchol del perm error = tf.linalg.norm(matrix_diag, ord=1, axis=-1) max_err = tf.reduce_max(error / orig_error) return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol)) batch_dims = tensorshape_util.rank(matrix.shape) - 2 def batch_gather(params, indices, axis=-1): return tf.gather(params, indices, axis=axis, batch_dims=batch_dims) def body(m, pchol, perm, matrix_diag): """Body of a single `tf.while_loop` iteration.""" # Here is roughly a numpy, non-batched version of what's going to happen. # (See also Algorithm 1 of Harbrecht et al.) # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m # 2: maxval = matrix_diag[perm][maxi] # 3: perm[m], perm[maxi] = perm[maxi], perm[m] # 4: row = matrix[perm[m]][perm[m + 1:]] # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2) # 6: pivot = np.sqrt(maxval); row /= pivot # 7: row = np.concatenate([[[pivot]], row], -1) # 8: matrix_diag[perm[m:]] -= row**2 # 9: pchol[m, perm[m:]] = row # Find the maximal position of the (remaining) permuted diagonal. # Steps 1, 2 above. permuted_diag = batch_gather(matrix_diag, perm[..., m:]) maxi = tf.argmax(permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis] maxval = batch_gather(permuted_diag, maxi) maxi = maxi + m maxval = maxval[..., 0] # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above. perm = _swap_m_with_i(perm, m, maxi) # Step 4. row = batch_gather(matrix, perm[..., m:m + 1], axis=-2) row = batch_gather(row, perm[..., m + 1:]) # Step 5. prev_rows = pchol[..., :m, :] prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:]) prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1]) row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col, axis=-2)[..., tf.newaxis, :] # Step 6. pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis] # Step 7. row = tf.concat([pivot, row / pivot], axis=-1) # TODO(b/130899118): Pad grad fails with int64 paddings. # Step 8. paddings = tf.concat([ tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32), [[tf.cast(m, tf.int32), 0]] ], axis=0) diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :] reverse_perm = _invert_permutation(perm) matrix_diag -= batch_gather(diag_update, reverse_perm) # Step 9. row = tf.pad(row, paddings=paddings) # TODO(bjp): Defer the reverse permutation all-at-once at the end? row = batch_gather(row, reverse_perm) pchol_shape = pchol.shape pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]], axis=-2) tensorshape_util.set_shape(pchol, pchol_shape) return m + 1, pchol, perm, matrix_diag m = np.int64(0) pchol = tf.zeros_like(matrix[..., :max_rank, :]) matrix_shape = prefer_static.shape(matrix, out_type=tf.int64) perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]), matrix_shape[:-1]) _, pchol, _, _ = tf.while_loop(cond=cond, body=body, loop_vars=(m, pchol, perm, matrix_diag)) pchol = tf.linalg.matrix_transpose(pchol) tensorshape_util.set_shape( pchol, tensorshape_util.concatenate(matrix_diag.shape, [None])) return pchol
def quadrature_scheme_softmaxnormal_quantiles(normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex. A `SoftmaxNormal` random variable `Y` may be generated via ``` Y = SoftmaxCentered(X), X = Normal(normal_loc, normal_scale) ``` Args: normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. The location parameter of the Normal used to construct the SoftmaxNormal. normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`. The scale parameter of the Normal used to construct the SoftmaxNormal. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the convex combination of affine parameters for `K` components. `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the associated with each grid point. """ with tf.name_scope(name or "softmax_normal_grid_and_probs"): normal_loc = tf.convert_to_tensor(value=normal_loc, name="normal_loc") dt = dtype_util.base_dtype(normal_loc.dtype) normal_scale = tf.convert_to_tensor(value=normal_scale, dtype=dt, name="normal_scale") normal_scale = maybe_check_quadrature_param(normal_scale, "normal_scale", validate_args) dist = normal.Normal(loc=normal_loc, scale=normal_scale) def _get_batch_ndims(): """Helper to get rank(dist.batch_shape), statically if possible.""" ndims = tensorshape_util.rank(dist.batch_shape) if ndims is None: ndims = tf.shape(input=dist.batch_shape_tensor())[0] return ndims batch_ndims = _get_batch_ndims() def _get_final_shape(qs): """Helper to build `TensorShape`.""" bs = tensorshape_util.with_rank_at_least(dist.batch_shape, 1) num_components = tf.compat.dimension_value(bs[-1]) if num_components is not None: num_components += 1 tail = tf.TensorShape([num_components, qs]) return bs[:-1].concatenate(tail) def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) quantiles = softmax_centered_bijector.SoftmaxCentered().forward( quantiles) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(a=quantiles, perm=perm) tensorshape_util.set_shape(quantiles, _get_final_shape(quadrature_size + 1)) return quantiles quantiles = _compute_quantiles() # Compute grid as quantile midpoints. grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. # Set shape hints. tensorshape_util.set_shape(grid, _get_final_shape(quadrature_size)) # By construction probs is constant, i.e., `1 / quadrature_size`. This is # important, because non-constant probs leads to non-reparameterizable # samples. probs = tf.fill(dims=[quadrature_size], value=1. / tf.cast(quadrature_size, dist.dtype)) return grid, probs
def _sample_n(self, n, seed=None): if self._use_static_graph: # This sampling approach is almost the same as the approach used by # `MixtureSameFamily`. The differences are due to having a list of # `Distribution` objects rather than a single object, and maintaining # random seed management that is consistent with the non-static code # path. samples = [] cat_samples = self.cat.sample(n, seed=seed) stream = SeedStream(seed, salt='Mixture') for c in range(self.num_components): samples.append(self.components[c].sample(n, seed=stream())) stack_axis = -1 - tensorshape_util.rank(self._static_event_shape) x = tf.stack(samples, axis=stack_axis) # [n, B, k, E] npdt = dtype_util.as_numpy_dtype(x.dtype) mask = tf.one_hot( indices=cat_samples, # [n, B] depth=self._num_components, # == k on_value=npdt(1), off_value=npdt(0)) # [n, B, k] mask = distribution_util.pad_mixture_dimensions( mask, self, self._cat, tensorshape_util.rank( self._static_event_shape)) # [n, B, k, [1]*e] return tf.reduce_sum(x * mask, axis=stack_axis) # [n, B, E] n = tf.convert_to_tensor(n, name='n') static_n = tf.get_static_value(n) n = int(static_n) if static_n is not None else n cat_samples = self.cat.sample(n, seed=seed) static_samples_shape = cat_samples.shape if tensorshape_util.is_fully_defined(static_samples_shape): samples_shape = tensorshape_util.as_list(static_samples_shape) samples_size = tensorshape_util.num_elements(static_samples_shape) else: samples_shape = tf.shape(cat_samples) samples_size = tf.size(cat_samples) static_batch_shape = self.batch_shape if tensorshape_util.is_fully_defined(static_batch_shape): batch_shape = tensorshape_util.as_list(static_batch_shape) batch_size = tensorshape_util.num_elements(static_batch_shape) else: batch_shape = tf.shape(cat_samples)[1:] batch_size = tf.reduce_prod(batch_shape) static_event_shape = self.event_shape if tensorshape_util.is_fully_defined(static_event_shape): event_shape = np.array( tensorshape_util.as_list(static_event_shape), dtype=np.int32) else: event_shape = None # Get indices into the raw cat sampling tensor. We will # need these to stitch sample values back out after sampling # within the component partitions. samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape) # Partition the raw indices so that we can use # dynamic_stitch later to reconstruct the samples from the # known partitions. partitioned_samples_indices = tf.dynamic_partition( data=samples_raw_indices, partitions=cat_samples, num_partitions=self.num_components) # Copy the batch indices n times, as we will need to know # these to pull out the appropriate rows within the # component partitions. batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]), samples_shape) # Explanation of the dynamic partitioning below: # batch indices are i.e., [0, 1, 0, 1, 0, 1] # Suppose partitions are: # [1 1 0 0 1 1] # After partitioning, batch indices are cut as: # [batch_indices[x] for x in 2, 3] # [batch_indices[x] for x in 0, 1, 4, 5] # i.e. # [1 1] and [0 0 0 0] # Now we sample n=2 from part 0 and n=4 from part 1. # For part 0 we want samples from batch entries 1, 1 (samples 0, 1), # and for part 1 we want samples from batch entries 0, 0, 0, 0 # (samples 0, 1, 2, 3). partitioned_batch_indices = tf.dynamic_partition( data=batch_raw_indices, partitions=cat_samples, num_partitions=self.num_components) samples_class = [None for _ in range(self.num_components)] stream = SeedStream(seed, salt='Mixture') for c in range(self.num_components): n_class = tf.size(partitioned_samples_indices[c]) samples_class_c = self.components[c].sample(n_class, seed=stream()) if event_shape is None: batch_ndims = prefer_static.rank_from_shape(batch_shape) event_shape = tf.shape(samples_class_c)[1 + batch_ndims:] # Pull out the correct batch entries from each index. # To do this, we may have to flatten the batch shape. # For sample s, batch element b of component c, we get the # partitioned batch indices from # partitioned_batch_indices[c]; and shift each element by # the sample index. The final lookup can be thought of as # a matrix gather along locations (s, b) in # samples_class_c where the n_class rows correspond to # samples within this component and the batch_size columns # correspond to batch elements within the component. # # Thus the lookup index is # lookup[c, i] = batch_size * s[i] + b[c, i] # for i = 0 ... n_class[c] - 1. lookup_partitioned_batch_indices = ( batch_size * tf.range(n_class) + partitioned_batch_indices[c]) samples_class_c = tf.reshape( samples_class_c, tf.concat([[n_class * batch_size], event_shape], 0)) samples_class_c = tf.gather(samples_class_c, lookup_partitioned_batch_indices, name='samples_class_c_gather') samples_class[c] = samples_class_c # Stitch back together the samples across the components. lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices, data=samples_class) # Reshape back to proper sample, batch, and event shape. ret = tf.reshape(lhs_flat_ret, tf.concat([samples_shape, event_shape], 0)) tensorshape_util.set_shape( ret, tensorshape_util.concatenate(static_samples_shape, self.event_shape)) return ret
def create_component(): loc = tf.random.normal(batch_and_event_shape) scale_diag = 10 * tf.random.uniform(batch_and_event_shape) tensorshape_util.set_shape(loc, static_batch_and_event_shape) tensorshape_util.set_shape(scale_diag, static_batch_and_event_shape) return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag)
def _mode(self): ret = tf.argmax(input=self.logits, axis=self._batch_rank) ret = tf.cast(ret, self.dtype) tensorshape_util.set_shape(ret, self.batch_shape) return ret
def histogram(x, edges, axis=None, extend_lower_interval=False, extend_upper_interval=False, dtype=None, name=None): """Count how often `x` falls in intervals defined by `edges`. Given `edges = [c0, ..., cK]`, defining intervals `I0 = [c0, c1)`, `I1 = [c1, c2)`, ..., `I_{K-1} = [c_{K-1}, cK]`, This function counts how often `x` falls into each interval. Values of `x` outside of the intervals cause errors. Consider using `extend_lower_interval`, `extend_upper_interval` to deal with this. Args: x: Numeric `N-D` `Tensor` with `N > 0`. If `axis` is not `None`, must have statically known number of dimensions. The `axis` kwarg determines which dimensions index iid samples. Other dimensions of `x` index "events" for which we will compute different histograms. edges: `Tensor` of same `dtype` as `x`. The first dimension indexes edges of intervals. Must either be `1-D` or have `edges.shape[1:]` the same as the dimensions of `x` excluding `axis`. If `rank(edges) > 1`, `edges[k]` designates a shape `edges.shape[1:]` `Tensor` of interval edges for the corresponding dimensions of `x`. axis: Optional `0-D` or `1-D` integer `Tensor` with constant values. The axis in `x` that index iid samples. `Default value:` `None` (treat every dimension as sample dimension). extend_lower_interval: Python `bool`. If `True`, extend the lowest interval `I0` to `(-inf, c1]`. extend_upper_interval: Python `bool`. If `True`, extend the upper interval `I_{K-1}` to `[c_{K-1}, +inf)`. dtype: The output type (`int32` or `int64`). `Default value:` `x.dtype`. name: A Python string name to prepend to created ops. `Default value:` 'histogram' Returns: counts: `Tensor` of type `dtype` and, with `~axis = [i for i in range(arr.ndim) if i not in axis]`, `counts.shape = [edges.shape[0]] + x.shape[~axis]`. With `I` a multi-index into `~axis`, `counts[k][I]` is the number of times event(s) fell into the `kth` interval of `edges`. #### Examples ```python # x.shape = [1000, 2] # x[:, 0] ~ Uniform(0, 1), x[:, 1] ~ Uniform(1, 2). x = tf.stack([tf.random.uniform([1000]), 1 + tf.random.uniform([1000])], axis=-1) # edges ==> bins [0, 0.5), [0.5, 1.0), [1.0, 1.5), [1.5, 2.0]. edges = [0., 0.5, 1.0, 1.5, 2.0] tfp.stats.histogram(x, edges) ==> approximately [500, 500, 500, 500] tfp.stats.histogram(x, edges, axis=0) ==> approximately [[500, 500, 0, 0], [0, 0, 500, 500]] ``` """ with tf.name_scope(name or 'histogram'): # Tensor conversions. in_dtype = dtype_util.common_dtype([x, edges], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, name='x', dtype=in_dtype) edges = tf.convert_to_tensor(edges, name='edges', dtype=in_dtype) # Move dims in axis to the left end as one flattened dim. # After this, x.shape = [n_samples] + E. if axis is None: x = tf.reshape(x, shape=[-1]) else: x_ndims = _get_static_ndims(x, expect_static=True, expect_ndims_at_least=1) axis = _make_static_axis_non_negative_list(axis, x_ndims) if not axis: raise ValueError( '`axis` cannot be empty. Found: {}'.format(axis)) x = _move_dims_to_flat_end(x, axis, x_ndims, right_end=False) # bins.shape = x.shape = [n_samples] + E, # and bins[i] is a shape E Tensor of the bins that sample `i` fell into. # E is the "event shape", which is [] if axis is None. bins = find_bins( x, edges=edges, # If not extending intervals, then values outside the edges will return # -1, which gives an error when fed to bincount. extend_lower_interval=extend_lower_interval, extend_upper_interval=extend_upper_interval, dtype=tf.int32) # TODO(b/124015136) Use standard tf.math.bincount once it supports `axis`. counts = count_integers( bins, # Ensure we get correct output, even if x did not fall into every bin minlength=tf.shape(edges)[0] - 1, maxlength=tf.shape(edges)[0] - 1, axis=0, dtype=dtype or in_dtype) n_edges = tf.compat.dimension_value(edges.shape[0]) if n_edges is not None: tensorshape_util.set_shape( counts, tf.TensorShape([n_edges - 1]).concatenate(counts.shape[1:])) return counts
def pinv(a, rcond=None, validate_args=False, name=None): """Compute the Moore-Penrose pseudo-inverse of a matrix. Calculate the [generalized inverse of a matrix]( https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its singular-value decomposition (SVD) and including all large singular values. The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves' [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1] This function is analogous to [`numpy.linalg.pinv`]( https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html). It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the default `rcond` is `1e-15`. Here the default is `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`. Args: a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be pseudo-inverted. rcond: `Tensor` of small singular value cutoffs. Singular values smaller (in modulus) than `rcond` * largest_singular_value (again, in modulus) are set to zero. Must broadcast against `tf.shape(a)[:-2]`. Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`. validate_args: When `True`, additional assertions might be embedded in the graph. Default value: `False` (i.e., no graph assertions are added). name: Python `str` prefixed to ops created by this function. Default value: 'pinv'. Returns: a_pinv: The pseudo-inverse of input `a`. Has same shape as `a` except rightmost two dimensions are transposed. Raises: TypeError: if input `a` does not have `float`-like `dtype`. ValueError: if input `a` has fewer than 2 dimensions. #### Examples ```python import tensorflow as tf import tensorflow_probability as tfp a = tf.constant([[1., 0.4, 0.5], [0.4, 0.2, 0.25], [0.5, 0.25, 0.35]]) tf.matmul(tfp.math.pinv(a), a) # ==> array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) a = tf.constant([[1., 0.4, 0.5, 1.], [0.4, 0.2, 0.25, 2.], [0.5, 0.25, 0.35, 3.]]) tf.matmul(tfp.math.pinv(a), a) # ==> array([[ 0.76, 0.37, 0.21, -0.02], [ 0.37, 0.43, -0.33, 0.02], [ 0.21, -0.33, 0.81, 0.01], [-0.02, 0.02, 0.01, 1. ]], dtype=float32) ``` #### References [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press, Inc., 1980, pp. 139-142. """ with tf.name_scope(name or 'pinv'): a = tf.convert_to_tensor(a, name='a') assertions = _maybe_validate_matrix(a, validate_args) if assertions: with tf.control_dependencies(assertions): a = tf.identity(a) dtype = dtype_util.as_numpy_dtype(a.dtype) if rcond is None: def get_dim_size(dim): if tf.compat.dimension_value(a.shape[dim]) is not None: return tf.compat.dimension_value(a.shape[dim]) return tf.shape(a)[dim] num_rows = get_dim_size(-2) num_cols = get_dim_size(-1) if isinstance(num_rows, int) and isinstance(num_cols, int): max_rows_cols = float(max(num_rows, num_cols)) else: max_rows_cols = tf.cast(tf.maximum(num_rows, num_cols), dtype) rcond = 10. * max_rows_cols * np.finfo(dtype).eps rcond = tf.convert_to_tensor(rcond, dtype=dtype, name='rcond') # Calculate pseudo inverse via SVD. # Note: if a is symmetric then u == v. (We might observe additional # performance by explicitly setting `v = u` in such cases.) [ singular_values, # Sigma left_singular_vectors, # U right_singular_vectors, # V ] = tf.linalg.svd(a, full_matrices=False, compute_uv=True) # Saturate small singular values to inf. This has the effect of make # `1. / s = 0.` while not resulting in `NaN` gradients. cutoff = rcond * tf.reduce_max(singular_values, axis=-1) singular_values = tf.where(singular_values > cutoff[..., tf.newaxis], singular_values, np.array(np.inf, dtype)) # Although `a == tf.matmul(u, s * v, transpose_b=True)` we swap # `u` and `v` here so that `tf.matmul(pinv(A), A) = tf.eye()`, i.e., # a matrix inverse has 'transposed' semantics. a_pinv = tf.matmul(right_singular_vectors / singular_values[..., tf.newaxis, :], left_singular_vectors, adjoint_b=True) if tensorshape_util.rank(a.shape) is not None: tensorshape_util.set_shape( a_pinv, tensorshape_util.concatenate(a.shape[:-2], [a.shape[-1], a.shape[-2]])) return a_pinv
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_reconstruct'). Returns: x: The original input to `tf.linalg.lu`, i.e., `x` as in, `lu_reconstruct(*tf.linalg.lu(x))`. #### Examples ```python import numpy as np import tensorflow as tf import tensorflow_probability as tfp x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x)) tf.assert_near(x, x_reconstructed) # ==> True ``` """ with tf.name_scope(name or 'lu_reconstruct'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) shape = tf.shape(lower_upper) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(shape[:-1], dtype=lower_upper.dtype)) upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1) x = tf.matmul(lower, upper) if (tensorshape_util.rank(lower_upper.shape) is None or tensorshape_util.rank(lower_upper.shape) != 2): # We either don't know the batch rank or there are >0 batch dims. batch_size = tf.reduce_prod(shape[:-2]) d = shape[-1] x = tf.reshape(x, [batch_size, d, d]) perm = tf.reshape(perm, [batch_size, d]) perm = tf.map_fn(tf.math.invert_permutation, perm) batch_indices = tf.broadcast_to( tf.range(batch_size)[:, tf.newaxis], [batch_size, d]) x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1)) x = tf.reshape(x, shape) else: x = tf.gather(x, tf.math.invert_permutation(perm)) tensorshape_util.set_shape(x, lower_upper.shape) return x
def _forward(self, x, **kwargs): y = super(Blockwise, self)._forward(x, **kwargs) if not self._maybe_changes_size: tensorshape_util.set_shape(y, x.shape) return y
def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and covariance `C_a`, `C_b` respectively, ``` KL(a || b) = 0.5 * ( L - k + T + Q ), L := Log[Det(C_b)] - Log[Det(C_a)] T := trace(C_b^{-1} C_a), Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), ``` This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient methods for solving systems with `C_b` may be available, a dense version of (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` and `y`. Args: a: Instance of `MultivariateNormalLinearOperator`. b: Instance of `MultivariateNormalLinearOperator`. name: (optional) name to use for created ops. Default "kl_mvn". Returns: Batchwise `KL(a || b)`. """ def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html # The gradient of KL[p,q] is not defined when p==q. The culprit is # tf.norm, i.e., we cannot use the commented out code. # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1])) return tf.reduce_sum(tf.square(x), axis=[-2, -1]) # TODO(b/35041439): See also b/35040945. Remove this function once LinOp # supports something like: # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) def is_diagonal(x): """Helper to identify if `LinearOperator` has only a diagonal component.""" return (isinstance(x, tf.linalg.LinearOperatorIdentity) or isinstance(x, tf.linalg.LinearOperatorScaledIdentity) or isinstance(x, tf.linalg.LinearOperatorDiag)) with tf.name_scope(name or "kl_mvn"): # Calculation is based on: # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians # and, # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm # i.e., # If Ca = AA', Cb = BB', then # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] # = tr[inv(B) A A' inv(B)'] # = tr[(inv(B) A) (inv(B) A)'] # = sum_{ij} (inv(B) A)_{ij}**2 # = ||inv(B) A||_F**2 # where ||.||_F is the Frobenius norm and the second equality follows from # the cyclic permutation property. if is_diagonal(a.scale) and is_diagonal(b.scale): # Using `stddev` because it handles expansion of Identity cases. b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis] else: b_inv_a = b.scale.solve(a.scale.to_dense()) kl_div = (b.scale.log_abs_determinant() - a.scale.log_abs_determinant() + 0.5 * (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( b.scale.solve((b.mean() - a.mean())[..., tf.newaxis])))) tensorshape_util.set_shape( kl_div, tf.broadcast_static_shape(a.batch_shape, b.batch_shape)) return kl_div
def _inverse(self, y, **kwargs): x = super(Blockwise, self)._inverse(y, **kwargs) if not self._maybe_changes_size: tensorshape_util.set_shape(x, y.shape) return x
def _log_loosum_exp_impl(logx, axis, keepdims, compute_mean): """Implementation for `*loosum*` functions.""" with tf.name_scope('log_loosum_exp_impl'): logx = tf.convert_to_tensor(logx, name='logx') dtype = dtype_util.as_numpy_dtype(logx.dtype) if axis is not None: x = np.array(axis) axis = (tf.convert_to_tensor( axis, name='axis', dtype_hint=tf.int32) if x.dtype is np.object else x.astype(np.int32)) log_sum_x = tf.reduce_logsumexp(logx, axis=axis, keepdims=True) # Later we'll want to compute the mean from a sum so we calculate the number # of reduced elements, n. n = prefer_static.size(logx) // prefer_static.size(log_sum_x) n = prefer_static.cast(n, dtype) # log_loosum_x[i] = # = logsumexp(logx[j] : j != i) # = log( exp(logsumexp(logx)) - exp(logx[i]) ) # = log( exp(logsumexp(logx - logx[i])) exp(logx[i]) - exp(logx[i])) # = logx[i] + log(exp(logsumexp(logx - logx[i])) - 1) # = logx[i] + log(exp(logsumexp(logx) - logx[i]) - 1) # = logx[i] + softplus_inverse(logsumexp(logx) - logx[i]) d = log_sum_x - logx # We use `d != 0` rather than `d > 0.` because `d < 0.` should never happen; # if it does we want to complain loudly (which `softplus_inverse` will). d_ok = tf.not_equal(d, 0.) safe_d = tf.where(d_ok, d, 1.) d_ok_result = logx + softplus_inverse(safe_d) neg_inf = tf.constant(-np.inf, dtype=dtype) # When not(d_ok) and is_positive_and_largest then we manually compute the # log_loosum_x. (We can efficiently do this for any one point but not all, # hence we still need the above calculation.) This is good because when # this condition is met, we cannot use the above calculation; its -inf. # We now compute the log-leave-out-max-sum, replicate it to every # point and make sure to select it only when we need to. max_logx = tf.reduce_max(logx, axis=axis, keepdims=True) is_positive_and_largest = (logx > 0.) & tf.equal(logx, max_logx) log_lomsum_x = tf.reduce_logsumexp(tf.where(is_positive_and_largest, neg_inf, logx), axis=axis, keepdims=True) d_not_ok_result = tf.where(is_positive_and_largest, log_lomsum_x, neg_inf) log_loosum_x = tf.where(d_ok, d_ok_result, d_not_ok_result) # We now squeeze log_sum_x so as if we used `keepdims=False`. # TODO(b/136176077): These mental gymnastics could all be replaced with # `tf.squeeze(log_sum_x, axis)` if tf.squeeze supported Tensor valued `axis` # arguments. if not keepdims: if axis is None: keepdims = np.array([], dtype=np.int32) else: rank = prefer_static.rank(logx) keepdims = prefer_static.setdiff1d( prefer_static.range(rank), prefer_static.non_negative_axis(axis, rank)) squeeze_shape = tf.gather(prefer_static.shape(logx), indices=keepdims) log_sum_x = tf.reshape(log_sum_x, shape=squeeze_shape) if prefer_static.is_numpy(keepdims): tensorshape_util.set_shape(log_sum_x, np.array(logx.shape)[keepdims]) # Set static shapes just in case we lost them. tensorshape_util.set_shape(n, []) tensorshape_util.set_shape(log_loosum_x, logx.shape) if not compute_mean: return log_loosum_x, log_sum_x, n log_nm1 = prefer_static.log(max(1., n - 1.)) log_n = prefer_static.log(n) return log_loosum_x - log_nm1, log_sum_x - log_n, n
def _mode(self): log_probs = self.categorical_log_probs() mode = tf.argmax(log_probs, axis=-1, output_type=self.dtype) tensorshape_util.set_shape(mode, log_probs.shape[:-1]) return mode
def fill_triangular_inverse(x, upper=False, name=None): """Creates a vector from a (batch of) triangular matrix. The vector is created from the lower-triangular or upper-triangular portion depending on the value of the parameter `upper`. If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. Example: ```python fill_triangular_inverse( [[4, 0, 0], [6, 5, 0], [3, 2, 1]]) # ==> [1, 2, 3, 4, 5, 6] fill_triangular_inverse( [[1, 2, 3], [0, 5, 6], [0, 0, 4]], upper=True) # ==> [1, 2, 3, 4, 5, 6] ``` Args: x: `Tensor` representing lower (or upper) triangular elements. upper: Python `bool` representing whether output matrix should be upper triangular (`True`) or lower triangular (`False`, default). name: Python `str`. The name to give this op. Returns: flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower (or upper) triangular elements from `x`. """ with tf.name_scope(name or 'fill_triangular_inverse'): x = tf.convert_to_tensor(x, name='x') n = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(x.shape, 2)[-1]) if n is not None: n = np.int32(n) m = np.int32((n * (n + 1)) // 2) static_final_shape = tensorshape_util.concatenate( x.shape[:-2], [m]) else: n = tf.shape(x)[-1] m = (n * (n + 1)) // 2 static_final_shape = tensorshape_util.concatenate( tensorshape_util.with_rank_at_least(x.shape, 2)[:-2], [None]) ndims = prefer_static.rank(x) if upper: initial_elements = x[..., 0, :] triangular_portion = x[..., 1:, :] else: initial_elements = tf.reverse(x[..., -1, :], axis=[ndims - 2]) triangular_portion = x[..., :-1, :] rotated_triangular_portion = tf.reverse(tf.reverse(triangular_portion, axis=[ndims - 1]), axis=[ndims - 2]) consolidated_matrix = triangular_portion + rotated_triangular_portion end_sequence = tf.reshape( consolidated_matrix, tf.concat([tf.shape(x)[:-2], [n * (n - 1)]], axis=0)) y = tf.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) tensorshape_util.set_shape(y, static_final_shape) return y
def _merge_static_length(x): tensorshape_util.set_shape(x, static_length.concatenate(x.shape[1:])) return x
def fill_triangular(x, upper=False, name=None): """Creates a (batch of) triangular matrix from a vector of inputs. Created matrix can be lower- or upper-triangular. (It is more efficient to create the matrix as upper or lower, rather than transpose.) Triangular matrix elements are filled in a clockwise spiral. See example, below. If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. Example: ```python fill_triangular([1, 2, 3, 4, 5, 6]) # ==> [[4, 0, 0], # [6, 5, 0], # [3, 2, 1]] fill_triangular([1, 2, 3, 4, 5, 6], upper=True) # ==> [[1, 2, 3], # [0, 5, 6], # [0, 0, 4]] ``` The key trick is to create an upper triangular matrix by concatenating `x` and a tail of itself, then reshaping. Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M` from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x` contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5` (so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with the first (`n = 5`) elements removed and reversed: ```python x = np.arange(15) + 1 xc = np.concatenate([x, x[5:][::-1]]) # ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13, # 12, 11, 10, 9, 8, 7, 6]) # (We add one to the arange result to disambiguate the zeros below the # diagonal of our upper-triangular matrix from the first entry in `x`.) # Now, when reshapedlay this out as a matrix: y = np.reshape(xc, [5, 5]) # ==> array([[ 1, 2, 3, 4, 5], # [ 6, 7, 8, 9, 10], # [11, 12, 13, 14, 15], # [15, 14, 13, 12, 11], # [10, 9, 8, 7, 6]]) # Finally, zero the elements below the diagonal: y = np.triu(y, k=0) # ==> array([[ 1, 2, 3, 4, 5], # [ 0, 7, 8, 9, 10], # [ 0, 0, 13, 14, 15], # [ 0, 0, 0, 12, 11], # [ 0, 0, 0, 0, 6]]) ``` From this example we see that the resuting matrix is upper-triangular, and contains all the entries of x, as desired. The rest is details: - If `n` is even, `x` doesn't exactly fill an even number of rows (it fills `n / 2` rows and half of an additional row), but the whole scheme still works. - If we want a lower triangular matrix instead of an upper triangular, we remove the first `n` elements from `x` rather than from the reversed `x`. For additional comparisons, a pure numpy version of this function can be found in `distribution_util_test.py`, function `_fill_triangular`. Args: x: `Tensor` representing lower (or upper) triangular elements. upper: Python `bool` representing whether output matrix should be upper triangular (`True`) or lower triangular (`False`, default). name: Python `str`. The name to give this op. Returns: tril: `Tensor` with lower (or upper) triangular elements filled from `x`. Raises: ValueError: if `x` cannot be mapped to a triangular matrix. """ with tf.name_scope(name or 'fill_triangular'): x = tf.convert_to_tensor(x, name='x') m = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(x.shape, 1)[-1]) if m is not None: # Formula derived by solving for n: m = n(n+1)/2. m = np.int32(m) n = np.sqrt(0.25 + 2. * m) - 0.5 if n != np.floor(n): raise ValueError( 'Input right-most shape ({}) does not ' 'correspond to a triangular matrix.'.format(m)) n = np.int32(n) static_final_shape = tensorshape_util.concatenate( x.shape[:-1], [n, n]) else: m = tf.shape(x)[-1] # For derivation, see above. Casting automatically lops off the 0.5, so we # omit it. We don't validate n is an integer because this has # graph-execution cost; an error will be thrown from the reshape, below. n = tf.cast(tf.sqrt(0.25 + tf.cast(2 * m, dtype=tf.float32)), dtype=tf.int32) static_final_shape = tensorshape_util.concatenate( tensorshape_util.with_rank_at_least(x.shape, 1)[:-1], [None, None]) # Try it out in numpy: # n = 3 # x = np.arange(n * (n + 1) / 2) # m = x.shape[0] # n = np.int32(np.sqrt(.25 + 2 * m) - .5) # x_tail = x[(m - (n**2 - m)):] # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower # # ==> array([[3, 4, 5], # [5, 4, 3], # [2, 1, 0]]) # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper # # ==> array([[0, 1, 2], # [3, 4, 5], # [5, 4, 3]]) # # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. # Furthermore observe that: # m - (n**2 - m) # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) # = 2 (n**2 / 2 + n / 2) - n**2 # = n**2 + n - n**2 # = n ndims = prefer_static.rank(x) if upper: x_list = [x, tf.reverse(x[..., n:], axis=[ndims - 1])] else: x_list = [x[..., n:], tf.reverse(x, axis=[ndims - 1])] new_shape = (tensorshape_util.as_list(static_final_shape) if tensorshape_util.is_fully_defined(static_final_shape) else tf.concat([tf.shape(x)[:-1], [n, n]], axis=0)) x = tf.reshape(tf.concat(x_list, axis=-1), new_shape) x = tf.linalg.band_part(x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0)) tensorshape_util.set_shape(x, static_final_shape) return x
def xform_static(x): # Copy the Tensor, because otherwise the set_shape can pass information # into the past. x = tf.identity(x) tensorshape_util.set_shape(x, [1]) return x
def _mode(self): x = self._probs if self._logits is None else self._logits mode = tf.cast(tf.argmax(x, axis=-1), self.dtype) tensorshape_util.set_shape(mode, x.shape[:-1]) return mode
def quadrature_scheme_lognormal_quantiles(loc, scale, quadrature_size, validate_args=False, name=None): """Use LogNormal quantiles to form quadrature on positive-reals. Args: loc: `float`-like (batch of) scalar `Tensor`; the location parameter of the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. quadrature_size: Python `int` scalar representing the number of quadrature points. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. name: Python `str` name prefixed to Ops created by this class. Returns: grid: (Batch of) length-`quadrature_size` vectors representing the `log_rate` parameters of a `Poisson`. probs: (Batch of) length-`quadrature_size` vectors representing the weight associate with each `grid` value. """ with tf.name_scope(name or "quadrature_scheme_lognormal_quantiles"): # Create a LogNormal distribution. dist = transformed_distribution.TransformedDistribution( distribution=normal.Normal(loc=loc, scale=scale), bijector=exp_bijector.Exp(), validate_args=validate_args) batch_ndims = tensorshape_util.rank(dist.batch_shape) if batch_ndims is None: batch_ndims = tf.shape(dist.batch_shape_tensor())[0] def _compute_quantiles(): """Helper to build quantiles.""" # Omit {0, 1} since they might lead to Inf/NaN. zero = tf.zeros([], dtype=dist.dtype) edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1] # Expand edges so its broadcast across batch dims. edges = tf.reshape( edges, shape=tf.concat( [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0)) quantiles = dist.quantile(edges) # Cyclically permute left by one. perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0) quantiles = tf.transpose(a=quantiles, perm=perm) return quantiles quantiles = _compute_quantiles() # Compute grid as quantile midpoints. grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. # Set shape hints. new_shape = tensorshape_util.concatenate(dist.batch_shape, [quadrature_size]) tensorshape_util.set_shape(grid, new_shape) # By construction probs is constant, i.e., `1 / quadrature_size`. This is # important, because non-constant probs leads to non-reparameterizable # samples. probs = tf.fill(dims=[quadrature_size], value=1. / tf.cast(quadrature_size, dist.dtype)) return grid, probs
def _log_prob(self, x): if self.input_output_cholesky: x_sqrt = x else: # Complexity: O(nbk**3) x_sqrt = tf.linalg.cholesky(x) df = tf.convert_to_tensor(self.df) batch_shape = self._batch_shape_tensor(df) event_shape = self._event_shape_tensor() dimension = self._dimension() x_ndims = ps.rank(x_sqrt) num_singleton_axes_to_prepend = ( ps.maximum(ps.size(batch_shape) + 2, x_ndims) - x_ndims) x_with_prepended_singletons_shape = ps.concat([ ps.ones([num_singleton_axes_to_prepend], dtype=tf.int32), ps.shape(x_sqrt) ], 0) x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape) ndims = ps.rank(x_sqrt) # sample_ndims = ndims - batch_ndims - event_ndims sample_ndims = ndims - ps.size(batch_shape) - 2 sample_shape = ps.shape(x_sqrt)[:sample_ndims] # We need to be able to pre-multiply each matrix by its corresponding # batch scale matrix. Since a Distribution Tensor supports multiple # samples per batch, this means we need to reshape the input matrix `x` # so that the first b dimensions are batch dimensions and the last two # are of shape [dimension, dimensions*number_of_samples]. Doing these # gymnastics allows us to do a batch_solve. # # After we're done with sqrt_solve (the batch operation) we need to undo # this reshaping so what we're left with is a Tensor partitionable by # sample, batch, event dimensions. # Complexity: O(nbk**2) since transpose must access every element. scale_sqrt_inv_x_sqrt = x_sqrt perm = ps.concat( [ps.range(sample_ndims, ndims), ps.range(0, sample_ndims)], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) last_dim_size = ( ps.cast(dimension, dtype=tf.int32) * ps.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims])) shape = ps.concat([ x_with_prepended_singletons_shape[sample_ndims:-2], [ps.cast(dimension, dtype=tf.int32), last_dim_size] ], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) # Complexity: O(nbM*k) where M is the complexity of the operator solving a # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so # this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self._scale.solve(scale_sqrt_inv_x_sqrt) # Undo make batch-op ready. # Complexity: O(nbk**2) shape = ps.concat( [ps.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape], axis=0) scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape) perm = ps.concat([ ps.range(ndims - sample_ndims, ndims), ps.range(0, ndims - sample_ndims) ], 0) scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt, perm=perm) # Write V = SS', X = LL'. Then: # tr[inv(V) X] = tr[inv(S)' inv(S) L L'] # = tr[inv(S) L L' inv(S)'] # = tr[(inv(S) L) (inv(S) L)'] # = sum_{ik} (inv(S) L)_{ik}**2 # The second equality follows from the cyclic permutation property. # Complexity: O(nbk**2) trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1]) # Complexity: O(nbk) half_log_det_x = tf.reduce_sum(tf.math.log( tf.linalg.diag_part(x_sqrt)), axis=[-1]) # Complexity: O(nbk**2) log_prob = ((df - dimension - 1.) * half_log_det_x - 0.5 * trace_scale_inv_x - self._log_normalization(df=df, scale=self._scale)) # Set shape hints. # Try to merge what we know from the input x with what we know from the # parameters of this distribution. if tensorshape_util.rank( x.shape) is not None and tensorshape_util.rank( self.batch_shape) is not None: tensorshape_util.set_shape( log_prob, tf.broadcast_static_shape(x.shape[:-2], self.batch_shape)) return log_prob
def target_log_prob_fn(*args): lp = pinned_model.unnormalized_log_prob(bijector.inverse(args)) tensorshape_util.set_shape(lp, lp_static_shape) ldj = bijector.inverse_log_det_jacobian( args, event_ndims=[1 for _ in initial_transformed_position]) return lp + ldj