예제 #1
0
    def _log_prob(self, x):
        logits = self._logits_parameter_no_checks()
        event_size = self._event_size(logits)

        x = tf.cast(x, logits.dtype)
        x = self._maybe_assert_valid_sample(x, dtype=logits.dtype)

        # broadcast logits or x if need be.
        if (not tensorshape_util.is_fully_defined(x.shape)
                or not tensorshape_util.is_fully_defined(logits.shape)
                or x.shape != logits.shape):
            broadcast_shape = tf.broadcast_dynamic_shape(
                tf.shape(logits), tf.shape(x))
            logits = tf.broadcast_to(logits, broadcast_shape)
            x = tf.broadcast_to(x, broadcast_shape)

        logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1))
        logits_2d = tf.reshape(logits, [-1, event_size])
        x_2d = tf.reshape(x, [-1, event_size])
        ret = -tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.stop_gradient(x_2d), logits=logits_2d)

        # Reshape back to user-supplied batch and sample dims prior to 2D reshape.
        ret = tf.reshape(ret, logits_shape)
        return ret
예제 #2
0
 def _cdf(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
         concentration1 = tf.convert_to_tensor(self.concentration1)
         concentration0 = tf.convert_to_tensor(self.concentration0)
         shape = self._batch_shape_tensor(concentration1, concentration0)
         concentration1 = tf.broadcast_to(concentration1, shape)
         concentration0 = tf.broadcast_to(concentration0, shape)
         return tf.math.betainc(concentration1, concentration0, x)
예제 #3
0
 def _cdf(self, x):
     x = self._maybe_assert_valid_sample(x)
     logits = self._logits_parameter_no_checks()
     total_count = tf.convert_to_tensor(self.total_count)
     shape = self._batch_shape_tensor(logits_or_probs=logits,
                                      total_count=total_count)
     return tf.math.betainc(tf.broadcast_to(total_count, shape),
                            tf.broadcast_to(1. + x, shape),
                            tf.broadcast_to(tf.sigmoid(-logits), shape))
    def _variance(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, 1, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size 1 num_states observation_event_size
            flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps 1 observation_event_size

            variances = self._observation_distribution.variance()
            variances = tf.broadcast_to(variances, means_shape)
            # variances :: batch_shape num_states observation_event_shape
            flat_variances = tf.reshape(variances, flat_means_shape)
            # flat_variances :: batch_size 1 num_states observation_event_size

            # For a mixture of n distributions with mixture probabilities
            # p[i], and where the individual distributions have means and
            # variances given by mean[i] and var[i], the variance of
            # the mixture is given by:
            #
            # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2)

            flat_variance = tf.einsum("ijk,jikl->jil", flat_probs,
                                      (flat_means - flat_mean)**2 +
                                      flat_variances)
            # flat_variance :: batch_size num_steps observation_event_size

            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)

            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_variance, unflat_mean_shape)
    def _log_prob(self, value):
        with tf.control_dependencies(self._runtime_assertions):
            # The argument `value` is a tensor of sequences of observations.
            # `observation_batch_shape` is the shape of that tensor with the
            # sequence part removed.
            # `observation_batch_shape` is then broadcast to the full batch shape
            # to give the `batch_shape` that defines the shape of the result.

            observation_tensor_shape = tf.shape(value)
            observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                               _underlying_event_rank]
            # value :: observation_batch_shape num_steps observation_event_shape
            batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                     self.batch_shape_tensor())
            log_init = tf.broadcast_to(
                self._log_init,
                tf.concat([batch_shape, [self._num_states]], axis=0))
            # log_init :: batch_shape num_states
            log_transition = self._log_trans

            # `observation_event_shape` is the shape of each sequence of observations
            # emitted by the model.
            observation_event_shape = observation_tensor_shape[
                -1 - self._underlying_event_rank:]
            working_obs = tf.broadcast_to(
                value, tf.concat([batch_shape, observation_event_shape],
                                 axis=0))
            # working_obs :: batch_shape observation_event_shape
            r = self._underlying_event_rank

            # Move index into sequence of observations to front so we can apply
            # tf.foldl
            working_obs = distribution_util.move_dimension(
                working_obs, -1 - r, 0)[..., tf.newaxis]
            # working_obs :: num_steps batch_shape underlying_event_shape
            observation_probs = (
                self._observation_distribution.log_prob(working_obs))

            def forward_step(log_prev_step, log_prob_observation):
                return _log_vector_matrix(
                    log_prev_step, log_transition) + log_prob_observation

            fwd_prob = tf.foldl(forward_step,
                                observation_probs,
                                initializer=log_init)
            # fwd_prob :: batch_shape num_states

            log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
            # log_prob :: batch_shape

            return log_prob
예제 #6
0
def _swap_m_with_i(vecs, m, i):
    """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.)

  Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped
  per-vector indices `i`, this function swaps elements `m` and `i` in each
  vector. For the use-case below, these are permutation vectors.

  Args:
    vecs: Vectors on which we perform the swap, int64 `Tensor`.
    m: Scalar int64 `Tensor`, the index into which the `i`th element is going.
    i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into
      which the `m`th element is going.

  Returns:
    vecs: The updated vectors.
  """
    vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs')
    m = tf.convert_to_tensor(m, dtype=tf.int64, name='m')
    i = tf.convert_to_tensor(i, dtype=tf.int64, name='i')
    trailing_elts = tf.broadcast_to(
        tf.range(m + 1,
                 prefer_static.shape(vecs, out_type=tf.int64)[-1]),
        prefer_static.shape(vecs[..., m + 1:]))
    trailing_elts = tf.where(tf.equal(trailing_elts, i),
                             tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:])
    # TODO(bjp): Could we use tensor_scatter_nd_update?
    vecs_shape = vecs.shape
    vecs = tf.concat([
        vecs[..., :m],
        tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1),
        trailing_elts
    ],
                     axis=-1)
    tensorshape_util.set_shape(vecs, vecs_shape)
    return vecs
    def _marginal_hidden_probs(self):
        """Compute marginal pdf for each individual observable."""

        initial_log_probs = tf.broadcast_to(
            self._log_init,
            tf.concat([self.batch_shape_tensor(), [self._num_states]], axis=0))

        # initial_log_probs :: batch_shape num_states

        def _scan_multiple_steps():
            """Perform `scan` operation when `num_steps` > 1."""

            transition_log_probs = self._log_trans

            def forward_step(log_probs, _):
                return _log_vector_matrix(log_probs, transition_log_probs)

            dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)

            forward_log_probs = tf.scan(forward_step,
                                        dummy_index,
                                        initializer=initial_log_probs,
                                        name="forward_log_probs")

            return tf.concat([[initial_log_probs], forward_log_probs], axis=0)

        forward_log_probs = prefer_static.cond(
            self._num_steps > 1, _scan_multiple_steps,
            lambda: initial_log_probs[tf.newaxis, ...])

        return tf.exp(forward_log_probs)
    def _mean(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError("mean is not implemented for non-affine "
                                      "bijectors")

        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        x = self.distribution.mean(**distribution_kwargs)

        if self._is_maybe_batch_override or self._is_maybe_event_override:
            # A batch (respectively event) shape override is only allowed if the batch
            # (event) shape of the base distribution is [], so concatenating all the
            # shapes does the right thing.
            new_shape = prefer_static.concat([
                prefer_static.ones_like(self._override_batch_shape),
                self.distribution.batch_shape_tensor(),
                prefer_static.ones_like(self._override_event_shape),
                self.distribution.event_shape_tensor(),
            ], 0)
            x = tf.reshape(x, new_shape)
            new_shape = prefer_static.concat(
                [self.batch_shape_tensor(),
                 self.event_shape_tensor()], 0)
            x = tf.broadcast_to(x, new_shape)

        y = self.bijector.forward(x, **bijector_kwargs)

        sample_shape = tf.convert_to_tensor([],
                                            dtype=tf.int32,
                                            name="sample_shape")
        y = self._set_sample_static_shape(y, sample_shape)
        return y
예제 #9
0
 def _sample_n(self, n, seed=None):
     seed = SeedStream(seed, "beta")
     concentration1 = tf.convert_to_tensor(self.concentration1)
     concentration0 = tf.convert_to_tensor(self.concentration0)
     shape = self._batch_shape_tensor(concentration1, concentration0)
     expanded_concentration1 = tf.broadcast_to(concentration1, shape)
     expanded_concentration0 = tf.broadcast_to(concentration0, shape)
     gamma1_sample = tf.random.gamma(shape=[n],
                                     alpha=expanded_concentration1,
                                     dtype=self.dtype,
                                     seed=seed())
     gamma2_sample = tf.random.gamma(shape=[n],
                                     alpha=expanded_concentration0,
                                     dtype=self.dtype,
                                     seed=seed())
     beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
     return beta_sample
예제 #10
0
 def _cdf(self, x):
     df = tf.convert_to_tensor(self.df)
     # Take Abs(scale) to make subsequent where work correctly.
     y = (x - self.loc) / tf.abs(self.scale)
     x_t = df / (y**2. + df)
     neg_cdf = 0.5 * tf.math.betainc(
         0.5 * tf.broadcast_to(df, prefer_static.shape(x_t)), 0.5, x_t)
     return tf.where(y < 0., neg_cdf, 1. - neg_cdf)
예제 #11
0
def cholesky_concat(chol, cols, name=None):
    """Concatenates `chol @ chol.T` with additional rows and columns.

  This operation is conceptually identical to:
  ```python
  def cholesky_concat_slow(chol, cols):  # cols shaped (n + m) x m = z x m
    mat = tf.matmul(chol, chol, adjoint_b=True)  # batch of n x n
    # Concat columns.
    mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1)  # n x z
    # Concat rows.
    mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2)  # z x z
    return tf.linalg.cholesky(mat)
  ```
  but whereas `cholesky_concat_slow` would cost `O(z**3)` work,
  `cholesky_concat` only costs `O(z**2 + m**3)` work.

  The resulting (implicit) matrix must be symmetric and positive definite.
  Thus, the bottom right `m x m` must be self-adjoint, and we do not require a
  separate `rows` argument (which can be inferred from `conj(cols.T)`).

  Args:
    chol: Cholesky decomposition of `mat = chol @ chol.T`.
    cols: The new columns whose first `n` rows we would like concatenated to the
      right of `mat = chol @ chol.T`, and whose conjugate transpose we would
      like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor`
      with final dims `(n+m, m)`. The first `n` rows are the top right rectangle
      (their conjugate transpose forms the bottom left), and the bottom `m x m`
      is self-adjoint.
    name: Optional name for this op.

  Returns:
    chol_concat: The Cholesky decomposition of:
      ```
      [ [ mat  cols[:n, :] ]
        [   conj(cols.T)   ] ]
      ```
  """
    with tf.name_scope(name or 'cholesky_extend'):
        dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32)
        chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype)
        cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype)
        n = prefer_static.shape(chol)[-1]
        mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :]
        solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast(
            chol, mat_nm)
        lower_right_mm = tf.linalg.cholesky(
            mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True))
        lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm))
        out_batch = prefer_static.shape(solved_nm)[:-2]
        chol = tf.broadcast_to(
            chol,
            tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0))
        top_right_zeros_nm = tf.zeros_like(solved_nm)
        return tf.concat([
            tf.concat([chol, top_right_zeros_nm], axis=-1),
            tf.concat([lower_left_mn, lower_right_mm], axis=-1)
        ],
                         axis=-2)
    def _sample_n(self, n, seed=None):
        # Like with the univariate Student's t, sampling can be implemented as a
        # ratio of samples from a multivariate gaussian with the appropriate
        # covariance matrix and a sample from the chi-squared distribution.
        seed = SeedStream(seed, salt='multivariate t')

        loc = tf.broadcast_to(self.loc, self._sample_shape())
        mvn = mvn_linear_operator.MultivariateNormalLinearOperator(
            loc=tf.zeros_like(loc), scale=self.scale)
        normal_samp = mvn.sample(n, seed=seed())

        df = tf.broadcast_to(self.df, self.batch_shape_tensor())
        chi2 = chi2_lib.Chi2(df=df)
        chi2_samp = chi2.sample(n, seed=seed())

        return (
            self._loc +
            normal_samp * tf.math.rsqrt(chi2_samp / self._df)[..., tf.newaxis])
예제 #13
0
 def _sample_n(self, n, seed=None):
     del seed  # unused
     loc = tf.convert_to_tensor(self.loc)
     return tf.broadcast_to(
         loc,
         tf.concat([[n],
                    self._batch_shape_tensor(loc=loc),
                    self._event_shape_tensor(loc=loc)],
                   axis=0))
예제 #14
0
 def _variance(self):
   probs = self._categorical.probs_parameter()
   outcomes = tf.broadcast_to(
       self.outcomes, shape=dist_util.prefer_static_shape(probs))
   if dtype_util.is_integer(outcomes.dtype):
     if self._validate_args:
       outcomes = dist_util.embed_check_integer_casting_closed(
           outcomes, target_dtype=probs.dtype)
     outcomes = tf.cast(outcomes, dtype=probs.dtype)
   square_d = tf.math.squared_difference(
       outcomes, self._mean(probs)[..., tf.newaxis])
   return tf.reduce_sum(probs * square_d, axis=-1)
    def _covariance(self):
        if distribution_util.is_diagonal_scale(self.scale):
            mvn_cov = tf.linalg.diag(tf.square(self.scale.diag_part()))
        else:
            mvn_cov = self.scale.matmul(self.scale.to_dense(),
                                        adjoint_arg=True)

        cov_shape = tf.concat(
            [self._sample_shape(),
             self._event_shape_tensor()], -1)
        mvn_cov = tf.broadcast_to(mvn_cov, cov_shape)
        return self._std_var_helper(mvn_cov, 'covariance', 2, lambda x: x)
    def _variance(self):
        if distribution_util.is_diagonal_scale(self.scale):
            mvn_var = tf.square(self.scale.diag_part())
        elif (isinstance(self.scale, tf.linalg.LinearOperatorLowRankUpdate)
              and self.scale.is_self_adjoint):
            mvn_var = tf.linalg.diag_part(
                self.scale.matmul(self.scale.to_dense()))
        else:
            mvn_var = tf.linalg.diag_part(
                self.scale.matmul(self.scale.to_dense(), adjoint_arg=True))

        mvn_var = tf.broadcast_to(mvn_var, self._sample_shape())
        return self._std_var_helper(mvn_var, 'variance', 1, lambda x: x)
    def _mean(self):
        mean = tf.broadcast_to(self.loc, self._sample_shape())

        if self.allow_nan_stats:
            return tf.where(self.df[..., tf.newaxis] > 1., mean,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
        else:
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.cast(1., self.dtype),
                        self.df,
                        message='Mean not defined for components of df <= 1.'),
            ]):
                return tf.identity(mean)
    def _stddev(self):
        if distribution_util.is_diagonal_scale(self.scale):
            mvn_std = tf.abs(self.scale.diag_part())
        elif (isinstance(self.scale, tf.linalg.LinearOperatorLowRankUpdate)
              and self.scale.is_self_adjoint):
            mvn_std = tf.sqrt(
                tf.linalg.diag_part(self.scale.matmul(self.scale.to_dense())))
        else:
            mvn_std = tf.sqrt(
                tf.linalg.diag_part(
                    self.scale.matmul(self.scale.to_dense(),
                                      adjoint_arg=True)))

        mvn_std = tf.broadcast_to(mvn_std, self._sample_shape())
        return self._std_var_helper(mvn_std, 'standard deviation', 1, tf.sqrt)
    def _entropy(self):
        df = tf.broadcast_to(self.df, self.batch_shape_tensor())
        num_dims = tf.cast(self.event_shape_tensor()[0], self.dtype)

        def _lbeta(concentration0, concentration1):
            return (tf.math.lgamma(concentration1) +
                    tf.math.lgamma(concentration0) -
                    tf.math.lgamma(concentration0 + concentration1))

        shape_factor = self._scale.log_abs_determinant()
        beta_factor = num_dims / 2. * (
            tf.math.log(df) + np.log(np.pi)) - tf.math.lgamma(
                num_dims / 2.) + _lbeta(num_dims / 2., df / 2.)
        digamma_factor = (num_dims + df) / 2. * (tf.math.digamma(
            (num_dims + df) / 2.) - tf.math.digamma(df / 2.))
        return shape_factor + beta_factor + digamma_factor
예제 #20
0
 def _fn(self, **kwargs):
     """Implements summary statistic, eg, mean, stddev, mode."""
     x = getattr(self.distribution, attr)(**kwargs)
     shape = prefer_static.concat([
         self.distribution.batch_shape_tensor(),
         prefer_static.ones(prefer_static.rank_from_shape(
             self.sample_shape),
                            dtype=self.sample_shape.dtype),
         self.distribution.event_shape_tensor(),
     ],
                                  axis=0)
     x = tf.reshape(x, shape=shape)
     shape = prefer_static.concat([
         self.distribution.batch_shape_tensor(),
         self.sample_shape,
         self.distribution.event_shape_tensor(),
     ],
                                  axis=0)
     return tf.broadcast_to(x, shape)
    def _mean(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size num_states observation_event_size
            flat_mean = tf.einsum("ijk,jkl->jil", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps observation_event_size
            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)
            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_mean, unflat_mean_shape)
예제 #22
0
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import jax as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
                                               validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if (tensorshape_util.rank(lower_upper.shape) is None
                or tensorshape_util.rank(lower_upper.shape) != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
        else:
            x = tf.gather(x, tf.math.invert_permutation(perm))

        x.set_shape(lower_upper.shape)
        return x
예제 #23
0
 def _mode(self):
     loc = tf.convert_to_tensor(self.loc)
     return tf.broadcast_to(loc, self._batch_shape_tensor(loc=loc))
예제 #24
0
 def _stddev(self):
     scale = tf.convert_to_tensor(self.scale)
     return tf.broadcast_to(scale * np.pi / np.sqrt(3),
                            self._batch_shape_tensor(scale=scale))
예제 #25
0
 def _entropy(self):
   scale = tf.convert_to_tensor(self.scale)
   return tf.broadcast_to(np.log(2.) + 1 + tf.math.log(scale),
                          self._batch_shape_tensor(scale=scale))
 def _mode(self):
     return tf.broadcast_to(self.loc, self._sample_shape())
예제 #27
0
 def _stddev(self):
   scale = tf.convert_to_tensor(self.scale)
   return tf.broadcast_to(np.sqrt(2.) * scale,
                          self._batch_shape_tensor(scale=scale))
예제 #28
0
 def _mode(self):
   scale = tf.convert_to_tensor(self.scale)
   return tf.broadcast_to(scale, self._batch_shape_tensor(scale=scale))
예제 #29
0
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
    """Computes the (partial) pivoted cholesky decomposition of `matrix`.

  The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
  of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The
  currently-worst-approximated diagonal element is selected as the pivot at each
  iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn,
  N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`.
  Note that, unlike the Cholesky decomposition, `lr` is not triangular even in
  a rectangular-matrix sense. However, under a permutation it could be made
  triangular (it has one more zero in each column as you move to the right).

  Such a matrix can be useful as a preconditioner for conjugate gradient
  optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be
  cheaply done via the Woodbury matrix identity, as implemented by
  `tf.linalg.LinearOperatorLowRankUpdate`.

  Args:
    matrix: Floating point `Tensor` batch of symmetric, positive definite
      matrices.
    max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
      approximation.
    diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
      errors of all diagonal elements of `lr @ lr.T` are each lower than
      `element * diag_rtol`, iteration is permitted to terminate early.
    name: Optional name for the op.

  Returns:
    lr: Low rank pivoted Cholesky approximation of `matrix`.

  #### References

  [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the
       pivoted Cholesky decomposition. _Applied numerical mathematics_,
       62(4):428-440, 2012.

  [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points.
       _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114
  """
    with tf.name_scope(name or 'pivoted_cholesky'):
        dtype = dtype_util.common_dtype([matrix, diag_rtol],
                                        dtype_hint=tf.float32)
        matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype)
        if tensorshape_util.rank(matrix.shape) is None:
            raise NotImplementedError(
                'Rank of `matrix` must be known statically')

        max_rank = tf.convert_to_tensor(max_rank,
                                        name='max_rank',
                                        dtype=tf.int64)
        max_rank = tf.minimum(
            max_rank,
            prefer_static.shape(matrix, out_type=tf.int64)[-1])
        diag_rtol = tf.convert_to_tensor(diag_rtol,
                                         dtype=dtype,
                                         name='diag_rtol')
        matrix_diag = tf.linalg.diag_part(matrix)
        # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs.
        orig_error = tf.reduce_max(matrix_diag, axis=-1)

        def cond(m, pchol, perm, matrix_diag):
            """Condition for `tf.while_loop` continuation."""
            del pchol
            del perm
            error = tf.linalg.norm(matrix_diag, ord=1, axis=-1)
            max_err = tf.reduce_max(error / orig_error)
            return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol))

        batch_dims = tensorshape_util.rank(matrix.shape) - 2

        def batch_gather(params, indices, axis=-1):
            return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)

        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(permuted_diag, axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag

        m = np.int64(0)
        pchol = tf.zeros_like(matrix[..., :max_rank, :])
        matrix_shape = prefer_static.shape(matrix, out_type=tf.int64)
        perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]),
                               matrix_shape[:-1])
        _, pchol, _, _ = tf.while_loop(cond=cond,
                                       body=body,
                                       loop_vars=(m, pchol, perm, matrix_diag))
        pchol = tf.linalg.matrix_transpose(pchol)
        tensorshape_util.set_shape(
            pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
        return pchol
예제 #30
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import jax as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                lower, permuted_rhs),
            lower=False)