def _z(self, x, scale=None): """Standardize input `x` to a unit normal.""" with tf.name_scope('standardize'): return (x - self.loc) / (self.scale if scale is None else scale)
def _kl_dirichlet_dirichlet(d1, d2, name=None): """Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet. Args: d1: instance of a Dirichlet distribution object. d2: instance of a Dirichlet distribution object. name: Python `str` name to use for created operations. Default value: `None` (i.e., `'kl_dirichlet_dirichlet'`). Returns: kl_div: Batchwise KL(d1 || d2) """ with tf.name_scope(name or 'kl_dirichlet_dirichlet'): # The KL between Dirichlet distributions can be derived as follows. We have # # Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)] # # where B(a) is the multivariate Beta function: # # B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n]) # # The KL is # # KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))} # # so we'll need to know the log density of the Dirichlet. This is # # log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a) # # The only term that matters for the expectations is the log(x[i]). To # compute the expectation of this term over the Dirichlet density, we can # use the following facts about the Dirichlet in exponential family form: # 1. log(x[i]) is a sufficient statistic # 2. expected sufficient statistics (of any exp family distribution) are # equal to derivatives of the log normalizer with respect to # corresponding natural parameters: E{T[i](x)} = dA/d(eta[i]) # # To proceed, we can rewrite the Dirichlet density in exponential family # form as follows: # # Dir(x; a) = exp{eta(a) . T(x) - A(a)} # # where '.' is the dot product of vectors eta and T, and A is a scalar: # # eta[i](a) = a[i] - 1 # T[i](x) = log(x[i]) # A(a) = log B(a) # # Now, we can use fact (2) above to write # # E_Dir(x; a)[log(x[i])] # = dA(a) / da[i] # = d/da[i] log B(a) # = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j]) # = digamma(a[i])) - digamma(sum_j a[j]) # # Putting it all together, we have # # KL[Dir(x; a) || Dir(x; b)] # = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)} # = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b)) # = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b) # = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))] # - lbeta(a) + lbeta(b)) concentration1 = tf.convert_to_tensor(d1.concentration) concentration2 = tf.convert_to_tensor(d2.concentration) digamma_sum_d1 = tf.math.digamma( tf.reduce_sum(concentration1, axis=-1, keepdims=True)) digamma_diff = tf.math.digamma(concentration1) - digamma_sum_d1 concentration_diff = concentration1 - concentration2 return ( tf.reduce_sum(concentration_diff * digamma_diff, axis=-1) - tf.math.lbeta(concentration1) + tf.math.lbeta(concentration2))
def __init__(self, loc=None, scale=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalLinearOperator'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape `[B1, ..., Bb, k, k]`. validate_args: Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): raise TypeError( '`scale` parameter must have floating-point dtype.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32) # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = tensor_util.convert_nonref_to_tensor(loc, dtype=dtype, name='loc') batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) super(MultivariateNormalLinearOperator, self).__init__( distribution=normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype)), bijector=affine_linear_operator_bijector.AffineLinearOperator( shift=loc, scale=scale, validate_args=validate_args), batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_reconstruct'). Returns: x: The original input to `tf.linalg.lu`, i.e., `x` as in, `lu_reconstruct(*tf.linalg.lu(x))`. #### Examples ```python import numpy as np from tensorflow_probability.python.internal.backend import jax as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x)) tf.assert_near(x, x_reconstructed) # ==> True ``` """ with tf.name_scope(name or 'lu_reconstruct'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) shape = tf.shape(lower_upper) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(shape[:-1], dtype=lower_upper.dtype)) upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1) x = tf.matmul(lower, upper) if (tensorshape_util.rank(lower_upper.shape) is None or tensorshape_util.rank(lower_upper.shape) != 2): # We either don't know the batch rank or there are >0 batch dims. batch_size = tf.reduce_prod(shape[:-2]) d = shape[-1] x = tf.reshape(x, [batch_size, d, d]) perm = tf.reshape(perm, [batch_size, d]) perm = tf.map_fn(tf.math.invert_permutation, perm) batch_indices = tf.broadcast_to( tf.range(batch_size)[:, tf.newaxis], [batch_size, d]) x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1)) x = tf.reshape(x, shape) else: x = tf.gather(x, tf.math.invert_permutation(perm)) x.set_shape(lower_upper.shape) return x
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope(mcmc_util.make_name(self.name, 'tmc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.post_tempering_inverse_temperatures, name='inverse_temperatures') steps_at_temperature = tf.convert_to_tensor( previous_kernel_results.steps_at_temperature, name='number of steps') target_score_for_inner_kernel = partial(self.target_score_fn, sigma=inverse_temperatures) target_log_prob_for_inner_kernel = partial( self.target_log_prob_fn, sigma=inverse_temperatures) try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, target_score_for_inner_kernel, inverse_temperatures) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is ' 'deprecated. `TransitionKernel` instances now receive seeds via ' '`one_step`.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, target_score_for_inner_kernel, inverse_temperatures, self._seed_stream()) if seed is not None: seed = samplers.sanitize_seed(seed) inner_seed, swap_seed, logu_seed = samplers.split_seed( seed, n=3, salt='tmc_one_step') inner_kwargs = dict(seed=inner_seed) else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) inner_kwargs = {} swap_seed, logu_seed = samplers.split_seed(self._seed_stream()) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = current_state else: states = [current_state] print(states) [ new_state, pre_tempering_results, ] = inner_kernel.one_step( states, previous_kernel_results.post_tempering_results, **inner_kwargs) # Now that we have run one step, we consider maybe lowering the temperature # Proposed new temperature proposed_inverse_temperatures = tf.clip_by_value( self.gamma * inverse_temperatures, self.min_temp, 1e6) dtype = inverse_temperatures.dtype # We will lower the temperature if this new proposed step is compatible with # a temperature swap v = new_state[0] - states[0] cs = states[0] @jax.vmap def integrand(t): return jnp.sum(self._parameters['target_score_fn']( t * v + cs, inverse_temperatures) * v, axis=-1) delta_logp1 = simps(integrand, 0., 1., self._parameters['num_delta_logp_steps']) # Now we compute the reverse v = -v cs = new_state[0] @jax.vmap def integrand(t): return jnp.sum(self._parameters['target_score_fn']( t * v + cs, proposed_inverse_temperatures) * v, axis=-1) delta_logp2 = simps(integrand, 0., 1., self._parameters['num_delta_logp_steps']) log_accept_ratio = (delta_logp1 + delta_logp2) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce Log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=log_accept_ratio.shape, dtype=dtype, seed=logu_seed)) is_tempering_accepted_mask = tf.less( log_uniform, log_accept_ratio, name='is_tempering_accepted_mask') is_min_steps_satisfied = tf.greater( steps_at_temperature, self.min_steps_per_temp * tf.ones_like(steps_at_temperature), name='is_min_steps_satisfied') # Only propose tempering if the chain was going to accept this point anyway is_tempering_accepted_mask = tf.math.logical_and( is_tempering_accepted_mask, pre_tempering_results.is_accepted) is_tempering_accepted_mask = tf.math.logical_and( is_tempering_accepted_mask, is_min_steps_satisfied) # Updating accepted inverse temperatures post_tempering_inverse_temperatures = mcmc_util.choose( is_tempering_accepted_mask, proposed_inverse_temperatures, inverse_temperatures) steps_at_temperature = mcmc_util.choose( is_tempering_accepted_mask, tf.zeros_like(steps_at_temperature), steps_at_temperature + 1) # Invalidating and recomputing results [ new_target_log_prob, new_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads( partial(self.target_log_prob_fn, sigma=post_tempering_inverse_temperatures), new_state) # Updating inner kernel results post_tempering_results = pre_tempering_results._replace( proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype), proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype), ) if isinstance(post_tempering_results.accepted_results, hmc.UncalibratedHamiltonianMonteCarloKernelResults): post_tempering_results = post_tempering_results._replace( accepted_results=post_tempering_results.accepted_results. _replace(target_log_prob=new_target_log_prob, grads_target_log_prob=new_grads_target_log_prob)) elif isinstance( post_tempering_results.accepted_results, random_walk_metropolis.UncalibratedRandomWalkResults): post_tempering_results = post_tempering_results._replace( accepted_results=post_tempering_results.accepted_results. _replace(target_log_prob=new_target_log_prob)) else: # TODO(b/143702650) Handle other kernels. raise NotImplementedError( 'Only HMC and RWMH Kernels are handled at this time. Please file a ' 'request with the TensorFlow Probability team.') new_kernel_results = TemperedMCKernelResults( pre_tempering_results=pre_tempering_results, post_tempering_results=post_tempering_results, pre_tempering_inverse_temperatures=inverse_temperatures, post_tempering_inverse_temperatures= post_tempering_inverse_temperatures, tempering_log_accept_ratio=log_accept_ratio, steps_at_temperature=steps_at_temperature, seed=samplers.zeros_seed() if seed is None else seed, ) return new_state[0], new_kernel_results
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None): """Computes the (partial) pivoted cholesky decomposition of `matrix`. The pivoted Cholesky is a low rank approximation of the Cholesky decomposition of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The currently-worst-approximated diagonal element is selected as the pivot at each iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn, N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`. Note that, unlike the Cholesky decomposition, `lr` is not triangular even in a rectangular-matrix sense. However, under a permutation it could be made triangular (it has one more zero in each column as you move to the right). Such a matrix can be useful as a preconditioner for conjugate gradient optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be cheaply done via the Woodbury matrix identity, as implemented by `tf.linalg.LinearOperatorLowRankUpdate`. Args: matrix: Floating point `Tensor` batch of symmetric, positive definite matrices. max_rank: Scalar `int` `Tensor`, the rank at which to truncate the approximation. diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the errors of all diagonal elements of `lr @ lr.T` are each lower than `element * diag_rtol`, iteration is permitted to terminate early. name: Optional name for the op. Returns: lr: Low rank pivoted Cholesky approximation of `matrix`. #### References [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the pivoted Cholesky decomposition. _Applied numerical mathematics_, 62(4):428-440, 2012. [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points. _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114 """ with tf.name_scope(name or 'pivoted_cholesky'): dtype = dtype_util.common_dtype([matrix, diag_rtol], dtype_hint=tf.float32) matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype) if tensorshape_util.rank(matrix.shape) is None: raise NotImplementedError( 'Rank of `matrix` must be known statically') max_rank = tf.convert_to_tensor(max_rank, name='max_rank', dtype=tf.int64) max_rank = tf.minimum( max_rank, prefer_static.shape(matrix, out_type=tf.int64)[-1]) diag_rtol = tf.convert_to_tensor(diag_rtol, dtype=dtype, name='diag_rtol') matrix_diag = tf.linalg.diag_part(matrix) # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs. orig_error = tf.reduce_max(matrix_diag, axis=-1) def cond(m, pchol, perm, matrix_diag): """Condition for `tf.while_loop` continuation.""" del pchol del perm error = tf.linalg.norm(matrix_diag, ord=1, axis=-1) max_err = tf.reduce_max(error / orig_error) return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol)) batch_dims = tensorshape_util.rank(matrix.shape) - 2 def batch_gather(params, indices, axis=-1): return tf.gather(params, indices, axis=axis, batch_dims=batch_dims) def body(m, pchol, perm, matrix_diag): """Body of a single `tf.while_loop` iteration.""" # Here is roughly a numpy, non-batched version of what's going to happen. # (See also Algorithm 1 of Harbrecht et al.) # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m # 2: maxval = matrix_diag[perm][maxi] # 3: perm[m], perm[maxi] = perm[maxi], perm[m] # 4: row = matrix[perm[m]][perm[m + 1:]] # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2) # 6: pivot = np.sqrt(maxval); row /= pivot # 7: row = np.concatenate([[[pivot]], row], -1) # 8: matrix_diag[perm[m:]] -= row**2 # 9: pchol[m, perm[m:]] = row # Find the maximal position of the (remaining) permuted diagonal. # Steps 1, 2 above. permuted_diag = batch_gather(matrix_diag, perm[..., m:]) maxi = tf.argmax(permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis] maxval = batch_gather(permuted_diag, maxi) maxi = maxi + m maxval = maxval[..., 0] # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above. perm = _swap_m_with_i(perm, m, maxi) # Step 4. row = batch_gather(matrix, perm[..., m:m + 1], axis=-2) row = batch_gather(row, perm[..., m + 1:]) # Step 5. prev_rows = pchol[..., :m, :] prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:]) prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1]) row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col, axis=-2)[..., tf.newaxis, :] # Step 6. pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis] # Step 7. row = tf.concat([pivot, row / pivot], axis=-1) # TODO(b/130899118): Pad grad fails with int64 paddings. # Step 8. paddings = tf.concat([ tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32), [[tf.cast(m, tf.int32), 0]] ], axis=0) diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :] reverse_perm = _invert_permutation(perm) matrix_diag -= batch_gather(diag_update, reverse_perm) # Step 9. row = tf.pad(row, paddings=paddings) # TODO(bjp): Defer the reverse permutation all-at-once at the end? row = batch_gather(row, reverse_perm) pchol_shape = pchol.shape pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]], axis=-2) tensorshape_util.set_shape(pchol, pchol_shape) return m + 1, pchol, perm, matrix_diag m = np.int64(0) pchol = tf.zeros_like(matrix[..., :max_rank, :]) matrix_shape = prefer_static.shape(matrix, out_type=tf.int64) perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]), matrix_shape[:-1]) _, pchol, _, _ = tf.while_loop(cond=cond, body=body, loop_vars=(m, pchol, perm, matrix_diag)) pchol = tf.linalg.matrix_transpose(pchol) tensorshape_util.set_shape( pchol, tensorshape_util.concatenate(matrix_diag.shape, [None])) return pchol
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): """Solves systems of linear eqns `A X = RHS`, given LU factorizations. Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. rhs: Matrix-shaped float `Tensor` representing targets for which to solve; `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_solve'). Returns: x: The `X` in `A @ X = RHS`. #### Examples ```python import numpy as np from tensorflow_probability.python.internal.backend import jax as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax x = [[[1., 2], [3, 4]], [[7, 8], [3, 4]]] inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with tf.name_scope(name or 'lu_solve'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') rhs = tf.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) rhs = tf.identity(rhs) if (tensorshape_util.rank(rhs.shape) == 2 and tensorshape_util.rank(perm.shape) == 1): # Both rhs and perm have scalar batch_shape. permuted_rhs = tf.gather(rhs, perm, axis=-2) else: # Either rhs or perm have non-scalar batch_shape or we can't determine # this information statically. rhs_shape = tf.shape(rhs) broadcast_batch_shape = tf.broadcast_dynamic_shape( rhs_shape[:-2], tf.shape(perm)[:-1]) d, m = rhs_shape[-2], rhs_shape[-1] rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]], axis=0) # Tile out rhs. broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape) broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m]) # Tile out perm and add batch indices. broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1]) broadcast_perm = tf.reshape(broadcast_perm, [-1, d]) broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape) broadcast_batch_indices = tf.broadcast_to( tf.range(broadcast_batch_size)[:, tf.newaxis], [broadcast_batch_size, d]) broadcast_perm = tf.stack( [broadcast_batch_indices, broadcast_perm], axis=-1) permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm) permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) return linear_operator_util.matrix_triangular_solve_with_broadcast( lower_upper, # Only upper is accessed. linear_operator_util.matrix_triangular_solve_with_broadcast( lower, permuted_rhs), lower=False)
def reduce_weighted_logsumexp(logx, w=None, axis=None, keep_dims=False, return_sign=False, name=None): """Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`. If all weights `w` are known to be positive, it is more efficient to directly use `reduce_logsumexp`, i.e., `tf.reduce_logsumexp(logx + tf.log(w))` is more efficient than `du.reduce_weighted_logsumexp(logx, w)`. Reduces `input_tensor` along the dimensions given in `axis`. Unless `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in `axis`. If `keep_dims` is true, the reduced dimensions are retained with length 1. If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. This function is more numerically stable than log(sum(w * exp(input))). It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs. For example: ```python x = tf.constant([[0., 0, 0], [0, 0, 0]]) w = tf.constant([[-1., 1, 1], [1, 1, 1]]) du.reduce_weighted_logsumexp(x, w) # ==> log(-1*1 + 1*1 + 1*1 + 1*1 + 1*1 + 1*1) = log(4) du.reduce_weighted_logsumexp(x, w, axis=0) # ==> [log(-1+1), log(1+1), log(1+1)] du.reduce_weighted_logsumexp(x, w, axis=1) # ==> [log(-1+1+1), log(1+1+1)] du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True) # ==> [[log(-1+1+1)], [log(1+1+1)]] du.reduce_weighted_logsumexp(x, w, axis=[0, 1]) # ==> log(-1+5) ``` Args: logx: The tensor to reduce. Should have numeric type. w: The weight tensor. Should have numeric type identical to `logx`. axis: The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(input_tensor), rank(input_tensor))`. keep_dims: If true, retains reduced dimensions with length 1. return_sign: If `True`, returns the sign of the result. name: A name for the operation (optional). Returns: lswe: The `log(abs(sum(weight * exp(x))))` reduced tensor. sign: (Optional) The sign of `sum(weight * exp(x))`. """ with tf.name_scope(name or 'reduce_weighted_logsumexp'): logx = tf.convert_to_tensor(logx, name='logx') if w is None: lswe = tf.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims) if return_sign: sgn = tf.ones_like(lswe) return lswe, sgn return lswe w = tf.convert_to_tensor(w, dtype=logx.dtype, name='w') log_absw_x = logx + tf.math.log(tf.abs(w)) max_log_absw_x = tf.reduce_max(log_absw_x, axis=axis, keepdims=True) # If the largest element is `-inf` or `inf` then we don't bother subtracting # off the max. We do this because otherwise we'd get `inf - inf = NaN`. That # this is ok follows from the fact that we're actually free to subtract any # value we like, so long as we add it back after taking the `log(sum(...))`. max_log_absw_x = tf.where(tf.math.is_inf(max_log_absw_x), tf.zeros([], max_log_absw_x.dtype), max_log_absw_x) wx_over_max_absw_x = (tf.sign(w) * tf.exp(log_absw_x - max_log_absw_x)) sum_wx_over_max_absw_x = tf.reduce_sum(wx_over_max_absw_x, axis=axis, keepdims=keep_dims) if not keep_dims: max_log_absw_x = tf.squeeze(max_log_absw_x, axis) sgn = tf.sign(sum_wx_over_max_absw_x) lswe = max_log_absw_x + tf.math.log(sgn * sum_wx_over_max_absw_x) if return_sign: return lswe, sgn return lswe
def soft_threshold(x, threshold, name=None): """Soft Thresholding operator. This operator is defined by the equations ```none { x[i] - gamma, x[i] > gamma SoftThreshold(x, gamma)[i] = { 0, x[i] == gamma { x[i] + gamma, x[i] < -gamma ``` In the context of proximal gradient methods, we have ```none SoftThreshold(x, gamma) = prox_{gamma L1}(x) ``` where `prox` is the proximity operator. Thus the soft thresholding operator is used in proximal gradient descent for optimizing a smooth function with (non-smooth) L1 regularization, as outlined below. The proximity operator is defined as: ```none prox_r(x) = argmin{ r(z) + 0.5 ||x - z||_2**2 : z }, ``` where `r` is a (weakly) convex function, not necessarily differentiable. Because the L2 norm is strictly convex, the above argmin is unique. One important application of the proximity operator is as follows. Let `L` be a convex and differentiable function with Lipschitz-continuous gradient. Let `R` be a convex lower semicontinuous function which is possibly nondifferentiable. Let `gamma` be an arbitrary positive real. Then ```none x_star = argmin{ L(x) + R(x) : x } ``` if and only if the fixed-point equation is satisfied: ```none x_star = prox_{gamma R}(x_star - gamma grad L(x_star)) ``` Proximal gradient descent thus typically consists of choosing an initial value `x^{(0)}` and repeatedly applying the update ```none x^{(k+1)} = prox_{gamma^{(k)} R}(x^{(k)} - gamma^{(k)} grad L(x^{(k)})) ``` where `gamma` is allowed to vary from iteration to iteration. Specializing to the case where `R(x) = ||x||_1`, we minimize `L(x) + ||x||_1` by repeatedly applying the update ``` x^{(k+1)} = SoftThreshold(x - gamma grad L(x^{(k)}), gamma) ``` (This idea can also be extended to second-order approximations, although the multivariate case does not have a known closed form like above.) Args: x: `float` `Tensor` representing the input to the SoftThreshold function. threshold: nonnegative scalar, `float` `Tensor` representing the radius of the interval on which each coordinate of SoftThreshold takes the value zero. Denoted `gamma` above. name: Python string indicating the name of the TensorFlow operation. Default value: `'soft_threshold'`. Returns: softthreshold: `float` `Tensor` with the same shape and dtype as `x`, representing the value of the SoftThreshold function. #### References [1]: Yu, Yao-Liang. The Proximity Operator. https://www.cs.cmu.edu/~suvrit/teach/yaoliang_proximity.pdf [2]: Wikipedia Contributors. Proximal gradient methods for learning. _Wikipedia, The Free Encyclopedia_, 2018. https://en.wikipedia.org/wiki/Proximal_gradient_methods_for_learning """ # https://math.stackexchange.com/questions/471339/derivation-of-soft-thresholding-operator with tf.name_scope(name or 'soft_threshold'): x = tf.convert_to_tensor(x, name='x') threshold = tf.convert_to_tensor(threshold, dtype=x.dtype, name='threshold') return tf.sign(x) * tf.maximum(tf.abs(x) - threshold, 0.)
def __init__(self, df, scale_operator, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name=None): """Construct Wishart distributions. Args: df: `float` or `double` tensor, the degrees of freedom of the distribution(s). `df` must be greater than or equal to `k`. scale_operator: `float` or `double` instance of `LinearOperator`. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: TypeError: if scale is not floating-type TypeError: if scale.dtype != df.dtype ValueError: if df < k, where scale operator event shape is `(k, k)` """ parameters = dict(locals()) self._input_output_cholesky = input_output_cholesky with tf.name_scope(name) as name: with tf.name_scope("init"): if not dtype_util.is_floating(scale_operator.dtype): raise TypeError( "scale_operator.dtype=%s is not a floating-point type" % scale_operator.dtype) if not scale_operator.is_square: print(scale_operator.to_dense().eval()) raise ValueError("scale_operator must be square.") self._scale_operator = scale_operator self._df = tf.convert_to_tensor(df, dtype=scale_operator.dtype, name="df") dtype_util.assert_same_float_dtype( [self._df, self._scale_operator]) if tf.compat.dimension_value( self._scale_operator.shape[-1]) is None: self._dimension = tf.cast( self._scale_operator.domain_dimension_tensor(), dtype=self._scale_operator.dtype, name="dimension") else: self._dimension = tf.convert_to_tensor( tf.compat.dimension_value( self._scale_operator.shape[-1]), dtype=self._scale_operator.dtype, name="dimension") df_val = tf.get_static_value(self._df) dim_val = tf.get_static_value(self._dimension) if df_val is not None and dim_val is not None: df_val = np.asarray(df_val) if not df_val.shape: df_val = [df_val] if np.any(df_val < dim_val): raise ValueError( "Degrees of freedom (df = %s) cannot be less than " "dimension of scale matrix (scale.dimension = %s)" % (df_val, dim_val)) elif validate_args: assertions = assert_util.assert_less_equal( self._dimension, self._df, message=("Degrees of freedom (df = %s) cannot be " "less than dimension of scale matrix " "(scale.dimension = %s)" % (self._dimension, self._df))) self._df = distribution_util.with_dependencies( [assertions], self._df) super(_WishartLinearOperator, self).__init__( dtype=self._scale_operator.dtype, validate_args=validate_args, allow_nan_stats=allow_nan_stats, reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, parameters=parameters, name=name)
def __init__(self, loc, scale, low, high, validate_args=False, allow_nan_stats=True, name="TruncatedNormal"): """Construct TruncatedNormal. All parameters of the distribution will be broadcast to the same shape, so the resulting distribution will have a batch_shape of the broadcast shape of all parameters. Args: loc: Floating point tensor; the mean of the normal distribution(s) ( note that the mean of the resulting distribution will be different since it is modified by the bounds). scale: Floating point tensor; the std deviation of the normal distribution(s). low: `float` `Tensor` representing lower bound of the distribution's support. Must be such that `low < high`. high: `float` `Tensor` representing upper bound of the distribution's support. Must be such that `low < high`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked at run-time. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale, low, high], tf.float32) loc = tf.convert_to_tensor(loc, name="loc", dtype=dtype) scale = tf.convert_to_tensor(scale, name="scale", dtype=dtype) low = tf.convert_to_tensor(low, name="low", dtype=dtype) high = tf.convert_to_tensor(high, name="high", dtype=dtype) dtype_util.assert_same_float_dtype([loc, scale, low, high]) self._broadcast_batch_shape = distribution_util.get_broadcast_shape( loc, scale, low, high) # Broadcast all parameters to the same shape broadcast_ones = tf.ones(shape=self._broadcast_batch_shape, dtype=scale.dtype) self._scale = scale * broadcast_ones self._loc = loc * broadcast_ones self._low = low * broadcast_ones self._high = high * broadcast_ones with tf.control_dependencies( [self._validate()] if validate_args else []): self._loc = tf.identity(self._loc) super(TruncatedNormal, self).__init__( dtype=dtype, # This distribution is fully reparameterized. loc, scale have straight # through gradients. The gradients for the bounds are implemented using # custom derived expressions based on implicit gradients. # For the special case of lower bound zero and a positive upper bound # an equivalent expression can also be found in Sec 9.1.1. # of https://arxiv.org/pdf/1806.01851.pdf. The implementation here # handles arbitrary bounds. reparameterization_type=reparameterization.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name)
def __init__(self, df, scale=None, scale_tril=None, input_output_cholesky=False, validate_args=False, allow_nan_stats=True, name="Wishart"): """Construct Wishart distributions. Args: df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than or equal to dimension of the scale matrix. scale: `float` or `double` `Tensor`. The symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. scale_tril: `float` or `double` `Tensor`. The Cholesky factorization of the symmetric positive definite scale matrix of the distribution. Exactly one of `scale` and 'scale_tril` must be passed. input_output_cholesky: Python `bool`. If `True`, functions whose input or output have the semantics of samples assume inputs are in Cholesky form and return outputs in Cholesky form. In particular, if this flag is `True`, input to `log_prob` is presumed of Cholesky form and output from `sample`, `mean`, and `mode` are of Cholesky form. Setting this argument to `True` is purely a computational optimization and does not change the underlying distribution; for instance, `mean` returns the Cholesky of the mean, not the mean of Cholesky factors. The `variance` and `stddev` methods are unaffected by this flag. Default value: `False` (i.e., input/output does not have Cholesky semantics). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if zero or both of 'scale' and 'scale_tril' are passed in. """ parameters = dict(locals()) with tf.name_scope(name) as name: with tf.name_scope("init"): if (scale is None) == (scale_tril is None): raise ValueError( "Must pass scale or scale_tril, but not both.") dtype = dtype_util.common_dtype([df, scale, scale_tril], tf.float32) df = tf.convert_to_tensor(df, name="df", dtype=dtype) if scale is not None: scale = tf.convert_to_tensor(scale, name="scale", dtype=dtype) if validate_args: scale = distribution_util.assert_symmetric(scale) scale_tril = tf.linalg.cholesky(scale) else: # scale_tril is not None scale_tril = tf.convert_to_tensor(scale_tril, name="scale_tril", dtype=dtype) if validate_args: scale_tril = distribution_util.with_dependencies([ assert_util.assert_positive( tf.linalg.diag_part(scale_tril), message="scale_tril must be positive definite" ), assert_util.assert_equal( tf.shape(scale_tril)[-1], tf.shape(scale_tril)[-2], message="scale_tril must be square") ], scale_tril) super(Wishart, self).__init__( df=df, scale_operator=tf.linalg.LinearOperatorLowerTriangular( tril=scale_tril, is_non_singular=True, is_positive_definite=True, is_square=True), input_output_cholesky=input_output_cholesky, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, loc=None, scale_diag=None, scale_identity_multiplier=None, validate_args=False, allow_nan_stats=True, name="VectorExponentialDiag"): """Construct Vector Exponential distribution supported on a subset of `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. ```none scale = diag(scale_diag + scale_identity_multiplier * ones(k)) ``` where: * `scale_diag.shape = [k]`, and, * `scale_identity_multiplier.shape = []`. Additional leading dimensions (if any) will index batches. If both `scale_diag` and `scale_identity_multiplier` are `None`, then `scale` is the Identity matrix. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_diag: Non-zero, floating-point `Tensor` representing a diagonal matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, and characterizes `b`-batches of `k x k` diagonal matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. scale_identity_multiplier: Non-zero, floating-point `Tensor` representing a scaled-identity-matrix added to `scale`. May have shape `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled `k x k` identity matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ parameters = dict(locals()) with tf.name_scope(name) as name: with tf.name_scope("init"): # No need to validate_args while making diag_scale. The returned # LinearOperatorDiag has an assert_non_singular method that is called by # the Bijector. scale = distribution_util.make_diag_scale( loc=loc, scale_diag=scale_diag, scale_identity_multiplier=scale_identity_multiplier, validate_args=False, assert_positive=False) super(VectorExponentialDiag, self).__init__(loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def _z(self, x): """Standardize input `x` to a unit logistic.""" with tf.name_scope('standardize'): return (x - self.loc) / self.scale
def __init__(self, loc=None, scale_diag=None, scale_identity_multiplier=None, skewness=None, tailweight=None, distribution=None, validate_args=False, allow_nan_stats=True, name="VectorSinhArcsinhDiag"): """Construct VectorSinhArcsinhDiag distribution on `R^k`. The arguments `scale_diag` and `scale_identity_multiplier` combine to define the diagonal `scale` referred to in this class docstring: ```none scale = diag(scale_diag + scale_identity_multiplier * ones(k)) ``` The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_diag: Non-zero, floating-point `Tensor` representing a diagonal matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, and characterizes `b`-batches of `k x k` diagonal matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. scale_identity_multiplier: Non-zero, floating-point `Tensor` representing a scale-identity-matrix added to `scale`. May have shape `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale `k x k` identity matrices added to `scale`. When both `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is the `Identity`. skewness: Skewness parameter. floating-point `Tensor` with shape broadcastable with `event_shape`. tailweight: Tailweight parameter. floating-point `Tensor` with shape broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is `tfd.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through a VectorSinhArcsinhDiag sample and `distribution` is not `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then the gradient will be incorrect! validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([ loc, scale_diag, scale_identity_multiplier, skewness, tailweight ], tf.float32) loc = loc if loc is None else tf.convert_to_tensor( loc, name="loc", dtype=dtype) tailweight = 1. if tailweight is None else tailweight skewness = 0. if skewness is None else skewness # Recall, with Z a random variable, # Y := loc + C * F(Z), # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) # C := 2 * scale / F_0(2) # Construct shapes and 'scale' out of the scale_* and loc kwargs. # scale_linop is only an intermediary to: # 1. get shapes from looking at loc and the two scale args. # 2. combine scale_diag with scale_identity_multiplier, which gives us # 'scale', which in turn gives us 'C'. scale_linop = distribution_util.make_diag_scale( loc=loc, scale_diag=scale_diag, scale_identity_multiplier=scale_identity_multiplier, validate_args=False, assert_positive=False, dtype=dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale_linop) # scale_linop.diag_part() is efficient since it is a diag type linop. scale_diag_part = scale_linop.diag_part() dtype = scale_diag_part.dtype if distribution is None: distribution = normal.Normal(loc=tf.zeros([], dtype=dtype), scale=tf.ones([], dtype=dtype), allow_nan_stats=allow_nan_stats) else: asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: scale_diag_part = distribution_util.with_dependencies( asserts, scale_diag_part) # Make the SAS bijector, 'F'. skewness = tf.convert_to_tensor(skewness, dtype=dtype, name="skewness") tailweight = tf.convert_to_tensor(tailweight, dtype=dtype, name="tailweight") f = sinh_arcsinh_bijector.SinhArcsinh(skewness=skewness, tailweight=tailweight) affine = affine_bijector.Affine(shift=loc, scale_diag=scale_diag_part, validate_args=validate_args) bijector = chain_bijector.Chain([affine, f]) super(VectorSinhArcsinhDiag, self).__init__(distribution=distribution, bijector=bijector, batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args, name=name) self._parameters = parameters self._loc = loc self._scale = scale_linop self._tailweight = tailweight self._skewness = skewness
def __init__(self, validate_args=False, name="softmax_centered"): with tf.name_scope(name) as name: super(SoftmaxCentered, self).__init__(forward_min_event_ndims=1, validate_args=validate_args, name=name)
def fill_triangular_inverse(x, upper=False, name=None): """Creates a vector from a (batch of) triangular matrix. The vector is created from the lower-triangular or upper-triangular portion depending on the value of the parameter `upper`. If `x.shape` is `[b1, b2, ..., bB, n, n]` then the output shape is `[b1, b2, ..., bB, d]` where `d = n (n + 1) / 2`. Example: ```python fill_triangular_inverse( [[4, 0, 0], [6, 5, 0], [3, 2, 1]]) # ==> [1, 2, 3, 4, 5, 6] fill_triangular_inverse( [[1, 2, 3], [0, 5, 6], [0, 0, 4]], upper=True) # ==> [1, 2, 3, 4, 5, 6] ``` Args: x: `Tensor` representing lower (or upper) triangular elements. upper: Python `bool` representing whether output matrix should be upper triangular (`True`) or lower triangular (`False`, default). name: Python `str`. The name to give this op. Returns: flat_tril: (Batch of) vector-shaped `Tensor` representing vectorized lower (or upper) triangular elements from `x`. """ with tf.name_scope(name or 'fill_triangular_inverse'): x = tf.convert_to_tensor(x, name='x') n = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(x.shape, 2)[-1]) if n is not None: n = np.int32(n) m = np.int32((n * (n + 1)) // 2) static_final_shape = tensorshape_util.concatenate( x.shape[:-2], [m]) else: n = tf.shape(x)[-1] m = (n * (n + 1)) // 2 static_final_shape = tensorshape_util.concatenate( tensorshape_util.with_rank_at_least(x.shape, 2)[:-2], [None]) ndims = prefer_static.rank(x) if upper: initial_elements = x[..., 0, :] triangular_portion = x[..., 1:, :] else: initial_elements = tf.reverse(x[..., -1, :], axis=[ndims - 2]) triangular_portion = x[..., :-1, :] rotated_triangular_portion = tf.reverse(tf.reverse(triangular_portion, axis=[ndims - 1]), axis=[ndims - 2]) consolidated_matrix = triangular_portion + rotated_triangular_portion end_sequence = tf.reshape( consolidated_matrix, tf.concat([tf.shape(x)[:-2], [n * (n - 1)]], axis=0)) y = tf.concat([initial_elements, end_sequence[..., :m - n]], axis=-1) tensorshape_util.set_shape(y, static_final_shape) return y
def __init__(self, validate_args=False, name="identity"): with tf.name_scope(name) as name: super(Identity, self).__init__(forward_min_event_ndims=0, is_constant_jacobian=True, validate_args=validate_args, name=name)
def pinv(a, rcond=None, validate_args=False, name=None): """Compute the Moore-Penrose pseudo-inverse of a matrix. Calculate the [generalized inverse of a matrix]( https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse) using its singular-value decomposition (SVD) and including all large singular values. The pseudo-inverse of a matrix `A`, is defined as: 'the matrix that 'solves' [the least-squares problem] `A @ x = b`,' i.e., if `x_hat` is a solution, then `A_pinv` is the matrix such that `x_hat = A_pinv @ b`. It can be shown that if `U @ Sigma @ V.T = A` is the singular value decomposition of `A`, then `A_pinv = V @ inv(Sigma) U^T`. [(Strang, 1980)][1] This function is analogous to [`numpy.linalg.pinv`]( https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.pinv.html). It differs only in default value of `rcond`. In `numpy.linalg.pinv`, the default `rcond` is `1e-15`. Here the default is `10. * max(num_rows, num_cols) * np.finfo(dtype).eps`. Args: a: (Batch of) `float`-like matrix-shaped `Tensor`(s) which are to be pseudo-inverted. rcond: `Tensor` of small singular value cutoffs. Singular values smaller (in modulus) than `rcond` * largest_singular_value (again, in modulus) are set to zero. Must broadcast against `tf.shape(a)[:-2]`. Default value: `10. * max(num_rows, num_cols) * np.finfo(a.dtype).eps`. validate_args: When `True`, additional assertions might be embedded in the graph. Default value: `False` (i.e., no graph assertions are added). name: Python `str` prefixed to ops created by this function. Default value: 'pinv'. Returns: a_pinv: The pseudo-inverse of input `a`. Has same shape as `a` except rightmost two dimensions are transposed. Raises: TypeError: if input `a` does not have `float`-like `dtype`. ValueError: if input `a` has fewer than 2 dimensions. #### Examples ```python from tensorflow_probability.python.internal.backend import jax as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax a = tf.constant([[1., 0.4, 0.5], [0.4, 0.2, 0.25], [0.5, 0.25, 0.35]]) tf.matmul(tfp.math.pinv(a), a) # ==> array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) a = tf.constant([[1., 0.4, 0.5, 1.], [0.4, 0.2, 0.25, 2.], [0.5, 0.25, 0.35, 3.]]) tf.matmul(tfp.math.pinv(a), a) # ==> array([[ 0.76, 0.37, 0.21, -0.02], [ 0.37, 0.43, -0.33, 0.02], [ 0.21, -0.33, 0.81, 0.01], [-0.02, 0.02, 0.01, 1. ]], dtype=float32) ``` #### References [1]: G. Strang. 'Linear Algebra and Its Applications, 2nd Ed.' Academic Press, Inc., 1980, pp. 139-142. """ with tf.name_scope(name or 'pinv'): a = tf.convert_to_tensor(a, name='a') assertions = _maybe_validate_matrix(a, validate_args) if assertions: with tf.control_dependencies(assertions): a = tf.identity(a) dtype = dtype_util.as_numpy_dtype(a.dtype) if rcond is None: def get_dim_size(dim): if tf.compat.dimension_value(a.shape[dim]) is not None: return tf.compat.dimension_value(a.shape[dim]) return tf.shape(a)[dim] num_rows = get_dim_size(-2) num_cols = get_dim_size(-1) if isinstance(num_rows, int) and isinstance(num_cols, int): max_rows_cols = float(max(num_rows, num_cols)) else: max_rows_cols = tf.cast(tf.maximum(num_rows, num_cols), dtype) rcond = 10. * max_rows_cols * np.finfo(dtype).eps rcond = tf.convert_to_tensor(rcond, dtype=dtype, name='rcond') # Calculate pseudo inverse via SVD. # Note: if a is symmetric then u == v. (We might observe additional # performance by explicitly setting `v = u` in such cases.) [ singular_values, # Sigma left_singular_vectors, # U right_singular_vectors, # V ] = tf.linalg.svd(a, full_matrices=False, compute_uv=True) # Saturate small singular values to inf. This has the effect of make # `1. / s = 0.` while not resulting in `NaN` gradients. cutoff = rcond * tf.reduce_max(singular_values, axis=-1) singular_values = tf.where(singular_values > cutoff[..., tf.newaxis], singular_values, np.array(np.inf, dtype)) # Although `a == tf.matmul(u, s * v, transpose_b=True)` we swap # `u` and `v` here so that `tf.matmul(pinv(A), A) = tf.eye()`, i.e., # a matrix inverse has 'transposed' semantics. a_pinv = tf.matmul(right_singular_vectors / singular_values[..., tf.newaxis, :], left_singular_vectors, adjoint_b=True) if tensorshape_util.rank(a.shape) is not None: a_pinv.set_shape(a.shape[:-2].concatenate( [a.shape[-1], a.shape[-2]])) return a_pinv
def __init__(self, df, loc=None, scale_identity_multiplier=None, scale_diag=None, scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, validate_args=False, allow_nan_stats=True, name="VectorStudentT"): """Instantiates the vector Student's t-distributions on `R^k`. The `batch_shape` is the broadcast between `df.batch_shape` and `Affine.batch_shape` where `Affine` is constructed from `loc` and `scale_*` arguments. The `event_shape` is the event shape of `Affine.event_shape`. Args: df: Floating-point `Tensor`. The degrees of freedom of the distribution(s). `df` must contain only positive values. Must be scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the same `batch_shape` implied by `loc`, `scale_*`. loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is applied. scale_identity_multiplier: floating point rank 0 `Tensor` representing a scaling done to the identity matrix. When `scale_identity_multiplier = scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added to `scale`. scale_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k], which represents a k x k diagonal matrix. When `None` no diagonal term is added to `scale`. scale_tril: Floating-point `Tensor` representing the diagonal matrix. `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k lower triangular matrix. When `None` no `scale_tril` term is added to `scale`. The upper triangular elements above the diagonal are ignored. scale_perturb_factor: Floating-point `Tensor` representing factor matrix with last two dimensions of shape `(k, r)`. When `None`, no rank-r update is added to `scale`. scale_perturb_diag: Floating-point `Tensor` representing the diagonal matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which represents an r x r Diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ parameters = dict(locals()) args = [ df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag ] with tf.name_scope(name) as name: with tf.name_scope("init"): dtype = dtype_util.common_dtype(args, tf.float32) df = tf.convert_to_tensor(df, name="df", dtype=dtype) # The shape of the _VectorStudentT distribution is governed by the # relationship between df.batch_shape and affine.batch_shape. In # pseudocode the basic procedure is: # if df.batch_shape is scalar: # if affine.batch_shape is not scalar: # # broadcast distribution.sample so # # it has affine.batch_shape. # self.batch_shape = affine.batch_shape # else: # if affine.batch_shape is scalar: # # let affine broadcasting do its thing. # self.batch_shape = df.batch_shape # All of the above magic is actually handled by TransformedDistribution. # Here we really only need to collect the affine.batch_shape and decide # what we're going to pass in to TransformedDistribution's # (override) batch_shape arg. affine = affine_bijector.Affine( shift=loc, scale_identity_multiplier=scale_identity_multiplier, scale_diag=scale_diag, scale_tril=scale_tril, scale_perturb_factor=scale_perturb_factor, scale_perturb_diag=scale_perturb_diag, validate_args=validate_args, dtype=dtype) distribution = student_t.StudentT( df=df, loc=tf.zeros([], dtype=affine.dtype), scale=tf.ones([], dtype=affine.dtype)) batch_shape, override_event_shape = ( distribution_util.shapes_from_loc_and_scale( affine.shift, affine.scale)) override_batch_shape = distribution_util.pick_vector( distribution.is_scalar_batch(), batch_shape, tf.constant([], dtype=tf.int32)) super(_VectorStudentT, self).__init__(distribution=distribution, bijector=affine, batch_shape=override_batch_shape, event_shape=override_event_shape, validate_args=validate_args, name=name) self._parameters = parameters
def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None): """Computes a matrix inverse given the matrix's LU decomposition. This op is conceptually identical to, ```python inv_X = tf.lu_matrix_inverse(*tf.linalg.lu(X)) tf.assert_near(tf.matrix_inverse(X), inv_X) # ==> True ``` Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_matrix_inverse'). Returns: inv_x: The matrix_inv, i.e., `tf.matrix_inverse(tfp.math.lu_reconstruct(lu, perm))`. #### Examples ```python import numpy as np from tensorflow_probability.python.internal.backend import jax as tf import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] inv_x = tfp.math.lu_matrix_inverse(*tf.linalg.lu(x)) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with tf.name_scope(name or 'lu_matrix_inverse'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) shape = tf.shape(lower_upper) return lu_solve(lower_upper, perm, rhs=tf.eye(shape[-1], batch_shape=shape[:-2], dtype=lower_upper.dtype), validate_args=False)
def __init__(self, forward_fn=None, inverse_fn=None, inverse_log_det_jacobian_fn=None, forward_log_det_jacobian_fn=None, forward_event_shape_fn=None, forward_event_shape_tensor_fn=None, inverse_event_shape_fn=None, inverse_event_shape_tensor_fn=None, is_constant_jacobian=False, validate_args=False, forward_min_event_ndims=None, inverse_min_event_ndims=None, name='inline'): """Creates a `Bijector` from callables. At the minimum, you must supply one of `forward_min_event_ndims` or `inverse_min_event_ndims`. To be fully functional, a typical bijector will also require `forward_fn`, `inverse_fn` and at least one of `inverse_log_det_jacobian_fn` or `forward_log_det_jacobian_fn`. Args: forward_fn: Python callable implementing the forward transformation. inverse_fn: Python callable implementing the inverse transformation. inverse_log_det_jacobian_fn: Python callable implementing the `log o det o jacobian` of the inverse transformation. forward_log_det_jacobian_fn: Python callable implementing the `log o det o jacobian` of the forward transformation. forward_event_shape_fn: Python callable implementing non-identical static event shape changes. Default: shape is assumed unchanged. forward_event_shape_tensor_fn: Python callable implementing non-identical event shape changes. Default: shape is assumed unchanged. inverse_event_shape_fn: Python callable implementing non-identical static event shape changes. Default: shape is assumed unchanged. inverse_event_shape_tensor_fn: Python callable implementing non-identical event shape changes. Default: shape is assumed unchanged. is_constant_jacobian: Python `bool` indicating that the Jacobian is constant for all input arguments. validate_args: Python `bool` indicating whether arguments should be checked for correctness. forward_min_event_ndims: Python `int` indicating the minimal dimensionality this bijector acts on. inverse_min_event_ndims: Python `int` indicating the minimal dimensionality this bijector acts on. name: Python `str`, name given to ops managed by this object. Raises: TypeError: If any of the non-`None` `*_fn` arguments are not callable. """ with tf.name_scope(name) as name: self._maybe_implement(forward_fn, '_forward', 'forward_fn') self._maybe_implement(inverse_fn, '_inverse', 'inverse_fn') self._maybe_implement(inverse_log_det_jacobian_fn, '_inverse_log_det_jacobian', 'inverse_log_det_jacobian_fn') self._maybe_implement(forward_log_det_jacobian_fn, '_forward_log_det_jacobian', 'forward_log_det_jacobian_fn') # By default assume shape doesn't change. self._forward_event_shape = _maybe_impute_as_identity( forward_event_shape_fn, 'forward_event_shape_fn') self._forward_event_shape_tensor = _maybe_impute_as_identity( forward_event_shape_tensor_fn, 'forward_event_shape_tensor_fn') self._inverse_event_shape = _maybe_impute_as_identity( inverse_event_shape_fn, 'inverse_event_shape_fn') self._inverse_event_shape_tensor = _maybe_impute_as_identity( inverse_event_shape_tensor_fn, 'inverse_event_shape_tensor_fn') super(Inline, self).__init__( forward_min_event_ndims=forward_min_event_ndims, inverse_min_event_ndims=inverse_min_event_ndims, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, name=name)
def fill_triangular(x, upper=False, name=None): """Creates a (batch of) triangular matrix from a vector of inputs. Created matrix can be lower- or upper-triangular. (It is more efficient to create the matrix as upper or lower, rather than transpose.) Triangular matrix elements are filled in a clockwise spiral. See example, below. If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. Example: ```python fill_triangular([1, 2, 3, 4, 5, 6]) # ==> [[4, 0, 0], # [6, 5, 0], # [3, 2, 1]] fill_triangular([1, 2, 3, 4, 5, 6], upper=True) # ==> [[1, 2, 3], # [0, 5, 6], # [0, 0, 4]] ``` The key trick is to create an upper triangular matrix by concatenating `x` and a tail of itself, then reshaping. Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M` from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x` contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5` (so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with the first (`n = 5`) elements removed and reversed: ```python x = np.arange(15) + 1 xc = np.concatenate([x, x[5:][::-1]]) # ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13, # 12, 11, 10, 9, 8, 7, 6]) # (We add one to the arange result to disambiguate the zeros below the # diagonal of our upper-triangular matrix from the first entry in `x`.) # Now, when reshapedlay this out as a matrix: y = np.reshape(xc, [5, 5]) # ==> array([[ 1, 2, 3, 4, 5], # [ 6, 7, 8, 9, 10], # [11, 12, 13, 14, 15], # [15, 14, 13, 12, 11], # [10, 9, 8, 7, 6]]) # Finally, zero the elements below the diagonal: y = np.triu(y, k=0) # ==> array([[ 1, 2, 3, 4, 5], # [ 0, 7, 8, 9, 10], # [ 0, 0, 13, 14, 15], # [ 0, 0, 0, 12, 11], # [ 0, 0, 0, 0, 6]]) ``` From this example we see that the resuting matrix is upper-triangular, and contains all the entries of x, as desired. The rest is details: - If `n` is even, `x` doesn't exactly fill an even number of rows (it fills `n / 2` rows and half of an additional row), but the whole scheme still works. - If we want a lower triangular matrix instead of an upper triangular, we remove the first `n` elements from `x` rather than from the reversed `x`. For additional comparisons, a pure numpy version of this function can be found in `distribution_util_test.py`, function `_fill_triangular`. Args: x: `Tensor` representing lower (or upper) triangular elements. upper: Python `bool` representing whether output matrix should be upper triangular (`True`) or lower triangular (`False`, default). name: Python `str`. The name to give this op. Returns: tril: `Tensor` with lower (or upper) triangular elements filled from `x`. Raises: ValueError: if `x` cannot be mapped to a triangular matrix. """ with tf.name_scope(name or 'fill_triangular'): x = tf.convert_to_tensor(x, name='x') m = tf.compat.dimension_value( tensorshape_util.with_rank_at_least(x.shape, 1)[-1]) if m is not None: # Formula derived by solving for n: m = n(n+1)/2. m = np.int32(m) n = np.sqrt(0.25 + 2. * m) - 0.5 if n != np.floor(n): raise ValueError( 'Input right-most shape ({}) does not ' 'correspond to a triangular matrix.'.format(m)) n = np.int32(n) static_final_shape = tensorshape_util.concatenate( x.shape[:-1], [n, n]) else: m = tf.shape(x)[-1] # For derivation, see above. Casting automatically lops off the 0.5, so we # omit it. We don't validate n is an integer because this has # graph-execution cost; an error will be thrown from the reshape, below. n = tf.cast(tf.sqrt(0.25 + tf.cast(2 * m, dtype=tf.float32)), dtype=tf.int32) static_final_shape = tensorshape_util.concatenate( tensorshape_util.with_rank_at_least(x.shape, 1)[:-1], [None, None]) # Try it out in numpy: # n = 3 # x = np.arange(n * (n + 1) / 2) # m = x.shape[0] # n = np.int32(np.sqrt(.25 + 2 * m) - .5) # x_tail = x[(m - (n**2 - m)):] # np.concatenate([x_tail, x[::-1]], 0).reshape(n, n) # lower # # ==> array([[3, 4, 5], # [5, 4, 3], # [2, 1, 0]]) # np.concatenate([x, x_tail[::-1]], 0).reshape(n, n) # upper # # ==> array([[0, 1, 2], # [3, 4, 5], # [5, 4, 3]]) # # Note that we can't simply do `x[..., -(n**2 - m):]` because this doesn't # correctly handle `m == n == 1`. Hence, we do nonnegative indexing. # Furthermore observe that: # m - (n**2 - m) # = n**2 / 2 + n / 2 - (n**2 - n**2 / 2 + n / 2) # = 2 (n**2 / 2 + n / 2) - n**2 # = n**2 + n - n**2 # = n ndims = prefer_static.rank(x) if upper: x_list = [x, tf.reverse(x[..., n:], axis=[ndims - 1])] else: x_list = [x[..., n:], tf.reverse(x, axis=[ndims - 1])] new_shape = (tensorshape_util.as_list(static_final_shape) if tensorshape_util.is_fully_defined(static_final_shape) else tf.concat([tf.shape(x)[:-1], [n, n]], axis=0)) x = tf.reshape(tf.concat(x_list, axis=-1), new_shape) x = tf.linalg.band_part(x, num_lower=(0 if upper else -1), num_upper=(-1 if upper else 0)) tensorshape_util.set_shape(x, static_final_shape) return x
def __init__(self, perm=None, rightmost_transposed_ndims=None, validate_args=False, name='transpose'): """Instantiates the `Transpose` bijector. Args: perm: Positive `int32` vector-shaped `Tensor` representing permutation of rightmost dims (for forward transformation). Note that the `0`th index represents the first of the rightmost dims and the largest value must be `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. Default value: `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`. rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor` representing the number of rightmost dimensions to permute. Only one of `perm` and `rightmost_transposed_ndims` can (and must) be specified. Default value: `tf.size(perm)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are specified. NotImplementedError: if `rightmost_transposed_ndims` is not known prior to graph execution. """ with tf.name_scope(name) as name: if (rightmost_transposed_ndims is None) == (perm is None): raise ValueError('Must specify exactly one of ' '`rightmost_transposed_ndims` and `perm`.') if rightmost_transposed_ndims is not None: rightmost_transposed_ndims = tf.convert_to_tensor( rightmost_transposed_ndims, dtype_hint=np.int32, name='rightmost_transposed_ndims') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) assertions = _maybe_validate_rightmost_transposed_ndims( rightmost_transposed_ndims, validate_args) if assertions: with tf.control_dependencies(assertions): rightmost_transposed_ndims = tf.identity( rightmost_transposed_ndims) perm_start = (distribution_util.prefer_static_value( rightmost_transposed_ndims) - 1) perm = tf.range(start=perm_start, limit=-1, delta=-1, name='perm') else: # perm is not None: perm = tf.convert_to_tensor(perm, dtype_hint=np.int32, name='perm') rightmost_transposed_ndims = tf.size( perm, name='rightmost_transposed_ndims') rightmost_transposed_ndims_ = tf.get_static_value( rightmost_transposed_ndims) assertions = _maybe_validate_perm(perm, validate_args) if assertions: with tf.control_dependencies(assertions): perm = tf.identity(perm) # TODO(b/110828604): If bijector base class ever supports dynamic # `min_event_ndims`, then this class already works dynamically and the # following five lines can be removed. if rightmost_transposed_ndims_ is None: raise NotImplementedError( '`rightmost_transposed_ndims` must be ' 'known prior to graph execution.') else: rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_) self._perm = perm self._rightmost_transposed_ndims = rightmost_transposed_ndims super(Transpose, self).__init__( forward_min_event_ndims=rightmost_transposed_ndims_, is_constant_jacobian=True, validate_args=validate_args, name=name)
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( mcmc_util.make_name(self.name, 'tmc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') target_score_for_inner_kernel = partial(self.target_score_fn, sigma=inverse_temperatures) target_log_prob_for_inner_kernel = partial( self.target_log_prob_fn, sigma=inverse_temperatures) # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.make_kernel_fn`. # In other words: # - We try `make_kernel_fn` without a seed first; this is the future. The # kernel will receive a seed later, as part of `one_step`. # - If the user code doesn't like that (Python complains about a missing # required argument), we fall back to the previous behavior and warn. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, target_score_for_inner_kernel, inverse_temperatures) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The second (`seed`) argument to `ReplicaExchangeMC`s ' '`make_kernel_fn` is deprecated. `TransitionKernel` instances now ' 'receive seeds via `bootstrap_results` and `one_step`. This ' 'fallback may become an error 2020-09-20.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, target_score_for_inner_kernel, inverse_temperatures, self._seed_stream()) inner_results = inner_kernel.bootstrap_results(init_state) post_tempering_results = inner_results # Invalidating and recomputing results [ new_target_log_prob, new_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads( partial(self.target_log_prob_fn, sigma=inverse_temperatures), init_state) # Updating inner kernel results dtype = inverse_temperatures.dtype post_tempering_results = post_tempering_results._replace( proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype), proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype), ) if isinstance(post_tempering_results.accepted_results, hmc.UncalibratedHamiltonianMonteCarloKernelResults): post_tempering_results = post_tempering_results._replace( accepted_results=post_tempering_results.accepted_results. _replace(target_log_prob=new_target_log_prob, grads_target_log_prob=new_grads_target_log_prob)) elif isinstance( post_tempering_results.accepted_results, random_walk_metropolis.UncalibratedRandomWalkResults): post_tempering_results = post_tempering_results._replace( accepted_results=post_tempering_results.accepted_results. _replace(target_log_prob=new_target_log_prob)) else: # TODO(b/143702650) Handle other kernels. raise NotImplementedError( 'Only HMC and RWMH Kernels are handled at this time. Please file a ' 'request with the TensorFlow Probability team.') return TemperedMCKernelResults( pre_tempering_results=inner_results, post_tempering_results=post_tempering_results, pre_tempering_inverse_temperatures=inverse_temperatures, post_tempering_inverse_temperatures=inverse_temperatures, tempering_log_accept_ratio=tf.zeros_like(inverse_temperatures), steps_at_temperature=tf.zeros_like(inverse_temperatures, dtype=tf.int32), seed=samplers.zeros_seed(), )
def __init__(self, loc=None, scale_tril=None, validate_args=False, allow_nan_stats=True, name='MultivariateNormalTriL'): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `scale` arguments. The `event_shape` is given by last dimension of the matrix implied by `scale`. The last dimension of `loc` (if provided) must broadcast with this. Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: ```none scale = scale_tril ``` where `scale_tril` is lower-triangular `k x k` matrix with non-zero diagonal, i.e., `tf.diag_part(scale_tril) != 0`. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. scale_tril: Floating-point, lower-triangular `Tensor` with non-zero diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ parameters = dict(locals()) if loc is None and scale_tril is None: raise ValueError( 'Must specify one or both of `loc`, `scale_tril`.') with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([loc, scale_tril], tf.float32) loc = tensor_util.convert_nonref_to_tensor(loc, name='loc', dtype=dtype) scale_tril = tensor_util.convert_nonref_to_tensor( scale_tril, name='scale_tril', dtype=dtype) if scale_tril is None: scale = tf.linalg.LinearOperatorIdentity( num_rows=distribution_util.dimension_size(loc, -1), dtype=loc.dtype, is_self_adjoint=True, is_positive_definite=True, assert_proper_shapes=validate_args) else: # No need to validate that scale_tril is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular # method that is called by the Bijector. scale = tf.linalg.LinearOperatorLowerTriangular( scale_tril, is_non_singular=True, is_self_adjoint=False, is_positive_definite=False) super(MultivariateNormalTriL, self).__init__(loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name='RelaxedBernoulli'): """Construct RelaxedBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of RelaxedBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent RelaxedBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ parameters = dict(locals()) with tf.name_scope(name) as name: dtype = dtype_util.common_dtype([logits, probs, temperature], tf.float32) self._temperature = tf.convert_to_tensor(temperature, name='temperature', dtype=dtype) if validate_args: with tf.control_dependencies( [assert_util.assert_positive(temperature)]): self._temperature = tf.identity(self._temperature) self._logits, self._probs = distribution_util.get_logits_and_probs( logits=logits, probs=probs, validate_args=validate_args, dtype=dtype) super(RelaxedBernoulli, self).__init__( distribution=logistic.Logistic(self._logits / self._temperature, 1. / self._temperature, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name + '/Logistic'), bijector=sigmoid_bijector.Sigmoid(validate_args=validate_args), validate_args=validate_args, name=name) self._parameters = parameters
def __init__(self, loc=None, covariance_matrix=None, validate_args=False, allow_nan_stats=True, name="MultivariateNormalFullCovariance"): """Construct Multivariate Normal distribution on `R^k`. The `batch_shape` is the broadcast shape between `loc` and `covariance_matrix` arguments. The `event_shape` is given by last dimension of the matrix implied by `covariance_matrix`. The last dimension of `loc` (if provided) must broadcast with this. A non-batch `covariance_matrix` matrix is a `k x k` symmetric positive definite matrix. In other words it is (real) symmetric with all eigenvalues strictly positive. Additional leading dimensions (if any) will index batches. Args: loc: Floating-point `Tensor`. If this is set to `None`, `loc` is implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where `b >= 0` and `k` is the event size. covariance_matrix: Floating-point, symmetric positive definite `Tensor` of same `dtype` as `loc`. The strict upper triangle of `covariance_matrix` is ignored, so if `covariance_matrix` is not symmetric no error will be raised (unless `validate_args is True`). `covariance_matrix` has shape `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ parameters = dict(locals()) # Convert the covariance_matrix up to a scale_tril and call MVNTriL. with tf.name_scope(name) as name: with tf.name_scope("init"): dtype = dtype_util.common_dtype([loc, covariance_matrix], tf.float32) loc = loc if loc is None else tf.convert_to_tensor( loc, name="loc", dtype=dtype) if covariance_matrix is None: scale_tril = None else: covariance_matrix = tf.convert_to_tensor( covariance_matrix, name="covariance_matrix", dtype=dtype) if validate_args: covariance_matrix = distribution_util.with_dependencies( [ assert_util.assert_near( covariance_matrix, tf.linalg.matrix_transpose( covariance_matrix), message="Matrix was not symmetric") ], covariance_matrix) # No need to validate that covariance_matrix is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular method that # is called by the Bijector. # However, cholesky() ignores the upper triangular part, so we do need # to separately assert symmetric. scale_tril = tf.linalg.cholesky(covariance_matrix) super(MultivariateNormalFullCovariance, self).__init__(loc=loc, scale_tril=scale_tril, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) self._parameters = parameters
def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and covariance `C_a`, `C_b` respectively, ``` KL(a || b) = 0.5 * ( L - k + T + Q ), L := Log[Det(C_b)] - Log[Det(C_a)] T := trace(C_b^{-1} C_a), Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a), ``` This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient methods for solving systems with `C_b` may be available, a dense version of (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B` is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x` and `y`. Args: a: Instance of `MultivariateNormalLinearOperator`. b: Instance of `MultivariateNormalLinearOperator`. name: (optional) name to use for created ops. Default "kl_mvn". Returns: Batchwise `KL(a || b)`. """ def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html # The gradient of KL[p,q] is not defined when p==q. The culprit is # tf.norm, i.e., we cannot use the commented out code. # return tf.square(tf.norm(x, ord="fro", axis=[-2, -1])) return tf.reduce_sum(tf.square(x), axis=[-2, -1]) # TODO(b/35041439): See also b/35040945. Remove this function once LinOp # supports something like: # A.inverse().solve(B).norm(order='fro', axis=[-1, -2]) def is_diagonal(x): """Helper to identify if `LinearOperator` has only a diagonal component.""" return (isinstance(x, tf.linalg.LinearOperatorIdentity) or isinstance(x, tf.linalg.LinearOperatorScaledIdentity) or isinstance(x, tf.linalg.LinearOperatorDiag)) with tf.name_scope(name or 'kl_mvn'): # Calculation is based on: # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians # and, # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm # i.e., # If Ca = AA', Cb = BB', then # tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A'] # = tr[inv(B) A A' inv(B)'] # = tr[(inv(B) A) (inv(B) A)'] # = sum_{ij} (inv(B) A)_{ij}**2 # = ||inv(B) A||_F**2 # where ||.||_F is the Frobenius norm and the second equality follows from # the cyclic permutation property. if is_diagonal(a.scale) and is_diagonal(b.scale): # Using `stddev` because it handles expansion of Identity cases. b_inv_a = (a.stddev() / b.stddev())[..., tf.newaxis] else: b_inv_a = b.scale.solve(a.scale.to_dense()) kl_div = (b.scale.log_abs_determinant() - a.scale.log_abs_determinant() + 0.5 * (-tf.cast(a.scale.domain_dimension_tensor(), a.dtype) + squared_frobenius_norm(b_inv_a) + squared_frobenius_norm( b.scale.solve((b.mean() - a.mean())[..., tf.newaxis])))) tensorshape_util.set_shape( kl_div, tf.broadcast_static_shape(a.batch_shape, b.batch_shape)) return kl_div
def log_ndtr(x, series_order=3, name="log_ndtr"): """Log Normal distribution function. For details of the Normal distribution function see `ndtr`. This function calculates `(log o ndtr)(x)` by either calling `log(ndtr(x))` or using an asymptotic series. Specifically: - For `x > upper_segment`, use the approximation `-ndtr(-x)` based on `log(1-x) ~= -x, x << 1`. - For `lower_segment < x <= upper_segment`, use the existing `ndtr` technique and take a log. - For `x <= lower_segment`, we use the series approximation of erf to compute the log CDF directly. The `lower_segment` is set based on the precision of the input: ``` lower_segment = { -20, x.dtype=float64 { -10, x.dtype=float32 upper_segment = { 8, x.dtype=float64 { 5, x.dtype=float32 ``` When `x < lower_segment`, the `ndtr` asymptotic series approximation is: ``` ndtr(x) = scale * (1 + sum) + R_N scale = exp(-0.5 x**2) / (-x sqrt(2 pi)) sum = Sum{(-1)^n (2n-1)!! / (x**2)^n, n=1:N} R_N = O(exp(-0.5 x**2) (2N+1)!! / |x|^{2N+3}) ``` where `(2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)` is a [double-factorial](https://en.wikipedia.org/wiki/Double_factorial). Args: x: `Tensor` of type `float32`, `float64`. series_order: Positive Python `integer`. Maximum depth to evaluate the asymptotic expansion. This is the `N` above. name: Python string. A name for the operation (default="log_ndtr"). Returns: log_ndtr: `Tensor` with `dtype=x.dtype`. Raises: TypeError: if `x.dtype` is not handled. TypeError: if `series_order` is a not Python `integer.` ValueError: if `series_order` is not in `[0, 30]`. """ if not isinstance(series_order, int): raise TypeError("series_order must be a Python integer.") if series_order < 0: raise ValueError("series_order must be non-negative.") if series_order > 30: raise ValueError("series_order must be <= 30.") with tf.name_scope(name): x = tf.convert_to_tensor(x, name="x") if dtype_util.base_equal(x.dtype, tf.float64): lower_segment = LOGNDTR_FLOAT64_LOWER upper_segment = LOGNDTR_FLOAT64_UPPER elif dtype_util.base_equal(x.dtype, tf.float32): lower_segment = LOGNDTR_FLOAT32_LOWER upper_segment = LOGNDTR_FLOAT32_UPPER else: raise TypeError("x.dtype=%s is not supported." % x.dtype) # The basic idea here was ported from: # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html # We copy the main idea, with a few changes # * For x >> 1, and X ~ Normal(0, 1), # Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x], # which extends the range of validity of this function. # * We use one fixed series_order for all of 'x', rather than adaptive. # * Our docstring properly reflects that this is an asymptotic series, not a # Taylor series. We also provided a correct bound on the remainder. # * We need to use the max/min in the _log_ndtr_lower arg to avoid nan when # x=0. This happens even though the branch is unchosen because when x=0 # the gradient of a select involves the calculation 1*dy+0*(-inf)=nan # regardless of whether dy is finite. Note that the minimum is a NOP if # the branch is chosen. return tf.where( x > upper_segment, -_ndtr(-x), # log(1-x) ~= -x, x << 1 tf.where( x > lower_segment, tf.math.log(_ndtr(tf.maximum(x, lower_segment))), _log_ndtr_lower(tf.minimum(x, lower_segment), series_order)))