Beispiel #1
0
def test_seed_stream(salt='Salt of the Earth', hardcoded_seed=None):
    """Returns a command-line-controllable SeedStream PRNG for unit tests.

  When seeding unit-test PRNGs, we want:

  - The seed to be fixed to an arbitrary value most of the time, so the test
    doesn't flake even if its failure probability is noticeable.

  - To switch to different seeds per run when using --runs_per_test to measure
    the test's failure probability.

  - To set the seed to a specific value when reproducing a low-probability event
    (e.g., debugging a crash that only some seeds trigger).

  To those ends, this function returns a `SeedStream` seeded with `test_seed`
  (which see).  The latter respects the command line flags `--fixed_seed=<seed>`
  and `--vary-seed` (Boolean, default False).  `--vary_seed` uses system entropy
  to produce unpredictable seeds.  `--fixed_seed` takes precedence over
  `--vary_seed` when both are present.

  Note that TensorFlow graph mode operations tend to read seed state from two
  sources: a "graph-level seed" and an "op-level seed".  test_util.TestCase will
  set the former to a fixed value per test, but in general it may be necessary
  to explicitly set both to ensure reproducibility.

  Args:
    salt: Optional string wherewith to salt the returned SeedStream.  Setting
      this guarantees independent random numbers across tests.
    hardcoded_seed: Optional Python value.  The seed to use if both the
      `--vary_seed` and `--fixed_seed` flags are unset.  This should usually be
      unnecessary, since a test should pass with any seed.

  Returns:
    strm: A SeedStream instance seeded with 17, unless otherwise specified by
      arguments or command line flags.
  """
    return SeedStream(test_seed(hardcoded_seed), salt=salt)
Beispiel #2
0
 def _flat_sample_distributions(self,
                                sample_shape=(),
                                seed=None,
                                value=None):
     # This function additionally depends on:
     #   self._dist_fn_wrapped
     #   self._dist_fn_args
     #   self._always_use_specified_sample_shape
     seed = SeedStream(seed, salt='JointDistributionSequential')
     ds = []
     xs = [None] * len(self._dist_fn_wrapped) if value is None else list(
         value)
     if len(xs) != len(self._dist_fn_wrapped):
         raise ValueError('Number of `xs`s must match number of '
                          'distributions.')
     for i, (dist_fn, args) in enumerate(
             zip(self._dist_fn_wrapped, self._dist_fn_args)):
         ds.append(dist_fn(*xs[:i]))  # Chain rule of probability.
         if xs[i] is None:
             # TODO(b/129364796): We should ignore args prefixed with `_`; this
             # would mean we more often identify when to use `sample_shape=()`
             # rather than `sample_shape=sample_shape`.
             xs[i] = ds[-1].sample(
                 () if args and not self._always_use_specified_sample_shape
                 else sample_shape,
                 seed=seed())
         else:
             xs[i] = nest.map_structure_up_to(
                 ds[-1].dtype,
                 lambda x, dtype: tf.convert_to_tensor(x, dtype_hint=dtype),
                 xs[i], ds[-1].dtype)
             seed(
             )  # Ensure reproducibility even when xs are (partially) set.
     # Note: we could also resolve distributions up to the first non-`None` in
     # `self._model_flatten(value)`, however we omit this feature for simplicity,
     # speed, and because it has not yet been requested.
     return ds, xs
Beispiel #3
0
  def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None):
    """Executes `model`, creating both samples and distributions."""
    ds = []
    values_out = []
    seed = SeedStream('JointDistributionCoroutine', seed)
    gen = self._model()
    index = 0
    d = next(gen)
    if not isinstance(d, self.Root):
      raise ValueError('First distribution yielded by coroutine must '
                       'be wrapped in `Root`.')
    try:
      while True:
        actual_distribution = d.distribution if isinstance(d, self.Root) else d
        ds.append(actual_distribution)
        if (value is not None and len(value) > index and
            value[index] is not None):
          seed()
          next_value = value[index]
        else:
          next_value = actual_distribution.sample(
              sample_shape=sample_shape if isinstance(d, self.Root) else (),
              seed=seed())

        if self._validate_args:
          with tf.control_dependencies(
              self._assert_compatible_shape(
                  index, sample_shape, next_value)):
            values_out.append(tf.nest.map_structure(tf.identity, next_value))
        else:
          values_out.append(next_value)

        index += 1
        d = gen.send(next_value)
    except StopIteration:
      pass
    return ds, values_out
Beispiel #4
0
  def _sample_n(self, n, seed=None):
    seed = SeedStream(seed, 'dirichlet_multinomial')

    concentration = tf.convert_to_tensor(self._concentration)
    total_count = tf.convert_to_tensor(self._total_count)

    n_draws = tf.cast(total_count, dtype=tf.int32)
    k = self._event_shape_tensor(concentration)[0]
    alpha = tf.math.multiply(
        tf.ones_like(total_count[..., tf.newaxis]),
        concentration,
        name='alpha')

    unnormalized_logits = tf.math.log(
        tf.random.gamma(
            shape=[n],
            alpha=alpha,
            dtype=self.dtype,
            seed=seed()))
    x = multinomial.draw_sample(
        1, k, unnormalized_logits, n_draws, self.dtype, seed())
    final_shape = tf.concat(
        [[n], self._batch_shape_tensor(concentration, total_count), [k]], 0)
    return tf.reshape(x, final_shape)
Beispiel #5
0
  def make_transform_hmc_kernel_fn(
      target_log_prob_fn,
      init_state,
      scalings,
      seed=None):
    """Generate a transform hmc kernel."""

    with tf.name_scope('make_transformed_hmc_kernel_fn'):
      seed = SeedStream(seed, salt='make_transformed_hmc_kernel_fn')
      # TransformedTransitionKernel doesn't modify the input step size, thus we
      # need to pass the appropriate step size that are already in unconstrained
      # space
      state_std = [
          tf.math.reduce_std(bij.inverse(x), axis=0, keepdims=True)
          for x, bij in zip(init_state, unconstraining_bijectors)
      ]
      step_size = compute_hmc_step_size(scalings, state_std, num_leapfrog_steps)
      return transformed_kernel.TransformedTransitionKernel(
          hmc.HamiltonianMonteCarlo(
              target_log_prob_fn=target_log_prob_fn,
              num_leapfrog_steps=num_leapfrog_steps,
              step_size=step_size,
              seed=seed),
          unconstraining_bijectors)
 def _sample_3d(self, n, mean_direction, concentration, seed=None):
   """Specialized inversion sampler for 3D."""
   seed = SeedStream(seed, salt='von_mises_fisher_3d')
   u_shape = tf.concat([[n], self._batch_shape_tensor(
       mean_direction=mean_direction, concentration=concentration)], axis=0)
   z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype)
   # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could
   # be bisected for bounded sampling runtime (i.e. not rejection sampling).
   # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/
   # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa
   # We must protect against both kappa and z being zero.
   safe_conc = tf.where(concentration > 0, concentration,
                        tf.ones_like(concentration))
   safe_z = tf.where(z > 0, z, tf.ones_like(z))
   safe_u = 1 + tf.reduce_logsumexp(
       [tf.math.log(safe_z),
        tf.math.log1p(-safe_z) - 2 * safe_conc], axis=0) / safe_conc
   # Limit of the above expression as kappa->0 is 2*z-1
   u = tf.where(concentration > 0., safe_u, 2 * z - 1)
   # Limit of the expression as z->0 is -1.
   u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u)
   if not self._allow_nan_stats:
     u = tf.debugging.check_numerics(u, 'u in _sample_3d')
   return u[..., tf.newaxis]
    def _fn(state_parts, seed):
        """Adds a normal perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
        applied.
        Default value: `None`.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
        with tf.compat.v1.name_scope(name,
                                     'random_walk_normal_fn',
                                     values=[state_parts, scale, seed]):
            scales = scale if mcmc_util.is_list_like(scale) else [scale]
            if len(scales) == 1:
                scales *= len(state_parts)
            if len(state_parts) != len(scales):
                raise ValueError('`scale` must broadcast with `state_parts`.')
            seed_stream = SeedStream(seed, salt='RandomWalkNormalFn')
            next_state_parts = [
                tf.random.normal(mean=state_part,
                                 stddev=scale_part,
                                 shape=tf.shape(input=state_part),
                                 dtype=state_part.dtype.base_dtype,
                                 seed=seed_stream())
                for scale_part, state_part in zip(scales, state_parts)
            ]

            return next_state_parts
Beispiel #8
0
    def _sample_n(self, n, seed=None):
        # The sampling method comes from the fact that if:
        #   X ~ Normal(0, 1)
        #   Z ~ Chi2(df)
        #   Y = |X| / sqrt(Z / df)
        # then:
        #   Y ~ HalfStudentT(df).
        df = tf.convert_to_tensor(self.df)
        loc = tf.convert_to_tensor(self.loc)
        scale = tf.convert_to_tensor(self.scale)
        batch_shape = self._batch_shape_tensor(df=df, loc=loc, scale=scale)
        shape = tf.concat([[n], batch_shape], 0)
        seed = SeedStream(seed, "half_student_t")

        abs_normal_sample = tf.math.abs(
            tf.random.normal(shape, dtype=self.dtype, seed=seed()))
        df = df * tf.ones(batch_shape, dtype=self.dtype)
        gamma_sample = tf.random.gamma([n],
                                       0.5 * df,
                                       beta=0.5,
                                       dtype=self.dtype,
                                       seed=seed())
        samples = abs_normal_sample * tf.math.rsqrt(gamma_sample / df)
        return samples * scale + loc  # Abs(scale) not wanted.
Beispiel #9
0
    def __init__(self, inner_kernel, seed=None, name=None):
        """Instantiates this object.

    Args:
      inner_kernel: `TransitionKernel`-like object which has
        `collections.namedtuple` `kernel_results` and which contains a
        `target_log_prob` member and optionally a `log_acceptance_correction`
        member.
      seed: Python integer to seed the random number generator.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., "mh_kernel").

    Returns:
      metropolis_hastings_kernel: Instance of `TransitionKernel` which wraps the
        input transition kernel with the Metropolis-Hastings algorithm.
    """
        if inner_kernel.is_calibrated:
            warnings.warn('Supplied `TransitionKernel` is already calibrated. '
                          'Composing `MetropolisHastings` `TransitionKernel` '
                          'may not be required.')
        self._seed_stream = SeedStream(seed, 'metropolis_hastings_one_step')
        self._parameters = dict(inner_kernel=inner_kernel,
                                seed=seed,
                                name=name)
    def _sample_n(self, n, seed=None):
        # Generate samples using:
        # mu + sigma* sgn(U-0.5)* sqrt(X^2 + Y^2 + Z^2) U~Unif; X,Y,Z ~N(0,1)
        seed = SeedStream(seed, salt='DoublesidedMaxwell')

        loc = tf.convert_to_tensor(self.loc)
        scale = tf.convert_to_tensor(self.scale)
        shape = prefer_static.pad(self._batch_shape_tensor(loc=loc,
                                                           scale=scale),
                                  paddings=[[1, 0]],
                                  constant_values=n)

        # Generate one-sided Maxwell variables by using 3 Gaussian variates
        norm_rvs = tf.random.normal(shape=prefer_static.pad(shape,
                                                            paddings=[[0, 1]],
                                                            constant_values=3),
                                    dtype=self.dtype,
                                    seed=seed())
        maxwell_rvs = tf.norm(norm_rvs, axis=-1)

        # Generate random signs for the symmetric variates.
        random_sign = tfp_math.random_rademacher(shape, seed=seed())
        sampled = random_sign * maxwell_rvs * scale + loc
        return sampled
Beispiel #11
0
  def _sample_n(self, n, seed=None):
    low = tf.convert_to_tensor(self.low)
    high = tf.convert_to_tensor(self.high)
    peak = tf.convert_to_tensor(self.peak)

    stream = SeedStream(seed, salt='triangular')
    shape = tf.concat([[n], self._batch_shape_tensor(
        low=low, high=high, peak=peak)], axis=0)
    samples = tf.random.uniform(shape=shape, dtype=self.dtype, seed=stream())
    # We use Inverse CDF sampling here. Because the CDF is a quadratic function,
    # we must use sqrts here.
    interval_length = high - low
    return tf.where(
        # Note the CDF on the left side of the peak is
        # (x - low) ** 2 / ((high - low) * (peak - low)).
        # If we plug in peak for x, we get that the CDF at the peak
        # is (peak - low) / (high - low). Because of this we decide
        # which part of the piecewise CDF we should use based on the cdf samples
        # we drew.
        samples < (peak - low) / interval_length,
        # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)).
        low + tf.sqrt(samples * interval_length * (peak - low)),
        # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak))
        high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
  def __init__(
      self,
      posterior,
      prior,
      penalty_weight=None,
      posterior_penalty_fn=kl_divergence_monte_carlo,
      posterior_value_fn=tfd.Distribution.sample,
      seed=None,
      dtype=tf.float32,
      name=None):
    """Base class for variational layers.

    # mean ==> penalty_weight =          1 / train_size
    # sum  ==> penalty_weight = batch_size / train_size

    Args:
      posterior: ...
      prior: ...
      penalty_weight: ...
      posterior_penalty_fn: ...
      posterior_value_fn: ...
      seed: ...
      dtype: ...
      name: Python `str` prepeneded to ops created by this object.
        Default value: `None` (i.e., `type(self).__name__`).
    """
    super(VariationalLayer, self).__init__(name=name)
    self._posterior = posterior
    self._prior = prior
    self._penalty_weight = penalty_weight
    self._posterior_penalty_fn = posterior_penalty_fn
    self._posterior_value_fn = posterior_value_fn
    self._seed = SeedStream(seed, salt=self.name)
    self._dtype = dtype
    tf.nest.assert_same_structure(prior.dtype, posterior.dtype,
                                  check_types=False)
def sample_sequential_monte_carlo(
        prior_log_prob_fn,
        likelihood_log_prob_fn,
        current_state,
        max_num_steps=25,
        max_stage=100,
        make_kernel_fn=make_rwmh_kernel_fn,
        tuning_fn=simple_heuristic_tuning,
        make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn,
        ess_threshold_ratio=0.5,
        parallel_iterations=10,
        seed=None,
        name=None):
    """Runs Sequential Monte Carlo to sample from the posterior distribution.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'prior' distribution:

    `exp(prior_log_prob_fn(x))`

  and the target 'posterior' distribution:

    `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`,

  by mutating a collection of MC samples (i.e., particles). The approach is also
  known as Particle Filter in some literature. The current implemenetation is
  largely based on  Del Moral et al [1], which adapts the tempering sequence
  adaptively (base on the effective sample size) and the scaling of the mutation
  kernel (base on the sample covariance of the particles) at each stage.

  Args:
    prior_log_prob_fn: Python callable that returns the log density of the
      prior distribution.
    likelihood_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the likelihood distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    max_num_steps: The maximum number of kernel transition steps in one mutation
      of the MC samples. Note that the actual number of steps in one mutation is
      tuned during sampling and likely lower than the max_num_step.
    max_stage: Integer number of the stage for increasing the temperature
      from 0 to 1.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_sequential_monte_carlo` creates a new `target_log_prob_fn`
      which is an interpolation between the supplied `target_log_prob_fn` and
      `proposal_log_prob_fn`; it is this interpolated function which is used as
      an argument to `make_kernel_fn`.
    tuning_fn: Python `callable` which takes the number of steps, the log
      scaling, and the log acceptance ratio from the last mutation and output
      the number of steps and log scaling for the next mutation.
    make_tempered_target_log_prob_fn: Python `callable` that takes the
      `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures`
      and creates a `target_log_prob_fn` `callable` that pass to
      `make_kernel_fn`.
    ess_threshold_ratio: Target ratio for effective sample size.
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    seed: Python integer or TFP seedstream to seed the random number generator.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_sequential_monte_carlo').

  Returns:
    n_stage: Number of the mutation stage SMC ran.
    final_state: `Tensor` or Python `list` of `Tensor`s representing the
      final state(s) of the Markov chain(s). The output are the posterior
      samples.
    final_kernel_results: `collections.namedtuple` of internal calculations used
      to advance the chain.

  #### References

  [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential
      Monte Carlo method for approximate Bayesian computation.
      _Statistics and Computing_, 22.5(1009-1020), 2012.

  """

    with tf.name_scope(name or 'sample_sequential_monte_carlo'):
        seed_stream = SeedStream(seed, salt='smc_seed')

        unwrap_state_list = not tf.nest.is_nested(current_state)
        if unwrap_state_list:
            current_state = [current_state]
        current_state = [
            tf.convert_to_tensor(s, dtype_hint=tf.float32)
            for s in current_state
        ]

        # Initial preprocessing at Stage 0
        likelihood_log_prob = likelihood_log_prob_fn(*current_state)

        likelihood_rank = ps.rank(likelihood_log_prob)
        dimension = ps.reduce_sum([
            ps.reduce_prod(ps.shape(x)[likelihood_rank:])
            for x in current_state
        ])

        # We infer the particle shapes from the resulting likelihood:
        # [num_particles, b1, ..., bN]
        particle_shape = ps.shape(likelihood_log_prob)
        num_particles, batch_shape = particle_shape[0], particle_shape[1:]
        effective_sample_size_threshold = tf.cast(
            num_particles * ess_threshold_ratio, tf.int32)

        # TODO(b/152412213): Revisit this default parameter.
        # Default to the optimal scaling of a random walk kernel for a d-dimensional
        # normal distributed targets: 2.38 ** 2 / d.
        # For more detail see:
        # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of
        # random walk Metropolis algorithms. _The annals of applied probability_.
        # 1997;7(1):110-20.
        scale_start = (tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) /
                       tf.constant(dimension, dtype=likelihood_log_prob.dtype))

        inverse_temperature = tf.zeros(batch_shape,
                                       dtype=likelihood_log_prob.dtype)
        scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(
            scale_start, 1.)
        kernel = make_kernel_fn(make_tempered_target_log_prob_fn(
            prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature),
                                current_state,
                                scalings,
                                seed=seed_stream)
        pkr = kernel.bootstrap_results(current_state)
        _, kernel_target_log_prob = gather_mh_like_result(pkr)

        particle_info = ParticleInfo(
            log_accept_prob=ps.zeros_like(likelihood_log_prob),
            log_scalings=tf.math.log(scalings),
            tempered_log_prob=kernel_target_log_prob,
            likelihood_log_prob=likelihood_log_prob,
        )

        current_pkr = SMCResults(
            num_steps=tf.convert_to_tensor(max_num_steps,
                                           dtype=tf.int32,
                                           name='num_steps'),
            inverse_temperature=inverse_temperature,
            log_marginal_likelihood=tf.zeros_like(inverse_temperature),
            particle_info=particle_info)

        def update_weights_temperature(inverse_temperature,
                                       likelihood_log_prob):
            """Calculate the next inverse temperature and update weights."""
            likelihood_diff = likelihood_log_prob - tf.reduce_max(
                likelihood_log_prob, axis=0)

            def _body_fn(new_beta, upper_beta, lower_beta, eff_size,
                         log_weights):
                """One iteration of the temperature and weight update."""
                new_beta = (lower_beta + upper_beta) / 2.0
                log_weights = (new_beta -
                               inverse_temperature) * likelihood_diff
                log_weights_norm = tf.math.log_softmax(log_weights, axis=0)
                eff_size = tf.cast(
                    tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm,
                                                     axis=0)), tf.int32)
                upper_beta = tf.where(
                    eff_size < effective_sample_size_threshold, new_beta,
                    upper_beta)
                lower_beta = tf.where(
                    eff_size < effective_sample_size_threshold, lower_beta,
                    new_beta)
                return new_beta, upper_beta, lower_beta, eff_size, log_weights

            def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_):  # pylint: disable=unused-argument
                # TODO(junpenglao): revisit threshold below to be dtype specific.
                threshold = 1e-6
                return (tf.math.reduce_any(upper_beta - lower_beta > threshold)
                        & tf.math.reduce_any(
                            eff_size != effective_sample_size_threshold))

            (new_beta, upper_beta, lower_beta, eff_size,
             log_weights) = tf.while_loop(  # pylint: disable=unused-variable
                 cond=_cond_fn,
                 body=_body_fn,
                 loop_vars=(tf.zeros_like(inverse_temperature),
                            tf.fill(ps.shape(inverse_temperature),
                                    tf.constant(2, inverse_temperature.dtype)),
                            inverse_temperature,
                            tf.zeros_like(inverse_temperature, dtype=tf.int32),
                            tf.zeros_like(likelihood_diff)),
                 parallel_iterations=parallel_iterations)

            log_weights = tf.where(new_beta < 1., log_weights,
                                   (1. - inverse_temperature) *
                                   likelihood_diff)
            marginal_loglike_ = reduce_logmeanexp(
                (new_beta - inverse_temperature) * likelihood_log_prob, axis=0)
            new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.)

            return marginal_loglike_, new_inverse_temperature, log_weights

        def mutate(current_state, log_scalings, num_steps,
                   inverse_temperature):
            """Mutate the state using a Transition kernel."""
            with tf.name_scope('mutate_states'):
                scalings = tf.exp(log_scalings)
                kernel = make_kernel_fn(make_tempered_target_log_prob_fn(
                    prior_log_prob_fn, likelihood_log_prob_fn,
                    inverse_temperature),
                                        current_state,
                                        scalings,
                                        seed=seed_stream)
                pkr = kernel.bootstrap_results(current_state)
                kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)

                def mutate_onestep(i, state, pkr, log_accept_prob_sum):
                    next_state, next_kernel_results = kernel.one_step(
                        state, pkr)
                    kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)
                    log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.)
                    log_accept_prob_sum = log_add_exp(log_accept_prob_sum,
                                                      log_accept_prob)
                    return i + 1, next_state, next_kernel_results, log_accept_prob_sum

                (
                    _, next_state, next_kernel_results, log_accept_prob_sum
                ) = tf.while_loop(
                    cond=lambda i, *args: i < num_steps,
                    body=mutate_onestep,
                    loop_vars=(
                        tf.zeros([], dtype=tf.int32),
                        current_state,
                        pkr,
                        # we accumulate the acceptance probability in log space.
                        tf.fill(
                            ps.shape(kernel_log_accept_ratio),
                            tf.constant(-np.inf,
                                        kernel_log_accept_ratio.dtype))),
                    parallel_iterations=parallel_iterations)
                _, kernel_target_log_prob = gather_mh_like_result(
                    next_kernel_results)
                avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log(
                    tf.cast(num_steps + 1, log_accept_prob_sum.dtype))
                return (next_state, avg_log_accept_prob_per_particle,
                        kernel_target_log_prob)

        # One SMC steps.
        def smc_body_fn(stage, state, smc_kernel_result):
            """Run one stage of SMC with constant temperature."""
            (new_marginal, new_inv_temperature,
             log_weights) = update_weights_temperature(
                 smc_kernel_result.inverse_temperature,
                 smc_kernel_result.particle_info.likelihood_log_prob)
            # TODO(b/152412213) Use a tf.scan to better collect debug info.
            if PRINT_DEBUG:
                tf.print(
                    'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:',
                    smc_kernel_result.num_steps, 'accept:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_accept_prob,
                            axis=0)), 'scaling:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_scalings,
                            axis=0)))
            (resampled_state,
             resampled_particle_info), _ = resample_particle_and_info(
                 (state, smc_kernel_result.particle_info),
                 log_weights,
                 seed=seed_stream)
            next_num_steps, next_log_scalings = tuning_fn(
                smc_kernel_result.num_steps,
                resampled_particle_info.log_scalings,
                resampled_particle_info.log_accept_prob)
            # Skip tuning at stage 0.
            next_num_steps = tf.where(stage == 0, smc_kernel_result.num_steps,
                                      next_num_steps)
            next_log_scalings = tf.where(stage == 0,
                                         resampled_particle_info.log_scalings,
                                         next_log_scalings)
            next_num_steps = tf.clip_by_value(next_num_steps, 2, max_num_steps)

            next_state, log_accept_prob, tempered_log_prob = mutate(
                resampled_state, next_log_scalings, next_num_steps,
                new_inv_temperature)
            next_pkr = SMCResults(
                num_steps=next_num_steps,
                inverse_temperature=new_inv_temperature,
                log_marginal_likelihood=(
                    new_marginal + smc_kernel_result.log_marginal_likelihood),
                particle_info=ParticleInfo(
                    log_accept_prob=log_accept_prob,
                    log_scalings=next_log_scalings,
                    tempered_log_prob=tempered_log_prob,
                    likelihood_log_prob=likelihood_log_prob_fn(*next_state),
                ))
            return stage + 1, next_state, next_pkr

        (n_stage, final_state, final_kernel_results) = tf.while_loop(
            cond=lambda i, state, pkr: (  # pylint: disable=g-long-lambda
                (i < max_stage) & tf.reduce_any(pkr.inverse_temperature < 1.)),
            body=smc_body_fn,
            loop_vars=(tf.zeros([],
                                dtype=tf.int32), current_state, current_pkr),
            parallel_iterations=parallel_iterations)
        if unwrap_state_list:
            final_state = final_state[0]
        return n_stage, final_state, final_kernel_results
Beispiel #14
0
  def __init__(self,
               target_log_prob_fn,
               step_size,
               volatility_fn=None,
               parallel_iterations=10,
               compute_acceptance=True,
               seed=None,
               name=None):
    """Initializes Langevin diffusion transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      volatility_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns
        volatility value at `current_state`. Should return a `Tensor` or Python
        `list` of `Tensor`s that must broadcast with the shape of
        `current_state` Defaults to the identity function.
      parallel_iterations: the number of coordinates for which the gradients of
        the volatility matrix `volatility_fn` can be computed in parallel.
      compute_acceptance: Python 'bool' indicating whether to compute the
        Metropolis log-acceptance ratio used to construct
        `MetropolisAdjustedLangevinAlgorithm` kernel.
      seed: Python integer to seed the random number generator.
        Default value: `None` (i.e., no seed).
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'mala_kernel').

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      TypeError: if `volatility_fn` is not callable.
    """
    self._seed_stream = SeedStream(seed, salt='UncalibratedLangevin')
    # Default value of `volatility_fn` is the identity function.
    if volatility_fn is None:
      volatility_fn = lambda *args: 1.
    if not callable(volatility_fn):
      raise TypeError('`volatility_fn` must be callable (saw: {})'.format(
          type(volatility_fn)))
    self._parameters = dict(
        target_log_prob_fn=target_log_prob_fn,
        step_size=step_size,
        volatility_fn=volatility_fn,
        compute_acceptance=tf.convert_to_tensor(value=compute_acceptance),
        seed=seed,
        parallel_iterations=parallel_iterations,
        name=name)
    def _sample_n(self, n, seed=None):
        stream = SeedStream(seed, salt="VectorDiffeomixture")
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=stream())  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
        mix_batch_size = tensorshape_util.num_elements(
            self.mixture_distribution.batch_shape)
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]),
                                          [batch_size // mix_batch_size])),
                                               seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = tensorshape_util.num_elements(
            tensorshape_util.with_rank(self.grid.shape[-2:], rank=2))
        if stride is None:
            stride = tf.reduce_prod(tf.shape(self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if tensorshape_util.is_fully_defined(self.batch_shape):
            new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
Beispiel #16
0
  def __init__(self,
               target_log_prob_fn,
               step_size,
               volatility_fn=None,
               seed=None,
               parallel_iterations=10,
               name=None):
    """Initializes MALA transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      volatility_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns
        volatility value at `current_state`. Should return a `Tensor` or Python
        `list` of `Tensor`s that must broadcast with the shape of
        `current_state` Defaults to the identity function.
      seed: Python integer to seed the random number generator. Deprecated, pass
        seed to `tfp.mcmc.sample_chain`.
      parallel_iterations: the number of coordinates for which the gradients of
        the volatility matrix `volatility_fn` can be computed in parallel.
        Default value: `None` (i.e., no seed).
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'mala_kernel').

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      TypeError: if `volatility_fn` is not callable.
    """
    seed_stream = SeedStream(seed, salt='langevin')
    mh_kwargs = {} if seed is None else dict(seed=seed_stream())
    uncal_kwargs = {} if seed is None else dict(seed=seed_stream())
    impl = metropolis_hastings.MetropolisHastings(
        inner_kernel=UncalibratedLangevin(
            target_log_prob_fn=target_log_prob_fn,
            step_size=step_size,
            volatility_fn=volatility_fn,
            parallel_iterations=parallel_iterations,
            name=name,
            **uncal_kwargs),
        **mh_kwargs)

    self._impl = impl
    parameters = impl.inner_kernel.parameters.copy()
    # Remove `compute_acceptance` parameter as this is not a MALA kernel
    # `__init__` parameter.
    del parameters['compute_acceptance']
    self._parameters = parameters
  def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None):
    # This function additionally depends on:
    #   self._dist_fn_wrapped
    #   self._dist_fn_args
    #   self._always_use_specified_sample_shape
    num_dists = len(self._dist_fn_wrapped)
    if seed is not None and samplers.is_stateful_seed(seed):
      seed_stream = SeedStream(seed, salt='JointDistributionSequential')
    else:
      seed_stream = None
    if seed is not None:
      seeds = samplers.split_seed(seed, n=num_dists,
                                  salt='JointDistributionSequential')
    else:
      seeds = [None] * num_dists
    ds = []
    xs = [None] * num_dists if value is None else list(value)
    if len(xs) != num_dists:
      raise ValueError('Number of `xs`s must match number of '
                       'distributions.')
    for i, (dist_fn, args) in enumerate(zip(self._dist_fn_wrapped,
                                            self._dist_fn_args)):
      ds.append(dist_fn(*xs[:i]))  # Chain rule of probability.

      # Ensure reproducibility even when xs are (partially) set.
      stateful_seed = None if seed_stream is None else seed_stream()

      if xs[i] is None:
        # TODO(b/129364796): We should ignore args prefixed with `_`; this
        # would mean we more often identify when to use `sample_shape=()`
        # rather than `sample_shape=sample_shape`.
        try:  # TODO(b/147874898): Eliminate the stateful fallback 20 Dec 2020.
          xs[i] = ds[-1].sample(
              () if args and not self._always_use_specified_sample_shape
              else sample_shape, seed=seeds[i])
        except TypeError as e:
          if ('Expected int for argument' not in str(e) and
              TENSOR_SEED_MSG_PREFIX not in str(e)) or stateful_seed is None:
            raise

          if not getattr(self, '_resolving_names', False):  # avoid recursion
            self._resolving_names = True
            resolved_names = self._flat_resolve_names()
            self._resolving_names = False
            msg = (
                'Falling back to stateful sampling for distribution #{i} '
                '(0-based) of type `{dist_cls}` with component name '
                '"{component_name}" and `dist.name` "{dist_name}". Please '
                'update to use `tf.random.stateless_*` RNGs. This fallback may '
                'be removed after 20-Dec-2020. ({exc})')
            warnings.warn(msg.format(
                i=i,
                dist_name=ds[-1].name,
                component_name=resolved_names[i],
                dist_cls=type(ds[-1]),
                exc=str(e)))
          xs[i] = ds[-1].sample(
              () if args and not self._always_use_specified_sample_shape
              else sample_shape, seed=stateful_seed)

      else:
        # This signature does not allow kwarg names. Applies
        # `convert_to_tensor` on the next value.
        xs[i] = nest.map_structure_up_to(
            ds[-1].dtype,  # shallow_tree
            lambda x, dtype: tf.convert_to_tensor(x, dtype_hint=dtype),  # func
            xs[i],  # x
            ds[-1].dtype)  # dtype
    # Note: we could also resolve distributions up to the first non-`None` in
    # `self._model_flatten(value)`, however we omit this feature for simplicity,
    # speed, and because it has not yet been requested.
    return ds, xs
Beispiel #18
0
    def _flat_sample_distributions(self,
                                   sample_shape=(),
                                   seed=None,
                                   value=None):
        """Executes `model`, creating both samples and distributions."""
        ds = []
        values_out = []
        if samplers.is_stateful_seed(seed):
            seed_stream = SeedStream(seed, salt='JointDistributionCoroutine')
            if not self._stateful_to_stateless:
                seed = None
        else:
            seed_stream = None  # We got a stateless seed for seed=.

        # TODO(b/166658748): Make _stateful_to_stateless always True (eliminate it).
        if self._stateful_to_stateless and (seed is not None or not JAX_MODE):
            seed = samplers.sanitize_seed(seed,
                                          salt='JointDistributionCoroutine')
        gen = self._model_coroutine()
        index = 0
        d = next(gen)
        if self._require_root and not isinstance(d, self.Root):
            raise ValueError('First distribution yielded by coroutine must '
                             'be wrapped in `Root`.')
        try:
            while True:
                actual_distribution = d.distribution if isinstance(
                    d, self.Root) else d
                ds.append(actual_distribution)
                # Ensure reproducibility even when xs are (partially) set. Always split.
                stateful_sample_seed = None if seed_stream is None else seed_stream(
                )
                if seed is None:
                    stateless_sample_seed = None
                else:
                    stateless_sample_seed, seed = samplers.split_seed(seed)

                if (value is not None and len(value) > index
                        and value[index] is not None):

                    def convert_tree_to_tensor(x, dtype_hint):
                        return tf.convert_to_tensor(x, dtype_hint=dtype_hint)

                    # This signature does not allow kwarg names. Applies
                    # `convert_to_tensor` on the next value.
                    next_value = nest.map_structure_up_to(
                        ds[-1].dtype,  # shallow_tree
                        convert_tree_to_tensor,  # func
                        value[index],  # x
                        ds[-1].dtype)  # dtype_hint
                else:
                    try:
                        next_value = actual_distribution.sample(
                            sample_shape=sample_shape if isinstance(
                                d, self.Root) else (),
                            seed=(stateful_sample_seed
                                  if stateless_sample_seed is None else
                                  stateless_sample_seed))
                    except TypeError as e:
                        if ('Expected int for argument' not in str(e)
                                and TENSOR_SEED_MSG_PREFIX not in str(e)) or (
                                    stateful_sample_seed is None):
                            raise
                        msg = (
                            'Falling back to stateful sampling for distribution #{index} '
                            '(0-based) of type `{dist_cls}` with component name '
                            '{component_name} and `dist.name` "{dist_name}". Please '
                            'update to use `tf.random.stateless_*` RNGs. This fallback may '
                            'be removed after 20-Dec-2020. ({exc})')
                        component_name = (joint_distribution_lib.
                                          get_explicit_name_for_component(
                                              ds[-1]))
                        if component_name is None:
                            component_name = '[None specified]'
                        else:
                            component_name = '"{}"'.format(component_name)
                        warnings.warn(
                            msg.format(index=index,
                                       component_name=component_name,
                                       dist_name=ds[-1].name,
                                       dist_cls=type(ds[-1]),
                                       exc=str(e)))
                        next_value = actual_distribution.sample(
                            sample_shape=sample_shape if isinstance(
                                d, self.Root) else (),
                            seed=stateful_sample_seed)

                if self._validate_args:
                    with tf.control_dependencies(
                            self._assert_compatible_shape(
                                index, sample_shape, next_value)):
                        values_out.append(
                            tf.nest.map_structure(tf.identity, next_value))
                else:
                    values_out.append(next_value)

                index += 1
                d = gen.send(next_value)
        except StopIteration:
            pass
        return ds, values_out
Beispiel #19
0
def random_von_mises(shape, concentration, dtype=tf.float32, seed=None):
    """Samples from the standardized von Mises distribution.

  The distribution is vonMises(loc=0, concentration=concentration), so the mean
  is zero.
  The location can then be changed by adding it to the samples.

  The sampling algorithm is rejection sampling with wrapped Cauchy proposal [1].
  The samples are pathwise differentiable using the approach of [2].

  Arguments:
    shape: The output sample shape.
    concentration: The concentration parameter of the von Mises distribution.
    dtype: The data type of concentration and the outputs.
    seed: (optional) The random seed.

  Returns:
    Differentiable samples of standardized von Mises.

  References:
    [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag,
    1986; Chapter 9, p. 473-476.
    http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
    + corrections http://www.nrbook.com/devroye/Devroye_files/errors.pdf
    [2] Michael Figurnov, Shakir Mohamed, Andriy Mnih. "Implicit
    Reparameterization Gradients", 2018.
  """
    seed = SeedStream(seed, salt='von_mises')
    concentration = tf.convert_to_tensor(concentration,
                                         dtype=dtype,
                                         name='concentration')

    @tf.custom_gradient
    def rejection_sample_with_gradient(concentration):
        """Performs rejection sampling for standardized von Mises.

    A nested function is required because @tf.custom_gradient does not handle
    non-tensor inputs such as dtype. Instead, they are captured by the outer
    scope.

    Arguments:
      concentration: The concentration parameter of the distribution.

    Returns:
      Differentiable samples of standardized von Mises.
    """
        r = 1. + tf.sqrt(1. + 4. * concentration**2)
        rho = (r - tf.sqrt(2. * r)) / (2. * concentration)

        s_exact = (1. + rho**2) / (2. * rho)

        # For low concentration, s becomes numerically unstable.
        # To fix that, we use an approximation. Here is the derivation.
        # First-order Taylor expansion at conc = 0 gives
        #   sqrt(1 + 4 concentration^2) ~= 1 + (2 concentration)^2 / 2.
        # Therefore, r ~= 2 + 2 concentration. By plugging this into rho, we have
        #   rho ~= conc + 1 / conc - sqrt(1 + 1 / concentration^2).
        # Let's expand the last term at concentration=0 up to the linear term:
        #   sqrt(1 + 1 / concentration^2) ~= 1 / concentration + concentration / 2
        # Thus, rho ~= concentration / 2. Finally,
        #   s = 1 / (2 rho) + rho / 2 ~= 1 / concentration + concentration / 4.
        # Since concentration is small, we drop the second term and simply use
        #   s ~= 1 / concentration.
        s_approximate = 1. / concentration

        # To compute the cutoff, we compute s_exact using mpmath with 30 decimal
        # digits precision and compare that to the s_exact and s_approximate
        # computed with dtype. Then, the cutoff is the largest concentration for
        # which abs(s_exact - s_exact_mpmath) > abs(s_approximate - s_exact_mpmath).
        s_concentration_cutoff_dict = {
            tf.float16: 1.8e-1,
            tf.float32: 2e-2,
            tf.float64: 1.2e-4,
        }
        s_concentration_cutoff = s_concentration_cutoff_dict[dtype]

        s = tf.where(concentration > s_concentration_cutoff, s_exact,
                     s_approximate)

        def loop_body(done, u, w):
            """Resample the non-accepted points."""
            # We resample u each time completely. Only its sign is used outside the
            # loop, which is random.
            u = tf.random.uniform(shape,
                                  minval=-1.,
                                  maxval=1.,
                                  dtype=dtype,
                                  seed=seed())
            z = tf.cos(np.pi * u)
            # Update the non-accepted points.
            w = tf.where(done, w, (1. + s * z) / (s + z))
            y = concentration * (s - w)

            v = tf.random.uniform(shape,
                                  minval=0.,
                                  maxval=1.,
                                  dtype=dtype,
                                  seed=seed())
            accept = (y * (2. - y) >= v) | (tf.math.log(y / v) + 1. >= y)

            return done | accept, u, w

        _, u, w = tf.while_loop(
            cond=lambda done, *_: ~tf.reduce_all(done),
            body=loop_body,
            loop_vars=(
                tf.zeros(shape, dtype=tf.bool, name='done'),
                tf.zeros(shape, dtype=dtype, name='u'),
                tf.zeros(shape, dtype=dtype, name='w'),
            ),
            # The expected number of iterations depends on concentration.
            # It monotonically increases from one iteration for concentration = 0 to
            # sqrt(2 pi / e) ~= 1.52 iterations for concentration = +inf [1].
            # We use a limit of 100 iterations to avoid infinite loops
            # for very large / nan concentration.
            maximum_iterations=100,
            parallel_iterations=1 if seed.original_seed is None else 10,
        )

        x = tf.sign(u) * tf.math.acos(w)

        def grad(dy):
            """The gradient of the von Mises samples w.r.t. concentration."""
            broadcast_concentration = tf.broadcast_to(concentration,
                                                      prefer_static.shape(x))
            _, dcdf_dconcentration = value_and_gradient(
                lambda conc: von_mises_cdf(x, conc), broadcast_concentration)
            inv_prob = tf.exp(-broadcast_concentration * (tf.cos(x) - 1.)) * (
                (2. * np.pi) * tf.math.bessel_i0e(broadcast_concentration))
            # Compute the implicit reparameterization gradient [2],
            # dz/dconc = -(dF(z; conc) / dconc) / p(z; conc)
            ret = dy * (-inv_prob * dcdf_dconcentration)
            # Sum over the sample dimensions. Assume that they are always the first
            # ones.
            num_sample_dimensions = (tf.rank(broadcast_concentration) -
                                     tf.rank(concentration))
            return tf.reduce_sum(ret, axis=tf.range(num_sample_dimensions))

        return x, grad

    return rejection_sample_with_gradient(concentration)
Beispiel #20
0
def batched_las_vegas_algorithm(
    batched_las_vegas_trial_fn, seed=None, name=None):
  """Batched Las Vegas Algorithm.

  This utility encapsulates the notion of a 'batched las_vegas_algorithm'
  (BLVA): a batch of independent (but not necessarily identical) randomized
  computations, each of which will eventually terminate after an unknown number
  of trials [(Babai, 1979)][1]. The number of trials will in general vary
  across batch points.

  The computation is parameterized by a callable representing a single trial for
  the entire batch. The utility runs the callable repeatedly, keeping track of
  which batch points have succeeded, until all have succeeded.

  Because we keep running the callable repeatedly until we've generated at least
  one good value for every batch point, we may generate multiple good values for
  many batch point. In this case, the particular good batch point returned is
  deliberately left unspecified.

  Args:
    batched_las_vegas_trial_fn: A callable that takes a Python integer PRNG seed
      and returns two values. (1) A structure of Tensors containing the results
      of the computation, all with a shape broadcastable with (2) a boolean mask
      representing whether each batch point succeeded.
    seed: Python integer or `tfp.util.SeedStream` instance, for seeding PRNG.
    name: A name to prepend to created ops.
      Default value: `'batched_las_vegas_algorithm'`.

  Returns:
    results, num_iters: A structure of Tensors representing the results of a
    successful computation for each batch point, and a scalar int32 tensor, the
    number of calls to `randomized_computation`.

  #### References

  [1]: Laszlo Babai. Monte-Carlo algorithms in graph isomorphism
       testing. Universite de Montreal, D.M.S. No. 79-10.
  """
  with tf.name_scope(name or 'batched_las_vegas_algorithm'):
    seed_stream = SeedStream(seed, 'batched_las_vegas_algorithm')
    values, good_values_mask = batched_las_vegas_trial_fn(seed_stream())
    num_iters = tf.constant(1)

    def cond(unused_values, good_values_mask, unused_num_iters):
      return tf.math.logical_not(tf.reduce_all(good_values_mask))

    def body(values, good_values_mask, num_iters):
      """Batched Las Vegas Algorithm body."""

      new_values, new_good_values_mask = batched_las_vegas_trial_fn(
          seed_stream())

      values = tf.nest.map_structure(
          lambda new, old: tf.where(new_good_values_mask, new, old),
          *(new_values, values))

      good_values_mask = tf.logical_or(good_values_mask, new_good_values_mask)

      return values, good_values_mask, num_iters+1

    (values, _, num_iters) = tf.while_loop(
        cond, body, (values, good_values_mask, num_iters),
        parallel_iterations=1 if seed is not None else 10)
    return values, num_iters
    def _sample_n(self, n, seed):
        components_seed, mix_seed = samplers.split_seed(
            seed, salt='MixtureSameFamily')
        try:
            seed_stream = SeedStream(seed, salt='MixtureSameFamily')
        except TypeError as e:  # Can happen for Tensor seeds.
            seed_stream = None
            seed_stream_err = e
        try:
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=components_seed)
            if seed_stream is not None:
                seed_stream()  # Advance even if unused.
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `components_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. {}')
            warnings.warn(
                msg.format(self.components_distribution.name,
                           type(self.components_distribution), str(e)))
            x = self.components_distribution.sample(  # [n, B, k, E]
                n, seed=seed_stream())

        event_shape = None
        event_ndims = tensorshape_util.rank(self.event_shape)
        if event_ndims is None:
            event_shape = self.components_distribution.event_shape_tensor()
            event_ndims = ps.rank_from_shape(event_shape)
        event_ndims_static = tf.get_static_value(event_ndims)

        num_components = None
        if event_ndims_static is not None:
            num_components = tf.compat.dimension_value(
                x.shape[-1 - event_ndims_static])
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        if num_components is None:
            num_components = tf.shape(x)[-1 - event_ndims]

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        try:
            mix_sample = self.mixture_distribution.sample(
                n, seed=mix_seed)  # [n, B] or [n]
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                'Falling back to stateful sampling for `mixture_distribution` '
                '{} of type `{}`. Please update to use `tf.random.stateless_*` '
                'RNGs. This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(self.mixture_distribution.name,
                           type(self.mixture_distribution), str(e)))
            mix_sample = self.mixture_distribution.sample(
                n, seed=seed_stream())  # [n, B] or [n]
        mask = tf.one_hot(
            indices=mix_sample,  # [n, B] or [n]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k] or [n, k]

        # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] .
        batch_ndims = ps.rank(x) - event_ndims - 1
        mask_batch_ndims = ps.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = ps.shape(mask)
        mask = tf.reshape(mask,
                          shape=ps.concat([
                              mask_shape[:-1],
                              ps.ones([pad_ndims], dtype=tf.int32),
                              mask_shape[-1:],
                              ps.ones([event_ndims], dtype=tf.int32),
                          ],
                                          axis=0))

        if x.dtype in [
                tf.bfloat16, tf.float16, tf.float32, tf.float64, tf.complex64,
                tf.complex128
        ]:
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        ret = tf.reduce_sum(masked, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            if event_shape is None:
                event_shape = self.components_distribution.event_shape_tensor()
            ret = self._reparameterize_sample(ret, event_shape=event_shape)

        return ret
Beispiel #22
0
  def __init__(self,
               target_log_prob_fn,
               step_size,
               max_tree_depth=10,
               unrolled_leapfrog_steps=1,
               num_trajectories_per_step=1,
               use_auto_batching=True,
               stackless=False,
               backend=None,
               seed=None,
               name=None):
    """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.  Due
        to limitations of the underlying auto-batching system,
        target_log_prob_fn may be invoked with junk data at some batch indexes,
        which it must process without crashing.  (The results at those indexes
        are ignored).
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The
        maximum number of leapfrog steps is bounded by `2**max_tree_depth-1`
        i.e. the number of nodes in a binary tree `max_tree_depth` nodes deep.
        The default setting of 10 takes up to 1023 leapfrog steps.
      unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree
        expansion step. Applies a direct linear multipler to the maximum
        trajectory length implied by max_tree_depth. Defaults to 1. This
        parameter can be useful for amortizing the auto-batching control flow
        overhead.
      num_trajectories_per_step: Python `int` giving the number of NUTS
        trajectories to run as "one" step.  Setting this higher than 1 may be
        favorable for performance by giving the autobatching system the
        opportunity to batch gradients across consecutive trajectories.  The
        intermediate samples are thinned: only the last sample from the run (in
        each batch member) is returned.
      use_auto_batching: Boolean.  If `False`, do not invoke the auto-batching
        system; operate on batch size 1 only.
      stackless: Boolean.  If `True`, invoke the stackless version of
        the auto-batching system.  Only works in Eager mode.
      backend: Auto-batching backend object. Falls back to a default
        TensorFlowBackend().
      seed: Python integer to seed the random number generator.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'nuts_kernel').
    """
    self._parameters = dict(locals())
    del self._parameters["self"]
    self.target_log_prob_fn = target_log_prob_fn
    self.step_size = step_size
    if max_tree_depth < 1:
      raise ValueError(
          "max_tree_depth must be >= 1 but was {}".format(max_tree_depth))
    self.max_tree_depth = max_tree_depth
    self.unrolled_leapfrog_steps = unrolled_leapfrog_steps
    self.num_trajectories_per_step = num_trajectories_per_step
    self.use_auto_batching = use_auto_batching
    self.stackless = stackless
    self.backend = backend
    self._seed_stream = SeedStream(seed, "nuts_one_step")
    self.name = "nuts_kernel" if name is None else name
    # TODO(b/125544625): Identify why we need `use_gradient_tape=True`, i.e.,
    # what's different between `tape.gradient` and `tf.gradient`.
    value_and_gradients_fn = lambda *args: tfp_math.value_and_gradient(  # pylint: disable=g-long-lambda
        self.target_log_prob_fn, args, use_gradient_tape=True)
    self.value_and_gradients_fn = _embed_no_none_gradient_check(
        value_and_gradients_fn)
    max_tree_edges = max_tree_depth - 1
    self.many_steps, self.autobatch_context = _make_evolve_trajectory(
        self.value_and_gradients_fn, max_tree_edges, unrolled_leapfrog_steps,
        self._seed_stream)
    self._block_code_cache = {}
Beispiel #23
0
  def _sample_n(self, n, seed=None):
    power = tf.convert_to_tensor(self.power)
    shape = tf.concat([[n], tf.shape(power)], axis=0)

    has_seed = seed is not None
    seed = SeedStream(seed, salt='zipf')

    minval_u = self._hat_integral(0.5, power=power) + 1.
    maxval_u = self._hat_integral(tf.int64.max - 0.5, power=power)

    def loop_body(should_continue, k):
      """Resample the non-accepted points."""
      # The range of U is chosen so that the resulting sample K lies in
      # [0, tf.int64.max). The final sample, if accepted, is K + 1.
      u = tf.random.uniform(
          shape,
          minval=minval_u,
          maxval=maxval_u,
          dtype=power.dtype,
          seed=seed())

      # Sample the point X from the continuous density h(x) \propto x^(-power).
      x = self._hat_integral_inverse(u, power=power)

      # Rejection-inversion requires a `hat` function, h(x) such that
      # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
      # support. A natural hat function for us is h(x) = x^(-power).
      #
      # After sampling X from h(x), suppose it lies in the interval
      # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
      # if lies to the left of x_K, where x_K is defined by:
      #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
      # where H(x) = \int_x^inf h(x) dx.

      # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
      # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
      # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

      # Update the non-accepted points.
      # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
      k = tf.where(should_continue, tf.floor(x + 0.5), k)
      accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp(
          self._log_prob(k + 1, power=power)))

      return [should_continue & (~accept), k]

    should_continue, samples = tf.while_loop(
        cond=lambda should_continue, *ignore: tf.reduce_any(should_continue),
        body=loop_body,
        loop_vars=[
            tf.ones(shape, dtype=tf.bool),  # should_continue
            tf.zeros(shape, dtype=power.dtype),  # k
        ],
        parallel_iterations=1 if has_seed else 10,
        maximum_iterations=self.sample_maximum_iterations,
    )
    samples = samples + 1.

    if self.validate_args and dtype_util.is_integer(self.dtype):
      samples = distribution_util.embed_check_integer_casting_closed(
          samples, target_dtype=self.dtype, assert_positive=True)

    samples = tf.cast(samples, self.dtype)

    if self.validate_args:
      npdt = dtype_util.as_numpy_dtype(self.dtype)
      v = npdt(dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan)
      samples = tf.where(should_continue, v, samples)

    return samples
Beispiel #24
0
  def _sample_n(self, n, seed=None):
    seed = SeedStream(seed, salt='vom_mises_fisher')
    # The sampling strategy relies on the fact that vMF variates are symmetric
    # about the mean direction. Accordingly, if we have a sampling strategy for
    # the away-from-mean angle, then we can uniformly sample the remaining
    # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
    # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
    #
    # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
    # von-Mises distributed `x` value in [-1, 1], then uniformly select what
    # amounts to a "up" or "down" additional degree of freedom after unit
    # normalizing, followed by a final rotation to the desired mean direction
    # from a basis of (1, 0).
    #
    # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
    # unit sphere over which the distribution is uniform, in particular the
    # circle where x = \hat{x} intersects the unit sphere. We pick a point on
    # that circle, then rotate to the desired mean direction from a basis of
    # (1, 0, 0).
    event_dim = (
        tf.compat.dimension_value(self.event_shape[0]) or
        self._event_shape_tensor()[0])

    sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
    dim = tf.cast(event_dim - 1, self.dtype)
    if event_dim == 3:
      samples_dim0 = self._sample_3d(n, seed=seed)
    else:
      # Wood'94 provides a rejection algorithm to sample the x coordinate.
      # Wood'94 definition of b:
      # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
      # https://stats.stackexchange.com/questions/156729 suggests:
      b = dim / (2 * self.concentration +
                 tf.sqrt(4 * self.concentration**2 + dim**2))
      # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
      #     https://github.com/nicola-decao/s-vae-tf/
      x = (1 - b) / (1 + b)
      c = self.concentration * x + dim * tf.math.log1p(-x**2)
      beta = beta_lib.Beta(dim / 2, dim / 2)

      def cond_fn(w, should_continue):
        del w
        return tf.reduce_any(should_continue)

      def body_fn(w, should_continue):
        z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
        # set_shape needed here because of b/139013403
        z.set_shape(w.shape)
        w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
        w = tf.debugging.check_numerics(w, 'w')
        unif = tf.random.uniform(
            sample_batch_shape, seed=seed(), dtype=self.dtype)
        # set_shape needed here because of b/139013403
        unif.set_shape(w.shape)
        should_continue = tf.logical_and(
            should_continue,
            self.concentration * w + dim * tf.math.log1p(-x * w) - c <
            tf.math.log(unif))
        return w, should_continue

      w = tf.zeros(sample_batch_shape, dtype=self.dtype)
      should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
      samples_dim0 = tf.while_loop(
          cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0]
      samples_dim0 = samples_dim0[..., tf.newaxis]
    if not self._allow_nan_stats:
      # Verify samples are w/in -1, 1, with useful error output tensors (top
      # value rather than all values).
      with tf.control_dependencies([
          assert_util.assert_less_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(1.01),
              data=[tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]]),
          assert_util.assert_greater_equal(
              samples_dim0,
              dtype_util.as_numpy_dtype(self.dtype)(-1.01),
              data=[-tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]])
      ]):
        samples_dim0 = tf.identity(samples_dim0)
    samples_otherdims_shape = tf.concat([sample_batch_shape, [event_dim - 1]],
                                        axis=0)
    unit_otherdims = tf.math.l2_normalize(
        tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
        axis=-1)
    samples = tf.concat([
        samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
        tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
    ], axis=-1)
    samples = tf.math.l2_normalize(samples, axis=-1)
    if not self._allow_nan_stats:
      samples = tf.debugging.check_numerics(samples, 'samples')

    # Runtime assert that samples are unit length.
    if not self._allow_nan_stats:
      worst, idx = tf.math.top_k(
          tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
      with tf.control_dependencies([
          assert_util.assert_near(
              dtype_util.as_numpy_dtype(self.dtype)(0),
              worst,
              data=[
                  worst, idx,
                  tf.gather(tf.reshape(samples, [-1, event_dim]), idx)
              ],
              atol=1e-4,
              summarize=100)
      ]):
        samples = tf.identity(samples)
    # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
    # Now, we move the mode to `self.mean_direction` using a rotation matrix.
    if not self._allow_nan_stats:
      # Assert that the basis vector rotates to the mean direction, as expected.
      basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                      self.dtype)
      with tf.control_dependencies([
          assert_util.assert_less(
              tf.linalg.norm(
                  self._rotate(basis) - self.mean_direction, axis=-1),
              dtype_util.as_numpy_dtype(self.dtype)(1e-5))
      ]):
        return self._rotate(samples)
    return self._rotate(samples)
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code
            # path.
            samples = []
            cat_samples = self.cat.sample(n, seed=seed)
            stream = SeedStream(seed, salt='Mixture')

            for c in range(self.num_components):
                samples.append(self.components[c].sample(n, seed=stream()))
            stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
            x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
            npdt = dtype_util.as_numpy_dtype(x.dtype)
            mask = tf.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                on_value=npdt(1),
                off_value=npdt(0))  # [n, B, k]
            mask = distribution_util.pad_mixture_dimensions(
                mask, self, self._cat,
                tensorshape_util.rank(
                    self._static_event_shape))  # [n, B, k, [1]*e]
            return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        n = tf.convert_to_tensor(n, name='n')
        static_n = tf.get_static_value(n)
        n = int(static_n) if static_n is not None else n
        cat_samples = self.cat.sample(n, seed=seed)

        static_samples_shape = cat_samples.shape
        if tensorshape_util.is_fully_defined(static_samples_shape):
            samples_shape = tensorshape_util.as_list(static_samples_shape)
            samples_size = tensorshape_util.num_elements(static_samples_shape)
        else:
            samples_shape = tf.shape(cat_samples)
            samples_size = tf.size(cat_samples)
        static_batch_shape = self.batch_shape
        if tensorshape_util.is_fully_defined(static_batch_shape):
            batch_shape = tensorshape_util.as_list(static_batch_shape)
            batch_size = tensorshape_util.num_elements(static_batch_shape)
        else:
            batch_shape = tf.shape(cat_samples)[1:]
            batch_size = tf.reduce_prod(batch_shape)
        static_event_shape = self.event_shape
        if tensorshape_util.is_fully_defined(static_event_shape):
            event_shape = np.array(
                tensorshape_util.as_list(static_event_shape), dtype=np.int32)
        else:
            event_shape = None

        # Get indices into the raw cat sampling tensor. We will
        # need these to stitch sample values back out after sampling
        # within the component partitions.
        samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                         samples_shape)

        # Partition the raw indices so that we can use
        # dynamic_stitch later to reconstruct the samples from the
        # known partitions.
        partitioned_samples_indices = tf.dynamic_partition(
            data=samples_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)

        # Copy the batch indices n times, as we will need to know
        # these to pull out the appropriate rows within the
        # component partitions.
        batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]),
                                       samples_shape)

        # Explanation of the dynamic partitioning below:
        #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
        # Suppose partitions are:
        #     [1 1 0 0 1 1]
        # After partitioning, batch indices are cut as:
        #     [batch_indices[x] for x in 2, 3]
        #     [batch_indices[x] for x in 0, 1, 4, 5]
        # i.e.
        #     [1 1] and [0 0 0 0]
        # Now we sample n=2 from part 0 and n=4 from part 1.
        # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
        # and for part 1 we want samples from batch entries 0, 0, 0, 0
        #   (samples 0, 1, 2, 3).
        partitioned_batch_indices = tf.dynamic_partition(
            data=batch_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)
        samples_class = [None for _ in range(self.num_components)]

        stream = SeedStream(seed, salt='Mixture')

        for c in range(self.num_components):
            n_class = tf.size(partitioned_samples_indices[c])
            samples_class_c = self.components[c].sample(n_class, seed=stream())

            if event_shape is None:
                batch_ndims = prefer_static.rank_from_shape(batch_shape)
                event_shape = tf.shape(samples_class_c)[1 + batch_ndims:]

            # Pull out the correct batch entries from each index.
            # To do this, we may have to flatten the batch shape.

            # For sample s, batch element b of component c, we get the
            # partitioned batch indices from
            # partitioned_batch_indices[c]; and shift each element by
            # the sample index. The final lookup can be thought of as
            # a matrix gather along locations (s, b) in
            # samples_class_c where the n_class rows correspond to
            # samples within this component and the batch_size columns
            # correspond to batch elements within the component.
            #
            # Thus the lookup index is
            #   lookup[c, i] = batch_size * s[i] + b[c, i]
            # for i = 0 ... n_class[c] - 1.
            lookup_partitioned_batch_indices = (
                batch_size * tf.range(n_class) + partitioned_batch_indices[c])
            samples_class_c = tf.reshape(
                samples_class_c,
                tf.concat([[n_class * batch_size], event_shape], 0))
            samples_class_c = tf.gather(samples_class_c,
                                        lookup_partitioned_batch_indices,
                                        name='samples_class_c_gather')
            samples_class[c] = samples_class_c

        # Stitch back together the samples across the components.
        lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices,
                                         data=samples_class)
        # Reshape back to proper sample, batch, and event shape.
        ret = tf.reshape(lhs_flat_ret,
                         tf.concat([samples_shape, event_shape], 0))
        tensorshape_util.set_shape(
            ret,
            tensorshape_util.concatenate(static_samples_shape,
                                         self.event_shape))
        return ret
Beispiel #26
0
  def __init__(self,
               output_shape=(32, 32, 3),
               num_glow_blocks=3,
               num_steps_per_block=32,
               coupling_bijector_fn=None,
               exit_bijector_fn=None,
               grab_after_block=None,
               use_actnorm=True,
               seed=None,
               validate_args=False,
               name='glow'):
    """Creates the Glow bijector.

    Args:
      output_shape: A list of integers, specifying the event shape of the
        output, of the bijectors forward pass (the image).  Specified as
        [H, W, C].
        Default Value: (32, 32, 3)
      num_glow_blocks: An integer, specifying how many downsampling levels to
        include in the model. This must divide equally into both H and W,
        otherwise the bijector would not be invertible.
        Default Value: 3
      num_steps_per_block: An integer specifying how many Affine Coupling and
        1x1 convolution layers to include at each level of the spatial
        hierarchy.
        Default Value: 32 (i.e. the value used in the original glow paper).
      coupling_bijector_fn: A function which takes the argument `input_shape`
        and returns a callable neural network (e.g. a keras.Sequential). The
        network should either return a tensor with the same event shape as
        `input_shape` (this will employ additive coupling), a tensor with the
        same height and width as `input_shape` but twice the number of channels
        (this will employ affine coupling), or a bijector which takes in a
        tensor with event shape `input_shape`, and returns a tensor with shape
        `input_shape`.
      exit_bijector_fn: Similar to coupling_bijector_fn, exit_bijector_fn is
        a function which takes the argument `input_shape` and `output_chan`
        and returns a callable neural network. The neural network it returns
        should take a tensor of shape `input_shape` as the input, and return
        one of three options: A tensor with `output_chan` channels, a tensor
        with `2 * output_chan` channels, or a bijector. Additional details can
        be found in the documentation for ExitBijector.
      grab_after_block: A tuple of floats, specifying what fraction of the
        remaining channels to remove following each glow block. Glow will take
        the integer floor of this number multiplied by the remaining number of
        channels. The default is half at each spatial hierarchy.
        Default value: None (this will take out half of the channels after each
          block.
      use_actnorm: A bool deciding whether or not to use actnorm. Data-dependent
        initialization is used to initialize this layer.
        Default value: `False`
      seed: A seed to control randomness in the 1x1 convolution initialization.
        Default value: `None` (i.e., non-reproducible sampling).
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
        Default value: `False`
      name: Python `str`, name given to ops managed by this object.
        Default value: `'glow'`.
    """
    # Make sure that the input shape is fully defined.
    if not tensorshape_util.is_fully_defined(output_shape):
      raise ValueError('Shape must be fully defined.')
    if tensorshape_util.rank(output_shape) != 3:
      raise ValueError('Shape ndims must be 3 for images.  Your shape is'
                       '{}'.format(tensorshape_util.rank(output_shape)))

    num_glow_blocks_ = tf.get_static_value(num_glow_blocks)
    if (num_glow_blocks_ is None or
        int(num_glow_blocks_) != num_glow_blocks_ or
        num_glow_blocks_ < 1):
      raise ValueError('Argument `num_glow_blocks` must be a statically known'
                       'positive `int` (saw: {}).'.format(num_glow_blocks))
    num_glow_blocks = int(num_glow_blocks_)

    output_shape = tensorshape_util.as_list(output_shape)
    h, w, c = output_shape
    n = num_glow_blocks
    nsteps = num_steps_per_block

    # Default Glow: Half of the channels are split off after each block,
    # and after the final block, no channels are split off.
    if grab_after_block is None:
      grab_after_block = tuple([0.5] * (n - 1) + [0.])

    # Thing we know must be true: h and w are evenly divisible by 2, n times.
    # Otherwise, the squeeze bijector will not work.
    if w % 2**n != 0:
      raise ValueError('Width must be divisible by 2 at least n times.'
                       'Saw: {} % {} != 0'.format(w, 2**n))
    if h % 2**n != 0:
      raise ValueError('Height should be divisible by 2 at least n times.')
    if h // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image height '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, h,
                                       int(np.log(h) / np.log(2.))))
    if w // 2**n < 1:
      raise ValueError('num_glow_blocks ({0}) is too large. The image width '
                       '({1}) must be divisible by 2 no more than {2} '
                       'times.'.format(num_glow_blocks, w,
                                       int(np.log(h) / np.log(2.))))

    # Other things we want to be true:
    # - The number of times we take must be equal to the number of glow blocks.
    if len(grab_after_block) != num_glow_blocks:
      raise ValueError('Length of grab_after_block ({0}) must match the number'
                       'of blocks ({1}).'.format(len(grab_after_block),
                                                 num_glow_blocks))

    self._blockwise_splits = self._get_blockwise_splits(output_shape,
                                                        grab_after_block[::-1])

    # Now check on the values of blockwise splits
    if any([bs[0] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[0] for bs in self._blockwise_splits].index(True)
      raise ValueError('At at least one exit, you are taking out all of your '
                       'channels, and therefore have no inputs to later blocks.'
                       ' Try setting grab_after_block to a lower value at index'
                       '{}.'.format(first_offender))

    if any(np.isclose(gab, 0) for gab in grab_after_block):
      # Special case: if specifically exiting no channels, then the exit is
      # just an identity bijector.
      pass
    elif any([bs[1] < 1 for bs in self._blockwise_splits]):
      first_offender = [bs[1] for bs in self._blockwise_splits].index(True)
      raise ValueError('At least one of your layers has < 1 output channels. '
                       'This means you set grab_at_block too small. '
                       'Try setting grab_after_block to a larger value at index'
                       '{}.'.format(first_offender))

    # Lets start to build our bijector. We assume that the distribution is 1
    # dimensional. First, lets reshape it to an image.
    glow_chain = [
        reshape.Reshape(
            event_shape_out=[h // 2**n, w // 2**n, c * 4**n],
            event_shape_in=[h * w * c])
    ]

    seedstream = SeedStream(seed=seed, salt='random_beta')

    for i in range(n):

      # This is the shape of the current tensor
      current_shape = (h // 2**n * 2**i, w // 2**n * 2**i, c * 4**(i + 1))

      # This is the shape of the input to both the glow block and exit bijector.
      this_nchan = sum(self._blockwise_splits[i][0:2])
      this_input_shape = (h // 2**n * 2**i, w // 2**n * 2**i, this_nchan)

      glow_chain.append(invert.Invert(ExitBijector(current_shape,
                                                   self._blockwise_splits[i],
                                                   exit_bijector_fn)))

      glow_block = GlowBlock(input_shape=this_input_shape,
                             num_steps=nsteps,
                             coupling_bijector_fn=coupling_bijector_fn,
                             use_actnorm=use_actnorm,
                             seedstream=seedstream)

      if self._blockwise_splits[i][2] == 0:
        # All channels are passed to the RealNVP
        glow_chain.append(glow_block)
      else:
        # Some channels are passed around the block.
        # This is done with the Blockwise bijector.
        glow_chain.append(
            blockwise.Blockwise(
                [glow_block, identity.Identity()],
                [sum(self._blockwise_splits[i][0:2]),
                 self._blockwise_splits[i][2]]))

      # Finally, lets expand the channels into spatial features.
      glow_chain.append(
          Expand(input_shape=[
              h // 2**n * 2**i,
              w // 2**n * 2**i,
              c * 4**n // 4**i,
          ]))

    glow_chain = glow_chain[::-1]
    # To finish off, we initialize the bijector with the chain we've built
    # This way, the rest of the model attributes are taken care of for us.
    super(Glow, self).__init__(
        bijectors=glow_chain, validate_args=validate_args, name=name)
Beispiel #27
0
    def _sample_n(self, n, seed=None):
        with tf.control_dependencies(self._runtime_assertions):
            strm = SeedStream(seed, salt="HiddenMarkovModel")

            num_states = self._num_states

            batch_shape = self.batch_shape_tensor()
            batch_size = tf.reduce_prod(batch_shape)

            # The batch sizes of the underlying initial distributions and
            # transition distributions might not match the batch size of
            # the HMM distribution.
            # As a result we need to ask for more samples from the
            # underlying distributions and then reshape the results into
            # the correct batch size for the HMM.
            init_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._initial_distribution.batch_shape_tensor()))
            init_state = self._initial_distribution.sample(n * init_repeat,
                                                           seed=strm())
            init_state = tf.reshape(init_state, [n, batch_size])
            # init_state :: n batch_size

            transition_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._transition_distribution.batch_shape_tensor()[:-1]))

            def generate_step(state, _):
                """Take a single step in Markov chain."""

                gen = self._transition_distribution.sample(n *
                                                           transition_repeat,
                                                           seed=strm())
                # gen :: (n * transition_repeat) transition_batch

                new_states = tf.reshape(gen, [n, batch_size, num_states])

                # new_states :: n batch_size num_states

                old_states_one_hot = tf.one_hot(state,
                                                num_states,
                                                dtype=tf.int32)

                # old_states :: n batch_size num_states

                return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

            def _scan_multiple_steps():
                """Take multiple steps with tf.scan."""
                dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
                if seed is not None:
                    # Force parallel_iterations to 1 to ensure reproducibility
                    # b/139210489
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state,
                                            parallel_iterations=1)
                else:
                    # Invoke default parallel_iterations behavior
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state)

                # TODO(b/115618503): add/use prepend_initializer to tf.scan
                return tf.concat([[init_state], hidden_states], axis=0)

            hidden_states = prefer_static.cond(
                self._num_steps > 1, _scan_multiple_steps,
                lambda: init_state[tf.newaxis, ...])

            hidden_one_hot = tf.one_hot(
                hidden_states,
                num_states,
                dtype=self._observation_distribution.dtype)
            # hidden_one_hot :: num_steps n batch_size num_states

            # The observation distribution batch size might not match
            # the required batch size so as with the initial and
            # transition distributions we generate more samples and
            # reshape.
            observation_repeat = (batch_size // tf.reduce_prod(
                self._observation_distribution.batch_shape_tensor()[:-1]))

            possible_observations = self._observation_distribution.sample(
                [self._num_steps, observation_repeat * n], seed=strm())

            inner_shape = self._observation_distribution.event_shape

            # possible_observations :: num_steps (observation_repeat * n)
            #                          observation_batch[:-1] num_states inner_shape

            possible_observations = tf.reshape(
                possible_observations,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           inner_shape],
                          axis=0))

            # possible_observations :: steps n batch_size num_states inner_shape

            hidden_one_hot = tf.reshape(
                hidden_one_hot,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           tf.ones_like(inner_shape)],
                          axis=0))

            # hidden_one_hot :: steps n batch_size num_states "inner_shape"

            observations = tf.reduce_sum(hidden_one_hot *
                                         possible_observations,
                                         axis=-1 - tf.size(inner_shape))

            # observations :: steps n batch_size inner_shape

            observations = distribution_util.move_dimension(
                observations, 0, 1 + tf.size(batch_shape))

            # returned :: n batch_shape steps inner_shape

            return observations
Beispiel #28
0
    def __init__(
        self,
        target_log_prob_fn,
        initial_state,
        initial_covariance=None,
        initial_covariance_scaling=2.38**2,
        covariance_scaling_reducer=0.7,
        covariance_scaling_limiter=0.01,
        covariance_burnin=100,
        target_accept_ratio=0.234,
        pu=0.95,
        fixed_variance=0.01,
        extra_getter_fn=rwm_extra_getter_fn,
        extra_setter_fn=rwm_extra_setter_fn,
        log_accept_prob_getter_fn=rwm_log_accept_prob_getter_fn,
        seed=None,
        name=None,
    ):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` and returns its (possibly unnormalized) log-density
        under the target distribution.
      initial_state: Python `list` of `Tensor`s representing the initial
        state of each parameter.
      initial_covariance: Python `list` of `Tensor`s representing the
        initial covariance of the proposal. The `initial_covariance` and 
        `initial_state` should have identical `dtype`s and 
        batch dimensions.  If `initial_covariance` is `None` then it 
        initialized to a Python `list` of `Tensor`s where each tensor is 
        the identity matrix multiplied by 0.001; the `list` structure will
        be identical to `initial_state`. The covariance matrix is tuned
        during the evolution of the MCMC chain.
        Default value: `None`.
      initial_covariance_scaling: Python floating point number representing a 
        the initial value of the `covariance_scaling`. The value of 
        `covariance_scaling` is tuned during the evolution of the MCMC chain.
        Let d represent the number of parameters e.g. as given by the 
        `initial_state`. The ratio given by the `covariance_scaling` divided
        by d is used to multiply the running covariance. The covariance
        scaling factor multiplied by the covariance matrix is used in the
        proposal at each step.
        Default value: 2.38**2.
      covariance_scaling_reducer: Python floating point number, bounded over the 
        range (0.5,1.0], representing the constant factor used during the
        adaptation of the `covariance_scaling`. 
        Default value: 0.7.
      covariance_scaling_limiter: Python floating point number, bounded between 
        0.0 and 1.0, which places a limit on the maximum amount the
        `covariance_scaling` value can be purturbed at each interaction of the 
        MCMC chain.
        Default value: 0.01.
      covariance_burnin: Python integer number of steps to take before starting to 
        compute the running covariance.
        Default value: 100.
      target_accept_ratio: Python floating point number, bounded between 0.0 and 1.0,
        representing the target acceptance probability of the 
        Metropolis–Hastings algorithm.
        Default value: 0.234.
      pu: Python floating point number, bounded between 0.0 and 1.0, representing the 
        bounded convergence parameter.  See `random_walk_mvnorm_fn()` for further
        details.
        Default value: 0.95.
      fixed_variance: Python floating point number representing the variance of
        the fixed proposal distribution. See `random_walk_mvnorm_fn` for 
        further details.
        Default value: 0.01.
      extra_getter_fn: A callable with the signature
        `(kernel_results) -> extra` where `kernel_results` are the results
        of the `inner_kernel`, and `extra` is a nested collection of 
        `Tensor`s.
      extra_setter_fn: A callable with the signature
        `(kernel_results, args) -> new_kernel_results` where
        `kernel_results` are the results of the `inner_kernel`, `args`
        are a nested collection of `Tensor`s with the same
        structure as returned by the `extra_getter_fn`, and
        `new_kernel_results` are a copy of `kernel_results` with `args`
        in the `extra` field set.
      log_accept_prob_getter_fn: A callable with the signature
        `(kernel_results) -> log_accept_prob` where `kernel_results` are the
        results of the `inner_kernel`, and `log_accept_prob` is either a 
        a scalar, or has shape [num_chains].
      seed: Python integer to seed the random number generator.
        Default value: `None`.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None`.

    Returns:
      next_state: Tensor or list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if `initial_covariance_scaling` is less than or equal
        to 0.0.
      ValueError: if `covariance_scaling_reducer` is less than or equal
        to 0.5 or greater than 1.0.
      ValueError: if `covariance_scaling_limiter` is less than 0.0 or
        greater than 1.0.
      ValueError: if `covariance_burnin` is less than 0.
      ValueError: if `target_accept_ratio` is less than 0.0 or
        greater than 1.0.
      ValueError: if `pu` is less than 0.0 or greater than 1.0.
      ValueError: if `fixed_variance` is less than 0.0.
    """
        with tf.name_scope(
                mcmc_util.make_name(name,
                                    "AdaptiveRandomWalkMetropolisHastings",
                                    "__init__")) as name:
            if initial_covariance_scaling <= 0.0:
                raise ValueError(
                    "`{}` must be a `float` greater than 0.0".format(
                        "initial_covariance_scaling"))
            if covariance_scaling_reducer <= 0.5 or covariance_scaling_reducer > 1.0:
                raise ValueError(
                    "`{}` must be a `float` greater than 0.5 and less than or equal to 1.0."
                    .format("covariance_scaling_reducer"))
            if covariance_scaling_limiter < 0.0 or covariance_scaling_limiter > 1.0:
                raise ValueError(
                    "`{}` must be a `float` between 0.0 and 1.0.".format(
                        "covariance_scaling_limiter"))
            if covariance_burnin < 0:
                raise ValueError(
                    "`{}` must be a `integer` greater or equal to 0.".format(
                        "covariance_burnin"))
            if target_accept_ratio <= 0.0 or target_accept_ratio > 1.0:
                raise ValueError(
                    "`{}` must be a `float` between 0.0 and 1.0.".format(
                        "target_accept_ratio"))
            if pu < 0.0 or pu > 1.0:
                raise ValueError(
                    "`{}` must be a `float` between 0.0 and 1.0.".format("pu"))
            if fixed_variance < 0.0:
                raise ValueError(
                    "`{}` must be a `float` greater than 0.0.".format(
                        "fixed_variance"))

        if mcmc_util.is_list_like(initial_state):
            initial_state_parts = list(initial_state)
        else:
            initial_state_parts = [initial_state]
        initial_state_parts = [
            tf.convert_to_tensor(s, name="initial_state")
            for s in initial_state_parts
        ]

        shape = tf.stack(initial_state_parts).shape
        dtype = dtype_util.base_dtype(tf.stack(initial_state_parts).dtype)

        if initial_covariance is None:
            initial_covariance = 0.001 * tf.eye(
                num_rows=shape[-1], dtype=dtype, batch_shape=[shape[0]])
        else:
            initial_covariance = tf.stack(initial_covariance)

        if mcmc_util.is_list_like(initial_covariance):
            initial_covariance_parts = list(initial_covariance)
        else:
            initial_covariance_parts = [initial_covariance]
        initial_covariance_parts = [
            tf.convert_to_tensor(s, name="initial_covariance")
            for s in initial_covariance_parts
        ]

        self._running_covar = stats.RunningCovariance(shape=(1, shape[-1]),
                                                      dtype=dtype,
                                                      event_ndims=1)
        self._accum_covar = self._running_covar.initialize()

        probs = tf.expand_dims(tf.ones([shape[0]], dtype=dtype) * pu, axis=1)
        self._u = Bernoulli(probs=probs, dtype=tf.dtypes.int32)
        self._initial_u = tf.zeros_like(self._u.sample(seed=seed),
                                        dtype=tf.dtypes.int32)

        name = mcmc_util.make_name(name,
                                   "AdaptiveRandomWalkMetropolisHastings", "")
        seed_stream = SeedStream(seed,
                                 salt="AdaptiveRandomWalkMetropolisHastings")

        self._parameters = dict(
            target_log_prob_fn=target_log_prob_fn,
            initial_state=initial_state,
            initial_covariance=initial_covariance,
            initial_covariance_scaling=initial_covariance_scaling,
            covariance_scaling_reducer=covariance_scaling_reducer,
            covariance_scaling_limiter=covariance_scaling_limiter,
            covariance_burnin=covariance_burnin,
            target_accept_ratio=target_accept_ratio,
            pu=pu,
            fixed_variance=fixed_variance,
            extra_getter_fn=extra_getter_fn,
            extra_setter_fn=extra_setter_fn,
            log_accept_prob_getter_fn=log_accept_prob_getter_fn,
            seed=seed,
            name=name,
        )
        self._impl = metropolis_hastings.MetropolisHastings(
            inner_kernel=random_walk_metropolis.UncalibratedRandomWalk(
                target_log_prob_fn=target_log_prob_fn,
                new_state_fn=random_walk_mvnorm_fn(
                    covariance=initial_covariance_parts,
                    pu=pu,
                    fixed_variance=fixed_variance,
                    is_adaptive=self._initial_u,
                    name=name,
                ),
                name=name,
            ),
            name=name,
        )
Beispiel #29
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 max_tree_depth=10,
                 max_energy_diff=1000.,
                 unrolled_leapfrog_steps=1,
                 parallel_iterations=10,
                 seed=None,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The
        maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e.
        the number of nodes in a binary tree `max_tree_depth` nodes deep. The
        default setting of 10 takes up to 1024 leapfrog steps.
      max_energy_diff: Scaler threshold of energy differences at each leapfrog,
        divergence samples are defined as leapfrog steps that exceed this
        threshold. Default to 1000.
      unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree
        expansion step. Applies a direct linear multipler to the maximum
        trajectory length implied by max_tree_depth. Defaults to 1.
      parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
        Note that if you set the seed to have deterministic output you should
        also set `parallel_iterations` to 1.
      seed: Python integer to seed the random number generator.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'nuts_kernel').
    """
        with tf.name_scope(name or 'NoUTurnSampler') as name:
            # Process `max_tree_depth` argument.
            max_tree_depth = tf.get_static_value(max_tree_depth)
            if max_tree_depth is None or max_tree_depth < 1:
                raise ValueError(
                    'max_tree_depth must be known statically and >= 1 but was '
                    '{}'.format(max_tree_depth))
            self._max_tree_depth = max_tree_depth

            # Compute parameters derived from `max_tree_depth`.
            instruction_array = build_tree_uturn_instruction(max_tree_depth,
                                                             init_memory=-1)
            [write_instruction_numpy, read_instruction_numpy
             ] = generate_efficient_write_read_instruction(instruction_array)

            # TensorArray version of the read/write instruction need to be created
            # within the function call to be compatible with XLA. Here we store the
            # numpy version of the instruction and convert it to TensorArray later.
            self._write_instruction = write_instruction_numpy
            self._read_instruction = read_instruction_numpy

            # Process all other arguments.
            self._target_log_prob_fn = target_log_prob_fn
            if not tf.nest.is_nested(step_size):
                step_size = [step_size]
            step_size = [
                tf.convert_to_tensor(s, dtype_hint=tf.float32)
                for s in step_size
            ]
            self._step_size = step_size

            self._parameters = dict(
                target_log_prob_fn=target_log_prob_fn,
                step_size=step_size,
                max_tree_depth=max_tree_depth,
                max_energy_diff=max_energy_diff,
                unrolled_leapfrog_steps=unrolled_leapfrog_steps,
                parallel_iterations=parallel_iterations,
                seed=seed,
                name=name,
            )
            self._parallel_iterations = parallel_iterations
            self._seed_stream = SeedStream(seed, salt='nuts_one_step')
            self._unrolled_leapfrog_steps = unrolled_leapfrog_steps
            self._name = name
            self._max_energy_diff = max_energy_diff
Beispiel #30
0
def sample_halton_sequence(dim,
                           num_results=None,
                           sequence_indices=None,
                           dtype=tf.float32,
                           randomized=True,
                           seed=None,
                           name=None):
    r"""Returns a sample from the `dim` dimensional Halton sequence.

  Warning: The sequence elements take values only between 0 and 1. Care must be
  taken to appropriately transform the domain of a function if it differs from
  the unit cube before evaluating integrals using Halton samples. It is also
  important to remember that quasi-random numbers without randomization are not
  a replacement for pseudo-random numbers in every context. Quasi random numbers
  are completely deterministic and typically have significant negative
  autocorrelation unless randomization is used.

  Computes the members of the low discrepancy Halton sequence in dimension
  `dim`. The `dim`-dimensional sequence takes values in the unit hypercube in
  `dim` dimensions. Currently, only dimensions up to 1000 are supported. The
  prime base for the k-th axes is the k-th prime starting from 2. For example,
  if `dim` = 3, then the bases will be [2, 3, 5] respectively and the first
  element of the non-randomized sequence will be: [0.5, 0.333, 0.2]. For a more
  complete description of the Halton sequences see
  [here](https://en.wikipedia.org/wiki/Halton_sequence). For low discrepancy
  sequences and their applications see
  [here](https://en.wikipedia.org/wiki/Low-discrepancy_sequence).

  If `randomized` is true, this function produces a scrambled version of the
  Halton sequence introduced by [Owen (2017)][1]. For the advantages of
  randomization of low discrepancy sequences see [here](
  https://en.wikipedia.org/wiki/Quasi-Monte_Carlo_method#Randomization_of_quasi-Monte_Carlo).

  The number of samples produced is controlled by the `num_results` and
  `sequence_indices` parameters. The user must supply either `num_results` or
  `sequence_indices` but not both.
  The former is the number of samples to produce starting from the first
  element. If `sequence_indices` is given instead, the specified elements of
  the sequence are generated. For example, sequence_indices=tf.range(10) is
  equivalent to specifying n=10.

  #### Examples

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp

  # Produce the first 1000 members of the Halton sequence in 3 dimensions.
  num_results = 1000
  dim = 3
  sample = tfp.mcmc.sample_halton_sequence(
    dim,
    num_results=num_results,
    seed=127)

  # Evaluate the integral of x_1 * x_2^2 * x_3^3  over the three dimensional
  # hypercube.
  powers = tf.range(1.0, limit=dim + 1)
  integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1))
  true_value = 1.0 / tf.reduce_prod(powers + 1.0)
  with tf.Session() as session:
    values = session.run((integral, true_value))

  # Produces a relative absolute error of 1.7%.
  print ("Estimated: %f, True Value: %f" % values)

  # Now skip the first 1000 samples and recompute the integral with the next
  # thousand samples. The sequence_indices argument can be used to do this.


  sequence_indices = tf.range(start=1000, limit=1000 + num_results,
                              dtype=tf.int32)
  sample_leaped = tfp.mcmc.sample_halton_sequence(
      dim,
      sequence_indices=sequence_indices,
      seed=111217)

  integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers,
                                                  axis=-1))
  with tf.Session() as session:
    values = session.run((integral_leaped, true_value))
  # Now produces a relative absolute error of 0.05%.
  print ("Leaped Estimated: %f, True Value: %f" % values)
  ```

  Args:
    dim: Positive Python `int` representing each sample's `event_size.` Must
      not be greater than 1000.
    num_results: (Optional) Positive scalar `Tensor` of dtype int32. The number
      of samples to generate. Either this parameter or sequence_indices must
      be specified but not both. If this parameter is None, then the behaviour
      is determined by the `sequence_indices`.
      Default value: `None`.
    sequence_indices: (Optional) `Tensor` of dtype int32 and rank 1. The
      elements of the sequence to compute specified by their position in the
      sequence. The entries index into the Halton sequence starting with 0 and
      hence, must be whole numbers. For example, sequence_indices=[0, 5, 6] will
      produce the first, sixth and seventh elements of the sequence. If this
      parameter is None, then the `num_results` parameter must be specified
      which gives the number of desired samples starting from the first sample.
      Default value: `None`.
    dtype: (Optional) The dtype of the sample. One of: `float16`, `float32` or
      `float64`.
      Default value: `tf.float32`.
    randomized: (Optional) bool indicating whether to produce a randomized
      Halton sequence. If True, applies the randomization described in
      [Owen (2017)][1].
      Default value: `True`.
    seed: (Optional) Python integer to seed the random number generator. Only
      used if `randomized` is True. If not supplied and `randomized` is True,
      no seed is set.
      Default value: `None`.
    name:  (Optional) Python `str` describing ops managed by this function. If
      not supplied the name of this function is used.
      Default value: "sample_halton_sequence".

  Returns:
    halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype
      and `shape` `[num_results, dim]` if `num_results` was specified or shape
      `[s, dim]` where s is the size of `sequence_indices` if `sequence_indices`
      were specified.

  Raises:
    ValueError: if both `sequence_indices` and `num_results` were specified or
      if dimension `dim` is less than 1 or greater than 1000.

  #### References

  [1]: Art B. Owen. A randomized Halton algorithm in R. _arXiv preprint
       arXiv:1706.02808_, 2017. https://arxiv.org/abs/1706.02808
  """
    if dim < 1 or dim > _MAX_DIMENSION:
        raise ValueError(
            'Dimension must be between 1 and {}. Supplied {}'.format(
                _MAX_DIMENSION, dim))
    if (num_results is None) == (sequence_indices is None):
        raise ValueError('Either `num_results` or `sequence_indices` must be'
                         ' specified but not both.')

    if not dtype.is_floating:
        raise ValueError('dtype must be of `float`-type')

    with tf.name_scope(name or 'sample'):
        # Here and in the following, the shape layout is as follows:
        # [sample dimension, event dimension, coefficient dimension].
        # The coefficient dimension is an intermediate axes which will hold the
        # weights of the starting integer when expressed in the (prime) base for
        # an event dimension.
        if num_results is not None:
            num_results = tf.convert_to_tensor(num_results)
        if sequence_indices is not None:
            sequence_indices = tf.convert_to_tensor(sequence_indices)
        indices = _get_indices(num_results, sequence_indices, dtype)
        radixes = tf.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1])

        max_sizes_by_axes = _base_expansion_size(tf.reduce_max(indices),
                                                 radixes)

        max_size = tf.reduce_max(max_sizes_by_axes)

        # The powers of the radixes that we will need. Note that there is a bit
        # of an excess here. Suppose we need the place value coefficients of 7
        # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits
        # for base 3. However, we can only create rectangular tensors so we
        # store both expansions in a [2, 3] tensor. This leads to the problem that
        # we might end up attempting to raise large numbers to large powers. For
        # example, base 2 expansion of 1024 has 10 digits. If we were in 10
        # dimensions, then the 10th prime (29) we will end up computing 29^10 even
        # though we don't need it. We avoid this by setting the exponents for each
        # axes to 0 beyond the maximum value needed for that dimension.
        exponents_by_axes = tf.tile([tf.range(max_size)], [dim, 1])

        # The mask is true for those coefficients that are irrelevant.
        weight_mask = exponents_by_axes < max_sizes_by_axes
        capped_exponents = tf.where(weight_mask, exponents_by_axes,
                                    tf.constant(0, exponents_by_axes.dtype))
        weights = radixes**capped_exponents
        # The following computes the base b expansion of the indices. Suppose,
        # x = a0 + a1*b + a2*b^2 + ... Then, performing a floor div of x with
        # the vector (1, b, b^2, b^3, ...) will produce
        # (a0 + s1 * b, a1 + s2 * b, ...) where s_i are coefficients we don't care
        # about. Noting that all a_i < b by definition of place value expansion,
        # we see that taking the elements mod b of the above vector produces the
        # place value expansion coefficients.
        coeffs = tf.math.floordiv(indices, weights)
        coeffs *= tf.cast(weight_mask, dtype)
        coeffs %= radixes
        if not randomized:
            coeffs /= radixes
            return tf.reduce_sum(coeffs / weights, axis=-1)
        stream = SeedStream(seed, salt='MCMCSampleHaltonSequence')
        coeffs = _randomize(coeffs, radixes, seed=stream())
        # Remove the contribution from randomizing the trailing zero for the
        # axes where max_size_by_axes < max_size. This will be accounted
        # for separately below (using zero_correction).
        coeffs *= tf.cast(weight_mask, dtype)
        coeffs /= radixes
        base_values = tf.reduce_sum(coeffs / weights, axis=-1)

        # The randomization used in Owen (2017) does not leave 0 invariant. While
        # we have accounted for the randomization of the first `max_size_by_axes`
        # coefficients, we still need to correct for the trailing zeros. Luckily,
        # this is equivalent to adding a uniform random value scaled so the first
        # `max_size_by_axes` coefficients are zero. The following statements perform
        # this correction.
        zero_correction = tf.random.uniform([dim, 1],
                                            seed=stream(),
                                            dtype=dtype)
        zero_correction /= radixes**max_sizes_by_axes
        return base_values + tf.reshape(zero_correction, [-1])