def testBatchedLasVegasAlgorithm(self):
        seed = test_util.test_seed()

        def uniform_less_than_point_five(seed):
            seed_stream = SeedStream(seed, 'uniform_less_than_point_five')
            values = tf.random.uniform([6], seed=seed_stream())
            negative_values = -values
            good = tf.less(values, 0.5)

            return ((negative_values, values), good)

        ((negative_values, values), _) = self.evaluate(
            brs.batched_las_vegas_algorithm(uniform_less_than_point_five,
                                            seed=seed))
        self.assertAllLess(values, 0.5)
        self.assertAllClose(-values, negative_values)

        if tf.executing_eagerly():
            tf.random.set_seed(seed)

        # Check for reproducibility.
        ((negative_values_2, values_2), _) = self.evaluate(
            brs.batched_las_vegas_algorithm(uniform_less_than_point_five,
                                            seed=seed))
        self.assertAllEqual(negative_values, negative_values_2)
        self.assertAllEqual(values, values_2)
    def testGivingUp(self):
        def trial(seed):
            del seed
            return tf.constant([1, 0]), tf.constant([True, False])

        values, final_successes, num_iters = self.evaluate(
            brs.batched_las_vegas_algorithm(trial,
                                            max_trials=50,
                                            seed=test_util.test_seed()))
        self.assertAllEqual(values, [1, 0])
        self.assertAllEqual(final_successes, [True, False])
        self.assertAllEqual(50, num_iters)
    def testBatchedLasVegasAlgorithm(self):
        def uniform_less_than_point_five(seed):
            values = samplers.uniform([6], seed=seed)
            negative_values = -values
            good = tf.less(values, 0.5)

            return ((negative_values, values), good)

        ((negative_values, values), _, _) = self.evaluate(
            brs.batched_las_vegas_algorithm(uniform_less_than_point_five,
                                            seed=test_util.test_seed()))

        self.assertAllLess(values, 0.5)
        self.assertAllClose(-values, negative_values)

        # Check for reproducibility.
        ((negative_values_2, values_2), _, _) = self.evaluate(
            brs.batched_las_vegas_algorithm(uniform_less_than_point_five,
                                            seed=test_util.test_seed()))
        self.assertAllEqual(negative_values, negative_values_2)
        self.assertAllEqual(values, values_2)
Beispiel #4
0
def _btrs(counts, probs, full_shape, seed):
    """Binomial transformed rejection sampler, for count*prob >= 10."""
    # We use a transformation-rejection algorithm from
    # pairs of uniform random variables due to Hormann.
    # https://www.tandfonline.com/doi/abs/10.1080/00949659308811496

    seed = samplers.sanitize_seed(seed)
    # This is spq in the paper.
    stddev = tf.math.sqrt(counts * probs * (1 - probs))

    # Other coefficients for Transformed Rejection sampling.
    b = 1.15 + 2.53 * stddev
    a = -0.0873 + 0.0248 * b + 0.01 * probs
    c = counts * probs + 0.5
    r = probs / (1 - probs)

    alpha = (2.83 + 5.1 / b) * stddev
    m = tf.math.floor((counts + 1) * probs)

    def batched_las_vegas_trial_fn(seed):
        u_seed, v_seed = samplers.split_seed(seed)
        u = samplers.uniform(full_shape, seed=u_seed, dtype=counts.dtype) - 0.5
        v = samplers.uniform(full_shape, seed=v_seed, dtype=counts.dtype)
        us = 0.5 - tf.math.abs(u)
        k = tf.math.floor((2 * a / us + b) * u + c)

        # When the bounding box is tight, this criteria is more numerically stable
        # and equally valid. Particularly on GPU/TPU, it may make the difference
        # between terminating and non-terminating loops.
        v_r = 0.92 - 4.2 / b
        accept_boxed = (us >= 0.07) & (v <= v_r)

        # Reject non-sensical answers.
        reject = (k < 0) | (k > counts)

        # This deviates from Hormann's BTRS algorithm, as there is a log missing.
        # For all (u, v) pairs outside of the bounding box, this calculates the
        # transformed-reject ratio.
        v = tf.math.log(v * alpha / (a / (us * us) + b))
        upperbound = ((m + 0.5) * tf.math.log(
            (m + 1) / (r * (counts - m + 1))) + (counts + 1) * tf.math.log(
                (counts - m + 1) / (counts - k + 1)) +
                      (k + 0.5) * tf.math.log(r * (counts - k + 1) / (k + 1)) +
                      _stirling_approx_tail(m) +
                      _stirling_approx_tail(counts - m) -
                      _stirling_approx_tail(k) -
                      _stirling_approx_tail(counts - k))
        accept_bounded = v <= upperbound
        return k, (~reject) & (accept_boxed | accept_bounded)

    return batched_rejection_sampler.batched_las_vegas_algorithm(
        batched_las_vegas_trial_fn, seed=seed)[0]  # Pick out samples.
Beispiel #5
0
def _btrs(counts, probs, full_shape, seed):
    """Binomial transformed rejection sampler, for count*prob >= 10."""
    # We use a transformation-rejection algorithm from
    # pairs of uniform random variables due to Hormann.
    # https://www.tandfonline.com/doi/abs/10.1080/00949659308811496

    seed = samplers.sanitize_seed(seed)
    # This is spq in the paper.
    stddev = tf.math.sqrt(counts * probs * (1 - probs))

    # Other coefficients for Transformed Rejection sampling.
    b = 1.15 + 2.53 * stddev
    a = -0.0873 + 0.0248 * b + 0.01 * probs
    c = counts * probs + 0.5
    r = probs / (1 - probs)

    alpha = (2.83 + 5.1 / b) * stddev
    m = tf.math.floor((counts + 1) * probs)

    def batched_las_vegas_trial_fn(seed):
        u_seed, v_seed = samplers.split_seed(seed)
        u = samplers.uniform(full_shape, seed=u_seed, dtype=counts.dtype) - 0.5
        v = samplers.uniform(full_shape, seed=v_seed, dtype=counts.dtype)
        us = 0.5 - tf.math.abs(u)
        k = tf.math.floor((2 * a / us + b) * u + c)

        # The original algorithm accepts early if (us >= 0.07) & (v <= v_r), where
        # `v_r = 0.92 - 4.2 / b`, the region for which the box is tight. Since we
        # have rewritten from scalar evaluation to batch and we can't avoid the
        # computation below, we omit the cheaper check.

        # Reject non-sensical answers.
        reject = (k < 0) | (k > counts)

        # This deviates from Hormann's BTRS algorithm, as there is a log missing.
        # For all (u, v) pairs outside of the bounding box, this calculates the
        # transformed-reject ratio.
        v = tf.math.log(v * alpha / (a / (us * us) + b))
        upperbound = ((m + 0.5) * tf.math.log(
            (m + 1) / (r * (counts - m + 1))) + (counts + 1) * tf.math.log(
                (counts - m + 1) / (counts - k + 1)) +
                      (k + 0.5) * tf.math.log(r * (counts - k + 1) / (k + 1)) +
                      _stirling_approx_tail(m) +
                      _stirling_approx_tail(counts - m) -
                      _stirling_approx_tail(k) -
                      _stirling_approx_tail(counts - k))
        return k, (~reject) & (v <= upperbound)

    return batched_rejection_sampler.batched_las_vegas_algorithm(
        batched_las_vegas_trial_fn, seed)[0]  # Drop `num_trials`.
Beispiel #6
0
      def generate_positive_v():
        """Generate positive v."""
        def _inner(seed):
          x = samplers.normal(shape, dtype=internal_dtype, seed=seed)
          # This implicitly broadcasts concentration up to sample shape.
          v = 1 + c * x
          return (x, v), v > 0.

        # Note: It should be possible to remove this 'inner' call to
        # `batched_las_vegas_algorithm` and merge the v > 0 check into the
        # overall check for a good sample. This would lead to a slightly simpler
        # implementation; it is unclear whether it would be faster. We include
        # the inner loop so this implementation is more easily comparable to
        # Ref. [1] and other implementations.
        return brs.batched_las_vegas_algorithm(_inner, v_seed)[0]
    def testBatchedStructuredLasVegasAlgorithm(self):
        def uniform_in_circle(seed):
            coords = samplers.uniform([6, 2],
                                      minval=-1.0,
                                      maxval=1.0,
                                      seed=seed)
            radii = tf.reduce_sum(coords * coords, axis=-1)
            good = tf.less(radii, 1)
            return (coords, good)

        (coords, _,
         _) = brs.batched_las_vegas_algorithm(uniform_in_circle,
                                              seed=test_util.test_seed())

        radii = self.evaluate(tf.reduce_sum(coords * coords, axis=-1))
        self.assertAllLess(radii, 1.0)
Beispiel #8
0
def _random_poisson_high_rate(sample_shape,
                              log_rate,
                              internal_dtype=tf.float64,
                              seed=None):
  """Samples from the Poisson distribution using transformed rejection sampling.

  Given a CDF F(x), and G(x), a dominating distribution chosen such that it is
  close to the inverse CDF F^-1(x), compute the following steps:

  1) Generate U and V, two independent random variates. Set U = U - 0.5 (this
  step isn't strictly necessary, but is done to make some calculations symmetric
  and convenient. Henceforth, G is defined on [-0.5, 0.5]).

  2) If V <= alpha * F'(G(U)) * G'(U), return floor(G(U)), else return to
  step 1. alpha is the acceptance probability of the rejection algorithm.
  The dominating distribution in this case:
    G(u) = (2 * a / (2 - |u|) + b) * u + c

  For more details on transformed rejection, see [1].

  Args:
    sample_shape: The output sample shape. Must broadcast with `log_rate`.
    log_rate: Floating point tensor, log rate.
    internal_dtype: dtype to use for internal computations.
    seed: (optional) The random seed.

  Returns:
    Samples from the poisson distribution using transformed rejection.

  #### References

  [1]: W. Hormann, G. Derflinger, The Transformed Rejection Method For
  Generating Random Variables, An Alternative To The Ratio Of Uniforms Method
  (1994), Manuskript, Institut f. Statistik, Wirtschaftsuniversitat
  """
  rate = tf.math.exp(log_rate)

  b = 0.931 + 2.53 * tf.math.exp(0.5 * log_rate)
  a = -0.059 + 0.02483 * b
  inverse_alpha = 1.1239 + 1.1328 / (b - 3.4)

  def generate_and_test_samples(seed):
    """Generate and test samples."""
    u_seed, v_seed = samplers.split_seed(seed)

    u = samplers.uniform(sample_shape, dtype=internal_dtype, seed=u_seed)
    u = u - 0.5
    u_shifted = 0.5 - tf.math.abs(u)

    v = samplers.uniform(sample_shape, dtype=internal_dtype, seed=v_seed)

    k = tf.math.floor(((2. * a) / u_shifted + b) * u + rate + 0.43)

    good_sample_mask = (u_shifted >= 0.07) & (v <= 0.9277 - 3.6224 / (b - 2.))

    s = tf.math.log(v * inverse_alpha / (a / tf.math.square(u_shifted) + b))
    t = -rate + k * log_rate - tf.math.lgamma(k + 1)

    good_sample_mask = good_sample_mask | (s <= t)
    # Make sure the sample is within bounds.
    good_sample_mask = good_sample_mask & (k >= 0) & ((u_shifted >= 0.013) |
                                                      (v <= u_shifted))
    return k, good_sample_mask

  samples = brs.batched_las_vegas_algorithm(
      generate_and_test_samples, seed=seed)[0]

  return samples
Beispiel #9
0
    def rejection_sample(concentration):
        """Gamma rejection sampler."""
        # Note, concentration here already has a shape that is broadcast with rate.
        cast_concentration = tf.cast(concentration, internal_dtype)

        good_params_mask = (concentration >= 0.)
        # When replacing NaN values, use 100. for concentration, since that leads to
        # a high-likelihood of the rejection sampler accepting on the first pass.
        safe_concentration = tf.where(good_params_mask, cast_concentration,
                                      100.)

        modified_safe_concentration = tf.where(safe_concentration < 1.,
                                               safe_concentration + 1.,
                                               safe_concentration)

        one_third = tf.constant(1. / 3, dtype=internal_dtype)
        d = modified_safe_concentration - one_third
        c = one_third * tf.math.rsqrt(d)

        def generate_and_test_samples(seed):
            """Generate and test samples."""
            v_seed, u_seed = samplers.split_seed(seed)

            x = samplers.normal(shape, dtype=internal_dtype, seed=v_seed)
            # This implicitly broadcasts concentration up to sample shape.
            v = 1 + c * x
            # In [1], there is an 'inner' rejection sampling loop which checks that
            # v > 0 and generates a new normal sample if it's not, saving the rest of
            # the computations below. We found that merging the check for  v > 0 with
            # the `good_sample_mask` not only simplifies the code, but leads to a
            # ~2x speedup for small concentrations on GPU, at the cost of deviating
            # slightly from the implementation given in Ref. [1].
            accept_v = v > 0.
            logv = tf.math.log1p(c * x)
            x2 = x * x
            v3 = v * v * v
            logv3 = logv * 3

            u = samplers.uniform(shape, dtype=internal_dtype, seed=u_seed)

            # In [1], the suggestion is to first check u < 1 - 0.331 * x2 * x2, and to
            # run the check below only if it fails, in order to avoid the relatively
            # expensive logarithm calls. Our algorithm operates in batch mode: we will
            # have to compute or not compute the logarithms for the entire batch, and
            # as the batch gets larger, the odds we compute it grow. Therefore we
            # don't bother with the "cheap" check.
            good_sample_mask = tf.logical_and(
                tf.math.log(u) < (x2 / 2. + d * (1 - v3 + logv3)), accept_v)

            return logv3 if log_space else v3, good_sample_mask

        samples = brs.batched_las_vegas_algorithm(
            generate_and_test_samples, seed=generate_and_test_samples_seed)[0]

        concentration_fix_unif = samplers.uniform(  # in [0, 1)
            shape,
            dtype=internal_dtype,
            seed=concentration_fix_seed)

        if log_space:
            concentration_lt_one_fix = tf.where(
                safe_concentration < 1.,
                # Why do we use log1p(-x)? x is in [0, 1) and log(0) = -inf, is bad.
                # x ~ U(0,1) => 1-x ~ U(0,1)
                # But at the boundary, 1-x in (0, 1]. Good.
                # So we can take log(unif(0,1)) safely as log(1-unif(0,1)).
                # log1p(-0) = 0, and log1p(-almost_one) = -not_quite_inf. Good.
                tf.math.log1p(-concentration_fix_unif) / safe_concentration,
                tf.zeros((), dtype=internal_dtype))
            samples = samples + tf.math.log(d) + concentration_lt_one_fix
        else:
            concentration_lt_one_fix = tf.where(
                safe_concentration < 1.,
                tf.math.pow(concentration_fix_unif,
                            tf.math.reciprocal(safe_concentration)),
                tf.ones((), dtype=internal_dtype))
            samples = samples * d * concentration_lt_one_fix

        samples = tf.where(good_params_mask, samples, np.nan)
        output_type_samples = tf.cast(samples, output_dtype)

        return output_type_samples
Beispiel #10
0
  def rejection_sample(concentration):
    """Gamma rejection sampler."""
    # Note, concentration here already has a shape that is broadcast with rate.
    cast_concentration = tf.cast(concentration, internal_dtype)

    good_params_mask = (concentration > 0.)
    # When replacing NaN values, use 100. for concentration, since that leads to
    # a high-likelihood of the rejection sampler accepting on the first pass.
    safe_concentration = tf.where(good_params_mask, cast_concentration, 100.)

    modified_safe_concentration = tf.where(
        safe_concentration < 1., safe_concentration + 1., safe_concentration)

    one_third = tf.constant(1. / 3, dtype=internal_dtype)
    d = modified_safe_concentration - one_third
    c = one_third * tf.math.rsqrt(d)

    def generate_and_test_samples(seed):
      """Generate and test samples."""
      v_seed, u_seed = samplers.split_seed(seed)

      def generate_positive_v():
        """Generate positive v."""
        def _inner(seed):
          x = samplers.normal(shape, dtype=internal_dtype, seed=seed)
          # This implicitly broadcasts concentration up to sample shape.
          v = 1 + c * x
          return (x, v), v > 0.

        # Note: It should be possible to remove this 'inner' call to
        # `batched_las_vegas_algorithm` and merge the v > 0 check into the
        # overall check for a good sample. This would lead to a slightly simpler
        # implementation; it is unclear whether it would be faster. We include
        # the inner loop so this implementation is more easily comparable to
        # Ref. [1] and other implementations.
        return brs.batched_las_vegas_algorithm(_inner, v_seed)[0]

      (x, v) = generate_positive_v()
      logv = tf.math.log1p(c * x)
      x2 = x * x
      v3 = v * v * v
      logv3 = logv * 3

      u = samplers.uniform(
          shape, dtype=internal_dtype, seed=u_seed)

      # In [1], the suggestion is to first check u < 1 - 0.331 * x2 * x2, and to
      # run the check below only if it fails, in order to avoid the relatively
      # expensive logarithm calls. Our algorithm operates in batch mode: we will
      # have to compute or not compute the logarithms for the entire batch, and
      # as the batch gets larger, the odds we compute it grow. Therefore we
      # don't bother with the "cheap" check.
      good_sample_mask = tf.math.log(u) < (x2 / 2. + d * (1 - v3 + logv3))

      return logv3 if log_space else v3, good_sample_mask

    samples = brs.batched_las_vegas_algorithm(
        generate_and_test_samples, seed=generate_and_test_samples_seed)[0]

    concentration_fix_unif = samplers.uniform(  # in [0, 1)
        shape, dtype=internal_dtype, seed=concentration_fix_seed)

    if log_space:
      concentration_lt_one_fix = tf.where(
          safe_concentration < 1.,
          # Why do we use log1p(-x)? x is in [0, 1) and log(0) = -inf, is bad.
          # x ~ U(0,1) => 1-x ~ U(0,1)
          # But at the boundary, 1-x in (0, 1]. Good.
          # So we can take log(unif(0,1)) safely as log(1-unif(0,1)).
          # log1p(-0) = 0, and log1p(-almost_one) = -not_quite_inf. Good.
          tf.math.log1p(-concentration_fix_unif) / safe_concentration,
          tf.zeros((), dtype=internal_dtype))
      samples = samples + tf.math.log(d) + concentration_lt_one_fix
    else:
      concentration_lt_one_fix = tf.where(
          safe_concentration < 1.,
          tf.math.pow(concentration_fix_unif,
                      tf.math.reciprocal(safe_concentration)),
          tf.ones((), dtype=internal_dtype))
      samples = samples * d * concentration_lt_one_fix

    samples = tf.where(good_params_mask, samples, np.nan)
    output_type_samples = tf.cast(samples, output_dtype)

    return output_type_samples
Beispiel #11
0
    def rejection_sample(alpha):
        """Gamma rejection sampler."""
        # Note that alpha here already has a shape that is broadcast with beta.
        cast_alpha = tf.cast(alpha, internal_dtype)

        good_params_mask = (alpha > 0.)
        # When replacing NaN values, use 100. for alpha, since that leads to a
        # high-likelihood of the rejection sampler accepting on the first pass.
        safe_alpha = tf.where(good_params_mask, cast_alpha, 100.)

        modified_safe_alpha = tf.where(safe_alpha < 1., safe_alpha + 1.,
                                       safe_alpha)

        one_third = tf.constant(1. / 3, dtype=internal_dtype)
        d = modified_safe_alpha - one_third
        c = one_third / tf.sqrt(d)

        def generate_and_test_samples(seed):
            """Generate and test samples."""
            seed_stream = SeedStream(seed, 'generate_and_test_samples')

            def generate_positive_v():
                """Generate positive v."""
                def _inner(seed):
                    seed_stream = SeedStream(seed, '_inner')
                    x = tf.random.normal(sample_shape,
                                         dtype=internal_dtype,
                                         seed=seed_stream())
                    # This implicitly broadcasts alpha up to sample shape.
                    v = 1 + c * x
                    return (x, v), v > 0.

                # Note: It should be possible to remove this 'inner' call to
                # `batched_las_vegas_algorithm` and merge the v > 0 check into the
                # overall check for a good sample. This would lead to a slightly simpler
                # implementation; it is unclear whether it would be faster. We include
                # the inner loop so this implementation is more easily comparable to
                # Ref. [1] and other implementations.
                return brs.batched_las_vegas_algorithm(_inner, seed)[0]

            (x, v) = generate_positive_v()
            x2 = x * x
            v3 = v * v * v
            u = tf.random.uniform(sample_shape,
                                  dtype=internal_dtype,
                                  seed=seed_stream())

            # In [1], the suggestion is to first check u < 1 - 0.331 * x2 * x2, and to
            # run the check below only if it fails, in order to avoid the relatively
            # expensive logarithm calls. Our algorithm operates in batch mode: we will
            # have to compute or not compute the logarithms for the entire batch, and
            # as the batch gets larger, the odds we compute it grow. Therefore we
            # don't bother with the "cheap" check.
            good_sample_mask = (tf.math.log(u) < x2 / 2. + d *
                                (1 - v3 + tf.math.log(v3)))

            return v3, good_sample_mask

        samples = brs.batched_las_vegas_algorithm(generate_and_test_samples,
                                                  seed=seed_stream())[0]

        samples = samples * d

        one = tf.constant(1., dtype=internal_dtype)

        alpha_lt_one_fix = tf.where(
            safe_alpha < 1.,
            tf.math.pow(
                tf.random.uniform(sample_shape,
                                  dtype=internal_dtype,
                                  seed=seed_stream()), one / safe_alpha), one)
        samples = samples * alpha_lt_one_fix
        samples = tf.where(good_params_mask, samples, np.nan)

        output_type_samples = tf.cast(samples, output_dtype)

        # We use `tf.where` instead of `tf.maximum` because we need to allow for
        # `samples` to be `nan`, but `tf.maximum(nan, x) == x`.
        output_type_samples = tf.where(
            output_type_samples == 0,
            np.finfo(dtype_util.as_numpy_dtype(
                output_type_samples.dtype)).tiny, output_type_samples)

        def grad(dy):
            """The gradient of the normalized (beta=1) gamma samples w.r.t alpha."""
            # Recall that cast_alpha has shape broadcast with beta, and samples and dy
            # have shape sample_shape (which further expands the alpha-beta broadcast
            # shape on the left).
            cast_dy = tf.cast(dy, internal_dtype)
            partial_alpha = tf.raw_ops.RandomGammaGrad(alpha=cast_alpha,
                                                       sample=samples)
            grad = tf.cast(
                tf.math.reduce_sum(
                    cast_dy * partial_alpha,
                    axis=tf.range(tf.rank(partial_alpha) - tf.rank(alpha))),
                output_dtype)
            return grad

        return output_type_samples, grad  # rejection_sample
def retry_init(proposal_fn, target_fn, *args, max_trials=50,
               seed=None, name=None, **kwargs):
  """Tries an MCMC initialization proposal until it gets a valid state.

  In this case, "valid" is defined as the value of `target_fn` is
  finite.  This corresponds to an MCMC workflow where `target_fn`
  compute the log-probability one wants to sample from, in which case
  "finite `target_fn`" means "finite and positive probability state".
  If `target_fn` returns a Tensor of size greater than 1, the results
  are assumed to be independent of each other, so that different batch
  members can be accepted individually.

  The method is bounded rejection sampling.  The bound serves to avoid
  wasting computation on hopeless initialization procedures.  In
  interactive MCMC, one would presumably rather come up with a better
  initialization proposal than wait for an unbounded number of
  attempts with a bad one.  If unbounded re-trials are desired,
  set `max_trials` to `None`.

  Note: XLA and @jax.jit do not support assertions, so this function
  can return invalid states on those platforms without raising an
  error (unless `max_trials` is set to `None`).

  Args:
    proposal_fn: A function accepting a `seed` keyword argument and no other
      required arguments which generates proposed initial states.
    target_fn: A function accepting the return value of `proposal_fn`
      and returning a floating-point Tensor.
    *args: Additional arguments passed to `proposal_fn`.
    max_trials: Size-1 integer `Tensor` or None. Maximum number of
      calls to `proposal_fn` to attempt.  If acceptable states are not
      found in this many trials, `retry_init` signals an error.  If
      `None`, there is no limit, and `retry_init` skips the control
      flow cost of checking for success.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'mcmc_sample_chain').
    **kwargs: Additional keyword arguments passed to `proposal_fn`.

  Returns:
    states: An acceptable result from `proposal_fn`.

  #### Example

  One popular MCMC initialization scheme is to start the chains near 0
  in unconstrained space.  There are models where the unconstraining
  transformation cannot exactly capture the space of valid states,
  such that this initialization has some material but not overwhelming
  chance of failure.  In this case, we can use `retry_init` to compensate.

  ```python
  @tfp.distributions.JointDistributionCoroutine
  def model():
    ...

  raw_init_dist = tfp.experimental.mcmc.init_near_unconstrained_zero(model)
  init_states = tfp.experimental.mcmc.retry_init(
    proposal_fn=raw_init_dist.sample,
    target_fn=model.log_prob,
    sample_shape=[100],
    seed=[4, 8])
  states = tfp.mcmc.sample_chain(
    current_state=init_states,
    ...)
  ```

  """
  def trial(seed):
    values = proposal_fn(*args, seed=seed, **kwargs)
    log_probs = target_fn(values)
    success = tf.math.is_finite(log_probs)
    return values, success
  with tf.name_scope(name or 'mcmc_retry_init'):
    values, successes, _ = brs.batched_las_vegas_algorithm(
        trial, max_trials=max_trials, seed=seed)
    if max_trials is None:
      # We were authorized to compute until success, so no need to
      # check for failure
      return values
    else:
      num_states = tf.size(successes)
      num_successes = tf.reduce_sum(tf.cast(successes, tf.int32))
      msg = ('Failed to find acceptable initial states after {} trials;\n'
             '{} of {} states have non-finite log probability').format(
                 max_trials, num_states - num_successes, num_states)
      with tf.control_dependencies([tf.debugging.assert_equal(
          successes, tf.ones_like(successes), message=msg)]):
        return tf.nest.map_structure(tf.identity, values)