Beispiel #1
0
 def _transpose_around_bijector_fn(self,
                                   bijector_fn,
                                   arg,
                                   src_event_ndims,
                                   dest_event_ndims=None,
                                   fn_reduces_event=False,
                                   **kwargs):
     # This function moves the axes corresponding to `self.sample_shape` to the
     # left of the batch shape, then applies `bijector_fn`, then moves the axes
     # corresponding to `self.sample_shape` back to the event part of the shape.
     #
     # `src_event_ndims` and `dest_event_ndims` indicate the expected event rank
     # (omitting `self.sample_shape`) before and after applying `bijector_fn`.
     #
     # This function arose because forward and inverse ended up being quite
     # similar. It was then only a small generalization to also support {F/I}LDJ.
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     arg_ndims = ps.rank(arg)
     # (1) Expand arg's dims.
     d = arg_ndims - batch_ndims - extra_sample_ndims - src_event_ndims
     arg = tf.reshape(arg,
                      shape=ps.pad(ps.shape(arg),
                                   paddings=[[ps.maximum(0, -d), 0]],
                                   constant_values=1))
     arg_ndims = ps.rank(arg)
     sample_ndims = ps.maximum(0, d)
     # (2) Transpose arg's dims.
     sample_dims = ps.range(0, sample_ndims)
     batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
     extra_sample_dims = ps.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims,
                           arg_ndims)
     perm = ps.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     arg = tf.transpose(arg, perm=perm)
     # (3) Apply underlying bijector.
     result = bijector_fn(arg, **kwargs)
     # (4) Transpose sample_shape from the sample to the event shape.
     result_ndims = ps.rank(result)
     if fn_reduces_event:
         dest_event_ndims = 0
     d = result_ndims - batch_ndims - extra_sample_ndims - dest_event_ndims
     if fn_reduces_event:
         # In some cases, fn may reduce event too far, i.e. ildj may return a
         # scalar `0.`, which won't work with the transpose we do below.
         result = tf.reshape(result,
                             shape=ps.pad(ps.shape(result),
                                          paddings=[[ps.maximum(0, -d), 0]],
                                          constant_values=1))
         result_ndims = ps.rank(result)
     sample_ndims = ps.maximum(0, d)
     sample_dims = ps.range(0, sample_ndims)
     extra_sample_dims = ps.range(sample_ndims,
                                  sample_ndims + extra_sample_ndims)
     batch_dims = ps.range(sample_ndims + extra_sample_ndims,
                           sample_ndims + extra_sample_ndims + batch_ndims)
     event_dims = ps.range(sample_ndims + extra_sample_ndims + batch_ndims,
                           result_ndims)
     perm = ps.concat(
         [sample_dims, batch_dims, extra_sample_dims, event_dims], axis=0)
     return tf.transpose(result, perm=perm)
Beispiel #2
0
def auto_correlation(x,
                     axis=-1,
                     max_lags=None,
                     center=True,
                     normalize=True,
                     name='auto_correlation'):
    """Auto correlation along one axis.

  Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation
  `RXX` may be defined as  (with `E` expectation and `Conj` complex conjugate)

  ```
  RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) },
  W[n]   := (X[n] - MU) / S,
  MU     := E{ X[0] },
  S**2   := E{ (X[0] - MU) Conj(X[0] - MU) }.
  ```

  This function takes the viewpoint that `x` is (along one axis) a finite
  sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an
  estimate of `RXX[m]` as follows:

  After extending `x` from length `L` to `inf` by zero padding, the auto
  correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as

  ```
  rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]),
  w[n]   := (x[n] - mu) / s,
  mu     := L**-1 sum_n x[n],
  s**2   := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu)
  ```

  The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users
  often set `max_lags` small enough so that the entire output is meaningful.

  Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by
  `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation
  contains a slight bias, which goes to zero as `len(x) - m --> infinity`.

  Args:
    x:  `float32` or `complex64` `Tensor`.
    axis:  Python `int`. The axis number along which to compute correlation.
      Other dimensions index different batch members.
    max_lags:  Positive `int` tensor.  The maximum value of `m` to consider (in
      equation above).  If `max_lags >= x.shape[axis]`, we effectively re-set
      `max_lags` to `x.shape[axis] - 1`.
    center:  Python `bool`.  If `False`, do not subtract the mean estimate `mu`
      from `x[n]` when forming `w[n]`.
    normalize:  Python `bool`.  If `False`, do not divide by the variance
      estimate `s**2` when forming `w[n]`.
    name:  `String` name to prepend to created ops.

  Returns:
    `rxx`: `Tensor` of same `dtype` as `x`.  `rxx.shape[i] = x.shape[i]` for
      `i != axis`, and `rxx.shape[axis] = max_lags + 1`.

  Raises:
    TypeError:  If `x` is not a supported type.
  """
    # Implementation details:
    # Extend length N / 2 1-D array x to length N by zero padding onto the end.
    # Then, set
    #   F[x]_k := sum_n x_n exp{-i 2 pi k n / N }.
    # It is not hard to see that
    #   F[x]_k Conj(F[x]_k) = F[R]_k, where
    #   R_m := sum_n x_n Conj(x_{(n - m) mod N}).
    # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m].

    # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT
    # based version of estimating RXX.
    # Note that this is a special case of the Wiener-Khinchin Theorem.
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name='x')

        # Rotate dimensions of x in order to put axis at the rightmost dim.
        # FFT op requires this.
        rank = prefer_static.rank(x)
        if axis < 0:
            axis = rank + axis
        shift = rank - 1 - axis
        # Suppose x.shape[axis] = T, so there are T 'time' steps.
        #   ==> x_rotated.shape = B + [T],
        # where B is x_rotated's batch shape.
        x_rotated = distribution_util.rotate_transpose(x, shift)

        if center:
            x_rotated -= tf.reduce_mean(x_rotated, axis=-1, keepdims=True)

        # x_len = N / 2 from above explanation.  The length of x along axis.
        # Get a value for x_len that works in all cases.
        x_len = prefer_static.shape(x_rotated)[-1]

        # TODO(langmore) Investigate whether this zero padding helps or hurts.  At
        # the moment is necessary so that all FFT implementations work.
        # Zero pad to the next power of 2 greater than 2 * x_len, which equals
        # 2**(ceil(Log_2(2 * x_len))).  Note: Log_2(X) = Log_e(X) / Log_e(2).
        x_len_float64 = tf.cast(x_len, np.float64)
        target_length = tf.pow(
            np.float64(2.),
            tf.math.ceil(tf.math.log(x_len_float64 * 2) / np.log(2.)))
        pad_length = tf.cast(target_length - x_len_float64, np.int32)

        # We should have:
        # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length]
        #                     = B + [T + pad_length]
        x_rotated_pad = distribution_util.pad(x_rotated,
                                              axis=-1,
                                              back=True,
                                              count=pad_length)

        dtype = x.dtype
        if not dtype_util.is_complex(dtype):
            if not dtype_util.is_floating(dtype):
                raise TypeError(
                    'Argument x must have either float or complex dtype'
                    ' found: {}'.format(dtype))
            x_rotated_pad = tf.complex(
                x_rotated_pad,
                dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.))

        # Autocorrelation is IFFT of power-spectral density (up to some scaling).
        fft_x_rotated_pad = tf.signal.fft(x_rotated_pad)
        spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad)
        # shifted_product is R[m] from above detailed explanation.
        # It is the inner product sum_n X[n] * Conj(X[n - m]).
        shifted_product = tf.signal.ifft(spectral_density)

        # Cast back to real-valued if x was real to begin with.
        shifted_product = tf.cast(shifted_product, dtype)

        # Figure out if we can deduce the final static shape, and set max_lags.
        # Use x_rotated as a reference, because it has the time dimension in the far
        # right, and was created before we performed all sorts of crazy shape
        # manipulations.
        know_static_shape = True
        if not tensorshape_util.is_fully_defined(x_rotated.shape):
            know_static_shape = False
        if max_lags is None:
            max_lags = x_len - 1
        else:
            max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
            max_lags_ = tf.get_static_value(max_lags)
            if max_lags_ is None or not know_static_shape:
                know_static_shape = False
                max_lags = tf.minimum(x_len - 1, max_lags)
            else:
                max_lags = min(x_len - 1, max_lags_)

        # Chop off the padding.
        # We allow users to provide a huge max_lags, but cut it off here.
        # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags]
        shifted_product_chopped = shifted_product[..., :max_lags + 1]

        # If possible, set shape.
        if know_static_shape:
            chopped_shape = tensorshape_util.as_list(x_rotated.shape)
            chopped_shape[-1] = min(x_len, max_lags + 1)
            shifted_product_chopped.set_shape(chopped_shape)

        # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]).  The
        # other terms were zeros arising only due to zero padding.
        # `denominator = (N / 2 - m)` (defined below) is the proper term to
        # divide by to make this an unbiased estimate of the expectation
        # E[X[n] Conj(X[n - m])].
        x_len = tf.cast(x_len, dtype_util.real_dtype(dtype))
        max_lags = tf.cast(max_lags, dtype_util.real_dtype(dtype))
        denominator = x_len - tf.range(0., max_lags + 1.)
        denominator = tf.cast(denominator, dtype)
        shifted_product_rotated = shifted_product_chopped / denominator

        if normalize:
            shifted_product_rotated /= shifted_product_rotated[..., :1]

        # Transpose dimensions back to those of x.
        return distribution_util.rotate_transpose(shifted_product_rotated,
                                                  -shift)
def where_left_justified_mask(mask, vals1, vals2, name=None):
  """Like `tf.where`, but broadcasts the `mask` left-justified."""
  with tf.name_scope(name or 'where_left_justified_mask'):
    target_rank = ps.maximum(ps.rank(vals1), ps.rank(vals2))
    bcast_mask = left_justified_expand_dims_to(mask, target_rank)
    return tf.where(bcast_mask, vals1, vals2)
Beispiel #4
0
    def _loop_build_sub_tree(self, directions, integrator,
                             current_step_meta_info, iter_,
                             energy_diff_sum_previous,
                             momentum_cumsum_previous, leapfrogs_taken,
                             prev_tree_state, candidate_tree_state,
                             continue_tree_previous, not_divergent_previous,
                             momentum_state_memory, seed):
        """Base case in tree doubling."""
        acceptance_seed, next_seed = samplers.split_seed(seed)
        with tf.name_scope('loop_build_sub_tree'):
            # Take one leapfrog step in the direction v and check divergence
            [
                next_momentum_parts, next_state_parts, next_target,
                next_target_grad_parts
            ] = integrator(prev_tree_state.momentum, prev_tree_state.state,
                           prev_tree_state.target,
                           prev_tree_state.target_grad_parts)

            next_tree_state = TreeDoublingState(
                momentum=next_momentum_parts,
                state=next_state_parts,
                target=next_target,
                target_grad_parts=next_target_grad_parts)
            momentum_cumsum = [
                p0 + p1 for p0, p1 in zip(momentum_cumsum_previous,
                                          next_momentum_parts)
            ]
            # If the tree have not yet terminated previously, we count this leapfrog.
            leapfrogs_taken = tf.where(continue_tree_previous,
                                       leapfrogs_taken + 1, leapfrogs_taken)

            write_instruction = current_step_meta_info.write_instruction
            read_instruction = current_step_meta_info.read_instruction
            init_energy = current_step_meta_info.init_energy

            if GENERALIZED_UTURN:
                state_to_write = momentum_cumsum_previous
                state_to_check = momentum_cumsum
            else:
                state_to_write = next_state_parts
                state_to_check = next_state_parts

            batch_shape = ps.shape(next_target)
            has_not_u_turn_init = ps.ones(batch_shape, dtype=tf.bool)

            read_index = read_instruction.gather([iter_])[0]
            no_u_turns_within_tree = has_not_u_turn_at_all_index(  # pylint: disable=g-long-lambda
                read_index,
                directions,
                momentum_state_memory,
                next_momentum_parts,
                state_to_check,
                has_not_u_turn_init,
                log_prob_rank=ps.rank(next_target),
                shard_axis_names=self.experimental_shard_axis_names)

            # Get index to write state into memory swap
            write_index = write_instruction.gather([iter_])
            momentum_state_memory = MomentumStateSwap(
                momentum_swap=[
                    _safe_tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.momentum_swap,
                                        next_momentum_parts)
                ],
                state_swap=[
                    _safe_tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.state_swap,
                                        state_to_write)
                ])

            energy = compute_hamiltonian(
                next_target,
                next_momentum_parts,
                shard_axis_names=self.experimental_shard_axis_names)
            current_energy = tf.where(tf.math.is_nan(energy),
                                      tf.constant(-np.inf, dtype=energy.dtype),
                                      energy)
            energy_diff = current_energy - init_energy

            if MULTINOMIAL_SAMPLE:
                not_divergent = -energy_diff < self.max_energy_diff
                weight_sum = log_add_exp(candidate_tree_state.weight,
                                         energy_diff)
                log_accept_thresh = energy_diff - weight_sum
            else:
                log_slice_sample = current_step_meta_info.log_slice_sample
                not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
                # Uniform sampling on the trajectory within the subtree across valid
                # samples.
                is_valid = log_slice_sample <= energy_diff
                weight_sum = tf.where(is_valid,
                                      candidate_tree_state.weight + 1,
                                      candidate_tree_state.weight)
                log_accept_thresh = tf.where(
                    is_valid,
                    -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
                    tf.constant(-np.inf, dtype=tf.float32))
            u = tf.math.log1p(-samplers.uniform(shape=batch_shape,
                                                dtype=log_accept_thresh.dtype,
                                                seed=acceptance_seed))
            is_sample_accepted = u <= log_accept_thresh

            next_candidate_tree_state = TreeDoublingStateCandidate(
                state=[
                    bu.where_left_justified_mask(is_sample_accepted, s0, s1)
                    for s0, s1 in zip(next_state_parts,
                                      candidate_tree_state.state)
                ],
                target=bu.where_left_justified_mask(
                    is_sample_accepted, next_target,
                    candidate_tree_state.target),
                target_grad_parts=[
                    bu.where_left_justified_mask(is_sample_accepted, grad0,
                                                 grad1)
                    for grad0, grad1 in zip(
                        next_target_grad_parts,
                        candidate_tree_state.target_grad_parts)
                ],
                energy=bu.where_left_justified_mask(
                    is_sample_accepted, current_energy,
                    candidate_tree_state.energy),
                weight=weight_sum)

            continue_tree = not_divergent & continue_tree_previous
            continue_tree_next = no_u_turns_within_tree & continue_tree

            not_divergent_tokeep = tf.where(
                continue_tree_previous, not_divergent,
                ps.ones(batch_shape, dtype=tf.bool))

            # min(1., exp(energy_diff)).
            exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.))
            energy_diff_sum = tf.where(
                continue_tree, energy_diff_sum_previous + exp_energy_diff,
                energy_diff_sum_previous)

            return (
                iter_ + 1,
                next_seed,
                energy_diff_sum,
                momentum_cumsum,
                leapfrogs_taken,
                next_tree_state,
                next_candidate_tree_state,
                continue_tree_next,
                not_divergent_previous & not_divergent_tokeep,
                momentum_state_memory,
            )
Beispiel #5
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                self.target_log_prob_fn, inverse_temperatures)
            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`, and no seed
            # expected by `kernel.one_step`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we warn and fall back to the previous behavior.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is '
                    'deprecated. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel, self._seed_stream())

            # Now that we've constructed the TransitionKernel instance:
            # - If we were given a seed, we sanitize it to stateless and pass along
            #   to `kernel.one_step`. If it doesn't like that, we crash and propagate
            #   the error.  Rationale: The contract is stateless sampling given
            #   seed, and doing otherwise would not meet it.
            # - If not given a seed, we don't pass one along. This avoids breaking
            #   underlying kernels lacking a `seed` arg on `one_step`.
            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                seed = samplers.sanitize_seed(seed)
                inner_seed, swap_seed, logu_seed = samplers.split_seed(
                    seed, n=3, salt='remc_one_step')
                inner_kwargs = dict(seed=inner_seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                inner_kwargs = {}
                swap_seed, logu_seed = samplers.split_seed(self._seed_stream())
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                **inner_kwargs)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            return states, post_swap_kernel_results
Beispiel #6
0
  def _sample_n(self, n, seed=None, conditional_input=None, training=False):
    """Samples from the distribution, with optional conditional input.

    Args:
      n: `int`, number of samples desired.
      seed: `int`, seed for RNG. Setting a random seed enforces reproducability
        of the samples between sessions (not within a single session).
      conditional_input: `Tensor` on which to condition the distribution (e.g.
        class labels), or `None`.
      training: `bool` or `None`. If `bool`, it controls the dropout layer,
        where `True` implies dropout is active. If `None`, it defers to Keras'
        handling of train/eval status.
    Returns:
      samples: a `Tensor` of shape `[n, height, width, num_channels]`.
    """
    if conditional_input is not None:
      conditional_input = tf.convert_to_tensor(
          conditional_input, dtype=self.dtype)
      conditional_event_rank = tensorshape_util.rank(self.conditional_shape)
      conditional_input_shape = prefer_static.shape(conditional_input)
      conditional_sample_rank = prefer_static.rank(
          conditional_input) - conditional_event_rank

      # If `conditional_input` has no sample dimensions, prepend a sample
      # dimension
      if conditional_sample_rank == 0:
        conditional_input = conditional_input[tf.newaxis, ...]
        conditional_sample_rank = 1

      # Assert that the conditional event shape in the `PixelCnnNetwork` is the
      # same as that implied by `conditional_input`.
      conditional_event_shape = conditional_input_shape[
          conditional_sample_rank:]
      with tf.control_dependencies([
          tf.assert_equal(self.conditional_shape, conditional_event_shape)]):

        conditional_sample_shape = conditional_input_shape[
            :conditional_sample_rank]
        repeat = n // prefer_static.reduce_prod(conditional_sample_shape)
        h = tf.reshape(
            conditional_input,
            prefer_static.concat([(-1,), self.conditional_shape], axis=0))
        h = tf.tile(h,
                    prefer_static.pad(
                        [repeat], paddings=[[0, conditional_event_rank]],
                        constant_values=1))

    samples_0 = tf.random.uniform(
        prefer_static.concat([(n,), self.event_shape], axis=0),
        minval=-1., maxval=1., dtype=self.dtype, seed=seed)
    inputs = samples_0 if conditional_input is None else [samples_0, h]
    params_0 = self.network(inputs, training=training)
    samples_0 = self._sample_channels(*params_0, seed=seed)

    image_height, image_width, _ = tensorshape_util.as_list(self.event_shape)
    def loop_body(index, samples):
      """Loop for iterative pixel sampling.

      Args:
        index: 0D `Tensor` of type `int32`. Index of the current pixel.
        samples: 4D `Tensor`. Images with pixels sampled in raster order, up to
          pixel `[index]`, with dimensions `[batch_size, height, width,
          num_channels]`.

      Returns:
        samples: 4D `Tensor`. Images with pixels sampled in raster order, up to
          and including pixel `[index]`, with dimensions `[batch_size, height,
          width, num_channels]`.
      """
      inputs = samples if conditional_input is None else [samples, h]
      params = self.network(inputs, training=training)
      samples_new = self._sample_channels(*params, seed=seed)

      # Update the current pixel
      samples = tf.transpose(samples, [1, 2, 3, 0])
      samples_new = tf.transpose(samples_new, [1, 2, 3, 0])
      row, col = index // image_width, index % image_width
      updates = samples_new[row, col, ...][tf.newaxis, ...]
      samples = tf.tensor_scatter_nd_update(samples, [[row, col]], updates)
      samples = tf.transpose(samples, [3, 0, 1, 2])

      return index + 1, samples

    index0 = tf.zeros([], dtype=tf.int32)

    # Construct the while loop for sampling
    total_pixels = image_height * image_width
    loop_cond = lambda ind, _: tf.less(ind, total_pixels)
    init_vars = (index0, samples_0)
    _, samples = tf.while_loop(loop_cond, loop_body, init_vars,
                               parallel_iterations=1)

    transformed_samples = (self._low +
                           0.5 * (self._high - self._low) * (samples + 1.))
    return tf.round(transformed_samples)
Beispiel #7
0
def _compute_log_acceptance_correction(current_momentums,
                                       proposed_momentums,
                                       independent_chain_ndims,
                                       name=None):
    """Helper to `kernel` which computes the log acceptance-correction.

  A sufficient but not necessary condition for the existence of a stationary
  distribution, `p(x)`, is "detailed balance", i.e.:

  ```none
  p(x'|x) p(x) = p(x|x') p(x')
  ```

  In the Metropolis-Hastings algorithm, a state is proposed according to
  `g(x'|x)` and accepted according to `a(x'|x)`, hence
  `p(x'|x) = g(x'|x) a(x'|x)`.

  Inserting this into the detailed balance equation implies:

  ```none
      g(x'|x) a(x'|x) p(x) = g(x|x') a(x|x') p(x')
  ==> a(x'|x) / a(x|x') = p(x') / p(x) [g(x|x') / g(x'|x)]    (*)
  ```

  One definition of `a(x'|x)` which satisfies (*) is:

  ```none
  a(x'|x) = min(1, p(x') / p(x) [g(x|x') / g(x'|x)])
  ```

  (To see that this satisfies (*), notice that under this definition only at
  most one `a(x'|x)` and `a(x|x') can be other than one.)

  We call the bracketed term the "acceptance correction".

  In the case of UncalibratedHMC, the log acceptance-correction is not the log
  proposal-ratio. UncalibratedHMC augments the state-space with momentum, z.
  Assuming a standard Gaussian distribution for momentums, the chain eventually
  converges to:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  ```

  Relating this back to Metropolis-Hastings parlance, for HMC we have:

  ```none
  p([x, z]) propto= target_prob(x) exp(-0.5 z**2)
  g([x, z] | [x', z']) = g([x', z'] | [x, z])
  ```

  In other words, the MH bracketed term is `1`. However, because we desire to
  use a general MH framework, we can place the momentum probability ratio inside
  the metropolis-correction factor thus getting an acceptance probability:

  ```none
                       target_prob(x')
  accept_prob(x'|x) = -----------------  [exp(-0.5 z**2) / exp(-0.5 z'**2)]
                       target_prob(x)
  ```

  (Note: we actually need to handle the kinetic energy change at each leapfrog
  step, but this is the idea.)

  Args:
    current_momentums: `Tensor` representing the value(s) of the current
      momentum(s) of the state (parts).
    proposed_momentums: `Tensor` representing the value(s) of the proposed
      momentum(s) of the state (parts).
    independent_chain_ndims: Scalar `int` `Tensor` representing the number of
      leftmost `Tensor` dimensions which index independent chains.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'compute_log_acceptance_correction').

  Returns:
    log_acceptance_correction: `Tensor` representing the `log`
      acceptance-correction.  (See docstring for mathematical definition.)
  """
    with tf.name_scope(name or 'compute_log_acceptance_correction'):
        sum_sq = lambda v: tf.reduce_sum(
            v**2.,
            axis=prefer_static.range(  # pylint: disable=g-long-lambda
                independent_chain_ndims, prefer_static.rank(v)))
        current_kinetic = tf.add_n([sum_sq(v) for v in current_momentums])
        proposed_kinetic = tf.add_n([sum_sq(v) for v in proposed_momentums])
        return 0.5 * mcmc_util.safe_sum([current_kinetic, -proposed_kinetic])
Beispiel #8
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):

                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x,
                                                  kernel=kernel,
                                                  filter_shape=filter_shape,
                                                  strides=(strides, ) * rank,
                                                  padding=padding,
                                                  dilations=dilations,
                                                  c_out=c_out,
                                                  batch_shape=batch_shape,
                                                  event_shape=event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                x_pad_shape = ps.shape(x_pad)[:-3]
                flat_shape = ps.pad(x_pad_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.reshape(x_pad, shape=flat_shape)

                idx, s = im2row_index(
                    (xh + tf.reduce_sum(padding_vals[0]),
                     xw + tf.reduce_sum(padding_vals[1]), c_in),
                    block_shape=(sub_fh, sub_fw),
                    slice_step=(1, 1),
                    dilations=dilations)

                x_ = tf.gather(flat_x, indices=idx, axis=-1)
                im_x = tf.reshape(x_,
                                  shape=ps.concat([x_pad_shape, s], axis=0))

                # Add channels to subkernel indices
                idx_event = event_ind * [[c_in, 1]]
                idx_event_channels = (idx_event[tf.newaxis] + tf.stack(
                    [ps.range(c_in),
                     tf.zeros(
                         (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :])
                idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels,
                                                         block_shape=[c_in],
                                                         crops=[[0, 0]]),
                                       axis=0)
                idx_event_broadcast = tf.broadcast_to(
                    idx_event,
                    shape=ps.concat(
                        [kernel_batch, ps.shape(idx_event)], axis=0))

                # Add cartesian product of batch indices, since scatter_nd can only be
                # applied to leading dimensions.
                idx_batch = tf.stack(tf.meshgrid(*[
                    ps.range(b_, delta=1, dtype=dtype)
                    for b_ in tf.unstack(kernel_batch)
                ],
                                                 indexing='ij'),
                                     axis=ps.size(kernel_batch))

                idx_batch = tf.cast(idx_batch,
                                    dtype=dtype)  # empty tensor is float

                idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
                    (ps.shape(idx_event)[0], 1), dtype=dtype)
                idx_kernel = tf.concat(
                    [idx_batch_broadcast, idx_event_broadcast], axis=-1)

                kernel_mat = tf.scatter_nd(
                    idx_kernel,
                    updates=kernel,
                    shape=ps.cast(ps.concat([
                        kernel_batch,
                        [sub_fh * sub_fw * c_in, strides**2, c_out]
                    ],
                                            axis=0),
                                  dtype=dtype))

                kernel_mat = tf.reshape(
                    kernel_mat,
                    shape=ps.concat(
                        [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]],
                        axis=0))

                kernel_mat = kernel_mat[..., tf.newaxis, :, :]
                out = tf.matmul(im_x, kernel_mat)
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)

                if strides > 1:
                    tot_size = tf.reduce_prod(broadcast_batch_shape)
                    flat_out = tf.reshape(out,
                                          shape=ps.concat([[tot_size],
                                                           ps.shape(out)[-3:]],
                                                          axis=0))
                    out = tf.nn.depth_to_space(flat_out, block_size=strides)

                if padding == 'VALID':
                    out_height = fh + strides * (xh - 1)
                    out_width = fw + strides * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * strides
                    out_width = xw * strides

                out = out[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
                out = tf.reshape(
                    out,
                    shape=ps.concat([
                        broadcast_batch_shape, [out_height, out_width, c_out]
                    ],
                                    axis=0))
                return out
Beispiel #9
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  c_out, batch_shape,
                                                  event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)

                ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1
                ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1

                def loop_body(i, outputs):
                    subkernel_ind = kernels_ind.read(i)
                    fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
                    eh = ex_h + fh_ - 1
                    ew = ex_w + fw_ - 1

                    subkernel_ind = ps.reshape(ps.reshape(
                        subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] +
                                               ps.range(c_in),
                                               shape=[-1])

                    k = tf.gather(kernel, subkernel_ind, axis=-2)
                    ind, shape = im2row_index([eh, ew, c_in],
                                              block_shape=(fh_, fw_),
                                              slice_step=(1, 1),
                                              dilations=dilations)
                    x_i = x_pad[..., :eh, :ew, :]
                    x_i_shape = ps.shape(x_i)
                    flat_shape = ps.pad(x_i_shape[:-3],
                                        paddings=[[0, 1]],
                                        constant_values=-1)
                    flat_x = tf.reshape(x_i, flat_shape)
                    x_ = tf.gather(flat_x, ind, axis=-1)
                    im_x = tf.reshape(
                        x_, ps.concat([x_i_shape[:-3], shape], axis=0))
                    outputs = outputs.write(
                        i,
                        tf.matmul(
                            im_x,
                            tf.reshape(
                                k,
                                ps.concat([
                                    kernel_batch, [1, fh_ * fw_ * c_in, c_out]
                                ],
                                          axis=0))))
                    return i + 1, outputs

                outputs = tf.TensorArray(dtype=input_dtype,
                                         infer_shape=False,
                                         size=1,
                                         dynamic_size=True)

                _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body,
                                           [0, outputs])

                y = outputs.concat()

                m = tf.reduce_prod(ps.shape(y)[:-3])
                y_ = tf.reshape(y,
                                shape=ps.concat([[m], ps.shape(y)[-3:]],
                                                axis=0))
                y2 = tf.batch_to_space(y_,
                                       strides,
                                       crops=tf.zeros([2, 2], dtype=tf.int64))
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)
                y2 = tf.reshape(
                    y2,
                    ps.concat([broadcast_batch_shape,
                               ps.shape(y2)[-3:]],
                              axis=0))

                if padding == 'VALID':
                    out_height = fh + sh * (xh - 1)
                    out_width = fw + sw * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * sh
                    out_width = xw * sw

                return y2[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
Beispiel #10
0
    def op(x, kernel):
        input_dtype = dtype_util.common_dtype([x, kernel],
                                              dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel')

        batch_shape, event_shape = ps.split(ps.shape(x),
                                            num_or_size_splits=[-1, 3])
        xh, xw, c_in = ps.unstack(event_shape, num=3)
        fh, fw = filter_shape

        assertions = _maybe_validate_input_shapes(ps.shape(kernel),
                                                  channels_in=c_in,
                                                  filter_height=fh,
                                                  filter_width=fw,
                                                  validate_args=validate_args)

        with tf.control_dependencies(assertions):
            if tf.get_static_value(ps.rank(kernel)) == 2:
                flat_x = tf.reshape(x,
                                    shape=ps.concat([[-1], event_shape],
                                                    axis=0))
                flat_y = tf.nn.conv2d(x,
                                      filters=tf.reshape(
                                          kernel, shape=[fh, fw, c_in, -1]),
                                      strides=strides,
                                      padding=padding,
                                      data_format='NHWC',
                                      dilations=dilations)
                output_shape = ps.shape(flat_y)[-3:]
                return tf.reshape(flat_y,
                                  shape=ps.concat([batch_shape, output_shape],
                                                  axis=0))

            pad_values = [
                _get_conv_padding(xdim,
                                  filter_dim=k,
                                  stride=s,
                                  dilation=d,
                                  padding=padding)
                for (xdim, k, s,
                     d) in zip((xh, xw), filter_shape, strides, dilations)
            ]

            idx, shape = im2row_index(
                (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in),
                block_shape=filter_shape,
                slice_step=strides,
                dilations=dilations,
                dtype=dtype)

            if padding == 'SAME':
                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x = tf.pad(x, paddings=paddings, constant_values=0)

            flat_shape = ps.pad(batch_shape,
                                paddings=[[0, 1]],
                                constant_values=-1)
            flat_x = tf.gather(tf.reshape(x, shape=flat_shape),
                               indices=idx,
                               axis=-1)
            im_x = tf.reshape(flat_x,
                              shape=ps.concat([batch_shape, shape], axis=0))
            return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
Beispiel #11
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)
            kernel_shape = ps.shape(kernel)
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  kernel_shape[-1],
                                                  batch_shape, event_shape)

                idx, shape = im2row_index((xh * sh + sum(pad_values[0]),
                                           xw * sw + sum(pad_values[1]), c_in),
                                          block_shape=filter_shape,
                                          slice_step=(1, 1),
                                          dilations=dilations,
                                          dtype=dtype,
                                          transpose=True)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                # Interleave the rows and columns of the input with rows and columns of
                # zeros equal to the number of strides.
                x_half_dilated = tf.concat([
                    tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)],
                                       axis=0),
                             dtype=input_dtype),
                    tf.reshape(x,
                               shape=ps.concat(
                                   [batch_shape, (xh * xw, 1, c_in)], axis=0))
                ],
                                           axis=-2)
                y = tf.reshape(x_half_dilated,
                               shape=ps.concat(
                                   [batch_shape, (xh, 1, xw * sw, c_in)],
                                   axis=0))

                x = tf.reshape(tf.concat([
                    tf.zeros(ps.concat(
                        [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0),
                             dtype=input_dtype), y
                ],
                                         axis=-3),
                               shape=ps.concat(
                                   [batch_shape, (xh * sh, xw * sw, c_in)],
                                   axis=0))
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                flat_shape = ps.pad(batch_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape),
                                   indices=idx,
                                   axis=-1)
                im_x = tf.reshape(flat_x,
                                  shape=ps.concat([batch_shape, shape],
                                                  axis=0))
                return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
Beispiel #12
0
def _sample_next(target_log_prob_fn,
                 current_state_parts,
                 step_sizes,
                 max_doublings,
                 current_target_log_prob,
                 batch_rank,
                 seed=None,
                 name=None):
  """Applies a single iteration of slice sampling update.

  Applies hit and run style slice sampling. Chooses a uniform random direction
  on the unit sphere in the event space. Applies the one dimensional slice
  sampling update along that direction.

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `*current_state_parts` and returns its (possibly unnormalized) log-density
      under the target distribution.
    current_state_parts: Python `list` of `Tensor`s representing the current
      state(s) of the Markov chain(s). The first `independent_chain_ndims` of
      the `Tensor`(s) index different chains.
    step_sizes: Python `list` of `Tensor`s. Provides a measure of the width
      of the density. Used to find the slice bounds. Must broadcast with the
      shape of `current_state_parts`.
    max_doublings: Integer number of doublings to allow while locating the slice
      boundaries.
    current_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn(*current_state_parts)`. The only reason to specify
      this argument is to reduce TF graph size.
    batch_rank: Integer. The number of axes in the state that correspond to
      independent batches.
    seed: Tensor seed pair.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    proposed_state_parts: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state_parts`.
    proposed_target_log_prob: `Tensor` representing the value of
      `target_log_prob_fn` at `next_state`.
    bounds_satisfied: Boolean `Tensor` of the same shape as the log density.
      True indicates whether the an interval containing the slice for that
      batch was found successfully.
    direction: `Tensor` or Python list of `Tensors`s representing the direction
      along which the slice was sampled. Has the same shape and dtype(s) as
      `current_state_parts`.
    upper_bounds: `Tensor` of batch shape and the dtype of the input state. The
      upper bounds of the slices along the sampling direction.
    lower_bounds: `Tensor` of batch shape and the dtype of the input state. The
      lower bounds of the slices along the sampling direction.
  """
  direction_seed, slice_seed = samplers.split_seed(seed)
  with tf.name_scope(name or 'sample_next'):
    # First step: Choose a random direction.
    # Direction is a list of tensors. The i'th tensor should have the same shape
    # as the i'th state part.
    direction = _choose_random_direction(current_state_parts,
                                         batch_rank=batch_rank,
                                         seed=direction_seed)

    # Interpolates the step sizes for the chosen direction.
    # Applies an ellipsoidal interpolation to compute the step direction for
    # the chosen direction. Suppose we are given step sizes for each direction.
    # Label these s_1, s_2, ... s_k. These are the step sizes to use if moving
    # in a direction parallel to one of the axes. Consider an ellipsoid which
    # intercepts the i'th axis at s_i. The step size for a direction specified
    # by the unit vector (n_1, n_2 ...n_k) is then defined as the intersection
    # of the line through this vector with this ellipsoid.
    #
    # One can show that the length of the vector from the origin to the
    # intersection point is given by:
    # 1 / sqrt(n_1^2 / s_1^2  + n_2^2 / s_2^2  + ...).
    #
    # Proof:
    # The equation of the ellipsoid is:
    # Sum_i [x_i^2 / s_i^2 ] = 1. Let n be a unit direction vector. Points
    # along the line given by n may be parameterized as alpha*n where alpha is
    # the distance along the vector. Plugging this into the equation for the
    # ellipsoid, we get:
    # alpha^2 ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) = 1
    # so alpha = \sqrt { \frac{1} { ( n_1^2 / s_1^2 + n_2^2 / s_2^2 + ...) } }
    reduce_axes = [ps.range(batch_rank, ps.rank(dirn_part))
                   for dirn_part in direction]

    components = [
        tf.reduce_sum((dirn_part / step_size)**2, axis=reduce_axes[i])
        for i, (step_size, dirn_part) in enumerate(zip(step_sizes, direction))
    ]
    step_size = tf.math.rsqrt(tf.add_n(components))
    # Computes the rank of a tensor. Uses the static rank if possible.
    state_part_ranks = [ps.rank(part)
                        for part in current_state_parts]

    def _step_along_direction(alpha):
      """Converts the scalar alpha into an n-dim vector with full state info.

      Computes x_0 + alpha * direction where x_0 is the current state and
      direction is the direction chosen above.

      Args:
        alpha: A tensor of shape equal to the batch dimensions of
          `current_state_parts`.

      Returns:
        state_parts: Tensor or Python list of `Tensor`s representing the
          state(s) of the Markov chain(s) for a given alpha and a given chosen
          direction. Has the same shape as `current_state_parts`.
      """
      padded_alphas = [_right_pad(alpha, final_rank=part_rank)
                       for part_rank in state_part_ranks]

      state_parts = [state_part + padded_alpha * direction_part
                     for state_part, direction_part, padded_alpha in
                     zip(current_state_parts, direction, padded_alphas)]
      return state_parts

    def projected_target_log_prob_fn(alpha):
      """The target log density projected along the chosen direction.

      Args:
        alpha: A tensor of shape equal to the batch dimensions of
          `current_state_parts`.

      Returns:
        Target log density evaluated at x_0 + alpha * direction where x_0 is the
        current state and direction is the direction chosen above. Has the same
        shape as `alpha`.
      """
      return target_log_prob_fn(*_step_along_direction(alpha))

    alpha_init = tf.zeros_like(current_target_log_prob,
                               dtype=current_state_parts[0].dtype)
    [
        next_alpha,
        next_target_log_prob,
        bounds_satisfied,
        upper_bounds,
        lower_bounds
    ] = ssu.slice_sampler_one_dim(projected_target_log_prob_fn,
                                  x_initial=alpha_init,
                                  max_doublings=max_doublings,
                                  step_size=step_size, seed=slice_seed)
    return [
        _step_along_direction(next_alpha),
        next_target_log_prob,
        bounds_satisfied,
        direction,
        upper_bounds,
        lower_bounds
    ]
Beispiel #13
0
  def one_step(self, current_state, previous_kernel_results, seed=None):
    """Runs one iteration of Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions
        index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      TypeError: if `not target_log_prob.dtype.is_floating`.
    """
    seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.

    with tf.name_scope(mcmc_util.make_name(self.name, 'slice', 'one_step')):
      with tf.name_scope('initialize'):
        [
            current_state_parts,
            step_sizes,
            current_target_log_prob
        ] = _prepare_args(
            self.target_log_prob_fn,
            current_state,
            self.step_size,
            previous_kernel_results.target_log_prob,
            maybe_expand=True)

        max_doublings = ps.convert_to_shape_tensor(
            value=self.max_doublings, dtype=tf.int32, name='max_doublings')

      independent_chain_ndims = ps.rank(current_target_log_prob)

      [
          next_state_parts,
          next_target_log_prob,
          bounds_satisfied,
          direction,
          upper_bounds,
          lower_bounds
      ] = _sample_next(
          self.target_log_prob_fn,
          current_state_parts,
          step_sizes,
          max_doublings,
          current_target_log_prob,
          independent_chain_ndims,
          seed=seed,
      )

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      return [
          maybe_flatten(next_state_parts),
          SliceSamplerKernelResults(
              target_log_prob=next_target_log_prob,
              bounds_satisfied=bounds_satisfied,
              direction=direction,
              upper_bounds=upper_bounds,
              lower_bounds=lower_bounds,
              seed=seed,
          ),
      ]
def _event_size(tensor_structure, event_ndims):
  """Returns the number of elements in the event-portion of a structure."""
  event_shapes = nest.map_structure(
      lambda t, nd: ps.slice(ps.shape(t), [ps.rank(t)-nd], [nd]),
      tensor_structure, event_ndims)
  return sum(ps.reduce_prod(shape) for shape in nest.flatten(event_shapes))
Beispiel #15
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                _make_replica_target_log_prob_fn(self.target_log_prob_fn,
                                                 inverse_temperatures),
                self._seed_stream())

            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = prefer_static.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = prefer_static.rank(
                pre_swap_replica_target_log_prob)
            num_replica = prefer_static.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            swaps = tf.cast(
                self.swap_proposal_fn(  # pylint: disable=not-callable
                    num_replica,
                    batch_shape=batch_shape,
                    seed=self._seed_stream()),
                dtype=tf.int32)
            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                tf.random.uniform(shape=replica_and_batch_shape,
                                  dtype=dtype,
                                  seed=self._seed_stream()))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
            )

            return states, post_swap_kernel_results
  def _one_step_part(
      self,
      step_size,
      state,
      error_sum,
      log_averaging_step,
      log_shrinkage_target,
      log_accept_prob_rank=None,
      log_accept_prob=None,
      target_accept_prob=None,
      previous_kernel_results=None):
    """Compute new step sizes for each step size part.

    If step size part has smaller rank than the corresponding state part, then
    the difference is averaged away in the log accept prob.

    Example:

      state_part has shape      [2, 3, 4, 5]
      step_size_part has shape     [1, 4, 1]
      log_accept_prob has shape [2, 3, 4]

    Since step size has 1 rank fewer than the state, we reduce away the leading
    dimension of `log_accept_prob` to get a Tensor with shape [3, 4]. Next,
    since `log_accept_prob` must broadcast into step_size_part on the left, we
    reduce the dimensions where their shapes differ, to get a Tensor with shape
    [1, 4], which now is compatible with the leading dimensions of
    step_size_part.

    There is a subtlety here in that `step_size_parts` might be a length-1 list,
    which means that we'll be "structure-broadcasting" it for all the state
    parts (see logic in, e.g., hmc.py). In this case we must assume that that
    the lone step size provided broadcasts with the event dims of each state
    part. This means that either step size has no dimensions corresponding to
    chain dimensions, or all states are of the same shape. For the former, we
    want to reduce over all chain dimensions. For the later, we want to use
    the same logic as in the non-structure-broadcasted case.

    It turns out we can compute the reduction dimensions for both cases
    uniformly by taking the rank of any state part. This obviously works in
    the second case (where all state ranks are the same). In the first case,
    all state parts have the rank L + D_i + B, where L is the rank of
    log_accept_prob, D_i is the non-shared dimensions amongst all states, and
    B are the shared dimensions of all the states, which are equal to the step
    size. When we subtract B, we will always get a number >= L, which means
    we'll get the full reduction we want.

    Args:
      step_size: Previous step's step_size.
      state: Previous step's state value.
      error_sum: Previous step's error accumulator.
      log_averaging_step: Previous step's log_averaging_step.
      log_shrinkage_target: Floating point scalar `Tensor`. Logarithm of value
        the exploration step size is biased towards.
      log_accept_prob_rank: Rank of log_accept_prob.
      log_accept_prob: Floating point scalar `Tensor`. Target accept
        probability.
      target_accept_prob: A floating point `Tensor` representing desired
        acceptance probability. Must be a positive number less than 1.
      previous_kernel_results: Results struct from previous step.

    Returns:
      new_step_size: Updated `step_size`.
      new_log_averaging_step: Updated `log_averaging_step`.
      new_error_sum: Updated `error_sum`.
    """
    num_reduce_dims = prefer_static.minimum(
        log_accept_prob_rank,
        (prefer_static.rank(state) - prefer_static.rank(step_size)))
    reduced_log_accept_prob = reduce_logmeanexp(
        log_accept_prob,
        axis=prefer_static.range(num_reduce_dims))

    # reduced_log_accept_prob must broadcast into step_size on the
    # left, so we do an additional reduction over dimensions where their
    # shapes differ.
    reduce_indices = get_differing_dims(
        reduced_log_accept_prob, step_size)
    reduced_log_accept_prob = reduce_logmeanexp(
        reduced_log_accept_prob, axis=reduce_indices, keepdims=True)
    new_error_sum = (error_sum +
                     target_accept_prob -
                     tf.math.exp(reduced_log_accept_prob))
    num_ones_to_pad = prefer_static.maximum(
        prefer_static.rank(log_shrinkage_target) -
        prefer_static.rank(new_error_sum), 0)
    new_error_sum_extend = tf.reshape(
        new_error_sum,
        shape=prefer_static.pad(
            prefer_static.shape(new_error_sum),
            paddings=[[0, num_ones_to_pad]],
            constant_values=1))

    step_count_smoothing = previous_kernel_results.step_count_smoothing
    step = tf.cast(
        previous_kernel_results.step, step_count_smoothing.dtype) + 1.
    soft_t = step_count_smoothing + step

    new_log_step = (
        log_shrinkage_target -
        ((tf.cast(new_error_sum_extend, step.dtype) * tf.math.sqrt(step)) /
         (soft_t * previous_kernel_results.exploration_shrinkage)))

    eta = step**(-previous_kernel_results.decay_rate)
    new_log_averaging_step = (eta * new_log_step +
                              (1. - eta) * log_averaging_step)

    # - If still adapting, return an exploring step size,
    # - If just finished, return the averaging step size
    # - Otherwise, do not update
    new_step_size = tf.where(
        previous_kernel_results.step < self.num_adaptation_steps,
        tf.math.exp(new_log_step),
        tf.where(previous_kernel_results.step > self.num_adaptation_steps,
                 step_size,
                 tf.math.exp(new_log_averaging_step)))
    new_log_averaging_step = tf.where(
        previous_kernel_results.step > self.num_adaptation_steps,
        log_averaging_step,
        new_log_averaging_step)
    return new_step_size, new_log_averaging_step, new_error_sum
Beispiel #17
0
  def _log_prob(self, value, conditional_input=None, training=None):
    """Log probability function with optional conditional input.

    Calculates the log probability of a batch of data under the modeled
    distribution (or conditional distribution, if conditional input is
    provided).

    Args:
      value: `Tensor` or Numpy array of image data. May have leading batch
        dimension(s), which must broadcast to the leading batch dimensions of
        `conditional_input`.
      conditional_input: `Tensor` on which to condition the distribution (e.g.
        class labels), or `None`. May have leading batch dimension(s), which
        must broadcast to the leading batch dimensions of `value`.
      training: `bool` or `None`. If `bool`, it controls the dropout layer,
        where `True` implies dropout is active. If `None`, it defaults to
        `tf.keras.backend.learning_phase()`.
    Returns:
      log_prob_values: `Tensor`.
    """
    # Determine the batch shape of the input images
    image_batch_shape = prefer_static.shape(value)[:-3]

    # Broadcast `value` and `conditional_input` to the same batch_shape
    if conditional_input is None:
      image_batch_and_conditional_shape = image_batch_shape
    else:
      conditional_input = tf.convert_to_tensor(conditional_input)
      conditional_input_shape = prefer_static.shape(conditional_input)
      conditional_batch_rank = (prefer_static.rank(conditional_input) -
                                tensorshape_util.rank(self.conditional_shape))
      conditional_batch_shape = conditional_input_shape[:conditional_batch_rank]

      image_batch_and_conditional_shape = prefer_static.broadcast_shape(
          image_batch_shape, conditional_batch_shape)
      conditional_input = tf.broadcast_to(
          conditional_input,
          prefer_static.concat(
              [image_batch_and_conditional_shape, self.conditional_shape],
              axis=0))
      value = tf.broadcast_to(
          value,
          prefer_static.concat(
              [image_batch_and_conditional_shape, self.event_shape],
              axis=0))

      # Flatten batch dimension for input to Keras model
      conditional_input = tf.reshape(
          conditional_input,
          prefer_static.concat([(-1,), self.conditional_shape], axis=0))

    value = tf.reshape(
        value, prefer_static.concat([(-1,), self.event_shape], axis=0))

    transformed_value = (2. * (value - self._low) /
                         (self._high - self._low)) - 1.
    inputs = (transformed_value if conditional_input is None
              else [transformed_value, conditional_input])

    params = self.network(inputs, training=training)

    num_channels = self.event_shape[-1]
    if num_channels == 1:
      component_logits, locs, scales = params
    else:
      # If there is more than one channel, we create a linear autoregressive
      # dependency among the location parameters of the channels of a single
      # pixel (the scale parameters within a pixel are independent). For a pixel
      # with R/G/B channels, the `r`, `g`, and `b` saturation values are
      # distributed as:
      #
      # r ~ Logistic(loc_r, scale_r)
      # g ~ Logistic(coef_rg * r + loc_g, scale_g)
      # b ~ Logistic(coef_rb * r + coef_gb * g + loc_b, scale_b)
      # TODO(emilyaf) Investigate using fill_triangular/matrix multiplication
      # on the coefficients instead of split/multiply/concat
      component_logits, locs, scales, coeffs = params
      num_coeffs = num_channels * (num_channels - 1) // 2
      loc_tensors = tf.split(locs, num_channels, axis=-1)
      coef_tensors = tf.split(coeffs, num_coeffs, axis=-1)
      channel_tensors = tf.split(transformed_value, num_channels, axis=-1)

      coef_count = 0
      for i in range(num_channels):
        channel_tensors[i] = channel_tensors[i][..., tf.newaxis, :]
        for j in range(i):
          loc_tensors[i] += channel_tensors[j] * coef_tensors[coef_count]
          coef_count += 1
      locs = tf.concat(loc_tensors, axis=-1)

    dist = self._make_mixture_dist(component_logits, locs, scales)
    return tf.reshape(dist.log_prob(value), image_batch_and_conditional_shape)
  def bootstrap_results(self, init_state):
    with tf.name_scope(
        mcmc_util.make_name(self.name, 'dual_averaging_step_size_adaptation',
                            'bootstrap_results')):
      inner_results = self.inner_kernel.bootstrap_results(init_state)
      step_size = self.step_size_getter_fn(inner_results)

      log_accept_prob = self.log_accept_prob_getter_fn(inner_results)

      state_parts = tf.nest.flatten(init_state)
      step_size_parts = tf.nest.flatten(step_size)

      if self._parameters['shrinkage_target'] is None:
        shrinkage_target_parts = [None] * len(step_size_parts)
      else:
        shrinkage_target_parts = tf.nest.flatten(
            self._parameters['shrinkage_target'])
        if len(shrinkage_target_parts) not in [1, len(step_size_parts)]:
          raise ValueError(
              '`shrinkage_target` should be a Tensor or list of tensors of '
              'same length as `step_size`. Found len(`step_size`) = {} and '
              'len(shrinkage_target) = {}'.format(
                  len(step_size_parts), len(shrinkage_target_parts)))
        if len(shrinkage_target_parts) < len(step_size_parts):
          shrinkage_target_parts *= len(step_size_parts)

      dtype = dtype_util.common_dtype(step_size_parts, tf.float32)
      error_sum, log_averaging_step, log_shrinkage_target = [], [], []
      for state_part, step_size_part, shrinkage_target_part in zip(
          state_parts, step_size_parts, shrinkage_target_parts):
        num_reduce_dims = prefer_static.minimum(
            prefer_static.rank(log_accept_prob),
            prefer_static.rank(state_part) - prefer_static.rank(step_size_part))
        reduced_log_accept_prob = reduce_logmeanexp(
            log_accept_prob,
            axis=prefer_static.range(num_reduce_dims))
        reduce_indices = get_differing_dims(
            reduced_log_accept_prob, step_size_part)
        reduced_log_accept_prob = reduce_logmeanexp(
            reduced_log_accept_prob,
            axis=reduce_indices,
            keepdims=True)
        error_sum.append(tf.zeros_like(reduced_log_accept_prob, dtype=dtype))
        log_averaging_step.append(tf.zeros_like(step_size_part, dtype=dtype))

        if shrinkage_target_part is None:
          log_shrinkage_target.append(
              float(np.log(10.)) + tf.math.log(step_size_part))
        else:
          log_shrinkage_target.append(
              tf.math.log(tf.cast(shrinkage_target_part, dtype)))

      return DualAveragingStepSizeAdaptationResults(
          inner_results=inner_results,
          step=tf.constant(0, dtype=tf.int32),
          target_accept_prob=tf.cast(self.parameters['target_accept_prob'],
                                     log_accept_prob.dtype),
          log_shrinkage_target=log_shrinkage_target,
          exploration_shrinkage=tf.cast(
              self.parameters['exploration_shrinkage'], dtype),
          step_count_smoothing=tf.cast(
              self.parameters['step_count_smoothing'], dtype),
          decay_rate=tf.cast(self.parameters['decay_rate'], dtype),
          error_sum=error_sum,
          log_averaging_step=log_averaging_step,
          new_step_size=step_size)
Beispiel #19
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            current_momentum_parts = []
            for x in current_state_parts:
                current_momentum_parts.append(
                    tf.random.normal(shape=tf.shape(x),
                                     dtype=self._momentum_dtype
                                     or dtype_util.base_dtype(x.dtype),
                                     seed=self._seed_stream()))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = prefer_static.rank(
                current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts, next_momentum_parts,
                    independent_chain_ndims),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
def batch_interp_regular_nd_grid(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis,
                                 fill_value='constant_extension',
                                 name=None):
  """Multi-linear interpolation on a regular (constant spacing) grid.

  Given [a batch of] reference values, this function computes a multi-linear
  interpolant and evaluates it on [a batch of] of new `x` values.

  The interpolant is built from reference values indexed by `nd` dimensions
  of `y_ref`, starting at `axis`.

  For example, take the case of a `2-D` scalar valued function and no leading
  batch dimensions.  In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]`
  is the reference value corresponding to grid point

  ```
  [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1),
   x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)]
  ```

  In the general case, dimensions to the left of `axis` in `y_ref` are broadcast
  with leading dimensions in `x`, `x_ref_min`, `x_ref_max`.

  Args:
    x: Numeric `Tensor` The x-coordinates of the interpolated output values for
      each batch.  Shape `[..., D, nd]`, designating [a batch of] `D`
      coordinates in `nd` space.  `D` must be `>= 1` and is not a batch dim.
    x_ref_min:  `Tensor` of same `dtype` as `x`.  The minimum values of the
      (implicitly defined) reference `x_ref`.  Shape `[..., nd]`.
    x_ref_max:  `Tensor` of same `dtype` as `x`.  The maximum values of the
      (implicitly defined) reference `x_ref`.  Shape `[..., nd]`.
    y_ref:  `Tensor` of same `dtype` as `x`.  The reference output values. Shape
      `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference
      values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued
      function (for `M >= 0`).
    axis:  Scalar integer `Tensor`.  Dimensions `[axis, axis + nd)` of `y_ref`
      index the interpolation table.  E.g. `3-D` interpolation of a scalar
      valued function requires `axis=-3` and a `3-D` matrix valued function
      requires `axis=-5`.
    fill_value:  Determines what values output should take for `x` values that
      are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or
      'constant_extension' ==> Extend as constant function.
      Default value: `'constant_extension'`
    name:  A name to prepend to created ops.
      Default value: `'batch_interp_regular_nd_grid'`.

  Returns:
    y_interp:  Interpolation between members of `y_ref`, at points `x`.
      `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].`

  Raises:
    ValueError:  If `rank(x) < 2` is determined statically.
    ValueError:  If `axis` is not a scalar is determined statically.
    ValueError:  If `axis + nd > rank(y_ref)` is determined statically.

  #### Examples

  Interpolate a function of one variable.

  ```python
  y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20))

  tfp.math.batch_interp_regular_nd_grid(
      # x.shape = [3, 1], x_ref_min/max.shape = [1].  Trailing `1` for `1-D`.
      x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref,
      axis=0)
  ==> approx [exp(6.0), exp(0.5), exp(3.3)]
  ```

  Interpolate a scalar function of two variables.

  ```python
  x_ref_min = [0., 0.]
  x_ref_max = [2 * np.pi, 2 * np.pi]

  # Build y_ref.
  x0s, x1s = tf.meshgrid(
      tf.linspace(x_ref_min[0], x_ref_max[0], num=100),
      tf.linspace(x_ref_min[1], x_ref_max[1], num=100),
      indexing='ij')

  def func(x0, x1):
    return tf.sin(x0) * tf.cos(x1)

  y_ref = func(x0s, x1s)

  x = np.pi * tf.random.uniform(shape=(10, 2))

  tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
  ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])
  ```

  """
  with tf.name_scope(name or 'interp_regular_nd_grid'):
    dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                    dtype_hint=tf.float32)

    # Arg checking.
    if isinstance(fill_value, str):
      if fill_value != 'constant_extension':
        raise ValueError(
            'A fill value ({}) was not an allowed string ({})'.format(
                fill_value, 'constant_extension'))
    else:
      fill_value = tf.convert_to_tensor(
          fill_value, name='fill_value', dtype=dtype)
      _assert_ndims_statically(fill_value, expect_ndims=0)

    # x.shape = [..., nd].
    x = tf.convert_to_tensor(x, name='x', dtype=dtype)
    _assert_ndims_statically(x, expect_ndims_at_least=2)

    # y_ref.shape = [..., C1,...,Cnd, B1,...,BM]
    y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)

    # x_ref_min.shape = [nd]
    x_ref_min = tf.convert_to_tensor(
        x_ref_min, name='x_ref_min', dtype=dtype)
    x_ref_max = tf.convert_to_tensor(
        x_ref_max, name='x_ref_max', dtype=dtype)
    _assert_ndims_statically(
        x_ref_min, expect_ndims_at_least=1, expect_static=True)
    _assert_ndims_statically(
        x_ref_max, expect_ndims_at_least=1, expect_static=True)

    # nd is the number of dimensions indexing the interpolation table, it's the
    # 'nd' in the function name.
    nd = tf.compat.dimension_value(x_ref_min.shape[-1])
    if nd is None:
      raise ValueError('`x_ref_min.shape[-1]` must be known statically.')
    tensorshape_util.assert_is_compatible_with(
        x_ref_max.shape[-1:], x_ref_min.shape[-1:])

    # Convert axis and check it statically.
    axis = tf.convert_to_tensor(axis, dtype=tf.int32, name='axis')
    axis = ps.non_negative_axis(axis, tf.rank(y_ref))
    tensorshape_util.assert_has_rank(axis.shape, 0)
    axis_ = tf.get_static_value(axis)
    y_ref_rank_ = tf.get_static_value(tf.rank(y_ref))
    if axis_ is not None and y_ref_rank_ is not None:
      if axis_ + nd > y_ref_rank_:
        raise ValueError(
            'Since dims `[axis, axis + nd)` index the interpolation table, we '
            'must have `axis + nd <= rank(y_ref)`.  Found: '
            '`axis`: {},  rank(y_ref): {}, and inferred `nd` from trailing '
            'dimensions of `x_ref_min` to be {}.'.format(
                axis_, y_ref_rank_, nd))

    x_batch_shape = ps.shape(x)[:-2]
    x_ref_min_batch_shape = ps.shape(x_ref_min)[:-1]
    x_ref_max_batch_shape = ps.shape(x_ref_max)[:-1]
    y_ref_batch_shape = ps.shape(y_ref)[:axis]

    # Do a brute-force broadcast of batch dims (add zeros).
    batch_shape = y_ref_batch_shape
    for tensor in [x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape]:
      batch_shape = ps.broadcast_shape(batch_shape, tensor)

    def _batch_shape_of_zeros_with_rightmost_singletons(n_singletons):
      """Return Tensor of zeros with some singletons on the rightmost dims."""
      ones = ps.ones(shape=[n_singletons], dtype=tf.int32)
      return ps.concat([batch_shape, ones], axis=0)

    x = _broadcast_with(
        x, _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=2))
    x_ref_min = _broadcast_with(
        x_ref_min,
        _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1))
    x_ref_max = _broadcast_with(
        x_ref_max,
        _batch_shape_of_zeros_with_rightmost_singletons(n_singletons=1))
    y_ref = _broadcast_with(
        y_ref,
        _batch_shape_of_zeros_with_rightmost_singletons(
            n_singletons=tf.rank(y_ref) - axis))

    return _batch_interp_with_gather_nd(
        x=x,
        x_ref_min=x_ref_min,
        x_ref_max=x_ref_max,
        y_ref=y_ref,
        nd=nd,
        fill_value=fill_value,
        batch_dims=ps.rank(x) - 2)
Beispiel #21
0
 def _copy(v):
     return v * ps.ones(ps.pad(
         [2], paddings=[[0, ps.rank(v)]], constant_values=1),
                        dtype=v.dtype)
def window_tune_nuts_sampling(target_log_prob,
                              prior_samples,
                              constraining_bijectors=None,
                              init_state=None,
                              num_samples=500,
                              nchains=4,
                              init_nchains=1,
                              target_accept_prob=.8,
                              max_tree_depth=9,
                              use_scaled_init=True,
                              tuning_window_schedule=(75, 25, 25, 25, 25, 25,
                                                      50),
                              use_wide_window_expanding_mode=True,
                              seed=None,
                              parallel_iterations=10,
                              jit_compile=True,
                              use_input_signature=True,
                              reduce_retracing=False):
    """Sample from a density with NUTS and an expanding window tuning scheme.

  This function implements a turnkey MCMC sampling routine using NUTS and an
  expanding window tuning strategy similar to Stan [1]. It learns a pre-
  conditioner that scales and rotates the target distribution using a series of
  expanding windows - either in number of samples (same as in Stan,
  use_wide_window_expanding_mode=False) or in number of batches/chains
  (use_wide_window_expanding_mode=True).

  Currently, the function uses `prior_samples` to initialize MCMC chains
  uniformly at random between -1 and 1 scaled by the prior standard deviation
  (i.e., [-prior_std, prior_std]). The scaling is ignored if `use_scaled_init`
  is set to False. Alternatively, user can input the `init_state` directly.

  Currently, the tuning and sampling routine is run in Python, with each block
  of the tuning epoch (window 1, 2, and 3 in Stan [1]) run with two @tf.function
  compiled functions. The user can control the compilation options using the
  kwargs `jit_compile`, `use_input_signature`, and
  `reduce_retracing`.  Setting all to True would compile to XLA and
  potentially avoid the small overhead of function recompilation (note that it
  is not yet the case in XLA right now). It is not yet clear whether doing it
  this way is better than just wrapping the full inference routine in
  tf.function with XLA.

  Internally, this calls `_sample_posterior`, which assumes a real-valued target
  density function and takes a Tensor with shape=(batch * dimension) as input.
  The tuning routine is a memory-less (i.e., no warm-up samples are saved) MCMC
  sampling with number of samples specified by a list-like
  `tuning_window_schedule`.

  Args:
    target_log_prob: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    prior_samples: Nested structure of `Tensor`s, each of shape `[batches,
      latent_part_event_shape]` and should be sample from the prior. They are
      used to generate an initial chain position if `init_state` is not
      supplied.
    constraining_bijectors: `tfp.distributions.Bijector` or list of
      `tfp.distributions.Bijector`s. These bijectors use `forward` to map the
      state on the real space to the constrained state expected by
      `target_log_prob`.
    init_state: (Optional) `Tensor` or Python `list` of `Tensor`s representing
      the initial state(s) of the Markov chain(s).
    num_samples: Integer number of the Markov chain draws after tuning.
    nchains: Integer number of the Markov chains after tuning.
    init_nchains: Integer number of the Markov chains in the first phase of
      tuning.
    target_accept_prob: Floating point scalar `Tensor`. Target acceptance
      probability for step size adaptation.
    max_tree_depth: Maximum depth of the tree implicitly built by NUTS. See
      `tfp.mcmc.NoUTurnSampler` for more details
    use_scaled_init: Boolean. If `True`, generate initial state within [-1, 1]
      scaled by prior sample standard deviation in the unconstrained real space.
      This kwarg is ignored if `init_state` is not None
    tuning_window_schedule: List-like sequence of integers that specify the
      tuning schedule. Each integer number specifies the number of MCMC samples
      within a single warm-up window. The first and the last window tunes the
      step size (a scalar) only, while the intermediate windows tune both step
      size and the pre-conditioner. Moreover, the intermediate windows double
      the number of samples taken: for example, the default schedule (75,
        25, 25, 25, 25, 25, 50) actually means it will take (75, 25 * 2**0, 25 *
        2**1, 25 * 2**2, 25 * 2**3, 25 * 2**4, 50) samples.
    use_wide_window_expanding_mode: Boolean. Default to `True` that we double
      the number of chains from the previous stage for the intermediate windows.
      See `tuning_window_schedule` kwarg for more details.
    seed: Python integer to seed the random number generator.
    parallel_iterations: The number of iterations allowed to run in parallel.
      It must be a positive integer. See `tf.while_loop` for more details.
      Note that if you set the seed to have deterministic output you should
      also set `parallel_iterations` to 1.
    jit_compile: kwarg pass to tf.function decorator. If True, the
      function is always compiled by XLA.
    use_input_signature: If True, generate an input_signature kwarg to pass to
      tf.function decorator.
    reduce_retracing: kwarg pass to tf.function decorator. When True,
      tf.function may generate fewer, graphs that are less specialized on input
      shapes.

  Returns:
    posterior_samples: A `Tensor` or Python list of `Tensor`s representing the
      posterior MCMC samples after tuning. It has the same structure as
      `prior_samples` but with the leading shape being (num_samples * nchains)
    diagnostic: A list of `Tensor` representing the diagnostics from NUTS:
      `target_log_prob`, `leapfrogs_taken`, `has_divergence`, `energy`,
      `log_accept_ratio`, `reach_max_depth`, `step_size.
    conditioning_bijector: A tfp bijector that scales and rotates the target
      density function in latent unconstrained space as determined by
      adaptation.

  ### Examples

  Sampling from a multivariate Student-T distribution.

  ```python
  DTYPE = np.float32

  nd = 50
  concentration = 1.

  prior_dist = tfd.Sample(tfd.Normal(tf.constant(0., DTYPE), 100.), nd)

  mu = tf.cast(np.linspace(-100, 100, nd), dtype=DTYPE)
  sigma = tf.cast(np.exp(np.linspace(-1, 1.5, nd)), dtype=DTYPE)
  corr_tril = tfd.CholeskyLKJ(
      dimension=nd, concentration=concentration).sample()
  scale_tril = tf.linalg.matmul(tf.linalg.diag(sigma), corr_tril)
  target_dist = tfd.MultivariateStudentTLinearOperator(
      df=5., loc=mu, scale=tf.linalg.LinearOperatorLowerTriangular(scale_tril))

  target_log_prob = lambda *x: (
      prior_dist.log_prob(*x) + target_dist.log_prob(*x))

  (
      [mcmc_samples], diagnostic, conditioning_bijector
  ) = window_tune_nuts_sampling(target_log_prob, [prior_dist.sample(2000)])

  loc_conditioner, scale_conditioner = conditioning_bijector.trainable_variables

  _, ax = plt.subplots(1, 2, figsize=(10, 5))
  ax[0].plot(mu, loc_conditioner.numpy(), 'o', label='conditioner mean')
  ax[0].plot(mu, tf.reduce_mean(
      mcmc_samples, axis=[0, 1]), 'o', label='estimated mean')
  ax[0].legend()

  sigma_sim = target_dist._stddev()
  ax[1].plot(sigma_sim, scale_conditioner.numpy(), 'o', label='conditioner std')
  ax[1].plot(sigma_sim, tf.math.reduce_std(
      mcmc_samples, axis=[0, 1]), 'o', label='estimated std');
  ax[1].legend()

  ax[0].plot([min(mu), max(mu)], [min(mu), max(mu)])
  ax[1].plot([min(sigma_sim), max(sigma_sim)], [min(sigma_sim), max(sigma_sim)])
  ```

  #### References

  [1]: Stan Reference Manual.
  https://mc-stan.org/docs/2_23/reference-manual/hmc-algorithm-parameters.html#automatic-parameter-tuning
  """

    log_prob_val = target_log_prob(*prior_samples)
    log_prob_rank = ps.rank(log_prob_val)
    assert log_prob_rank == 1

    if constraining_bijectors is not None:
        target_log_prob_unconstrained = make_transformed_log_prob(
            target_log_prob,
            constraining_bijectors,
            direction='forward',
            enable_bijector_caching=False)
        # constrain to unconstrain
        inverse_transform = make_transform_fn(constraining_bijectors,
                                              'inverse')
        # unconstrain to constrain
        forward_transform = make_transform_fn(constraining_bijectors,
                                              'forward')
    else:
        target_log_prob_unconstrained = target_log_prob
        inverse_transform = lambda x: x
        forward_transform = lambda y: y

    prior_samples_unconstrained = inverse_transform(prior_samples)
    init_state_unconstrained = None

    # If the input to target_log_prob_fn is a nested structure of Tensors, we
    # flatten and concatenate them into a 1D vector so that it is easier to work
    # with in mass matrix adaptation.
    if tf.nest.is_nested(prior_samples_unconstrained):
        free_rv_event_shape = [x.shape[log_prob_rank:] for x in prior_samples]
        flat_event_splits = [s.num_elements() for s in free_rv_event_shape]

        # TODO(b/158878248): replace the two function below with `tfb.Split`.
        def split_and_reshape(x):
            assertions = []
            message = 'Input must have at least one dimension.'
            if tensorshape_util.rank(x.shape) is not None:
                if tensorshape_util.rank(x.shape) == 0:
                    raise ValueError(message)
            else:
                assertions.append(
                    assert_util.assert_rank_at_least(x, 1, message=message))
            with tf.control_dependencies(assertions):
                x = tf.nest.pack_sequence_as(
                    free_rv_event_shape, tf.split(x,
                                                  flat_event_splits,
                                                  axis=-1))

                def _reshape_map_part(part, event_shape):
                    static_rank = tf.get_static_value(
                        ps.rank_from_shape(event_shape))
                    if static_rank == 1:
                        return part
                    new_shape = ps.concat([ps.shape(part)[:-1], event_shape],
                                          axis=-1)
                    return tf.reshape(part, ps.cast(new_shape, tf.int32))

                x = tf.nest.map_structure(_reshape_map_part, x,
                                          free_rv_event_shape)
            return x

        def concat_list_event(x):
            def handle_part(x, shape):
                if len(shape) == 0:  # pylint: disable=g-explicit-length-test
                    return x[..., tf.newaxis]
                return tf.reshape(x, list(x.shape)[:-len(shape)] + [-1])

            flat_parts = [
                handle_part(v, s) for v, s in zip(x, free_rv_event_shape)
            ]
            return tf.concat(flat_parts, axis=-1)

        def target_log_prob_unconstrained_concated(x):
            x = split_and_reshape(x)
            return target_log_prob_unconstrained(*x)

        prior_samples_unconstrained_concated = concat_list_event(
            prior_samples_unconstrained)
        if init_state is not None:
            init_state_unconstrained = concat_list_event(
                inverse_transform(init_state))
    else:
        target_log_prob_unconstrained_concated = target_log_prob_unconstrained
        prior_samples_unconstrained_concated = prior_samples_unconstrained
        split_and_reshape = lambda x: x
        if init_state is not None:
            init_state_unconstrained = inverse_transform(init_state)

    nuts_samples, diagnostic, conditioning_bijector = _sample_posterior(
        target_log_prob_unconstrained_concated,
        prior_samples_unconstrained_concated,
        init_state=init_state_unconstrained,
        num_samples=num_samples,
        nchains=nchains,
        init_nchains=init_nchains,
        target_accept_prob=target_accept_prob,
        max_tree_depth=max_tree_depth,
        use_scaled_init=use_scaled_init,
        tuning_window_schedule=tuning_window_schedule,
        use_wide_window_expanding_mode=use_wide_window_expanding_mode,
        seed=seed,
        parallel_iterations=parallel_iterations,
        jit_compile=jit_compile,
        use_input_signature=use_input_signature,
        reduce_retracing=reduce_retracing)
    return forward_transform(
        split_and_reshape(nuts_samples)), diagnostic, conditioning_bijector
Beispiel #23
0
 def reduce_sum(x, m, shard_axes):
     out = tf.reduce_sum(x, axis=ps.range(log_prob_rank, ps.rank(m)))
     if shard_axes is not None:
         out = distribute_lib.psum(out, shard_axes)
     return out
def _sample_posterior(target_log_prob_unconstrained,
                      prior_samples_unconstrained,
                      init_state=None,
                      num_samples=500,
                      nchains=4,
                      init_nchains=1,
                      target_accept_prob=.8,
                      max_tree_depth=9,
                      use_scaled_init=True,
                      tuning_window_schedule=(75, 25, 25, 25, 25, 25, 50),
                      use_wide_window_expanding_mode=True,
                      seed=None,
                      parallel_iterations=10,
                      jit_compile=True,
                      use_input_signature=False,
                      reduce_retracing=False):
    """MCMC sampling with HMC/NUTS using an expanding epoch tuning scheme."""

    seed_stream = tfp.util.SeedStream(seed, 'window_tune_nuts_sampling')
    rv_rank = ps.rank(prior_samples_unconstrained)
    assert rv_rank == 2
    total_ndims = ps.shape(prior_samples_unconstrained)[-1]
    dtype = prior_samples_unconstrained.dtype

    # TODO(b/158878248): explore option to for user to control the
    # parameterization of conditioning_bijector.
    # TODO(b/158878248): right now, we use 2 tf.Variable to initialize a scaling
    # bijector, and update the underlying value at the end of each warmup window.
    # It might be faster to rewrite it into a functional style (with a small
    # additional compilation cost).
    loc_conditioner = tf.Variable(tf.zeros([total_ndims], dtype=dtype),
                                  name='loc_conditioner')
    scale_conditioner = tf.Variable(tf.ones([total_ndims], dtype=dtype),
                                    name='scale_conditioner')

    # Start with Identity Covariance matrix
    scale = tf.linalg.LinearOperatorDiag(diag=scale_conditioner,
                                         is_non_singular=True,
                                         is_self_adjoint=True,
                                         is_positive_definite=True)
    conditioning_bijector = tfb.Shift(shift=loc_conditioner)(
        tfb.ScaleMatvecLinearOperator(scale))

    if init_state is None:
        # Start at uniform random [-1, 1] around the prior mean in latent space
        init_state_uniform = tf.random.uniform([init_nchains, total_ndims],
                                               dtype=dtype,
                                               seed=seed_stream()) * 2. - 1.
        if use_scaled_init:
            prior_z_mean = tf.math.reduce_mean(prior_samples_unconstrained,
                                               axis=0)
            prior_z_std = tf.math.reduce_std(prior_samples_unconstrained,
                                             axis=0)
            init_state = init_state_uniform * prior_z_std + prior_z_mean
        else:
            init_state = init_state_uniform

    # The denominator is the O(N^0.25) scaling from Beskos et al. 2010. The
    # numerator corresponds to the trajectory length. Candidate value includs: 1,
    # 1.57 (pi / 2). We use a conservately small value here (0.25).
    init_step_size = tf.constant(0.25 / (total_ndims**0.25), dtype=dtype)

    hmc_inner = tfp.mcmc.TransformedTransitionKernel(
        tfp.mcmc.NoUTurnSampler(
            target_log_prob_fn=target_log_prob_unconstrained,
            step_size=init_step_size,
            max_tree_depth=max_tree_depth,
            parallel_iterations=parallel_iterations,
        ), conditioning_bijector)

    hmc_step_size_tuning = tfp.mcmc.DualAveragingStepSizeAdaptation(
        inner_kernel=hmc_inner,
        num_adaptation_steps=max(tuning_window_schedule),
        target_accept_prob=target_accept_prob)

    if use_input_signature:
        input_signature = [
            tf.TensorSpec(shape=None, dtype=tf.int32),
            tf.TensorSpec(shape=[None, total_ndims], dtype=dtype),
        ]
    else:
        input_signature = None

    # TODO(b/158878248): move the nested function definitions to module top-level.
    @tf.function(input_signature=input_signature,
                 autograph=False,
                 jit_compile=jit_compile,
                 reduce_retracing=reduce_retracing)
    def fast_adaptation_interval(num_steps, previous_state):
        """Step size only adaptation interval.

    This corresponds to window 1 and window 3 in the Stan HMC parameter
    tuning scheme.

    Args:
      num_steps: Number of tuning steps the interval will run.
      previous_state: Initial state of the tuning interval.

    Returns:
      last_state: Last state of the tuning interval.
      last_pkr: Kernel result from the TransitionKernel at the end of the
        tuning interval.
    """
        def body_fn(i, state, pkr):
            next_state, next_pkr = hmc_step_size_tuning.one_step(state, pkr)
            return i + 1, next_state, next_pkr

        current_pkr = hmc_step_size_tuning.bootstrap_results(previous_state)
        _, last_state, last_pkr = tf.while_loop(
            lambda i, *_: i < num_steps,
            body_fn,
            loop_vars=(0, previous_state, current_pkr),
            maximum_iterations=num_steps,
            parallel_iterations=parallel_iterations)
        return last_state, last_pkr

    def body_fn_window2(i, previous_state, previous_pkr, previous_mean,
                        previous_cov):
        """Take one MCMC step and update the step size and mass matrix."""
        next_state, next_pkr = hmc_step_size_tuning.one_step(
            previous_state, previous_pkr)
        n_next = i + 1
        delta_pre = previous_state - previous_mean
        next_mean = previous_mean + delta_pre / tf.cast(
            n_next, delta_pre.dtype)
        delta_post = previous_state - next_mean
        delta_cov = tf.expand_dims(delta_post, -1) * tf.expand_dims(
            delta_pre, -2)
        next_cov = previous_cov + delta_cov

        next_mean.set_shape(previous_mean.shape)
        next_cov.set_shape(previous_cov.shape)
        return n_next, next_state, next_pkr, next_mean, next_cov

    if use_input_signature:
        input_signature = [
            tf.TensorSpec(shape=None, dtype=tf.int32),
            tf.TensorSpec(shape=None, dtype=tf.int32),
            tf.TensorSpec(shape=[None, total_ndims], dtype=dtype),
            tf.TensorSpec(shape=[None, total_ndims], dtype=dtype),
            tf.TensorSpec(shape=[None, total_ndims, total_ndims], dtype=dtype),
        ]
    else:
        input_signature = None

    # TODO(b/158878248): move the nested function definitions to module top-level.
    @tf.function(input_signature=input_signature,
                 autograph=False,
                 jit_compile=jit_compile,
                 reduce_retracing=reduce_retracing)
    def slow_adaptation_interval(num_steps, previous_n, previous_state,
                                 previous_mean, previous_cov):
        """Interval that tunes the mass matrix and step size simultaneously.

    This corresponds to window 2 in the Stan HMC parameter tuning scheme.

    Args:
      num_steps: Number of tuning steps the interval will run.
      previous_n: Previous number of tuning steps we have run.
      previous_state: Initial state of the tuning interval.
      previous_mean: Current estimated posterior mean.
      previous_cov: Current estimated posterior covariance matrix.

    Returns:
      total_n: Total number of tuning steps we have run.
      next_state: Last state of the tuning interval.
      next_pkr: Kernel result from the TransitionKernel at the end of the
        tuning interval.
      next_mean: estimated posterior mean after tuning.
      next_cov: estimated posterior covariance matrix after tuning.
    """
        previous_pkr = hmc_step_size_tuning.bootstrap_results(previous_state)
        total_n, next_state, next_pkr, next_mean, next_cov = tf.while_loop(
            lambda i, *_: i < num_steps + previous_n,
            body_fn_window2,
            loop_vars=(previous_n, previous_state, previous_pkr, previous_mean,
                       previous_cov),
            maximum_iterations=num_steps,
            parallel_iterations=parallel_iterations)
        float_n = tf.cast(total_n, next_cov.dtype)
        cov = next_cov / (float_n - 1.)

        # Regularization
        scaled_cov = (float_n / (float_n + 5.)) * cov
        shrinkage = 1e-3 * (5. / (float_n + 5.))
        next_cov = scaled_cov + shrinkage

        return total_n, next_state, next_pkr, next_mean, next_cov

    def trace_fn(_, pkr):
        return (
            pkr.inner_results.target_log_prob,
            pkr.inner_results.leapfrogs_taken,
            pkr.inner_results.has_divergence,
            pkr.inner_results.energy,
            pkr.inner_results.log_accept_ratio,
            pkr.inner_results.reach_max_depth,
            pkr.inner_results.step_size,
        )

    @tf.function(autograph=False, jit_compile=jit_compile)
    def run_chain(num_results, current_state, previous_kernel_results):
        return tfp.mcmc.sample_chain(
            num_results=num_results,
            num_burnin_steps=0,
            current_state=current_state,
            previous_kernel_results=previous_kernel_results,
            kernel=hmc_inner,
            trace_fn=trace_fn,
            parallel_iterations=parallel_iterations,
            seed=seed_stream())

    # Main sampling with tuning routine.
    num_steps_tuning_window_schedule0 = tuning_window_schedule[0]

    # Window 1 to tune step size
    logging.info('Tuning Window 1...')
    next_state, _ = fast_adaptation_interval(num_steps_tuning_window_schedule0,
                                             init_state)

    next_mean = tf.zeros_like(init_state)
    next_cov = tf.zeros(ps.concat(
        [ps.shape(init_state), ps.shape(init_state)[-1:]], axis=-1),
                        dtype=dtype)

    mean_updater = tf.zeros([total_ndims], dtype=dtype)
    diag_updater = tf.ones([total_ndims], dtype=dtype)

    # Window 2 to tune mass matrix.
    total_n = 0
    for i, num_steps in enumerate(tuning_window_schedule[1:-1]):
        logging.info('Tuning Window 2 - %s...', i)
        if not use_wide_window_expanding_mode:
            num_steps = num_steps * 2**i
        with tf.control_dependencies([
                loc_conditioner.assign(mean_updater, read_value=False),
                scale_conditioner.assign(diag_updater, read_value=False)
        ]):
            (total_n, next_state_, _, next_mean_,
             next_cov_) = slow_adaptation_interval(num_steps, total_n,
                                                   next_state, next_mean,
                                                   next_cov)
            diag_part = tf.linalg.diag_part(next_cov_)
            if ps.rank(next_state) > 1:
                mean_updater = tf.reduce_mean(next_mean_, axis=0)
                diag_updater = tf.math.sqrt(tf.reduce_mean(diag_part, axis=0))
            else:
                mean_updater = next_mean_
                diag_updater = tf.math.sqrt(diag_part)

            if use_wide_window_expanding_mode:
                next_mean = tf.concat([next_mean_, next_mean_], axis=0)
                next_cov = tf.concat([next_cov_, next_cov_], axis=0)
                next_state = tf.concat([next_state_, next_state_], axis=0)
            else:
                next_mean, next_cov, next_state = next_mean_, next_cov_, next_state_

    num_steps_tuning_window_schedule3 = tuning_window_schedule[-1]
    num_batches = ps.size0(next_state)
    if nchains > num_batches:
        final_init_state = tf.repeat(next_state, (nchains + 1) // num_batches,
                                     axis=0)[:nchains]
    else:
        final_init_state = next_state[:nchains]

    with tf.control_dependencies([
            loc_conditioner.assign(mean_updater, read_value=False),
            scale_conditioner.assign(diag_updater, read_value=False)
    ]):
        # Window 3 step size tuning
        logging.info('Tuning Window 3...')
        final_tuned_state, final_pkr = fast_adaptation_interval(
            num_steps_tuning_window_schedule3, final_init_state)

        # Final samples
        logging.info('Sampling...')
        nuts_samples, diagnostic = run_chain(num_samples, final_tuned_state,
                                             final_pkr.inner_results)

    return nuts_samples, diagnostic, conditioning_bijector
Beispiel #25
0
def _make_post_swap_replica_results(pre_swap_replica_results,
                                    inverse_temperatures,
                                    swapped_inverse_temperatures,
                                    is_swap_accepted_mask, swap_tensor_fn):
    """Return Kernel results, valid for post-swap states.

  Fields will be removed if they cannot be updated in an unambiguous manner.

  Args:
    pre_swap_replica_results: Kernel results obtained by running
      inner_kernel.one_step before swapping.
    inverse_temperatures: Tensor of inverse temperatures.
    swapped_inverse_temperatures: Tensor of inverse temperatures, permuted by
      swaps.
    is_swap_accepted_mask: Shape [num_replica] + batch_shape boolean Tensor
      telling which swaps were accepted.  Returns Kernel results of same type as
      pre_swap_replica_results.
    swap_tensor_fn: Callable. For `x.shape = [num_replica] + batch_shape`,
      swap_tensor_fn(x) performs swaps where they are accepted, and does not
      swap otherwise.

  Returns:
    new_replica_results:  Same type as pre_swap_replica_results.

  Raises:
    NotImplementedError: If type of [nested] results is not handled.
  """
    if not isinstance(pre_swap_replica_results,
                      metropolis_hastings.MetropolisHastingsKernelResults):
        # TODO(b/143702650) Handle other kernels.
        raise NotImplementedError(
            '`pre_swap_replica_results` currently works only for '
            'MetropolisHastingsKernelResults.  Found: {}. '
            'Please file a request with the TensorFlow Probability team.'.
            format(type(pre_swap_replica_results)))

    kr = pre_swap_replica_results
    dtype = swapped_inverse_temperatures.dtype

    # Hard to modify proposed_results in an um-ambiguous manner.
    # ...we also don't need to.
    kr = kr._replace(
        proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype),
        proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype),
    )

    replica_and_batch_rank = ps.rank(kr.log_accept_ratio)

    # After using swap_tensor_fn on "values", values will be multiplied by the
    # swapped_inverse_temperatures.  We need it to be multiplied instead by the
    # inverse temperature corresponding to its index.
    it_ratio_raw = inverse_temperatures / swapped_inverse_temperatures
    it_ratio = tf.where(
        is_swap_accepted_mask,
        mcmc_util.left_justified_expand_dims_to(it_ratio_raw,
                                                replica_and_batch_rank),
        tf.convert_to_tensor(1.0, dtype=dtype))

    def _swap_then_retemper(x):
        x, is_multipart = mcmc_util.prepare_state_parts(x)
        it_ratio_ = mcmc_util.left_justified_expand_dims_like(it_ratio, x[0])
        x = [swap_tensor_fn(x_part) * it_ratio_ for x_part in x]
        if not is_multipart:
            x = x[0]
        return x

    if isinstance(kr.accepted_results,
                  hmc.UncalibratedHamiltonianMonteCarloKernelResults):
        kr = kr._replace(accepted_results=kr.accepted_results._replace(
            target_log_prob=_swap_then_retemper(
                kr.accepted_results.target_log_prob),
            grads_target_log_prob=_swap_then_retemper(
                kr.accepted_results.grads_target_log_prob)))
    elif isinstance(kr.accepted_results,
                    random_walk_metropolis.UncalibratedRandomWalkResults):
        kr = kr._replace(accepted_results=kr.accepted_results._replace(
            target_log_prob=_swap_then_retemper(
                kr.accepted_results.target_log_prob)))
    else:
        # TODO(b/143702650) Handle other kernels.
        raise NotImplementedError(
            'Only HMC and RWMH Kernels are handled at this time. Please file a '
            'request with the TensorFlow Probability team.')

    return kr
Beispiel #26
0
def covariance(x,
               y=None,
               sample_axis=0,
               event_axis=-1,
               keepdims=False,
               name=None):
    """Sample covariance between observations indexed by `event_axis`.

  Given `N` samples of scalar random variables `X` and `Y`, covariance may be
  estimated as

  ```none
  Cov[X, Y] := N^{-1} sum_{n=1}^N (X_n - Xbar) Conj{(Y_n - Ybar)}
  Xbar := N^{-1} sum_{n=1}^N X_n
  Ybar := N^{-1} sum_{n=1}^N Y_n
  ```

  For vector-variate random variables `X = (X1, ..., Xd)`, `Y = (Y1, ..., Yd)`,
  one is often interested in the covariance matrix, `C_{ij} := Cov[Xi, Yj]`.

  ```python
  x = tf.random.normal(shape=(100, 2, 3))
  y = tf.random.normal(shape=(100, 2, 3))

  # cov[i, j] is the sample covariance between x[:, i, j] and y[:, i, j].
  cov = tfp.stats.covariance(x, y, sample_axis=0, event_axis=None)

  # cov_matrix[i, m, n] is the sample covariance of x[:, i, m] and y[:, i, n]
  cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1)
  ```

  Notice we divide by `N`, which does not create `NaN` when `N = 1`, but is
  slightly biased.

  Args:
    x:  A numeric `Tensor` holding samples.
    y:  Optional `Tensor` with same `dtype` and `shape` as `x`.
      Default value: `None` (`y` is effectively set to `x`).
    sample_axis: Scalar or vector `Tensor` designating axis holding samples, or
      `None` (meaning all axis hold samples).
      Default value: `0` (leftmost dimension).
    event_axis:  Scalar or vector `Tensor`, or `None` (scalar events).
      Axis indexing random events, whose covariance we are interested in.
      If a vector, entries must form a contiguous block of dims. `sample_axis`
      and `event_axis` should not intersect.
      Default value: `-1` (rightmost axis holds events).
    keepdims:  Boolean.  Whether to keep the sample axis as singletons.
    name: Python `str` name prefixed to Ops created by this function.
          Default value: `None` (i.e., `'covariance'`).

  Returns:
    cov: A `Tensor` of same `dtype` as the `x`, and rank equal to
      `rank(x) - len(sample_axis) + 2 * len(event_axis)`.

  Raises:
    AssertionError:  If `x` and `y` are found to have different shape.
    ValueError:  If `sample_axis` and `event_axis` are found to overlap.
    ValueError:  If `event_axis` is found to not be contiguous.
  """

    with tf.name_scope(name or 'covariance'):
        x = tf.convert_to_tensor(x, name='x')
        # Covariance *only* uses the centered versions of x (and y).
        x = x - tf.reduce_mean(x, axis=sample_axis, keepdims=True)

        if y is None:
            y = x
        else:
            y = tf.convert_to_tensor(y, name='y', dtype=x.dtype)
            # If x and y have different shape, sample_axis and event_axis will likely
            # be wrong for one of them!
            tensorshape_util.assert_is_compatible_with(x.shape, y.shape)
            y = y - tf.reduce_mean(y, axis=sample_axis, keepdims=True)

        if event_axis is None:
            return tf.reduce_mean(x * tf.math.conj(y),
                                  axis=sample_axis,
                                  keepdims=keepdims)

        if sample_axis is None:
            raise ValueError(
                'sample_axis was None, which means all axis hold events, and this '
                'overlaps with event_axis ({})'.format(event_axis))

        event_axis = _make_positive_axis(event_axis, ps.rank(x))
        sample_axis = _make_positive_axis(sample_axis, ps.rank(x))

        # If we get lucky and axis is statically defined, we can do some checks.
        if _is_list_like(event_axis) and _is_list_like(sample_axis):
            event_axis = tuple(map(int, event_axis))
            sample_axis = tuple(map(int, sample_axis))
            if set(event_axis).intersection(sample_axis):
                raise ValueError(
                    'sample_axis ({}) and event_axis ({}) overlapped'.format(
                        sample_axis, event_axis))
            if (np.diff(np.array(sorted(event_axis))) > 1).any():
                raise ValueError(
                    'event_axis must be contiguous. Found: {}'.format(
                        event_axis))
            batch_axis = list(
                sorted(
                    set(range(tensorshape_util.rank(
                        x.shape))).difference(sample_axis + event_axis)))
        else:
            batch_axis = ps.setdiff1d(ps.range(0, ps.rank(x)),
                                      ps.concat((sample_axis, event_axis), 0))

        event_axis = ps.cast(event_axis, dtype=tf.int32)
        sample_axis = ps.cast(sample_axis, dtype=tf.int32)
        batch_axis = ps.cast(batch_axis, dtype=tf.int32)

        # Permute x/y until shape = B + E + S
        perm_for_xy = ps.concat((batch_axis, event_axis, sample_axis), 0)
        x_permed = tf.transpose(a=x, perm=perm_for_xy)
        y_permed = tf.transpose(a=y, perm=perm_for_xy)

        batch_ndims = ps.size(batch_axis)
        batch_shape = ps.shape(x_permed)[:batch_ndims]
        event_ndims = ps.size(event_axis)
        event_shape = ps.shape(x_permed)[batch_ndims:batch_ndims + event_ndims]
        sample_shape = ps.shape(x_permed)[batch_ndims + event_ndims:]
        sample_ndims = ps.size(sample_shape)
        n_samples = ps.reduce_prod(sample_shape)
        n_events = ps.reduce_prod(event_shape)

        # Flatten sample_axis into one long dim.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, event_shape, [n_samples]), 0))
        # Do the same for event_axis.
        x_permed_flat = tf.reshape(
            x_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))
        y_permed_flat = tf.reshape(
            y_permed, ps.concat((batch_shape, [n_events], [n_samples]), 0))

        # After matmul, cov.shape = batch_shape + [n_events, n_events]
        cov = tf.matmul(x_permed_flat, y_permed_flat,
                        adjoint_b=True) / ps.cast(n_samples, x.dtype)

        # Insert some singletons to make
        # cov.shape = batch_shape + event_shape**2 + [1,...,1]
        # This is just like x_permed.shape, except the sample_axis is all 1's, and
        # the [n_events] became event_shape**2.
        cov = tf.reshape(
            cov,
            ps.concat(
                (
                    batch_shape,
                    # event_shape**2 used here because it is the same length as
                    # event_shape, and has the same number of elements as one
                    # batch of covariance.
                    event_shape**2,
                    ps.ones([sample_ndims], tf.int32)),
                0))
        # Permuting by the argsort inverts the permutation, making
        # cov.shape have ones in the position where there were samples, and
        # [n_events * n_events] in the event position.
        cov = tf.transpose(a=cov, perm=ps.invert_permutation(perm_for_xy))

        # Now expand event_shape**2 into event_shape + event_shape.
        # We here use (for the first time) the fact that we require event_axis to be
        # contiguous.
        e_start = event_axis[0]
        e_len = 1 + event_axis[-1] - event_axis[0]
        cov = tf.reshape(
            cov,
            ps.concat((ps.shape(cov)[:e_start], event_shape, event_shape,
                       ps.shape(cov)[e_start + e_len:]), 0))

        # tf.squeeze requires python ints for axis, not Tensor.  This is enough to
        # require our axis args to be constants.
        if not keepdims:
            squeeze_axis = ps.where(sample_axis < e_start, sample_axis,
                                    sample_axis + e_len)
            cov = _squeeze(cov, axis=squeeze_axis)

        return cov
def left_justified_expand_dims_like(x, reference, name=None):
  """Right pads `x` with `rank(reference) - rank(x)` ones."""
  with tf.name_scope(name or 'left_justified_expand_dims_like'):
    return left_justified_expand_dims_to(x, ps.rank(reference))
class MarkovChainBijectorTest(test_util.TestCase):

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        dict(testcase_name='deterministic_prior',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Normal(loc=x, scale=1.)),
        dict(testcase_name='deterministic_transition',
             prior_fn=lambda: tfd.Normal(loc=[-100., 0., 100.], scale=1.),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='fully_deterministic',
             prior_fn=lambda: tfd.Deterministic([-100., 0., 100.]),
             transition_fn=lambda _, x: tfd.Deterministic(x)),
        dict(testcase_name='mvn_diag',
             prior_fn=(lambda: tfd.MultivariateNormalDiag(loc=[[2.], [2.]],
                                                          scale_diag=[1.])),
             transition_fn=lambda _, x: tfd.VectorDeterministic(x)),
        dict(testcase_name='docstring_dirichlet',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 {'probs': tfd.Dirichlet([1., 1.])}),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 {
                     'probs':
                     tfd.MultivariateNormalDiag(loc=x['probs'],
                                                scale_diag=[0.1, 0.1])
                 },
                 batch_ndims=ps.rank(x['probs']))),
        dict(testcase_name='uniform_step',
             prior_fn=lambda: tfd.Exponential(tf.ones([4, 1])),
             transition_fn=lambda _, x: tfd.Uniform(low=x, high=x + 1.)),
        dict(testcase_name='joint_distribution',
             prior_fn=lambda: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=2,
                 model={
                     'a':
                     tfd.Gamma(tf.zeros([5]), 1.),
                     'b':
                     lambda a: (tfb.Reshape(event_shape_in=[4, 3],
                                            event_shape_out=[2, 3, 2])
                                (tfd.Independent(tfd.Normal(
                                    loc=tf.zeros([5, 4, 3]),
                                    scale=a[..., tf.newaxis, tf.newaxis]),
                                                 reinterpreted_batch_ndims=2)))
                 }),
             transition_fn=lambda _, x: tfd.JointDistributionNamedAutoBatched(
                 batch_ndims=ps.rank_from_shape(x['a'].shape),
                 model={
                     'a':
                     tfd.Normal(loc=x['a'], scale=1.),
                     'b':
                     lambda a: tfd.Deterministic(x['b'] + a[
                         ..., tf.newaxis, tf.newaxis, tf.newaxis])
                 })),
        dict(testcase_name='nested_chain',
             prior_fn=lambda: tfd.
             MarkovChain(initial_state_prior=tfb.Split(2)
                         (tfd.MultivariateNormalDiag(0., [1., 2.])),
                         transition_fn=lambda _, x: tfb.Split(2)
                         (tfd.MultivariateNormalDiag(x[0], [1., 2.])),
                         num_steps=6),
             transition_fn=(
                 lambda _, x: tfd.JointDistributionSequentialAutoBatched(
                     [
                         tfd.MultivariateNormalDiag(x[0], [1.]),
                         tfd.MultivariateNormalDiag(x[1], [1.])
                     ],
                     batch_ndims=ps.rank(x[0])))))
    # pylint: enable=g-long-lambda
    def test_default_bijector(self, prior_fn, transition_fn):
        chain = tfd.MarkovChain(initial_state_prior=prior_fn(),
                                transition_fn=transition_fn,
                                num_steps=7)

        y = self.evaluate(chain.sample(seed=test_util.test_seed()))
        bijector = chain.experimental_default_event_space_bijector()

        self.assertAllEqual(chain.batch_shape_tensor(),
                            bijector.experimental_batch_shape_tensor())

        x = bijector.inverse(y)
        yy = bijector.forward(tf.nest.map_structure(
            tf.identity, x))  # Bypass bijector cache.
        self.assertAllCloseNested(y, yy)

        chain_event_ndims = tf.nest.map_structure(ps.rank_from_shape,
                                                  chain.event_shape_tensor())
        self.assertAllEqualNested(bijector.inverse_min_event_ndims,
                                  chain_event_ndims)

        ildj = bijector.inverse_log_det_jacobian(
            tf.nest.map_structure(tf.identity, y),  # Bypass bijector cache.
            event_ndims=chain_event_ndims)
        if not bijector.is_constant_jacobian:
            self.assertAllEqual(ildj.shape, chain.batch_shape)
        fldj = bijector.forward_log_det_jacobian(
            tf.nest.map_structure(tf.identity, x),  # Bypass bijector cache.
            event_ndims=bijector.inverse_event_ndims(chain_event_ndims))
        self.assertAllClose(ildj, -fldj)

        # Verify that event shapes are passed through and flattened/unflattened
        # correctly.
        inverse_event_shapes = bijector.inverse_event_shape(chain.event_shape)
        x_event_shapes = tf.nest.map_structure(
            lambda t, nd: t.shape[ps.rank(t) - nd:], x,
            bijector.forward_min_event_ndims)
        self.assertAllEqualNested(inverse_event_shapes, x_event_shapes)
        forward_event_shapes = bijector.forward_event_shape(
            inverse_event_shapes)
        self.assertAllEqualNested(forward_event_shapes, chain.event_shape)

        # Verify that the outputs of other methods have the correct structure.
        inverse_event_shape_tensors = bijector.inverse_event_shape_tensor(
            chain.event_shape_tensor())
        self.assertAllEqualNested(inverse_event_shape_tensors, x_event_shapes)
        forward_event_shape_tensors = bijector.forward_event_shape_tensor(
            inverse_event_shape_tensors)
        self.assertAllEqualNested(forward_event_shape_tensors,
                                  chain.event_shape_tensor())
Beispiel #29
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                target_log_prob_fn=self.target_log_prob_fn,
                inverse_temperatures=inverse_temperatures,
                untempered_log_prob_fn=self.untempered_log_prob_fn,
                tempered_log_prob_fn=self.tempered_log_prob_fn,
            )
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` '
                    'argument. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3)
            # Step the inner TransitionKernel.
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                seed=inner_seed)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs for use in the swap acceptance ratio.
            if self.tempered_log_prob_fn is None:
                # Efficient way of re-evaluating target_log_prob_fn on the
                # pre_swap_replica_states.
                untempered_energy_ignoring_ulp = (
                    # Since untempered_log_prob_fn is None, we may assume
                    # inverse_temperatures > 0 (else the target is improper).
                    pre_swap_replica_target_log_prob / inverse_temperatures)
            else:
                # The untempered_log_prob_fn does not factor into the acceptance ratio.
                # Proof: Suppose the tempered target is
                #   p_k(x) = f(x)^{beta_k} g(x),
                # So f(x) is tempered, and g(x) is not.  Then, the acceptance ratio for
                # a 1 <--> 2 swap is...
                #   (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2))
                # which depends only on f(x), since terms involving g(x) cancel.
                untempered_energy_ignoring_ulp = self.tempered_log_prob_fn(
                    *pre_swap_replica_states)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            # Note: The untempered_log_prob_fn (if provided) is not included in
            # untempered_pre_swap_replica_target_log_prob, and hence does not factor
            # into energy_diff. Why? Because, it cancels out in the acceptance ratio.
            energy_diff = (untempered_energy_ignoring_ulp -
                           mcmc_util.index_remapping_gather(
                               untempered_energy_ignoring_ulp,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _set_swapped_fields_to_nan(
                _swap_log_prob_and_maybe_grads(pre_swap_replica_results,
                                               post_swap_replica_states,
                                               inner_kernel))

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=seed,
            )

            return states, post_swap_kernel_results
Beispiel #30
0
    def loop_tree_doubling(self, step_size, momentum_state_memory,
                           current_step_meta_info, iter_, initial_step_state,
                           initial_step_metastate):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            batch_shape = prefer_static.shape(
                current_step_meta_info.init_energy)
            direction = tf.cast(tf.random.uniform(shape=batch_shape,
                                                  minval=0,
                                                  maxval=2,
                                                  dtype=tf.int32,
                                                  seed=self._seed_stream()),
                                dtype=tf.bool)

            tree_start_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[1], v[0]),
                initial_step_state)

            directions_expanded = [
                _rightmost_expand_to_rank(direction, prefer_static.rank(state))
                for state in tree_start_states.state
            ]

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(d, ss, -ss)
                    for d, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state, tree_final_states, final_not_divergence,
                continue_tree_final, energy_diff_tree_sum,
                momentum_subtree_cumsum, leapfrogs_taken
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                current_step_meta_info,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory)

            last_candidate_state = initial_step_metastate.candidate_state
            tree_weight = candidate_tree_state.weight
            if MULTINOMIAL_SAMPLE:
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-tf.random.uniform(shape=batch_shape,
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(s0)), s0,
                        s1) for s0, s1 in zip(candidate_tree_state.state,
                                              last_candidate_state.state)
                ],
                target=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.target, last_candidate_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            candidate_tree_state.target_grad_parts,
                            last_candidate_state.target_grad_parts)
                ],
                energy=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.energy, last_candidate_state.energy),
                weight=weight_sum)

            for new_candidate_state_temp, old_candidate_state_temp in zip(
                    new_candidate_state.state, last_candidate_state.state):
                new_candidate_state_temp.set_shape(
                    old_candidate_state_temp.shape)

            for new_candidate_grad_temp, old_candidate_grad_temp in zip(
                    new_candidate_state.target_grad_parts,
                    last_candidate_state.target_grad_parts):
                new_candidate_grad_temp.set_shape(
                    old_candidate_grad_temp.shape)

            # Update left right information of the trajectory, and check trajectory
            # level U turn
            tree_otherend_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[0], v[1]),
                initial_step_state)

            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    tf.stack(
                        [  # pylint: disable=g-complex-comprehension
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), r, l),
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), l, r),
                        ],
                        axis=0)
                    for l, r in zip(tf.nest.flatten(tree_final_states),
                                    tf.nest.flatten(tree_otherend_states))
                ])

            momentum_tree_cumsum = []
            for p0, p1 in zip(initial_step_metastate.momentum_sum,
                              momentum_subtree_cumsum):
                momentum_part_temp = p0 + p1
                momentum_part_temp.set_shape(p0.shape)
                momentum_tree_cumsum.append(momentum_part_temp)

            for new_state_temp, old_state_temp in zip(
                    tf.nest.flatten(new_step_state),
                    tf.nest.flatten(initial_step_state)):
                new_state_temp.set_shape(old_state_temp.shape)

            if GENERALIZED_UTURN:
                state_diff = momentum_tree_cumsum
            else:
                state_diff = [s[1] - s[0] for s in new_step_state.state]

            no_u_turns_trajectory = has_not_u_turn(
                state_diff, [m[0] for m in new_step_state.momentum],
                [m[1] for m in new_step_state.momentum],
                log_prob_rank=prefer_static.rank_from_shape(batch_shape))

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                momentum_sum=momentum_tree_cumsum,
                energy_diff_sum=(energy_diff_tree_sum +
                                 initial_step_metastate.energy_diff_sum),
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, new_step_state, new_step_metastate