Example #1
0
 def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None):
   # This function additionally depends on:
   #   self._dist_fn_wrapped
   #   self._dist_fn_args
   #   self._always_use_specified_sample_shape
   seed = SeedStream('JointDistributionSequential', seed)
   ds = []
   xs = [None]*len(self._dist_fn_wrapped) if value is None else list(value)
   if len(xs) != len(self._dist_fn_wrapped):
     raise ValueError('Number of `xs`s must match number of '
                      'distributions.')
   for i, (dist_fn, args) in enumerate(zip(self._dist_fn_wrapped,
                                           self._dist_fn_args)):
     ds.append(dist_fn(*xs[:i]))  # Chain rule of probability.
     if xs[i] is None:
       # TODO(b/129364796): We should ignore args prefixed with `_`; this
       # would mean we more often identify when to use `sample_shape=()`
       # rather than `sample_shape=sample_shape`.
       xs[i] = ds[-1].sample(
           () if args and not self._always_use_specified_sample_shape
           else sample_shape, seed=seed())
     else:
       xs[i] = tf.convert_to_tensor(xs[i], dtype_hint=ds[-1].dtype)
       seed()  # Ensure reproducibility even when xs are (partially) set.
   # Note: we could also resolve distributions up to the first non-`None` in
   # `self._model_flatten(value)`, however we omit this feature for simplicity,
   # speed, and because it has not yet been requested.
   return ds, xs
 def _sample_n(self, n, seed=None):
     seed = SeedStream(seed, salt="Autoregressive")()
     samples = self.distribution0.sample(n, seed=seed)
     for _ in range(self._num_steps):
         # pylint: disable=not-callable
         samples = self.distribution_fn(samples).sample(seed=seed)
     return samples
Example #3
0
    def _sample_n(self, n, seed=None):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        stream = SeedStream(seed, salt='triangular')
        shape = tf.concat(
            [[n], self._batch_shape_tensor(low=low, high=high, peak=peak)],
            axis=0)
        samples = tf.random.uniform(shape=shape,
                                    dtype=self.dtype,
                                    seed=stream())
        # We use Inverse CDF sampling here. Because the CDF is a quadratic function,
        # we must use sqrts here.
        interval_length = high - low
        return tf.where(
            # Note the CDF on the left side of the peak is
            # (x - low) ** 2 / ((high - low) * (peak - low)).
            # If we plug in peak for x, we get that the CDF at the peak
            # is (peak - low) / (high - low). Because of this we decide
            # which part of the piecewise CDF we should use based on the cdf samples
            # we drew.
            samples < (peak - low) / interval_length,
            # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)).
            low + tf.sqrt(samples * interval_length * (peak - low)),
            # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak))
            high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
Example #4
0
 def _sample_n(self, n, seed=None):
     scale = tf.convert_to_tensor(self.scale)
     shape = tf.concat([[n], tf.shape(scale)], axis=0)
     seed = SeedStream(seed, salt='random_horseshoe')
     local_shrinkage = self._half_cauchy.sample(shape, seed=seed())
     shrinkage = scale * local_shrinkage
     sampled = tf.random.normal(shape=shape,
                                mean=0.,
                                stddev=1.,
                                dtype=scale.dtype,
                                seed=seed())
     return sampled * shrinkage
Example #5
0
 def _sample_n(self, n, seed=None):
     # Here we use the fact that if:
     # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
     # then X ~ Poisson(lam) is Negative Binomially distributed.
     logits = self._logits_parameter_no_checks()
     stream = SeedStream(seed, salt='NegativeBinomial')
     rate = tf.random.gamma(shape=[n],
                            alpha=self.total_count,
                            beta=tf.exp(-logits),
                            dtype=self.dtype,
                            seed=stream())
     return tf.random.poisson(lam=rate,
                              shape=[],
                              dtype=self.dtype,
                              seed=stream())
Example #6
0
 def _sample_n(self, n, seed=None):
     concentration = tf.convert_to_tensor(self.concentration)
     mixing_concentration = tf.convert_to_tensor(self.mixing_concentration)
     mixing_rate = tf.convert_to_tensor(self.mixing_rate)
     seed = SeedStream(seed, 'gamma_gamma')
     rate = tf.random.gamma(
         shape=[n],
         # Be sure to draw enough rates for the fully-broadcasted gamma-gamma.
         alpha=mixing_concentration + tf.zeros_like(concentration),
         beta=mixing_rate,
         dtype=self.dtype,
         seed=seed())
     return tf.random.gamma(shape=[],
                            alpha=concentration,
                            beta=rate,
                            dtype=self.dtype,
                            seed=seed())
Example #7
0
 def _sample_n(self, n, seed=None):
     seed = SeedStream(seed, "beta")
     concentration1 = tf.convert_to_tensor(self.concentration1)
     concentration0 = tf.convert_to_tensor(self.concentration0)
     shape = self._batch_shape_tensor(concentration1, concentration0)
     expanded_concentration1 = tf.broadcast_to(concentration1, shape)
     expanded_concentration0 = tf.broadcast_to(concentration0, shape)
     gamma1_sample = tf.random.gamma(shape=[n],
                                     alpha=expanded_concentration1,
                                     dtype=self.dtype,
                                     seed=seed())
     gamma2_sample = tf.random.gamma(shape=[n],
                                     alpha=expanded_concentration0,
                                     dtype=self.dtype,
                                     seed=seed())
     beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
     return beta_sample
    def _sample_n(self, n, seed=None):
        # Like with the univariate Student's t, sampling can be implemented as a
        # ratio of samples from a multivariate gaussian with the appropriate
        # covariance matrix and a sample from the chi-squared distribution.
        seed = SeedStream(seed, salt='multivariate t')

        loc = tf.broadcast_to(self.loc, self._sample_shape())
        mvn = mvn_linear_operator.MultivariateNormalLinearOperator(
            loc=tf.zeros_like(loc), scale=self.scale)
        normal_samp = mvn.sample(n, seed=seed())

        df = tf.broadcast_to(self.df, self.batch_shape_tensor())
        chi2 = chi2_lib.Chi2(df=df)
        chi2_samp = chi2.sample(n, seed=seed())

        return (
            self._loc +
            normal_samp * tf.math.rsqrt(chi2_samp / self._df)[..., tf.newaxis])
    def _sample_n(self, n, seed=None):
        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        distributions = self.poisson_and_mixture_distributions()
        dist, mixture_dist = distributions
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(
                self._batch_shape_tensor(distributions=distributions))
        # We need to 'sample extra' from the mixture distribution if it doesn't
        # already specify a probs vector for each batch coordinate.
        # We only support this kind of reduced broadcasting, i.e., there is exactly
        # one probs vector for all batch dims or one for each.
        stream = SeedStream(seed, salt='PoissonLogNormalQuadratureCompound')
        ids = mixture_dist.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(mixture_dist.is_scalar_batch(),
                                          [batch_size], np.int32([]))),
                                  seed=stream())
        # We need to flatten batch dims in case mixture_dist has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `quadrature_size` for `batch_size` number of times.
        offset = tf.range(start=0,
                          limit=batch_size * self._quadrature_size,
                          delta=self._quadrature_size,
                          dtype=ids.dtype)
        ids = ids + offset
        rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids)
        rate = tf.reshape(
            rate,
            shape=concat_vectors(
                [n], self._batch_shape_tensor(distributions=distributions)))
        return tf.random.poisson(lam=rate,
                                 shape=[],
                                 dtype=self.dtype,
                                 seed=seed)
 def _sample_n(self, n, seed=None):
   # See https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution or
   # https://www.jstor.org/stable/2683801
   concentration = tf.convert_to_tensor(self.concentration)
   loc = tf.convert_to_tensor(self.loc)
   seed = SeedStream(seed, 'inverse_gaussian')
   shape = tf.concat([[n], self._batch_shape_tensor(
       loc=loc, concentration=concentration)], axis=0)
   sampled_chi2 = (tf.random.normal(
       shape, mean=0., stddev=1., seed=seed(), dtype=self.dtype))**2.
   sampled_uniform = tf.random.uniform(
       shape, minval=0., maxval=1., seed=seed(), dtype=self.dtype)
   sampled = (
       loc + loc ** 2. * sampled_chi2 / (2. * concentration) -
       loc / (2. * concentration) *
       (4. * loc * concentration * sampled_chi2 +
        (loc * sampled_chi2) ** 2) ** 0.5)
   return tf.where(sampled_uniform <= loc / (loc + sampled),
                   sampled, loc**2 / sampled)
Example #11
0
    def _flat_sample_distributions(self,
                                   sample_shape=(),
                                   seed=None,
                                   value=None):
        """Executes `model`, creating both samples and distributions."""
        ds = []
        values_out = []
        seed = SeedStream('JointDistributionCoroutine', seed)
        gen = self._model()
        index = 0
        d = next(gen)
        if not isinstance(d, self.Root):
            raise ValueError('First distribution yielded by coroutine must '
                             'be wrapped in `Root`.')
        try:
            while True:
                actual_distribution = d.distribution if isinstance(
                    d, self.Root) else d
                ds.append(actual_distribution)
                if (value is not None and len(value) > index
                        and value[index] is not None):
                    seed()
                    next_value = value[index]
                else:
                    next_value = actual_distribution.sample(
                        sample_shape=sample_shape
                        if isinstance(d, self.Root) else (),
                        seed=seed())

                if self._validate_args:
                    with tf.control_dependencies(
                            self._assert_compatible_shape(
                                index, sample_shape, next_value)):
                        values_out.append(tf.identity(next_value))
                else:
                    values_out.append(next_value)

                index += 1
                d = gen.send(next_value)
        except StopIteration:
            pass
        return ds, values_out
Example #12
0
 def _sample_n(self, n, seed):
     with tf.control_dependencies(self._runtime_assertions):
         seed = SeedStream(seed, salt="MixtureSameFamily")
         x = self.components_distribution.sample(
             n, seed=seed())  # [n, B, k, E]
         # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
         npdt = dtype_util.as_numpy_dtype(x.dtype)
         mask = tf.one_hot(
             indices=self.mixture_distribution.sample(
                 n, seed=seed()),  # [n, B]
             depth=self._num_components,  # == k
             on_value=npdt(1),
             off_value=npdt(0))  # [n, B, k]
         mask = distribution_utils.pad_mixture_dimensions(
             mask, self, self.mixture_distribution,
             self._event_ndims)  # [n, B, k, [1]*e]
         x = tf.reduce_sum(x * mask,
                           axis=-1 - self._event_ndims)  # [n, B, E]
         if self._reparameterize:
             x = self._reparameterize_sample(x)
         return x
 def _sample_3d(self, n, seed=None):
   """Specialized inversion sampler for 3D."""
   seed = SeedStream(seed, salt='von_mises_fisher_3d')
   u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
   z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype)
   # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could
   # be bisected for bounded sampling runtime (i.e. not rejection sampling).
   # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/
   # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa
   # We must protect against both kappa and z being zero.
   safe_conc = tf.where(self.concentration > 0, self.concentration,
                        tf.ones_like(self.concentration))
   safe_z = tf.where(z > 0, z, tf.ones_like(z))
   safe_u = 1 + tf.reduce_logsumexp(
       [tf.math.log(safe_z),
        tf.math.log1p(-safe_z) - 2 * safe_conc], axis=0) / safe_conc
   # Limit of the above expression as kappa->0 is 2*z-1
   u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u, 2 * z - 1)
   # Limit of the expression as z->0 is -1.
   u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u)
   if not self._allow_nan_stats:
     u = tf.debugging.check_numerics(u, 'u in _sample_3d')
   return u[..., tf.newaxis]
Example #14
0
    def _sample_n(self, n, seed=None):
        seed = SeedStream(seed, 'dirichlet_multinomial')

        concentration = tf.convert_to_tensor(self._concentration)
        total_count = tf.convert_to_tensor(self._total_count)

        n_draws = tf.cast(total_count, dtype=tf.int32)
        k = self._event_shape_tensor(concentration)[0]
        alpha = tf.math.multiply(tf.ones_like(total_count[..., tf.newaxis]),
                                 concentration,
                                 name='alpha')

        unnormalized_logits = tf.math.log(
            tf.random.gamma(shape=[n],
                            alpha=alpha,
                            dtype=self.dtype,
                            seed=seed()))
        x = multinomial.draw_sample(1, k, unnormalized_logits, n_draws,
                                    self.dtype, seed())
        final_shape = tf.concat(
            [[n],
             self._batch_shape_tensor(concentration, total_count), [k]], 0)
        return tf.reshape(x, final_shape)
Example #15
0
    def _sample_n(self, n, seed=None):
        # The sampling method comes from the fact that if:
        #   X ~ Normal(0, 1)
        #   Z ~ Chi2(df)
        #   Y = X / sqrt(Z / df)
        # then:
        #   Y ~ StudentT(df).
        df = tf.convert_to_tensor(self.df)
        loc = tf.convert_to_tensor(self.loc)
        scale = tf.convert_to_tensor(self.scale)
        batch_shape = self._batch_shape_tensor(df=df, loc=loc, scale=scale)
        shape = tf.concat([[n], batch_shape], 0)
        seed = SeedStream(seed, 'student_t')

        normal_sample = tf.random.normal(shape, dtype=self.dtype, seed=seed())
        df = df * tf.ones(batch_shape, dtype=self.dtype)
        gamma_sample = tf.random.gamma([n],
                                       0.5 * df,
                                       beta=0.5,
                                       dtype=self.dtype,
                                       seed=seed())
        samples = normal_sample * tf.math.rsqrt(gamma_sample / df)
        return samples * scale + loc  # Abs(scale) not wanted.
Example #16
0
    def _sample_n(self, num_samples, seed=None, name=None):
        """Returns a Tensor of samples from an LKJ distribution.

    Args:
      num_samples: Python `int`. The number of samples to draw.
      seed: Python integer seed for RNG
      name: Python `str` name prefixed to Ops created by this function.

    Returns:
      samples: A Tensor of correlation matrices with shape `[n, B, D, D]`,
        where `B` is the shape of the `concentration` parameter, and `D`
        is the `dimension`.

    Raises:
      ValueError: If `dimension` is negative.
    """
        if self.dimension < 0:
            raise ValueError(
                'Cannot sample negative-dimension correlation matrices.')
        # Notation below: B is the batch shape, i.e., tf.shape(concentration)
        seed = SeedStream(seed, 'sample_lkj')
        with tf.name_scope('sample_lkj' or name):
            concentration = tf.convert_to_tensor(self.concentration)
            if not dtype_util.is_floating(concentration.dtype):
                raise TypeError(
                    'The concentration argument should have floating type, not '
                    '{}'.format(dtype_util.name(concentration.dtype)))

            concentration = _replicate(num_samples, concentration)
            concentration_shape = tf.shape(concentration)
            if self.dimension <= 1:
                # For any dimension <= 1, there is only one possible correlation matrix.
                shape = tf.concat(
                    [concentration_shape, [self.dimension, self.dimension]],
                    axis=0)
                return tf.ones(shape=shape, dtype=concentration.dtype)
            beta_conc = concentration + (self.dimension - 2.) / 2.
            beta_dist = beta.Beta(concentration1=beta_conc,
                                  concentration0=beta_conc)

            # Note that the sampler below deviates from [1], by doing the sampling in
            # cholesky space. This does not change the fundamental logic of the
            # sampler, but does speed up the sampling.

            # This is the correlation coefficient between the first two dimensions.
            # This is also `r` in reference [1].
            corr12 = 2. * beta_dist.sample(seed=seed()) - 1.

            # Below we construct the Cholesky of the initial 2x2 correlation matrix,
            # which is of the form:
            # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the
            # first two dimensions.
            # This is the top-left corner of the cholesky of the final sample.
            first_row = tf.concat([
                tf.ones_like(corr12)[..., tf.newaxis],
                tf.zeros_like(corr12)[..., tf.newaxis]
            ],
                                  axis=-1)
            second_row = tf.concat([
                corr12[..., tf.newaxis],
                tf.sqrt(1 - corr12**2)[..., tf.newaxis]
            ],
                                   axis=-1)

            chol_result = tf.concat([
                first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :]
            ],
                                    axis=-2)

            for n in range(2, self.dimension):
                # Loop invariant: on entry, result has shape B + [n, n]
                beta_conc = beta_conc - 0.5
                # norm is y in reference [1].
                norm = beta.Beta(concentration1=n / 2.,
                                 concentration0=beta_conc).sample(seed=seed())
                # distance shape: B + [1] for broadcast
                distance = tf.sqrt(norm)[..., tf.newaxis]
                # direction is u in reference [1].
                # direction shape: B + [n]
                direction = _uniform_unit_norm(n, concentration_shape,
                                               concentration.dtype, seed)
                # raw_correlation is w in reference [1].
                raw_correlation = distance * direction  # shape: B + [n]

                # This is the next row in the cholesky of the result,
                # which differs from the construction in reference [1].
                # In the reference, the new row `z` = chol_result @ raw_correlation^T
                # = C @ raw_correlation^T (where as short hand we use C = chol_result).
                # We prove that the below equation is the right row to add to the
                # cholesky, by showing equality with reference [1].
                # Let S be the sample constructed so far, and let `z` be as in
                # reference [1]. Then at this iteration, the new sample S' will be
                # [[S z^T]
                #  [z 1]]
                # In our case we have the cholesky decomposition factor C, so
                # we want our new row x (same size as z) to satisfy:
                #  [[S z^T]  [[C 0]    [[C^T  x^T]         [[CC^T  Cx^T]
                #   [z 1]] =  [x k]]    [0     k]]  =       [xC^t   xx^T + k**2]]
                # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible,
                # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k
                # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 -
                # distance**2).
                new_row = tf.concat(
                    [raw_correlation,
                     tf.sqrt(1. - norm[..., tf.newaxis])],
                    axis=-1)

                # Finally add this new row, by growing the cholesky of the result.
                chol_result = tf.concat([
                    chol_result,
                    tf.zeros_like(chol_result[..., 0][..., tf.newaxis])
                ],
                                        axis=-1)

                chol_result = tf.concat(
                    [chol_result, new_row[..., tf.newaxis, :]], axis=-2)

            if self.input_output_cholesky:
                return chol_result

            result = tf.matmul(chol_result, chol_result, transpose_b=True)
            # The diagonal for a correlation matrix should always be ones. Due to
            # numerical instability the matmul might not achieve that, so manually set
            # these to ones.
            result = tf.linalg.set_diag(
                result, tf.ones(shape=tf.shape(result)[:-1],
                                dtype=result.dtype))
            # This sampling algorithm can produce near-PSD matrices on which standard
            # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals`
            # fail. Specifically, as documented in b/116828694, around 2% of trials
            # of 900,000 5x5 matrices (distributed according to 9 different
            # concentration parameter values) contained at least one matrix on which
            # the Cholesky decomposition failed.
            return result
Example #17
0
  def _sample_n(self, n, seed=None):
    if self._use_static_graph:
      with tf.control_dependencies(self._assertions):
        # This sampling approach is almost the same as the approach used by
        # `MixtureSameFamily`. The differences are due to having a list of
        # `Distribution` objects rather than a single object, and maintaining
        # random seed management that is consistent with the non-static code
        # path.
        samples = []
        cat_samples = self.cat.sample(n, seed=seed)
        stream = SeedStream(seed, salt="Mixture")

        for c in range(self.num_components):
          samples.append(self.components[c].sample(n, seed=stream()))
        stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
        x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        mask = tf.one_hot(
            indices=cat_samples,  # [n, B]
            depth=self._num_components,  # == k
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k]
        mask = distribution_util.pad_mixture_dimensions(
            mask, self, self._cat,
            tensorshape_util.rank(self._static_event_shape))  # [n, B, k, [1]*e]
        return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

    with tf.control_dependencies(self._assertions):
      n = tf.convert_to_tensor(n, name="n")
      static_n = tf.get_static_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples = self.cat.sample(n, seed=seed)

      static_samples_shape = cat_samples.shape
      if tensorshape_util.is_fully_defined(static_samples_shape):
        samples_shape = tensorshape_util.as_list(static_samples_shape)
        samples_size = tensorshape_util.num_elements(static_samples_shape)
      else:
        samples_shape = tf.shape(cat_samples)
        samples_size = tf.size(cat_samples)
      static_batch_shape = self.batch_shape
      if tensorshape_util.is_fully_defined(static_batch_shape):
        batch_shape = tensorshape_util.as_list(static_batch_shape)
        batch_size = tensorshape_util.num_elements(static_batch_shape)
      else:
        batch_shape = self.batch_shape_tensor()
        batch_size = tf.reduce_prod(batch_shape)
      static_event_shape = self.event_shape
      if tensorshape_util.is_fully_defined(static_event_shape):
        event_shape = np.array(
            tensorshape_util.as_list(static_event_shape), dtype=np.int32)
      else:
        event_shape = self.event_shape_tensor()

      # Get indices into the raw cat sampling tensor. We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = tf.dynamic_partition(
          data=samples_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = tf.reshape(
          tf.tile(tf.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = tf.dynamic_partition(
          data=batch_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)
      samples_class = [None for _ in range(self.num_components)]

      stream = SeedStream(seed, salt="Mixture")

      for c in range(self.num_components):
        n_class = tf.size(partitioned_samples_indices[c])
        samples_class_c = self.components[c].sample(
            n_class, seed=stream())

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index. The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * tf.range(n_class) + partitioned_batch_indices[c])
        samples_class_c = tf.reshape(
            samples_class_c, tf.concat([[n_class * batch_size], event_shape],
                                       0))
        samples_class_c = tf.gather(
            samples_class_c,
            lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = tf.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = tf.reshape(
          lhs_flat_ret, tf.concat(
              [samples_shape, self.event_shape_tensor()], 0))
      tensorshape_util.set_shape(
          ret,
          tensorshape_util.concatenate(static_samples_shape, self.event_shape))
      return ret
Example #18
0
  def _sample_n(self, n, seed=None):
    power = tf.convert_to_tensor(self.power)
    shape = tf.concat([[n], tf.shape(power)], axis=0)

    has_seed = seed is not None
    seed = SeedStream(seed, salt='zipf')

    minval_u = self._hat_integral(0.5, power=power) + 1.
    maxval_u = self._hat_integral(tf.int64.max - 0.5, power=power)

    def loop_body(should_continue, k):
      """Resample the non-accepted points."""
      # The range of U is chosen so that the resulting sample K lies in
      # [0, tf.int64.max). The final sample, if accepted, is K + 1.
      u = tf.random.uniform(
          shape,
          minval=minval_u,
          maxval=maxval_u,
          dtype=power.dtype,
          seed=seed())

      # Sample the point X from the continuous density h(x) \propto x^(-power).
      x = self._hat_integral_inverse(u, power=power)

      # Rejection-inversion requires a `hat` function, h(x) such that
      # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
      # support. A natural hat function for us is h(x) = x^(-power).
      #
      # After sampling X from h(x), suppose it lies in the interval
      # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
      # if lies to the left of x_K, where x_K is defined by:
      #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
      # where H(x) = \int_x^inf h(x) dx.

      # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
      # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
      # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

      # Update the non-accepted points.
      # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
      k = tf.where(should_continue, tf.floor(x + 0.5), k)
      accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp(
          self._log_prob(k + 1, power=power)))

      return [should_continue & (~accept), k]

    should_continue, samples = tf.while_loop(
        cond=lambda should_continue, *ignore: tf.reduce_any(should_continue),
        body=loop_body,
        loop_vars=[
            tf.ones(shape, dtype=tf.bool),  # should_continue
            tf.zeros(shape, dtype=power.dtype),  # k
        ],
        parallel_iterations=1 if has_seed else 10,
        maximum_iterations=self.sample_maximum_iterations,
    )
    samples = samples + 1.

    if self.validate_args and dtype_util.is_integer(self.dtype):
      samples = distribution_util.embed_check_integer_casting_closed(
          samples, target_dtype=self.dtype, assert_positive=True)

    samples = tf.cast(samples, self.dtype)

    if self.validate_args:
      npdt = dtype_util.as_numpy_dtype(self.dtype)
      v = npdt(dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan)
      samples = tf.where(should_continue, v, samples)

    return samples
Example #19
0
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        batch_ndims = tf.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = tf.concat([[n], batch_shape, event_shape], 0)
        stream = SeedStream(seed, salt="Wishart")

        # Complexity: O(nbk**2)
        x = tf.random.normal(shape=shape,
                             mean=0.,
                             stddev=1.,
                             dtype=self.dtype,
                             seed=stream())

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = self.df * tf.ones(
            self.scale_operator.batch_shape_tensor(),
            dtype=dtype_util.base_dtype(self.df.dtype))

        g = tf.random.gamma(shape=[n],
                            alpha=self._multi_gamma_sequence(
                                0.5 * expanded_df, self.dimension),
                            beta=0.5,
                            dtype=self.dtype,
                            seed=stream())

        # Complexity: O(nbk**2)
        x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.linalg.set_diag(x, tf.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = tf.concat([tf.range(1, ndims), [0]], 0)
        x = tf.transpose(a=x, perm=perm)
        shape = tf.concat(
            [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
        # this step has complexity O(nbk^3).
        x = self.scale_operator.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
        x = tf.transpose(a=x, perm=perm)

        if not self.input_output_cholesky:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x
    def _sample_n(self, n, seed=None):
        stream = SeedStream(seed, salt="VectorDiffeomixture")
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=stream())  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
        mix_batch_size = tensorshape_util.num_elements(
            self.mixture_distribution.batch_shape)
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]),
                                          [batch_size // mix_batch_size])),
                                               seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:])
        if stride is None:
            stride = tf.reduce_prod(tf.shape(self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if tensorshape_util.is_fully_defined(self.batch_shape):
            new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
  def _sample_n(self, n, seed=None):
    seed = SeedStream(seed, salt='vom_mises_fisher')
    # The sampling strategy relies on the fact that vMF variates are symmetric
    # about the mean direction. Accordingly, if we have a sampling strategy for
    # the away-from-mean angle, then we can uniformly sample the remaining
    # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
    # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
    #
    # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
    # von-Mises distributed `x` value in [-1, 1], then uniformly select what
    # amounts to a "up" or "down" additional degree of freedom after unit
    # normalizing, followed by a final rotation to the desired mean direction
    # from a basis of (1, 0).
    #
    # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
    # unit sphere over which the distribution is uniform, in particular the
    # circle where x = \hat{x} intersects the unit sphere. We pick a point on
    # that circle, then rotate to the desired mean direction from a basis of
    # (1, 0, 0).
    event_dim = (
        tf.compat.dimension_value(self.event_shape[0]) or
        self._event_shape_tensor()[0])

    sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
    dim = tf.cast(event_dim - 1, self.dtype)
    if event_dim == 3:
      samples_dim0 = self._sample_3d(n, seed=seed)
    else:
      # Wood'94 provides a rejection algorithm to sample the x coordinate.
      # Wood'94 definition of b:
      # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
      # https://stats.stackexchange.com/questions/156729 suggests:
      b = dim / (2 * self.concentration +
                 tf.sqrt(4 * self.concentration**2 + dim**2))
      # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
      #     https://github.com/nicola-decao/s-vae-tf/
      x = (1 - b) / (1 + b)
      c = self.concentration * x + dim * tf.math.log1p(-x**2)
      beta = beta_lib.Beta(dim / 2, dim / 2)

      def cond_fn(w, should_continue):
        del w
        return tf.reduce_any(should_continue)

      def body_fn(w, should_continue):
        z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
        # set_shape needed here because of b/139013403
        z.set_shape(w.shape)
        w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
        w = tf.debugging.check_numerics(w, 'w')
        unif = tf.random.uniform(
            sample_batch_shape, seed=seed(), dtype=self.dtype)
        # set_shape needed here because of b/139013403
        unif.set_shape(w.shape)
        should_continue = tf.logical_and(
            should_continue,
            self.concentration * w + dim * tf.math.log1p(-x * w) - c <
            tf.math.log(unif))
        return w, should_continue

      w = tf.zeros(sample_batch_shape, dtype=self.dtype)
      should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
      samples_dim0 = tf.while_loop(
          cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0]
      samples_dim0 = samples_dim0[..., tf.newaxis]
    if not self._allow_nan_stats:
      # Verify samples are w/in -1, 1, with useful error output tensors (top
      # value rather than all values).
      with tf.control_dependencies([
          assert_util.assert_less_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(1.01),
              data=[tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]]),
          assert_util.assert_greater_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(-1.01),
              data=[-tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]])
      ]):
        samples_dim0 = tf.identity(samples_dim0)
    samples_otherdims_shape = tf.concat([sample_batch_shape, [event_dim - 1]],
                                        axis=0)
    unit_otherdims = tf.math.l2_normalize(
        tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
        axis=-1)
    samples = tf.concat([
        samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
        tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
    ], axis=-1)
    samples = tf.math.l2_normalize(samples, axis=-1)
    if not self._allow_nan_stats:
      samples = tf.debugging.check_numerics(samples, 'samples')

    # Runtime assert that samples are unit length.
    if not self._allow_nan_stats:
      worst, idx = tf.math.top_k(
          tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
      with tf.control_dependencies([
          assert_util.assert_near(
              dtype_util.as_numpy_dtype(self.dtype)(0),
              worst,
              data=[
                  worst, idx,
                  tf.gather(tf.reshape(samples, [-1, event_dim]), idx)
              ],
              atol=1e-4,
              summarize=100)
      ]):
        samples = tf.identity(samples)
    # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
    # Now, we move the mode to `self.mean_direction` using a rotation matrix.
    if not self._allow_nan_stats:
      # Assert that the basis vector rotates to the mean direction, as expected.
      basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                      self.dtype)
      with tf.control_dependencies([
          assert_util.assert_less(
              tf.linalg.norm(
                  self._rotate(basis) - self.mean_direction, axis=-1),
              dtype_util.as_numpy_dtype(self.dtype)(1e-5))
      ]):
        return self._rotate(samples)
    return self._rotate(samples)
    def _sample_n(self, n, seed=None):
        with tf.control_dependencies(self._runtime_assertions):
            strm = SeedStream(seed, salt="HiddenMarkovModel")

            num_states = self._num_states

            batch_shape = self.batch_shape_tensor()
            batch_size = tf.reduce_prod(batch_shape)

            # The batch sizes of the underlying initial distributions and
            # transition distributions might not match the batch size of
            # the HMM distribution.
            # As a result we need to ask for more samples from the
            # underlying distributions and then reshape the results into
            # the correct batch size for the HMM.
            init_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._initial_distribution.batch_shape_tensor()))
            init_state = self._initial_distribution.sample(n * init_repeat,
                                                           seed=strm())
            init_state = tf.reshape(init_state, [n, batch_size])
            # init_state :: n batch_size

            transition_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._transition_distribution.batch_shape_tensor()[:-1]))

            def generate_step(state, _):
                """Take a single step in Markov chain."""

                gen = self._transition_distribution.sample(n *
                                                           transition_repeat,
                                                           seed=strm())
                # gen :: (n * transition_repeat) transition_batch

                new_states = tf.reshape(gen, [n, batch_size, num_states])

                # new_states :: n batch_size num_states

                old_states_one_hot = tf.one_hot(state,
                                                num_states,
                                                dtype=tf.int32)

                # old_states :: n batch_size num_states

                return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

            def _scan_multiple_steps():
                """Take multiple steps with tf.scan."""
                dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
                if seed is not None:
                    # Force parallel_iterations to 1 to ensure reproducibility
                    # b/139210489
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state,
                                            parallel_iterations=1)
                else:
                    # Invoke default parallel_iterations behavior
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state)

                # TODO(b/115618503): add/use prepend_initializer to tf.scan
                return tf.concat([[init_state], hidden_states], axis=0)

            hidden_states = prefer_static.cond(
                self._num_steps > 1, _scan_multiple_steps,
                lambda: init_state[tf.newaxis, ...])

            hidden_one_hot = tf.one_hot(
                hidden_states,
                num_states,
                dtype=self._observation_distribution.dtype)
            # hidden_one_hot :: num_steps n batch_size num_states

            # The observation distribution batch size might not match
            # the required batch size so as with the initial and
            # transition distributions we generate more samples and
            # reshape.
            observation_repeat = (batch_size // tf.reduce_prod(
                self._observation_distribution.batch_shape_tensor()[:-1]))

            possible_observations = self._observation_distribution.sample(
                [self._num_steps, observation_repeat * n], seed=strm())

            inner_shape = self._observation_distribution.event_shape

            # possible_observations :: num_steps (observation_repeat * n)
            #                          observation_batch[:-1] num_states inner_shape

            possible_observations = tf.reshape(
                possible_observations,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           inner_shape],
                          axis=0))

            # possible_observations :: steps n batch_size num_states inner_shape

            hidden_one_hot = tf.reshape(
                hidden_one_hot,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           tf.ones_like(inner_shape)],
                          axis=0))

            # hidden_one_hot :: steps n batch_size num_states "inner_shape"

            observations = tf.reduce_sum(hidden_one_hot *
                                         possible_observations,
                                         axis=-1 - tf.size(inner_shape))

            # observations :: steps n batch_size inner_shape

            observations = distribution_util.move_dimension(
                observations, 0, 1 + tf.size(batch_shape))

            # returned :: n batch_shape steps inner_shape

            return observations