示例#1
0
 def _sample_n(self, n, seed=None):
   seed1, seed2 = samplers.split_seed(seed, salt='beta')
   concentration1 = tf.convert_to_tensor(self.concentration1)
   concentration0 = tf.convert_to_tensor(self.concentration0)
   shape = self._batch_shape_tensor(concentration1, concentration0)
   expanded_concentration1 = tf.broadcast_to(concentration1, shape)
   expanded_concentration0 = tf.broadcast_to(concentration0, shape)
   gamma1_sample = samplers.gamma(
       shape=[n], alpha=expanded_concentration1, dtype=self.dtype, seed=seed1)
   gamma2_sample = samplers.gamma(
       shape=[n], alpha=expanded_concentration0, dtype=self.dtype, seed=seed2)
   beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
   return beta_sample
示例#2
0
 def _sample_n(self, n, seed=None):
     concentration = tf.convert_to_tensor(self.concentration)
     mixing_concentration = tf.convert_to_tensor(self.mixing_concentration)
     mixing_rate = tf.convert_to_tensor(self.mixing_rate)
     seed_rate, seed_samples = samplers.split_seed(seed, salt='gamma_gamma')
     rate = samplers.gamma(
         shape=[n],
         # Be sure to draw enough rates for the fully-broadcasted gamma-gamma.
         alpha=mixing_concentration + tf.zeros_like(concentration),
         beta=mixing_rate,
         dtype=self.dtype,
         seed=seed_rate)
     return samplers.gamma(shape=[],
                           alpha=concentration,
                           beta=rate,
                           dtype=self.dtype,
                           seed=seed_samples)
 def _sample_n(self, n, seed=None):
   # Here we use the fact that if:
   # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
   # then X ~ Poisson(lam) is Negative Binomially distributed.
   logits = self._logits_parameter_no_checks()
   gamma_seed, poisson_seed = samplers.split_seed(
       seed, salt='NegativeBinomial')
   rate = samplers.gamma(
       shape=[n],
       alpha=self.total_count,
       beta=tf.math.exp(-logits),
       dtype=self.dtype,
       seed=gamma_seed)
   return samplers.poisson(
       shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
 def _sample_n(self, n, seed=None):
   # Here we use the fact that if:
   # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
   # then X ~ Poisson(lam) is Negative Binomially distributed.
   logits = self._logits_parameter_no_checks()
   gamma_seed, poisson_seed = samplers.split_seed(
       seed, salt='NegativeBinomial')
   # TODO(b/152785714): For some reason switching to gamma_lib.random_gamma
   # makes tests time out. Note: observed similar in jax_transformation_test.
   rate = samplers.gamma(
       shape=[n],
       alpha=self.total_count,
       beta=tf.math.exp(-logits),
       dtype=self.dtype,
       seed=gamma_seed)
   return samplers.poisson(
       shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
示例#5
0
def sample_n(n, df, loc, scale, batch_shape, dtype, seed):
  """Draw n samples from a Student T distribution.

  Note that `scale` can be negative or zero.
  The sampling method comes from the fact that if:
    X ~ Normal(0, 1)
    Z ~ Chi2(df)
    Y = X / sqrt(Z / df)
  then:
    Y ~ StudentT(df)

  Args:
    n: int, number of samples
    df: Floating-point `Tensor`. The degrees of freedom of the
      distribution(s). `df` must contain only positive values.
    loc: Floating-point `Tensor`; the location(s) of the distribution(s).
    scale: Floating-point `Tensor`; the scale(s) of the distribution(s). Must
      contain only positive values.
    batch_shape: Callable to compute batch shape
    dtype: Return dtype.
    seed: Optional seed for random draw.

  Returns:
    samples: a `Tensor` with prepended dimensions `n`.
  """
  normal_seed, gamma_seed = samplers.split_seed(seed, salt='student_t')
  shape = tf.concat([[n], batch_shape], 0)

  normal_sample = samplers.normal(shape, dtype=dtype, seed=normal_seed)
  df = df * tf.ones(batch_shape, dtype=dtype)
  gamma_sample = samplers.gamma([n],
                                0.5 * df,
                                beta=0.5,
                                dtype=dtype,
                                seed=gamma_seed)
  samples = normal_sample * tf.math.rsqrt(gamma_sample / df)
  return samples * scale + loc
示例#6
0
    def _sample_n(self, n, seed=None):
        gamma_seed, multinomial_seed = samplers.split_seed(
            seed, salt='dirichlet_multinomial')

        concentration = tf.convert_to_tensor(self._concentration)
        total_count = tf.convert_to_tensor(self._total_count)

        n_draws = tf.cast(total_count, dtype=tf.int32)
        k = self._event_shape_tensor(concentration)[0]
        alpha = tf.math.multiply(tf.ones_like(total_count[..., tf.newaxis]),
                                 concentration,
                                 name='alpha')

        unnormalized_logits = tf.math.log(
            samplers.gamma(shape=[n],
                           alpha=alpha,
                           dtype=self.dtype,
                           seed=gamma_seed))
        x = multinomial.draw_sample(1, k, unnormalized_logits, n_draws,
                                    self.dtype, multinomial_seed)
        final_shape = tf.concat(
            [[n],
             self._batch_shape_tensor(concentration, total_count), [k]], 0)
        return tf.reshape(x, final_shape)
示例#7
0
 def _sample_n(self, n, seed=None):
     return 1. / samplers.gamma(shape=[n],
                                alpha=self.concentration,
                                beta=self.scale,
                                dtype=self.dtype,
                                seed=seed)
 def _sample_n(self, n, seed=None):
   gamma_sample = samplers.gamma(
       shape=[n], alpha=self.concentration, dtype=self.dtype, seed=seed)
   return gamma_sample / tf.reduce_sum(gamma_sample, axis=-1, keepdims=True)
示例#9
0
 def _sample_n(self, n, seed=None):
     return samplers.gamma(shape=[n],
                           alpha=0.5 * self.df,
                           beta=tf.convert_to_tensor(0.5, dtype=self.dtype),
                           dtype=self.dtype,
                           seed=seed)
示例#10
0
def slice_sampler_one_dim(target_log_prob,
                          x_initial,
                          step_size=0.01,
                          max_doublings=30,
                          seed=None,
                          name=None):
    """For a given x position in each Markov chain, returns the next x.

  Applies the one dimensional slice sampling algorithm as defined in Neal (2003)
  to an input tensor x of shape (num_chains,) where num_chains is the number of
  simulataneous Markov chains, and returns the next tensor x of shape
  (num_chains,) when these chains are evolved by the slice sampling algorithm.

  Args:
    target_log_prob: Callable accepting a tensor like `x_initial` and returning
      a tensor containing the log density at that point of the same shape.
    x_initial: A tensor of any shape. The initial positions of the chains. This
      function assumes that all the dimensions of `x_initial` are batch
      dimensions (i.e. the event shape is `[]`).
    step_size: A tensor of shape and dtype compatible with `x_initial`. The min
      interval size in the doubling algorithm.
    max_doublings: Scalar tensor of dtype `tf.int32`. The maximum number of
      doublings to try to find the slice bounds.
    seed: (Optional) positive int, or Tensor seed pair. The random seed.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    retval: A tensor of the same shape and dtype as `x_initial`. The next state
      of the Markov chain.
    next_target_log_prob: The target log density evaluated at `retval`.
    bounds_satisfied: A tensor of bool dtype and shape batch dimensions.
    upper_bounds: Tensor of the same shape and dtype as `x_initial`. The upper
      bounds for the slice found.
    lower_bounds: Tensor of the same shape and dtype as `x_initial`. The lower
      bounds for the slice found.
  """
    gamma_seed, bounds_seed, sample_seed = samplers.split_seed(
        seed, n=3, salt='ssu.slice_sampler_one_dim')
    with tf.name_scope(name or 'slice_sampler_one_dim'):
        dtype = dtype_util.common_dtype([x_initial, step_size],
                                        dtype_hint=tf.float32)
        x_initial = tf.convert_to_tensor(x_initial, dtype=dtype)
        step_size = tf.convert_to_tensor(step_size, dtype=dtype)
        # Obtain the input dtype of the array.
        # Select the height of the slice. Tensor of shape x_initial.shape.
        log_slice_heights = target_log_prob(x_initial) - samplers.gamma(
            ps.shape(x_initial), alpha=1, dtype=dtype, seed=gamma_seed)
        # Given the above x and slice heights, compute the bounds of the slice for
        # each chain.
        upper_bounds, lower_bounds, bounds_satisfied = slice_bounds_by_doubling(
            x_initial,
            target_log_prob,
            log_slice_heights,
            max_doublings,
            step_size,
            seed=bounds_seed)
        retval = _sample_with_shrinkage(x_initial,
                                        target_log_prob=target_log_prob,
                                        log_slice_heights=log_slice_heights,
                                        step_size=step_size,
                                        lower_bounds=lower_bounds,
                                        upper_bounds=upper_bounds,
                                        seed=sample_seed)
        return (retval, target_log_prob(retval), bounds_satisfied,
                upper_bounds, lower_bounds)
示例#11
0
    def _sample_n(self, n, seed):
        df = tf.convert_to_tensor(self.df)
        batch_shape = self._batch_shape_tensor(df)
        event_shape = self._event_shape_tensor()
        batch_ndims = tf.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = tf.concat([[n], batch_shape, event_shape], 0)
        normal_seed, gamma_seed = samplers.split_seed(seed, salt='Wishart')

        # Complexity: O(nbk**2)
        x = samplers.normal(shape=shape,
                            mean=0.,
                            stddev=1.,
                            dtype=self.dtype,
                            seed=normal_seed)

        # Complexity: O(nbk)
        # This parameterization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = df * tf.ones(self._scale.batch_shape_tensor(),
                                   dtype=dtype_util.base_dtype(df.dtype))

        g = samplers.gamma(shape=[n],
                           alpha=self._multi_gamma_sequence(
                               0.5 * expanded_df, self._dimension()),
                           beta=0.5,
                           dtype=self.dtype,
                           seed=gamma_seed)

        # Complexity: O(nbk**2)
        x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.linalg.set_diag(x, tf.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = tf.concat([tf.range(1, ndims), [0]], 0)
        x = tf.transpose(a=x, perm=perm)
        shape = tf.concat(
            [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
        # this step has complexity O(nbk^3).
        x = self._scale.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
        x = tf.transpose(a=x, perm=perm)

        if not self.input_output_cholesky:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x