示例#1
0
    def adjacent_swaps(num_replica,
                       batch_shape=(),
                       step_count=None,
                       seed=None):
        """Make random shuffle using only one time swaps."""
        del step_count  # Unused for this function.
        with tf.name_scope(name or 'adjacent_swaps'):
            parity_seed, proposal_seed = samplers.split_seed(seed)
            # u selects parity.  E.g.,
            #  u==False ==> [1, 0, 3, 2, 4] even parity swaps
            #  u==True ==>  [0, 2, 1, 4, 3] odd parity swaps
            # If there are only 2 replicas, then the "True" swaps are null
            # swaps...which would contradict the user provided `prob_swap`.
            # So special case num_replica==2, forcing u==False in this case.
            u_shape = ps.concat(
                (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)),
                axis=0)
            u = samplers.uniform(u_shape, seed=parity_seed) < 0.5
            u = tf.where(num_replica > 2, u, False)

            x = mcmc_util.left_justified_expand_dims_to(ps.range(
                num_replica, dtype=tf.int64),
                                                        rank=ps.size(u_shape))
            y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1,
                         x - 1)
            y = tf.clip_by_value(y, 0, num_replica - 1)
            # TODO(b/142689785): Consider using tf.cond and returning an empty list
            # then in REMC consider using a tf.cond for short-circuiting.
            return tf.where(
                samplers.uniform(batch_shape, seed=proposal_seed) < prob_swap,
                y, x)
示例#2
0
    def adjacent_swaps(num_replica, batch_shape=(), seed=None):
        """Make random shuffle using only one time swaps."""
        with tf.name_scope(name or 'adjacent_swaps'):
            seed = SeedStream(seed, salt='random_adjacent_shuffle')
            # u selects parity.  E.g.,
            #  u==True ==>  [0, 2, 1, 4, 3] like swaps
            #  u==False ==> [1, 0, 3, 2, 4] like swaps
            # If there are only 2 replicas, then the "True" swaps are null
            # swaps...which would contradict the user provided `prob_swap`.
            # So special case num_replica==2, forcing u==False in this case.
            u_shape = prefer_static.concat(
                (tf.ones(1, dtype=tf.int32), tf.cast(batch_shape, tf.int32)),
                axis=0)
            u = tf.random.uniform(u_shape, seed=seed()) < 0.5
            u = tf.where(num_replica > 2, u, False)

            x = mcmc_util.left_justified_expand_dims_to(
                tf.range(num_replica, dtype=tf.int64),
                rank=prefer_static.size(u_shape))
            y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1,
                         x - 1)
            y = tf.clip_by_value(y, 0, num_replica - 1)
            # TODO(b/142689785): Consider using tf.cond and returning an empty list
            # then in REMC consider using a tf.cond for short-circuiting.
            return tf.where(
                tf.random.uniform(batch_shape, seed=seed()) < prob_swap, y, x)
示例#3
0
    def even_odd_swaps(num_replica,
                       batch_shape=(),
                       step_count=None,
                       seed=None):
        """Make deterministic even_odd one time swaps."""
        if step_count is None:
            raise ValueError('`step_count` must be supplied. Found `None`.')
        del seed  # Unused for this function.
        with tf.name_scope(name or 'even_odd_swaps'):
            # Period is 1 / frequency, and we want period = Inf if frequency = 0.
            # safe_swap_period is the correct swap period in case swap_frequency > 0.
            # If swap_frequency == 0, safe_swap_period is set to 1 (to avoid integer
            # div by zero below). We will hard-set this case to "null swap."
            swap_freq = tf.convert_to_tensor(swap_frequency,
                                             name='swap_frequency')
            safe_swap_period = tf.cast(
                tf.where(swap_freq > 0,
                         tf.math.ceil(tf.math.reciprocal_no_nan(swap_freq)),
                         1),
                # Although period = 1 / frequency may have roundoff error, and result
                # in a period different than what the user intended, the
                # user will end up with a single integer period, and thus well defined
                # deterministic swaps.
                tf.int32,
            )

            # u selects parity.  E.g.,
            #  u==False ==> [1, 0, 3, 2, 4] even parity swaps
            #  u==True ==>  [0, 2, 1, 4, 3] odd parity swaps
            # If there are 2 replicas, then the "True" swaps are null
            # swaps...which would contradict the user provided `swap_frequency`.
            # So special case num_replica==2, forcing u==False in this case.
            u_shape = ps.concat(
                (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)),
                axis=0)
            u = tf.fill(u_shape,
                        tf.cast((step_count // safe_swap_period) % 2, tf.bool))
            u = tf.where(num_replica > 2, u, False)

            x = mcmc_util.left_justified_expand_dims_to(tf.range(
                num_replica, dtype=tf.int64),
                                                        rank=ps.size(u_shape))
            y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1,
                         x - 1)
            y = tf.clip_by_value(y, 0, num_replica - 1)
            # TODO(b/142689785): Consider using tf.cond and returning an empty list
            # then in REMC consider using a tf.cond for short-circuiting.
            return tf.where(
                (tf.cast(step_count % safe_swap_period, tf.bool)
                 | tf.math.equal(swap_freq, 0)),
                x,  # Don't swap
                y,  # Swap
            )
示例#4
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                target_log_prob_fn=self.target_log_prob_fn,
                inverse_temperatures=inverse_temperatures,
                untempered_log_prob_fn=self.untempered_log_prob_fn,
                tempered_log_prob_fn=self.tempered_log_prob_fn,
            )
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` '
                    'argument. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3)
            # Step the inner TransitionKernel.
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                seed=inner_seed)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs for use in the swap acceptance ratio.
            if self.tempered_log_prob_fn is None:
                # Efficient way of re-evaluating target_log_prob_fn on the
                # pre_swap_replica_states.
                untempered_energy_ignoring_ulp = (
                    # Since untempered_log_prob_fn is None, we may assume
                    # inverse_temperatures > 0 (else the target is improper).
                    pre_swap_replica_target_log_prob / inverse_temperatures)
            else:
                # The untempered_log_prob_fn does not factor into the acceptance ratio.
                # Proof: Suppose the tempered target is
                #   p_k(x) = f(x)^{beta_k} g(x),
                # So f(x) is tempered, and g(x) is not.  Then, the acceptance ratio for
                # a 1 <--> 2 swap is...
                #   (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2))
                # which depends only on f(x), since terms involving g(x) cancel.
                untempered_energy_ignoring_ulp = self.tempered_log_prob_fn(
                    *pre_swap_replica_states)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            # Note: The untempered_log_prob_fn (if provided) is not included in
            # untempered_pre_swap_replica_target_log_prob, and hence does not factor
            # into energy_diff. Why? Because, it cancels out in the acceptance ratio.
            energy_diff = (untempered_energy_ignoring_ulp -
                           mcmc_util.index_remapping_gather(
                               untempered_energy_ignoring_ulp,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _set_swapped_fields_to_nan(
                _swap_log_prob_and_maybe_grads(pre_swap_replica_results,
                                               post_swap_replica_states,
                                               inner_kernel))

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=seed,
            )

            return states, post_swap_kernel_results
示例#5
0
def _make_post_swap_replica_results(pre_swap_replica_results,
                                    inverse_temperatures,
                                    swapped_inverse_temperatures,
                                    is_swap_accepted_mask, swap_tensor_fn):
    """Return Kernel results, valid for post-swap states.

  Fields will be removed if they cannot be updated in an unambiguous manner.

  Args:
    pre_swap_replica_results: Kernel results obtained by running
      inner_kernel.one_step before swapping.
    inverse_temperatures: Tensor of inverse temperatures.
    swapped_inverse_temperatures: Tensor of inverse temperatures, permuted by
      swaps.
    is_swap_accepted_mask: Shape [num_replica] + batch_shape boolean Tensor
      telling which swaps were accepted.  Returns Kernel results of same type as
      pre_swap_replica_results.
    swap_tensor_fn: Callable. For `x.shape = [num_replica] + batch_shape`,
      swap_tensor_fn(x) performs swaps where they are accepted, and does not
      swap otherwise.

  Returns:
    new_replica_results:  Same type as pre_swap_replica_results.

  Raises:
    NotImplementedError: If type of [nested] results is not handled.
  """
    if not isinstance(pre_swap_replica_results,
                      metropolis_hastings.MetropolisHastingsKernelResults):
        # TODO(b/143702650) Handle other kernels.
        raise NotImplementedError(
            '`pre_swap_replica_results` currently works only for '
            'MetropolisHastingsKernelResults.  Found: {}. '
            'Please file a request with the TensorFlow Probability team.'.
            format(type(pre_swap_replica_results)))

    kr = pre_swap_replica_results
    dtype = swapped_inverse_temperatures.dtype

    # Hard to modify proposed_results in an um-ambiguous manner.
    # ...we also don't need to.
    kr = kr._replace(
        proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype),
        proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype),
    )

    replica_and_batch_rank = ps.rank(kr.log_accept_ratio)

    # After using swap_tensor_fn on "values", values will be multiplied by the
    # swapped_inverse_temperatures.  We need it to be multiplied instead by the
    # inverse temperature corresponding to its index.
    it_ratio_raw = inverse_temperatures / swapped_inverse_temperatures
    it_ratio = tf.where(
        is_swap_accepted_mask,
        mcmc_util.left_justified_expand_dims_to(it_ratio_raw,
                                                replica_and_batch_rank),
        tf.convert_to_tensor(1.0, dtype=dtype))

    def _swap_then_retemper(x):
        x, is_multipart = mcmc_util.prepare_state_parts(x)
        it_ratio_ = mcmc_util.left_justified_expand_dims_like(it_ratio, x[0])
        x = [swap_tensor_fn(x_part) * it_ratio_ for x_part in x]
        if not is_multipart:
            x = x[0]
        return x

    if isinstance(kr.accepted_results,
                  hmc.UncalibratedHamiltonianMonteCarloKernelResults):
        kr = kr._replace(accepted_results=kr.accepted_results._replace(
            target_log_prob=_swap_then_retemper(
                kr.accepted_results.target_log_prob),
            grads_target_log_prob=_swap_then_retemper(
                kr.accepted_results.grads_target_log_prob)))
    elif isinstance(kr.accepted_results,
                    random_walk_metropolis.UncalibratedRandomWalkResults):
        kr = kr._replace(accepted_results=kr.accepted_results._replace(
            target_log_prob=_swap_then_retemper(
                kr.accepted_results.target_log_prob)))
    else:
        # TODO(b/143702650) Handle other kernels.
        raise NotImplementedError(
            'Only HMC and RWMH Kernels are handled at this time. Please file a '
            'request with the TensorFlow Probability team.')

    return kr
示例#6
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                self.target_log_prob_fn, inverse_temperatures)
            # Seed handling complexity is due to users possibly expecting an old-style
            # stateful seed to be passed to `self.make_kernel_fn`, and no seed
            # expected by `kernel.one_step`.
            # In other words:
            # - We try `make_kernel_fn` without a seed first; this is the future. The
            #   kernel will receive a seed later, as part of `one_step`.
            # - If the user code doesn't like that (Python complains about a missing
            #   required argument), we warn and fall back to the previous behavior.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is '
                    'deprecated. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel, self._seed_stream())

            # Now that we've constructed the TransitionKernel instance:
            # - If we were given a seed, we sanitize it to stateless and pass along
            #   to `kernel.one_step`. If it doesn't like that, we crash and propagate
            #   the error.  Rationale: The contract is stateless sampling given
            #   seed, and doing otherwise would not meet it.
            # - If not given a seed, we don't pass one along. This avoids breaking
            #   underlying kernels lacking a `seed` arg on `one_step`.
            # TODO(b/159636942): Clean up after 2020-09-20.
            if seed is not None:
                seed = samplers.sanitize_seed(seed)
                inner_seed, swap_seed, logu_seed = samplers.split_seed(
                    seed, n=3, salt='remc_one_step')
                inner_kwargs = dict(seed=inner_seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                inner_kwargs = {}
                swap_seed, logu_seed = samplers.split_seed(self._seed_stream())
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                **inner_kwargs)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            return states, post_swap_kernel_results
示例#7
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                _make_replica_target_log_prob_fn(self.target_log_prob_fn,
                                                 inverse_temperatures),
                self._seed_stream())

            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = prefer_static.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = prefer_static.rank(
                pre_swap_replica_target_log_prob)
            num_replica = prefer_static.size0(inverse_temperatures)

            inverse_temperatures = mcmc_util.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            swaps = tf.cast(
                self.swap_proposal_fn(  # pylint: disable=not-callable
                    num_replica,
                    batch_shape=batch_shape,
                    seed=self._seed_stream()),
                dtype=tf.int32)
            null_swaps = mcmc_util.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs.  E.g., for replica k, at point x_k, this is
            # Log[p(x_k)], and *not* Log[p_x(x_k)] = Log[p(x_k)] * beta_k.
            untempered_pre_swap_replica_target_log_prob = (
                pre_swap_replica_target_log_prob / inverse_temperatures)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            energy_diff = (untempered_pre_swap_replica_target_log_prob -
                           mcmc_util.index_remapping_gather(
                               untempered_pre_swap_replica_target_log_prob,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff *
                                mcmc_util.left_justified_expand_dims_to(
                                    inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                tf.random.uniform(shape=replica_and_batch_shape,
                                  dtype=dtype,
                                  seed=self._seed_stream()))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = mcmc_util.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                mcmc_util.left_justified_broadcast_to(swaps,
                                                      replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _make_post_swap_replica_results(
                pre_swap_replica_results, inverse_temperatures,
                swapped_inverse_temperatures, is_swap_accepted_mask,
                _swap_tensor)

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
            )

            return states, post_swap_kernel_results
示例#8
0
def _make_post_swap_replica_results(pre_swap_replica_results,
                                    inverse_temperatures,
                                    swapped_inverse_temperatures,
                                    is_swap_accepted_mask, swap_tensor_fn):
    """Return Kernel results, valid for post-swap states.

  Fields will be removed if they cannot be updated in an unambiguous manner.

  Args:
    pre_swap_replica_results: Kernel results obtained by running
      inner_kernel.one_step before swapping.
    inverse_temperatures: Tensor of inverse temperatures.
    swapped_inverse_temperatures: Tensor of inverse temperatures, permuted by
      swaps.
    is_swap_accepted_mask: Shape [num_replica] + batch_shape boolean Tensor
      telling which swaps were accepted.  Returns Kernel results of same type as
      pre_swap_replica_results.
    swap_tensor_fn: Callable. For `x.shape = [num_replica] + batch_shape`,
      swap_tensor_fn(x) performs swaps where they are accepted, and does not
      swap otherwise.

  Returns:
    new_replica_results:  Same type as pre_swap_replica_results.

  Raises:
    NotImplementedError: If type of [nested] results is not handled.
  """

    kr = pre_swap_replica_results
    dtype = swapped_inverse_temperatures.dtype

    # Hard to modify proposed_results in an um-ambiguous manner.
    # ...we also don't need to.
    kr = _update_field(kr, 'proposed_results',
                       tf.convert_to_tensor(np.nan, dtype=dtype))
    kr = _update_field(kr, 'proposed_state',
                       tf.convert_to_tensor(np.nan, dtype=dtype))

    replica_and_batch_rank = ps.rank(_get_field(kr, 'log_accept_ratio'))

    # After using swap_tensor_fn on "values", values will be multiplied by the
    # swapped_inverse_temperatures.  We need it to be multiplied instead by the
    # inverse temperature corresponding to its index.
    it_ratio_raw = inverse_temperatures / swapped_inverse_temperatures
    it_ratio = tf.where(
        is_swap_accepted_mask,
        mcmc_util.left_justified_expand_dims_to(it_ratio_raw,
                                                replica_and_batch_rank),
        tf.convert_to_tensor(1.0, dtype=dtype))

    def _swap_then_retemper(x):
        x, is_multipart = mcmc_util.prepare_state_parts(x)
        it_ratio_ = mcmc_util.left_justified_expand_dims_like(it_ratio, x[0])
        x = [swap_tensor_fn(x_part) * it_ratio_ for x_part in x]
        if not is_multipart:
            x = x[0]
        return x

    kr = _update_field(kr, 'target_log_prob',
                       _swap_then_retemper(_get_field(kr, 'target_log_prob')))
    try:
        new_grads_target_log_prob = _swap_then_retemper(
            _get_field(kr, 'grads_target_log_prob'))
        kr = _update_field(kr, 'grads_target_log_prob',
                           new_grads_target_log_prob)
    # For transition kernels not involving the gradient of the log-probability,
    # grads_target_log_prob will not exist in the (possibly multiply wrapped)
    # kernel results and that's perfectly fine. But _get_field() / _update_field()
    # will throw a REMCFieldNotFoundError, which we thus catch silently.
    except REMCFieldNotFoundError:
        pass

    return kr