Example #1
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
                init_state)

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            if self._state_includes_replicas:
                it_n_replica = inverse_temperatures.shape[0]
                state_n_replica = init_state[0].shape[0]
                if ((it_n_replica is not None)
                        and (state_n_replica is not None)
                        and (it_n_replica != state_n_replica)):
                    raise ValueError(
                        'Number of replicas implied by initial state ({}) must equal '
                        'number of replicas implied by inverse_temperatures ({}), but '
                        'did not'.format(it_n_replica, state_n_replica))

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = ps.size0(inverse_temperatures)
            replica_shape = tf.convert_to_tensor([num_replica])

            if self._state_includes_replicas:
                replica_states = init_state
            else:
                replica_states = [
                    tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                        x,
                        ps.concat([replica_shape, ps.shape(x)], axis=0),
                        name='replica_states') for x in init_state
                ]

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                self.target_log_prob_fn, inverse_temperatures)
            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we fall back to the previous behavior and warn.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The second (`seed`) argument to `ReplicaExchangeMC`s '
                    '`make_kernel_fn` is deprecated. `TransitionKernel` instances now '
                    'receive seeds via `bootstrap_results` and `one_step`. This '
                    'fallback may become an error 2020-09-20.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel, self._seed_stream())

            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = mcmc_util.left_justified_broadcast_to(
                tf.range(num_replica), replica_and_batch_shape)
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),
                                                                  shift=2)

            post_swap_replica_results = _make_post_swap_replica_results(
                replica_results,
                inverse_temperatures,
                inverse_temperatures,
                is_swap_accepted[0],
                lambda x: x,
            )

            return ReplicaExchangeMCKernelResults(
                post_swap_replica_states=replica_states,
                pre_swap_replica_results=replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_accepted,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_accepted),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                inverse_temperatures=self.inverse_temperatures,
                swaps=swaps,
                step_count=tf.zeros(shape=(), dtype=tf.int32),
                seed=samplers.zeros_seed(),
            )
Example #2
0
  def bootstrap_results(self, init_state):
    """Creates initial `previous_kernel_results` using a supplied `state`."""
    with tf.name_scope(self.name + '.bootstrap_results'):
      if not tf.nest.is_nested(init_state):
        init_state = [init_state]
      dummy_momentum = [tf.ones_like(state) for state in init_state]

      def _init(shape_and_dtype):
        """Allocate TensorArray for storing state and momentum."""
        return [  # pylint: disable=g-complex-comprehension
            ps.zeros(
                ps.concat([[max(self._write_instruction) + 1], s], axis=0),
                dtype=d) for (s, d) in shape_and_dtype
        ]

      get_shapes_and_dtypes = lambda x: [(ps.shape(x_), x_.dtype)  # pylint: disable=g-long-lambda
                                         for x_ in x]
      momentum_state_memory = MomentumStateSwap(
          momentum_swap=_init(get_shapes_and_dtypes(dummy_momentum)),
          state_swap=_init(get_shapes_and_dtypes(init_state)))
      [
          _,
          _,
          current_target_log_prob,
          current_grads_log_prob,
      ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum,
                                     init_state)

      # Padding the step_size so it is compatable with the states
      step_size = self.step_size
      if len(step_size) == 1:
        step_size = step_size * len(init_state)
      if len(step_size) != len(init_state):
        raise ValueError('Expected either one step size or {} (size of '
                         '`init_state`), but found {}'.format(
                             len(init_state), len(step_size)))
      step_size = tf.nest.map_structure(
          lambda x: tf.convert_to_tensor(  # pylint: disable=g-long-lambda
              x,
              dtype=current_target_log_prob.dtype,
              name='step_size'),
          step_size)

      return NUTSKernelResults(
          target_log_prob=current_target_log_prob,
          grads_target_log_prob=current_grads_log_prob,
          momentum_state_memory=momentum_state_memory,
          step_size=step_size,
          log_accept_ratio=tf.zeros_like(current_target_log_prob,
                                         name='log_accept_ratio'),
          leapfrogs_taken=tf.zeros_like(current_target_log_prob,
                                        dtype=TREE_COUNT_DTYPE,
                                        name='leapfrogs_taken'),
          is_accepted=tf.zeros_like(current_target_log_prob,
                                    dtype=tf.bool,
                                    name='is_accepted'),
          reach_max_depth=tf.zeros_like(current_target_log_prob,
                                        dtype=tf.bool,
                                        name='reach_max_depth'),
          has_divergence=tf.zeros_like(current_target_log_prob,
                                       dtype=tf.bool,
                                       name='has_divergence'),
          energy=compute_hamiltonian(current_target_log_prob, dummy_momentum),
          # Allow room for one_step's seed.
          seed=samplers.zeros_seed(),
      )
Example #3
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                self.target_log_prob_fn, inverse_temperatures)
            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`, and no seed
            # expected by `kernel.one_step`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we warn and fall back to the previous behavior.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is '
                    'deprecated. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel, self._seed_stream())

            # Now that we've constructed the TransitionKernel instance:
            # - If we were given a seed, we sanitize it to stateless and pass along
            #   to `kernel.one_step`. If it doesn't like that, we crash and propagate
            #   the error.  Rationale: The contract is stateless sampling given
            #   seed, and doing otherwise would not meet it.
            # - If not given a seed, we don't pass one along. This avoids breaking
            #   underlying kernels lacking a `seed` arg on `one_step`.
            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                seed = samplers.sanitize_seed(seed)
                inner_seed, swap_seed, logu_seed = samplers.split_seed(
                    seed, n=3, salt='remc_one_step')
                inner_kwargs = dict(seed=inner_seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                inner_kwargs = {}
                swap_seed, logu_seed = samplers.split_seed(self._seed_stream())
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                **inner_kwargs)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            return states, post_swap_kernel_results
Example #4
0
def _fixed_sample(d):
    return d.sample(seed=samplers.zeros_seed())
  def bootstrap_results(self, init_state):
    """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
    with tf.name_scope(mcmc_util.make_name(
        self.name, 'remc', 'bootstrap_results')):
      init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
          init_state)

      inverse_temperatures = tf.convert_to_tensor(
          self.inverse_temperatures,
          name='inverse_temperatures')

      if self._state_includes_replicas:
        it_n_replica = inverse_temperatures.shape[0]
        state_n_replica = init_state[0].shape[0]
        if ((it_n_replica is not None) and (state_n_replica is not None) and
            (it_n_replica != state_n_replica)):
          raise ValueError(
              'Number of replicas implied by initial state ({}) must equal '
              'number of replicas implied by inverse_temperatures ({}), but '
              'did not'.format(it_n_replica, state_n_replica))

      # We will now replicate each of a possible batch of initial stats, one for
      # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
      # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
      # concatenation and T=shape(inverse_temperature).
      num_replica = ps.size0(inverse_temperatures)
      replica_shape = ps.convert_to_shape_tensor([num_replica])

      if self._state_includes_replicas:
        replica_states = init_state
      else:
        replica_states = [
            tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                x,
                ps.concat([replica_shape, ps.shape(x)], axis=0),
                name='replica_states')
            for x in init_state
        ]

      target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
          target_log_prob_fn=self.target_log_prob_fn,
          inverse_temperatures=inverse_temperatures,
          untempered_log_prob_fn=self.untempered_log_prob_fn,
          tempered_log_prob_fn=self.tempered_log_prob_fn,
      )
      # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
      try:
        inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
            target_log_prob_for_inner_kernel)
      except TypeError as e:
        if 'argument' not in str(e):
          raise
        raise TypeError(
            '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a second '
            '(`seed`) argument. `TransitionKernel` instances now receive seeds '
            'via `one_step`.')

      replica_results = inner_kernel.bootstrap_results(replica_states)

      pre_swap_replica_target_log_prob = _get_field(
          replica_results, 'target_log_prob')

      replica_and_batch_shape = ps.shape(
          pre_swap_replica_target_log_prob)
      batch_shape = replica_and_batch_shape[1:]

      inverse_temperatures = bu.left_justified_broadcast_to(
          inverse_temperatures, replica_and_batch_shape)

      # Pretend we did a "null swap", which will always be accepted.
      swaps = bu.left_justified_broadcast_to(
          tf.range(num_replica), replica_and_batch_shape)
      # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
      is_swap_accepted = distribution_util.rotate_transpose(
          tf.eye(num_replica, batch_shape=batch_shape, dtype=tf.bool),
          shift=2)

      return ReplicaExchangeMCKernelResults(
          post_swap_replica_states=replica_states,
          pre_swap_replica_results=replica_results,
          post_swap_replica_results=_set_swapped_fields_to_nan(replica_results),
          is_swap_proposed=is_swap_accepted,
          is_swap_accepted=is_swap_accepted,
          is_swap_proposed_adjacent=_sub_diag(is_swap_accepted),
          is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
          inverse_temperatures=self.inverse_temperatures,
          swaps=swaps,
          step_count=tf.zeros(shape=(), dtype=tf.int32),
          seed=samplers.zeros_seed(),
      )
Example #6
0
    def bootstrap_results(self, init_state):
        """Creates initial `previous_kernel_results` using a supplied `state`."""
        with tf.name_scope(self.name + '.bootstrap_results'):
            if not tf.nest.is_nested(init_state):
                init_state = [init_state]
            # Padding the step_size so it is compatable with the states
            step_size = self.step_size
            if len(step_size) == 1:
                step_size = step_size * len(init_state)
            if len(step_size) != len(init_state):
                raise ValueError(
                    'Expected either one step size or {} (size of '
                    '`init_state`), but found {}'.format(
                        len(init_state), len(step_size)))
            state_parts, _ = mcmc_util.prepare_state_parts(
                init_state, name='current_state')
            current_target_log_prob, current_grads_log_prob = mcmc_util.maybe_call_fn_and_grads(
                self.target_log_prob_fn, state_parts)
            momentum_distribution = self.momentum_distribution
            if momentum_distribution is None:
                momentum_distribution = pu.make_momentum_distribution(
                    state_parts, ps.shape(current_target_log_prob))
            momentum_distribution = pu.maybe_make_list_and_batch_broadcast(
                momentum_distribution, ps.shape(current_target_log_prob))
            momentum_parts = momentum_distribution.sample()

            def _init(shape_and_dtype):
                """Allocate TensorArray for storing state and velocity."""
                return [  # pylint: disable=g-complex-comprehension
                    ps.zeros(ps.concat([[max(self._write_instruction) + 1], s],
                                       axis=0),
                             dtype=d) for (s, d) in shape_and_dtype
                ]

            get_shapes_and_dtypes = lambda x: [
                (ps.shape(x_), x_.dtype)  # pylint: disable=g-long-lambda
                for x_ in x
            ]
            velocity_state_memory = VelocityStateSwap(
                velocity_swap=_init(get_shapes_and_dtypes(momentum_parts)),
                state_swap=_init(get_shapes_and_dtypes(init_state)))

            return PreconditionedNUTSKernelResults(
                target_log_prob=current_target_log_prob,
                grads_target_log_prob=current_grads_log_prob,
                velocity_state_memory=velocity_state_memory,
                step_size=step_size,
                log_accept_ratio=tf.zeros_like(current_target_log_prob,
                                               name='log_accept_ratio'),
                leapfrogs_taken=tf.zeros_like(current_target_log_prob,
                                              dtype=TREE_COUNT_DTYPE,
                                              name='leapfrogs_taken'),
                is_accepted=tf.zeros_like(current_target_log_prob,
                                          dtype=tf.bool,
                                          name='is_accepted'),
                reach_max_depth=tf.zeros_like(current_target_log_prob,
                                              dtype=tf.bool,
                                              name='reach_max_depth'),
                has_divergence=tf.zeros_like(current_target_log_prob,
                                             dtype=tf.bool,
                                             name='has_divergence'),
                energy=compute_hamiltonian(current_target_log_prob,
                                           momentum_parts,
                                           momentum_distribution),
                momentum_distribution=momentum_distribution,
                # Allow room for one_step's seed.
                seed=samplers.zeros_seed(),
            )
Example #7
0
def dummy_seed():
    """Returns a fixed constant seed, for cases needing samples without a seed."""
    # TODO(b/147874898): After 20 Dec 2020, drop the 42 and inline the zeros_seed.
    return samplers.zeros_seed() if JAX_MODE else 42
from tensorflow_probability.python.internal import callable_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers

# pylint: disable=g-long-lambda,protected-access
preconditioning_bijector_fns = {
    deterministic.Deterministic:
    (lambda d: d.experimental_default_event_space_bijector()),
    independent.Independent:
    lambda d: make_distribution_bijector(d.distribution),
    markov_chain.MarkovChain:
    lambda d: markov_chain._MarkovChainBijector(
        chain=d,
        transition_bijector=make_distribution_bijector(
            d.transition_fn(
                0, d.initial_state_prior.sample(seed=samplers.zeros_seed()))),
        bijector_fn=make_distribution_bijector),
    normal.Normal:
    lambda d: tfb.Shift(d.loc)(tfb.Scale(d.scale)),
    sample.Sample:
    lambda d: sample._DefaultSampleBijector(
        distribution=d.distribution,
        sample_shape=d.sample_shape,
        sum_fn=d._sum_fn(),
        bijector=make_distribution_bijector(d.distribution)),
    uniform.Uniform:
    lambda d: (tfb.Shift(d.low)(tfb.Scale(d.high - d.low)(tfb.NormalCDF())))
}
# pylint: enable=g-long-lambda,protected-access

Example #9
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'rwm', 'one_step')):
            with tf.name_scope('initialize'):
                if mcmc_util.is_list_like(current_state):
                    current_state_parts = list(current_state)
                else:
                    current_state_parts = [current_state]
                current_state_parts = [
                    tf.convert_to_tensor(s, name='current_state')
                    for s in current_state_parts
                ]

            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.new_state_fn`.
            # In other words:
            # - If we were given a seed, we sanitize it to stateless, and
            #   if the `new_state_fn` doesn't like that, we crash and propagate
            #   the error.  Rationale: The contract is stateless sampling given
            #   seed, and doing otherwise would not meet it.
            # - If we were not given a seed, we try `new_state_fn` with a stateless
            #   seed.  Rationale: This is the future.
            # - If it fails with a seed incompatibility problem (as best we can
            #   detect from here), we issue a warning and try it again with a
            #   stateful-style seed. Rationale: User code that didn't set seeds
            #   shouldn't suddenly break.
            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                force_stateless = True
                seed = samplers.sanitize_seed(seed)
            else:
                force_stateless = False
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                stateful_seed = self._seed_stream()
                seed = samplers.sanitize_seed(stateful_seed)
            try:
                next_state_parts = self.new_state_fn(current_state_parts, seed)  # pylint: disable=not-callable
            except TypeError as e:
                if ('Expected int for argument' not in str(e)
                        and TENSOR_SEED_MSG_PREFIX
                        not in str(e)) or force_stateless:
                    raise
                msg = (
                    'Falling back to `int` seed for `new_state_fn` {}. Please update '
                    'to use `tf.random.stateless_*` RNGs. '
                    'This fallback may be removed after 10-Sep-2020. ({})')
                warnings.warn(msg.format(self.new_state_fn, str(e)))
                seed = None
                next_state_parts = self.new_state_fn(  # pylint: disable=not-callable
                    current_state_parts, stateful_seed)
            # Compute `target_log_prob` so its available to MetropolisHastings.
            next_target_log_prob = self.target_log_prob_fn(*next_state_parts)  # pylint: disable=not-callable

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedRandomWalkResults(
                    log_acceptance_correction=tf.zeros_like(
                        next_target_log_prob),
                    target_log_prob=next_target_log_prob,
                    seed=samplers.zeros_seed() if seed is None else seed,
                ),
            ]
Example #10
0
    def __init__(self,
                 parameter_prior,
                 parameterized_initial_state_prior_fn,
                 parameterized_transition_fn,
                 parameterized_observation_fn,
                 parameterized_initial_state_proposal_fn=None,
                 parameterized_proposal_fn=None,
                 parameter_constraining_bijector=None,
                 name=None):
        """Builds an iterated filter for parameter estimation in sequential models.

    Iterated filtering is a parameter estimation method in which parameters
    are included in an augmented state space, with dynamics that introduce
    parameter perturbations, and a filtering
    algorithm such as particle filtering is run several times with perturbations
    of decreasing size. This class implements the IF2 algorithm of
    [Ionides et al., 2015][1], for which, under appropriate conditions
    (including a uniform prior) the final parameter distribution approaches a
    point mass at the maximum likelihood estimate. If a non-uniform prior is
    provided, the final parameter distribution will (under appropriate
    conditions) approach a point mass at the maximum a posteriori (MAP) value.

    This class augments the state space of a sequential model to include
    parameter perturbations, and provides utilities to run particle filtering
    on that augmented model. Alternately, the augmented components may be passed
    directly into a filtering algorithm of the user's choice.

    Args:
      parameter_prior: prior `tfd.Distribution` over parameters (may be a joint
        distribution).
      parameterized_initial_state_prior_fn: `callable` with signature
        `initial_state_prior = parameterized_initial_state_prior_fn(parameters)`
        where `parameters` has the form of a sample from `parameter_prior`,
        and `initial_state_prior` is a distribution over the initial state.
      parameterized_transition_fn: `callable` with signature
        `next_state_dist = parameterized_transition_fn(
        step, state, parameters, **kwargs)`.
      parameterized_observation_fn: `callable` with signature
        `observation_dist = parameterized_observation_fn(
        step, state, parameters, **kwargs)`.
      parameterized_initial_state_proposal_fn: optional `callable` with
        signature `initial_state_proposal =
        parameterized_initial_state_proposal_fn(parameters)` where `parameters`
        has the form of a sample from `parameter_prior`, and
        `initial_state_proposal` is a distribution over the initial state.
      parameterized_proposal_fn: optional `callable` with signature
        `next_state_dist = parameterized_transition_fn(
        step, state, parameters, **kwargs)`.
        Default value: `None`.
      parameter_constraining_bijector: optional `tfb.Bijector` instance
        such that `parameter_constraining_bijector.forward(x)` returns valid
        parameters for any real-valued `x` of the same structure and shape
        as `parameters`. If `None`, the default bijector of the provided
        `parameter_prior` will be used.
        Default value: `None`.
      name: `str` name for ops constructed by this object.
        Default value: `iterated_filter`.

    #### Example

    We'll walk through applying iterated filtering to a toy
    Susceptible-Infected-Recovered (SIR) model, a [compartmental model](
    https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#The_SIR_model)
    of infectious disease. Note that the model we use here is extremely
    simplified and is intended as a pedagogical example; it should not be
    interpreted to describe disease spread in the real world.

    We begin by specifying a prior distribution over the parameters to be
    inferred, thus defining the structure of the parameter space and the support
    of the parameters (which will imply a default constraining bijector). Here
    we'll use uniform priors over ranges that we expect to contain the
    parameters:

    ```python
    parameter_prior = tfd.JointDistributionNamed({
        'infection_rate': tfd.Uniform(low=0., high=3.),
        'recovery_rate': tfd.Uniform(low=0., high=3.),
    })
    ```

    The model specification itself is identical to that used by
    `tfp.experimental.mcmc.infer_trajectories`, except that each component
    accepts an additional `parameters` keyword argument. We start by specifying
    a parameterized prior on initial states. In this case, our state
    includes the current number of susceptible and infected individuals
    (the third compartment, recovered individuals, is implicitly defined
    to include the remaining population). We'll also include, as auxiliary
    variables, the daily counts of new infections and new recoveries; these
    will help ensure that people shift consistently across compartments.

    ```python
    population_size = 1000
    initial_state_prior_fn = lambda parameters: tfd.JointDistributionNamed({
        'new_infections': tfd.Poisson(parameters['infection_rate']),
        'new_recoveries': tfd.Deterministic(
            tf.broadcast_to(0., tf.shape(parameters['recovery_rate']))),
        'susceptible': (lambda new_infections:
                        tfd.Deterministic(population_size - new_infections)),
        'infected': (lambda new_infections:
                     tfd.Deterministic(new_infections))})
    ```

    **Note**: the state prior must have the same batch shape as the
    passed-in parameters; equivalently, it must sample a full state for each
    parameter particle. If any part of the state prior does not depend
    on the parameters, you must manually ensure that it has the appropriate
    batch shape. For example, in the definition of `new_recoveries` above,
    applying `broadcast_to` with the shape of a parameter ensures that
    the batch shape is maintained.

    Next, we specify a transition model. This takes the state at the
    previous day, along with parameters, and returns a distribution
    over the state for the current day.

    ```python
    def parameterized_infection_dynamics(_, previous_state, parameters):
      new_infections = tfd.Poisson(
          parameters['infection_rate'] * previous_state['infected'] *
          previous_state['susceptible'] / population_size)
      new_recoveries = tfd.Poisson(
          previous_state['infected'] * parameters['recovery_rate'])
      return tfd.JointDistributionNamed({
          'new_infections': new_infections,
          'new_recoveries': new_recoveries,
          'susceptible': lambda new_infections: tfd.Deterministic(
            tf.maximum(0., previous_state['susceptible'] - new_infections)),
          'infected': lambda new_infections, new_recoveries: tfd.Deterministic(
            tf.maximum(0.,
                       (previous_state['infected'] +
                        new_infections - new_recoveries)))})
    ```

    Finally, assume that every day we get to observe noisy counts of new
    infections and recoveries.

    ```python
    def parameterized_infection_observations(_, state, parameters):
      del parameters  # Not used.
      return tfd.JointDistributionNamed({
          'new_infections': tfd.Poisson(state['new_infections'] + 0.1),
          'new_recoveries': tfd.Poisson(state['new_recoveries'] + 0.1)})
    ```

    Combining these components, an `IteratedFilter` augments
    the state space to include parameters that may change over time.

    ```python
    iterated_filter = tfp.experimental.sequential.IteratedFilter(
      parameter_prior=parameter_prior,
      parameterized_initial_state_prior_fn=initial_state_prior_fn,
      parameterized_transition_fn=parameterized_infection_dynamics,
      parameterized_observation_fn=parameterized_infection_observations)
    ```

    We may then run the filter to estimate parameters from a series
    of observations:

    ```python
     # Simulated with `infection_rate=1.2` and `recovery_rate=0.1`.
     observed_values = {
       'new_infections': tf.convert_to_tensor([
          2., 7., 14., 24., 45., 93., 160., 228., 252., 158.,  17.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
       'new_recoveries': tf.convert_to_tensor([
          0., 0., 3., 4., 3., 8., 12., 31., 49., 73., 85., 65., 71.,
          58., 42., 65., 36., 31., 32., 27., 31., 20., 19., 19., 14., 27.])
     }
     parameter_particles = iterated_filter.estimate_parameters(
         observations=observed_values,
         num_iterations=20,
         num_particles=4096,
         initial_perturbation_scale=1.0,
         cooling_schedule=(
             tfp.experimental.sequential.geometric_cooling_schedule(
                 0.001, k=20)),
         seed=test_util.test_seed())
     print('Mean of parameter particles from final iteration: {}'.format(
       tf.nest.map_structure(lambda x: tf.reduce_mean(x[-1], axis=0),
                             parameter_particles)))
     print('Standard deviation of parameter particles from '
           'final iteration: {}'.format(
           tf.nest.map_structure(lambda x: tf.math.reduce_std(x[-1], axis=0),
                                 parameter_particles)))
    ```

    For more control, we could alternately choose to run filtering iterations
    on the augmented model manually, using the filter of our choice.
    For example, manually invoking `infer_trajectories` would allow us
    to inspect the parameter and state values at all timesteps, and their
    corresponding log-probabilities:

    ```python
    trajectories, lps = tfp.experimental.mcmc.infer_trajectories(
      observations=observations,
      initial_state_prior=iterated_filter.joint_initial_state_prior,
      transition_fn=functools.partial(
          iterated_filter.joint_transition_fn,
          perturbation_scale=perturbation_scale),
      observation_fn=iterated_filter.joint_observation_fn,
      proposal_fn=iterated_filter.joint_proposal_fn,
      initial_state_proposal=iterated_filter.joint_initial_state_proposal(
          initial_unconstrained_parameters),
      num_particles=4096)
    ```

    #### References:

    [1] Edward L. Ionides, Dao Nguyen, Yves Atchade, Stilian Stoev, and Aaron A.
    King. Inference for dynamic and latent variable models via iterated,
    perturbed Bayes maps. _Proceedings of the National Academy of Sciences_
    112, no. 3: 719-724, 2015.
    https://www.pnas.org/content/pnas/112/3/719.full.pdf
    """
        name = name or 'IteratedFilter'
        with tf.name_scope(name):
            self._parameter_prior = parameter_prior
            self._parameterized_initial_state_prior_fn = (
                parameterized_initial_state_prior_fn)

            if parameter_constraining_bijector is None:
                parameter_constraining_bijector = (
                    parameter_prior.experimental_default_event_space_bijector(
                    ))
            self._parameter_constraining_bijector = parameter_constraining_bijector

            # Augment the prior to include both parameters and states.
            self._joint_initial_state_prior = joint_prior_on_parameters_and_state(
                parameter_prior,
                parameterized_initial_state_prior_fn,
                parameter_constraining_bijector,
                prior_is_constrained=True)

            # Check that prior samples have a consistent number of particles.
            # TODO(davmre): remove the need for dummy shape dependencies,
            # and this check, by using `JointDistributionNamedAutoBatched` with
            # auto-vectorization enabled in `joint_prior_on_parameters_and_state`.

            num_particles_canary = 13
            canary_seed = samplers.zeros_seed()

            def _get_shape_1(x):
                if hasattr(x, 'state'):
                    x = x.state
                return tf.TensorShape(x.shape[1:2])

            prior_static_sample_shapes = tf.nest.map_structure(
                # Sample shape [0, num_particles_canary] particles (size will be zero)
                # then trim off the leading 0 and (possibly) any event shape.
                # We expect shape [num_particles_canary] to remain.
                _get_shape_1,
                self._joint_initial_state_prior.sample(
                    [0, num_particles_canary], seed=canary_seed))
            if not all([
                    tensorshape_util.is_compatible_with(
                        s[:1], [num_particles_canary])
                    for s in tf.nest.flatten(prior_static_sample_shapes)
            ]):
                raise ValueError(
                    'The specified prior does not generate consistent '
                    'shapes when sampled. Please verify that all parts of '
                    '`initial_state_prior_fn` have batch shape matching '
                    'that of the parameters. This may require creating '
                    '"dummy" dependencies on parameters; for example: '
                    '`tf.broadcast_to(value, tf.shape(parameter))`. (in a '
                    f'test sample with {num_particles_canary} particles, we expected '
                    'all) values to have shape compatible with '
                    f'[{num_particles_canary}, ...]; '
                    f'saw shapes {prior_static_sample_shapes})')

            # Augment the transition and observation fns to cover both
            # parameters and states.
            self._joint_transition_fn = augment_transition_fn_with_parameters(
                parameter_prior, parameterized_transition_fn,
                parameter_constraining_bijector)
            self._joint_observation_fn = augment_observation_fn_with_parameters(
                parameterized_observation_fn, parameter_constraining_bijector)

            # If given a proposal for the initial state, augment it into a joint
            # proposal over parameters and states.
            joint_initial_state_proposal = None
            if parameterized_initial_state_proposal_fn:
                joint_initial_state_proposal = joint_prior_on_parameters_and_state(
                    parameter_prior, parameterized_initial_state_proposal_fn,
                    parameter_constraining_bijector)
            else:
                parameterized_initial_state_proposal_fn = (
                    parameterized_initial_state_prior_fn)
            self._joint_initial_state_proposal = joint_initial_state_proposal
            self._parameterized_initial_state_proposal_fn = (
                parameterized_initial_state_proposal_fn)

            # If given a conditional proposal fn (for non-initial states), augment
            # it to be joint over states and parameters.
            self._joint_proposal_fn = None
            if parameterized_proposal_fn:
                self._joint_proposal_fn = augment_transition_fn_with_parameters(
                    parameter_prior, parameterized_proposal_fn,
                    parameter_constraining_bijector)

            self._batch_ndims = tf.nest.map_structure(
                ps.rank_from_shape, parameter_prior.batch_shape_tensor())
            self._name = name
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        # TODO(b/159636942): Clean up after 2020-09-20.
        if seed is not None:
            seed = samplers.sanitize_seed(seed)  # preserve for kernel results
            proposal_seed, acceptance_seed = samplers.split_seed(seed)
        else:
            if self._seed_stream.original_seed is not None:
                warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
            acceptance_seed = samplers.sanitize_seed(self._seed_stream())

        with tf.name_scope(mcmc_util.make_name(self.name, 'mh', 'one_step')):
            # Take one inner step.
            inner_kwargs = {} if seed is None else dict(seed=proposal_seed)
            [
                proposed_state,
                proposed_results,
            ] = self.inner_kernel.one_step(
                current_state, previous_kernel_results.accepted_results,
                **inner_kwargs)

            if (not has_target_log_prob(proposed_results)
                    or not has_target_log_prob(
                        previous_kernel_results.accepted_results)):
                raise ValueError('"target_log_prob" must be a member of '
                                 '`inner_kernel` results.')

            # Compute log(acceptance_ratio).
            to_sum = [
                proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob
            ]
            try:
                if (not mcmc_util.is_list_like(
                        proposed_results.log_acceptance_correction)
                        or proposed_results.log_acceptance_correction):
                    to_sum.append(proposed_results.log_acceptance_correction)
            except AttributeError:
                warnings.warn(
                    'Supplied inner `TransitionKernel` does not have a '
                    '`log_acceptance_correction`. Assuming its value is `0.`')
            log_accept_ratio = mcmc_util.safe_sum(
                to_sum, name='compute_log_accept_ratio')

            # If proposed state reduces likelihood: randomly accept.
            # If proposed state increases likelihood: always accept.
            # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
            #       ==> log(u) < log_accept_ratio
            log_uniform = tf.math.log(
                samplers.uniform(shape=prefer_static.shape(
                    proposed_results.target_log_prob),
                                 dtype=dtype_util.base_dtype(
                                     proposed_results.target_log_prob.dtype),
                                 seed=acceptance_seed))
            is_accepted = log_uniform < log_accept_ratio

            next_state = mcmc_util.choose(is_accepted,
                                          proposed_state,
                                          current_state,
                                          name='choose_next_state')

            kernel_results = MetropolisHastingsKernelResults(
                accepted_results=mcmc_util.choose(
                    is_accepted,
                    # We strip seeds when populating `accepted_results` because unlike
                    # other kernel result fields, seeds are not a per-chain value.
                    # Thus it is impossible to choose between a previously accepted
                    # seed value and a proposed seed, since said choice would need to
                    # be made on a per-chain basis.
                    mcmc_util.strip_seeds(proposed_results),
                    previous_kernel_results.accepted_results,
                    name='choose_inner_results'),
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
                extra=[],
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            return next_state, kernel_results
Example #12
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        # TODO(b/159636942): Clean up after 2020-09-20.
        if seed is not None:
            start_trajectory_seed, loop_seed = samplers.split_seed(
                seed, salt='nuts.one_step')
        else:
            if self._seed_stream.original_seed is not None:
                warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
            start_trajectory_seed, loop_seed = samplers.split_seed(
                self._seed_stream(), salt='nuts.one_step')

        with tf.name_scope(self.name + '.one_step'):
            unwrap_state_list = not tf.nest.is_nested(current_state)
            if unwrap_state_list:
                current_state = [current_state]

            current_target_log_prob = previous_kernel_results.target_log_prob
            [init_momentum, init_energy, log_slice_sample
             ] = self._start_trajectory_batched(current_state,
                                                current_target_log_prob,
                                                seed=start_trajectory_seed)

            def _copy(v):
                return v * ps.ones(ps.pad(
                    [2], paddings=[[0, ps.rank(v)]], constant_values=1),
                                   dtype=v.dtype)

            initial_state = TreeDoublingState(
                momentum=init_momentum,
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.grads_target_log_prob
            )
            initial_step_state = tf.nest.map_structure(_copy, initial_state)

            if MULTINOMIAL_SAMPLE:
                init_weight = tf.zeros_like(init_energy)  # log(exp(H0 - H0))
            else:
                init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE)

            candidate_state = TreeDoublingStateCandidate(
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.
                grads_target_log_prob,
                energy=init_energy,
                weight=init_weight)

            initial_step_metastate = TreeDoublingMetaState(
                candidate_state=candidate_state,
                is_accepted=tf.zeros_like(init_energy, dtype=tf.bool),
                momentum_sum=init_momentum,
                energy_diff_sum=tf.zeros_like(init_energy),
                leapfrog_count=tf.zeros_like(init_energy,
                                             dtype=TREE_COUNT_DTYPE),
                continue_tree=tf.ones_like(init_energy, dtype=tf.bool),
                not_divergence=tf.ones_like(init_energy, dtype=tf.bool))

            # Convert the write/read instruction into TensorArray so that it is
            # compatible with XLA.
            write_instruction = tf.TensorArray(
                TREE_COUNT_DTYPE,
                size=len(self._write_instruction),
                clear_after_read=False).unstack(self._write_instruction)
            read_instruction = tf.TensorArray(tf.int32,
                                              size=len(self._read_instruction),
                                              clear_after_read=False).unstack(
                                                  self._read_instruction)

            current_step_meta_info = OneStepMetaInfo(
                log_slice_sample=log_slice_sample,
                init_energy=init_energy,
                write_instruction=write_instruction,
                read_instruction=read_instruction)

            _, _, _, new_step_metastate = tf.while_loop(
                cond=lambda iter_, seed, state, metastate: (  # pylint: disable=g-long-lambda
                    (iter_ < self.max_tree_depth) & tf.reduce_any(
                        metastate.continue_tree)),
                body=lambda iter_, seed, state, metastate: self.
                _loop_tree_doubling(  # pylint: disable=g-long-lambda
                    previous_kernel_results.step_size, previous_kernel_results.
                    momentum_state_memory, current_step_meta_info, iter_,
                    state, metastate, seed),
                loop_vars=(tf.zeros([], dtype=tf.int32,
                                    name='iter'), loop_seed,
                           initial_step_state, initial_step_metastate),
                parallel_iterations=self.parallel_iterations,
            )

            kernel_results = NUTSKernelResults(
                target_log_prob=new_step_metastate.candidate_state.target,
                grads_target_log_prob=(
                    new_step_metastate.candidate_state.target_grad_parts),
                momentum_state_memory=previous_kernel_results.
                momentum_state_memory,
                step_size=previous_kernel_results.step_size,
                log_accept_ratio=tf.math.log(
                    new_step_metastate.energy_diff_sum /
                    tf.cast(new_step_metastate.leapfrog_count,
                            dtype=new_step_metastate.energy_diff_sum.dtype)),
                leapfrogs_taken=(new_step_metastate.leapfrog_count *
                                 self.unrolled_leapfrog_steps),
                is_accepted=new_step_metastate.is_accepted,
                reach_max_depth=new_step_metastate.continue_tree,
                has_divergence=~new_step_metastate.not_divergence,
                energy=new_step_metastate.candidate_state.energy,
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            result_state = new_step_metastate.candidate_state.state
            if unwrap_state_list:
                result_state = result_state[0]

            return result_state, kernel_results