コード例 #1
0
    def testExponentialExponentialKL(self):
        a_rate = np.arange(0.5, 1.6, 0.1)
        b_rate = np.arange(0.5, 1.6, 0.1)

        # This reshape is intended to expand the number of test cases.
        a_rate = a_rate.reshape((len(a_rate), 1))
        b_rate = b_rate.reshape((1, len(b_rate)))

        a = exponential_lib.Exponential(rate=a_rate)
        b = exponential_lib.Exponential(rate=b_rate)

        # Consistent with
        # http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf, page 108
        true_kl = np.log(a_rate) - np.log(b_rate) + (b_rate - a_rate) / a_rate

        kl = tfd.kl_divergence(a, b)

        x = a.sample(int(4e5), seed=tfp_test_util.test_seed())
        kl_sample = tf.reduce_mean(a.log_prob(x) - b.log_prob(x), axis=0)

        kl_, kl_sample_ = self.evaluate([kl, kl_sample])
        self.assertAllClose(true_kl, kl_, atol=0., rtol=1e-12)
        self.assertAllClose(true_kl, kl_sample_, atol=0., rtol=8e-2)

        zero_kl = tfd.kl_divergence(a, a)
        true_zero_kl_, zero_kl_ = self.evaluate(
            [tf.zeros_like(zero_kl), zero_kl])
        self.assertAllEqual(true_zero_kl_, zero_kl_)
コード例 #2
0
 def testExponentialLogPDFBoundary(self):
     # Check that Log PDF is finite at 0.
     rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32)
     exponential = exponential_lib.Exponential(rate=rate,
                                               validate_args=False)
     log_pdf = exponential.log_prob(0.)
     self.assertAllClose(np.log(rate), self.evaluate(log_pdf))
コード例 #3
0
 def testExponentialMean(self):
     lam_v = np.array([1.0, 4.0, 2.5])
     exponential = exponential_lib.Exponential(rate=lam_v,
                                               validate_args=True)
     self.assertEqual(exponential.mean().shape, (3, ))
     expected_mean = sp_stats.expon.mean(scale=1 / lam_v)
     self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
コード例 #4
0
 def testExponentialVariance(self):
     lam_v = np.array([1.0, 4.0, 2.5])
     exponential = exponential_lib.Exponential(rate=lam_v)
     self.assertEqual(exponential.variance().shape, (3, ))
     expected_variance = sp_stats.expon.var(scale=1 / lam_v)
     self.assertAllClose(self.evaluate(exponential.variance()),
                         expected_variance)
コード例 #5
0
 def testExponentialEntropy(self):
     lam_v = np.array([1.0, 4.0, 2.5])
     exponential = exponential_lib.Exponential(rate=lam_v)
     self.assertEqual(exponential.entropy().shape, (3, ))
     expected_entropy = sp_stats.expon.entropy(scale=1 / lam_v)
     self.assertAllClose(self.evaluate(exponential.entropy()),
                         expected_entropy)
コード例 #6
0
 def testFullyReparameterized(self):
     lam = tf.constant([0.1, 1.0])
     _, grad_lam = tfp.math.value_and_gradient(
         lambda l: exponential_lib.Exponential(rate=lam, validate_args=True)
         .  # pylint: disable=g-long-lambda
         sample(100, seed=test_util.test_seed()),
         lam)
     self.assertIsNotNone(grad_lam)
コード例 #7
0
 def testExponentialMean(self):
     lam_v = np.array([1.0, 4.0, 2.5])
     exponential = exponential_lib.Exponential(rate=lam_v)
     self.assertEqual(exponential.mean().shape, (3, ))
     if not stats:
         return
     expected_mean = stats.expon.mean(scale=1 / lam_v)
     self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
コード例 #8
0
 def testFullyReparameterized(self):
     lam = tf.constant([0.1, 1.0])
     with tf.GradientTape() as tape:
         tape.watch(lam)
         exponential = exponential_lib.Exponential(rate=lam)
         samples = exponential.sample(100)
     grad_lam = tape.gradient(samples, lam)
     self.assertIsNotNone(grad_lam)
コード例 #9
0
def resample_independent(log_probs, event_size, sample_shape,
                         seed=None, name=None):
  """Categorical resampler for sequential Monte Carlo.

  The return value from this function is similar to sampling with

  ```python
  expanded_sample_shape = tf.concat([[event_size], sample_shape]), axis=-1)
  tfd.Categorical(logits=log_probs).sample(expanded_sample_shape)`
  ```

  but with values sorted along the first axis. It can be considered to be
  sampling events made up of a length-`event_size` vector of draws from
  the `Categorical` distribution. For large input values this function should
  give better performance than using `Categorical`.
  The sortedness is an unintended side effect of the algorithm that is
  harmless in the context of simple SMC algorithms.

  This implementation is based on the algorithms in [Maskell et al. (2006)][1].
  It is also known as multinomial resampling as described in
  [Doucet et al. (2011)][2].

  Args:
    log_probs: A tensor-valued batch of discrete log probability distributions.
    event_size: the dimension of the vector considered a single draw.
    sample_shape: the `sample_shape` determining the number of draws.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
      Default value: None (i.e. no seed).
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'resample_independent'`).

  Returns:
    resampled_indices: a tensor of samples.

  #### References

  [1]: S. Maskell, B. Alun-Jones and M. Macleod. A Single Instruction Multiple
       Data Particle Filter.
       In 2006 IEEE Nonlinear Statistical Signal Processing Workshop.
       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
  [2]: A. Doucet & A. M. Johansen. Tutorial on Particle Filtering and
       Smoothing: Fifteen Years Later
       In 2011 The Oxford Handbook of Nonlinear Filtering
       https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf

  """
  with tf.name_scope(name or 'resample_independent') as name:
    log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32)
    log_probs = dist_util.move_dimension(log_probs, source_idx=0, dest_idx=-1)
    points_shape = ps.concat([sample_shape,
                              ps.shape(log_probs)[:-1],
                              [event_size]], axis=0)
    log_points = -exponential.Exponential(
        rate=tf.constant(1.0, dtype=log_probs.dtype)).sample(
            points_shape, seed=seed)

    resampled = _resample_using_log_points(log_probs, sample_shape, log_points)
    return dist_util.move_dimension(resampled, source_idx=-1, dest_idx=0)
コード例 #10
0
    def testExpontentialQuantile(self):
        exponential = exponential_lib.Exponential(rate=[1., 2.])

        # Corner cases.
        result = self.evaluate(exponential.quantile([0., 1.]))
        self.assertAllClose(result, [0., np.inf])

        # Two sample values calculated by hand.
        result = self.evaluate(exponential.quantile(0.5))
        self.assertAllClose(result, [0.693147, 0.346574])
コード例 #11
0
    def testExponentialCDF(self):
        batch_size = 6
        lam = tf.constant([2.0] * batch_size)
        lam_v = 2.0
        x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)

        exponential = exponential_lib.Exponential(rate=lam)

        cdf = exponential.cdf(x)
        self.assertEqual(cdf.shape, (6, ))

        expected_cdf = sp_stats.expon.cdf(x, scale=1 / lam_v)
        self.assertAllClose(self.evaluate(cdf), expected_cdf)
コード例 #12
0
    def testExponentialLogSurvival(self):
        batch_size = 7
        lam = tf.constant([2.0] * batch_size)
        lam_v = 2.0
        x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0, 10.0], dtype=np.float32)

        exponential = exponential_lib.Exponential(rate=lam)

        log_survival = exponential.log_survival_function(x)
        self.assertEqual(log_survival.shape, (7, ))

        expected_log_survival = sp_stats.expon.logsf(x, scale=1 / lam_v)
        self.assertAllClose(self.evaluate(log_survival), expected_log_survival)
コード例 #13
0
  def testExponentialSample(self):
    lam = tf.constant([3.0, 4.0])
    lam_v = [3.0, 4.0]
    n = tf.constant(100000)
    exponential = exponential_lib.Exponential(rate=lam)

    samples = exponential.sample(n, seed=tfp_test_util.test_seed())
    sample_values = self.evaluate(samples)
    self.assertEqual(sample_values.shape, (100000, 2))
    self.assertFalse(np.any(sample_values < 0.0))
    for i in range(2):
      self.assertLess(
          sp_stats.kstest(sample_values[:, i],
                          sp_stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
コード例 #14
0
 def _sample_n(self, n, seed=None):
   normal_seed, exp_seed = samplers.split_seed(seed, salt='emg_sample')
   # need to make sure component distributions are broadcast appropriately
   # for correct generation of samples
   loc = tf.convert_to_tensor(self.loc)
   rate = tf.convert_to_tensor(self.rate)
   scale = tf.convert_to_tensor(self.scale)
   batch_shape = self._batch_shape_tensor(loc, scale, rate)
   loc_broadcast = tf.broadcast_to(loc, batch_shape)
   rate_broadcast = tf.broadcast_to(rate, batch_shape)
   normal_dist = normal_lib.Normal(loc=loc_broadcast, scale=scale)
   exp_dist = exponential_lib.Exponential(rate_broadcast)
   x = normal_dist.sample(n, normal_seed)
   y = exp_dist.sample(n, exp_seed)
   return x + y
コード例 #15
0
    def testExponentialLogPDF(self):
        batch_size = 6
        lam = tf.constant([2.0] * batch_size)
        lam_v = 2.0
        x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
        exponential = exponential_lib.Exponential(rate=lam)

        log_pdf = exponential.log_prob(x)
        self.assertEqual(log_pdf.shape, (6, ))

        pdf = exponential.prob(x)
        self.assertEqual(pdf.shape, (6, ))

        expected_log_pdf = sp_stats.expon.logpdf(x, scale=1 / lam_v)
        self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
        self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
コード例 #16
0
  def testExponentialSampleMultiDimensional(self):
    batch_size = 2
    lam_v = [3.0, 22.0]
    lam = tf.constant([lam_v] * batch_size)

    exponential = exponential_lib.Exponential(rate=lam)

    n = 100000
    samples = exponential.sample(n, seed=tfp_test_util.test_seed())
    self.assertEqual(samples.shape, (n, batch_size, 2))

    sample_values = self.evaluate(samples)

    self.assertFalse(np.any(sample_values < 0.0))
    for i in range(2):
      self.assertLess(
          sp_stats.kstest(sample_values[:, 0, i],
                          sp_stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
      self.assertLess(
          sp_stats.kstest(sample_values[:, 1, i],
                          sp_stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
コード例 #17
0
def _log_concave_rejection_sampler(
    mode,
    prob_fn,
    dtype,
    sample_shape=(),
    distribution_minimum=None,
    distribution_maximum=None,
    seed=None):
  """Utility for rejection sampling from log-concave discrete distributions.

  This utility constructs an easy-to-sample-from upper bound for a discrete
  univariate log-concave distribution (for discrete univariate distributions, a
  necessary and sufficient condition is p_k^2 >= p_{k-1} p_{k+1} for all k).
  The method requires that the mode of the distribution is known. While a better
  method can likely be derived for any given distribution, this method is
  general and easy to implement. The expected number of iterations is bounded by
  4+m, where m is the probability of the mode. For details, see [(Devroye,
  1979)][1].

  Args:
    mode: Tensor, the mode[s] of the [batch of] distribution[s].
    prob_fn: Python callable, counts -> prob(counts).
    dtype: DType of the generated samples.
    sample_shape: 0D or 1D `int32` `Tensor`. Shape of the generated samples.
    distribution_minimum: Tensor of type `dtype`. The minimum value
      taken by the distribution. The `prob` method will only be called on values
      greater than equal to the specified minimum. The shape must broadcast with
      the batch shape of the distribution. If unspecified, the domain is treated
      as unbounded below.
    distribution_maximum: Tensor of type `dtype`. The maximum value
      taken by the distribution. See `distribution_minimum` for details.
    seed: Python integer or `Tensor` instance, for seeding PRNG.

  Returns:
    samples: a `Tensor` with prepended dimensions `sample_shape`.

  #### References

  [1] Luc Devroye. A Simple Generator for Discrete Log-Concave
      Distributions. Computing, 1987.

  [2] Dillon et al. TensorFlow Distributions. 2017.
      https://arxiv.org/abs/1711.10604
  """
  mode = tf.broadcast_to(
      mode, tf.concat([sample_shape, prefer_static.shape(mode)], axis=0))

  mode_height = prob_fn(mode)
  mode_shape = prefer_static.shape(mode)

  top_width = 1. + mode_height / 2.  # w in ref [1].
  top_fraction = top_width / (1 + top_width)
  exponential_distribution = exponential.Exponential(
      rate=tf.constant(1., dtype=dtype))  # E in ref [1].

  if distribution_minimum is None:
    distribution_minimum = tf.constant(-np.inf, dtype)
  if distribution_maximum is None:
    distribution_maximum = tf.constant(np.inf, dtype)

  def proposal(seed):
    """Proposal for log-concave rejection sampler."""
    (top_lobe_fractions_seed,
     exponential_samples_seed,
     top_selector_seed,
     rademacher_seed) = samplers.split_seed(
         seed, n=4, salt='log_concave_rejection_sampler_proposal')

    top_lobe_fractions = samplers.uniform(
        mode_shape, seed=top_lobe_fractions_seed, dtype=dtype)  # V in ref [1].
    top_offsets = top_lobe_fractions * top_width / mode_height

    exponential_samples = exponential_distribution.sample(
        mode_shape, seed=exponential_samples_seed)  # E in ref [1].
    exponential_height = (exponential_distribution.prob(exponential_samples) *
                          mode_height)
    exponential_offsets = (top_width + exponential_samples) / mode_height

    top_selector = samplers.uniform(
        mode_shape, seed=top_selector_seed, dtype=dtype)  # U in ref [1].
    on_top_mask = tf.less_equal(top_selector, top_fraction)

    unsigned_offsets = tf.where(on_top_mask, top_offsets, exponential_offsets)
    offsets = tf.round(
        tfp_random.rademacher(
            mode_shape, seed=rademacher_seed, dtype=dtype) *
        unsigned_offsets)

    potential_samples = mode + offsets
    envelope_height = tf.where(on_top_mask, mode_height, exponential_height)

    return potential_samples, envelope_height

  def target(values):
    in_range_mask = (
        (values >= distribution_minimum) & (values <= distribution_maximum))
    in_range_values = tf.where(in_range_mask, values, 0.)
    return tf.where(in_range_mask, prob_fn(in_range_values), 0.)

  return tf.stop_gradient(
      batched_rejection_sampler.batched_rejection_sampler(
          proposal, target, seed, dtype=dtype)[0])  # Discard `num_iters`.
    def __init__(self,
                 loc=None,
                 scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='VectorExponentialLinearOperator'):
        """Construct Vector Exponential distribution supported on a subset of `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
        parameters = dict(locals())
        if loc is None:
            loc = 0.0  # Implicit value for backwards compatibility.
        if scale is None:
            raise ValueError('Missing required `scale` parameter.')
        if not dtype_util.is_floating(scale.dtype):
            raise TypeError(
                '`scale` parameter must have floating-point dtype.')

        with tf.name_scope(name) as name:
            # Since expand_dims doesn't preserve constant-ness, we obtain the
            # non-dynamic value if possible.
            loc = loc if loc is None else tf.convert_to_tensor(
                loc, name='loc', dtype=scale.dtype)
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale)
            self._loc = loc
            self._scale = scale
            super(VectorExponentialLinearOperator, self).__init__(
                # TODO(b/137665504): Use batch-adding meta-distribution to set the
                # batch shape instead of tf.ones.
                # We use `Sample` instead of `Independent` because `Independent`
                # requires concatenating `batch_shape` and `event_shape`, which loses
                # static `batch_shape` information when `event_shape` is not
                # statically known.
                distribution=sample.Sample(
                    exponential.Exponential(rate=tf.ones(batch_shape,
                                                         dtype=scale.dtype),
                                            allow_nan_stats=allow_nan_stats),
                    event_shape),
                bijector=shift_bijector.Shift(shift=loc)(
                    scale_matvec_linear_operator.ScaleMatvecLinearOperator(
                        scale=scale, validate_args=validate_args)),
                validate_args=validate_args,
                name=name)
            self._parameters = parameters
コード例 #19
0
 def testExponentialQuantileIsInverseOfCdf(self):
     exponential = exponential_lib.Exponential(rate=[1., 2.],
                                               validate_args=False)
     values = [2 * [t / 10.] for t in range(0, 11)]
     result = self.evaluate(exponential.cdf(exponential.quantile(values)))
     self.assertAllClose(result, values)
コード例 #20
0
    def __init__(self,
                 loc=None,
                 scale=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorExponentialLinearOperator"):
        """Construct Vector Exponential distribution supported on a subset of `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale: Instance of `LinearOperator` with same `dtype` as `loc` and shape
        `[B1, ..., Bb, k, k]`.
      validate_args: Python `bool`, default `False`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if `scale` is unspecified.
      TypeError: if not `scale.dtype.is_floating`
    """
        parameters = dict(locals())
        if scale is None:
            raise ValueError("Missing required `scale` parameter.")
        if not scale.dtype.is_floating:
            raise TypeError(
                "`scale` parameter must have floating-point dtype.")

        with tf.compat.v2.name_scope(name) as name:
            # Since expand_dims doesn't preserve constant-ness, we obtain the
            # non-dynamic value if possible.
            loc = loc if loc is None else tf.convert_to_tensor(
                value=loc, name="loc", dtype=scale.dtype)
            batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale(
                loc, scale)

            super(VectorExponentialLinearOperator, self).__init__(
                distribution=exponential.Exponential(
                    rate=tf.ones([], dtype=scale.dtype),
                    allow_nan_stats=allow_nan_stats),
                bijector=affine_linear_operator_bijector.AffineLinearOperator(
                    shift=loc, scale=scale, validate_args=validate_args),
                batch_shape=batch_shape,
                event_shape=event_shape,
                validate_args=validate_args,
                name=name)
            self._parameters = parameters
コード例 #21
0
 def testFullyReparameterized(self):
     lam = tf.constant([0.1, 1.0])
     _, grad_lam = tfp.math.value_and_gradient(
         lambda l: exponential_lib.Exponential(rate=lam).sample(100), lam)
     self.assertIsNotNone(grad_lam)
コード例 #22
0
def resample_independent(log_probs,
                         event_size,
                         sample_shape,
                         seed=None,
                         name=None):
    """Categorical resampler for sequential Monte Carlo.

  This function is based on Algorithm #1 in the paper
  [Maskell et al. (2006)][1].

  Args:
    log_probs: A tensor-valued batch of discrete log probability distributions.
    event_size: the dimension of the vector considered a single draw.
    sample_shape: the `sample_shape` determining the number of draws.
    seed: Python '`int` used to seed calls to `tf.random.*`.
      Default value: None (i.e. no seed).
    name: Python `str` name for ops created by this method.
      Default value: `None` (i.e., `'resample_independent'`).

  Returns:
    resampled_indices: The result is similar to sampling with
    ```python
    expanded_sample_shape = tf.concat([[event_size], sample_shape]), axis=-1)
    tfd.Categorical(logits=log_probs).sample(expanded_sample_shape)`
    ```
    but with values sorted along the first axis. It can be considered to be
    sampling events made up of a length-`event_size` vector of draws from
    the `Categorical` distribution. For large input values this function should
    give better performance than using `Categorical`.
    The sortedness is an unintended side effect of the algorithm that is
    harmless in the context of simple SMC algorithms.

  #### References

  [1]: S. Maskell, B. Alun-Jones and M. Macleod. A Single Instruction Multiple
       Data Particle Filter.
       In 2006 IEEE Nonlinear Statistical Signal Processing Workshop.
       http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf

  """
    with tf.name_scope(name or 'resample_independent') as name:
        log_probs = tf.convert_to_tensor(log_probs, dtype_hint=tf.float32)
        log_probs = dist_util.move_dimension(log_probs,
                                             source_idx=0,
                                             dest_idx=-1)

        batch_shape = prefer_static.shape(log_probs)[:-1]
        num_markers = prefer_static.shape(log_probs)[-1]

        # `working_shape` specifies the total number of events
        # we will be generating.
        working_shape = prefer_static.concat([sample_shape, batch_shape],
                                             axis=0)
        # `points_shape` is the shape of the final result.
        points_shape = prefer_static.concat([working_shape, [event_size]],
                                            axis=0)
        # `markers_shape` is the shape of the markers we temporarily insert.
        markers_shape = prefer_static.concat([working_shape, [num_markers]],
                                             axis=0)
        # Generate one real point for each particle.
        log_points = -exponential.Exponential(
            rate=tf.constant(1.0, dtype=log_probs.dtype)).sample(points_shape,
                                                                 seed=seed)

        # We divide up the unit interval [0, 1] according to the provided
        # probability distributions using `cumsum`.
        # At the end of each division we place a 'marker'.
        # We generate random points on the unit interval.
        # We sort the combination of points and markers. The number
        # of points between the markers defining a division gives the number
        # of samples we require in that division.
        # For example, suppose `probs` is `[0.2, 0.3, 0.5]`.
        # We divide up `[0, 1]` using 3 markers:
        #
        #     |     |          |
        # 0.  0.2   0.5        1.0  <- markers
        #
        # Suppose we generate four points: [0.1, 0.25, 0.9, 0.75]
        # After sorting the combination we get:
        #
        # 0.1  0.25     0.75 0.9    <- points
        #  *  | *   |    *    *|
        # 0.   0.2 0.5         1.0  <- markers
        #
        # We have one sample in the first category, one in the second and
        # two in the last.
        #
        # All of these computations are carried out in batched form.
        markers = prefer_static.concat([
            tf.zeros(points_shape, dtype=tf.int32),
            tf.ones(markers_shape, dtype=tf.int32)
        ],
                                       axis=-1)
        log_marker_positions = tf.broadcast_to(
            tf.math.cumulative_logsumexp(log_probs, axis=-1), markers_shape)
        log_points_and_markers = prefer_static.concat(
            [log_points, log_marker_positions], axis=-1)
        indices = tf.argsort(log_points_and_markers, axis=-1, stable=False)
        sorted_markers = tf.gather_nd(
            markers,
            indices[..., tf.newaxis],
            batch_dims=(prefer_static.rank_from_shape(sample_shape) +
                        prefer_static.rank_from_shape(batch_shape)))
        markers_and_samples = prefer_static.cast(tf.cumsum(sorted_markers,
                                                           axis=-1),
                                                 dtype=tf.int32)
        markers_and_samples = tf.minimum(markers_and_samples, num_markers - 1)
        # Collect up samples, omitting markers.
        resampled = tf.reshape(
            markers_and_samples[tf.equal(sorted_markers, 0)], points_shape)
        resampled = dist_util.move_dimension(resampled,
                                             source_idx=-1,
                                             dest_idx=0)
        return resampled