示例#1
0
def _swap_m_with_i(vecs, m, i):
    """Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.)

  Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped
  per-vector indices `i`, this function swaps elements `m` and `i` in each
  vector. For the use-case below, these are permutation vectors.

  Args:
    vecs: Vectors on which we perform the swap, int64 `Tensor`.
    m: Scalar int64 `Tensor`, the index into which the `i`th element is going.
    i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into
      which the `m`th element is going.

  Returns:
    vecs: The updated vectors.
  """
    vecs = tf.convert_to_tensor(vecs, dtype=tf.int64, name='vecs')
    m = tf.convert_to_tensor(m, dtype=tf.int64, name='m')
    i = tf.convert_to_tensor(i, dtype=tf.int64, name='i')
    trailing_elts = tf.broadcast_to(
        tf.range(m + 1,
                 prefer_static.shape(vecs, out_type=tf.int64)[-1]),
        prefer_static.shape(vecs[..., m + 1:]))
    trailing_elts = tf.where(tf.equal(trailing_elts, i),
                             tf.gather(vecs, [m], axis=-1), vecs[..., m + 1:])
    # TODO(bjp): Could we use tensor_scatter_nd_update?
    vecs_shape = vecs.shape
    vecs = tf.concat([
        vecs[..., :m],
        tf.gather(vecs, i, batch_dims=int(prefer_static.rank(vecs)) - 1),
        trailing_elts
    ],
                     axis=-1)
    tensorshape_util.set_shape(vecs, vecs_shape)
    return vecs
示例#2
0
    def _cdf(self, k):
        # TODO(b/135263541): Improve numerical precision of categorical.cdf.
        probs = self.probs_parameter()
        num_categories = self._num_categories(probs)

        k, probs = _broadcast_cat_event_and_params(
            k, probs, base_dtype=dtype_util.base_dtype(self.dtype))

        # Since the lowest number in the support is 0, any k < 0 should be zero in
        # the output.
        should_be_zero = k < 0

        # Will use k as an index in the gather below, so clip it to {0,...,K-1}.
        k = tf.clip_by_value(tf.cast(k, tf.int32), 0, num_categories - 1)

        batch_shape = tf.shape(k)

        # tf.gather(..., batch_dims=batch_dims) requires static batch_dims kwarg, so
        # to handle the case where the batch shape is dynamic, flatten the batch
        # dims (so we know batch_dims=1).
        k_flat_batch = tf.reshape(k, [-1])
        probs_flat_batch = tf.reshape(
            probs, tf.concat(([-1], [num_categories]), axis=0))

        cdf_flat = tf.gather(tf.cumsum(probs_flat_batch, axis=-1),
                             k_flat_batch[..., tf.newaxis],
                             batch_dims=1)

        cdf = tf.reshape(cdf_flat, shape=batch_shape)

        zero = np.array(0, dtype=dtype_util.as_numpy_dtype(cdf.dtype))
        return tf.where(should_be_zero, zero, cdf)
示例#3
0
 def _log_prob(self, x):
     x = tf.convert_to_tensor(x, name='x')
     right_indices = tf.minimum(
         tf.size(self.outcomes) - 1,
         tf.reshape(
             tf.searchsorted(self.outcomes,
                             values=tf.reshape(x, shape=[-1]),
                             side='right'),
             dist_util.prefer_static_shape(x)))
     use_right_indices = self._is_equal_or_close(
         x, tf.gather(self.outcomes, indices=right_indices))
     left_indices = tf.maximum(0, right_indices - 1)
     use_left_indices = self._is_equal_or_close(
         x, tf.gather(self.outcomes, indices=left_indices))
     log_probs = self._categorical.log_prob(
         tf.where(use_left_indices, left_indices, right_indices))
     return tf.where(tf.logical_not(use_left_indices | use_right_indices),
                     dtype_util.as_numpy_dtype(log_probs.dtype)(-np.inf),
                     log_probs)
示例#4
0
 def _sample_n(self, n, seed=None):
   samples = tf.convert_to_tensor(self._samples)
   indices = tf.random.uniform([n], maxval=self._compute_num_samples(samples),
                               dtype=tf.int32, seed=seed)
   draws = tf.gather(samples, indices, axis=self._samples_axis)
   axes = tf.concat(
       [[self._samples_axis],
        tf.range(self._samples_axis, dtype=tf.int32),
        tf.range(self._event_ndims, dtype=tf.int32) + self._samples_axis + 1],
       axis=0)
   draws = tf.transpose(a=draws, perm=axes)
   return draws
示例#5
0
 def _forward(self, x):
     map_values = tf.convert_to_tensor(self.map_values)
     if self.validate_args:
         with tf.control_dependencies([
                 assert_util.assert_equal(
                     (0 <= x) & (x < tf.size(map_values)),
                     True,
                     message='indices out of bound')
         ]):
             x = tf.identity(x)
     # If we want batch dims in self.map_values, we can (after broadcasting),
     # use:
     # tf.gather(self.map_values, x, batch_dims=-1, axis=-1)
     return tf.gather(map_values, indices=x)
示例#6
0
 def _cdf(self, x):
     x = tf.convert_to_tensor(x, name='x')
     flat_x = tf.reshape(x, shape=[-1])
     upper_bound = tf.searchsorted(self.outcomes,
                                   values=flat_x,
                                   side='right')
     values_at_ub = tf.gather(
         self.outcomes,
         indices=tf.minimum(
             upper_bound,
             dist_util.prefer_static_shape(self.outcomes)[-1] - 1))
     should_use_upper_bound = self._is_equal_or_close(flat_x, values_at_ub)
     indices = tf.where(should_use_upper_bound, upper_bound,
                        upper_bound - 1)
     return self._categorical.cdf(
         tf.reshape(indices, shape=dist_util.prefer_static_shape(x)))
示例#7
0
 def _prob(self, x):
     if self.validate_args:
         is_vector_check = assert_util.assert_rank_at_least(x, 1)
         right_vec_space_check = assert_util.assert_equal(
             self.event_shape_tensor(),
             tf.gather(tf.shape(x),
                       tf.rank(x) - 1),
             message=
             "Argument 'x' not defined in the same space R^k as this distribution"
         )
         with tf.control_dependencies([is_vector_check]):
             with tf.control_dependencies([right_vec_space_check]):
                 x = tf.identity(x)
     loc = tf.convert_to_tensor(self.loc)
     return tf.cast(tf.reduce_all(tf.abs(x - loc) <= self._slack(loc),
                                  axis=-1),
                    dtype=self.dtype)
示例#8
0
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    distributions = self.poisson_and_mixture_distributions()
    dist, mixture_dist = distributions
    batch_size = tensorshape_util.num_elements(self.batch_shape)
    if batch_size is None:
      batch_size = tf.reduce_prod(
          self._batch_shape_tensor(distributions=distributions))
    # We need to 'sample extra' from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    stream = SeedStream(seed, salt='PoissonLogNormalQuadratureCompound')
    ids = mixture_dist.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                mixture_dist.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=stream())
    # We need to flatten batch dims in case mixture_dist has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = tf.range(
        start=0,
        limit=batch_size * self._quadrature_size,
        delta=self._quadrature_size,
        dtype=ids.dtype)
    ids = ids + offset
    rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids)
    rate = tf.reshape(
        rate, shape=concat_vectors([n], self._batch_shape_tensor(
            distributions=distributions)))
    return tf.random.poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
    def _sample_n(self, n, seed=None):
        seed = SeedStream(seed, salt='vom_mises_fisher')
        # The sampling strategy relies on the fact that vMF variates are symmetric
        # about the mean direction. Accordingly, if we have a sampling strategy for
        # the away-from-mean angle, then we can uniformly sample the remaining
        # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
        # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
        #
        # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
        # von-Mises distributed `x` value in [-1, 1], then uniformly select what
        # amounts to a "up" or "down" additional degree of freedom after unit
        # normalizing, followed by a final rotation to the desired mean direction
        # from a basis of (1, 0).
        #
        # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
        # unit sphere over which the distribution is uniform, in particular the
        # circle where x = \hat{x} intersects the unit sphere. We pick a point on
        # that circle, then rotate to the desired mean direction from a basis of
        # (1, 0, 0).
        event_dim = (tf.compat.dimension_value(self.event_shape[0])
                     or self._event_shape_tensor()[0])

        sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()],
                                       axis=0)
        dim = tf.cast(event_dim - 1, self.dtype)
        if event_dim == 3:
            samples_dim0 = self._sample_3d(n, seed=seed)
        else:
            # Wood'94 provides a rejection algorithm to sample the x coordinate.
            # Wood'94 definition of b:
            # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
            # https://stats.stackexchange.com/questions/156729 suggests:
            b = dim / (2 * self.concentration +
                       tf.sqrt(4 * self.concentration**2 + dim**2))
            # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
            #     https://github.com/nicola-decao/s-vae-tf/
            x = (1 - b) / (1 + b)
            c = self.concentration * x + dim * tf.math.log1p(-x**2)
            beta = beta_lib.Beta(dim / 2, dim / 2)

            def cond_fn(w, should_continue):
                del w
                return tf.reduce_any(should_continue)

            def body_fn(w, should_continue):
                z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
                # set_shape needed here because of b/139013403
                z.set_shape(w.shape)
                w = tf.where(should_continue,
                             (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
                w = tf.debugging.check_numerics(w, 'w')
                unif = tf.random.uniform(sample_batch_shape,
                                         seed=seed(),
                                         dtype=self.dtype)
                # set_shape needed here because of b/139013403
                unif.set_shape(w.shape)
                should_continue = tf.logical_and(
                    should_continue,
                    self.concentration * w + dim * tf.math.log1p(-x * w) - c <
                    tf.math.log(unif))
                return w, should_continue

            w = tf.zeros(sample_batch_shape, dtype=self.dtype)
            should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
            samples_dim0 = tf.while_loop(cond=cond_fn,
                                         body=body_fn,
                                         loop_vars=(w, should_continue))[0]
            samples_dim0 = samples_dim0[..., tf.newaxis]
        if not self._allow_nan_stats:
            # Verify samples are w/in -1, 1, with useful error output tensors (top
            # value rather than all values).
            with tf.control_dependencies([
                    assert_util.assert_less_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(1.01),
                        data=[
                            tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]
                        ]),
                    assert_util.assert_greater_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(-1.01),
                        data=[
                            -tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]
                        ])
            ]):
                samples_dim0 = tf.identity(samples_dim0)
        samples_otherdims_shape = tf.concat(
            [sample_batch_shape, [event_dim - 1]], axis=0)
        unit_otherdims = tf.math.l2_normalize(tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
                                              axis=-1)
        samples = tf.concat(
            [
                samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
                tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
            ],
            axis=-1)
        samples = tf.math.l2_normalize(samples, axis=-1)
        if not self._allow_nan_stats:
            samples = tf.debugging.check_numerics(samples, 'samples')

        # Runtime assert that samples are unit length.
        if not self._allow_nan_stats:
            worst, idx = tf.math.top_k(
                tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
            with tf.control_dependencies([
                    assert_util.assert_near(
                        dtype_util.as_numpy_dtype(self.dtype)(0),
                        worst,
                        data=[
                            worst, idx,
                            tf.gather(tf.reshape(samples, [-1, event_dim]),
                                      idx)
                        ],
                        atol=1e-4,
                        summarize=100)
            ]):
                samples = tf.identity(samples)
        # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
        # Now, we move the mode to `self.mean_direction` using a rotation matrix.
        if not self._allow_nan_stats:
            # Assert that the basis vector rotates to the mean direction, as expected.
            basis = tf.cast(
                tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                self.dtype)
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.linalg.norm(self._rotate(basis) -
                                       self.mean_direction,
                                       axis=-1),
                        dtype_util.as_numpy_dtype(self.dtype)(1e-5))
            ]):
                return self._rotate(samples)
        return self._rotate(samples)
示例#10
0
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import numpy as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
                                               validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if (tensorshape_util.rank(lower_upper.shape) is None
                or tensorshape_util.rank(lower_upper.shape) != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
        else:
            x = tf.gather(x, tf.math.invert_permutation(perm))

        x.set_shape(lower_upper.shape)
        return x
示例#11
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  from tensorflow_probability.python.internal.backend import numpy as tf
  import tensorflow_probability as tfp; tfp = tfp.experimental.substrates.numpy

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower_upper,  # Only upper is accessed.
            linear_operator_util.matrix_triangular_solve_with_broadcast(
                lower, permuted_rhs),
            lower=False)
示例#12
0
 def batch_gather(params, indices, axis=-1):
     return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)
示例#13
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            with tf.control_dependencies(self._assertions):
                # This sampling approach is almost the same as the approach used by
                # `MixtureSameFamily`. The differences are due to having a list of
                # `Distribution` objects rather than a single object, and maintaining
                # random seed management that is consistent with the non-static code
                # path.
                samples = []
                cat_samples = self.cat.sample(n, seed=seed)
                stream = SeedStream(seed, salt="Mixture")

                for c in range(self.num_components):
                    samples.append(self.components[c].sample(n, seed=stream()))
                stack_axis = -1 - tensorshape_util.rank(
                    self._static_event_shape)
                x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
                npdt = dtype_util.as_numpy_dtype(x.dtype)
                mask = tf.one_hot(
                    indices=cat_samples,  # [n, B]
                    depth=self._num_components,  # == k
                    on_value=npdt(1),
                    off_value=npdt(0))  # [n, B, k]
                mask = distribution_util.pad_mixture_dimensions(
                    mask, self, self._cat,
                    tensorshape_util.rank(
                        self._static_event_shape))  # [n, B, k, [1]*e]
                return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        with tf.control_dependencies(self._assertions):
            n = tf.convert_to_tensor(n, name="n")
            static_n = tf.get_static_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.shape
            if tensorshape_util.is_fully_defined(static_samples_shape):
                samples_shape = tensorshape_util.as_list(static_samples_shape)
                samples_size = tensorshape_util.num_elements(
                    static_samples_shape)
            else:
                samples_shape = tf.shape(cat_samples)
                samples_size = tf.size(cat_samples)
            static_batch_shape = self.batch_shape
            if tensorshape_util.is_fully_defined(static_batch_shape):
                batch_shape = tensorshape_util.as_list(static_batch_shape)
                batch_size = tensorshape_util.num_elements(static_batch_shape)
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = tf.reduce_prod(batch_shape)
            static_event_shape = self.event_shape
            if tensorshape_util.is_fully_defined(static_event_shape):
                event_shape = np.array(
                    tensorshape_util.as_list(static_event_shape),
                    dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                             samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = tf.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = tf.reshape(
                tf.tile(tf.range(0, batch_size), [n]), samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = tf.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            stream = SeedStream(seed, salt="Mixture")

            for c in range(self.num_components):
                n_class = tf.size(partitioned_samples_indices[c])
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=stream())

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * tf.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = tf.reshape(
                    samples_class_c,
                    tf.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = tf.gather(samples_class_c,
                                            lookup_partitioned_batch_indices,
                                            name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = tf.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = tf.reshape(
                lhs_flat_ret,
                tf.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            tensorshape_util.set_shape(
                ret,
                tensorshape_util.concatenate(static_samples_shape,
                                             self.event_shape))
            return ret
示例#14
0
 def _sample_n(self, n, seed=None, **distribution_kwargs):
     return tf.gather(self.outcomes,
                      indices=self._categorical.sample(
                          sample_shape=[n],
                          seed=seed,
                          **distribution_kwargs))
示例#15
0
 def _forward(self, x):
     y = tf.gather(x, self.permutation, axis=self.axis)
     tensorshape_util.set_shape(y, x.shape)
     return y
示例#16
0
 def _forward_event_shape_tensor(self, input_shape):
     perm = self._make_perm(tf.size(input_shape), self.perm)
     return tf.gather(input_shape, perm)
示例#17
0
 def _inverse(self, y):
     x = tf.gather(y,
                   tf.math.invert_permutation(self.permutation),
                   axis=self.axis)
     tensorshape_util.set_shape(x, y.shape)
     return x
    def _sample_n(self, n, seed=None):
        stream = SeedStream(seed, salt="VectorDiffeomixture")
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=stream())  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
        mix_batch_size = tensorshape_util.num_elements(
            self.mixture_distribution.batch_shape)
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]),
                                          [batch_size // mix_batch_size])),
                                               seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:])
        if stride is None:
            stride = tf.reduce_prod(tf.shape(self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if tensorshape_util.is_fully_defined(self.batch_shape):
            new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
示例#19
0
 def _inverse_event_shape_tensor(self, output_shape):
     perm = self._make_perm(tf.size(output_shape), tf.argsort(self.perm))
     return tf.gather(output_shape, perm)
示例#20
0
 def _mode(self):
     return tf.gather(self.outcomes, indices=self._categorical.mode())