def _cdf(self, counts): counts = self._maybe_assert_valid_sample(counts) probs = self._probs_parameter_no_checks() if not (tensorshape_util.is_fully_defined(counts.shape) and tensorshape_util.is_fully_defined(probs.shape) and tensorshape_util.is_compatible_with(counts.shape, probs.shape)): # If both shapes are well defined and equal, we skip broadcasting. probs = probs + tf.zeros_like(counts) counts = counts + tf.zeros_like(probs) return _bdtr(k=counts, n=tf.convert_to_tensor(self.total_count), p=probs)
def _mean(self): # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/ event_dim = tf.compat.dimension_value(self.event_shape[0]) if event_dim is None: raise ValueError('event shape must be statically known for _bessel_ive') safe_conc = tf.where(self.concentration > 0, self.concentration, tf.ones_like(self.concentration)) safe_mean = self.mean_direction * ( _bessel_ive(event_dim / 2, safe_conc) / _bessel_ive(event_dim / 2 - 1, safe_conc))[..., tf.newaxis] return tf.where( self.concentration[..., tf.newaxis] > tf.zeros_like(safe_mean), safe_mean, tf.zeros_like(safe_mean))
def _survival_function(self, y): low = self._low high = self._high # Recall the promise: # survival_function(y) := P[Y > y] # = 0, if y >= high, # = 1, if y < low, # = P[X > y], otherwise. # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in # between. j = tf.math.ceil(y) # P[X > j], used when low < X < high. result_so_far = self.distribution.survival_function(j) # Re-define values at the cutoffs. if low is not None: result_so_far = tf.where(j < low, tf.ones_like(result_so_far), result_so_far) if high is not None: result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far), result_so_far) return result_so_far
def _log_cdf(self, y): low = self._low high = self._high # Recall the promise: # cdf(y) := P[Y <= y] # = 1, if y >= high, # = 0, if y < low, # = P[X <= y], otherwise. # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in # between. j = tf.floor(y) result_so_far = self.distribution.log_cdf(j) # Re-define values at the cutoffs. if low is not None: result_so_far = tf.where( j < low, dtype_util.as_numpy_dtype(self.dtype)(-np.inf), result_so_far) if high is not None: result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far), result_so_far) return result_so_far
def _log_normalization(self, concentration=None, name='log_normalization'): """Returns the log normalization of an LKJ distribution. Args: concentration: `float` or `double` `Tensor`. The positive concentration parameter of the LKJ distributions. name: Python `str` name prefixed to Ops created by this function. Returns: log_z: A Tensor of the same shape and dtype as `concentration`, containing the corresponding log normalizers. """ # The formula is from D. Lewandowski et al [1], p. 1999, from the # proof that eqs 16 and 17 are equivalent. with tf.name_scope(name or 'log_normalization_lkj'): concentration = (tf.convert_to_tensor( self.concentration if concentration is None else concentration) ) logpi = np.log(np.pi) ans = tf.zeros_like(concentration) for k in range(1, self.dimension): ans = ans + logpi * (k / 2.) ans = ans + tf.math.lgamma(concentration + (self.dimension - 1 - k) / 2.) ans = ans - tf.math.lgamma(concentration + (self.dimension - 1) / 2.) return ans
def _cdf(self, y): low = self._low high = self._high # Recall the promise: # cdf(y) := P[Y <= y] # = 1, if y >= high, # = 0, if y < low, # = P[X <= y], otherwise. # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in # between. j = tf.floor(y) # P[X <= j], used when low < X < high. result_so_far = self.distribution.cdf(j) # Re-define values at the cutoffs. if low is not None: result_so_far = tf.where(j < low, tf.zeros_like(result_so_far), result_so_far) if high is not None: result_so_far = tf.where(j >= high, tf.ones_like(result_so_far), result_so_far) return result_so_far
def _log_ndtr_asymptotic_series(x, series_order): """Calculates the asymptotic series used in log_ndtr.""" npdt = dtype_util.as_numpy_dtype(x.dtype) if series_order <= 0: return npdt(1) x_2 = tf.square(x) even_sum = tf.zeros_like(x) odd_sum = tf.zeros_like(x) x_2n = x_2 # Start with x^{2*1} = x^{2*n} with n = 1. for n in range(1, series_order + 1): y = npdt(_double_factorial(2 * n - 1)) / x_2n if n % 2: odd_sum += y else: even_sum += y x_2n *= x_2 return 1. + even_sum - odd_sum
def cholesky_concat(chol, cols, name=None): """Concatenates `chol @ chol.T` with additional rows and columns. This operation is conceptually identical to: ```python def cholesky_concat_slow(chol, cols): # cols shaped (n + m) x m = z x m mat = tf.matmul(chol, chol, adjoint_b=True) # batch of n x n # Concat columns. mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1) # n x z # Concat rows. mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2) # z x z return tf.linalg.cholesky(mat) ``` but whereas `cholesky_concat_slow` would cost `O(z**3)` work, `cholesky_concat` only costs `O(z**2 + m**3)` work. The resulting (implicit) matrix must be symmetric and positive definite. Thus, the bottom right `m x m` must be self-adjoint, and we do not require a separate `rows` argument (which can be inferred from `conj(cols.T)`). Args: chol: Cholesky decomposition of `mat = chol @ chol.T`. cols: The new columns whose first `n` rows we would like concatenated to the right of `mat = chol @ chol.T`, and whose conjugate transpose we would like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor` with final dims `(n+m, m)`. The first `n` rows are the top right rectangle (their conjugate transpose forms the bottom left), and the bottom `m x m` is self-adjoint. name: Optional name for this op. Returns: chol_concat: The Cholesky decomposition of: ``` [ [ mat cols[:n, :] ] [ conj(cols.T) ] ] ``` """ with tf.name_scope(name or 'cholesky_extend'): dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32) chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype) cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype) n = prefer_static.shape(chol)[-1] mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :] solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast( chol, mat_nm) lower_right_mm = tf.linalg.cholesky( mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True)) lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm)) out_batch = prefer_static.shape(solved_nm)[:-2] chol = tf.broadcast_to( chol, tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0)) top_right_zeros_nm = tf.zeros_like(solved_nm) return tf.concat([ tf.concat([chol, top_right_zeros_nm], axis=-1), tf.concat([lower_left_mn, lower_right_mm], axis=-1) ], axis=-2)
def _cdf(self, x): # CDF is the probability that the Poisson variable is less or equal to x. # For fractional x, the CDF is equal to the CDF at n = floor(x). # For negative x, the CDF is zero, but tf.igammac gives NaNs, so we impute # the values and handle this case explicitly. safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 0.) cdf = tf.math.igammac(1. + safe_x, self._rate_parameter_no_checks()) return tf.where(x < 0., tf.zeros_like(cdf), cdf)
def _cdf(self, x): with tf.control_dependencies(self._maybe_assert_valid_sample(x)): probs = self._probs_parameter_no_checks() if not self.validate_args: # Whether or not x is integer-form, the following is well-defined. # However, scipy takes the floor, so we do too. x = tf.floor(x) return tf.where(x < 0., tf.zeros_like(x), -tf.math.expm1( (1. + x) * tf.math.log1p(-probs)))
def _log_prob(self, x): with tf.control_dependencies(self._maybe_assert_valid_sample(x)): probs = self._probs_parameter_no_checks() if not self.validate_args: # For consistency with cdf, we take the floor. x = tf.floor(x) safe_domain = tf.where(tf.equal(x, 0.), tf.zeros_like(probs), probs) return x * tf.math.log1p(-safe_domain) + tf.math.log(probs)
def _prob(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) return tf.where( tf.math.is_nan(x), x, tf.where( # This > is only sound for continuous uniform (x < low) | (x > high), tf.zeros_like(x), tf.ones_like(x) / self._range(low=low, high=high)))
def _assertions(self, x): if not self.validate_args: return [] shape = tf.shape(x) is_matrix = assert_util.assert_rank_at_least( x, 2, message="Input must have rank at least 2.") is_square = assert_util.assert_equal( shape[-2], shape[-1], message="Input must be a square matrix.") above_diagonal = tf.linalg.band_part( tf.linalg.set_diag(x, tf.zeros(shape[:-1], dtype=tf.float32)), 0, -1) is_lower_triangular = assert_util.assert_equal( above_diagonal, tf.zeros_like(above_diagonal), message="Input must be lower triangular.") # A lower triangular matrix is nonsingular iff all its diagonal entries are # nonzero. diag_part = tf.linalg.diag_part(x) is_nonsingular = assert_util.assert_none_equal( diag_part, tf.zeros_like(diag_part), message="Input must have all diagonal entries nonzero.") return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
def _cdf(self, x): # CDF(x) at positive integer x is the probability that the Zipf variable is # less than or equal to x; given by the formula: # CDF(x) = 1 - (zeta(power, x + 1) / Z) # For fractional x, the CDF is equal to the CDF at n = floor(x). # For x < 1, the CDF is zero. # If interpolate_nondiscrete is True, we return a continuous relaxation # which agrees with the CDF at integer points. power = tf.convert_to_tensor(self.power) x = tf.cast(x, power.dtype) safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 0.) cdf = 1. - ( tf.math.zeta(power, safe_x + 1.) / tf.math.zeta(power, 1.)) return tf.where(x < 1., tf.zeros_like(cdf), cdf)
def _sample_n(self, n, seed=None): concentration = tf.convert_to_tensor(self.concentration) mixing_concentration = tf.convert_to_tensor(self.mixing_concentration) mixing_rate = tf.convert_to_tensor(self.mixing_rate) seed = SeedStream(seed, 'gamma_gamma') rate = tf.random.gamma( shape=[n], # Be sure to draw enough rates for the fully-broadcasted gamma-gamma. alpha=mixing_concentration + tf.zeros_like(concentration), beta=mixing_rate, dtype=self.dtype, seed=seed()) return tf.random.gamma(shape=[], alpha=concentration, beta=rate, dtype=self.dtype, seed=seed())
def _sample_n(self, n, seed=None): # Like with the univariate Student's t, sampling can be implemented as a # ratio of samples from a multivariate gaussian with the appropriate # covariance matrix and a sample from the chi-squared distribution. seed = SeedStream(seed, salt='multivariate t') loc = tf.broadcast_to(self.loc, self._sample_shape()) mvn = mvn_linear_operator.MultivariateNormalLinearOperator( loc=tf.zeros_like(loc), scale=self.scale) normal_samp = mvn.sample(n, seed=seed()) df = tf.broadcast_to(self.df, self.batch_shape_tensor()) chi2 = chi2_lib.Chi2(df=df) chi2_samp = chi2.sample(n, seed=seed()) return ( self._loc + normal_samp * tf.math.rsqrt(chi2_samp / self._df)[..., tf.newaxis])
def _sample_n(self, n, seed=None): # Need to create logits corresponding to [p, 1 - p]. # Note that for this distributions, logits corresponds to # inverse sigmoid(p) while in multivariate distributions, # such as multinomial this corresponds to log(p). # Because of this, when we construct the logits for the multinomial # sampler, we'll have to be careful. # log(p) = log(sigmoid(logits)) = logits - softplus(logits) # log(1 - p) = log(1 - sigmoid(logits)) = -softplus(logits) # Because softmax is invariant to a constant shift in all inputs, # we can offset the logits by softplus(logits) so that we can use # [logits, 0.] as our input. orig_logits = self._logits_parameter_no_checks() logits = tf.stack([orig_logits, tf.zeros_like(orig_logits)], axis=-1) return multinomial.draw_sample( num_samples=n, num_classes=2, logits=logits, num_trials=tf.cast(self.total_count, dtype=tf.int32), dtype=self.dtype, seed=seed)[..., 0]
def _sample_3d(self, n, seed=None): """Specialized inversion sampler for 3D.""" seed = SeedStream(seed, salt='von_mises_fisher_3d') u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype) # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could # be bisected for bounded sampling runtime (i.e. not rejection sampling). # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/ # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa # We must protect against both kappa and z being zero. safe_conc = tf.where(self.concentration > 0, self.concentration, tf.ones_like(self.concentration)) safe_z = tf.where(z > 0, z, tf.ones_like(z)) safe_u = 1 + tf.reduce_logsumexp( [tf.math.log(safe_z), tf.math.log1p(-safe_z) - 2 * safe_conc], axis=0) / safe_conc # Limit of the above expression as kappa->0 is 2*z-1 u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u, 2 * z - 1) # Limit of the expression as z->0 is -1. u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u) if not self._allow_nan_stats: u = tf.debugging.check_numerics(u, 'u in _sample_3d') return u[..., tf.newaxis]
def _cdf(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) interval_length = high - low # Due to the PDF being not smooth at the peak, we have to treat each side # somewhat differently. The PDF is two line segments, and thus we get # quadratics here for the CDF. result_inside_interval = tf.where( (x >= low) & (x <= peak), # (x - low) ** 2 / ((high - low) * (peak - low)) tf.math.squared_difference(x, low) / (interval_length * (peak - low)), # 1 - (high - x) ** 2 / ((high - low) * (high - peak)) 1. - tf.math.squared_difference(high, x) / (interval_length * (high - peak))) # We now add that the left tail is 0 and the right tail is 1. result_if_not_big = tf.where(x < low, tf.zeros_like(x), result_inside_interval) return tf.where(x >= high, tf.ones_like(x), result_if_not_big)
def _prob(self, x): low = tf.convert_to_tensor(self.low) high = tf.convert_to_tensor(self.high) peak = tf.convert_to_tensor(self.peak) if self.validate_args: with tf.control_dependencies([ assert_util.assert_greater_equal(x, low), assert_util.assert_less_equal(x, high) ]): x = tf.identity(x) interval_length = high - low # This is the pdf function when a low <= high <= x. This looks like # a triangle, so we have to treat each line segment separately. result_inside_interval = tf.where( (x >= low) & (x <= peak), # Line segment from (low, 0) to (peak, 2 / (high - low)). 2. * (x - low) / (interval_length * (peak - low)), # Line segment from (peak, 2 / (high - low)) to (high, 0). 2. * (high - x) / (interval_length * (high - peak))) return tf.where((x < low) | (x > high), tf.zeros_like(x), result_inside_interval)
def _covariance(self): # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/ event_dim = tf.compat.dimension_value(self.event_shape[0]) if event_dim is None: raise ValueError('event shape must be statically known for _bessel_ive') # TODO(bjp): Enable this; numerically unstable. if event_dim > 2: raise ValueError('vMF covariance is numerically unstable for dim>2') concentration = self.concentration[..., tf.newaxis] safe_conc = tf.where(concentration > 0, concentration, tf.ones_like(concentration)) h = (_bessel_ive(event_dim / 2, safe_conc) / _bessel_ive(event_dim / 2 - 1, safe_conc)) intermediate = ( tf.matmul(self.mean_direction[..., :, tf.newaxis], self.mean_direction[..., tf.newaxis, :]) * (1 - event_dim * h / safe_conc - h**2)[..., tf.newaxis]) cov = tf.linalg.set_diag( intermediate, tf.linalg.diag_part(intermediate) + (h / safe_conc)) return tf.where( concentration[..., tf.newaxis] > tf.zeros_like(cov), cov, tf.linalg.eye(event_dim, batch_shape=self.batch_shape_tensor()) / event_dim)
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None): """Computes the (partial) pivoted cholesky decomposition of `matrix`. The pivoted Cholesky is a low rank approximation of the Cholesky decomposition of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The currently-worst-approximated diagonal element is selected as the pivot at each iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn, N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`. Note that, unlike the Cholesky decomposition, `lr` is not triangular even in a rectangular-matrix sense. However, under a permutation it could be made triangular (it has one more zero in each column as you move to the right). Such a matrix can be useful as a preconditioner for conjugate gradient optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be cheaply done via the Woodbury matrix identity, as implemented by `tf.linalg.LinearOperatorLowRankUpdate`. Args: matrix: Floating point `Tensor` batch of symmetric, positive definite matrices. max_rank: Scalar `int` `Tensor`, the rank at which to truncate the approximation. diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the errors of all diagonal elements of `lr @ lr.T` are each lower than `element * diag_rtol`, iteration is permitted to terminate early. name: Optional name for the op. Returns: lr: Low rank pivoted Cholesky approximation of `matrix`. #### References [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the pivoted Cholesky decomposition. _Applied numerical mathematics_, 62(4):428-440, 2012. [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points. _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114 """ with tf.name_scope(name or 'pivoted_cholesky'): dtype = dtype_util.common_dtype([matrix, diag_rtol], dtype_hint=tf.float32) matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype) if tensorshape_util.rank(matrix.shape) is None: raise NotImplementedError( 'Rank of `matrix` must be known statically') max_rank = tf.convert_to_tensor(max_rank, name='max_rank', dtype=tf.int64) max_rank = tf.minimum( max_rank, prefer_static.shape(matrix, out_type=tf.int64)[-1]) diag_rtol = tf.convert_to_tensor(diag_rtol, dtype=dtype, name='diag_rtol') matrix_diag = tf.linalg.diag_part(matrix) # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs. orig_error = tf.reduce_max(matrix_diag, axis=-1) def cond(m, pchol, perm, matrix_diag): """Condition for `tf.while_loop` continuation.""" del pchol del perm error = tf.linalg.norm(matrix_diag, ord=1, axis=-1) max_err = tf.reduce_max(error / orig_error) return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol)) batch_dims = tensorshape_util.rank(matrix.shape) - 2 def batch_gather(params, indices, axis=-1): return tf.gather(params, indices, axis=axis, batch_dims=batch_dims) def body(m, pchol, perm, matrix_diag): """Body of a single `tf.while_loop` iteration.""" # Here is roughly a numpy, non-batched version of what's going to happen. # (See also Algorithm 1 of Harbrecht et al.) # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m # 2: maxval = matrix_diag[perm][maxi] # 3: perm[m], perm[maxi] = perm[maxi], perm[m] # 4: row = matrix[perm[m]][perm[m + 1:]] # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2) # 6: pivot = np.sqrt(maxval); row /= pivot # 7: row = np.concatenate([[[pivot]], row], -1) # 8: matrix_diag[perm[m:]] -= row**2 # 9: pchol[m, perm[m:]] = row # Find the maximal position of the (remaining) permuted diagonal. # Steps 1, 2 above. permuted_diag = batch_gather(matrix_diag, perm[..., m:]) maxi = tf.argmax(permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis] maxval = batch_gather(permuted_diag, maxi) maxi = maxi + m maxval = maxval[..., 0] # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above. perm = _swap_m_with_i(perm, m, maxi) # Step 4. row = batch_gather(matrix, perm[..., m:m + 1], axis=-2) row = batch_gather(row, perm[..., m + 1:]) # Step 5. prev_rows = pchol[..., :m, :] prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:]) prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1]) row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col, axis=-2)[..., tf.newaxis, :] # Step 6. pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis] # Step 7. row = tf.concat([pivot, row / pivot], axis=-1) # TODO(b/130899118): Pad grad fails with int64 paddings. # Step 8. paddings = tf.concat([ tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32), [[tf.cast(m, tf.int32), 0]] ], axis=0) diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :] reverse_perm = _invert_permutation(perm) matrix_diag -= batch_gather(diag_update, reverse_perm) # Step 9. row = tf.pad(row, paddings=paddings) # TODO(bjp): Defer the reverse permutation all-at-once at the end? row = batch_gather(row, reverse_perm) pchol_shape = pchol.shape pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]], axis=-2) tensorshape_util.set_shape(pchol, pchol_shape) return m + 1, pchol, perm, matrix_diag m = np.int64(0) pchol = tf.zeros_like(matrix[..., :max_rank, :]) matrix_shape = prefer_static.shape(matrix, out_type=tf.int64) perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]), matrix_shape[:-1]) _, pchol, _, _ = tf.while_loop(cond=cond, body=body, loop_vars=(m, pchol, perm, matrix_diag)) pchol = tf.linalg.matrix_transpose(pchol) tensorshape_util.set_shape( pchol, tensorshape_util.concatenate(matrix_diag.shape, [None])) return pchol
def _mean(self): # Shape is broadcasted with + tf.zeros_like(). return self.loc + tf.zeros_like(self.concentration)
def _variance(self): return tf.zeros_like(self.loc)
def one_step(self, current_state, previous_kernel_results, seed=None): """Takes one step of the TransitionKernel. Args: current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within the previous call to this function (or as returned by `bootstrap_results`). seed: Optional, a seed for reproducible sampling. Returns: next_state: `Tensor` or Python `list` of `Tensor`s representing the next state(s) of the Markov chain(s). kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope(mcmc_util.make_name(self.name, 'tmc', 'one_step')): # Force a read in case the `inverse_temperatures` is a `tf.Variable`. inverse_temperatures = tf.convert_to_tensor( previous_kernel_results.post_tempering_inverse_temperatures, name='inverse_temperatures') steps_at_temperature = tf.convert_to_tensor( previous_kernel_results.steps_at_temperature, name='number of steps') target_score_for_inner_kernel = partial(self.target_score_fn, sigma=inverse_temperatures) target_log_prob_for_inner_kernel = partial( self.target_log_prob_fn, sigma=inverse_temperatures) try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, target_score_for_inner_kernel, inverse_temperatures) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is ' 'deprecated. `TransitionKernel` instances now receive seeds via ' '`one_step`.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, target_score_for_inner_kernel, inverse_temperatures, self._seed_stream()) if seed is not None: seed = samplers.sanitize_seed(seed) inner_seed, swap_seed, logu_seed = samplers.split_seed( seed, n=3, salt='tmc_one_step') inner_kwargs = dict(seed=inner_seed) else: if self._seed_stream.original_seed is not None: warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG) inner_kwargs = {} swap_seed, logu_seed = samplers.split_seed(self._seed_stream()) if mcmc_util.is_list_like(current_state): # We *always* canonicalize the states in the kernel results. states = current_state else: states = [current_state] print(states) [ new_state, pre_tempering_results, ] = inner_kernel.one_step( states, previous_kernel_results.post_tempering_results, **inner_kwargs) # Now that we have run one step, we consider maybe lowering the temperature # Proposed new temperature proposed_inverse_temperatures = tf.clip_by_value( self.gamma * inverse_temperatures, self.min_temp, 1e6) dtype = inverse_temperatures.dtype # We will lower the temperature if this new proposed step is compatible with # a temperature swap v = new_state[0] - states[0] cs = states[0] @jax.vmap def integrand(t): return jnp.sum(self._parameters['target_score_fn']( t * v + cs, inverse_temperatures) * v, axis=-1) delta_logp1 = simps(integrand, 0., 1., self._parameters['num_delta_logp_steps']) # Now we compute the reverse v = -v cs = new_state[0] @jax.vmap def integrand(t): return jnp.sum(self._parameters['target_score_fn']( t * v + cs, proposed_inverse_temperatures) * v, axis=-1) delta_logp2 = simps(integrand, 0., 1., self._parameters['num_delta_logp_steps']) log_accept_ratio = (delta_logp1 + delta_logp2) log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio), log_accept_ratio, tf.constant(-np.inf, dtype=dtype)) # Produce Log[Uniform] draws that are identical at swapped indices. log_uniform = tf.math.log( samplers.uniform(shape=log_accept_ratio.shape, dtype=dtype, seed=logu_seed)) is_tempering_accepted_mask = tf.less( log_uniform, log_accept_ratio, name='is_tempering_accepted_mask') is_min_steps_satisfied = tf.greater( steps_at_temperature, self.min_steps_per_temp * tf.ones_like(steps_at_temperature), name='is_min_steps_satisfied') # Only propose tempering if the chain was going to accept this point anyway is_tempering_accepted_mask = tf.math.logical_and( is_tempering_accepted_mask, pre_tempering_results.is_accepted) is_tempering_accepted_mask = tf.math.logical_and( is_tempering_accepted_mask, is_min_steps_satisfied) # Updating accepted inverse temperatures post_tempering_inverse_temperatures = mcmc_util.choose( is_tempering_accepted_mask, proposed_inverse_temperatures, inverse_temperatures) steps_at_temperature = mcmc_util.choose( is_tempering_accepted_mask, tf.zeros_like(steps_at_temperature), steps_at_temperature + 1) # Invalidating and recomputing results [ new_target_log_prob, new_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads( partial(self.target_log_prob_fn, sigma=post_tempering_inverse_temperatures), new_state) # Updating inner kernel results post_tempering_results = pre_tempering_results._replace( proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype), proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype), ) if isinstance(post_tempering_results.accepted_results, hmc.UncalibratedHamiltonianMonteCarloKernelResults): post_tempering_results = post_tempering_results._replace( accepted_results=post_tempering_results.accepted_results. _replace(target_log_prob=new_target_log_prob, grads_target_log_prob=new_grads_target_log_prob)) elif isinstance( post_tempering_results.accepted_results, random_walk_metropolis.UncalibratedRandomWalkResults): post_tempering_results = post_tempering_results._replace( accepted_results=post_tempering_results.accepted_results. _replace(target_log_prob=new_target_log_prob)) else: # TODO(b/143702650) Handle other kernels. raise NotImplementedError( 'Only HMC and RWMH Kernels are handled at this time. Please file a ' 'request with the TensorFlow Probability team.') new_kernel_results = TemperedMCKernelResults( pre_tempering_results=pre_tempering_results, post_tempering_results=post_tempering_results, pre_tempering_inverse_temperatures=inverse_temperatures, post_tempering_inverse_temperatures= post_tempering_inverse_temperatures, tempering_log_accept_ratio=log_accept_ratio, steps_at_temperature=steps_at_temperature, seed=samplers.zeros_seed() if seed is None else seed, ) return new_state[0], new_kernel_results
def bootstrap_results(self, init_state): """Returns an object with the same type as returned by `one_step`. Args: init_state: `Tensor` or Python `list` of `Tensor`s representing the initial state(s) of the Markov chain(s). Returns: kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of `Tensor`s representing internal calculations made within this function. This inculdes replica states. """ with tf.name_scope( mcmc_util.make_name(self.name, 'tmc', 'bootstrap_results')): init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts( init_state) inverse_temperatures = tf.convert_to_tensor( self.inverse_temperatures, name='inverse_temperatures') target_score_for_inner_kernel = partial(self.target_score_fn, sigma=inverse_temperatures) target_log_prob_for_inner_kernel = partial( self.target_log_prob_fn, sigma=inverse_temperatures) # Seed handling complexity is due to users possibly expecting an old-style # stateful seed to be passed to `self.make_kernel_fn`. # In other words: # - We try `make_kernel_fn` without a seed first; this is the future. The # kernel will receive a seed later, as part of `one_step`. # - If the user code doesn't like that (Python complains about a missing # required argument), we fall back to the previous behavior and warn. try: inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, target_score_for_inner_kernel, inverse_temperatures) except TypeError as e: if 'argument' not in str(e): raise warnings.warn( 'The second (`seed`) argument to `ReplicaExchangeMC`s ' '`make_kernel_fn` is deprecated. `TransitionKernel` instances now ' 'receive seeds via `bootstrap_results` and `one_step`. This ' 'fallback may become an error 2020-09-20.') inner_kernel = self.make_kernel_fn( # pylint: disable=not-callable target_log_prob_for_inner_kernel, target_score_for_inner_kernel, inverse_temperatures, self._seed_stream()) inner_results = inner_kernel.bootstrap_results(init_state) post_tempering_results = inner_results # Invalidating and recomputing results [ new_target_log_prob, new_grads_target_log_prob, ] = mcmc_util.maybe_call_fn_and_grads( partial(self.target_log_prob_fn, sigma=inverse_temperatures), init_state) # Updating inner kernel results dtype = inverse_temperatures.dtype post_tempering_results = post_tempering_results._replace( proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype), proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype), ) if isinstance(post_tempering_results.accepted_results, hmc.UncalibratedHamiltonianMonteCarloKernelResults): post_tempering_results = post_tempering_results._replace( accepted_results=post_tempering_results.accepted_results. _replace(target_log_prob=new_target_log_prob, grads_target_log_prob=new_grads_target_log_prob)) elif isinstance( post_tempering_results.accepted_results, random_walk_metropolis.UncalibratedRandomWalkResults): post_tempering_results = post_tempering_results._replace( accepted_results=post_tempering_results.accepted_results. _replace(target_log_prob=new_target_log_prob)) else: # TODO(b/143702650) Handle other kernels. raise NotImplementedError( 'Only HMC and RWMH Kernels are handled at this time. Please file a ' 'request with the TensorFlow Probability team.') return TemperedMCKernelResults( pre_tempering_results=inner_results, post_tempering_results=post_tempering_results, pre_tempering_inverse_temperatures=inverse_temperatures, post_tempering_inverse_temperatures=inverse_temperatures, tempering_log_accept_ratio=tf.zeros_like(inverse_temperatures), steps_at_temperature=tf.zeros_like(inverse_temperatures, dtype=tf.int32), seed=samplers.zeros_seed(), )
def _log_unnormalized_prob(self, samples): samples = self._maybe_assert_valid_sample(samples) bcast_mean_dir = (self.mean_direction + tf.zeros_like(self.concentration)[..., tf.newaxis]) inner_product = tf.reduce_sum(samples * bcast_mean_dir, axis=-1) return self.concentration * inner_product
def _sample_n(self, num_samples, seed=None, name=None): """Returns a Tensor of samples from an LKJ distribution. Args: num_samples: Python `int`. The number of samples to draw. seed: Python integer seed for RNG name: Python `str` name prefixed to Ops created by this function. Returns: samples: A Tensor of correlation matrices with shape `[n, B, D, D]`, where `B` is the shape of the `concentration` parameter, and `D` is the `dimension`. Raises: ValueError: If `dimension` is negative. """ if self.dimension < 0: raise ValueError( 'Cannot sample negative-dimension correlation matrices.') # Notation below: B is the batch shape, i.e., tf.shape(concentration) seed = SeedStream(seed, 'sample_lkj') with tf.name_scope('sample_lkj' or name): concentration = tf.convert_to_tensor(self.concentration) if not dtype_util.is_floating(concentration.dtype): raise TypeError( 'The concentration argument should have floating type, not ' '{}'.format(dtype_util.name(concentration.dtype))) concentration = _replicate(num_samples, concentration) concentration_shape = tf.shape(concentration) if self.dimension <= 1: # For any dimension <= 1, there is only one possible correlation matrix. shape = tf.concat( [concentration_shape, [self.dimension, self.dimension]], axis=0) return tf.ones(shape=shape, dtype=concentration.dtype) beta_conc = concentration + (self.dimension - 2.) / 2. beta_dist = beta.Beta(concentration1=beta_conc, concentration0=beta_conc) # Note that the sampler below deviates from [1], by doing the sampling in # cholesky space. This does not change the fundamental logic of the # sampler, but does speed up the sampling. # This is the correlation coefficient between the first two dimensions. # This is also `r` in reference [1]. corr12 = 2. * beta_dist.sample(seed=seed()) - 1. # Below we construct the Cholesky of the initial 2x2 correlation matrix, # which is of the form: # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the # first two dimensions. # This is the top-left corner of the cholesky of the final sample. first_row = tf.concat([ tf.ones_like(corr12)[..., tf.newaxis], tf.zeros_like(corr12)[..., tf.newaxis] ], axis=-1) second_row = tf.concat([ corr12[..., tf.newaxis], tf.sqrt(1 - corr12**2)[..., tf.newaxis] ], axis=-1) chol_result = tf.concat([ first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :] ], axis=-2) for n in range(2, self.dimension): # Loop invariant: on entry, result has shape B + [n, n] beta_conc = beta_conc - 0.5 # norm is y in reference [1]. norm = beta.Beta(concentration1=n / 2., concentration0=beta_conc).sample(seed=seed()) # distance shape: B + [1] for broadcast distance = tf.sqrt(norm)[..., tf.newaxis] # direction is u in reference [1]. # direction shape: B + [n] direction = _uniform_unit_norm(n, concentration_shape, concentration.dtype, seed) # raw_correlation is w in reference [1]. raw_correlation = distance * direction # shape: B + [n] # This is the next row in the cholesky of the result, # which differs from the construction in reference [1]. # In the reference, the new row `z` = chol_result @ raw_correlation^T # = C @ raw_correlation^T (where as short hand we use C = chol_result). # We prove that the below equation is the right row to add to the # cholesky, by showing equality with reference [1]. # Let S be the sample constructed so far, and let `z` be as in # reference [1]. Then at this iteration, the new sample S' will be # [[S z^T] # [z 1]] # In our case we have the cholesky decomposition factor C, so # we want our new row x (same size as z) to satisfy: # [[S z^T] [[C 0] [[C^T x^T] [[CC^T Cx^T] # [z 1]] = [x k]] [0 k]] = [xC^t xx^T + k**2]] # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible, # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 - # distance**2). new_row = tf.concat( [raw_correlation, tf.sqrt(1. - norm[..., tf.newaxis])], axis=-1) # Finally add this new row, by growing the cholesky of the result. chol_result = tf.concat([ chol_result, tf.zeros_like(chol_result[..., 0][..., tf.newaxis]) ], axis=-1) chol_result = tf.concat( [chol_result, new_row[..., tf.newaxis, :]], axis=-2) if self.input_output_cholesky: return chol_result result = tf.matmul(chol_result, chol_result, transpose_b=True) # The diagonal for a correlation matrix should always be ones. Due to # numerical instability the matmul might not achieve that, so manually set # these to ones. result = tf.linalg.set_diag( result, tf.ones(shape=tf.shape(result)[:-1], dtype=result.dtype)) # This sampling algorithm can produce near-PSD matrices on which standard # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals` # fail. Specifically, as documented in b/116828694, around 2% of trials # of 900,000 5x5 matrices (distributed according to 9 different # concentration parameter values) contained at least one matrix on which # the Cholesky decomposition failed. return result
def _mode(self): """The mode of the von Mises-Fisher distribution is the mean direction.""" return (self.mean_direction + tf.zeros_like(self.concentration)[..., tf.newaxis])
def _create_polynomial(var, coeffs): """Compute n_th order polynomial via Horner's method.""" coeffs = np.array(coeffs, dtype_util.as_numpy_dtype(var.dtype)) if not coeffs.size: return tf.zeros_like(var) return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var