Ejemplo n.º 1
0
 def _where(proposed, current):
   """Wraps `tf.where`."""
   if proposed is current:
     return proposed
   # Preserve the name from `current` so names can propagate from
   # `bootstrap_results`.
   name = getattr(current, 'name', None)
   if name is not None:
     name = name.rpartition('/')[2].rsplit(':', 1)[0]
   # Since this is an internal utility it is ok to assume
   # tf.shape(proposed) == tf.shape(current).
   return tf.where(bu.left_justified_expand_dims_like(is_accepted, proposed),
                   proposed, current, name=name)
Ejemplo n.º 2
0
  def _replica_target_log_prob(*x):
    if tempered_log_prob_fn is not None:
      tlp = tempered_log_prob_fn(*x)
    else:
      tlp = target_log_prob_fn(*x)

    log_prob = tf.cast(bu.left_justified_expand_dims_like(
        inverse_temperatures, tlp), dtype=tlp.dtype) * tlp

    if untempered_log_prob_fn is not None:
      log_prob = log_prob + untempered_log_prob_fn(*x)

    return log_prob
def make_rwmh_kernel_fn(target_log_prob_fn, init_state, scalings):
    """Generate a Random Walk MH kernel."""
    with tf.name_scope('make_rwmh_kernel_fn'):
        state_std = [
            tf.math.reduce_std(x, axis=0, keepdims=True) for x in init_state
        ]
        step_size = [
            s * ps.cast(  # pylint: disable=g-complex-comprehension
                bu.left_justified_expand_dims_like(scalings, s), s.dtype)
            for s in state_std
        ]
        return random_walk_metropolis.RandomWalkMetropolis(
            target_log_prob_fn,
            new_state_fn=random_walk_metropolis.random_walk_normal_fn(
                scale=step_size))
def _make_momentum_distribution(running_variance_parts, state_parts,
                                batch_ndims):
    """Construct a momentum distribution from the running variance.

  This uses a running variance to construct a momentum distribution with the
  correct batch_shape and event_shape.

  Args:
    running_variance_parts: List of `Tensor`, outputs of
      `tfp.experimental.stats.RunningVariance.variance()`.
    state_parts: List of `Tensor`.
    batch_ndims: Scalar, for leading batch dimensions.

  Returns:
    `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
    and `.log_prob` of the sample will have the rank of `batch_ndims`
  """
    distributions = []
    for variance_part, state_part in zip(running_variance_parts, state_parts):
        running_variance_rank = ps.rank(variance_part)
        state_rank = ps.rank(state_part)
        event_shape = ps.shape(state_part)[batch_ndims:]
        nevt = ps.reduce_prod(event_shape)
        # Pad dimensions and tile by multiplying by tf.ones to add a batch shape
        ones = tf.ones(
            ps.shape(state_part)[:-(state_rank - running_variance_rank)],
            dtype=variance_part.dtype)
        ones = bu.left_justified_expand_dims_like(ones, state_part)
        variance_tiled = ones * variance_part
        variance_flattened = tf.reshape(
            variance_tiled,
            ps.concat([ps.shape(variance_tiled)[:batch_ndims], [nevt]],
                      axis=0))

        distributions.append(
            _CompositeTransformedDistribution(
                bijector=_CompositeReshape(event_shape_out=event_shape,
                                           event_shape_in=[nevt]),
                distribution=(
                    _CompositeMultivariateNormalPrecisionFactorLinearOperator(
                        precision_factor=_CompositeLinearOperatorDiag(
                            tf.math.sqrt(variance_flattened)),
                        precision=_CompositeLinearOperatorDiag(
                            variance_flattened)))))
    return _CompositeJointDistributionSequential(distributions)
  def _center_proposed_state(x):
    # The empirical mean here is a stand-in for the true mean, so we drop the
    # gradient that flows through this term. The goal here is to get a reliable
    # diagnostic of the unrelying dynamics, rather than incorporating the effect
    # of the MetropolisHastings correction.
    # TODO(mhoffman): Needs more experimentation.
    expanded_accept_prob = bu.left_justified_expand_dims_like(
        accept_prob, x)

    # accept_prob is zero when x is NaN, but we still want to sanitize such
    # values.
    x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
    # If all accept_prob's are zero, the x_center will have a nonsense value,
    # but we'll discard the resultant gradients later on, so it's fine.
    x_center = (
        tf.reduce_sum(expanded_accept_prob * x_safe, axis=batch_axes) /
        (tf.reduce_sum(expanded_accept_prob, axis=batch_axes) + 1e-20))

    return x - tf.stop_gradient(x_center)
Ejemplo n.º 6
0
  def _where(proposed, current):
    """Wraps `tf.where`."""
    if proposed is current:
      return proposed

    # Handle CompositeTensor types at the leafmost `addr`.
    flat_p = tf.nest.flatten(proposed, expand_composites=True)
    flat_c = tf.nest.flatten(current, expand_composites=True)

    res = []
    for p, c in zip(flat_p, flat_c):
      # Preserve the name from `current` so names can propagate from
      # `bootstrap_results`.
      name = getattr(c, 'name', None)
      if name is not None:
        name = name.rpartition('/')[2].rsplit(':', 1)[0]
      # Since this is an internal utility it is ok to assume
      # tf.shape(proposed) == tf.shape(current).
      res.append(
          tf.where(bu.left_justified_expand_dims_like(is_accepted, p), p, c,
                   name=name))
    return tf.nest.pack_sequence_as(current, res, expand_composites=True)
Ejemplo n.º 7
0
    def _center_proposed_state(x, x_mean):
        # Note that we don't do a monte carlo average of the accepted chain
        # position, but rather try to get an estimate of the underlying dynamics.
        # This is done by only looking at proposed states where the integration
        # error is low.
        expanded_accept_prob = bu.left_justified_expand_dims_like(
            accept_prob, x)

        # accept_prob is zero when x is NaN, but we still want to sanitize such
        # values.
        x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
        # The empirical mean here is a stand-in for the true mean, so we drop the
        # gradient that flows through this term.
        # If all accept_prob's are zero, the x_center will have a nonsense value,
        # but we'll discard the resultant gradients later on, so it's fine.
        emp_x_mean = tf.stop_gradient(
            distribute_lib.reduce_sum(expanded_accept_prob * x_safe,
                                      batch_axes, reduce_chain_axis_names) /
            (distribute_lib.reduce_sum(expanded_accept_prob, batch_axes,
                                       reduce_chain_axis_names) + 1e-20))

        x_mean = _mix_in_state_mean(emp_x_mean, x_mean)
        return x - x_mean
def _make_momentum_distribution(running_variance_parts, state_parts,
                                batch_ndims):
    """Construct a momentum distribution from the running variance.

  This uses a running variance to construct a momentum distribution with the
  correct batch_shape and event_shape.

  Args:
    running_variance_parts: List of `Tensor`, outputs of
      `tfp.experimental.stats.RunningVariance.variance()`.
    state_parts: List of `Tensor`.
    batch_ndims: Scalar, for leading batch dimensions.

  Returns:
    `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
    and `.log_prob` of the sample will have the rank of `batch_ndims`
  """
    distributions = []
    for variance_part, state_part in zip(running_variance_parts, state_parts):
        running_variance_rank = ps.rank(variance_part)
        state_rank = ps.rank(state_part)
        # Pad dimensions and tile by multiplying by tf.ones to add a batch shape
        ones = tf.ones(
            ps.shape(state_part)[:-(state_rank - running_variance_rank)],
            dtype=variance_part.dtype)
        ones = bu.left_justified_expand_dims_like(ones, state_part)
        variance_tiled = variance_part * ones
        reinterpreted_batch_ndims = state_rank - batch_ndims - 1

        distributions.append(
            _CompositeIndependent(
                _CompositeMultivariateNormalPrecisionFactorLinearOperator(
                    precision_factor=_CompositeLinearOperatorDiag(
                        tf.math.sqrt(variance_tiled)),
                    precision=_CompositeLinearOperatorDiag(variance_tiled)),
                reinterpreted_batch_ndims=reinterpreted_batch_ndims))
    return _CompositeJointDistributionSequential(distributions)
Ejemplo n.º 9
0
    def _loop_tree_doubling(self, step_size, momentum_state_memory,
                            current_step_meta_info, iter_, initial_step_state,
                            initial_step_metastate, seed):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            (direction_seed, subtree_seed, acceptance_seed,
             next_seed) = samplers.split_seed(seed, n=4)
            batch_shape = ps.shape(current_step_meta_info.init_energy)
            direction = tf.cast(samplers.uniform(shape=batch_shape,
                                                 minval=0,
                                                 maxval=2,
                                                 dtype=tf.int32,
                                                 seed=direction_seed),
                                dtype=tf.bool)

            tree_start_states = tf.nest.map_structure(
                lambda v: bu.where_left_justified_mask(direction, v[1], v[0]),
                initial_step_state)

            directions_expanded = [
                bu.left_justified_expand_dims_like(direction, state)
                for state in tree_start_states.state
            ]

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(d, ss, -ss)
                    for d, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state, tree_final_states, final_not_divergence,
                continue_tree_final, energy_diff_tree_sum,
                momentum_subtree_cumsum, leapfrogs_taken
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                current_step_meta_info,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory,
                seed=subtree_seed)

            last_candidate_state = initial_step_metastate.candidate_state

            energy_diff_sum = (energy_diff_tree_sum +
                               initial_step_metastate.energy_diff_sum)
            if MULTINOMIAL_SAMPLE:
                tree_weight = tf.where(
                    continue_tree_final, candidate_tree_state.weight,
                    tf.constant(-np.inf,
                                dtype=candidate_tree_state.weight.dtype))
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                tree_weight = tf.where(continue_tree_final,
                                       candidate_tree_state.weight,
                                       tf.zeros([], dtype=TREE_COUNT_DTYPE))
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-samplers.uniform(shape=batch_shape,
                                                dtype=log_accept_thresh.dtype,
                                                seed=acceptance_seed))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    bu.where_left_justified_mask(choose_new_state, s0, s1)
                    for s0, s1 in zip(candidate_tree_state.state,
                                      last_candidate_state.state)
                ],
                target=bu.where_left_justified_mask(
                    choose_new_state, candidate_tree_state.target,
                    last_candidate_state.target),
                target_grad_parts=[
                    bu.where_left_justified_mask(choose_new_state, grad0,
                                                 grad1)
                    for grad0, grad1 in zip(
                        candidate_tree_state.target_grad_parts,
                        last_candidate_state.target_grad_parts)
                ],
                energy=bu.where_left_justified_mask(
                    choose_new_state, candidate_tree_state.energy,
                    last_candidate_state.energy),
                weight=weight_sum)

            for new_candidate_state_temp, old_candidate_state_temp in zip(
                    new_candidate_state.state, last_candidate_state.state):
                tensorshape_util.set_shape(new_candidate_state_temp,
                                           old_candidate_state_temp.shape)

            for new_candidate_grad_temp, old_candidate_grad_temp in zip(
                    new_candidate_state.target_grad_parts,
                    last_candidate_state.target_grad_parts):
                tensorshape_util.set_shape(new_candidate_grad_temp,
                                           old_candidate_grad_temp.shape)

            # Update left right information of the trajectory, and check trajectory
            # level U turn
            tree_otherend_states = tf.nest.map_structure(
                lambda v: bu.where_left_justified_mask(direction, v[0], v[1]),
                initial_step_state)

            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    tf.stack(
                        [  # pylint: disable=g-complex-comprehension
                            bu.where_left_justified_mask(
                                direction, right, left),
                            bu.where_left_justified_mask(
                                direction, left, right),
                        ],
                        axis=0) for left, right in zip(
                            tf.nest.flatten(tree_final_states),
                            tf.nest.flatten(tree_otherend_states))
                ])

            momentum_tree_cumsum = []
            for p0, p1 in zip(initial_step_metastate.momentum_sum,
                              momentum_subtree_cumsum):
                momentum_part_temp = p0 + p1
                tensorshape_util.set_shape(momentum_part_temp, p0.shape)
                momentum_tree_cumsum.append(momentum_part_temp)

            for new_state_temp, old_state_temp in zip(
                    tf.nest.flatten(new_step_state),
                    tf.nest.flatten(initial_step_state)):
                tensorshape_util.set_shape(new_state_temp,
                                           old_state_temp.shape)

            if GENERALIZED_UTURN:
                state_diff = momentum_tree_cumsum
            else:
                state_diff = [s[1] - s[0] for s in new_step_state.state]

            no_u_turns_trajectory = has_not_u_turn(
                state_diff, [m[0] for m in new_step_state.momentum],
                [m[1] for m in new_step_state.momentum],
                log_prob_rank=ps.rank_from_shape(batch_shape))

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                momentum_sum=momentum_tree_cumsum,
                energy_diff_sum=energy_diff_sum,
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, next_seed, new_step_state, new_step_metastate
Ejemplo n.º 10
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                target_log_prob_fn=self.target_log_prob_fn,
                inverse_temperatures=inverse_temperatures,
                untempered_log_prob_fn=self.untempered_log_prob_fn,
                tempered_log_prob_fn=self.tempered_log_prob_fn,
            )
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` '
                    'argument. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3)
            # Step the inner TransitionKernel.
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                seed=inner_seed)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = bu.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = bu.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs for use in the swap acceptance ratio.
            if self.tempered_log_prob_fn is None:
                # Efficient way of re-evaluating target_log_prob_fn on the
                # pre_swap_replica_states.
                untempered_negative_energy_ignoring_ulp = (
                    # Since untempered_log_prob_fn is None, we may assume
                    # inverse_temperatures > 0 (else the target is improper).
                    pre_swap_replica_target_log_prob / inverse_temperatures)
            else:
                # The untempered_log_prob_fn does not factor into the acceptance ratio.
                # Proof: Suppose the tempered target is
                #   p_k(x) = f(x)^{beta_k} g(x),
                # So f(x) is tempered, and g(x) is not.  Then, the acceptance ratio for
                # a 1 <--> 2 swap is...
                #   (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2))
                # which depends only on f(x), since terms involving g(x) cancel.
                untempered_negative_energy_ignoring_ulp = self.tempered_log_prob_fn(
                    *pre_swap_replica_states)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            # Note: The untempered_log_prob_fn (if provided) is not included in
            # untempered_pre_swap_replica_target_log_prob, and hence does not factor
            # into energy_diff. Why? Because, it cancels out in the acceptance ratio.
            energy_diff = (untempered_negative_energy_ignoring_ulp -
                           mcmc_util.index_remapping_gather(
                               untempered_negative_energy_ignoring_ulp,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff * bu.left_justified_expand_dims_to(
                inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = bu.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                bu.left_justified_broadcast_to(swaps, replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _set_swapped_fields_to_nan(
                _swap_log_prob_and_maybe_grads(pre_swap_replica_results,
                                               post_swap_replica_states,
                                               inner_kernel))

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=seed,
                potential_energy=-untempered_negative_energy_ignoring_ulp,
            )

            return states, post_swap_kernel_results
Ejemplo n.º 11
0
def _normalize(x, axis, named_axis):
    norm = tf.sqrt(_dot_product(x, x, axis, named_axis)) + 1e-20
    return tf.nest.map_structure(
        lambda x: x / bu.left_justified_expand_dims_like(norm, x), x)
Ejemplo n.º 12
0
 def _weighted_sum_part(x):
     return distribute_lib.reduce_sum(
         bu.left_justified_expand_dims_like(state_dot_p, x) * x,
         reduce_axes, self.experimental_reduce_chain_axis_names)
Ejemplo n.º 13
0
    def __init__(self,
                 design_matrix,
                 nonzero_prior_prob=0.5,
                 weights_prior_precision=None,
                 default_pseudo_observations=1.,
                 observation_noise_variance_prior_concentration=0.005,
                 observation_noise_variance_prior_scale=0.0025,
                 observation_noise_variance_upper_bound=None,
                 num_missing=0.):
        """Initializes priors for the spike and slab sampler.

    Args:
      design_matrix: (batch of) float `Tensor`(s) regression design matrix (`X`
        in [1]) having shape `[num_outputs, num_features]`.
      nonzero_prior_prob: scalar float `Tensor` prior probability of the 'slab',
        i.e., prior probability that any given feature has nonzero weight (`pi`
        in [1]). Default value: `0.5`.
      weights_prior_precision: (batch of) float `Tensor` complete prior
        precision matrix(s) over the weights, of shape `[num_features,
        num_features]`. If not specified, defaults to the Zellner g-prior
        specified in `[1]` as `Omega^{-1} = kappa * (X'X + diag(X'X)) / (2 *
        num_outputs)`, in which we've plugged in the suggested default of `w =
        0.5`. The parameter `kappa` is controlled by the
        `default_pseudo_observations` argument. Default value: `None`.
      default_pseudo_observations: scalar float `Tensor` Controls the number of
        pseudo-observations for the prior precision matrix over the weights.
        Corresponds to `kappa` in [1]. See also `weights_prior_precision`.
      observation_noise_variance_prior_concentration: scalar float `Tensor`
        concentration parameter of the inverse gamma prior on the noise
        variance. Corresponds to `nu / 2` in [1]. Default value: 0.005.
      observation_noise_variance_prior_scale: scalar float `Tensor` scale
        parameter of the inverse gamma prior on the noise variance. Corresponds
        to `ss / 2` in [1]. Default value: 0.0025.
      observation_noise_variance_upper_bound: optional scalar float `Tensor`
        maximum value of sampled observation noise variance. Specifying a bound
        can help avoid divergence when the sampler is initialized far from the
        posterior. Default value: `None`.
      num_missing: Optional scalar float `Tensor`. Corrects for how many missing
        values are are coded as zero in the design matrix.
    """
        with tf.name_scope('spike_slab_sampler'):
            dtype = dtype_util.common_dtype([
                design_matrix, nonzero_prior_prob, weights_prior_precision,
                observation_noise_variance_prior_concentration,
                observation_noise_variance_prior_scale,
                observation_noise_variance_upper_bound, num_missing
            ],
                                            dtype_hint=tf.float32)
            design_matrix = tf.convert_to_tensor(design_matrix, dtype=dtype)
            nonzero_prior_prob = tf.convert_to_tensor(nonzero_prior_prob,
                                                      dtype=dtype)
            observation_noise_variance_prior_concentration = tf.convert_to_tensor(
                observation_noise_variance_prior_concentration, dtype=dtype)
            observation_noise_variance_prior_scale = tf.convert_to_tensor(
                observation_noise_variance_prior_scale, dtype=dtype)
            num_missing = tf.convert_to_tensor(num_missing, dtype=dtype)
            if observation_noise_variance_upper_bound is not None:
                observation_noise_variance_upper_bound = tf.convert_to_tensor(
                    observation_noise_variance_upper_bound, dtype=dtype)

            design_shape = ps.shape(design_matrix)
            num_outputs = tf.cast(design_shape[-2], dtype=dtype) - num_missing
            num_features = design_shape[-1]

            x_transpose_x = tf.matmul(design_matrix,
                                      design_matrix,
                                      adjoint_a=True)
            if weights_prior_precision is None:
                # Default prior: 'Zellner’s g−prior' from section 3.2.1 of [1]:
                #   `omega^{-1} = kappa * (w X'X + (1 − w) diag(X'X))/n`
                # with default `w = 0.5`.
                padded_inputs = broadcast_util.left_justified_expand_dims_like(
                    num_outputs, x_transpose_x)
                weights_prior_precision = default_pseudo_observations * tf.linalg.set_diag(
                    0.5 * x_transpose_x,
                    tf.linalg.diag_part(x_transpose_x)) / padded_inputs

            observation_noise_variance_posterior_concentration = (
                observation_noise_variance_prior_concentration +
                tf.convert_to_tensor(num_outputs / 2., dtype=dtype))

            self.num_outputs = num_outputs
            self.num_features = num_features
            self.design_matrix = design_matrix
            self.x_transpose_x = x_transpose_x
            self.dtype = dtype
            self.nonzeros_prior = sample_dist.Sample(
                bernoulli.Bernoulli(probs=nonzero_prior_prob),
                sample_shape=[num_features])
            self.weights_prior_precision = weights_prior_precision
            self.observation_noise_variance_prior_concentration = (
                observation_noise_variance_prior_concentration)
            self.observation_noise_variance_prior_scale = (
                observation_noise_variance_prior_scale)
            self.observation_noise_variance_upper_bound = (
                observation_noise_variance_upper_bound)
            self.observation_noise_variance_posterior_concentration = (
                observation_noise_variance_posterior_concentration)
Ejemplo n.º 14
0
def remc_thermodynamic_integrals(
    inverse_temperatures,
    potential_energy,
    iid_chain_ndims=0,
):
    """Estimate thermodynamic integrals using results of ReplicaExchangeMC.

  Write the density, when tempering with inverse temperature `b`, as
  `p_b(x) = exp(-b * U(x)) f(x) / Z_b`. Here `Z_b` is a normalizing constant,
  and `U(x)` is the potential energy. f(x) is the untempered part, if any.

  Let `E_b[U(X)]` be the expected potential energy when `X ~ p_b`. Then,
  `-1 * integral_c^d E_b[U(X)] db = log[Z_d / Z_c]`, the log normalizing
  constant ratio.

  Let `Var_b[U(X)] be the variance of potential energy when `X ~ p_b(x)`. Then,
  `integral_c^d Var_b[U(X)] db = E_d[U(X)] - E_c[U(X)]`, the cross entropy
  difference.

  Integration is done via the trapezoidal rule. Assume `E_b[U(X)]` and
  `Var_b[U(X)]` have bounded second derivatives, uniform in `b`. Then, the
  bias due to approximation of the integral by a summation is `O(1 / K^2)`.

  Suppose `U(X)`, `X ~ p_b` has bounded fourth moment, uniform in `b`. Suppose
  further that the swap acceptance rate between every adjacent pair is greater
  than `C_s > 0`.  If we have `N` effective samples from each of the `n_replica`
  replicas, then the standard error of the summation is
  `O(1 / Sqrt(n_replica * N))`.

  Args:
    inverse_temperatures: `Tensor` of shape `[n_replica, ...]`, used to temper
      `n_replica` replicas. Assumed to be decreasing with respect to the replica
      index.
    potential_energy: The `potential_energy` field of
      `ReplicaExchangeMCKernelResults`, shape `[n_samples, n_replica, ...]`.
      If the kth replica has density `p_k(x) = exp(-beta_k * U(x)) * f_k(x)`,
      then `potential_energy[k]` is `U(X)`, where `X ~ p_k`.
    iid_chain_ndims: Number of dimensions in `potential_energy`, to the
      right of the replica dimension, that index independent identically
      distributed chains. In particular, the temperature for these chains should
      be identical. The sample means will be computed over these dimensions.

  Returns:
    ReplicaExchangeMCThermodynamicIntegrals namedtuple.
  """
    dtype = dtype_util.common_dtype([inverse_temperatures, potential_energy],
                                    dtype_hint=tf.float32)
    inverse_temperatures = tf.convert_to_tensor(inverse_temperatures,
                                                dtype=dtype)
    potential_energy = tf.convert_to_tensor(potential_energy, dtype=dtype)

    # mean is E[U(beta)].
    # Reduction is over samples and (possibly) independent chains.
    # Squeeze out the singleton left over from samples in axis=0.
    # Keepdims so we can broadcast with inverse_temperatures, which *may* have
    # additional batch dimensions.
    iid_axis = ps.concat([[0], ps.range(2, 2 + iid_chain_ndims)], axis=0)
    mean = tf.reduce_mean(potential_energy, axis=iid_axis, keepdims=True)[0]
    var = sample_stats.variance(potential_energy,
                                sample_axis=iid_axis,
                                keepdims=True)[0]

    # Integrate over the single temperature dimension.
    # dx[k] = beta_k - beta_{k+1} > 0.
    dx = bu.left_justified_expand_dims_like(
        inverse_temperatures[:-1] - inverse_temperatures[1:], mean)

    def _trapz(y):
        avg_y = 0.5 * (y[:-1] + y[1:])
        return tf.reduce_sum(avg_y * dx, axis=0)

    def _squeeze_chains(x):
        # Squeeze with a reshape, since squeeze can't use tensors.
        return tf.reshape(x, ps.shape(x)[iid_chain_ndims:])

    return ReplicaExchangeMCThermodynamicIntegrals(
        log_normalizing_constant_ratio=-_squeeze_chains(_trapz(mean)),
        cross_entropy_difference=_squeeze_chains(_trapz(var)),
    )
def compute_hmc_step_size(scalings, state_std, num_leapfrog_steps):
    return [
        s / ps.cast(num_leapfrog_steps, s.dtype) * ps.cast(  # pylint: disable=g-complex-comprehension
            bu.left_justified_expand_dims_like(scalings, s), s.dtype)
        for s in state_std
    ]
 def adjust_state(x, v, shard_axes=None):
   broadcasted_dt = distribute_lib.pbroadcast(
       bu.left_justified_expand_dims_like(dt, v), shard_axes)
   return x + broadcasted_dt * v