Пример #1
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'rwm', 'one_step')):
            with tf.name_scope('initialize'):
                if mcmc_util.is_list_like(current_state):
                    current_state_parts = list(current_state)
                else:
                    current_state_parts = [current_state]
                current_state_parts = [
                    tf.convert_to_tensor(s, name='current_state')
                    for s in current_state_parts
                ]

            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.new_state_fn`.
            # In other words:
            # - If we were given a seed, we sanitize it to stateless, and
            #   if the `new_state_fn` doesn't like that, we crash and propagate
            #   the error.  Rationale: The contract is stateless sampling given
            #   seed, and doing otherwise would not meet it.
            # - If we were not given a seed, we try `new_state_fn` with a stateless
            #   seed.  Rationale: This is the future.
            # - If it fails with a seed incompatibility problem (as best we can
            #   detect from here), we issue a warning and try it again with a
            #   stateful-style seed. Rationale: User code that didn't set seeds
            #   shouldn't suddenly break.
            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                force_stateless = True
                seed = samplers.sanitize_seed(seed)
            else:
                force_stateless = False
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                stateful_seed = self._seed_stream()
                seed = samplers.sanitize_seed(stateful_seed)
            try:
                next_state_parts = self.new_state_fn(current_state_parts, seed)  # pylint: disable=not-callable
            except TypeError as e:
                if ('Expected int for argument' not in str(e)
                        and TENSOR_SEED_MSG_PREFIX
                        not in str(e)) or force_stateless:
                    raise
                msg = (
                    'Falling back to `int` seed for `new_state_fn` {}. Please update '
                    'to use `tf.random.stateless_*` RNGs. '
                    'This fallback may be removed after 10-Sep-2020. ({})')
                warnings.warn(msg.format(self.new_state_fn, str(e)))
                seed = None
                next_state_parts = self.new_state_fn(  # pylint: disable=not-callable
                    current_state_parts, stateful_seed)
            # Compute `target_log_prob` so its available to MetropolisHastings.
            next_target_log_prob = self.target_log_prob_fn(*next_state_parts)  # pylint: disable=not-callable

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedRandomWalkResults(
                    log_acceptance_correction=tf.zeros_like(
                        next_target_log_prob),
                    target_log_prob=next_target_log_prob,
                    seed=samplers.zeros_seed() if seed is None else seed,
                ),
            ]
    def one_step(self, current_state, previous_kernel_results):
        """Runs one iteration of the Elliptical Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions
        index independent chains,
        `r = tf.rank(log_likelihood_fn(*normal_sampler_fn()))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      TypeError: if `not log_likelihood.dtype.is_floating`.
    """
        with tf.compat.v1.name_scope(name=mcmc_util.make_name(
                self.name, 'elliptical_slice', 'one_step'),
                                     values=[
                                         self._seed_stream, current_state,
                                         previous_kernel_results.log_likelihood
                                     ]):
            with tf.compat.v1.name_scope('initialize'):
                [init_state_parts, init_log_likelihood
                 ] = _prepare_args(self.log_likelihood_fn, current_state,
                                   previous_kernel_results.log_likelihood)

            normal_samples = self.normal_sampler_fn(self._seed_stream())  # pylint: disable=not-callable
            normal_samples = list(normal_samples) if mcmc_util.is_list_like(
                normal_samples) else [normal_samples]
            u = tf.random.uniform(
                shape=tf.shape(init_log_likelihood),
                seed=self._seed_stream(),
                dtype=init_log_likelihood.dtype.base_dtype,
            )
            threshold = init_log_likelihood + tf.math.log(u)

            starting_angle = tf.random.uniform(
                shape=tf.shape(init_log_likelihood),
                minval=0.,
                maxval=2 * np.pi,
                name='angle',
                seed=self._seed_stream(),
                dtype=init_log_likelihood.dtype.base_dtype,
            )
            starting_angle_min = starting_angle - 2 * np.pi
            starting_angle_max = starting_angle

            starting_state_parts = _rotate_on_ellipse(init_state_parts,
                                                      normal_samples,
                                                      starting_angle)
            starting_log_likelihood = self.log_likelihood_fn(
                *starting_state_parts)  # pylint: disable=not-callable

            def chain_not_done(angle, angle_min, angle_max,
                               current_state_parts, current_log_likelihood):
                del angle, angle_min, angle_max, current_state_parts
                return tf.reduce_any(current_log_likelihood < threshold)

            def sample_next_angle(angle, angle_min, angle_max,
                                  current_state_parts, current_log_likelihood):
                """Slice sample a new angle, and rotate init_state by that amount."""
                chain_not_done = current_log_likelihood < threshold
                # Box in on angle. Only update angles for which we haven't generated a
                # point that beats the threshold.
                angle_min = tf.where(
                    tf.math.logical_and(angle < 0, chain_not_done), angle,
                    angle_min)
                angle_max = tf.where(
                    tf.math.logical_and(angle >= 0, chain_not_done), angle,
                    angle_max)
                new_angle = tf.random.uniform(
                    shape=tf.shape(current_log_likelihood),
                    minval=angle_min,
                    maxval=angle_max,
                    seed=self._seed_stream(),
                    dtype=angle.dtype.base_dtype)
                angle = tf.where(chain_not_done, new_angle, angle)
                next_state_parts = _rotate_on_ellipse(init_state_parts,
                                                      normal_samples, angle)

                new_state_parts = []
                broadcasted_chain_not_done = _right_pad_with_ones(
                    chain_not_done, tf.rank(next_state_parts[0]))
                for n_state, c_state in zip(next_state_parts,
                                            current_state_parts):
                    new_state_part = tf.where(broadcasted_chain_not_done,
                                              n_state, c_state)
                    new_state_parts.append(new_state_part)

                return (
                    angle,
                    angle_min,
                    angle_max,
                    new_state_parts,
                    self.log_likelihood_fn(*new_state_parts)  # pylint: disable=not-callable
                )

            [
                next_angle,
                _,
                _,
                next_state_parts,
                next_log_likelihood,
            ] = tf.while_loop(cond=chain_not_done,
                              body=sample_next_angle,
                              loop_vars=[
                                  starting_angle, starting_angle_min,
                                  starting_angle_max, starting_state_parts,
                                  starting_log_likelihood
                              ])

            return [
                next_state_parts if mcmc_util.is_list_like(current_state) else
                next_state_parts[0],
                EllipticalSliceSamplerKernelResults(
                    log_likelihood=next_log_likelihood,
                    angle=next_angle,
                    normal_samples=normal_samples,
                ),
            ]
Пример #3
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.compat.v1.name_scope(
                name=mcmc_util.make_name(self.name,
                                         'simple_step_size_adaptation',
                                         'one_step'),
                values=[current_state, previous_kernel_results]):

            # Set the step_size.
            inner_results = self.step_size_setter_fn(
                previous_kernel_results.inner_results,
                previous_kernel_results.new_step_size)

            # Step the inner kernel.
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results)

            # Get the new step size.
            log_accept_prob = self.log_accept_prob_getter_fn(new_inner_results)
            log_target_accept_prob = tf.math.log(
                previous_kernel_results.target_accept_prob)

            state_parts = tf.nest.flatten(current_state)
            step_size = self.step_size_getter_fn(new_inner_results)
            step_size_parts = tf.nest.flatten(step_size)
            log_accept_prob_rank = tf.rank(log_accept_prob)

            new_step_size_parts = []
            for step_size_part, state_part in zip(step_size_parts,
                                                  state_parts):
                # Compute new step sizes for each step size part. If step size part has
                # smaller rank than the corresponding state part, then the difference is
                # averaged away in the log accept prob.
                #
                # Example:
                #
                # state_part has shape      [2, 3, 4, 5]
                # step_size_part has shape     [1, 4, 1]
                # log_accept_prob has shape [2, 3, 4]
                #
                # Since step size has 1 rank fewer than the state, we reduce away the
                # leading dimension of log_accept_prob to get a Tensor with shape [3,
                # 4]. Next, since log_accept_prob must broadcast into step_size_part on
                # the left, we reduce the dimensions where their shapes differ, to get a
                # Tensor with shape [1, 4], which now is compatible with the leading
                # dimensions of step_size_part.
                #
                # There is a subtlety here in that step_size_parts might be a length-1
                # list, which means that we'll be "structure-broadcasting" it for all
                # the state parts (see logic in, e.g., hmc.py). In this case we must
                # assume that that the lone step size provided broadcasts with the event
                # dims of each state part. This means that either step size has no
                # dimensions corresponding to chain dimensions, or all states are of the
                # same shape. For the former, we want to reduce over all chain
                # dimensions. For the later, we want to use the same logic as in the
                # non-structure-broadcasted case.
                #
                # It turns out we can compute the reduction dimensions for both cases
                # uniformly by taking the rank of any state part. This obviously works
                # in the second case (where all state ranks are the same). In the first
                # case, all state parts have the rank L + D_i + B, where L is the rank
                # of log_accept_prob, D_i is the non-shared dimensions amongst all
                # states, and B are the shared dimensions of all the states, which are
                # equal to the step size. When we subtract B, we will always get a
                # number >= L, which means we'll get the full reduction we want.
                num_reduce_dims = tf.minimum(
                    log_accept_prob_rank,
                    tf.rank(state_part) - tf.rank(step_size_part))
                reduced_log_accept_prob = _reduce_logmeanexp(
                    log_accept_prob, tf.range(num_reduce_dims))
                # reduced_log_accept_prob must broadcast into step_size_part on the
                # left, so we do an additional reduction over dimensions where their
                # shapes differ.
                reduce_indices = _get_differing_dims(reduced_log_accept_prob,
                                                     step_size_part)
                reduced_log_accept_prob = _reduce_logmeanexp(
                    reduced_log_accept_prob, reduce_indices, keepdims=True)

                new_step_size_part = mcmc_util.choose(
                    reduced_log_accept_prob > log_target_accept_prob,
                    step_size_part *
                    (1. + previous_kernel_results.adaptation_rate),
                    step_size_part /
                    (1. + previous_kernel_results.adaptation_rate))

                new_step_size_parts.append(
                    tf.compat.v1.where(
                        previous_kernel_results.step <
                        self.num_adaptation_steps, new_step_size_part,
                        step_size_part))
            new_step_size = tf.nest.pack_sequence_as(step_size,
                                                     new_step_size_parts)

            return new_state, previous_kernel_results._replace(
                inner_results=new_inner_results,
                step=previous_kernel_results.step + 1,
                new_step_size=new_step_size)
Пример #4
0
  def one_step(self, current_state, previous_kernel_results, seed=None):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

    # The code below propagates one step states of shape
    #  [n_replica] + batch_shape + event_shape.
    #
    # The step is done in three parts:
    #  1) Call one_step to transition states via a tempered version of
    #     self.target_log_prob_fn (see _replica_target_log_prob).
    #  2) Permute values in states
    #  3) Update state-dependent values, such as log_probs.
    #
    # We chose to swap states, rather than temperatures, because...
    # (i)  If swapping temperatures, you *still* have to swap log_probs to
    #      determine acceptance, as well as states (for kernel results).
    #      So it's just as difficult to swap temperatures.
    # (ii) If swapping temperatures, you have to take care to swap any user-
    #      supplied temperature related things (like step size).
    #      A-priori, we don't know what else will need to be swapped!
    # (iii)In both cases, the kernel results need to be updated in a non-trivial
    #      manner....so we either special-case, or use bootstrap.

    with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
      # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
      inverse_temperatures = tf.convert_to_tensor(
          previous_kernel_results.inverse_temperatures,
          name='inverse_temperatures')

      target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
          target_log_prob_fn=self.target_log_prob_fn,
          inverse_temperatures=inverse_temperatures,
          untempered_log_prob_fn=self.untempered_log_prob_fn,
          tempered_log_prob_fn=self.tempered_log_prob_fn,
      )
      # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
      try:
        inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
            target_log_prob_for_inner_kernel)
      except TypeError as e:
        if 'argument' not in str(e):
          raise
        raise TypeError(
            '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` '
            'argument. `TransitionKernel` instances now receive seeds via '
            '`one_step`.')

      seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
      inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3)
      # Step the inner TransitionKernel.
      [
          pre_swap_replica_states,
          pre_swap_replica_results,
      ] = inner_kernel.one_step(
          previous_kernel_results.post_swap_replica_states,
          previous_kernel_results.post_swap_replica_results,
          seed=inner_seed)

      pre_swap_replica_target_log_prob = _get_field(
          # These are tempered log probs (have been divided by temperature).
          pre_swap_replica_results, 'target_log_prob')

      dtype = pre_swap_replica_target_log_prob.dtype
      replica_and_batch_shape = ps.shape(
          pre_swap_replica_target_log_prob)
      batch_shape = replica_and_batch_shape[1:]
      replica_and_batch_rank = ps.rank(
          pre_swap_replica_target_log_prob)
      num_replica = ps.size0(inverse_temperatures)

      inverse_temperatures = bu.left_justified_broadcast_to(
          inverse_temperatures, replica_and_batch_shape)

      # Now that each replica has done one_step, it is time to consider swaps.

      # swap.shape = [n_replica], and is a "once only" permutation, meaning it
      # is achievable by a sequence of pairwise permutations, where each element
      # is moved at most once.
      # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
      # 1, keeping 2 fixed.  This exact same swap is considered for *every*
      # batch member.  Of course some batch members may accept and some reject.
      try:
        swaps = tf.cast(
            self.swap_proposal_fn(  # pylint: disable=not-callable
                num_replica,
                batch_shape=batch_shape,
                seed=swap_seed,
                step_count=previous_kernel_results.step_count),
            dtype=tf.int32)
      except TypeError as e:
        if 'step_count' not in str(e):
          raise
        warnings.warn(
            'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
            'the `step_count` argument. Falling back to omitting the '
            'argument. This fallback will be removed after 24-Oct-2020.')
        swaps = tf.cast(
            self.swap_proposal_fn(  # pylint: disable=not-callable
                num_replica,
                batch_shape=batch_shape,
                seed=swap_seed),
            dtype=tf.int32)

      null_swaps = bu.left_justified_expand_dims_like(
          tf.range(num_replica, dtype=swaps.dtype), swaps)
      swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                            self.validate_args)

      # Un-temper the log probs for use in the swap acceptance ratio.
      if self.tempered_log_prob_fn is None:
        # Efficient way of re-evaluating target_log_prob_fn on the
        # pre_swap_replica_states.
        untempered_energy_ignoring_ulp = (
            # Since untempered_log_prob_fn is None, we may assume
            # inverse_temperatures > 0 (else the target is improper).
            pre_swap_replica_target_log_prob / inverse_temperatures)
      else:
        # The untempered_log_prob_fn does not factor into the acceptance ratio.
        # Proof: Suppose the tempered target is
        #   p_k(x) = f(x)^{beta_k} g(x),
        # So f(x) is tempered, and g(x) is not.  Then, the acceptance ratio for
        # a 1 <--> 2 swap is...
        #   (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2))
        # which depends only on f(x), since terms involving g(x) cancel.
        untempered_energy_ignoring_ulp = self.tempered_log_prob_fn(
            *pre_swap_replica_states)

      # Since `swaps` is its own inverse permutation we automatically know the
      # swap counterpart: range(num_replica). We use this idea to compute the
      # acceptance in a vectorized manner at the cost of wasting roughly half
      # our computation. Although we could use `unique` to solve this problem,
      # we expect the cost of `unique` to be higher than the dozens of wasted
      # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
      # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

      # Note: diffs would normally be "proposed - current" however energy is
      # flipped since `energy == -log_prob`.
      # Note: The untempered_log_prob_fn (if provided) is not included in
      # untempered_pre_swap_replica_target_log_prob, and hence does not factor
      # into energy_diff. Why? Because, it cancels out in the acceptance ratio.
      energy_diff = (
          untempered_energy_ignoring_ulp -
          mcmc_util.index_remapping_gather(
              untempered_energy_ignoring_ulp,
              swaps, name='gather_swap_tlp'))
      swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
          inverse_temperatures, swaps, name='gather_swap_temps')
      inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

      # If i and j are swapping, log_accept_ratio[] i and j are equal.
      log_accept_ratio = (
          energy_diff * bu.left_justified_expand_dims_to(
              inverse_temp_diff, replica_and_batch_rank))

      log_accept_ratio = tf.where(
          tf.math.is_finite(log_accept_ratio),
          log_accept_ratio, tf.constant(-np.inf, dtype=dtype))

      # Produce log[Uniform] draws that are identical at swapped indices.
      log_uniform = tf.math.log(
          samplers.uniform(shape=replica_and_batch_shape,
                           dtype=dtype,
                           seed=logu_seed))
      anchor_swaps = tf.minimum(swaps, null_swaps)
      log_uniform = mcmc_util.index_remapping_gather(log_uniform, anchor_swaps)

      is_swap_accepted_mask = tf.less(
          log_uniform,
          log_accept_ratio,
          name='is_swap_accepted_mask')

      def _swap_tensor(x):
        return mcmc_util.choose(
            is_swap_accepted_mask,
            mcmc_util.index_remapping_gather(x, swaps), x)

      post_swap_replica_states = [
          _swap_tensor(s) for s in pre_swap_replica_states]

      expanded_null_swaps = bu.left_justified_broadcast_to(
          null_swaps, replica_and_batch_shape)
      is_swap_proposed = _compute_swap_notmatrix(
          # Broadcast both so they have shape [num_replica] + batch_shape.
          # This (i) makes them have same shape as is_swap_accepted, and
          # (ii) keeps shape consistent if someday swaps has a batch shape.
          expanded_null_swaps,
          bu.left_justified_broadcast_to(swaps, replica_and_batch_shape))

      # To get is_swap_accepted in ordered position, we use
      # _compute_swap_notmatrix on current and next replica positions.
      post_swap_replica_position = _swap_tensor(expanded_null_swaps)

      is_swap_accepted = _compute_swap_notmatrix(
          post_swap_replica_position,
          expanded_null_swaps)

      if self._state_includes_replicas:
        post_swap_states = post_swap_replica_states
      else:
        post_swap_states = [s[0] for s in post_swap_replica_states]

      post_swap_replica_results = _set_swapped_fields_to_nan(
          _swap_log_prob_and_maybe_grads(
              pre_swap_replica_results,
              post_swap_replica_states,
              inner_kernel))

      if mcmc_util.is_list_like(current_state):
        # We *always* canonicalize the states in the kernel results.
        states = post_swap_states
      else:
        states = post_swap_states[0]

      post_swap_kernel_results = ReplicaExchangeMCKernelResults(
          post_swap_replica_states=post_swap_replica_states,
          pre_swap_replica_results=pre_swap_replica_results,
          post_swap_replica_results=post_swap_replica_results,
          is_swap_proposed=is_swap_proposed,
          is_swap_accepted=is_swap_accepted,
          is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
          is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
          # Store the original pkr.inverse_temperatures in case its a
          # `tf.Variable`.
          inverse_temperatures=previous_kernel_results.inverse_temperatures,
          swaps=swaps,
          step_count=previous_kernel_results.step_count + 1,
          seed=seed,
      )

      return states, post_swap_kernel_results
Пример #5
0
    def __init__(self,
                 inner_kernel,
                 num_adaptation_steps,
                 target_accept_prob=0.75,
                 exploration_shrinkage=0.05,
                 step_count_smoothing=10,
                 decay_rate=0.75,
                 step_size_setter_fn=_hmc_like_step_size_setter_fn,
                 step_size_getter_fn=_hmc_like_step_size_getter_fn,
                 log_accept_prob_getter_fn=_hmc_like_log_accept_prob_getter_fn,
                 validate_args=False,
                 name=None):
        """Initializes this transition kernel.

    Args:
      inner_kernel: `TransitionKernel`-like object.
      num_adaptation_steps: Scalar `int` `Tensor` number of initial steps to
        during which to adjust the step size. This may be greater, less than, or
        equal to the number of burnin steps.
      target_accept_prob: A floating point `Tensor` representing desired
        acceptance probability. Must be a positive number less than 1. This can
        either be a scalar, or have shape [num_chains]. Default value: `0.75`
          (the [center of asymptotically optimal rate for HMC][1]).
      exploration_shrinkage: Floating point scalar `Tensor`. How strongly the
        exploration rate is biased towards the shrinkage target.
      step_count_smoothing: Int32 scalar `Tensor`. Number of "pseudo-steps"
        added to the number of steps taken to prevents noisy exploration during
        the early samples.
      decay_rate: Floating point scalar `Tensor`. How much to favor recent
        iterations over earlier ones. A value of 1 gives equal weight to all
        history.
      step_size_setter_fn: A callable with the signature `(kernel_results,
        new_step_size) -> new_kernel_results` where `kernel_results` are the
        results of the `inner_kernel`, `new_step_size` is a `Tensor` or a nested
        collection of `Tensor`s with the same structure as returned by the
        `step_size_getter_fn`, and `new_kernel_results` are a copy of
        `kernel_results` with the step size(s) set.
      step_size_getter_fn: A callable with the signature `(kernel_results) ->
        step_size` where `kernel_results` are the results of the `inner_kernel`,
        and `step_size` is a floating point `Tensor` or a nested collection of
        such `Tensor`s.
      log_accept_prob_getter_fn: A callable with the signature `(kernel_results)
        -> log_accept_prob` where `kernel_results` are the results of the
        `inner_kernel`, and `log_accept_prob` is a floating point `Tensor`.
        `log_accept_prob` can either be a scalar, or have shape [num_chains]. If
        it's the latter, `step_size` should also have the same leading
        dimension.
      validate_args: Python `bool`. When `True` kernel parameters are checked
        for validity. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'dual_averaging_step_size_adaptation').
    """
        inner_kernel = mcmc_util.enable_store_parameters_in_results(
            inner_kernel)

        with tf.name_scope(
                mcmc_util.make_name(name,
                                    'dual_averaging_step_size_adaptation',
                                    '__init__')) as name:
            dtype = dtype_util.common_dtype(
                [target_accept_prob, exploration_shrinkage, decay_rate],
                dtype_hint=tf.float32)
            target_accept_prob = tf.convert_to_tensor(
                target_accept_prob, dtype=dtype, name='target_accept_prob')
            exploration_shrinkage = tf.convert_to_tensor(
                exploration_shrinkage,
                dtype=dtype,
                name='exploration_shrinkage')
            step_count_smoothing = tf.convert_to_tensor(
                step_count_smoothing, dtype=dtype, name='step_count_smoothing')
            decay_rate = tf.convert_to_tensor(decay_rate,
                                              dtype=dtype,
                                              name='decay_rate')
            num_adaptation_steps = tf.convert_to_tensor(
                num_adaptation_steps,
                dtype=tf.int32,
                name='num_adaptation_steps')
            target_accept_prob = _maybe_validate_target_accept_prob(
                target_accept_prob, validate_args)

        self._parameters = dict(
            inner_kernel=inner_kernel,
            num_adaptation_steps=num_adaptation_steps,
            target_accept_prob=target_accept_prob,
            exploration_shrinkage=exploration_shrinkage,
            step_count_smoothing=step_count_smoothing,
            decay_rate=decay_rate,
            step_size_setter_fn=step_size_setter_fn,
            step_size_getter_fn=step_size_getter_fn,
            log_accept_prob_getter_fn=log_accept_prob_getter_fn,
            name=name)
Пример #6
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        is_seeded = seed is not None
        seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
        proposal_seed, acceptance_seed = samplers.split_seed(seed)

        with tf.name_scope(mcmc_util.make_name(self.name, 'mh', 'one_step')):
            # Take one inner step.
            inner_kwargs = dict(seed=proposal_seed) if is_seeded else {}
            [
                proposed_state,
                proposed_results,
            ] = self.inner_kernel.one_step(
                current_state, previous_kernel_results.accepted_results,
                **inner_kwargs)
            if mcmc_util.is_list_like(current_state):
                proposed_state = tf.nest.pack_sequence_as(
                    current_state, proposed_state)

            if (not has_target_log_prob(proposed_results)
                    or not has_target_log_prob(
                        previous_kernel_results.accepted_results)):
                raise ValueError('"target_log_prob" must be a member of '
                                 '`inner_kernel` results.')

            # Compute log(acceptance_ratio).
            to_sum = [
                proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob
            ]
            try:
                if (not mcmc_util.is_list_like(
                        proposed_results.log_acceptance_correction)
                        or proposed_results.log_acceptance_correction):
                    to_sum.append(proposed_results.log_acceptance_correction)
            except AttributeError:
                warnings.warn(
                    'Supplied inner `TransitionKernel` does not have a '
                    '`log_acceptance_correction`. Assuming its value is `0.`')
            log_accept_ratio = mcmc_util.safe_sum(
                to_sum, name='compute_log_accept_ratio')

            # If proposed state reduces likelihood: randomly accept.
            # If proposed state increases likelihood: always accept.
            # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
            #       ==> log(u) < log_accept_ratio
            log_uniform = tf.math.log(
                samplers.uniform(shape=prefer_static.shape(
                    proposed_results.target_log_prob),
                                 dtype=dtype_util.base_dtype(
                                     proposed_results.target_log_prob.dtype),
                                 seed=acceptance_seed))
            is_accepted = log_uniform < log_accept_ratio

            next_state = mcmc_util.choose(is_accepted,
                                          proposed_state,
                                          current_state,
                                          name='choose_next_state')

            kernel_results = MetropolisHastingsKernelResults(
                accepted_results=mcmc_util.choose(
                    is_accepted,
                    # We strip seeds when populating `accepted_results` because unlike
                    # other kernel result fields, seeds are not a per-chain value.
                    # Thus it is impossible to choose between a previously accepted
                    # seed value and a proposed seed, since said choice would need to
                    # be made on a per-chain basis.
                    mcmc_util.strip_seeds(proposed_results),
                    previous_kernel_results.accepted_results,
                    name='choose_inner_results'),
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
                extra=[],
                seed=seed,
            )

            return next_state, kernel_results
Пример #7
0
    def __init__(
        self,
        target_log_prob_fn,
        initial_covariance,
        initial_covariance_scaling=2.38**2,
        covariance_scaling_reducer=0.7,
        covariance_scaling_limiter=0.01,
        covariance_burnin=100,
        target_accept_ratio=0.234,
        pu=0.95,
        fixed_variance=0.01,
        extra_getter_fn=rwm_extra_getter_fn,
        extra_setter_fn=rwm_extra_setter_fn,
        log_accept_prob_getter_fn=rwm_log_accept_prob_getter_fn,
        seed=None,
        name=None,
    ):
        """Initializes this transition kernel.

        Args:
          target_log_prob_fn: Python callable which takes an argument like
            `current_state` and returns its (possibly unnormalized) log-density
            under the target distribution.
          initial_covariance: Python `list` of `Tensor`s each representing the
            initial covariance matrix of the proposal.
            The covariance matrix is tuned during the evolution of the MCMC chain.
            Default value: `None`.
          initial_covariance_scaling: Python floating point number representing a
            the initial value of the `covariance_scaling`. The value of
            `covariance_scaling` is tuned during the evolution of the MCMC chain.
            Let d represent the number of parameters e.g. as determined by the
            `initial_covariance`. The ratio given by the `covariance_scaling`
            divided by d is used to multiply the running covariance. The
            covariance scaling factor multiplied by the covariance matrix is used
            in the proposal at each step.
            Default value: 2.38**2.
          covariance_scaling_reducer: Python floating point number, bounded over the
            range (0.5,1.0], representing the constant factor used during the
            adaptation of the `covariance_scaling`.
            Default value: 0.7.
          covariance_scaling_limiter: Python floating point number, bounded between
            0.0 and 1.0, which places a limit on the maximum amount the
            `covariance_scaling` value can be purturbed at each interaction of the
            MCMC chain.
            Default value: 0.01.
          covariance_burnin: Python integer number of steps to take before starting to
            compute the running covariance.
            Default value: 100.
          target_accept_ratio: Python floating point number, bounded between 0.0 and 1.0,
            representing the target acceptance probability of the
            Metropolis–Hastings algorithm.
            The default value of 0.234 is applicable when the number of parameters is 3
            or more.  For the one parameter case typically the 'target_accept_ratio'
            should be set to 0.44.
          pu: Python floating point number, bounded between 0.0 and 1.0, representing the
            bounded convergence parameter.  See `random_walk_mvnorm_fn()` for further
            details.
            Default value: 0.95.
          fixed_variance: Python floating point number representing the variance of
            the fixed proposal distribution. See `random_walk_mvnorm_fn` for
            further details.
            Default value: 0.01.
          extra_getter_fn: A callable with the signature
            `(kernel_results) -> extra` where `kernel_results` are the results
            of the `inner_kernel`, and `extra` is a nested collection of
            `Tensor`s.
          extra_setter_fn: A callable with the signature
            `(kernel_results, args) -> new_kernel_results` where
            `kernel_results` are the results of the `inner_kernel`, `args`
            are a nested collection of `Tensor`s with the same
            structure as returned by the `extra_getter_fn`, and
            `new_kernel_results` are a copy of `kernel_results` with `args`
            in the `extra` field set.
          log_accept_prob_getter_fn: A callable with the signature
            `(kernel_results) -> log_accept_prob` where `kernel_results` are the
            results of the `inner_kernel`, and `log_accept_prob` is either a
            a scalar, or has shape [num_chains].
          seed: Python integer to seed the random number generator.
            Default value: `None`.
          name: Python `str` name prefixed to Ops created by this function.
            Default value: `None`.

        Returns:
          next_state: Tensor or list of `Tensor`s representing the state(s)
            of the Markov chain(s) at each result step. Has same shape as
            `current_state`.
          kernel_results: `collections.namedtuple` of internal calculations used to
            advance the chain.

        Raises:
          ValueError: if `initial_covariance_scaling` is less than or equal
            to 0.0.
          ValueError: if `covariance_scaling_reducer` is less than or equal
            to 0.5 or greater than 1.0.
          ValueError: if `covariance_scaling_limiter` is less than 0.0 or
            greater than 1.0.
          ValueError: if `covariance_burnin` is less than 0.
          ValueError: if `target_accept_ratio` is less than 0.0 or
            greater than 1.0.
          ValueError: if `pu` is less than 0.0 or greater than 1.0.
          ValueError: if `fixed_variance` is less than 0.0.
        """
        with tf.name_scope(
                mcmc_util.make_name(name, "AdaptiveRandomWalkMetropolis",
                                    "__init__")) as name:
            if initial_covariance_scaling <= 0.0:
                raise ValueError(
                    "`{}` must be a `float` greater than 0.0".format(
                        "initial_covariance_scaling"))
            if (covariance_scaling_reducer <= 0.5
                    or covariance_scaling_reducer > 1.0):
                raise ValueError(
                    "`{}` must be a `float` greater than 0.5 and less than or equal to 1.0."
                    .format("covariance_scaling_reducer"))
            if (covariance_scaling_limiter < 0.0
                    or covariance_scaling_limiter > 1.0):
                raise ValueError(
                    "`{}` must be a `float` between 0.0 and 1.0.".format(
                        "covariance_scaling_limiter"))
            if covariance_burnin < 0:
                raise ValueError(
                    "`{}` must be a `integer` greater or equal to 0.".format(
                        "covariance_burnin"))
            if target_accept_ratio <= 0.0 or target_accept_ratio > 1.0:
                raise ValueError(
                    "`{}` must be a `float` between 0.0 and 1.0.".format(
                        "target_accept_ratio"))
            if pu < 0.0 or pu > 1.0:
                raise ValueError(
                    "`{}` must be a `float` between 0.0 and 1.0.".format("pu"))
            if fixed_variance < 0.0:
                raise ValueError(
                    "`{}` must be a `float` greater than 0.0.".format(
                        "fixed_variance"))

        if initial_covariance.shape == ():
            initial_covariance_ = tf.reshape(initial_covariance, shape=(1, 1))
        else:
            initial_covariance_ = initial_covariance

        if mcmc_util.is_list_like(initial_covariance_):
            initial_covariance_parts = list(initial_covariance_)
        else:
            initial_covariance_parts = [initial_covariance_]
        initial_covariance_parts = [
            tf.convert_to_tensor(s, name="initial_covariance_")
            for s in initial_covariance_parts
        ]
        self._initial_covariance_matrices = tf.stack(initial_covariance_parts)

        dtype = dtype_util.base_dtype(self._initial_covariance_matrices.dtype)
        shape = self._initial_covariance_matrices.shape

        self._running_covar = stats.RunningCovariance(shape=(1, shape[-1]),
                                                      dtype=dtype,
                                                      event_ndims=1)
        self._accum_covar = self._running_covar.initialize()

        probs = tf.expand_dims(tf.ones([shape[0]], dtype=dtype) * pu, axis=1)
        self._u = Bernoulli(probs=probs, dtype=tf.dtypes.int32)
        self._initial_u = tf.zeros_like(self._u.sample(seed=seed),
                                        dtype=tf.dtypes.int32)

        name = mcmc_util.make_name(name, "AdaptiveRandomWalkMetropolis", "")
        seed_stream = SeedStream(seed, salt="AdaptiveRandomWalkMetropolis")

        self._parameters = dict(
            target_log_prob_fn=target_log_prob_fn,
            initial_covariance=initial_covariance,
            initial_covariance_scaling=initial_covariance_scaling,
            covariance_scaling_reducer=covariance_scaling_reducer,
            covariance_scaling_limiter=covariance_scaling_limiter,
            covariance_burnin=covariance_burnin,
            target_accept_ratio=target_accept_ratio,
            pu=pu,
            fixed_variance=fixed_variance,
            extra_getter_fn=extra_getter_fn,
            extra_setter_fn=extra_setter_fn,
            log_accept_prob_getter_fn=log_accept_prob_getter_fn,
            seed=seed,
            name=name,
        )
        self._impl = metropolis_hastings.MetropolisHastings(
            inner_kernel=random_walk_metropolis.UncalibratedRandomWalk(
                target_log_prob_fn=target_log_prob_fn,
                new_state_fn=random_walk_mvnorm_fn(
                    covariance=self._initial_covariance_matrices,
                    pu=pu,
                    fixed_variance=fixed_variance,
                    is_adaptive=self._initial_u,
                    name=name,
                ),
                name=name,
            ),
            name=name,
        )
Пример #8
0
  def __init__(
      self,
      inner_kernel,
      num_adaptation_steps,
      target_accept_prob=0.75,
      exploration_shrinkage=0.05,
      shrinkage_target=None,
      step_count_smoothing=10,
      decay_rate=0.75,
      step_size_setter_fn=hmc_like_step_size_setter_fn,
      step_size_getter_fn=hmc_like_step_size_getter_fn,
      log_accept_prob_getter_fn=hmc_like_log_accept_prob_getter_fn,
      reduce_fn=reduce_logmeanexp,
      experimental_reduce_chain_axis_names=None,
      validate_args=False,
      name=None):
    """Initializes this transition kernel.

    Args:
      inner_kernel: `TransitionKernel`-like object.
      num_adaptation_steps: Scalar `int` `Tensor` number of initial steps to
        during which to adjust the step size. This may be greater, less than, or
        equal to the number of burnin steps.
      target_accept_prob: A floating point `Tensor` representing desired
        acceptance probability. Must be a positive number less than 1. This can
        either be a scalar, or have shape [num_chains]. Default value: `0.75`
          (the [center of asymptotically optimal rate for HMC][1]).
      exploration_shrinkage: Floating point scalar `Tensor`. How strongly the
        exploration rate is biased towards the shrinkage target.
      shrinkage_target: `Tensor` or list of tensors. Value the exploration
        step size(s) is/are biased towards.
        As `num_adaptation_steps --> infinity`, this bias goes to zero.
        Defaults to 10 times the initial step size.
      step_count_smoothing: Int32 scalar `Tensor`. Number of "pseudo-steps"
        added to the number of steps taken to prevents noisy exploration during
        the early samples.
      decay_rate: Floating point scalar `Tensor`. How much to favor recent
        iterations over earlier ones. A value of 1 gives equal weight to all
        history. A value of 0 gives weight only to the most recent iteration.
      step_size_setter_fn: A callable with the signature `(kernel_results,
        new_step_size) -> new_kernel_results` where `kernel_results` are the
        results of the `inner_kernel`, `new_step_size` is a `Tensor` or a nested
        collection of `Tensor`s with the same structure as returned by the
        `step_size_getter_fn`, and `new_kernel_results` are a copy of
        `kernel_results` with the step size(s) set.
      step_size_getter_fn: A callable with the signature `(kernel_results) ->
        step_size` where `kernel_results` are the results of the `inner_kernel`,
        and `step_size` is a floating point `Tensor` or a nested collection of
        such `Tensor`s.
      log_accept_prob_getter_fn: A callable with the signature `(kernel_results)
        -> log_accept_prob` where `kernel_results` are the results of the
        `inner_kernel`, and `log_accept_prob` is a floating point `Tensor`.
        `log_accept_prob` can either be a scalar, or have shape [num_chains]. If
        it's the latter, `step_size` should also have the same leading
        dimension.
      reduce_fn: A callable with signature `(input_tensor, axis, keepdims) ->
        tensor` that returns a log-reduction of `log_accept_prob`, typically
        some sort of mean. By default, this performs an arithmetic mean.
      experimental_reduce_chain_axis_names: A `str` or list of `str`s indicating
        the named axes that should additionally reduced during the log-reduction
        of `log_accept_prob`.
      validate_args: Python `bool`. When `True` kernel parameters are checked
        for validity. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'dual_averaging_step_size_adaptation').
    """
    inner_kernel = mcmc_util.enable_store_parameters_in_results(inner_kernel)

    with tf.name_scope(
        mcmc_util.make_name(
            name, 'dual_averaging_step_size_adaptation', '__init__')) as name:
      dtype = dtype_util.common_dtype([
          target_accept_prob,
          exploration_shrinkage,
          shrinkage_target,
          decay_rate
      ], dtype_hint=tf.float32)
      target_accept_prob = tf.convert_to_tensor(
          target_accept_prob, dtype=dtype, name='target_accept_prob')
      exploration_shrinkage = tf.convert_to_tensor(
          exploration_shrinkage, dtype=dtype, name='exploration_shrinkage')
      step_count_smoothing = tf.cast(
          # Cast to dtype, since we asked the user to provide Int32, but we
          # want to convert to tensor here (and make it `dtype`).
          # I.e., convert_to_tensor will fail here if the user did what we
          # asked.
          step_count_smoothing, dtype=dtype, name='step_count_smoothing')
      decay_rate = tf.convert_to_tensor(
          decay_rate, dtype=dtype, name='decay_rate')
      num_adaptation_steps = tf.convert_to_tensor(
          num_adaptation_steps, dtype=tf.int32, name='num_adaptation_steps')
      target_accept_prob = _maybe_validate_target_accept_prob(
          target_accept_prob, validate_args)

      if shrinkage_target is not None:
        def _convert(x):
          return tf.convert_to_tensor(x, dtype=dtype, name='shrinkage_target')
        shrinkage_target = tf.nest.map_structure(_convert, shrinkage_target)

    self._parameters = dict(
        inner_kernel=inner_kernel,
        num_adaptation_steps=num_adaptation_steps,
        target_accept_prob=target_accept_prob,
        exploration_shrinkage=exploration_shrinkage,
        shrinkage_target=shrinkage_target,
        step_count_smoothing=step_count_smoothing,
        decay_rate=decay_rate,
        step_size_setter_fn=step_size_setter_fn,
        step_size_getter_fn=step_size_getter_fn,
        log_accept_prob_getter_fn=log_accept_prob_getter_fn,
        reduce_fn=reduce_fn,
        experimental_reduce_chain_axis_names=(
            experimental_reduce_chain_axis_names),
        name=name)
Пример #9
0
  def bootstrap_results(self, init_state):
    with tf.name_scope(
        mcmc_util.make_name(self.name, 'dual_averaging_step_size_adaptation',
                            'bootstrap_results')):
      inner_results = self.inner_kernel.bootstrap_results(init_state)
      step_size = self.step_size_getter_fn(inner_results)

      log_accept_prob = self.log_accept_prob_getter_fn(inner_results)

      state_parts = tf.nest.flatten(init_state)
      step_size_parts = tf.nest.flatten(step_size)

      if self._parameters['shrinkage_target'] is None:
        shrinkage_target_parts = [None] * len(step_size_parts)
      else:
        shrinkage_target_parts = tf.nest.flatten(
            self._parameters['shrinkage_target'])
        if len(shrinkage_target_parts) not in [1, len(step_size_parts)]:
          raise ValueError(
              '`shrinkage_target` should be a Tensor or list of tensors of '
              'same length as `step_size`. Found len(`step_size`) = {} and '
              'len(shrinkage_target) = {}'.format(
                  len(step_size_parts), len(shrinkage_target_parts)))
        if len(shrinkage_target_parts) < len(step_size_parts):
          shrinkage_target_parts *= len(step_size_parts)

      dtype = dtype_util.common_dtype(step_size_parts, tf.float32)
      error_sum, log_averaging_step, log_shrinkage_target = [], [], []
      for state_part, step_size_part, shrinkage_target_part in zip(
          state_parts, step_size_parts, shrinkage_target_parts):
        num_reduce_dims = ps.minimum(
            ps.rank(log_accept_prob),
            ps.rank(state_part) - ps.rank(step_size_part))
        reduced_log_accept_prob = reduce_logmeanexp(
            log_accept_prob,
            axis=ps.range(num_reduce_dims),
            experimental_named_axis=self.experimental_reduce_chain_axis_names)
        reduce_indices = get_differing_dims(
            reduced_log_accept_prob, step_size_part)
        reduced_log_accept_prob = reduce_logmeanexp(
            reduced_log_accept_prob,
            axis=reduce_indices,
            keepdims=True)
        error_sum.append(tf.zeros_like(reduced_log_accept_prob, dtype=dtype))
        log_averaging_step.append(tf.zeros_like(step_size_part, dtype=dtype))

        if shrinkage_target_part is None:
          log_shrinkage_target.append(
              float(np.log(10.)) + tf.math.log(step_size_part))
        else:
          log_shrinkage_target.append(
              tf.math.log(tf.cast(shrinkage_target_part, dtype)))

      return DualAveragingStepSizeAdaptationResults(
          inner_results=inner_results,
          step=tf.constant(0, dtype=tf.int32),
          target_accept_prob=tf.cast(self.parameters['target_accept_prob'],
                                     log_accept_prob.dtype),
          log_shrinkage_target=log_shrinkage_target,
          exploration_shrinkage=tf.cast(
              self.parameters['exploration_shrinkage'], dtype),
          step_count_smoothing=tf.cast(
              self.parameters['step_count_smoothing'], dtype),
          decay_rate=tf.cast(self.parameters['decay_rate'], dtype),
          error_sum=error_sum,
          log_averaging_step=log_averaging_step,
          new_step_size=step_size)
Пример #10
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                _make_replica_target_log_prob_fn(self.target_log_prob_fn,
                                                 inverse_temperatures),
                self._seed_stream())

            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = prefer_static.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = prefer_static.rank(
                pre_swap_replica_target_log_prob)
            num_replica = prefer_static.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            swaps = tf.cast(
                self.swap_proposal_fn(  # pylint: disable=not-callable
                    num_replica,
                    batch_shape=batch_shape,
                    seed=self._seed_stream()),
                dtype=tf.int32)
            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                tf.random.uniform(shape=replica_and_batch_shape,
                                  dtype=dtype,
                                  seed=self._seed_stream()))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
            )

            return states, post_swap_kernel_results
Пример #11
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
                init_state)

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = prefer_static.size0(inverse_temperatures)
            replica_shape = tf.convert_to_tensor([num_replica])

            replica_states = [
                tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                    x,
                    prefer_static.concat(
                        [replica_shape, prefer_static.shape(x)], axis=0),
                    name='replica_states') for x in init_state
            ]

            inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                _make_replica_target_log_prob_fn(self.target_log_prob_fn,
                                                 inverse_temperatures),
                self._seed_stream())
            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = prefer_static.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = mcmc_util.left_justified_broadcast_to(
                tf.range(num_replica), replica_and_batch_shape)
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),
                                                                  shift=2)

            post_swap_replica_results = _make_post_swap_replica_results(
                replica_results,
                inverse_temperatures,
                inverse_temperatures,
                is_swap_accepted[0],
                lambda x: x,
            )

            return ReplicaExchangeMCKernelResults(
                post_swap_replica_states=replica_states,
                pre_swap_replica_results=replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_accepted,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_accepted),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                inverse_temperatures=self.inverse_temperatures,
                swaps=swaps,
            )
Пример #12
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                seed = samplers.sanitize_seed(seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                seed = samplers.sanitize_seed(self._seed_stream())
            seeds = samplers.split_seed(seed, n=len(current_state_parts))

            current_momentum_parts = []
            for part_seed, x in zip(seeds, current_state_parts):
                current_momentum_parts.append(
                    samplers.normal(shape=tf.shape(x),
                                    dtype=self._momentum_dtype
                                    or dtype_util.base_dtype(x.dtype),
                                    seed=part_seed))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = prefer_static.rank(
                current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts, next_momentum_parts,
                    independent_chain_ndims),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
                initial_momentum=current_momentum_parts,
                final_momentum=next_momentum_parts,
                seed=seed,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
Пример #13
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'one_step')):
            variance_parts = previous_kernel_results.running_variance
            inner_results = previous_kernel_results.inner_results

            # Step the inner kernel.
            inner_kwargs = {} if seed is None else dict(seed=seed)
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results, **inner_kwargs)

            def update_running_variance():
                diags = [
                    variance_part.variance()
                    for variance_part in variance_parts
                ]
                new_state_parts = tf.nest.flatten(new_state)
                new_variance_parts = []
                for variance_part, diag, state_part in zip(
                        variance_parts, diags, new_state_parts):
                    # Compute new variance for each variance part, accounting for partial
                    # batching of the variance calculation across chains (ie, some, all,
                    # or none of the chains may share the estimated mass matrix).
                    #
                    # For example, say
                    #
                    # state_part has shape       [2, 3, 4] + [5, 6]  (batch + event)
                    # variance_part has shape          [4] + [5, 6]
                    # log_prob has shape         [2, 3, 4]
                    #
                    # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass
                    # matrices, each being shared across a [2, 3]-batch of chains. Note
                    # this division is inferred from the shapes of the state part, the
                    # log_prob, and the user-provided initial running variances.
                    #
                    # Until RunningVariance supports rank > 1 chunking, we need to flatten
                    # the states that go into updating the variance estimates. In the
                    # above example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and
                    # fed to `RunningVariance.update(state_part, axis=0)`, recording
                    # 6 new observations in the running variance calculation.
                    # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and
                    # the resulting momentum distribution will have batch shape of
                    # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part.
                    state_rank = ps.rank(state_part)
                    variance_rank = ps.rank(diag)
                    num_reduce_dims = state_rank - variance_rank

                    state_part_shape = ps.shape(state_part)
                    # This reshape adds a 1 when reduce_dims==0, and collapses all the
                    # lead dimensions to a single one otherwise.
                    reshaped_state = ps.reshape(
                        state_part,
                        ps.concat([[
                            ps.reduce_prod(state_part_shape[:num_reduce_dims])
                        ], state_part_shape[num_reduce_dims:]],
                                  axis=0))

                    # The `axis=0` here removes the leading dimension we got from the
                    # reshape above, so the new_variance_parts have the correct shape
                    # again.
                    new_variance_parts.append(
                        variance_part.update(reshaped_state, axis=0))
                return new_variance_parts

            def update_momentum():
                diags = [
                    variance_part.variance()
                    for variance_part in new_variance_parts
                ]
                # Update the momentum.
                prev_momentum_distribution = self.momentum_distribution_getter_fn(
                    new_inner_results)
                new_momentum_distribution = (
                    preconditioning_utils.update_momentum_distribution(
                        prev_momentum_distribution, diags))
                updated_new_inner_results = self.momentum_distribution_setter_fn(
                    new_inner_results, new_momentum_distribution)
                return updated_new_inner_results

            step = previous_kernel_results.step + 1
            if self.num_estimation_steps is None:
                new_variance_parts = update_running_variance()
                new_inner_results = update_momentum()
            else:
                new_variance_parts = mcmc_util.choose(
                    step <= previous_kernel_results.num_estimation_steps,
                    update_running_variance(), variance_parts)
                new_inner_results = mcmc_util.choose(
                    tf.equal(step,
                             previous_kernel_results.num_estimation_steps),
                    update_momentum(), new_inner_results)
            new_kernel_results = previous_kernel_results._replace(
                inner_results=new_inner_results,
                running_variance=new_variance_parts,
                step=step)

            return new_state, new_kernel_results
Пример #14
0
    def bootstrap_results(self, init_state=None, transformed_init_state=None):
        """Returns an object with the same type as returned by `one_step`.

    Unlike other `TransitionKernel`s,
    `TransformedTransitionKernel.bootstrap_results` has the option of
    initializing the `TransformedTransitionKernelResults` from either an initial
    state, eg, requiring computing `bijector.inverse(init_state)`, or
    directly from `transformed_init_state`, i.e., a `Tensor` or list
    of `Tensor`s which is interpretted as the `bijector.inverse`
    transformed state.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the a
        state(s) of the Markov chain(s). Must specify `init_state` or
        `transformed_init_state` but not both.
      transformed_init_state: `Tensor` or Python `list` of `Tensor`s
        representing the a state(s) of the Markov chain(s). Must specify
        `init_state` or `transformed_init_state` but not both.

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".

    #### Examples

    To use `transformed_init_state` in context of
    `tfp.mcmc.sample_chain`, you need to explicitly pass the
    `previous_kernel_results`, e.g.,

    ```python
    transformed_kernel = tfp.mcmc.TransformedTransitionKernel(...)
    init_state = ...        # Doesnt matter.
    transformed_init_state = ... # Does matter.
    results, _ = tfp.mcmc.sample_chain(
        num_results=...,
        current_state=init_state,
        previous_kernel_results=transformed_kernel.bootstrap_results(
            transformed_init_state=transformed_init_state),
        kernel=transformed_kernel)
    ```
    """
        if (init_state is None) == (transformed_init_state is None):
            raise ValueError('Must specify exactly one of `init_state` '
                             'or `transformed_init_state`.')
        with tf1.name_scope(name=mcmc_util.make_name(self.name,
                                                     'transformed_kernel',
                                                     'bootstrap_results'),
                            values=[init_state, transformed_init_state]):
            if transformed_init_state is None:
                init_state_parts = (init_state
                                    if mcmc_util.is_list_like(init_state) else
                                    [init_state])
                transformed_init_state_parts = self._inverse_transform(
                    init_state_parts)
                transformed_init_state = (transformed_init_state_parts
                                          if mcmc_util.is_list_like(init_state)
                                          else transformed_init_state_parts[0])
            else:
                if mcmc_util.is_list_like(transformed_init_state):
                    transformed_init_state = [
                        tf.convert_to_tensor(value=s,
                                             name='transformed_init_state')
                        for s in transformed_init_state
                    ]
                else:
                    transformed_init_state = tf.convert_to_tensor(
                        value=transformed_init_state,
                        name='transformed_init_state')
            kernel_results = TransformedTransitionKernelResults(
                transformed_state=transformed_init_state,
                inner_results=self._inner_kernel.bootstrap_results(
                    transformed_init_state))
            return kernel_results
Пример #15
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.compat.v1.name_scope(
                name=mcmc_util.make_name(self.name, 'mala', 'one_step'),
                values=[
                    self.step_size, current_state,
                    previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    previous_kernel_results.volatility,
                    previous_kernel_results.diffusion_drift
                ]):
            with tf.compat.v1.name_scope('initialize'):
                # Prepare input arguments to be passed to `_euler_method`.
                [
                    current_state_parts,
                    step_size_parts,
                    current_target_log_prob,
                    _,  # grads_target_log_prob
                    current_volatility_parts,
                    _,  # grads_volatility
                    current_drift_parts,
                ] = _prepare_args(
                    self.target_log_prob_fn, self.volatility_fn, current_state,
                    self.step_size, previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    previous_kernel_results.volatility,
                    previous_kernel_results.grads_volatility,
                    previous_kernel_results.diffusion_drift,
                    self.parallel_iterations)

                random_draw_parts = []
                for s in current_state_parts:
                    random_draw_parts.append(
                        tf.random.normal(shape=tf.shape(input=s),
                                         dtype=s.dtype.base_dtype,
                                         seed=self._seed_stream()))

            # Number of independent chains run by the algorithm.
            independent_chain_ndims = distribution_util.prefer_static_rank(
                current_target_log_prob)

            # Generate the next state of the algorithm using Euler-Maruyama method.
            next_state_parts = _euler_method(random_draw_parts,
                                             current_state_parts,
                                             current_drift_parts,
                                             step_size_parts,
                                             current_volatility_parts)

            # Compute helper `UncalibratedLangevinKernelResults` to be processed by
            # `_compute_log_acceptance_correction` and in the next iteration of
            # `one_step` function.
            [
                _,  # state_parts
                _,  # step_sizes
                next_target_log_prob,
                next_grads_target_log_prob,
                next_volatility_parts,
                next_grads_volatility,
                next_drift_parts,
            ] = _prepare_args(self.target_log_prob_fn,
                              self.volatility_fn,
                              next_state_parts,
                              step_size_parts,
                              parallel_iterations=self.parallel_iterations)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            # Decide whether to compute the acceptance ratio
            log_acceptance_correction_compute = _compute_log_acceptance_correction(
                current_state_parts, next_state_parts,
                current_volatility_parts, next_volatility_parts,
                current_drift_parts, next_drift_parts, step_size_parts,
                independent_chain_ndims)
            log_acceptance_correction_skip = tf.zeros_like(
                next_target_log_prob)

            log_acceptance_correction = tf.cond(
                pred=self.compute_acceptance,
                true_fn=lambda: log_acceptance_correction_compute,
                false_fn=lambda: log_acceptance_correction_skip)

            return [
                maybe_flatten(next_state_parts),
                UncalibratedLangevinKernelResults(
                    log_acceptance_correction=log_acceptance_correction,
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_grads_target_log_prob,
                    volatility=maybe_flatten(next_volatility_parts),
                    grads_volatility=next_grads_volatility,
                    diffusion_drift=next_drift_parts),
            ]
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'phmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
                momentum_distribution = previous_kernel_results.momentum_distribution
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps
                momentum_distribution = self.momentum_distribution

            [
                current_state_parts,
                step_sizes,
                momentum_distribution,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                momentum_distribution,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            seed = samplers.sanitize_seed(seed)
            current_momentum_parts = list(
                momentum_distribution.sample(seed=seed))
            momentum_log_prob = getattr(momentum_distribution,
                                        '_log_prob_unnormalized',
                                        momentum_distribution.log_prob)
            kinetic_energy_fn = lambda *args: -momentum_log_prob(*args)

            # Let the integrator handle the case where no momentum distribution
            # is provided
            if self.momentum_distribution is None:
                leapfrog_kinetic_energy_fn = None
            else:
                leapfrog_kinetic_energy_fn = kinetic_energy_fn

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(
                current_momentum_parts,
                current_state_parts,
                target=current_target_log_prob,
                target_grad_parts=current_target_log_prob_grad_parts,
                kinetic_energy_fn=leapfrog_kinetic_energy_fn)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    kinetic_energy_fn, current_momentum_parts,
                    next_momentum_parts),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
                initial_momentum=current_momentum_parts,
                final_momentum=next_momentum_parts,
                seed=seed,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
Пример #17
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.compat.v2.name_scope(
                mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps
            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            self.restoreShapes = []
            for x in current_state_parts:
                n = 1
                shape = x.shape
                for m in shape:
                    n *= m
                self.restoreShapes.append([shape, n])
            current_state_parts = [
                tf.reshape(part, [-1]) for part in current_state_parts
            ]
            current_state_parts = tf.concat(current_state_parts, -1)
            temp = []
            #print(current_state_parts)
            for x in range(current_state_parts.shape[0]):
                temp.append(current_state_parts[x])
            current_state_parts = temp
            #print(current_state_parts)

            current_momentum_parts = []

            for x in current_state_parts:
                current_momentum_parts.append(
                    tf.random.normal(shape=tf.shape(input=x),
                                     dtype=self._momentum_dtype
                                     or x.dtype.base_dtype,
                                     seed=self._seed_stream()))

            next_state_parts, initial_kinetic, final_kinetic, final_target_log_prob = self.run_integrator(
                step_sizes, num_leapfrog_steps, current_momentum_parts,
                current_state_parts)

            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = distribution_util.prefer_static_rank(
                current_target_log_prob)

            next_state_parts = maybe_flatten(next_state_parts)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    initial_kinetic, final_kinetic, independent_chain_ndims),
                target_log_prob=final_target_log_prob)
            argv = next_state_parts  #[0]
            next_state_parts = []
            index = 0
            #print(self.restoreShapes)
            for info in self.restoreShapes:
                next_state_parts.append(
                    tf.reshape(argv[index:index + info[1]], info[0]))
                index += info[1]

            return next_state_parts, new_kernel_results
  def __init__(
      self,
      inner_kernel,
      num_adaptation_steps,
      use_halton_sequence_jitter=True,
      adaptation_rate=0.025,
      jitter_amount=1.,
      criterion_fn=chees_criterion,
      max_leapfrog_steps=1000,
      num_leapfrog_steps_getter_fn=hmc_like_num_leapfrog_steps_getter_fn,
      num_leapfrog_steps_setter_fn=hmc_like_num_leapfrog_steps_setter_fn,
      step_size_getter_fn=hmc_like_step_size_getter_fn,
      proposed_velocity_getter_fn=hmc_like_proposed_velocity_getter_fn,
      log_accept_prob_getter_fn=hmc_like_log_accept_prob_getter_fn,
      proposed_state_getter_fn=hmc_like_proposed_state_getter_fn,
      validate_args=False,
      name=None):
    """Creates the trajectory length adaptation kernel.

    The default setter_fn and the getter_fn callbacks assume that the inner
    kernel produces kernel results structurally the same as the
    `HamiltonianMonteCarlo` kernel (possibly wrapped in some step size
    adaptation kernel).

    Args:
      inner_kernel: `TransitionKernel`-like object.
      num_adaptation_steps: Scalar `int` `Tensor` number of initial steps to
        during which to adjust the step size. This may be greater, less than, or
        equal to the number of burnin steps.
      use_halton_sequence_jitter: Python bool. Whether to use a Halton sequence
        for jittering the trajectory length. This makes the procedure more
        stable than sampling trajectory lengths from a uniform distribution.
      adaptation_rate: Floating point scalar `Tensor`. How rapidly to adapt the
        trajectory length.
      jitter_amount: Floating point scalar `Tensor`. How much to jitter the
        trajectory on the next step. The trajectory length is sampled from `[(1
        - jitter_amount) * max_trajectory_length, max_trajectory_length]`.
      criterion_fn: Callable with `(previous_state, proposed_state, accept_prob)
        -> criterion`. Computes the criterion value.
      max_leapfrog_steps: Int32 scalar `Tensor`. Clips the number of leapfrog
        steps to this value.
      num_leapfrog_steps_getter_fn: A callable with the signature
        `(kernel_results) -> num_leapfrog_steps` where `kernel_results` are the
        results of the `inner_kernel`, and `num_leapfrog_steps` is a floating
        point `Tensor`.
      num_leapfrog_steps_setter_fn: A callable with the signature
        `(kernel_results, new_num_leapfrog_steps) -> new_kernel_results` where
        `kernel_results` are the results of the `inner_kernel`,
        `new_num_leapfrog_steps` is a scalar tensor `Tensor`, and
        `new_kernel_results` are a copy of `kernel_results` with the number of
        leapfrog steps set.
      step_size_getter_fn: A callable with the signature `(kernel_results) ->
        step_size` where `kernel_results` are the results of the `inner_kernel`,
        and `step_size` is a floating point `Tensor`.
      proposed_velocity_getter_fn: A callable with the signature
        `(kernel_results) -> proposed_velocity` where `kernel_results` are the
        results of the `inner_kernel`, and `proposed_velocity` is a (possibly
        nested) floating point `Tensor`. Velocity is derivative of state with
        respect to trajectory length.
      log_accept_prob_getter_fn: A callable with the signature `(kernel_results)
        -> log_accept_prob` where `kernel_results` are the results of the
        `inner_kernel`, and `log_accept_prob` is a floating point `Tensor`.
        `log_accept_prob` has shape `[C0, ...., Cb]` with `b > 0`.
      proposed_state_getter_fn: A callable with the signature `(kernel_results)
        -> proposed_state` where `kernel_results` are the results of the
        `inner_kernel`, and `proposed_state` is a (possibly nested) floating
        point `Tensor`.
      validate_args: Python `bool`. When `True` kernel parameters are checked
        for validity. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class. Default:
        'simple_step_size_adaptation'.

    Raises:
      ValueError: If `inner_kernel` contains a `TransformedTransitionKernel` in
        its hierarchy. If you need to use the `TransformedTransitionKernel`,
        place it above this kernel in the hierarchy (see the example in the
        class docstring).
    """
    inner_kernel = mcmc_util.enable_store_parameters_in_results(inner_kernel)
    _forbid_inner_transformed_kernel(inner_kernel)

    with tf.name_scope(
        mcmc_util.make_name(name, 'gradient_based_trajectory_length_adaptation',
                            '__init__')) as name:
      dtype = dtype_util.common_dtype([adaptation_rate, jitter_amount],
                                      tf.float32)
      num_adaptation_steps = tf.convert_to_tensor(
          num_adaptation_steps, dtype=tf.int32, name='num_adaptation_steps')
      adaptation_rate = tf.convert_to_tensor(
          adaptation_rate, dtype=dtype, name='adaptation_rate')
      jitter_amount = tf.convert_to_tensor(
          jitter_amount, dtype=dtype, name='jitter_amount')
      max_leapfrog_steps = tf.convert_to_tensor(
          max_leapfrog_steps, dtype=tf.int32, name='max_leapfrog_steps')

    self._parameters = dict(
        inner_kernel=inner_kernel,
        num_adaptation_steps=num_adaptation_steps,
        use_halton_sequence_jitter=use_halton_sequence_jitter,
        adaptation_rate=adaptation_rate,
        jitter_amount=jitter_amount,
        criterion_fn=criterion_fn,
        max_leapfrog_steps=max_leapfrog_steps,
        num_leapfrog_steps_getter_fn=num_leapfrog_steps_getter_fn,
        num_leapfrog_steps_setter_fn=num_leapfrog_steps_setter_fn,
        step_size_getter_fn=step_size_getter_fn,
        proposed_velocity_getter_fn=proposed_velocity_getter_fn,
        log_accept_prob_getter_fn=log_accept_prob_getter_fn,
        proposed_state_getter_fn=hmc_like_proposed_state_getter_fn,
        validate_args=validate_args,
        name=name,
    )
Пример #19
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'mala', 'one_step')):
            with tf.name_scope('initialize'):
                # Prepare input arguments to be passed to `_euler_method`.
                [
                    current_state_parts,
                    step_size_parts,
                    current_target_log_prob,
                    _,  # grads_target_log_prob
                    current_volatility_parts,
                    _,  # grads_volatility
                    current_drift_parts,
                ] = _prepare_args(
                    self.target_log_prob_fn, self.volatility_fn, current_state,
                    self.step_size, previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    previous_kernel_results.volatility,
                    previous_kernel_results.grads_volatility,
                    previous_kernel_results.diffusion_drift,
                    self.parallel_iterations)

                seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
                seeds = samplers.split_seed(seed,
                                            n=len(current_state_parts),
                                            salt='langevin.one_step')
                seeds = distribute_lib.fold_in_axis_index(
                    seeds, self.experimental_shard_axis_names)

                random_draw_parts = []
                for state_part, part_seed in zip(current_state_parts, seeds):
                    random_draw_parts.append(
                        samplers.normal(shape=ps.shape(state_part),
                                        dtype=dtype_util.base_dtype(
                                            state_part.dtype),
                                        seed=part_seed))

            # Number of independent chains run by the algorithm.
            independent_chain_ndims = ps.rank(current_target_log_prob)

            # Generate the next state of the algorithm using Euler-Maruyama method.
            next_state_parts = _euler_method(random_draw_parts,
                                             current_state_parts,
                                             current_drift_parts,
                                             step_size_parts,
                                             current_volatility_parts)

            # Compute helper `UncalibratedLangevinKernelResults` to be processed by
            # `_compute_log_acceptance_correction` and in the next iteration of
            # `one_step` function.
            [
                _,  # state_parts
                _,  # step_sizes
                next_target_log_prob,
                next_grads_target_log_prob,
                next_volatility_parts,
                next_grads_volatility,
                next_drift_parts,
            ] = _prepare_args(self.target_log_prob_fn,
                              self.volatility_fn,
                              next_state_parts,
                              step_size_parts,
                              parallel_iterations=self.parallel_iterations)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            # Decide whether to compute the acceptance ratio
            log_acceptance_correction_compute = _compute_log_acceptance_correction(
                current_state_parts,
                next_state_parts,
                current_volatility_parts,
                next_volatility_parts,
                current_drift_parts,
                next_drift_parts,
                step_size_parts,
                independent_chain_ndims,
                experimental_shard_axis_names=self.
                experimental_shard_axis_names)
            log_acceptance_correction_skip = tf.zeros_like(
                next_target_log_prob)

            log_acceptance_correction = tf.cond(
                pred=self.compute_acceptance,
                true_fn=lambda: log_acceptance_correction_compute,
                false_fn=lambda: log_acceptance_correction_skip)

            return [
                maybe_flatten(next_state_parts),
                UncalibratedLangevinKernelResults(
                    log_acceptance_correction=log_acceptance_correction,
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_grads_target_log_prob,
                    volatility=maybe_flatten(next_volatility_parts),
                    grads_volatility=next_grads_volatility,
                    diffusion_drift=next_drift_parts,
                    seed=seed,
                ),
            ]
  def one_step(self, current_state, previous_kernel_results, seed=None):
    with tf.name_scope(
        mcmc_util.make_name(self.name,
                            'gradient_based_trajectory_length_adaptation',
                            'one_step')):

      jitter_seed, inner_seed = samplers.split_seed(seed)

      dtype = previous_kernel_results.adaptation_rate.dtype
      current_state = tf.nest.map_structure(
          lambda x: tf.convert_to_tensor(x, dtype=dtype), current_state)
      step_f = tf.cast(previous_kernel_results.step, dtype)
      if self.use_halton_sequence_jitter:
        trajectory_jitter = _halton_sequence(step_f)
      else:
        trajectory_jitter = samplers.uniform((), seed=jitter_seed, dtype=dtype)

      jitter_amount = previous_kernel_results.jitter_amount
      trajectory_jitter = (
          trajectory_jitter * jitter_amount + (1. - jitter_amount))

      adapting = previous_kernel_results.step < self.num_adaptation_steps
      max_trajectory_length = tf.where(
          adapting, previous_kernel_results.max_trajectory_length,
          previous_kernel_results.averaged_max_trajectory_length)
      jittered_trajectory_length = (max_trajectory_length * trajectory_jitter)

      step_size = _ensure_step_size_is_scalar(
          self.step_size_getter_fn(previous_kernel_results), self.validate_args)
      num_leapfrog_steps = tf.cast(
          tf.maximum(
              tf.ones([], dtype),
              tf.math.ceil(jittered_trajectory_length / step_size)), tf.int32)

      previous_kernel_results_with_jitter = self.num_leapfrog_steps_setter_fn(
          previous_kernel_results, num_leapfrog_steps)

      new_state, new_inner_results = self.inner_kernel.one_step(
          current_state, previous_kernel_results_with_jitter.inner_results,
          inner_seed)

      proposed_state = self.proposed_state_getter_fn(new_inner_results)
      proposed_velocity = self.proposed_velocity_getter_fn(new_inner_results)
      accept_prob = tf.exp(self.log_accept_prob_getter_fn(new_inner_results))

      new_kernel_results = _update_trajectory_grad(
          previous_kernel_results_with_jitter,
          previous_state=current_state,
          proposed_state=proposed_state,
          proposed_velocity=proposed_velocity,
          trajectory_jitter=trajectory_jitter,
          accept_prob=accept_prob,
          step_size=step_size,
          criterion_fn=self.criterion_fn,
          max_leapfrog_steps=self.max_leapfrog_steps)

      # Undo the effect of adaptation if we're not in the burnin phase. We keep
      # the criterion, however, as that's a diagnostic. We also keep the
      # leapfrog steps setting, as that's an effect of jitter (and also doubles
      # as a diagnostic).
      criterion = new_kernel_results.criterion
      new_kernel_results = mcmc_util.choose(
          adapting, new_kernel_results, previous_kernel_results_with_jitter)

      new_kernel_results = new_kernel_results._replace(
          inner_results=new_inner_results,
          step=previous_kernel_results.step + 1,
          criterion=criterion)

      return new_state, new_kernel_results
Пример #21
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name, "AdaptiveRandomWalkMetropolis",
                                    "one_step")):
            with tf.name_scope("initialize"):
                if mcmc_util.is_list_like(current_state):
                    current_state_parts = list(current_state)
                else:
                    current_state_parts = [current_state]
                current_state_parts = [
                    tf.convert_to_tensor(s, name="current_state")
                    for s in current_state_parts
                ]

            # Note 'covariance_scaling' and 'accum_covar' are updated every step but
            # 'covariance' is not updated until 'num_steps' >= 'covariance_burnin'.
            num_steps = self.extra_getter_fn(previous_kernel_results).num_steps
            # for parallel processing efficiency use gather() rather than cond()?
            previous_is_adaptive = self.extra_getter_fn(
                previous_kernel_results).is_adaptive
            current_covariance_scaling = tf.gather(
                tf.stack(
                    [
                        self.extra_getter_fn(
                            previous_kernel_results).covariance_scaling,
                        self.update_covariance_scaling(previous_kernel_results,
                                                       num_steps),
                    ],
                    axis=-1,
                ),
                previous_is_adaptive,
                batch_dims=1,
                axis=1,
            )
            previous_accum_covar = self.extra_getter_fn(
                previous_kernel_results).running_covariance
            current_accum_covar = self.running_covar.update(
                state=previous_accum_covar, new_sample=current_state_parts)

            previous_covariance = self.extra_getter_fn(
                previous_kernel_results).covariance
            current_covariance = tf.gather(
                [
                    previous_covariance,
                    self.running_covar.finalize(current_accum_covar, ddof=1),
                ],
                tf.cast(
                    num_steps >= self.covariance_burnin,
                    dtype=tf.dtypes.int32,
                ),
            )

            current_scaled_covariance = tf.squeeze(
                tf.expand_dims(current_covariance_scaling, axis=1) *
                tf.stack([current_covariance]),
                axis=0,
            )

            current_is_adaptive = self.u.sample(seed=self.seed)

            self._impl = metropolis_hastings.MetropolisHastings(
                inner_kernel=random_walk_metropolis.UncalibratedRandomWalk(
                    target_log_prob_fn=self.target_log_prob_fn,
                    new_state_fn=random_walk_mvnorm_fn(
                        covariance=current_scaled_covariance,
                        pu=self.pu,
                        fixed_variance=self.fixed_variance,
                        is_adaptive=current_is_adaptive,
                        name=self.name,
                    ),
                    name=self.name,
                ),
                name=self.name,
            )
            new_state, new_inner_results = self._impl.one_step(
                current_state, previous_kernel_results)
            new_inner_results = self.extra_setter_fn(
                new_inner_results,
                num_steps + 1,
                tf.squeeze(current_covariance_scaling, axis=1),
                current_covariance,
                current_accum_covar,
                current_is_adaptive,
            )
            return [new_state, new_inner_results]
Пример #22
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                self.target_log_prob_fn, inverse_temperatures)
            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`, and no seed
            # expected by `kernel.one_step`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we warn and fall back to the previous behavior.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is '
                    'deprecated. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel, self._seed_stream())

            # Now that we've constructed the TransitionKernel instance:
            # - If we were given a seed, we sanitize it to stateless and pass along
            #   to `kernel.one_step`. If it doesn't like that, we crash and propagate
            #   the error.  Rationale: The contract is stateless sampling given
            #   seed, and doing otherwise would not meet it.
            # - If not given a seed, we don't pass one along. This avoids breaking
            #   underlying kernels lacking a `seed` arg on `one_step`.
            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                seed = samplers.sanitize_seed(seed)
                inner_seed, swap_seed, logu_seed = samplers.split_seed(
                    seed, n=3, salt='remc_one_step')
                inner_kwargs = dict(seed=inner_seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                inner_kwargs = {}
                swap_seed, logu_seed = samplers.split_seed(self._seed_stream())
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                **inner_kwargs)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            return states, post_swap_kernel_results
Пример #23
0
  def bootstrap_results(self, init_state):
    """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
    with tf.name_scope(mcmc_util.make_name(
        self.name, 'remc', 'bootstrap_results')):
      init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
          init_state)

      inverse_temperatures = tf.convert_to_tensor(
          self.inverse_temperatures,
          name='inverse_temperatures')

      if self._state_includes_replicas:
        it_n_replica = inverse_temperatures.shape[0]
        state_n_replica = init_state[0].shape[0]
        if ((it_n_replica is not None) and (state_n_replica is not None) and
            (it_n_replica != state_n_replica)):
          raise ValueError(
              'Number of replicas implied by initial state ({}) must equal '
              'number of replicas implied by inverse_temperatures ({}), but '
              'did not'.format(it_n_replica, state_n_replica))

      # We will now replicate each of a possible batch of initial stats, one for
      # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
      # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
      # concatenation and T=shape(inverse_temperature).
      num_replica = ps.size0(inverse_temperatures)
      replica_shape = ps.convert_to_shape_tensor([num_replica])

      if self._state_includes_replicas:
        replica_states = init_state
      else:
        replica_states = [
            tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                x,
                ps.concat([replica_shape, ps.shape(x)], axis=0),
                name='replica_states')
            for x in init_state
        ]

      target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
          target_log_prob_fn=self.target_log_prob_fn,
          inverse_temperatures=inverse_temperatures,
          untempered_log_prob_fn=self.untempered_log_prob_fn,
          tempered_log_prob_fn=self.tempered_log_prob_fn,
      )
      # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
      try:
        inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
            target_log_prob_for_inner_kernel)
      except TypeError as e:
        if 'argument' not in str(e):
          raise
        raise TypeError(
            '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a second '
            '(`seed`) argument. `TransitionKernel` instances now receive seeds '
            'via `one_step`.')

      replica_results = inner_kernel.bootstrap_results(replica_states)

      pre_swap_replica_target_log_prob = _get_field(
          replica_results, 'target_log_prob')

      replica_and_batch_shape = ps.shape(
          pre_swap_replica_target_log_prob)
      batch_shape = replica_and_batch_shape[1:]

      inverse_temperatures = bu.left_justified_broadcast_to(
          inverse_temperatures, replica_and_batch_shape)

      # Pretend we did a "null swap", which will always be accepted.
      swaps = bu.left_justified_broadcast_to(
          tf.range(num_replica), replica_and_batch_shape)
      # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
      is_swap_accepted = distribution_util.rotate_transpose(
          tf.eye(num_replica, batch_shape=batch_shape, dtype=tf.bool),
          shift=2)

      return ReplicaExchangeMCKernelResults(
          post_swap_replica_states=replica_states,
          pre_swap_replica_results=replica_results,
          post_swap_replica_results=_set_swapped_fields_to_nan(replica_results),
          is_swap_proposed=is_swap_accepted,
          is_swap_accepted=is_swap_accepted,
          is_swap_proposed_adjacent=_sub_diag(is_swap_accepted),
          is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
          inverse_temperatures=self.inverse_temperatures,
          swaps=swaps,
          step_count=tf.zeros(shape=(), dtype=tf.int32),
          seed=samplers.zeros_seed(),
      )
Пример #24
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
                init_state)

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            if self._state_includes_replicas:
                it_n_replica = inverse_temperatures.shape[0]
                state_n_replica = init_state[0].shape[0]
                if ((it_n_replica is not None)
                        and (state_n_replica is not None)
                        and (it_n_replica != state_n_replica)):
                    raise ValueError(
                        'Number of replicas implied by initial state ({}) must equal '
                        'number of replicas implied by inverse_temperatures ({}), but '
                        'did not'.format(it_n_replica, state_n_replica))

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = ps.size0(inverse_temperatures)
            replica_shape = tf.convert_to_tensor([num_replica])

            if self._state_includes_replicas:
                replica_states = init_state
            else:
                replica_states = [
                    tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                        x,
                        ps.concat([replica_shape, ps.shape(x)], axis=0),
                        name='replica_states') for x in init_state
                ]

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                self.target_log_prob_fn, inverse_temperatures)
            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we fall back to the previous behavior and warn.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The second (`seed`) argument to `ReplicaExchangeMC`s '
                    '`make_kernel_fn` is deprecated. `TransitionKernel` instances now '
                    'receive seeds via `bootstrap_results` and `one_step`. This '
                    'fallback may become an error 2020-09-20.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel, self._seed_stream())

            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = mcmc_util.left_justified_broadcast_to(
                tf.range(num_replica), replica_and_batch_shape)
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),
                                                                  shift=2)

            post_swap_replica_results = _make_post_swap_replica_results(
                replica_results,
                inverse_temperatures,
                inverse_temperatures,
                is_swap_accepted[0],
                lambda x: x,
            )

            return ReplicaExchangeMCKernelResults(
                post_swap_replica_states=replica_states,
                pre_swap_replica_results=replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_accepted,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_accepted),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                inverse_temperatures=self.inverse_temperatures,
                swaps=swaps,
                step_count=tf.zeros(shape=(), dtype=tf.int32),
                seed=samplers.zeros_seed(),
            )
Пример #25
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            seeds = samplers.split_seed(seed, n=len(current_state_parts))
            seeds = distribute_lib.fold_in_axis_index(
                seeds, self.experimental_shard_axis_names)

            current_momentum_parts = []
            for part_seed, x in zip(seeds, current_state_parts):
                current_momentum_parts.append(
                    samplers.normal(shape=ps.shape(x),
                                    dtype=self._momentum_dtype
                                    or dtype_util.base_dtype(x.dtype),
                                    seed=part_seed))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = ps.rank(current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts,
                    next_momentum_parts,
                    independent_chain_ndims,
                    shard_axis_names=self.experimental_shard_axis_names),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
                initial_momentum=current_momentum_parts,
                final_momentum=next_momentum_parts,
                seed=seed,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
Пример #26
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.compat.v2.name_scope(
                mcmc_util.make_name(self.name, 'hmc', 'one_step')):
            if self._store_parameters_in_results:
                step_size = previous_kernel_results.step_size
                num_leapfrog_steps = previous_kernel_results.num_leapfrog_steps
            else:
                step_size = self.step_size
                num_leapfrog_steps = self.num_leapfrog_steps

            [
                current_state_parts,
                step_sizes,
                current_target_log_prob,
                current_target_log_prob_grad_parts,
            ] = _prepare_args(
                self.target_log_prob_fn,
                current_state,
                step_size,
                previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob,
                maybe_expand=True,
                state_gradients_are_stopped=self.state_gradients_are_stopped)

            current_momentum_parts = []
            for x in current_state_parts:
                current_momentum_parts.append(
                    tf.random.normal(shape=tf.shape(input=x),
                                     dtype=self._momentum_dtype
                                     or x.dtype.base_dtype,
                                     seed=self._seed_stream()))

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn, step_sizes, num_leapfrog_steps)

            [
                next_momentum_parts,
                next_state_parts,
                next_target_log_prob,
                next_target_log_prob_grad_parts,
            ] = integrator(current_momentum_parts, current_state_parts,
                           current_target_log_prob,
                           current_target_log_prob_grad_parts)
            if self.state_gradients_are_stopped:
                next_state_parts = [
                    tf.stop_gradient(x) for x in next_state_parts
                ]

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            independent_chain_ndims = distribution_util.prefer_static_rank(
                current_target_log_prob)

            new_kernel_results = previous_kernel_results._replace(
                log_acceptance_correction=_compute_log_acceptance_correction(
                    current_momentum_parts, next_momentum_parts,
                    independent_chain_ndims),
                target_log_prob=next_target_log_prob,
                grads_target_log_prob=next_target_log_prob_grad_parts,
            )

            return maybe_flatten(next_state_parts), new_kernel_results
Пример #27
0
    def __init__(self,
                 inner_kernel,
                 num_adaptation_steps,
                 target_accept_prob=0.75,
                 adaptation_rate=0.01,
                 step_size_setter_fn=_hmc_like_step_size_setter_fn,
                 step_size_getter_fn=_hmc_like_step_size_getter_fn,
                 log_accept_prob_getter_fn=_hmc_like_log_accept_prob_getter_fn,
                 validate_args=False,
                 name=None):
        """Creates the step size adaptation kernel.

    The default setter_fn and the getter_fn callbacks assume that the inner
    kernel produces kernel results structurally the same as the
    `HamiltonianMonteCarlo` kernel.

    Args:
      inner_kernel: `TransitionKernel`-like object.
      num_adaptation_steps: Scalar `int` `Tensor` number of initial steps to
        during which to adjust the step size. This may be greater, less than, or
        equal to the number of burnin steps.
      target_accept_prob: A floating point `Tensor` representing desired
        acceptance probability. Must be a positive number less than 1. This can
        either be a scalar, or have shape [num_chains]. Default value: `0.75`
        (the [center of asymptotically optimal rate for HMC][1]).
      adaptation_rate: `Tensor` representing amount to scale the current
        `step_size`.
      step_size_setter_fn: A callable with the signature
        `(kernel_results, new_step_size) -> new_kernel_results` where
        `kernel_results` are the results of the `inner_kernel`, `new_step_size`
        is a `Tensor` or a nested collection of `Tensor`s with the same
        structure as returned by the `step_size_getter_fn`, and
        `new_kernel_results` are a copy of `kernel_results` with the step
        size(s) set.
      step_size_getter_fn: A callable with the signature
        `(kernel_results) -> step_size` where `kernel_results` are the results
        of the `inner_kernel`, and `step_size` is a floating point `Tensor` or a
        nested collection of such `Tensor`s.
      log_accept_prob_getter_fn: A callable with the signature
        `(kernel_results) -> log_accept_prob` where `kernel_results` are the
        results of the `inner_kernel`, and `log_accept_prob` is a floating point
        `Tensor`. `log_accept_prob` can either be a scalar, or have shape
        [num_chains]. If it's the latter, `step_size` should also have the same
        leading dimension.
      validate_args: Python `bool`. When `True` kernel parameters are checked
        for validity. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class. Default:
        'simple_step_size_adaptation'.

    #### References

    [1]: Betancourt, M. J., Byrne, S., & Girolami, M. (2014). _Optimizing The
         Integrator Step Size for Hamiltonian Monte Carlo_.
         http://arxiv.org/abs/1411.6669
    """

        inner_kernel = mcmc_util.enable_store_parameters_in_results(
            inner_kernel)

        with tf.compat.v1.name_scope(mcmc_util.make_name(
                name, 'simple_step_size_adaptation', '__init__'),
                                     values=[
                                         target_accept_prob, adaptation_rate,
                                         num_adaptation_steps
                                     ]) as name:
            dtype = dtype_util.common_dtype(
                [target_accept_prob, adaptation_rate], tf.float32)
            target_accept_prob = tf.convert_to_tensor(
                value=target_accept_prob,
                dtype=dtype,
                name='target_accept_prob')
            adaptation_rate = tf.convert_to_tensor(value=adaptation_rate,
                                                   dtype=dtype,
                                                   name='adaptation_rate')
            num_adaptation_steps = tf.convert_to_tensor(
                value=num_adaptation_steps,
                dtype=tf.int32,
                name='num_adaptation_steps')

            target_accept_prob = _maybe_validate_target_accept_prob(
                target_accept_prob, validate_args)

        self._parameters = dict(
            inner_kernel=inner_kernel,
            num_adaptation_steps=num_adaptation_steps,
            target_accept_prob=target_accept_prob,
            adaptation_rate=adaptation_rate,
            step_size_setter_fn=step_size_setter_fn,
            step_size_getter_fn=step_size_getter_fn,
            log_accept_prob_getter_fn=log_accept_prob_getter_fn,
            name=name,
        )
Пример #28
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        with tf.compat.v1.name_scope(
                name=mcmc_util.make_name(self.name, 'mh', 'one_step'),
                values=[current_state, previous_kernel_results]):
            # Take one inner step.
            [
                proposed_state,
                proposed_results,
            ] = self.inner_kernel.one_step(
                current_state, previous_kernel_results.accepted_results)

            if (not has_target_log_prob(proposed_results)
                    or not has_target_log_prob(
                        previous_kernel_results.accepted_results)):
                raise ValueError('"target_log_prob" must be a member of '
                                 '`inner_kernel` results.')

            # Compute log(acceptance_ratio).
            to_sum = [
                proposed_results.target_log_prob,
                -previous_kernel_results.accepted_results.target_log_prob
            ]
            try:
                if (not mcmc_util.is_list_like(
                        proposed_results.log_acceptance_correction)
                        or proposed_results.log_acceptance_correction):
                    to_sum.append(proposed_results.log_acceptance_correction)
            except AttributeError:
                warnings.warn(
                    'Supplied inner `TransitionKernel` does not have a '
                    '`log_acceptance_correction`. Assuming its value is `0.`')
            log_accept_ratio = mcmc_util.safe_sum(
                to_sum, name='compute_log_accept_ratio')

            # If proposed state reduces likelihood: randomly accept.
            # If proposed state increases likelihood: always accept.
            # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
            #       ==> log(u) < log_accept_ratio
            log_uniform = tf.math.log(
                tf.random.uniform(
                    shape=tf.shape(input=proposed_results.target_log_prob),
                    dtype=proposed_results.target_log_prob.dtype.base_dtype,
                    seed=self._seed_stream()))
            is_accepted = log_uniform < log_accept_ratio

            next_state = mcmc_util.choose(is_accepted,
                                          proposed_state,
                                          current_state,
                                          name='choose_next_state')

            kernel_results = MetropolisHastingsKernelResults(
                accepted_results=mcmc_util.choose(
                    is_accepted,
                    proposed_results,
                    previous_kernel_results.accepted_results,
                    name='choose_inner_results'),
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
                extra=[],
            )

            return next_state, kernel_results
Пример #29
0
  def one_step(self, current_state, previous_kernel_results, seed=None):
    """Runs one iteration of Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions
        index independent chains,
        `r = tf.rank(target_log_prob_fn(*current_state))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `step_size` or a list with same length as
        `current_state`.
      TypeError: if `not target_log_prob.dtype.is_floating`.
    """
    seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.

    with tf.name_scope(mcmc_util.make_name(self.name, 'slice', 'one_step')):
      with tf.name_scope('initialize'):
        [
            current_state_parts,
            step_sizes,
            current_target_log_prob
        ] = _prepare_args(
            self.target_log_prob_fn,
            current_state,
            self.step_size,
            previous_kernel_results.target_log_prob,
            maybe_expand=True)

        max_doublings = ps.convert_to_shape_tensor(
            value=self.max_doublings, dtype=tf.int32, name='max_doublings')

      independent_chain_ndims = ps.rank(current_target_log_prob)

      [
          next_state_parts,
          next_target_log_prob,
          bounds_satisfied,
          direction,
          upper_bounds,
          lower_bounds
      ] = _sample_next(
          self.target_log_prob_fn,
          current_state_parts,
          step_sizes,
          max_doublings,
          current_target_log_prob,
          independent_chain_ndims,
          seed=seed,
      )

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      return [
          maybe_flatten(next_state_parts),
          SliceSamplerKernelResults(
              target_log_prob=next_target_log_prob,
              bounds_satisfied=bounds_satisfied,
              direction=direction,
              upper_bounds=upper_bounds,
              lower_bounds=lower_bounds,
              seed=seed,
          ),
      ]
Пример #30
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        # Key difficulty:  The type of exchanges differs from one call to the
        # next...even the number of exchanges can differ.
        # As a result, exchanges must happen dynamically, in while loops.
        with tf.compat.v1.name_scope(
                name=mcmc_util.make_name(self.name, 'remc', 'one_step'),
                values=[current_state, previous_kernel_results]):

            # Each replica does `one_step` to get pre-exchange states/KernelResults.
            sampled_replica_states, sampled_replica_results = zip(*[
                rk.one_step(previous_kernel_results.replica_states[i],
                            previous_kernel_results.replica_results[i])
                for i, rk in enumerate(self.replica_kernels)
            ])
            sampled_replica_states = list(sampled_replica_states)
            sampled_replica_results = list(sampled_replica_results)

            states_are_lists = mcmc_util.is_list_like(
                sampled_replica_states[0])

            if not states_are_lists:
                sampled_replica_states = [[s] for s in sampled_replica_states]
            num_state_parts = len(sampled_replica_states[0])

            dtype = sampled_replica_states[0][0].dtype

            # Must put states into TensorArrays.  Why?  We will read/write states
            # dynamically with Tensor index `i`, and you cannot do this with lists.
            # old_states[k][i] is Tensor of (old) state part k, for replica i.
            # The `k` will be known statically, and `i` is a Tensor.
            old_states = [
                tf.TensorArray(
                    dtype,
                    size=self.num_replica,
                    dynamic_size=False,
                    clear_after_read=False,
                    tensor_array_name='old_states',
                    # State part k has same shape, regardless of replica.  So use 0.
                    element_shape=sampled_replica_states[0][k].shape)
                for k in range(num_state_parts)
            ]
            for k in range(num_state_parts):
                for i in range(self.num_replica):
                    old_states[k] = old_states[k].write(
                        i, sampled_replica_states[i][k])

            exchange_proposed = self.exchange_proposed_fn(
                self.num_replica, seed=self._seed_stream())
            exchange_proposed_n = tf.shape(input=exchange_proposed)[0]

            exchanged_states = self._get_exchanged_states(
                old_states, exchange_proposed, exchange_proposed_n,
                sampled_replica_states, sampled_replica_results)

            no_exchange_proposed, _ = tf.compat.v1.setdiff1d(
                tf.range(self.num_replica), tf.reshape(exchange_proposed,
                                                       [-1]))

            exchanged_states = self._insert_old_states_where_no_exchange_was_proposed(
                no_exchange_proposed, old_states, exchanged_states)

            next_replica_states = []
            for i in range(self.num_replica):
                next_replica_states_i = []
                for k in range(num_state_parts):
                    next_replica_states_i.append(exchanged_states[k].read(i))
                next_replica_states.append(next_replica_states_i)

            if not states_are_lists:
                next_replica_states = [s[0] for s in next_replica_states]
                sampled_replica_states = [s[0] for s in sampled_replica_states]

            # Now that states are/aren't exchanged, bootstrap next kernel_results.
            # The viewpoint is that after each exchange, we are starting anew.
            next_replica_results = [
                rk.bootstrap_results(state)
                for rk, state in zip(self.replica_kernels, next_replica_states)
            ]

            next_state = next_replica_states[
                0]  # Replica 0 is the returned state(s).

            kernel_results = ReplicaExchangeMCKernelResults(
                replica_states=next_replica_states,
                replica_results=next_replica_results,
                sampled_replica_states=sampled_replica_states,
                sampled_replica_results=sampled_replica_results,
            )

            return next_state, kernel_results