Esempio n. 1
0
 def _sample_n(self, n, seed=None):
     samples = tf.convert_to_tensor(self._samples)
     indices = tf.random.uniform([n],
                                 maxval=self._compute_num_samples(samples),
                                 dtype=tf.int32,
                                 seed=seed)
     draws = tf.gather(samples, indices, axis=self._samples_axis)
     axes = tf.concat([[self._samples_axis],
                       tf.range(self._samples_axis, dtype=tf.int32),
                       tf.range(self._event_ndims, dtype=tf.int32) +
                       self._samples_axis + 1],
                      axis=0)
     draws = tf.transpose(a=draws, perm=axes)
     return draws
Esempio n. 2
0
def _maybe_validate_perm(perm, validate_args, name=None):
    """Checks that `perm` is valid."""
    with tf.name_scope(name or 'maybe_validate_perm'):
        assertions = []
        if not dtype_util.is_integer(perm.dtype):
            raise TypeError('`perm` must be integer type')

        msg = '`perm` must be a vector.'
        if tensorshape_util.rank(perm.shape) is not None:
            if tensorshape_util.rank(perm.shape) != 1:
                raise ValueError(msg[:-1] + ', saw rank: {}.'.format(
                    tensorshape_util.rank(perm.shape)))
        elif validate_args:
            assertions += [assert_util.assert_rank(perm, 1, message=msg)]

        perm_ = tf.get_static_value(perm)
        msg = '`perm` must be a valid permutation vector.'
        if perm_ is not None:
            if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)):
                raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_))
        elif validate_args:
            assertions += [
                assert_util.assert_equal(tf.sort(perm),
                                         tf.range(tf.size(perm)),
                                         message=msg)
            ]

        return assertions
Esempio n. 3
0
    def _mode(self, samples=None):
        # Samples count can vary by batch member. Use map_fn to compute mode for
        # each batch separately.
        def _get_mode(samples):
            # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed
            count = gen_array_ops.unique_with_counts_v2(samples,
                                                        axis=[0]).count
            return tf.argmax(count)

        if samples is None:
            samples = tf.convert_to_tensor(self._samples)
        num_samples = self._compute_num_samples(samples)

        # Flatten samples for each batch.
        if self._event_ndims == 0:
            flattened_samples = tf.reshape(samples, [-1, num_samples])
            mode_shape = self._batch_shape_tensor(samples)
        else:
            event_size = tf.reduce_prod(self._event_shape_tensor(samples))
            mode_shape = tf.concat([
                self._batch_shape_tensor(samples),
                self._event_shape_tensor(samples)
            ],
                                   axis=0)
            flattened_samples = tf.reshape(samples,
                                           [-1, num_samples, event_size])

        indices = tf.map_fn(_get_mode, flattened_samples, dtype=tf.int64)
        full_indices = tf.stack(
            [tf.range(tf.shape(indices)[0]),
             tf.cast(indices, tf.int32)],
            axis=1)

        mode = tf.gather_nd(flattened_samples, full_indices)
        return tf.reshape(mode, mode_shape)
Esempio n. 4
0
def _swap_m_with_i(vecs, m, i):
    """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.)

  Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped
  per-vector indices `i`, this function swaps elements `m` and `i` in each
  vector. For the use-case below, these are permutation vectors.

  Args:
    vecs: Vectors on which we perform the swap, int64 `Tensor`.
    m: Scalar int64 `Tensor`, the index into which the `i`th element is going.
    i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into
      which the `m`th element is going.

  Returns:
    vecs: The updated vectors.
  """
    vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs')
    m = tf.convert_to_tensor(m, dtype=tf.int64, name='m')
    i = tf.convert_to_tensor(i, dtype=tf.int64, name='i')
    trailing_elts = tf.broadcast_to(
        tf.range(m + 1,
                 prefer_static.shape(vecs, out_type=tf.int64)[-1]),
        prefer_static.shape(vecs[..., m + 1:]))
    trailing_elts = tf.where(tf.equal(trailing_elts, i),
                             tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:])
    # TODO(bjp): Could we use tensor_scatter_nd_update?
    vecs_shape = vecs.shape
    vecs = tf.concat([
        vecs[..., :m],
        tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1),
        trailing_elts
    ],
                     axis=-1)
    tensorshape_util.set_shape(vecs, vecs_shape)
    return vecs
 def _inverse(self, y):
     map_values = tf.convert_to_tensor(self.map_values)
     flat_y = tf.reshape(y, shape=[-1])
     # Search for the indices of map_values that are closest to flat_y.
     # Since map_values is strictly increasing, the closest is either the
     # first one that is strictly greater than flat_y, or the one before it.
     upper_candidates = tf.minimum(
         tf.size(map_values) - 1,
         tf.searchsorted(map_values, values=flat_y, side='right'))
     lower_candidates = tf.maximum(0, upper_candidates - 1)
     candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
     lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
     upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_near(tf.minimum(lower_cand_diff,
                                                    upper_cand_diff),
                                         0,
                                         message='inverse value not found')
         ]):
             candidates = tf.identity(candidates)
     candidate_selector = tf.stack([
         tf.range(tf.size(flat_y), dtype=tf.int32),
         tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
     ],
                                   axis=-1)
     return tf.reshape(tf.gather_nd(candidates, candidate_selector),
                       shape=y.shape)
def _extract_log_probs(num_states, dist):
    """Tabulate log probabilities from a batch of distributions."""

    states = tf.reshape(
        tf.range(num_states),
        tf.concat([[num_states],
                   tf.ones_like(dist.batch_shape_tensor())],
                  axis=0))
    return distribution_util.move_dimension(dist.log_prob(states), 0, -1)
Esempio n. 7
0
 def _inverse(self, y):
     # As specified in the Stan reference manual, the procedure is as follows:
     # N = y.shape[-1]
     # z_k = y_k / (1 - sum_{i=1 to k-1} y_i)
     # x_k = logit(z_k) - log(1 / (N - k))
     offset = tf.math.log(
         tf.cast(tf.range(tf.shape(y)[-1] - 1, 0, delta=-1),
                 dtype=dtype_util.base_dtype(y.dtype)))
     z = y / (1. - tf.math.cumsum(y, axis=-1, exclusive=True))
     return tf.math.log(z[..., :-1]) - tf.math.log1p(-z[..., :-1]) + offset
Esempio n. 8
0
 def _prob(self, event):
     samples = tf.convert_to_tensor(self._samples)
     num_samples = self._compute_num_samples(samples)
     event = tf.convert_to_tensor(event, name='event', dtype=self.dtype)
     event, samples = _broadcast_event_and_samples(
         event, samples, event_ndims=self._event_ndims)
     prob = tf.reduce_sum(tf.cast(tf.reduce_all(
         tf.equal(samples, event), axis=tf.range(-self._event_ndims, 0)),
                                  dtype=tf.int32),
                          axis=-1) / num_samples
     if dtype_util.is_floating(self.dtype):
         prob = tf.cast(prob, self.dtype)
     return prob
Esempio n. 9
0
 def _make_perm(self, x_rank, perm):
     sample_batch_ndims = (distribution_util.prefer_static_value(x_rank) -
                           distribution_util.prefer_static_value(
                               self.rightmost_transposed_ndims))
     dtype = perm.dtype
     perm = tf.concat([
         tf.range(tf.cast(sample_batch_ndims, dtype)),
         tf.cast(
             sample_batch_ndims +
             distribution_util.prefer_static_value(perm), dtype),
     ],
                      axis=0)
     return perm
 def _compute_quantiles():
     """Helper to build quantiles."""
     # Omit {0, 1} since they might lead to Inf/NaN.
     zero = tf.zeros([], dtype=dist.dtype)
     edges = tf.linspace(zero, 1., quadrature_size + 3)[1:-1]
     # Expand edges so its broadcast across batch dims.
     edges = tf.reshape(
         edges,
         shape=tf.concat(
             [[-1], tf.ones([batch_ndims], dtype=tf.int32)], axis=0))
     quantiles = dist.quantile(edges)
     # Cyclically permute left by one.
     perm = tf.concat([tf.range(1, 1 + batch_ndims), [0]], axis=0)
     quantiles = tf.transpose(a=quantiles, perm=perm)
     return quantiles
Esempio n. 11
0
 def _forward(self, x):
     # As specified in the Stan reference manual, the procedure is as follows:
     # N = x.shape[-1] + 1
     # z_k = sigmoid(x + log(1 / (N - k)))
     # y_1 = z_1
     # y_k = (1 - sum_{i=1 to k-1} y_i) * z_k
     # y_N = 1 - sum_{i=1 to N-1} y_i
     # TODO(b/128857065): The numerics can possibly be improved here with a
     # log-space computation.
     offset = -tf.math.log(
         tf.cast(tf.range(tf.shape(x)[-1], 0, delta=-1),
                 dtype=dtype_util.base_dtype(x.dtype)))
     z = tf.math.sigmoid(x + offset)
     y = z * tf.math.cumprod(1 - z, axis=-1, exclusive=True)
     return tf.concat([y, 1. - tf.reduce_sum(y, axis=-1, keepdims=True)],
                      axis=-1)
Esempio n. 12
0
 def _inverse_log_det_jacobian(self, y):
     # The inverse log det jacobian (ILDJ) of the entire mapping is the sum of
     # the ILDJs of each row's mapping.
     #
     # To compute the ILDJ for each row's mapping, consider the forward mapping
     # `f_k` restricted to the `k`th (1-indexed) row. It maps unconstrained reals
     # in `R^{k-1}` to unit vectors in `R^k`. `f_k : R^{k-1} -> R^k` is given by:
     #
     #   f(x_1, x_2, ... x_{k-1}) = (x_1/s, x_2/s, ..., x_{k-1}/s, 1/s)
     #
     # where `s = norm(x_1, x_2, ..., x_{k-1}, 1)`.
     #
     # The change in infinitesimal `k-1`-dimensional volume (or surface area) is
     # given by sqrt(|det J^T J|); where J is the `k x (k-1)` Jacobian matrix.
     #
     # Claim: sqrt(|det(J^T J)|) = s^{-k}.
     #
     # Proof: We compute the entries of the Jacobian matrix J:
     #
     #     J_{i, j} =  -x_j / s^3           if i == k
     #     J_{i, j} =  (s^2 - x_i^2) / s^3  if i == j and i < k
     #     J_{i, j} = -(x_i * x_j) / s^3    if i != j and i < k
     #
     #   By spherical symmetry, the volume element depends only on `s`; w.l.o.g.
     #   we can assume that `x_1 = r` and `x_2, ..., x_n = 0`; where
     #   `r^2 + 1 = s^2`.
     #
     #   We can write `J^T = [A|B]` where `A` is a diagonal matrix of rank `k-1`
     #   with diagonal `(1/s^3, 1/s, 1/s, ..., 1/s)`; and `B` is a column vector
     #   of size `k-1`, with entries (-r/s^3, 0, 0, ..., 0). Hence,
     #
     #     det(J^T J) = det(diag((r^2 + 1) / s^6, 1/s^2, ..., s^2))
     #                = s^{-2k}.
     #
     #   Or, sqrt(|det(J^T J)|) = s^{-k}.
     #
     # Hence, the forward log det jacobian (FLDJ) for the `k`th row is given by
     # `-k * log(s)`. The ILDJ is equal to negative FLDJ at the pre-image, or,
     # `k * log(s)`; where `s` is the reciprocal of the `k`th diagonal entry.
     #
     n = prefer_static.shape(y)[-1]
     return -tf.reduce_sum(tf.range(1, n + 1, dtype=y.dtype) *
                           tf.math.log(tf.linalg.diag_part(y)),
                           axis=-1)
    def _sample_n(self, n, seed=None):
        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        distributions = self.poisson_and_mixture_distributions()
        dist, mixture_dist = distributions
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(
                self._batch_shape_tensor(distributions=distributions))
        # We need to 'sample extra' from the mixture distribution if it doesn't
        # already specify a probs vector for each batch coordinate.
        # We only support this kind of reduced broadcasting, i.e., there is exactly
        # one probs vector for all batch dims or one for each.
        stream = SeedStream(seed, salt='PoissonLogNormalQuadratureCompound')
        ids = mixture_dist.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(mixture_dist.is_scalar_batch(),
                                          [batch_size], np.int32([]))),
                                  seed=stream())
        # We need to flatten batch dims in case mixture_dist has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `quadrature_size` for `batch_size` number of times.
        offset = tf.range(start=0,
                          limit=batch_size * self._quadrature_size,
                          delta=self._quadrature_size,
                          dtype=ids.dtype)
        ids = ids + offset
        rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids)
        rate = tf.reshape(
            rate,
            shape=concat_vectors(
                [n], self._batch_shape_tensor(distributions=distributions)))
        return tf.random.poisson(lam=rate,
                                 shape=[],
                                 dtype=self.dtype,
                                 seed=seed)
Esempio n. 14
0
  def _sample_n(self, n, seed=None):
    if self._use_static_graph:
      with tf.control_dependencies(self._assertions):
        # This sampling approach is almost the same as the approach used by
        # `MixtureSameFamily`. The differences are due to having a list of
        # `Distribution` objects rather than a single object, and maintaining
        # random seed management that is consistent with the non-static code
        # path.
        samples = []
        cat_samples = self.cat.sample(n, seed=seed)
        stream = SeedStream(seed, salt="Mixture")

        for c in range(self.num_components):
          samples.append(self.components[c].sample(n, seed=stream()))
        stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
        x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        mask = tf.one_hot(
            indices=cat_samples,  # [n, B]
            depth=self._num_components,  # == k
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k]
        mask = distribution_util.pad_mixture_dimensions(
            mask, self, self._cat,
            tensorshape_util.rank(self._static_event_shape))  # [n, B, k, [1]*e]
        return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

    with tf.control_dependencies(self._assertions):
      n = tf.convert_to_tensor(n, name="n")
      static_n = tf.get_static_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples = self.cat.sample(n, seed=seed)

      static_samples_shape = cat_samples.shape
      if tensorshape_util.is_fully_defined(static_samples_shape):
        samples_shape = tensorshape_util.as_list(static_samples_shape)
        samples_size = tensorshape_util.num_elements(static_samples_shape)
      else:
        samples_shape = tf.shape(cat_samples)
        samples_size = tf.size(cat_samples)
      static_batch_shape = self.batch_shape
      if tensorshape_util.is_fully_defined(static_batch_shape):
        batch_shape = tensorshape_util.as_list(static_batch_shape)
        batch_size = tensorshape_util.num_elements(static_batch_shape)
      else:
        batch_shape = self.batch_shape_tensor()
        batch_size = tf.reduce_prod(batch_shape)
      static_event_shape = self.event_shape
      if tensorshape_util.is_fully_defined(static_event_shape):
        event_shape = np.array(
            tensorshape_util.as_list(static_event_shape), dtype=np.int32)
      else:
        event_shape = self.event_shape_tensor()

      # Get indices into the raw cat sampling tensor. We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = tf.dynamic_partition(
          data=samples_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = tf.reshape(
          tf.tile(tf.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = tf.dynamic_partition(
          data=batch_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)
      samples_class = [None for _ in range(self.num_components)]

      stream = SeedStream(seed, salt="Mixture")

      for c in range(self.num_components):
        n_class = tf.size(partitioned_samples_indices[c])
        samples_class_c = self.components[c].sample(
            n_class, seed=stream())

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index. The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * tf.range(n_class) + partitioned_batch_indices[c])
        samples_class_c = tf.reshape(
            samples_class_c, tf.concat([[n_class * batch_size], event_shape],
                                       0))
        samples_class_c = tf.gather(
            samples_class_c,
            lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = tf.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = tf.reshape(
          lhs_flat_ret, tf.concat(
              [samples_shape, self.event_shape_tensor()], 0))
      tensorshape_util.set_shape(
          ret,
          tensorshape_util.concatenate(static_samples_shape, self.event_shape))
      return ret
Esempio n. 15
0
    def _log_prob(self, x):
        if self.input_output_cholesky:
            x_sqrt = x
        else:
            # Complexity: O(nbk**3)
            x_sqrt = tf.linalg.cholesky(x)

        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        x_ndims = tf.rank(x_sqrt)
        num_singleton_axes_to_prepend = (
            tf.maximum(tf.size(batch_shape) + 2, x_ndims) - x_ndims)
        x_with_prepended_singletons_shape = tf.concat([
            tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
            tf.shape(x_sqrt)
        ], 0)
        x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
        ndims = tf.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - tf.size(batch_shape) - 2
        sample_shape = tf.shape(x_sqrt)[:sample_ndims]

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix. Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk**2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = tf.concat(
            [tf.range(sample_ndims, ndims),
             tf.range(0, sample_ndims)], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)
        last_dim_size = (
            tf.cast(self.dimension, dtype=tf.int32) *
            tf.reduce_prod(x_with_prepended_singletons_shape[:sample_ndims]))
        shape = tf.concat([
            x_with_prepended_singletons_shape[sample_ndims:-2],
            [tf.cast(self.dimension, dtype=tf.int32), last_dim_size]
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
        # this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self.scale_operator.solve(
            scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat(
            [tf.shape(scale_sqrt_inv_x_sqrt)[:-2], event_shape, sample_shape],
            axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = tf.concat([
            tf.range(ndims - sample_ndims, ndims),
            tf.range(0, ndims - sample_ndims)
        ], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}**2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk**2)
        trace_scale_inv_x = tf.reduce_sum(tf.square(scale_sqrt_inv_x_sqrt),
                                          axis=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = tf.reduce_sum(tf.math.log(
            tf.linalg.diag_part(x_sqrt)),
                                       axis=[-1])

        # Complexity: O(nbk**2)
        log_prob = ((self.df - self.dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x - self.log_normalization())

        # Set shape hints.
        # Try to merge what we know from the input x with what we know from the
        # parameters of this distribution.
        if tensorshape_util.rank(
                x.shape) is not None and tensorshape_util.rank(
                    self.batch_shape) is not None:
            tensorshape_util.set_shape(
                log_prob,
                tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

        return log_prob
Esempio n. 16
0
    def __init__(self,
                 perm=None,
                 rightmost_transposed_ndims=None,
                 validate_args=False,
                 name='transpose'):
        """Instantiates the `Transpose` bijector.

    Args:
      perm: Positive `int32` vector-shaped `Tensor` representing permutation of
        rightmost dims (for forward transformation).  Note that the `0`th index
        represents the first of the rightmost dims and the largest value must be
        `rightmost_transposed_ndims - 1` and corresponds to `tf.rank(x) - 1`.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value:
        `tf.range(start=rightmost_transposed_ndims, limit=-1, delta=-1)`.
      rightmost_transposed_ndims: Positive `int32` scalar-shaped `Tensor`
        representing the number of rightmost dimensions to permute.
        Only one of `perm` and `rightmost_transposed_ndims` can (and must) be
        specified.
        Default value: `tf.size(perm)`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      ValueError: if both or neither `perm` and `rightmost_transposed_ndims` are
        specified.
      NotImplementedError: if `rightmost_transposed_ndims` is not known prior to
        graph execution.
    """
        with tf.name_scope(name) as name:
            if (rightmost_transposed_ndims is None) == (perm is None):
                raise ValueError('Must specify exactly one of '
                                 '`rightmost_transposed_ndims` and `perm`.')
            if rightmost_transposed_ndims is not None:
                rightmost_transposed_ndims = tf.convert_to_tensor(
                    rightmost_transposed_ndims,
                    dtype_hint=np.int32,
                    name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_rightmost_transposed_ndims(
                    rightmost_transposed_ndims, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        rightmost_transposed_ndims = tf.identity(
                            rightmost_transposed_ndims)
                perm_start = (distribution_util.prefer_static_value(
                    rightmost_transposed_ndims) - 1)
                perm = tf.range(start=perm_start,
                                limit=-1,
                                delta=-1,
                                name='perm')
            else:  # perm is not None:
                perm = tf.convert_to_tensor(perm,
                                            dtype_hint=np.int32,
                                            name='perm')
                rightmost_transposed_ndims = tf.size(
                    perm, name='rightmost_transposed_ndims')
                rightmost_transposed_ndims_ = tf.get_static_value(
                    rightmost_transposed_ndims)
                assertions = _maybe_validate_perm(perm, validate_args)
                if assertions:
                    with tf.control_dependencies(assertions):
                        perm = tf.identity(perm)

            # TODO(b/110828604): If bijector base class ever supports dynamic
            # `min_event_ndims`, then this class already works dynamically and the
            # following five lines can be removed.
            if rightmost_transposed_ndims_ is None:
                raise NotImplementedError(
                    '`rightmost_transposed_ndims` must be '
                    'known prior to graph execution.')
            else:
                rightmost_transposed_ndims_ = int(rightmost_transposed_ndims_)

            self._perm = perm
            self._rightmost_transposed_ndims = rightmost_transposed_ndims
            super(Transpose, self).__init__(
                forward_min_event_ndims=rightmost_transposed_ndims_,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)
    def _sample_n(self, n, seed=None):
        stream = SeedStream(seed, salt="VectorDiffeomixture")
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=stream())  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
        mix_batch_size = tensorshape_util.num_elements(
            self.mixture_distribution.batch_shape)
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]),
                                          [batch_size // mix_batch_size])),
                                               seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:])
        if stride is None:
            stride = tf.reduce_prod(tf.shape(self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if tensorshape_util.is_fully_defined(self.batch_shape):
            new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
Esempio n. 18
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import jax as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                lower, permuted_rhs),
            lower=False)
Esempio n. 19
0
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import jax as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.jax

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
                                               validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if (tensorshape_util.rank(lower_upper.shape) is None
                or tensorshape_util.rank(lower_upper.shape) != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
        else:
            x = tf.gather(x, tf.math.invert_permutation(perm))

        x.set_shape(lower_upper.shape)
        return x
Esempio n. 20
0
    def __init__(self, permutation, axis=-1, validate_args=False, name=None):
        """Creates the `Permute` bijector.

    Args:
      permutation: An `int`-like vector-shaped `Tensor` representing the
        permutation to apply to the `axis` dimension of the transformed
        `Tensor`.
      axis: Scalar `int` `Tensor` representing the dimension over which to
        `tf.gather`. `axis` must be relative to the end (reading left to right)
        thus must be negative.
        Default value: `-1` (i.e., right-most).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str`, name given to ops managed by this object.

    Raises:
      TypeError: if `not dtype_util.is_integer(permutation.dtype)`.
      ValueError: if `permutation` does not contain exactly one of each of
        `{0, 1, ..., d}`.
      NotImplementedError: if `axis` is not known prior to graph execution.
      NotImplementedError: if `axis` is not negative.
    """
        with tf.name_scope(name or "permute") as name:
            axis = tf.convert_to_tensor(axis, name="axis")
            if not dtype_util.is_integer(axis.dtype):
                raise TypeError("axis.dtype ({}) should be `int`-like.".format(
                    dtype_util.name(axis.dtype)))
            permutation = tf.convert_to_tensor(permutation, name="permutation")
            if not dtype_util.is_integer(permutation.dtype):
                raise TypeError(
                    "permutation.dtype ({}) should be `int`-like.".format(
                        dtype_util.name(permutation.dtype)))
            p = tf.get_static_value(permutation)
            if p is not None:
                if set(p) != set(np.arange(p.size)):
                    raise ValueError(
                        "Permutation over `d` must contain exactly one of "
                        "each of `{0, 1, ..., d}`.")
            elif validate_args:
                p, _ = tf.math.top_k(-permutation,
                                     k=tf.shape(permutation)[-1],
                                     sorted=True)
                permutation = distribution_util.with_dependencies([
                    assert_util.assert_equal(
                        -p,
                        tf.range(tf.size(p)),
                        message=(
                            "Permutation over `d` must contain exactly one of "
                            "each of `{0, 1, ..., d}`.")),
                ], permutation)
            axis_ = tf.get_static_value(axis)
            if axis_ is None:
                raise NotImplementedError(
                    "`axis` must be known prior to graph "
                    "execution.")
            elif axis_ >= 0:
                raise NotImplementedError(
                    "`axis` must be relative the rightmost "
                    "dimension, i.e., negative.")
            else:
                forward_min_event_ndims = int(np.abs(axis_))
            self._permutation = permutation
            self._axis = axis
            super(Permute, self).__init__(
                forward_min_event_ndims=forward_min_event_ndims,
                is_constant_jacobian=True,
                validate_args=validate_args,
                name=name)
Esempio n. 21
0
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        batch_ndims = tf.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = tf.concat([[n], batch_shape, event_shape], 0)
        stream = SeedStream(seed, salt="Wishart")

        # Complexity: O(nbk**2)
        x = tf.random.normal(shape=shape,
                             mean=0.,
                             stddev=1.,
                             dtype=self.dtype,
                             seed=stream())

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = self.df * tf.ones(
            self.scale_operator.batch_shape_tensor(),
            dtype=dtype_util.base_dtype(self.df.dtype))

        g = tf.random.gamma(shape=[n],
                            alpha=self._multi_gamma_sequence(
                                0.5 * expanded_df, self.dimension),
                            beta=0.5,
                            dtype=self.dtype,
                            seed=stream())

        # Complexity: O(nbk**2)
        x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.linalg.set_diag(x, tf.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = tf.concat([tf.range(1, ndims), [0]], 0)
        x = tf.transpose(a=x, perm=perm)
        shape = tf.concat(
            [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
        # this step has complexity O(nbk^3).
        x = self.scale_operator.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
        x = tf.transpose(a=x, perm=perm)

        if not self.input_output_cholesky:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x