示例#1
0
  def _cdf(self, counts):
    counts = self._maybe_assert_valid_sample(counts)
    probs = self._probs_parameter_no_checks()
    if not (tensorshape_util.is_fully_defined(counts.shape) and
            tensorshape_util.is_fully_defined(probs.shape) and
            tensorshape_util.is_compatible_with(counts.shape, probs.shape)):
      # If both shapes are well defined and equal, we skip broadcasting.
      probs = probs + tf.zeros_like(counts)
      counts = counts + tf.zeros_like(probs)

    return _bdtr(k=counts, n=tf.convert_to_tensor(self.total_count), p=probs)
 def _mean(self):
   # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/
   event_dim = tf.compat.dimension_value(self.event_shape[0])
   if event_dim is None:
     raise ValueError('event shape must be statically known for _bessel_ive')
   safe_conc = tf.where(self.concentration > 0, self.concentration,
                        tf.ones_like(self.concentration))
   safe_mean = self.mean_direction * (
       _bessel_ive(event_dim / 2, safe_conc) /
       _bessel_ive(event_dim / 2 - 1, safe_conc))[..., tf.newaxis]
   return tf.where(
       self.concentration[..., tf.newaxis] > tf.zeros_like(safe_mean),
       safe_mean, tf.zeros_like(safe_mean))
    def _survival_function(self, y):
        low = self._low
        high = self._high

        # Recall the promise:
        # survival_function(y) := P[Y > y]
        #                       = 0, if y >= high,
        #                       = 1, if y < low,
        #                       = P[X > y], otherwise.

        # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in
        # between.
        j = tf.math.ceil(y)

        # P[X > j], used when low < X < high.
        result_so_far = self.distribution.survival_function(j)

        # Re-define values at the cutoffs.
        if low is not None:
            result_so_far = tf.where(j < low, tf.ones_like(result_so_far),
                                     result_so_far)
        if high is not None:
            result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far),
                                     result_so_far)

        return result_so_far
    def _log_cdf(self, y):
        low = self._low
        high = self._high

        # Recall the promise:
        # cdf(y) := P[Y <= y]
        #         = 1, if y >= high,
        #         = 0, if y < low,
        #         = P[X <= y], otherwise.

        # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
        # between.
        j = tf.floor(y)

        result_so_far = self.distribution.log_cdf(j)

        # Re-define values at the cutoffs.
        if low is not None:
            result_so_far = tf.where(
                j < low,
                dtype_util.as_numpy_dtype(self.dtype)(-np.inf), result_so_far)
        if high is not None:
            result_so_far = tf.where(j >= high, tf.zeros_like(result_so_far),
                                     result_so_far)

        return result_so_far
示例#5
0
    def _log_normalization(self, concentration=None, name='log_normalization'):
        """Returns the log normalization of an LKJ distribution.

    Args:
      concentration: `float` or `double` `Tensor`. The positive concentration
        parameter of the LKJ distributions.
      name: Python `str` name prefixed to Ops created by this function.

    Returns:
      log_z: A Tensor of the same shape and dtype as `concentration`, containing
        the corresponding log normalizers.
    """
        # The formula is from D. Lewandowski et al [1], p. 1999, from the
        # proof that eqs 16 and 17 are equivalent.
        with tf.name_scope(name or 'log_normalization_lkj'):
            concentration = (tf.convert_to_tensor(
                self.concentration if concentration is None else concentration)
                             )
            logpi = np.log(np.pi)
            ans = tf.zeros_like(concentration)
            for k in range(1, self.dimension):
                ans = ans + logpi * (k / 2.)
                ans = ans + tf.math.lgamma(concentration +
                                           (self.dimension - 1 - k) / 2.)
                ans = ans - tf.math.lgamma(concentration +
                                           (self.dimension - 1) / 2.)
            return ans
    def _cdf(self, y):
        low = self._low
        high = self._high

        # Recall the promise:
        # cdf(y) := P[Y <= y]
        #         = 1, if y >= high,
        #         = 0, if y < low,
        #         = P[X <= y], otherwise.

        # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
        # between.
        j = tf.floor(y)

        # P[X <= j], used when low < X < high.
        result_so_far = self.distribution.cdf(j)

        # Re-define values at the cutoffs.
        if low is not None:
            result_so_far = tf.where(j < low, tf.zeros_like(result_so_far),
                                     result_so_far)
        if high is not None:
            result_so_far = tf.where(j >= high, tf.ones_like(result_so_far),
                                     result_so_far)

        return result_so_far
示例#7
0
def _log_ndtr_asymptotic_series(x, series_order):
    """Calculates the asymptotic series used in log_ndtr."""
    npdt = dtype_util.as_numpy_dtype(x.dtype)
    if series_order <= 0:
        return npdt(1)
    x_2 = tf.square(x)
    even_sum = tf.zeros_like(x)
    odd_sum = tf.zeros_like(x)
    x_2n = x_2  # Start with x^{2*1} = x^{2*n} with n = 1.
    for n in range(1, series_order + 1):
        y = npdt(_double_factorial(2 * n - 1)) / x_2n
        if n % 2:
            odd_sum += y
        else:
            even_sum += y
        x_2n *= x_2
    return 1. + even_sum - odd_sum
示例#8
0
def cholesky_concat(chol, cols, name=None):
    """Concatenates `chol @ chol.T` with additional rows and columns.

  This operation is conceptually identical to:
  ```python
  def cholesky_concat_slow(chol, cols):  # cols shaped (n + m) x m = z x m
    mat = tf.matmul(chol, chol, adjoint_b=True)  # batch of n x n
    # Concat columns.
    mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1)  # n x z
    # Concat rows.
    mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2)  # z x z
    return tf.linalg.cholesky(mat)
  ```
  but whereas `cholesky_concat_slow` would cost `O(z**3)` work,
  `cholesky_concat` only costs `O(z**2 + m**3)` work.

  The resulting (implicit) matrix must be symmetric and positive definite.
  Thus, the bottom right `m x m` must be self-adjoint, and we do not require a
  separate `rows` argument (which can be inferred from `conj(cols.T)`).

  Args:
    chol: Cholesky decomposition of `mat = chol @ chol.T`.
    cols: The new columns whose first `n` rows we would like concatenated to the
      right of `mat = chol @ chol.T`, and whose conjugate transpose we would
      like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor`
      with final dims `(n+m, m)`. The first `n` rows are the top right rectangle
      (their conjugate transpose forms the bottom left), and the bottom `m x m`
      is self-adjoint.
    name: Optional name for this op.

  Returns:
    chol_concat: The Cholesky decomposition of:
      ```
      [ [ mat  cols[:n, :] ]
        [   conj(cols.T)   ] ]
      ```
  """
    with tf.name_scope(name or 'cholesky_extend'):
        dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32)
        chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype)
        cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype)
        n = prefer_static.shape(chol)[-1]
        mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :]
        solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast(
            chol, mat_nm)
        lower_right_mm = tf.linalg.cholesky(
            mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True))
        lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm))
        out_batch = prefer_static.shape(solved_nm)[:-2]
        chol = tf.broadcast_to(
            chol,
            tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0))
        top_right_zeros_nm = tf.zeros_like(solved_nm)
        return tf.concat([
            tf.concat([chol, top_right_zeros_nm], axis=-1),
            tf.concat([lower_left_mn, lower_right_mm], axis=-1)
        ],
                         axis=-2)
示例#9
0
 def _cdf(self, x):
     # CDF is the probability that the Poisson variable is less or equal to x.
     # For fractional x, the CDF is equal to the CDF at n = floor(x).
     # For negative x, the CDF is zero, but tf.igammac gives NaNs, so we impute
     # the values and handle this case explicitly.
     safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x),
                         0.)
     cdf = tf.math.igammac(1. + safe_x, self._rate_parameter_no_checks())
     return tf.where(x < 0., tf.zeros_like(cdf), cdf)
示例#10
0
 def _cdf(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
         probs = self._probs_parameter_no_checks()
         if not self.validate_args:
             # Whether or not x is integer-form, the following is well-defined.
             # However, scipy takes the floor, so we do too.
             x = tf.floor(x)
         return tf.where(x < 0., tf.zeros_like(x), -tf.math.expm1(
             (1. + x) * tf.math.log1p(-probs)))
示例#11
0
 def _log_prob(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
         probs = self._probs_parameter_no_checks()
         if not self.validate_args:
             # For consistency with cdf, we take the floor.
             x = tf.floor(x)
         safe_domain = tf.where(tf.equal(x, 0.), tf.zeros_like(probs),
                                probs)
         return x * tf.math.log1p(-safe_domain) + tf.math.log(probs)
示例#12
0
 def _prob(self, x):
     low = tf.convert_to_tensor(self.low)
     high = tf.convert_to_tensor(self.high)
     return tf.where(
         tf.math.is_nan(x),
         x,
         tf.where(
             # This > is only sound for continuous uniform
             (x < low) | (x > high),
             tf.zeros_like(x),
             tf.ones_like(x) / self._range(low=low, high=high)))
 def _assertions(self, x):
     if not self.validate_args:
         return []
     shape = tf.shape(x)
     is_matrix = assert_util.assert_rank_at_least(
         x, 2, message="Input must have rank at least 2.")
     is_square = assert_util.assert_equal(
         shape[-2], shape[-1], message="Input must be a square matrix.")
     above_diagonal = tf.linalg.band_part(
         tf.linalg.set_diag(x, tf.zeros(shape[:-1], dtype=tf.float32)), 0,
         -1)
     is_lower_triangular = assert_util.assert_equal(
         above_diagonal,
         tf.zeros_like(above_diagonal),
         message="Input must be lower triangular.")
     # A lower triangular matrix is nonsingular iff all its diagonal entries are
     # nonzero.
     diag_part = tf.linalg.diag_part(x)
     is_nonsingular = assert_util.assert_none_equal(
         diag_part,
         tf.zeros_like(diag_part),
         message="Input must have all diagonal entries nonzero.")
     return [is_matrix, is_square, is_lower_triangular, is_nonsingular]
示例#14
0
  def _cdf(self, x):
    # CDF(x) at positive integer x is the probability that the Zipf variable is
    # less than or equal to x; given by the formula:
    #     CDF(x) = 1 - (zeta(power, x + 1) / Z)
    # For fractional x, the CDF is equal to the CDF at n = floor(x).
    # For x < 1, the CDF is zero.

    # If interpolate_nondiscrete is True, we return a continuous relaxation
    # which agrees with the CDF at integer points.
    power = tf.convert_to_tensor(self.power)
    x = tf.cast(x, power.dtype)
    safe_x = tf.maximum(x if self.interpolate_nondiscrete else tf.floor(x), 0.)

    cdf = 1. - (
        tf.math.zeta(power, safe_x + 1.) / tf.math.zeta(power, 1.))
    return tf.where(x < 1., tf.zeros_like(cdf), cdf)
示例#15
0
 def _sample_n(self, n, seed=None):
     concentration = tf.convert_to_tensor(self.concentration)
     mixing_concentration = tf.convert_to_tensor(self.mixing_concentration)
     mixing_rate = tf.convert_to_tensor(self.mixing_rate)
     seed = SeedStream(seed, 'gamma_gamma')
     rate = tf.random.gamma(
         shape=[n],
         # Be sure to draw enough rates for the fully-broadcasted gamma-gamma.
         alpha=mixing_concentration + tf.zeros_like(concentration),
         beta=mixing_rate,
         dtype=self.dtype,
         seed=seed())
     return tf.random.gamma(shape=[],
                            alpha=concentration,
                            beta=rate,
                            dtype=self.dtype,
                            seed=seed())
    def _sample_n(self, n, seed=None):
        # Like with the univariate Student's t, sampling can be implemented as a
        # ratio of samples from a multivariate gaussian with the appropriate
        # covariance matrix and a sample from the chi-squared distribution.
        seed = SeedStream(seed, salt='multivariate t')

        loc = tf.broadcast_to(self.loc, self._sample_shape())
        mvn = mvn_linear_operator.MultivariateNormalLinearOperator(
            loc=tf.zeros_like(loc), scale=self.scale)
        normal_samp = mvn.sample(n, seed=seed())

        df = tf.broadcast_to(self.df, self.batch_shape_tensor())
        chi2 = chi2_lib.Chi2(df=df)
        chi2_samp = chi2.sample(n, seed=seed())

        return (
            self._loc +
            normal_samp * tf.math.rsqrt(chi2_samp / self._df)[..., tf.newaxis])
示例#17
0
 def _sample_n(self, n, seed=None):
   # Need to create logits corresponding to [p, 1 - p].
   # Note that for this distributions, logits corresponds to
   # inverse sigmoid(p) while in multivariate distributions,
   # such as multinomial this corresponds to log(p).
   # Because of this, when we construct the logits for the multinomial
   # sampler, we'll have to be careful.
   # log(p) = log(sigmoid(logits)) = logits - softplus(logits)
   # log(1 - p) = log(1 - sigmoid(logits)) = -softplus(logits)
   # Because softmax is invariant to a constant shift in all inputs,
   # we can offset the logits by softplus(logits) so that we can use
   # [logits, 0.] as our input.
   orig_logits = self._logits_parameter_no_checks()
   logits = tf.stack([orig_logits, tf.zeros_like(orig_logits)], axis=-1)
   return multinomial.draw_sample(
       num_samples=n,
       num_classes=2,
       logits=logits,
       num_trials=tf.cast(self.total_count, dtype=tf.int32),
       dtype=self.dtype,
       seed=seed)[..., 0]
 def _sample_3d(self, n, seed=None):
   """Specialized inversion sampler for 3D."""
   seed = SeedStream(seed, salt='von_mises_fisher_3d')
   u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
   z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype)
   # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could
   # be bisected for bounded sampling runtime (i.e. not rejection sampling).
   # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/
   # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa
   # We must protect against both kappa and z being zero.
   safe_conc = tf.where(self.concentration > 0, self.concentration,
                        tf.ones_like(self.concentration))
   safe_z = tf.where(z > 0, z, tf.ones_like(z))
   safe_u = 1 + tf.reduce_logsumexp(
       [tf.math.log(safe_z),
        tf.math.log1p(-safe_z) - 2 * safe_conc], axis=0) / safe_conc
   # Limit of the above expression as kappa->0 is 2*z-1
   u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u, 2 * z - 1)
   # Limit of the expression as z->0 is -1.
   u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u)
   if not self._allow_nan_stats:
     u = tf.debugging.check_numerics(u, 'u in _sample_3d')
   return u[..., tf.newaxis]
示例#19
0
    def _cdf(self, x):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        interval_length = high - low
        # Due to the PDF being not smooth at the peak, we have to treat each side
        # somewhat differently. The PDF is two line segments, and thus we get
        # quadratics here for the CDF.
        result_inside_interval = tf.where(
            (x >= low) & (x <= peak),
            # (x - low) ** 2 / ((high - low) * (peak - low))
            tf.math.squared_difference(x, low) / (interval_length *
                                                  (peak - low)),
            # 1 - (high - x) ** 2 / ((high - low) * (high - peak))
            1. - tf.math.squared_difference(high, x) / (interval_length *
                                                        (high - peak)))

        # We now add that the left tail is 0 and the right tail is 1.
        result_if_not_big = tf.where(x < low, tf.zeros_like(x),
                                     result_inside_interval)

        return tf.where(x >= high, tf.ones_like(x), result_if_not_big)
示例#20
0
    def _prob(self, x):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        if self.validate_args:
            with tf.control_dependencies([
                    assert_util.assert_greater_equal(x, low),
                    assert_util.assert_less_equal(x, high)
            ]):
                x = tf.identity(x)

        interval_length = high - low
        # This is the pdf function when a low <= high <= x. This looks like
        # a triangle, so we have to treat each line segment separately.
        result_inside_interval = tf.where(
            (x >= low) & (x <= peak),
            # Line segment from (low, 0) to (peak, 2 / (high - low)).
            2. * (x - low) / (interval_length * (peak - low)),
            # Line segment from (peak, 2 / (high - low)) to (high, 0).
            2. * (high - x) / (interval_length * (high - peak)))

        return tf.where((x < low) | (x > high), tf.zeros_like(x),
                        result_inside_interval)
 def _covariance(self):
   # Derivation: https://sachinruk.github.io/blog/von-Mises-Fisher/
   event_dim = tf.compat.dimension_value(self.event_shape[0])
   if event_dim is None:
     raise ValueError('event shape must be statically known for _bessel_ive')
   # TODO(bjp): Enable this; numerically unstable.
   if event_dim > 2:
     raise ValueError('vMF covariance is numerically unstable for dim>2')
   concentration = self.concentration[..., tf.newaxis]
   safe_conc = tf.where(concentration > 0, concentration,
                        tf.ones_like(concentration))
   h = (_bessel_ive(event_dim / 2, safe_conc) /
        _bessel_ive(event_dim / 2 - 1, safe_conc))
   intermediate = (
       tf.matmul(self.mean_direction[..., :, tf.newaxis],
                 self.mean_direction[..., tf.newaxis, :]) *
       (1 - event_dim * h / safe_conc - h**2)[..., tf.newaxis])
   cov = tf.linalg.set_diag(
       intermediate,
       tf.linalg.diag_part(intermediate) + (h / safe_conc))
   return tf.where(
       concentration[..., tf.newaxis] > tf.zeros_like(cov), cov,
       tf.linalg.eye(event_dim, batch_shape=self.batch_shape_tensor()) /
       event_dim)
示例#22
0
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
    """Computes the (partial) pivoted cholesky decomposition of `matrix`.

  The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
  of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The
  currently-worst-approximated diagonal element is selected as the pivot at each
  iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn,
  N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`.
  Note that, unlike the Cholesky decomposition, `lr` is not triangular even in
  a rectangular-matrix sense. However, under a permutation it could be made
  triangular (it has one more zero in each column as you move to the right).

  Such a matrix can be useful as a preconditioner for conjugate gradient
  optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be
  cheaply done via the Woodbury matrix identity, as implemented by
  `tf.linalg.LinearOperatorLowRankUpdate`.

  Args:
    matrix: Floating point `Tensor` batch of symmetric, positive definite
      matrices.
    max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
      approximation.
    diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
      errors of all diagonal elements of `lr @ lr.T` are each lower than
      `element * diag_rtol`, iteration is permitted to terminate early.
    name: Optional name for the op.

  Returns:
    lr: Low rank pivoted Cholesky approximation of `matrix`.

  #### References

  [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the
       pivoted Cholesky decomposition. _Applied numerical mathematics_,
       62(4):428-440, 2012.

  [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points.
       _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114
  """
    with tf.name_scope(name or 'pivoted_cholesky'):
        dtype = dtype_util.common_dtype([matrix, diag_rtol],
                                        dtype_hint=tf.float32)
        matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype)
        if tensorshape_util.rank(matrix.shape) is None:
            raise NotImplementedError(
                'Rank of `matrix` must be known statically')

        max_rank = tf.convert_to_tensor(max_rank,
                                        name='max_rank',
                                        dtype=tf.int64)
        max_rank = tf.minimum(
            max_rank,
            prefer_static.shape(matrix, out_type=tf.int64)[-1])
        diag_rtol = tf.convert_to_tensor(diag_rtol,
                                         dtype=dtype,
                                         name='diag_rtol')
        matrix_diag = tf.linalg.diag_part(matrix)
        # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs.
        orig_error = tf.reduce_max(matrix_diag, axis=-1)

        def cond(m, pchol, perm, matrix_diag):
            """Condition for `tf.while_loop` continuation."""
            del pchol
            del perm
            error = tf.linalg.norm(matrix_diag, ord=1, axis=-1)
            max_err = tf.reduce_max(error / orig_error)
            return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol))

        batch_dims = tensorshape_util.rank(matrix.shape) - 2

        def batch_gather(params, indices, axis=-1):
            return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)

        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(permuted_diag, axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag

        m = np.int64(0)
        pchol = tf.zeros_like(matrix[..., :max_rank, :])
        matrix_shape = prefer_static.shape(matrix, out_type=tf.int64)
        perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]),
                               matrix_shape[:-1])
        _, pchol, _, _ = tf.while_loop(cond=cond,
                                       body=body,
                                       loop_vars=(m, pchol, perm, matrix_diag))
        pchol = tf.linalg.matrix_transpose(pchol)
        tensorshape_util.set_shape(
            pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
        return pchol
 def _mean(self):
   # Shape is broadcasted with + tf.zeros_like().
   return self.loc + tf.zeros_like(self.concentration)
示例#24
0
 def _variance(self):
     return tf.zeros_like(self.loc)
示例#25
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.
    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.
    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        with tf.name_scope(mcmc_util.make_name(self.name, 'tmc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.post_tempering_inverse_temperatures,
                name='inverse_temperatures')

            steps_at_temperature = tf.convert_to_tensor(
                previous_kernel_results.steps_at_temperature,
                name='number of steps')

            target_score_for_inner_kernel = partial(self.target_score_fn,
                                                    sigma=inverse_temperatures)
            target_log_prob_for_inner_kernel = partial(
                self.target_log_prob_fn, sigma=inverse_temperatures)

            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel,
                    target_score_for_inner_kernel, inverse_temperatures)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is '
                    'deprecated. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel,
                    target_score_for_inner_kernel, inverse_temperatures,
                    self._seed_stream())

            if seed is not None:
                seed = samplers.sanitize_seed(seed)
                inner_seed, swap_seed, logu_seed = samplers.split_seed(
                    seed, n=3, salt='tmc_one_step')
                inner_kwargs = dict(seed=inner_seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                inner_kwargs = {}
                swap_seed, logu_seed = samplers.split_seed(self._seed_stream())

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = current_state
            else:
                states = [current_state]
            print(states)
            [
                new_state,
                pre_tempering_results,
            ] = inner_kernel.one_step(
                states, previous_kernel_results.post_tempering_results,
                **inner_kwargs)

            # Now that we have run one step, we consider maybe lowering the temperature
            # Proposed new temperature
            proposed_inverse_temperatures = tf.clip_by_value(
                self.gamma * inverse_temperatures, self.min_temp, 1e6)
            dtype = inverse_temperatures.dtype

            # We will lower the temperature if this new proposed step is compatible with
            # a temperature swap
            v = new_state[0] - states[0]
            cs = states[0]

            @jax.vmap
            def integrand(t):
                return jnp.sum(self._parameters['target_score_fn'](
                    t * v + cs, inverse_temperatures) * v,
                               axis=-1)

            delta_logp1 = simps(integrand, 0., 1.,
                                self._parameters['num_delta_logp_steps'])

            # Now we compute the reverse
            v = -v
            cs = new_state[0]

            @jax.vmap
            def integrand(t):
                return jnp.sum(self._parameters['target_score_fn'](
                    t * v + cs, proposed_inverse_temperatures) * v,
                               axis=-1)

            delta_logp2 = simps(integrand, 0., 1.,
                                self._parameters['num_delta_logp_steps'])

            log_accept_ratio = (delta_logp1 + delta_logp2)

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=log_accept_ratio.shape,
                                 dtype=dtype,
                                 seed=logu_seed))

            is_tempering_accepted_mask = tf.less(
                log_uniform,
                log_accept_ratio,
                name='is_tempering_accepted_mask')

            is_min_steps_satisfied = tf.greater(
                steps_at_temperature,
                self.min_steps_per_temp * tf.ones_like(steps_at_temperature),
                name='is_min_steps_satisfied')

            # Only propose tempering if the chain was going to accept this point anyway
            is_tempering_accepted_mask = tf.math.logical_and(
                is_tempering_accepted_mask, pre_tempering_results.is_accepted)

            is_tempering_accepted_mask = tf.math.logical_and(
                is_tempering_accepted_mask, is_min_steps_satisfied)

            # Updating accepted inverse temperatures
            post_tempering_inverse_temperatures = mcmc_util.choose(
                is_tempering_accepted_mask, proposed_inverse_temperatures,
                inverse_temperatures)

            steps_at_temperature = mcmc_util.choose(
                is_tempering_accepted_mask,
                tf.zeros_like(steps_at_temperature), steps_at_temperature + 1)

            # Invalidating and recomputing results
            [
                new_target_log_prob,
                new_grads_target_log_prob,
            ] = mcmc_util.maybe_call_fn_and_grads(
                partial(self.target_log_prob_fn,
                        sigma=post_tempering_inverse_temperatures), new_state)

            # Updating inner kernel results
            post_tempering_results = pre_tempering_results._replace(
                proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype),
                proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype),
            )

            if isinstance(post_tempering_results.accepted_results,
                          hmc.UncalibratedHamiltonianMonteCarloKernelResults):
                post_tempering_results = post_tempering_results._replace(
                    accepted_results=post_tempering_results.accepted_results.
                    _replace(target_log_prob=new_target_log_prob,
                             grads_target_log_prob=new_grads_target_log_prob))
            elif isinstance(
                    post_tempering_results.accepted_results,
                    random_walk_metropolis.UncalibratedRandomWalkResults):
                post_tempering_results = post_tempering_results._replace(
                    accepted_results=post_tempering_results.accepted_results.
                    _replace(target_log_prob=new_target_log_prob))
            else:
                # TODO(b/143702650) Handle other kernels.
                raise NotImplementedError(
                    'Only HMC and RWMH Kernels are handled at this time. Please file a '
                    'request with the TensorFlow Probability team.')

            new_kernel_results = TemperedMCKernelResults(
                pre_tempering_results=pre_tempering_results,
                post_tempering_results=post_tempering_results,
                pre_tempering_inverse_temperatures=inverse_temperatures,
                post_tempering_inverse_temperatures=
                post_tempering_inverse_temperatures,
                tempering_log_accept_ratio=log_accept_ratio,
                steps_at_temperature=steps_at_temperature,
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            return new_state[0], new_kernel_results
示例#26
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.
    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).
    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'tmc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
                init_state)

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            target_score_for_inner_kernel = partial(self.target_score_fn,
                                                    sigma=inverse_temperatures)
            target_log_prob_for_inner_kernel = partial(
                self.target_log_prob_fn, sigma=inverse_temperatures)

            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we fall back to the previous behavior and warn.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel,
                    target_score_for_inner_kernel, inverse_temperatures)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The second (`seed`) argument to `ReplicaExchangeMC`s '
                    '`make_kernel_fn` is deprecated. `TransitionKernel` instances now '
                    'receive seeds via `bootstrap_results` and `one_step`. This '
                    'fallback may become an error 2020-09-20.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel,
                    target_score_for_inner_kernel, inverse_temperatures,
                    self._seed_stream())

            inner_results = inner_kernel.bootstrap_results(init_state)
            post_tempering_results = inner_results

            # Invalidating and recomputing results
            [
                new_target_log_prob,
                new_grads_target_log_prob,
            ] = mcmc_util.maybe_call_fn_and_grads(
                partial(self.target_log_prob_fn, sigma=inverse_temperatures),
                init_state)

            # Updating inner kernel results
            dtype = inverse_temperatures.dtype
            post_tempering_results = post_tempering_results._replace(
                proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype),
                proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype),
            )

            if isinstance(post_tempering_results.accepted_results,
                          hmc.UncalibratedHamiltonianMonteCarloKernelResults):
                post_tempering_results = post_tempering_results._replace(
                    accepted_results=post_tempering_results.accepted_results.
                    _replace(target_log_prob=new_target_log_prob,
                             grads_target_log_prob=new_grads_target_log_prob))
            elif isinstance(
                    post_tempering_results.accepted_results,
                    random_walk_metropolis.UncalibratedRandomWalkResults):
                post_tempering_results = post_tempering_results._replace(
                    accepted_results=post_tempering_results.accepted_results.
                    _replace(target_log_prob=new_target_log_prob))
            else:
                # TODO(b/143702650) Handle other kernels.
                raise NotImplementedError(
                    'Only HMC and RWMH Kernels are handled at this time. Please file a '
                    'request with the TensorFlow Probability team.')

            return TemperedMCKernelResults(
                pre_tempering_results=inner_results,
                post_tempering_results=post_tempering_results,
                pre_tempering_inverse_temperatures=inverse_temperatures,
                post_tempering_inverse_temperatures=inverse_temperatures,
                tempering_log_accept_ratio=tf.zeros_like(inverse_temperatures),
                steps_at_temperature=tf.zeros_like(inverse_temperatures,
                                                   dtype=tf.int32),
                seed=samplers.zeros_seed(),
            )
 def _log_unnormalized_prob(self, samples):
   samples = self._maybe_assert_valid_sample(samples)
   bcast_mean_dir = (self.mean_direction +
                     tf.zeros_like(self.concentration)[..., tf.newaxis])
   inner_product = tf.reduce_sum(samples * bcast_mean_dir, axis=-1)
   return self.concentration * inner_product
示例#28
0
    def _sample_n(self, num_samples, seed=None, name=None):
        """Returns a Tensor of samples from an LKJ distribution.

    Args:
      num_samples: Python `int`. The number of samples to draw.
      seed: Python integer seed for RNG
      name: Python `str` name prefixed to Ops created by this function.

    Returns:
      samples: A Tensor of correlation matrices with shape `[n, B, D, D]`,
        where `B` is the shape of the `concentration` parameter, and `D`
        is the `dimension`.

    Raises:
      ValueError: If `dimension` is negative.
    """
        if self.dimension < 0:
            raise ValueError(
                'Cannot sample negative-dimension correlation matrices.')
        # Notation below: B is the batch shape, i.e., tf.shape(concentration)
        seed = SeedStream(seed, 'sample_lkj')
        with tf.name_scope('sample_lkj' or name):
            concentration = tf.convert_to_tensor(self.concentration)
            if not dtype_util.is_floating(concentration.dtype):
                raise TypeError(
                    'The concentration argument should have floating type, not '
                    '{}'.format(dtype_util.name(concentration.dtype)))

            concentration = _replicate(num_samples, concentration)
            concentration_shape = tf.shape(concentration)
            if self.dimension <= 1:
                # For any dimension <= 1, there is only one possible correlation matrix.
                shape = tf.concat(
                    [concentration_shape, [self.dimension, self.dimension]],
                    axis=0)
                return tf.ones(shape=shape, dtype=concentration.dtype)
            beta_conc = concentration + (self.dimension - 2.) / 2.
            beta_dist = beta.Beta(concentration1=beta_conc,
                                  concentration0=beta_conc)

            # Note that the sampler below deviates from [1], by doing the sampling in
            # cholesky space. This does not change the fundamental logic of the
            # sampler, but does speed up the sampling.

            # This is the correlation coefficient between the first two dimensions.
            # This is also `r` in reference [1].
            corr12 = 2. * beta_dist.sample(seed=seed()) - 1.

            # Below we construct the Cholesky of the initial 2x2 correlation matrix,
            # which is of the form:
            # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the
            # first two dimensions.
            # This is the top-left corner of the cholesky of the final sample.
            first_row = tf.concat([
                tf.ones_like(corr12)[..., tf.newaxis],
                tf.zeros_like(corr12)[..., tf.newaxis]
            ],
                                  axis=-1)
            second_row = tf.concat([
                corr12[..., tf.newaxis],
                tf.sqrt(1 - corr12**2)[..., tf.newaxis]
            ],
                                   axis=-1)

            chol_result = tf.concat([
                first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :]
            ],
                                    axis=-2)

            for n in range(2, self.dimension):
                # Loop invariant: on entry, result has shape B + [n, n]
                beta_conc = beta_conc - 0.5
                # norm is y in reference [1].
                norm = beta.Beta(concentration1=n / 2.,
                                 concentration0=beta_conc).sample(seed=seed())
                # distance shape: B + [1] for broadcast
                distance = tf.sqrt(norm)[..., tf.newaxis]
                # direction is u in reference [1].
                # direction shape: B + [n]
                direction = _uniform_unit_norm(n, concentration_shape,
                                               concentration.dtype, seed)
                # raw_correlation is w in reference [1].
                raw_correlation = distance * direction  # shape: B + [n]

                # This is the next row in the cholesky of the result,
                # which differs from the construction in reference [1].
                # In the reference, the new row `z` = chol_result @ raw_correlation^T
                # = C @ raw_correlation^T (where as short hand we use C = chol_result).
                # We prove that the below equation is the right row to add to the
                # cholesky, by showing equality with reference [1].
                # Let S be the sample constructed so far, and let `z` be as in
                # reference [1]. Then at this iteration, the new sample S' will be
                # [[S z^T]
                #  [z 1]]
                # In our case we have the cholesky decomposition factor C, so
                # we want our new row x (same size as z) to satisfy:
                #  [[S z^T]  [[C 0]    [[C^T  x^T]         [[CC^T  Cx^T]
                #   [z 1]] =  [x k]]    [0     k]]  =       [xC^t   xx^T + k**2]]
                # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible,
                # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k
                # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 -
                # distance**2).
                new_row = tf.concat(
                    [raw_correlation,
                     tf.sqrt(1. - norm[..., tf.newaxis])],
                    axis=-1)

                # Finally add this new row, by growing the cholesky of the result.
                chol_result = tf.concat([
                    chol_result,
                    tf.zeros_like(chol_result[..., 0][..., tf.newaxis])
                ],
                                        axis=-1)

                chol_result = tf.concat(
                    [chol_result, new_row[..., tf.newaxis, :]], axis=-2)

            if self.input_output_cholesky:
                return chol_result

            result = tf.matmul(chol_result, chol_result, transpose_b=True)
            # The diagonal for a correlation matrix should always be ones. Due to
            # numerical instability the matmul might not achieve that, so manually set
            # these to ones.
            result = tf.linalg.set_diag(
                result, tf.ones(shape=tf.shape(result)[:-1],
                                dtype=result.dtype))
            # This sampling algorithm can produce near-PSD matrices on which standard
            # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals`
            # fail. Specifically, as documented in b/116828694, around 2% of trials
            # of 900,000 5x5 matrices (distributed according to 9 different
            # concentration parameter values) contained at least one matrix on which
            # the Cholesky decomposition failed.
            return result
 def _mode(self):
   """The mode of the von Mises-Fisher distribution is the mean direction."""
   return (self.mean_direction +
           tf.zeros_like(self.concentration)[..., tf.newaxis])
示例#30
0
 def _create_polynomial(var, coeffs):
     """Compute n_th order polynomial via Horner's method."""
     coeffs = np.array(coeffs, dtype_util.as_numpy_dtype(var.dtype))
     if not coeffs.size:
         return tf.zeros_like(var)
     return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var