コード例 #1
0
    def _log_prob(self, x):
        logits = self._logits_parameter_no_checks()
        event_size = self._event_size(logits)

        x = tf.cast(x, logits.dtype)
        x = self._maybe_assert_valid_sample(x, dtype=logits.dtype)

        # broadcast logits or x if need be.
        if (not tensorshape_util.is_fully_defined(x.shape)
                or not tensorshape_util.is_fully_defined(logits.shape)
                or x.shape != logits.shape):
            broadcast_shape = tf.broadcast_dynamic_shape(
                tf.shape(logits), tf.shape(x))
            logits = tf.broadcast_to(logits, broadcast_shape)
            x = tf.broadcast_to(x, broadcast_shape)

        logits_shape = tf.shape(tf.reduce_sum(logits, axis=-1))
        logits_2d = tf.reshape(logits, [-1, event_size])
        x_2d = tf.reshape(x, [-1, event_size])
        ret = -tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.stop_gradient(x_2d), logits=logits_2d)

        # Reshape back to user-supplied batch and sample dims prior to 2D reshape.
        ret = tf.reshape(ret, logits_shape)
        return ret
コード例 #2
0
ファイル: beta.py プロジェクト: HackerShohag/SuggestBot-bn
 def _cdf(self, x):
     with tf.control_dependencies(self._maybe_assert_valid_sample(x)):
         concentration1 = tf.convert_to_tensor(self.concentration1)
         concentration0 = tf.convert_to_tensor(self.concentration0)
         shape = self._batch_shape_tensor(concentration1, concentration0)
         concentration1 = tf.broadcast_to(concentration1, shape)
         concentration0 = tf.broadcast_to(concentration0, shape)
         return tf.math.betainc(concentration1, concentration0, x)
コード例 #3
0
 def _cdf(self, x):
     x = self._maybe_assert_valid_sample(x)
     logits = self._logits_parameter_no_checks()
     total_count = tf.convert_to_tensor(self.total_count)
     shape = self._batch_shape_tensor(logits_or_probs=logits,
                                      total_count=total_count)
     return tf.math.betainc(tf.broadcast_to(total_count, shape),
                            tf.broadcast_to(1. + x, shape),
                            tf.broadcast_to(tf.sigmoid(-logits), shape))
コード例 #4
0
    def _variance(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, 1, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size 1 num_states observation_event_size
            flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps 1 observation_event_size

            variances = self._observation_distribution.variance()
            variances = tf.broadcast_to(variances, means_shape)
            # variances :: batch_shape num_states observation_event_shape
            flat_variances = tf.reshape(variances, flat_means_shape)
            # flat_variances :: batch_size 1 num_states observation_event_size

            # For a mixture of n distributions with mixture probabilities
            # p[i], and where the individual distributions have means and
            # variances given by mean[i] and var[i], the variance of
            # the mixture is given by:
            #
            # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2)

            flat_variance = tf.einsum("ijk,jikl->jil", flat_probs,
                                      (flat_means - flat_mean)**2 +
                                      flat_variances)
            # flat_variance :: batch_size num_steps observation_event_size

            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)

            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_variance, unflat_mean_shape)
コード例 #5
0
    def _log_prob(self, value):
        with tf.control_dependencies(self._runtime_assertions):
            # The argument `value` is a tensor of sequences of observations.
            # `observation_batch_shape` is the shape of that tensor with the
            # sequence part removed.
            # `observation_batch_shape` is then broadcast to the full batch shape
            # to give the `batch_shape` that defines the shape of the result.

            observation_tensor_shape = tf.shape(value)
            observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                               _underlying_event_rank]
            # value :: observation_batch_shape num_steps observation_event_shape
            batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                     self.batch_shape_tensor())
            log_init = tf.broadcast_to(
                self._log_init,
                tf.concat([batch_shape, [self._num_states]], axis=0))
            # log_init :: batch_shape num_states
            log_transition = self._log_trans

            # `observation_event_shape` is the shape of each sequence of observations
            # emitted by the model.
            observation_event_shape = observation_tensor_shape[
                -1 - self._underlying_event_rank:]
            working_obs = tf.broadcast_to(
                value, tf.concat([batch_shape, observation_event_shape],
                                 axis=0))
            # working_obs :: batch_shape observation_event_shape
            r = self._underlying_event_rank

            # Move index into sequence of observations to front so we can apply
            # tf.foldl
            working_obs = distribution_util.move_dimension(
                working_obs, -1 - r, 0)[..., tf.newaxis]
            # working_obs :: num_steps batch_shape underlying_event_shape
            observation_probs = (
                self._observation_distribution.log_prob(working_obs))

            def forward_step(log_prev_step, log_prob_observation):
                return _log_vector_matrix(
                    log_prev_step, log_transition) + log_prob_observation

            fwd_prob = tf.foldl(forward_step,
                                observation_probs,
                                initializer=log_init)
            # fwd_prob :: batch_shape num_states

            log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
            # log_prob :: batch_shape

            return log_prob
コード例 #6
0
 def _log_moment(self, n, concentration1=None, concentration0=None):
     """Compute the n'th (uncentered) moment."""
     concentration0 = tf.convert_to_tensor(
         self.concentration0) if concentration0 is None else concentration0
     concentration1 = tf.convert_to_tensor(
         self.concentration1) if concentration1 is None else concentration1
     total_concentration = concentration1 + concentration0
     expanded_concentration1 = tf.broadcast_to(
         concentration1, tf.shape(total_concentration))
     expanded_concentration0 = tf.broadcast_to(
         concentration0, tf.shape(total_concentration))
     beta_arg0 = 1 + n / expanded_concentration1
     beta_arg = tf.stack([beta_arg0, expanded_concentration0], -1)
     return tf.math.log(expanded_concentration0) + tf.math.lbeta(beta_arg)
コード例 #7
0
    def _marginal_hidden_probs(self):
        """Compute marginal pdf for each individual observable."""

        initial_log_probs = tf.broadcast_to(
            self._log_init,
            tf.concat([self.batch_shape_tensor(), [self._num_states]], axis=0))

        # initial_log_probs :: batch_shape num_states

        def _scan_multiple_steps():
            """Perform `scan` operation when `num_steps` > 1."""

            transition_log_probs = self._log_trans

            def forward_step(log_probs, _):
                return _log_vector_matrix(log_probs, transition_log_probs)

            dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)

            forward_log_probs = tf.scan(forward_step,
                                        dummy_index,
                                        initializer=initial_log_probs,
                                        name="forward_log_probs")

            return tf.concat([[initial_log_probs], forward_log_probs], axis=0)

        forward_log_probs = prefer_static.cond(
            self._num_steps > 1, _scan_multiple_steps,
            lambda: initial_log_probs[tf.newaxis, ...])

        return tf.exp(forward_log_probs)
コード例 #8
0
ファイル: linalg.py プロジェクト: HackerShohag/SuggestBot-bn
def _swap_m_with_i(vecs, m, i):
    """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.)

  Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped
  per-vector indices `i`, this function swaps elements `m` and `i` in each
  vector. For the use-case below, these are permutation vectors.

  Args:
    vecs: Vectors on which we perform the swap, int64 `Tensor`.
    m: Scalar int64 `Tensor`, the index into which the `i`th element is going.
    i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into
      which the `m`th element is going.

  Returns:
    vecs: The updated vectors.
  """
    vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs')
    m = tf.convert_to_tensor(m, dtype=tf.int64, name='m')
    i = tf.convert_to_tensor(i, dtype=tf.int64, name='i')
    trailing_elts = tf.broadcast_to(
        tf.range(m + 1,
                 prefer_static.shape(vecs, out_type=tf.int64)[-1]),
        prefer_static.shape(vecs[..., m + 1:]))
    trailing_elts = tf.where(tf.equal(trailing_elts, i),
                             tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:])
    # TODO(bjp): Could we use tensor_scatter_nd_update?
    vecs_shape = vecs.shape
    vecs = tf.concat([
        vecs[..., :m],
        tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1),
        trailing_elts
    ],
                     axis=-1)
    tensorshape_util.set_shape(vecs, vecs_shape)
    return vecs
コード例 #9
0
    def _mean(self, **kwargs):
        if not self.bijector.is_constant_jacobian:
            raise NotImplementedError("mean is not implemented for non-affine "
                                      "bijectors")

        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        x = self.distribution.mean(**distribution_kwargs)

        if self._is_maybe_batch_override or self._is_maybe_event_override:
            # A batch (respectively event) shape override is only allowed if the batch
            # (event) shape of the base distribution is [], so concatenating all the
            # shapes does the right thing.
            new_shape = prefer_static.concat([
                prefer_static.ones_like(self._override_batch_shape),
                self.distribution.batch_shape_tensor(),
                prefer_static.ones_like(self._override_event_shape),
                self.distribution.event_shape_tensor(),
            ], 0)
            x = tf.reshape(x, new_shape)
            new_shape = prefer_static.concat(
                [self.batch_shape_tensor(),
                 self.event_shape_tensor()], 0)
            x = tf.broadcast_to(x, new_shape)

        y = self.bijector.forward(x, **bijector_kwargs)

        sample_shape = tf.convert_to_tensor([],
                                            dtype=tf.int32,
                                            name="sample_shape")
        y = self._set_sample_static_shape(y, sample_shape)
        return y
コード例 #10
0
ファイル: beta.py プロジェクト: HackerShohag/SuggestBot-bn
 def _sample_n(self, n, seed=None):
     seed = SeedStream(seed, "beta")
     concentration1 = tf.convert_to_tensor(self.concentration1)
     concentration0 = tf.convert_to_tensor(self.concentration0)
     shape = self._batch_shape_tensor(concentration1, concentration0)
     expanded_concentration1 = tf.broadcast_to(concentration1, shape)
     expanded_concentration0 = tf.broadcast_to(concentration0, shape)
     gamma1_sample = tf.random.gamma(shape=[n],
                                     alpha=expanded_concentration1,
                                     dtype=self.dtype,
                                     seed=seed())
     gamma2_sample = tf.random.gamma(shape=[n],
                                     alpha=expanded_concentration0,
                                     dtype=self.dtype,
                                     seed=seed())
     beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
     return beta_sample
コード例 #11
0
 def _cdf(self, x):
   df = tf.convert_to_tensor(self.df)
   # Take Abs(scale) to make subsequent where work correctly.
   y = (x - self.loc) / tf.abs(self.scale)
   x_t = df / (y**2. + df)
   neg_cdf = 0.5 * tf.math.betainc(
       0.5 * tf.broadcast_to(df, prefer_static.shape(x_t)), 0.5, x_t)
   return tf.where(y < 0., neg_cdf, 1. - neg_cdf)
コード例 #12
0
ファイル: linalg.py プロジェクト: HackerShohag/SuggestBot-bn
def cholesky_concat(chol, cols, name=None):
    """Concatenates `chol @ chol.T` with additional rows and columns.

  This operation is conceptually identical to:
  ```python
  def cholesky_concat_slow(chol, cols):  # cols shaped (n + m) x m = z x m
    mat = tf.matmul(chol, chol, adjoint_b=True)  # batch of n x n
    # Concat columns.
    mat = tf.concat([mat, cols[..., :tf.shape(mat)[-2], :]], axis=-1)  # n x z
    # Concat rows.
    mat = tf.concat([mat, tf.linalg.matrix_transpose(cols)], axis=-2)  # z x z
    return tf.linalg.cholesky(mat)
  ```
  but whereas `cholesky_concat_slow` would cost `O(z**3)` work,
  `cholesky_concat` only costs `O(z**2 + m**3)` work.

  The resulting (implicit) matrix must be symmetric and positive definite.
  Thus, the bottom right `m x m` must be self-adjoint, and we do not require a
  separate `rows` argument (which can be inferred from `conj(cols.T)`).

  Args:
    chol: Cholesky decomposition of `mat = chol @ chol.T`.
    cols: The new columns whose first `n` rows we would like concatenated to the
      right of `mat = chol @ chol.T`, and whose conjugate transpose we would
      like concatenated to the bottom of `concat(mat, cols[:n,:])`. A `Tensor`
      with final dims `(n+m, m)`. The first `n` rows are the top right rectangle
      (their conjugate transpose forms the bottom left), and the bottom `m x m`
      is self-adjoint.
    name: Optional name for this op.

  Returns:
    chol_concat: The Cholesky decomposition of:
      ```
      [ [ mat  cols[:n, :] ]
        [   conj(cols.T)   ] ]
      ```
  """
    with tf.name_scope(name or 'cholesky_extend'):
        dtype = dtype_util.common_dtype([chol, cols], dtype_hint=tf.float32)
        chol = tf.convert_to_tensor(chol, name='chol', dtype=dtype)
        cols = tf.convert_to_tensor(cols, name='cols', dtype=dtype)
        n = prefer_static.shape(chol)[-1]
        mat_nm, mat_mm = cols[..., :n, :], cols[..., n:, :]
        solved_nm = linear_operator_util.matrix_triangular_solve_with_broadcast(
            chol, mat_nm)
        lower_right_mm = tf.linalg.cholesky(
            mat_mm - tf.matmul(solved_nm, solved_nm, adjoint_a=True))
        lower_left_mn = tf.math.conj(tf.linalg.matrix_transpose(solved_nm))
        out_batch = prefer_static.shape(solved_nm)[:-2]
        chol = tf.broadcast_to(
            chol,
            tf.concat([out_batch, prefer_static.shape(chol)[-2:]], axis=0))
        top_right_zeros_nm = tf.zeros_like(solved_nm)
        return tf.concat([
            tf.concat([chol, top_right_zeros_nm], axis=-1),
            tf.concat([lower_left_mn, lower_right_mm], axis=-1)
        ],
                         axis=-2)
コード例 #13
0
 def _sample_n(self, n, seed=None):
     del seed  # unused
     loc = tf.convert_to_tensor(self.loc)
     return tf.broadcast_to(
         loc,
         tf.concat([[n],
                    self._batch_shape_tensor(loc=loc),
                    self._event_shape_tensor(loc=loc)],
                   axis=0))
コード例 #14
0
 def _variance(self):
     probs = self._categorical.probs_parameter()
     outcomes = tf.broadcast_to(self.outcomes,
                                shape=dist_util.prefer_static_shape(probs))
     if dtype_util.is_integer(outcomes.dtype):
         if self._validate_args:
             outcomes = dist_util.embed_check_integer_casting_closed(
                 outcomes, target_dtype=probs.dtype)
         outcomes = tf.cast(outcomes, dtype=probs.dtype)
     square_d = tf.math.squared_difference(
         outcomes,
         self._mean(probs)[..., tf.newaxis])
     return tf.reduce_sum(probs * square_d, axis=-1)
コード例 #15
0
ファイル: sample.py プロジェクト: HackerShohag/SuggestBot-bn
 def _fn(self, **kwargs):
     """Implements summary statistic, eg, mean, stddev, mode."""
     x = getattr(self.distribution, attr)(**kwargs)
     shape = prefer_static.concat([
         self.distribution.batch_shape_tensor(),
         prefer_static.ones(prefer_static.rank_from_shape(
             self.sample_shape),
                            dtype=self.sample_shape.dtype),
         self.distribution.event_shape_tensor(),
     ],
                                  axis=0)
     x = tf.reshape(x, shape=shape)
     shape = prefer_static.concat([
         self.distribution.batch_shape_tensor(),
         self.sample_shape,
         self.distribution.event_shape_tensor(),
     ],
                                  axis=0)
     return tf.broadcast_to(x, shape)
コード例 #16
0
    def _mean(self):
        with tf.control_dependencies(self._runtime_assertions):
            probs = self._marginal_hidden_probs()
            # probs :: num_steps batch_shape num_states
            means = self._observation_distribution.mean()
            # means :: observation_batch_shape[:-1] num_states
            #          observation_event_shape
            means_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_states],
                self._observation_distribution.event_shape_tensor()
            ],
                                    axis=0)
            means = tf.broadcast_to(means, means_shape)
            # means :: batch_shape num_states observation_event_shape

            observation_event_shape = (
                self._observation_distribution.event_shape_tensor())
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
            flat_probs_shape = [self._num_steps, batch_size, self._num_states]
            flat_means_shape = [
                batch_size, self._num_states,
                tf.reduce_prod(observation_event_shape)
            ]

            flat_probs = tf.reshape(probs, flat_probs_shape)
            # flat_probs :: num_steps batch_size num_states
            flat_means = tf.reshape(means, flat_means_shape)
            # flat_means :: batch_size num_states observation_event_size
            flat_mean = tf.einsum("ijk,jkl->jil", flat_probs, flat_means)
            # flat_mean :: batch_size num_steps observation_event_size
            unflat_mean_shape = tf.concat([
                self.batch_shape_tensor(), [self._num_steps],
                observation_event_shape
            ],
                                          axis=0)
            # returns :: batch_shape num_steps observation_event_shape
            return tf.reshape(flat_mean, unflat_mean_shape)
コード例 #17
0
ファイル: pareto.py プロジェクト: HackerShohag/SuggestBot-bn
 def _mode(self):
     scale = tf.convert_to_tensor(self.scale)
     return tf.broadcast_to(scale, self._batch_shape_tensor(scale=scale))
コード例 #18
0
    def _observation_log_probs(self, observations, mask):
        """Compute and shape tensor of log probs associated with observations.."""

        # Let E be the underlying event shape
        #     M the number of steps in the HMM
        #     N the number of states of the HMM
        #
        # Then the incoming observations have shape
        #
        # observations : batch_o [M] E
        #
        # and the mask (if present) has shape
        #
        # mask : batch_m [M]
        #
        # Let this HMM distribution have batch shape batch_d
        # We need to broadcast all three of these batch shapes together
        # into the shape batch.
        #
        # We need to move the step dimension to the first dimension to make
        # them suitable for folding or scanning over.
        #
        # When we call `log_prob` for our observations we need to
        # do this for each state the observation could correspond to.
        # We do this by expanding the dimensions by 1 so we end up with:
        #
        # observations : [M] batch [1] [E]
        #
        # After calling `log_prob` we get
        #
        # observation_log_probs : [M] batch [N]
        #
        # We wish to use `mask` to select from this so we also
        # reshape and broadcast it up to shape
        #
        # mask : [M] batch [N]

        observation_tensor_shape = tf.shape(observations)
        observation_batch_shape = observation_tensor_shape[:-1 - self.
                                                           _underlying_event_rank]
        observation_event_shape = observation_tensor_shape[
            -1 - self._underlying_event_rank:]

        if mask is not None:
            mask_tensor_shape = tf.shape(mask)
            mask_batch_shape = mask_tensor_shape[:-1]

        batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())

        if mask is not None:
            batch_shape = tf.broadcast_dynamic_shape(batch_shape,
                                                     mask_batch_shape)
        observations = tf.broadcast_to(
            observations,
            tf.concat([batch_shape, observation_event_shape], axis=0))
        observation_rank = tf.rank(observations)
        underlying_event_rank = self._underlying_event_rank
        observations = distribution_util.move_dimension(
            observations, observation_rank - underlying_event_rank - 1, 0)
        observations = tf.expand_dims(observations,
                                      observation_rank - underlying_event_rank)
        observation_log_probs = self._observation_distribution.log_prob(
            observations)

        if mask is not None:
            mask = tf.broadcast_to(
                mask, tf.concat([batch_shape, [self._num_steps]], axis=0))
            mask = distribution_util.move_dimension(mask, -1, 0)
            observation_log_probs = tf.where(
                mask[..., tf.newaxis], tf.zeros_like(observation_log_probs),
                observation_log_probs)

        return observation_log_probs
コード例 #19
0
 def _stddev(self):
   scale = tf.convert_to_tensor(self.scale)
   return tf.broadcast_to(scale * np.pi / np.sqrt(3),
                          self._batch_shape_tensor(scale=scale))
コード例 #20
0
 def _mean(self):
   loc = tf.convert_to_tensor(self.loc)
   return tf.broadcast_to(loc, self._batch_shape_tensor(loc=loc))
コード例 #21
0
 def _entropy(self):
   scale = tf.convert_to_tensor(self.scale)
   return tf.broadcast_to(2. + tf.math.log(scale),
                          self._batch_shape_tensor(scale=scale))
コード例 #22
0
ファイル: laplace.py プロジェクト: HackerShohag/SuggestBot-bn
 def _stddev(self):
     scale = tf.convert_to_tensor(self.scale)
     return tf.broadcast_to(
         np.sqrt(2.) * scale, self._batch_shape_tensor(scale=scale))
コード例 #23
0
ファイル: linalg.py プロジェクト: HackerShohag/SuggestBot-bn
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import numpy as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
                                               validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if (tensorshape_util.rank(lower_upper.shape) is None
                or tensorshape_util.rank(lower_upper.shape) != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
        else:
            x = tf.gather(x, tf.math.invert_permutation(perm))

        x.set_shape(lower_upper.shape)
        return x
コード例 #24
0
ファイル: linalg.py プロジェクト: HackerShohag/SuggestBot-bn
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import numpy as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                lower, permuted_rhs),
            lower=False)
コード例 #25
0
ファイル: linalg.py プロジェクト: HackerShohag/SuggestBot-bn
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
    """Computes the (partial) pivoted cholesky decomposition of `matrix`.

  The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
  of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The
  currently-worst-approximated diagonal element is selected as the pivot at each
  iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn,
  N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`.
  Note that, unlike the Cholesky decomposition, `lr` is not triangular even in
  a rectangular-matrix sense. However, under a permutation it could be made
  triangular (it has one more zero in each column as you move to the right).

  Such a matrix can be useful as a preconditioner for conjugate gradient
  optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be
  cheaply done via the Woodbury matrix identity, as implemented by
  `tf.linalg.LinearOperatorLowRankUpdate`.

  Args:
    matrix: Floating point `Tensor` batch of symmetric, positive definite
      matrices.
    max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
      approximation.
    diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
      errors of all diagonal elements of `lr @ lr.T` are each lower than
      `element * diag_rtol`, iteration is permitted to terminate early.
    name: Optional name for the op.

  Returns:
    lr: Low rank pivoted Cholesky approximation of `matrix`.

  #### References

  [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the
       pivoted Cholesky decomposition. _Applied numerical mathematics_,
       62(4):428-440, 2012.

  [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points.
       _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114
  """
    with tf.name_scope(name or 'pivoted_cholesky'):
        dtype = dtype_util.common_dtype([matrix, diag_rtol],
                                        dtype_hint=tf.float32)
        matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype)
        if tensorshape_util.rank(matrix.shape) is None:
            raise NotImplementedError(
                'Rank of `matrix` must be known statically')

        max_rank = tf.convert_to_tensor(max_rank,
                                        name='max_rank',
                                        dtype=tf.int64)
        max_rank = tf.minimum(
            max_rank,
            prefer_static.shape(matrix, out_type=tf.int64)[-1])
        diag_rtol = tf.convert_to_tensor(diag_rtol,
                                         dtype=dtype,
                                         name='diag_rtol')
        matrix_diag = tf.linalg.diag_part(matrix)
        # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs.
        orig_error = tf.reduce_max(matrix_diag, axis=-1)

        def cond(m, pchol, perm, matrix_diag):
            """Condition for `tf.while_loop` continuation."""
            del pchol
            del perm
            error = tf.linalg.norm(matrix_diag, ord=1, axis=-1)
            max_err = tf.reduce_max(error / orig_error)
            return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol))

        batch_dims = tensorshape_util.rank(matrix.shape) - 2

        def batch_gather(params, indices, axis=-1):
            return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)

        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(permuted_diag, axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(prev_rows_perm_m_onward * prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag

        m = np.int64(0)
        pchol = tf.zeros_like(matrix[..., :max_rank, :])
        matrix_shape = prefer_static.shape(matrix, out_type=tf.int64)
        perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]),
                               matrix_shape[:-1])
        _, pchol, _, _ = tf.while_loop(cond=cond,
                                       body=body,
                                       loop_vars=(m, pchol, perm, matrix_diag))
        pchol = tf.linalg.matrix_transpose(pchol)
        tensorshape_util.set_shape(
            pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
        return pchol