示例#1
0
def sample_chain(
    num_results,
    current_state,
    previous_kernel_results=None,
    kernel=None,
    num_burnin_steps=0,
    num_steps_between_results=0,
    trace_fn=_trace_kernel_results,
    return_final_kernel_results=False,
    parallel_iterations=10,
    seed=None,
    name=None,
):
    """Implements Markov chain Monte Carlo via repeated `TransitionKernel` steps.

  This function samples from a Markov chain at `current_state` whose
  stationary distribution is governed by the supplied `TransitionKernel`
  instance (`kernel`).

  This function can sample from multiple chains, in parallel. (Whether or not
  there are multiple chains is dictated by the `kernel`.)

  The `current_state` can be represented as a single `Tensor` or a `list` of
  `Tensors` which collectively represent the current state.

  Since MCMC states are correlated, it is sometimes desirable to produce
  additional intermediate states, and then discard them, ending up with a set of
  states with decreased autocorrelation.  See [Owen (2017)][1]. Such 'thinning'
  is made possible by setting `num_steps_between_results > 0`. The chain then
  takes `num_steps_between_results` extra steps between the steps that make it
  into the results. The extra steps are never materialized, and thus do not
  increase memory requirements.

  In addition to returning the chain state, this function supports tracing of
  auxiliary variables used by the kernel. The traced values are selected by
  specifying `trace_fn`. By default, all kernel results are traced but in the
  future the default will be changed to no results being traced, so plan
  accordingly. See below for some examples of this feature.

  Args:
    num_results: Integer number of Markov chain draws.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s).
    previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s
      representing internal calculations made within the previous call to this
      function (or as returned by `bootstrap_results`).
    kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step
      of the Markov chain.
    num_burnin_steps: Integer number of chain steps to take before starting to
      collect results.
      Default value: 0 (i.e., no burn-in).
    num_steps_between_results: Integer number of chain steps between collecting
      a result. Only one out of every `num_steps_between_samples + 1` steps is
      included in the returned results.  The number of returned chain states is
      still equal to `num_results`.  Default value: 0 (i.e., no thinning).
    trace_fn: A callable that takes in the current chain state and the previous
      kernel results and return a `Tensor` or a nested collection of `Tensor`s
      that is then traced along with the chain state.
    return_final_kernel_results: If `True`, then the final kernel results are
      returned alongside the chain state and the trace specified by the
      `trace_fn`.
    parallel_iterations: The number of iterations allowed to run in parallel. It
      must be a positive integer. See `tf.while_loop` for more details.
    seed: Optional, a seed for reproducible sampling.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'experimental_mcmc_sample_chain').

  Returns:
    checkpointable_states_and_trace: if `return_final_kernel_results` is
      `True`. The return value is an instance of
      `CheckpointableStatesAndTrace`.
    all_states: if `return_final_kernel_results` is `False` and `trace_fn` is
      `None`. The return value is a `Tensor` or Python list of `Tensor`s
      representing the state(s) of the Markov chain(s) at each result step. Has
      same shape as input `current_state` but with a prepended
      `num_results`-size dimension.
    states_and_trace: if `return_final_kernel_results` is `False` and
      `trace_fn` is not `None`. The return value is an instance of
      `StatesAndTrace`.

  #### References

  [1]: Art B. Owen. Statistically efficient thinning of a Markov chain sampler.
       _Technical Report_, 2017.
       http://statweb.stanford.edu/~owen/reports/bestthinning.pdf
  """
    with tf.name_scope(name or 'experimental_mcmc_sample_chain'):
        if not kernel.is_calibrated:
            warnings.warn(
                'supplied `TransitionKernel` is not calibrated. Markov '
                'chain may not converge to intended target distribution.')

        if trace_fn is None:
            trace_fn = lambda *args: ()
            no_trace = True
        else:
            no_trace = False

        if trace_fn is sample_chain.__defaults__[4]:
            warnings.warn(
                'Tracing all kernel results by default is deprecated. Set '
                'the `trace_fn` argument to None (the future default '
                'value) or an explicit callback that traces the values '
                'you are interested in.')

        # `WithReductions` assumes all its reducers want to reduce over the
        # immediate inner results of its kernel results. However,
        # We don't care about the kernel results of `SampleDiscardingKernel`; hence,
        # we evaluate the `trace_fn` on a deeper level of inner results.
        def real_trace_fn(curr_state, kr):
            return curr_state, trace_fn(curr_state, kr.inner_results)

        trace_reducer = tracing_reducer.TracingReducer(trace_fn=real_trace_fn,
                                                       size=num_results)
        # pylint: disable=unbalanced-tuple-unpacking
        trace_results, _, final_kernel_results = sample_fold(
            num_steps=num_results,
            current_state=current_state,
            previous_kernel_results=previous_kernel_results,
            kernel=kernel,
            reducer=trace_reducer,
            num_burnin_steps=num_burnin_steps,
            num_steps_between_results=num_steps_between_results,
            parallel_iterations=parallel_iterations,
            seed=seed,
            name=name,
        )

        all_states, trace = trace_results
        if return_final_kernel_results:
            return sample.CheckpointableStatesAndTrace(
                all_states=all_states,
                trace=trace,
                final_kernel_results=final_kernel_results)
        else:
            if no_trace:
                return all_states
            else:
                return sample.StatesAndTrace(all_states=all_states,
                                             trace=trace)
def _windowed_adaptive_impl(n_draws, joint_dist, *, kind, n_chains,
                            proposal_kernel_kwargs, num_adaptation_steps,
                            current_state, dual_averaging_kwargs, trace_fn,
                            return_final_kernel_results, discard_tuning, seed,
                            chain_axis_names, **pins):
    """Runs windowed sampling using either HMC or NUTS as internal sampler."""
    if trace_fn is None:
        trace_fn = lambda *args: ()
        no_trace = True
    else:
        no_trace = False

    if isinstance(n_chains, int):
        n_chains = [n_chains]

    if (tf.executing_eagerly()
            or not control_flow_util.GraphOrParentsInXlaContext(
                tf1.get_default_graph())):
        # A Tensor num_draws argument breaks XLA, which requires static TensorArray
        # trace_fn result allocation sizes.
        num_adaptation_steps = ps.convert_to_shape_tensor(num_adaptation_steps)

    if 'num_adaptation_steps' in dual_averaging_kwargs:
        warnings.warn(
            'Dual averaging adaptation will use the value specified in'
            ' the `num_adaptation_steps` argument for its construction,'
            ' hence there is no need to specify it in the'
            ' `dual_averaging_kwargs` argument.')

    # TODO(b/180011931): if num_adaptation_steps is small, this throws an error.
    dual_averaging_kwargs['num_adaptation_steps'] = num_adaptation_steps
    dual_averaging_kwargs.setdefault(
        'reduce_fn',
        functools.partial(
            generic_math.reduce_log_harmonic_mean_exp,
            # There is only one log_accept_prob per chain, and we reduce across
            # all chains, so typically the all_gather will be gathering scalars,
            # which should be relatively efficient.
            experimental_allow_all_gather=True))
    # By default, reduce over named axes for step size adaptation
    dual_averaging_kwargs.setdefault('experimental_reduce_chain_axis_names',
                                     chain_axis_names)
    setup_seed, sample_seed = samplers.split_seed(samplers.sanitize_seed(seed),
                                                  n=2)
    (target_log_prob_fn, initial_transformed_position, bijector,
     step_broadcast, batch_shape,
     shard_axis_names) = _setup_mcmc(joint_dist,
                                     n_chains=n_chains,
                                     init_position=current_state,
                                     seed=setup_seed,
                                     **pins)

    if proposal_kernel_kwargs.get('step_size') is None:
        if batch_shape.shape != (0, ):  # Scalar batch has a 0-vector shape.
            raise ValueError(
                'Batch target density must specify init_step_size. Got '
                f'batch shape {batch_shape} from joint {joint_dist}.')

        init_step_size = _get_step_size(initial_transformed_position,
                                        target_log_prob_fn)

    else:
        init_step_size = step_broadcast(proposal_kernel_kwargs['step_size'])

    proposal_kernel_kwargs.update({
        'target_log_prob_fn':
        target_log_prob_fn,
        'step_size':
        init_step_size,
        'momentum_distribution':
        _init_momentum(initial_transformed_position,
                       batch_shape=ps.concat([n_chains, batch_shape], axis=0),
                       shard_axis_names=shard_axis_names)
    })

    initial_running_variance = [
        sample_stats.RunningVariance.from_stats(  # pylint: disable=g-complex-comprehension
            num_samples=tf.zeros([], part.dtype),
            mean=tf.zeros_like(part),
            variance=tf.ones_like(part))
        for part in initial_transformed_position
    ]
    # TODO(phandu): Consider splitting out warmup and post warmup phases
    # to avoid executing adaptation code during the post warmup phase.
    ret = _do_sampling(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=n_draws if discard_tuning else n_draws +
        num_adaptation_steps,
        num_burnin_steps=num_adaptation_steps if discard_tuning else 0,
        initial_position=initial_transformed_position,
        initial_running_variance=initial_running_variance,
        bijector=bijector,
        trace_fn=trace_fn,
        return_final_kernel_results=return_final_kernel_results,
        chain_axis_names=chain_axis_names,
        shard_axis_names=shard_axis_names,
        seed=sample_seed)

    if return_final_kernel_results:
        draws, trace, fkr = ret
        return sample.CheckpointableStatesAndTrace(
            all_states=bijector.inverse(draws),
            trace=trace,
            final_kernel_results=fkr)
    else:
        draws, trace = ret
        if no_trace:
            return bijector.inverse(draws)
        else:
            return sample.StatesAndTrace(all_states=bijector.inverse(draws),
                                         trace=trace)
示例#3
0
def _windowed_adaptive_impl(n_draws,
                            joint_dist,
                            *,
                            kind,
                            n_chains,
                            proposal_kernel_kwargs,
                            num_adaptation_steps,
                            dual_averaging_kwargs,
                            trace_fn,
                            return_final_kernel_results,
                            discard_tuning,
                            seed,
                            **pins):
  """Runs windowed sampling using either HMC or NUTS as internal sampler."""
  if trace_fn is None:
    trace_fn = lambda *args: ()
    no_trace = True
  else:
    no_trace = False

  num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps)

  setup_seed, init_seed, seed = samplers.split_seed(
      samplers.sanitize_seed(seed), n=3)
  target_log_prob_fn, initial_transformed_position, bijector = _setup_mcmc(
      joint_dist, n_chains=n_chains, seed=setup_seed, **pins)

  first_window_size, slow_window_size, last_window_size = _get_window_sizes(
      num_adaptation_steps)
  # If we (over) optimistically assume good scaling, this will be near the
  # optimal step size, see Langmore, Ian, Michael Dikovsky, Scott Geraedts,
  # Peter Norgaard, and Rob Von Behren. 2019. “A Condition Number for
  # Hamiltonian Monte Carlo.” arXiv [stat.CO]. arXiv.
  # http://arxiv.org/abs/1905.09813.
  init_step_size = tf.cast(
      ps.shape(initial_transformed_position)[-1], tf.float32) ** -0.25

  all_draws = []
  all_traces = []
  proposal_kernel_kwargs.update({
      'target_log_prob_fn': target_log_prob_fn,
      'step_size': tf.fill([n_chains, 1], init_step_size),
      'momentum_distribution': _init_momentum(initial_transformed_position),
  })
  draws, trace, step_size, running_variance = _fast_window(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      dual_averaging_kwargs=dual_averaging_kwargs,
      num_draws=first_window_size,
      initial_position=initial_transformed_position,
      bijector=bijector,
      trace_fn=trace_fn,
      seed=init_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})

  all_draws.append(draws)
  all_traces.append(trace)
  *slow_seeds, seed = samplers.split_seed(seed, n=5)
  for idx, slow_seed in enumerate(slow_seeds):
    window_size = slow_window_size * (2**idx)

    # TODO(b/180011931): if num_adaptation_steps is small, this throws an error.
    draws, trace, step_size, running_variance, momentum_distribution = _slow_window(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=window_size,
        initial_position=draws[-1],
        initial_running_variance=running_variance,
        bijector=bijector,
        trace_fn=trace_fn,
        seed=slow_seed)
    all_draws.append(draws)
    all_traces.append(trace)
    proposal_kernel_kwargs.update(
        {'step_size': step_size,
         'momentum_distribution': momentum_distribution})

  fast_seed, sample_seed = samplers.split_seed(seed)
  draws, trace, step_size, running_variance = _fast_window(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      dual_averaging_kwargs=dual_averaging_kwargs,
      num_draws=last_window_size,
      initial_position=draws[-1],
      bijector=bijector,
      trace_fn=trace_fn,
      seed=fast_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})
  all_draws.append(draws)
  all_traces.append(trace)

  ret = _do_sampling(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      num_draws=n_draws,
      initial_position=draws[-1],
      bijector=bijector,
      trace_fn=trace_fn,
      return_final_kernel_results=return_final_kernel_results,
      seed=sample_seed)

  if discard_tuning:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(draws),
          trace=trace,
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      if no_trace:
        return bijector.inverse(draws)
      else:
        return sample.StatesAndTrace(all_states=bijector.inverse(draws),
                                     trace=trace)
  else:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      all_draws.append(draws)
      all_traces.append(trace)
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(tf.concat(all_draws, axis=0)),
          trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                      *all_traces, expand_composites=True),
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      all_draws.append(draws)
      all_traces.append(trace)
      if no_trace:
        return bijector.inverse(tf.concat(all_draws, axis=0))
      else:
        return sample.StatesAndTrace(
            all_states=bijector.inverse(tf.concat(all_draws, axis=0)),
            trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                        *all_traces, expand_composites=True))
示例#4
0
def _windowed_adaptive_impl(n_draws,
                            joint_dist,
                            *,
                            kind,
                            n_chains,
                            proposal_kernel_kwargs,
                            num_adaptation_steps,
                            current_state,
                            dual_averaging_kwargs,
                            trace_fn,
                            return_final_kernel_results,
                            discard_tuning,
                            seed,
                            **pins):
  """Runs windowed sampling using either HMC or NUTS as internal sampler."""
  if trace_fn is None:
    trace_fn = lambda *args: ()
    no_trace = True
  else:
    no_trace = False

  if (tf.executing_eagerly() or
      not control_flow_util.GraphOrParentsInXlaContext(
          tf1.get_default_graph())):
    # A Tensor num_draws argument breaks XLA, which requires static TensorArray
    # trace_fn result allocation sizes.
    num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps)

  setup_seed, init_seed, seed = samplers.split_seed(
      samplers.sanitize_seed(seed), n=3)
  (target_log_prob_fn, initial_transformed_position, bijector,
   step_broadcast, batch_shape) = _setup_mcmc(
       joint_dist,
       n_chains=n_chains,
       init_position=current_state,
       seed=setup_seed,
       **pins)

  if proposal_kernel_kwargs.get('step_size') is None:
    if batch_shape.shape != (0,):  # Scalar batch has a 0-vector shape.
      raise ValueError('Batch target density must specify init_step_size. Got '
                       f'batch shape {batch_shape} from joint {joint_dist}.')

    init_step_size = _get_step_size(initial_transformed_position,
                                    target_log_prob_fn)

  else:
    init_step_size = step_broadcast(proposal_kernel_kwargs['step_size'])

  proposal_kernel_kwargs.update({
      'target_log_prob_fn': target_log_prob_fn,
      'step_size': init_step_size,
      'momentum_distribution': _init_momentum(
          initial_transformed_position,
          batch_shape=ps.concat([[n_chains], batch_shape], axis=0))})

  first_window_size, slow_window_size, last_window_size = _get_window_sizes(
      num_adaptation_steps)

  all_traces = []
  # Using tf.function here and on _slow_window_closure caches tracing
  # of _fast_window and _slow_window, respectively, within a single
  # call to windowed sampling.  Why not annotate _fast_window and
  # _slow_window directly?  Two reasons:
  # - Caching across calls to windowed sampling is probably futile,
  #   because the trace function and bijector will be different Python
  #   objects, preventing cache hits.
  # - The cache of a global tf.function sticks around for the lifetime
  #   of the Python process, potentially leaking memory.
  @tf.function(autograph=False)
  def _fast_window_closure(proposal_kernel_kwargs,
                           window_size,
                           initial_position,
                           seed):
    return _fast_window(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=window_size,
        initial_position=initial_position,
        bijector=bijector,
        trace_fn=trace_fn,
        seed=seed)
  draws, trace, step_size, running_variances = _fast_window_closure(
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      window_size=first_window_size,
      initial_position=initial_transformed_position,
      seed=init_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})

  all_draws = [[d] for d in draws]
  all_traces.append(trace)
  *slow_seeds, seed = samplers.split_seed(seed, n=5)
  @tf.function(autograph=False)
  def _slow_window_closure(proposal_kernel_kwargs,
                           window_size,
                           initial_position,
                           running_variances,
                           seed):
    return _slow_window(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=window_size,
        initial_position=initial_position,
        initial_running_variance=running_variances,
        bijector=bijector,
        trace_fn=trace_fn,
        seed=seed)
  for idx, slow_seed in enumerate(slow_seeds):
    window_size = slow_window_size * (2**idx)

    # TODO(b/180011931): if num_adaptation_steps is small, this throws an error.
    (draws, trace, step_size, running_variances, momentum_distribution
     ) = _slow_window_closure(
         proposal_kernel_kwargs=proposal_kernel_kwargs,
         window_size=window_size,
         initial_position=[d[-1] for d in draws],
         running_variances=running_variances,
         seed=slow_seed)
    for all_d, d in zip(all_draws, draws):
      all_d.append(d)
    all_traces.append(trace)
    proposal_kernel_kwargs.update(
        {'step_size': step_size,
         'momentum_distribution': momentum_distribution})

  fast_seed, sample_seed = samplers.split_seed(seed)
  draws, trace, step_size, _ = _fast_window_closure(
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      window_size=last_window_size,
      initial_position=[d[-1] for d in draws],
      seed=fast_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})
  for all_d, d in zip(all_draws, draws):
    all_d.append(d)
  all_traces.append(trace)

  ret = _do_sampling(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      num_draws=n_draws,
      initial_position=[d[-1] for d in draws],
      bijector=bijector,
      trace_fn=trace_fn,
      return_final_kernel_results=return_final_kernel_results,
      seed=sample_seed)

  if discard_tuning:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(draws),
          trace=trace,
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      if no_trace:
        return bijector.inverse(draws)
      else:
        return sample.StatesAndTrace(all_states=bijector.inverse(draws),
                                     trace=trace)
  else:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      for all_d, d in zip(all_draws, draws):
        all_d.append(d)
      all_traces.append(trace)
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(
              [tf.concat(d, axis=0) for d in all_draws]),
          trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                      *all_traces, expand_composites=True),
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      for all_d, d in zip(all_draws, draws):
        all_d.append(d)
      all_states = bijector.inverse([tf.concat(d, axis=0) for d in all_draws])
      if no_trace:
        return all_states
      else:
        all_traces.append(trace)
        return sample.StatesAndTrace(
            all_states=all_states,
            trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                        *all_traces, expand_composites=True))
示例#5
0
def windowed_adaptive_hmc(n_draws,
                          joint_dist,
                          *,
                          num_leapfrog_steps,
                          n_chains=64,
                          num_adaptation_steps=525,
                          target_accept_prob=0.75,
                          trace_fn=_default_trace_fn,
                          return_final_kernel_results=False,
                          discard_tuning=True,
                          seed=None,
                          **pins):
    """Adapt and sample from a joint distribution, conditioned on pins.

  This uses Hamiltonian Monte Carlo to do the sampling. Step size is tuned using
  a dual-averaging adaptation, and the kernel is conditioned using a diagonal
  mass matrix, which is estimated using expanding windows.

  Args:
    n_draws: int
      Number of draws after adaptation.
    joint_dist: `tfd.JointDistribution`
      A joint distribution to sample from.
    num_leapfrog_steps: int
      Number of leapfrog steps to use for the Hamiltonian Monte Carlo step.
    n_chains: int
      Number of independent chains to run MCMC with.
    num_adaptation_steps: int
      Number of draws used to adapt step size and
    target_accept_prob: float
      Target acceptance rate for the step size adaptation.
    trace_fn: Optional callable
      The trace function should accept the arguments
      `(state, bijector, is_adapting, phmc_kernel_results)`,  where the `state`
      is an unconstrained, flattened float tensor, `bijector` is the
      `tfb.Bijector` that is used for unconstraining and flattening,
      `is_adapting` is a boolean to mark whether the draw is from an adaptation
      step, and `phmc_kernel_results` is the
      `UncalibratedPreconditionedHamiltonianMonteCarloKernelResults` from the
      `PreconditionedHamiltonianMonteCarlo` kernel. Note that
      `bijector.inverse(state)` will provide access to the current draw in the
      untransformed space, using the structure of the provided `joint_dist`.
    return_final_kernel_results: If `True`, then the final kernel results are
      returned alongside the chain state and the trace specified by the
      `trace_fn`.
    discard_tuning: bool
      Whether to return tuning traces and draws.
    seed: Optional, a seed for reproducible sampling.
    **pins:
      These are used to condition the provided joint distribution, and are
      passed directly to `joint_dist.experimental_pin(**pins)`.
  Returns:
    A single structure of draws is returned in case the trace_fn is `None`, and
    `return_final_kernel_results` is `False`. If there is a trace function,
    the return value is a tuple, with the trace second. If the
    `return_final_kernel_results` is `True`, the return value is a tuple of
    length 3, with final kernel results returned last. If `discard_tuning` is
    `True`, the tensors in `draws` and `trace` will have length `n_draws`,
    otherwise, they will have length `n_draws + num_adaptation_steps`.
  """
    if trace_fn is None:
        trace_fn = lambda *args: ()
        no_trace = True
    else:
        no_trace = False

    num_leapfrog_steps = tf.convert_to_tensor(num_leapfrog_steps)
    target_accept_prob = tf.convert_to_tensor(target_accept_prob)
    num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps)

    setup_seed, init_seed, seed = samplers.split_seed(
        samplers.sanitize_seed(seed), n=3)
    target_log_prob_fn, initial_transformed_position, bijector = _setup_mcmc(
        joint_dist, n_chains=n_chains, seed=setup_seed, **pins)

    first_window_size, slow_window_size, last_window_size = _get_window_sizes(
        num_adaptation_steps)
    # If we (over) optimistically assume good scaling, this will be near the
    # optimal step size, see Langmore, Ian, Michael Dikovsky, Scott Geraedts,
    # Peter Norgaard, and Rob Von Behren. 2019. “A Condition Number for
    # Hamiltonian Monte Carlo.” arXiv [stat.CO]. arXiv.
    # http://arxiv.org/abs/1905.09813.
    init_step_size = tf.cast(
        1. / ps.shape(initial_transformed_position)[-1]**0.25, tf.float32)

    all_draws = []
    all_traces = []
    draws, trace, step_size, running_variance = _fast_window(
        target_log_prob_fn=target_log_prob_fn,
        num_leapfrog_steps=num_leapfrog_steps,
        num_draws=first_window_size,
        initial_position=initial_transformed_position,
        initial_step_size=tf.fill([n_chains, 1], init_step_size),
        target_accept_prob=target_accept_prob,
        momentum_distribution=_init_momentum(initial_transformed_position),
        bijector=bijector,
        trace_fn=trace_fn,
        seed=init_seed)
    all_draws.append(draws)
    all_traces.append(trace)

    *slow_seeds, seed = samplers.split_seed(seed, n=5)
    for idx, slow_seed in enumerate(slow_seeds):
        window_size = slow_window_size * (2**idx)

        # TODO(b/180011931): if num_adaptation_steps is small, this throws an error.
        draws, trace, step_size, running_variance, momentum_distribution = _slow_window(
            target_log_prob_fn=target_log_prob_fn,
            num_leapfrog_steps=num_leapfrog_steps,
            num_draws=window_size,
            initial_position=draws[-1],
            initial_running_variance=running_variance,
            initial_step_size=step_size,
            target_accept_prob=target_accept_prob,
            bijector=bijector,
            trace_fn=trace_fn,
            seed=slow_seed)
        all_draws.append(draws)
        all_traces.append(trace)

    fast_seed, sample_seed = samplers.split_seed(seed)
    draws, trace, step_size, running_variance = _fast_window(
        target_log_prob_fn=target_log_prob_fn,
        num_leapfrog_steps=num_leapfrog_steps,
        num_draws=last_window_size,
        initial_position=draws[-1],
        initial_step_size=step_size,
        target_accept_prob=target_accept_prob,
        momentum_distribution=momentum_distribution,
        bijector=bijector,
        trace_fn=trace_fn,
        seed=fast_seed)
    all_draws.append(draws)
    all_traces.append(trace)

    ret = _do_sampling(target_log_prob_fn=target_log_prob_fn,
                       num_leapfrog_steps=num_leapfrog_steps,
                       num_draws=n_draws,
                       initial_position=draws[-1],
                       step_size=step_size,
                       momentum_distribution=momentum_distribution,
                       bijector=bijector,
                       trace_fn=trace_fn,
                       return_final_kernel_results=return_final_kernel_results,
                       seed=sample_seed)

    if discard_tuning:
        if return_final_kernel_results:
            draws, trace, fkr = ret
            return sample.CheckpointableStatesAndTrace(
                all_states=bijector.inverse(draws),
                trace=trace,
                final_kernel_results=fkr)
        else:
            draws, trace = ret
            if no_trace:
                return bijector.inverse(draws)
            else:
                return sample.StatesAndTrace(
                    all_states=bijector.inverse(draws), trace=trace)
    else:
        if return_final_kernel_results:
            draws, trace, fkr = ret
            all_draws.append(draws)
            all_traces.append(trace)
            return sample.CheckpointableStatesAndTrace(
                all_states=bijector.inverse(tf.concat(all_draws, axis=0)),
                trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                            *all_traces,
                                            expand_composites=True),
                final_kernel_results=fkr)
        else:
            draws, trace = ret
            all_draws.append(draws)
            all_traces.append(trace)
            if no_trace:
                return bijector.inverse(tf.concat(all_draws, axis=0))
            else:
                return sample.StatesAndTrace(
                    all_states=bijector.inverse(tf.concat(all_draws, axis=0)),
                    trace=tf.nest.map_structure(
                        lambda *s: tf.concat(s, axis=0),
                        *all_traces,
                        expand_composites=True))
def _windowed_adaptive_impl(n_draws,
                            joint_dist,
                            *,
                            kind,
                            n_chains,
                            proposal_kernel_kwargs,
                            num_adaptation_steps,
                            current_state,
                            dual_averaging_kwargs,
                            trace_fn,
                            return_final_kernel_results,
                            discard_tuning,
                            seed,
                            **pins):
  """Runs windowed sampling using either HMC or NUTS as internal sampler."""
  if trace_fn is None:
    trace_fn = lambda *args: ()
    no_trace = True
  else:
    no_trace = False

  if (tf.executing_eagerly() or
      not control_flow_util.GraphOrParentsInXlaContext(
          tf1.get_default_graph())):
    # A Tensor num_draws argument breaks XLA, which requires static TensorArray
    # trace_fn result allocation sizes.
    num_adaptation_steps = tf.convert_to_tensor(num_adaptation_steps)

  setup_seed, init_seed, seed = samplers.split_seed(
      samplers.sanitize_seed(seed), n=3)
  target_log_prob_fn, initial_transformed_position, bijector = _setup_mcmc(
      joint_dist,
      n_chains=n_chains,
      init_position=current_state,
      seed=setup_seed,
      **pins)

  first_window_size, slow_window_size, last_window_size = _get_window_sizes(
      num_adaptation_steps)
  # If we (over) optimistically assume good scaling, this will be near the
  # optimal step size, see Langmore, Ian, Michael Dikovsky, Scott Geraedts,
  # Peter Norgaard, and Rob Von Behren. 2019. “A Condition Number for
  # Hamiltonian Monte Carlo.” arXiv [stat.CO]. arXiv.
  # http://arxiv.org/abs/1905.09813.
  init_step_size = 0.5 * sum([
      tf.cast(ps.shape(state_part)[-1], tf.float32)
      for state_part in initial_transformed_position])**-0.25

  proposal_kernel_kwargs.update({
      'target_log_prob_fn': target_log_prob_fn,
      'step_size': tf.fill([n_chains, 1], init_step_size, name='step_size'),
      'momentum_distribution': _init_momentum(initial_transformed_position,
                                              batch_shape=[n_chains]),
  })
  all_traces = []
  draws, trace, step_size, running_variances = _fast_window(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      dual_averaging_kwargs=dual_averaging_kwargs,
      num_draws=first_window_size,
      initial_position=initial_transformed_position,
      bijector=bijector,
      trace_fn=trace_fn,
      seed=init_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})

  all_draws = [[d] for d in draws]
  all_traces.append(trace)
  *slow_seeds, seed = samplers.split_seed(seed, n=5)
  for idx, slow_seed in enumerate(slow_seeds):
    window_size = slow_window_size * (2**idx)

    # TODO(b/180011931): if num_adaptation_steps is small, this throws an error.
    draws, trace, step_size, running_variances, momentum_distribution = _slow_window(
        kind=kind,
        proposal_kernel_kwargs=proposal_kernel_kwargs,
        dual_averaging_kwargs=dual_averaging_kwargs,
        num_draws=window_size,
        initial_position=[d[-1] for d in draws],
        initial_running_variance=running_variances,
        bijector=bijector,
        trace_fn=trace_fn,
        seed=slow_seed)
    for all_d, d in zip(all_draws, draws):
      all_d.append(d)
    all_traces.append(trace)
    proposal_kernel_kwargs.update(
        {'step_size': step_size,
         'momentum_distribution': momentum_distribution})

  fast_seed, sample_seed = samplers.split_seed(seed)
  draws, trace, step_size, _ = _fast_window(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      dual_averaging_kwargs=dual_averaging_kwargs,
      num_draws=last_window_size,
      initial_position=[d[-1] for d in draws],
      bijector=bijector,
      trace_fn=trace_fn,
      seed=fast_seed)
  proposal_kernel_kwargs.update({'step_size': step_size})
  for all_d, d in zip(all_draws, draws):
    all_d.append(d)
  all_traces.append(trace)

  ret = _do_sampling(
      kind=kind,
      proposal_kernel_kwargs=proposal_kernel_kwargs,
      num_draws=n_draws,
      initial_position=[d[-1] for d in draws],
      bijector=bijector,
      trace_fn=trace_fn,
      return_final_kernel_results=return_final_kernel_results,
      seed=sample_seed)

  if discard_tuning:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(draws),
          trace=trace,
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      if no_trace:
        return bijector.inverse(draws)
      else:
        return sample.StatesAndTrace(all_states=bijector.inverse(draws),
                                     trace=trace)
  else:
    if return_final_kernel_results:
      draws, trace, fkr = ret
      for all_d, d in zip(all_draws, draws):
        all_d.append(d)
      all_traces.append(trace)
      return sample.CheckpointableStatesAndTrace(
          all_states=bijector.inverse(
              [tf.concat(d, axis=0) for d in all_draws]),
          trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                      *all_traces, expand_composites=True),
          final_kernel_results=fkr)
    else:
      draws, trace = ret
      for all_d, d in zip(all_draws, draws):
        all_d.append(d)
      all_states = bijector.inverse([tf.concat(d, axis=0) for d in all_draws])
      if no_trace:
        return all_states
      else:
        all_traces.append(trace)
        return sample.StatesAndTrace(
            all_states=all_states,
            trace=tf.nest.map_structure(lambda *s: tf.concat(s, axis=0),
                                        *all_traces, expand_composites=True))