コード例 #1
0
 def _swap_then_retemper(x):
     x, is_multipart = mcmc_util.prepare_state_parts(x)
     it_ratio_ = mcmc_util.left_justified_expand_dims_like(it_ratio, x[0])
     x = [swap_tensor_fn(x_part) * it_ratio_ for x_part in x]
     if not is_multipart:
         x = x[0]
     return x
コード例 #2
0
def _make_momentum_distribution(running_variance_parts, state_parts,
                                batch_ndims):
  """Construct a momentum distribution from the running variance.

  This uses a running variance to construct a momentum distribution with the
  correct batch_shape and event_shape.

  Args:
    running_variance_parts: List of `Tensor`, outputs of
      `tfp.experimental.stats.RunningVariance.variance()`.
    state_parts: List of `Tensor`.
    batch_ndims: Scalar, for leading batch dimensions.

  Returns:
    `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
    and `.log_prob` of the sample will have the rank of `batch_ndims`
  """
  distributions = []
  for variance_part, state_part in zip(running_variance_parts, state_parts):
    running_variance_rank = ps.rank(variance_part)
    state_rank = ps.rank(state_part)
    # Pad dimensions and tile by multiplying by tf.ones to add a batch shape
    ones = tf.ones(ps.shape(state_part)[:-(state_rank - running_variance_rank)])
    ones = mcmc_util.left_justified_expand_dims_like(ones, state_part)
    variance_tiled = variance_part * ones
    reinterpreted_batch_ndims = state_rank - batch_ndims - 1

    distributions.append(
        _CompositeIndependent(
            _CompositeMultivariateNormalPrecisionFactorLinearOperator(
                precision_factor=_CompositeLinearOperatorDiag(
                    tf.math.sqrt(variance_tiled)),
                precision=_CompositeLinearOperatorDiag(variance_tiled)),
            reinterpreted_batch_ndims=reinterpreted_batch_ndims))
  return _CompositeJointDistributionSequential(distributions)
 def leapfrog_action(dt):
   # This represents the effect on the criterion value as the state follows the
   # proposed velocity. This implicitly assumes an identity mass matrix.
   return criterion_fn(
       previous_state,
       tf.nest.map_structure(
           lambda x, v:  # pylint: disable=g-long-lambda
           (x + mcmc_util.left_justified_expand_dims_like(dt, v) * v),
           proposed_state,
           proposed_velocity),
       accept_prob)
コード例 #4
0
    def _replica_target_log_prob(*x):
        if tempered_log_prob_fn is not None:
            tlp = tempered_log_prob_fn(*x)
        else:
            tlp = target_log_prob_fn(*x)

        log_prob = tf.cast(mcmc_util.left_justified_expand_dims_like(
            inverse_temperatures, tlp),
                           dtype=tlp.dtype) * tlp

        if untempered_log_prob_fn is not None:
            log_prob = log_prob + untempered_log_prob_fn(*x)

        return log_prob
コード例 #5
0
def make_rwmh_kernel_fn(target_log_prob_fn, init_state, scalings):
    """Generate a Random Walk MH kernel."""
    with tf.name_scope('make_rwmh_kernel_fn'):
        state_std = [
            tf.math.reduce_std(x, axis=0, keepdims=True) for x in init_state
        ]
        step_size = [
            s * ps.cast(  # pylint: disable=g-complex-comprehension
                mcmc_util.left_justified_expand_dims_like(scalings, s),
                s.dtype) for s in state_std
        ]
        return random_walk_metropolis.RandomWalkMetropolis(
            target_log_prob_fn,
            new_state_fn=random_walk_metropolis.random_walk_normal_fn(
                scale=step_size))
  def _center_proposed_state(x):
    # The empirical mean here is a stand-in for the true mean, so we drop the
    # gradient that flows through this term. The goal here is to get a reliable
    # diagnostic of the unrelying dynamics, rather than incorporating the effect
    # of the MetropolisHastings correction.
    # TODO(mhoffman): Needs more experimentation.
    expanded_accept_prob = mcmc_util.left_justified_expand_dims_like(
        accept_prob, x)

    # accept_prob is zero when x is NaN, but we still want to sanitize such
    # values.
    x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
    # If all accept_prob's are zero, the x_center will have a nonsense value,
    # but we'll discard the resultant gradients later on, so it's fine.
    x_center = (
        tf.reduce_sum(expanded_accept_prob * x_safe, axis=batch_axes) /
        (tf.reduce_sum(expanded_accept_prob, axis=batch_axes) + 1e-20))

    return x - tf.stop_gradient(x_center)
コード例 #7
0
def compute_hmc_step_size(scalings, state_std, num_leapfrog_steps):
    return [
        s / ps.cast(num_leapfrog_steps, s.dtype) * ps.cast(  # pylint: disable=g-complex-comprehension
            mcmc_util.left_justified_expand_dims_like(scalings, s), s.dtype)
        for s in state_std
    ]
コード例 #8
0
ファイル: nuts.py プロジェクト: wataruhashimoto52/probability
  def _loop_build_sub_tree(self,
                           directions,
                           integrator,
                           current_step_meta_info,
                           iter_,
                           energy_diff_sum_previous,
                           momentum_cumsum_previous,
                           leapfrogs_taken,
                           prev_tree_state,
                           candidate_tree_state,
                           continue_tree_previous,
                           not_divergent_previous,
                           momentum_state_memory):
    """Base case in tree doubling."""
    with tf.name_scope('loop_build_sub_tree'):
      # Take one leapfrog step in the direction v and check divergence
      [
          next_momentum_parts,
          next_state_parts,
          next_target,
          next_target_grad_parts
      ] = integrator(prev_tree_state.momentum,
                     prev_tree_state.state,
                     prev_tree_state.target,
                     prev_tree_state.target_grad_parts)

      next_tree_state = TreeDoublingState(
          momentum=next_momentum_parts,
          state=next_state_parts,
          target=next_target,
          target_grad_parts=next_target_grad_parts)
      momentum_cumsum = [p0 + p1 for p0, p1 in zip(momentum_cumsum_previous,
                                                   next_momentum_parts)]
      # If the tree have not yet terminated previously, we count this leapfrog.
      leapfrogs_taken = tf.where(
          continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken)

      write_instruction = current_step_meta_info.write_instruction
      read_instruction = current_step_meta_info.read_instruction
      init_energy = current_step_meta_info.init_energy

      if GENERALIZED_UTURN:
        state_to_write = momentum_cumsum_previous
        state_to_check = momentum_cumsum
      else:
        state_to_write = next_state_parts
        state_to_check = next_state_parts

      batch_shape = ps.shape(next_target)
      has_not_u_turn_init = ps.ones(batch_shape, dtype=tf.bool)

      read_index = read_instruction.gather([iter_])[0]
      no_u_turns_within_tree = has_not_u_turn_at_all_index(  # pylint: disable=g-long-lambda
          read_index,
          directions,
          momentum_state_memory,
          next_momentum_parts,
          state_to_check,
          has_not_u_turn_init,
          log_prob_rank=ps.rank(next_target))

      # Get index to write state into memory swap
      write_index = write_instruction.gather([iter_])
      momentum_state_memory = MomentumStateSwap(
          momentum_swap=[
              tf.tensor_scatter_nd_update(old, [write_index], [new])
              for old, new in zip(momentum_state_memory.momentum_swap,
                                  next_momentum_parts)
          ],
          state_swap=[
              tf.tensor_scatter_nd_update(old, [write_index], [new])
              for old, new in zip(momentum_state_memory.state_swap,
                                  state_to_write)
          ])

      energy = compute_hamiltonian(next_target, next_momentum_parts)
      current_energy = tf.where(tf.math.is_nan(energy),
                                tf.constant(-np.inf, dtype=energy.dtype),
                                energy)
      energy_diff = current_energy - init_energy

      if MULTINOMIAL_SAMPLE:
        not_divergent = -energy_diff < self.max_energy_diff
        weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff)
        log_accept_thresh = energy_diff - weight_sum
      else:
        log_slice_sample = current_step_meta_info.log_slice_sample
        not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
        # Uniform sampling on the trajectory within the subtree across valid
        # samples.
        is_valid = log_slice_sample <= energy_diff
        weight_sum = tf.where(is_valid,
                              candidate_tree_state.weight + 1,
                              candidate_tree_state.weight)
        log_accept_thresh = tf.where(
            is_valid,
            -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
            tf.constant(-np.inf, dtype=tf.float32))
      u = tf.math.log1p(-tf.random.uniform(
          shape=batch_shape,
          dtype=log_accept_thresh.dtype,
          seed=self._seed_stream()))
      is_sample_accepted = u <= log_accept_thresh

      next_candidate_tree_state = TreeDoublingStateCandidate(
          state=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  mcmc_util.left_justified_expand_dims_like(
                      is_sample_accepted, s0), s0, s1)
              for s0, s1 in zip(next_state_parts, candidate_tree_state.state)
          ],
          target=tf.where(
              mcmc_util.left_justified_expand_dims_like(
                  is_sample_accepted, next_target),
              next_target, candidate_tree_state.target),
          target_grad_parts=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  mcmc_util.left_justified_expand_dims_like(
                      is_sample_accepted, grad0),
                  grad0, grad1)
              for grad0, grad1 in zip(next_target_grad_parts,
                                      candidate_tree_state.target_grad_parts)
          ],
          energy=tf.where(
              mcmc_util.left_justified_expand_dims_like(
                  is_sample_accepted, next_target),
              current_energy, candidate_tree_state.energy),
          weight=weight_sum)

      continue_tree = not_divergent & continue_tree_previous
      continue_tree_next = no_u_turns_within_tree & continue_tree

      not_divergent_tokeep = tf.where(
          continue_tree_previous,
          not_divergent,
          ps.ones(batch_shape, dtype=tf.bool))

      # min(1., exp(energy_diff)).
      exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.))
      energy_diff_sum = tf.where(continue_tree,
                                 energy_diff_sum_previous + exp_energy_diff,
                                 energy_diff_sum_previous)

      return (
          iter_ + 1,
          energy_diff_sum,
          momentum_cumsum,
          leapfrogs_taken,
          next_tree_state,
          next_candidate_tree_state,
          continue_tree_next,
          not_divergent_previous & not_divergent_tokeep,
          momentum_state_memory,
      )
コード例 #9
0
ファイル: nuts.py プロジェクト: wataruhashimoto52/probability
  def loop_tree_doubling(self, step_size, momentum_state_memory,
                         current_step_meta_info, iter_, initial_step_state,
                         initial_step_metastate):
    """Main loop for tree doubling."""
    with tf.name_scope('loop_tree_doubling'):
      batch_shape = ps.shape(current_step_meta_info.init_energy)
      direction = tf.cast(
          tf.random.uniform(
              shape=batch_shape,
              minval=0,
              maxval=2,
              dtype=tf.int32,
              seed=self._seed_stream()),
          dtype=tf.bool)

      tree_start_states = tf.nest.map_structure(
          lambda v: tf.where(  # pylint: disable=g-long-lambda
              mcmc_util.left_justified_expand_dims_like(direction, v[1]),
              v[1], v[0]),
          initial_step_state)

      directions_expanded = [
          mcmc_util.left_justified_expand_dims_like(direction, state)
          for state in tree_start_states.state
      ]

      integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
          self.target_log_prob_fn,
          step_sizes=[
              tf.where(d, ss, -ss)
              for d, ss in zip(directions_expanded, step_size)
          ],
          num_steps=self.unrolled_leapfrog_steps)

      [
          candidate_tree_state,
          tree_final_states,
          final_not_divergence,
          continue_tree_final,
          energy_diff_tree_sum,
          momentum_subtree_cumsum,
          leapfrogs_taken
      ] = self._build_sub_tree(
          directions_expanded,
          integrator,
          current_step_meta_info,
          # num_steps_at_this_depth = 2**iter_ = 1 << iter_
          tf.bitwise.left_shift(1, iter_),
          tree_start_states,
          initial_step_metastate.continue_tree,
          initial_step_metastate.not_divergence,
          momentum_state_memory)

      last_candidate_state = initial_step_metastate.candidate_state

      energy_diff_sum = (
          energy_diff_tree_sum + initial_step_metastate.energy_diff_sum)
      if MULTINOMIAL_SAMPLE:
        tree_weight = tf.where(
            continue_tree_final,
            candidate_tree_state.weight,
            tf.constant(-np.inf, dtype=candidate_tree_state.weight.dtype))
        weight_sum = log_add_exp(tree_weight, last_candidate_state.weight)
        log_accept_thresh = tree_weight - last_candidate_state.weight
      else:
        tree_weight = tf.where(
            continue_tree_final,
            candidate_tree_state.weight,
            tf.zeros([], dtype=TREE_COUNT_DTYPE))
        weight_sum = tree_weight + last_candidate_state.weight
        log_accept_thresh = tf.math.log(
            tf.cast(tree_weight, tf.float32) /
            tf.cast(last_candidate_state.weight, tf.float32))
      log_accept_thresh = tf.where(
          tf.math.is_nan(log_accept_thresh),
          tf.zeros([], log_accept_thresh.dtype),
          log_accept_thresh)
      u = tf.math.log1p(-tf.random.uniform(
          shape=batch_shape,
          dtype=log_accept_thresh.dtype,
          seed=self._seed_stream()))
      is_sample_accepted = u <= log_accept_thresh

      choose_new_state = is_sample_accepted & continue_tree_final

      new_candidate_state = TreeDoublingStateCandidate(
          state=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  mcmc_util.left_justified_expand_dims_like(
                      choose_new_state, s0),
                  s0, s1)
              for s0, s1 in zip(candidate_tree_state.state,
                                last_candidate_state.state)
          ],
          target=tf.where(
              mcmc_util.left_justified_expand_dims_like(
                  choose_new_state,
                  candidate_tree_state.target),
              candidate_tree_state.target, last_candidate_state.target),
          target_grad_parts=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  mcmc_util.left_justified_expand_dims_like(
                      choose_new_state, grad0),
                  grad0, grad1)
              for grad0, grad1 in zip(candidate_tree_state.target_grad_parts,
                                      last_candidate_state.target_grad_parts)
          ],
          energy=tf.where(
              mcmc_util.left_justified_expand_dims_like(
                  choose_new_state, candidate_tree_state.target),
              candidate_tree_state.energy, last_candidate_state.energy),
          weight=weight_sum)

      for new_candidate_state_temp, old_candidate_state_temp in zip(
          new_candidate_state.state, last_candidate_state.state):
        tensorshape_util.set_shape(new_candidate_state_temp,
                                   old_candidate_state_temp.shape)

      for new_candidate_grad_temp, old_candidate_grad_temp in zip(
          new_candidate_state.target_grad_parts,
          last_candidate_state.target_grad_parts):
        tensorshape_util.set_shape(new_candidate_grad_temp,
                                   old_candidate_grad_temp.shape)

      # Update left right information of the trajectory, and check trajectory
      # level U turn
      tree_otherend_states = tf.nest.map_structure(
          lambda v: tf.where(  # pylint: disable=g-long-lambda
              mcmc_util.left_justified_expand_dims_like(direction, v[1]),
              v[0], v[1]), initial_step_state)

      new_step_state = tf.nest.pack_sequence_as(initial_step_state, [
          tf.stack([  # pylint: disable=g-complex-comprehension
              tf.where(
                  mcmc_util.left_justified_expand_dims_like(direction, left),
                  right, left),
              tf.where(
                  mcmc_util.left_justified_expand_dims_like(direction, left),
                  left, right),
          ], axis=0)
          for left, right in zip(tf.nest.flatten(tree_final_states),
                                 tf.nest.flatten(tree_otherend_states))
      ])

      momentum_tree_cumsum = []
      for p0, p1 in zip(
          initial_step_metastate.momentum_sum, momentum_subtree_cumsum):
        momentum_part_temp = p0 + p1
        tensorshape_util.set_shape(momentum_part_temp, p0.shape)
        momentum_tree_cumsum.append(momentum_part_temp)

      for new_state_temp, old_state_temp in zip(
          tf.nest.flatten(new_step_state),
          tf.nest.flatten(initial_step_state)):
        tensorshape_util.set_shape(new_state_temp, old_state_temp.shape)

      if GENERALIZED_UTURN:
        state_diff = momentum_tree_cumsum
      else:
        state_diff = [s[1] - s[0] for s in new_step_state.state]

      no_u_turns_trajectory = has_not_u_turn(
          state_diff,
          [m[0] for m in new_step_state.momentum],
          [m[1] for m in new_step_state.momentum],
          log_prob_rank=ps.rank_from_shape(batch_shape))

      new_step_metastate = TreeDoublingMetaState(
          candidate_state=new_candidate_state,
          is_accepted=choose_new_state | initial_step_metastate.is_accepted,
          momentum_sum=momentum_tree_cumsum,
          energy_diff_sum=energy_diff_sum,
          continue_tree=continue_tree_final & no_u_turns_trajectory,
          not_divergence=final_not_divergence,
          leapfrog_count=(initial_step_metastate.leapfrog_count +
                          leapfrogs_taken))

      return iter_ + 1, new_step_state, new_step_metastate
コード例 #10
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                target_log_prob_fn=self.target_log_prob_fn,
                inverse_temperatures=inverse_temperatures,
                untempered_log_prob_fn=self.untempered_log_prob_fn,
                tempered_log_prob_fn=self.tempered_log_prob_fn,
            )
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` '
                    'argument. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3)
            # Step the inner TransitionKernel.
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                seed=inner_seed)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs for use in the swap acceptance ratio.
            if self.tempered_log_prob_fn is None:
                # Efficient way of re-evaluating target_log_prob_fn on the
                # pre_swap_replica_states.
                untempered_energy_ignoring_ulp = (
                    # Since untempered_log_prob_fn is None, we may assume
                    # inverse_temperatures > 0 (else the target is improper).
                    pre_swap_replica_target_log_prob / inverse_temperatures)
            else:
                # The untempered_log_prob_fn does not factor into the acceptance ratio.
                # Proof: Suppose the tempered target is
                #   p_k(x) = f(x)^{beta_k} g(x),
                # So f(x) is tempered, and g(x) is not.  Then, the acceptance ratio for
                # a 1 <--> 2 swap is...
                #   (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2))
                # which depends only on f(x), since terms involving g(x) cancel.
                untempered_energy_ignoring_ulp = self.tempered_log_prob_fn(
                    *pre_swap_replica_states)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            # Note: The untempered_log_prob_fn (if provided) is not included in
            # untempered_pre_swap_replica_target_log_prob, and hence does not factor
            # into energy_diff. Why? Because, it cancels out in the acceptance ratio.
            energy_diff = (untempered_energy_ignoring_ulp -
                           mcmc_util.index_remapping_gather(
                               untempered_energy_ignoring_ulp,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _set_swapped_fields_to_nan(
                _swap_log_prob_and_maybe_grads(pre_swap_replica_results,
                                               post_swap_replica_states,
                                               inner_kernel))

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=seed,
            )

            return states, post_swap_kernel_results
コード例 #11
0
 def _replica_target_log_prob(*x):
     tlp = target_log_prob_fn(*x)
     return tf.cast(mcmc_util.left_justified_expand_dims_like(
         inverse_temperatures, tlp),
                    dtype=tlp.dtype) * tlp
コード例 #12
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                self.target_log_prob_fn, inverse_temperatures)
            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`, and no seed
            # expected by `kernel.one_step`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we warn and fall back to the previous behavior.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is '
                    'deprecated. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel, self._seed_stream())

            # Now that we've constructed the TransitionKernel instance:
            # - If we were given a seed, we sanitize it to stateless and pass along
            #   to `kernel.one_step`. If it doesn't like that, we crash and propagate
            #   the error.  Rationale: The contract is stateless sampling given
            #   seed, and doing otherwise would not meet it.
            # - If not given a seed, we don't pass one along. This avoids breaking
            #   underlying kernels lacking a `seed` arg on `one_step`.
            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                seed = samplers.sanitize_seed(seed)
                inner_seed, swap_seed, logu_seed = samplers.split_seed(
                    seed, n=3, salt='remc_one_step')
                inner_kwargs = dict(seed=inner_seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                inner_kwargs = {}
                swap_seed, logu_seed = samplers.split_seed(self._seed_stream())
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                **inner_kwargs)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            return states, post_swap_kernel_results
コード例 #13
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                _make_replica_target_log_prob_fn(self.target_log_prob_fn,
                                                 inverse_temperatures),
                self._seed_stream())

            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = prefer_static.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = prefer_static.rank(
                pre_swap_replica_target_log_prob)
            num_replica = prefer_static.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            swaps = tf.cast(
                self.swap_proposal_fn(  # pylint: disable=not-callable
                    num_replica,
                    batch_shape=batch_shape,
                    seed=self._seed_stream()),
                dtype=tf.int32)
            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                tf.random.uniform(shape=replica_and_batch_shape,
                                  dtype=dtype,
                                  seed=self._seed_stream()))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
            )

            return states, post_swap_kernel_results