Ejemplo n.º 1
0
def _log_bessel_kve_jvp(primals, tangents):
  """Computes JVP for bessel_kve (supports JAX custom derivative)."""
  v, z = primals
  _, dz = tangents

  dtype = dtype_util.common_dtype([v, z], tf.float32)
  numpy_dtype = dtype_util.as_numpy_dtype(dtype)

  # TODO(https://github.com/google/jax/issues/3768): eliminate broadcast_to?
  bc_shp = ps.broadcast_shape(ps.shape(v), ps.shape(dz))
  dz = tf.broadcast_to(dz, bc_shp)

  log_kve = _log_bessel_kve_custom_gradient(v, z)
  pz = tfp_math.log_add_exp(
      _log_bessel_kve_custom_gradient(v - 1., z),
      _log_bessel_kve_custom_gradient(v + 1., z)) - numpy_dtype(
          np.log(2.)) - log_kve
  pz = -tf.math.expm1(pz)

  # `bessel_kve` does not have gradients with respect to `v`, and thus
  # this `JVP` rule matches TF.
  # Ideally, it would be nice to throw an exception when taking gradients of
  # in JAX mode, but this is not possible at the moment with `custom_jvp`.
  # See https://github.com/google/jax/issues/5913 for details.
  # TODO(https://github.com/google/jax/issues/5913): Define vjp for v.

  return log_kve, pz * dz
 def mutate_onestep(i, state, pkr, log_accept_prob_sum):
   next_state, next_kernel_results = kernel.one_step(state, pkr)
   kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)
   log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.)
   log_accept_prob_sum = log_add_exp(
       log_accept_prob_sum, log_accept_prob)
   return i + 1, next_state, next_kernel_results, log_accept_prob_sum
def simple_heuristic_tuning(num_steps,
                            log_scalings,
                            log_accept_prob,
                            optimal_accept=0.234,
                            target_accept_prob=0.99,
                            name=None):
  """Tune the number of steps and scaling of one mutation.

  # TODO(b/152412213): Better explanation of the heuristic used here.

  This is a simple heuristic for tuning the number of steps of the next
  mutation, as well as the scaling of a transition kernel (e.g., step size in
  HMC, scale of a Normal proposal in RWMH) using the acceptance probability from
  the previous mutation stage in SMC.

  Args:
    num_steps: The initial number of steps for the next mutation, to be tune.
    log_scalings: The log of the scale of the proposal kernel
    log_accept_prob: The log of the acceptance ratio from the last mutation.
    optimal_accept: Optimal acceptance ratio for a Transitional Kernel. Default
      value is 0.234 (Optimal for Random Walk Metropolis kernel).
    target_accept_prob: Target acceptance probability at the end of one mutation
      step. Default value: 0.99
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None`.

  Returns:
    num_steps: The number of steps for the next mutation.
    new_log_scalings: The log of the scale of the proposal kernel for the next
      mutation.

  """
  with tf.name_scope(name or 'simple_heuristic_tuning'):
    optimal_accept = tf.constant(optimal_accept, dtype=log_accept_prob.dtype)
    target_accept_prob = tf.constant(
        target_accept_prob, dtype=log_accept_prob.dtype)
    log_half_constant = tf.constant(np.log(.5), dtype=log_scalings.dtype)

    avg_log_scalings = reduce_logmeanexp(log_scalings)
    avg_log_accept_prob = reduce_logmeanexp(log_accept_prob)

    avg_log_scaling_target = avg_log_scalings + (
        tf.exp(avg_log_accept_prob) - optimal_accept)
    new_log_scalings = log_half_constant + log_add_exp(
        avg_log_scaling_target,
        log_scalings + (tf.exp(log_accept_prob) - optimal_accept)
        )

    num_replica = ps.size0(log_accept_prob)
    num_proposed = tf.cast(
        num_replica * num_steps, dtype=avg_log_accept_prob.dtype)
    log_avg_accept = tf.math.maximum(-tf.math.log(num_proposed),
                                     avg_log_accept_prob)
    num_steps = tf.cast(
        tf.math.log1p(-target_accept_prob) / log1mexp(log_avg_accept),
        dtype=num_steps.dtype)
    return num_steps, new_log_scalings
Ejemplo n.º 4
0
 def bessel_recurrence(index, kve, kvep1):
   if output_log_space:
     next_kvep1 = tfp_math.log_add_exp(
         kvep1 + tf.math.log(u + index) +
         numpy_dtype(np.log(2.)) - tf.math.log(x_abs), kve)
   else:
     next_kvep1 = 2 * (u + index) * kvep1 / x_abs + kve
   kve = tf.where(index > n, kve, kvep1)
   kvep1 = tf.where(index > n, kvep1, next_kvep1)
   return index + 1., kve, kvep1
def _hyp2f1_z_near_one(a, b, c, z):
  """"Compute 2F1(a, b, c, z) when z is near 1."""
  with tf.name_scope('hyp2f1_z_near_one'):
    dtype = dtype_util.common_dtype([a, b, c, z], tf.float32)
    a = tf.convert_to_tensor(a, dtype=dtype)
    b = tf.convert_to_tensor(b, dtype=dtype)
    c = tf.convert_to_tensor(c, dtype=dtype)
    z = tf.convert_to_tensor(z, dtype=dtype)

    # When z > 0.5, We can transform z to 1 - z and make use of a hypergeometric
    # identity.

    d = c - a - b

    # TODO(b/171982819): When tfp.math.log_gamma_difference and tfp.math.lbeta
    # support negative parameters, use them here for greater accuracy.
    log_first_coefficient = (tf.math.lgamma(c) + tf.math.lgamma(d) -
                             tf.math.lgamma(c - a) - tf.math.lgamma(c - b))

    sign_first_coefficient = (
        _gamma_negative(c) ^ _gamma_negative(d) ^
        _gamma_negative(c - a) ^ _gamma_negative(c - b))
    sign_first_coefficient = -2. * tf.cast(sign_first_coefficient, dtype) + 1.

    log_second_coefficient = (
        tf.math.xlog1py(d, -z) +
        tf.math.lgamma(c) + tf.math.lgamma(-d) -
        tf.math.lgamma(a) - tf.math.lgamma(b))

    sign_second_coefficient = (
        _gamma_negative(c) ^ _gamma_negative(a) ^ _gamma_negative(b) ^
        _gamma_negative(-d))
    sign_second_coefficient = -2. * tf.cast(sign_second_coefficient, dtype) + 1.

    first_term = _hyp2f1_internal(a, b, 1 - d, 1 - z)
    second_term = _hyp2f1_internal(c - a, c - b, d + 1., 1 - z)
    log_first_term = log_first_coefficient + tf.math.log(
        tf.math.abs(first_term))
    log_second_term = log_second_coefficient + tf.math.log(
        tf.math.abs(second_term))

    sign_first_term = sign_first_coefficient * tf.math.sign(first_term)
    sign_second_term = sign_second_coefficient * tf.math.sign(second_term)
    log_diff, sign_log_diff = tfp_math.log_sub_exp(
        log_first_term, log_second_term, return_sign=True)
    sign = tf.where(
        tf.math.equal(sign_first_term, sign_second_term),
        sign_first_term,
        sign_first_term * sign_log_diff)
    log_result = tf.where(
        tf.math.equal(sign_first_term, sign_second_term),
        tfp_math.log_add_exp(log_first_term, log_second_term),
        log_diff)
    return tf.math.exp(log_result) * sign
                def mutate_onestep(i, seed, state, pkr, log_accept_prob_sum):
                    iter_seed, next_seed = (samplers.split_seed(seed)
                                            if is_seeded else (None, seed))

                    one_step_kwargs = dict(seed=iter_seed) if is_seeded else {}
                    next_state, next_kernel_results = kernel.one_step(
                        state, pkr, **one_step_kwargs)
                    kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)
                    log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.)
                    log_accept_prob_sum = log_add_exp(log_accept_prob_sum,
                                                      log_accept_prob)
                    return [
                        i + 1, next_seed, next_state, next_kernel_results,
                        log_accept_prob_sum
                    ]
Ejemplo n.º 7
0
def _log_soosum_exp_impl(logx, axis, keepdims, compute_mean):
    """Implementation for `*soosum*` functions."""
    with tf.name_scope('log_soosum_exp_impl'):
        logx = tf.convert_to_tensor(logx, name='logx')
        log_loosum_x, log_sum_x, n = _log_loosum_exp_impl(logx,
                                                          axis,
                                                          keepdims,
                                                          compute_mean=False)
        # The swap-one-out-sum ('soosum') is n different sums, each of which
        # replaces the i-th item with the i-th-left-out average (or the user
        # specified value), i.e.,
        # soo_sum_x[i] = [exp(logx) - exp(logx[i])] + exp(mean(logx[!=i]))
        #              =  exp(log_loosum_x[i])      + exp(loo_log_swap_in[i])
        loo_log_swap_in = (
            (tf.reduce_sum(logx, axis=axis, keepdims=True) - logx) / (n - 1.))
        log_soosum_x = log_add_exp(log_loosum_x, loo_log_swap_in)
        if not compute_mean:
            return log_soosum_x, log_sum_x
        log_n = prefer_static.log(n)
        return log_soosum_x - log_n, log_sum_x - log_n
Ejemplo n.º 8
0
def _log_bessel_kve_bwd(aux, g):
  """Reverse mode impl for bessel_kve."""
  v, z = aux
  dtype = dtype_util.common_dtype([v, z], tf.float32)
  numpy_dtype = dtype_util.as_numpy_dtype(dtype)

  log_kve = _log_bessel_kve_custom_gradient(v, z)
  grad_z = tfp_math.log_add_exp(
      _log_bessel_kve_custom_gradient(v - 1., z),
      _log_bessel_kve_custom_gradient(v + 1., z)) - numpy_dtype(
          np.log(2.)) - log_kve
  grad_z = g * -tf.math.expm1(grad_z)
  _, grad_z = _fix_gradient_for_broadcasting(
      v, z, tf.ones_like(grad_z), grad_z)

  # No gradient for v at the moment. This is a complicated expression
  # The gradient with respect to the parameter doesn't have an easy closed
  # form. More work will need to be done to ensure good numerics for the
  # gradient.
  # TODO(b/169357627): Implement gradients of modified bessel functions with
  # respect to parameters.

  return None, grad_z
Ejemplo n.º 9
0
    def _loop_build_sub_tree(self, directions, integrator,
                             current_step_meta_info, iter_,
                             energy_diff_sum_previous,
                             momentum_cumsum_previous, leapfrogs_taken,
                             prev_tree_state, candidate_tree_state,
                             continue_tree_previous, not_divergent_previous,
                             momentum_state_memory):
        """Base case in tree doubling."""
        with tf.name_scope('loop_build_sub_tree'):
            # Take one leapfrog step in the direction v and check divergence
            [
                next_momentum_parts, next_state_parts, next_target,
                next_target_grad_parts
            ] = integrator(prev_tree_state.momentum, prev_tree_state.state,
                           prev_tree_state.target,
                           prev_tree_state.target_grad_parts)

            next_tree_state = TreeDoublingState(
                momentum=next_momentum_parts,
                state=next_state_parts,
                target=next_target,
                target_grad_parts=next_target_grad_parts)
            momentum_cumsum = [
                p0 + p1 for p0, p1 in zip(momentum_cumsum_previous,
                                          next_momentum_parts)
            ]
            # If the tree have not yet terminated previously, we count this leapfrog.
            leapfrogs_taken = tf.where(continue_tree_previous,
                                       leapfrogs_taken + 1, leapfrogs_taken)

            write_instruction = current_step_meta_info.write_instruction
            read_instruction = current_step_meta_info.read_instruction
            init_energy = current_step_meta_info.init_energy

            if GENERALIZED_UTURN:
                state_to_write = momentum_cumsum_previous
                state_to_check = momentum_cumsum
            else:
                state_to_write = next_state_parts
                state_to_check = next_state_parts

            batch_shape = prefer_static.shape(next_target)
            has_not_u_turn_init = prefer_static.ones(batch_shape,
                                                     dtype=tf.bool)

            read_index = read_instruction.gather([iter_])[0]
            no_u_turns_within_tree = has_not_u_turn_at_all_index(  # pylint: disable=g-long-lambda
                read_index,
                directions,
                momentum_state_memory,
                next_momentum_parts,
                state_to_check,
                has_not_u_turn_init,
                log_prob_rank=prefer_static.rank(next_target))

            # Get index to write state into memory swap
            write_index = write_instruction.gather([iter_])
            momentum_state_memory = MomentumStateSwap(
                momentum_swap=[
                    tf.tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.momentum_swap,
                                        next_momentum_parts)
                ],
                state_swap=[
                    tf.tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.state_swap,
                                        state_to_write)
                ])

            energy = compute_hamiltonian(next_target, next_momentum_parts)
            current_energy = tf.where(tf.math.is_nan(energy),
                                      tf.constant(-np.inf, dtype=energy.dtype),
                                      energy)
            energy_diff = current_energy - init_energy

            if MULTINOMIAL_SAMPLE:
                not_divergent = -energy_diff < self.max_energy_diff
                weight_sum = log_add_exp(candidate_tree_state.weight,
                                         energy_diff)
                log_accept_thresh = energy_diff - weight_sum
            else:
                log_slice_sample = current_step_meta_info.log_slice_sample
                not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
                # Uniform sampling on the trajectory within the subtree across valid
                # samples.
                is_valid = log_slice_sample <= energy_diff
                weight_sum = tf.where(is_valid,
                                      candidate_tree_state.weight + 1,
                                      candidate_tree_state.weight)
                log_accept_thresh = tf.where(
                    is_valid,
                    -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
                    tf.constant(-np.inf, dtype=tf.float32))
            u = tf.math.log1p(-tf.random.uniform(shape=batch_shape,
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            next_candidate_tree_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(is_sample_accepted,
                                                  prefer_static.rank(s0)), s0,
                        s1) for s0, s1 in zip(next_state_parts,
                                              candidate_tree_state.state)
                ],
                target=tf.where(
                    _rightmost_expand_to_rank(is_sample_accepted,
                                              prefer_static.rank(next_target)),
                    next_target, candidate_tree_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(is_sample_accepted,
                                                  prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            next_target_grad_parts,
                            candidate_tree_state.target_grad_parts)
                ],
                energy=tf.where(
                    _rightmost_expand_to_rank(is_sample_accepted,
                                              prefer_static.rank(next_target)),
                    current_energy, candidate_tree_state.energy),
                weight=weight_sum)

            continue_tree = not_divergent & continue_tree_previous
            continue_tree_next = no_u_turns_within_tree & continue_tree

            not_divergent_tokeep = tf.where(
                continue_tree_previous, not_divergent,
                prefer_static.ones(batch_shape, dtype=tf.bool))

            # min(1., exp(energy_diff)).
            exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.))
            energy_diff_sum = tf.where(
                continue_tree, energy_diff_sum_previous + exp_energy_diff,
                energy_diff_sum_previous)

            return (
                iter_ + 1,
                energy_diff_sum,
                momentum_cumsum,
                leapfrogs_taken,
                next_tree_state,
                next_candidate_tree_state,
                continue_tree_next,
                not_divergent_previous & not_divergent_tokeep,
                momentum_state_memory,
            )
Ejemplo n.º 10
0
    def loop_tree_doubling(self, step_size, momentum_state_memory,
                           current_step_meta_info, iter_, initial_step_state,
                           initial_step_metastate):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            batch_shape = prefer_static.shape(
                current_step_meta_info.init_energy)
            direction = tf.cast(tf.random.uniform(shape=batch_shape,
                                                  minval=0,
                                                  maxval=2,
                                                  dtype=tf.int32,
                                                  seed=self._seed_stream()),
                                dtype=tf.bool)

            tree_start_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[1], v[0]),
                initial_step_state)

            directions_expanded = [
                _rightmost_expand_to_rank(direction, prefer_static.rank(state))
                for state in tree_start_states.state
            ]

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(d, ss, -ss)
                    for d, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state, tree_final_states, final_not_divergence,
                continue_tree_final, energy_diff_tree_sum,
                momentum_subtree_cumsum, leapfrogs_taken
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                current_step_meta_info,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory)

            last_candidate_state = initial_step_metastate.candidate_state

            energy_diff_sum = (energy_diff_tree_sum +
                               initial_step_metastate.energy_diff_sum)
            if MULTINOMIAL_SAMPLE:
                tree_weight = tf.where(
                    continue_tree_final, candidate_tree_state.weight,
                    tf.constant(-np.inf,
                                dtype=candidate_tree_state.weight.dtype))
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                tree_weight = tf.where(continue_tree_final,
                                       candidate_tree_state.weight,
                                       tf.zeros([], dtype=TREE_COUNT_DTYPE))
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-tf.random.uniform(shape=batch_shape,
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(s0)), s0,
                        s1) for s0, s1 in zip(candidate_tree_state.state,
                                              last_candidate_state.state)
                ],
                target=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.target, last_candidate_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            candidate_tree_state.target_grad_parts,
                            last_candidate_state.target_grad_parts)
                ],
                energy=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.energy, last_candidate_state.energy),
                weight=weight_sum)

            for new_candidate_state_temp, old_candidate_state_temp in zip(
                    new_candidate_state.state, last_candidate_state.state):
                tensorshape_util.set_shape(new_candidate_state_temp,
                                           old_candidate_state_temp.shape)

            for new_candidate_grad_temp, old_candidate_grad_temp in zip(
                    new_candidate_state.target_grad_parts,
                    last_candidate_state.target_grad_parts):
                tensorshape_util.set_shape(new_candidate_grad_temp,
                                           old_candidate_grad_temp.shape)

            # Update left right information of the trajectory, and check trajectory
            # level U turn
            tree_otherend_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[0], v[1]),
                initial_step_state)

            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    tf.stack(
                        [  # pylint: disable=g-complex-comprehension
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), r, l),
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), l, r),
                        ],
                        axis=0)
                    for l, r in zip(tf.nest.flatten(tree_final_states),
                                    tf.nest.flatten(tree_otherend_states))
                ])

            momentum_tree_cumsum = []
            for p0, p1 in zip(initial_step_metastate.momentum_sum,
                              momentum_subtree_cumsum):
                momentum_part_temp = p0 + p1
                tensorshape_util.set_shape(momentum_part_temp, p0.shape)
                momentum_tree_cumsum.append(momentum_part_temp)

            for new_state_temp, old_state_temp in zip(
                    tf.nest.flatten(new_step_state),
                    tf.nest.flatten(initial_step_state)):
                tensorshape_util.set_shape(new_state_temp,
                                           old_state_temp.shape)

            if GENERALIZED_UTURN:
                state_diff = momentum_tree_cumsum
            else:
                state_diff = [s[1] - s[0] for s in new_step_state.state]

            no_u_turns_trajectory = has_not_u_turn(
                state_diff, [m[0] for m in new_step_state.momentum],
                [m[1] for m in new_step_state.momentum],
                log_prob_rank=prefer_static.rank_from_shape(batch_shape))

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                momentum_sum=momentum_tree_cumsum,
                energy_diff_sum=energy_diff_sum,
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, new_step_state, new_step_metastate
Ejemplo n.º 11
0
def _temme_expansion(v, x, output_log_space=False):
  """Compute modified bessel functions using Temme's method."""
  # The implementation of this is based on [1].
  # [1] N. Temme, On the Numerical Evaluation of the Modified Bessel Function
  #   of the Third Kind. Journal of Computational Physics 19, 1975.
  dtype = dtype_util.common_dtype([v, x], tf.float32)
  numpy_dtype = dtype_util.as_numpy_dtype(dtype)
  v_less_than_zero = v < 0.
  v = tf.math.abs(v)
  n = tf.math.round(v)
  # Use this to compute Kv(u, x) and Kv(u + 1., x)
  u = v - n
  x_abs = tf.math.abs(x)

  small_x = tf.where(x_abs <= 2., x_abs, numpy_dtype(0.1))
  large_x = tf.where(x_abs > 2., x_abs, numpy_dtype(1000.))
  temme_kue, temme_kuep1 = _temme_series(
      u, small_x, output_log_space=output_log_space)
  cf_kue, cf_kuep1 = _continued_fraction_kv(
      u, large_x, output_log_space=output_log_space)

  kue = tf.where(x_abs <= 2., temme_kue, cf_kue)
  kuep1 = tf.where(x_abs <= 2., temme_kuep1, cf_kuep1)

  # Now use the forward recurrence for modified bessel functions
  # to compute Kv(v, x). That is,
  # K_{v + 1}(z) - (2v / z) K_v(z) - K_{v - 1}(z) = 0.
  # This is known to be forward numerically stable.
  # Note: This recurrence is also satisfied by K_v(z) * exp(z)

  def bessel_recurrence(index, kve, kvep1):
    if output_log_space:
      next_kvep1 = tfp_math.log_add_exp(
          kvep1 + tf.math.log(u + index) +
          numpy_dtype(np.log(2.)) - tf.math.log(x_abs), kve)
    else:
      next_kvep1 = 2 * (u + index) * kvep1 / x_abs + kve
    kve = tf.where(index > n, kve, kvep1)
    kvep1 = tf.where(index > n, kvep1, next_kvep1)
    return index + 1., kve, kvep1

  _, kve, kvep1 = tf.while_loop(
      cond=lambda i, *_: tf.reduce_any(i <= n),
      body=bessel_recurrence,
      loop_vars=(tf.cast(1., dtype=dtype), kue, kuep1))

  # Finally, it is known that the Wronskian
  # det(I_v * K'_v - K_v * I'_v) = - 1. / x. We can
  # use this to evaluate I_v by taking advantage of identities of Bessel
  # derivatives.

  if output_log_space:
    ive = -tf.math.log(x_abs) - tfp_math.log_add_exp(
        kve + tf.math.log(bessel_iv_ratio(v + 1., x)), kvep1)
  else:
    ive = tf.math.reciprocal(
        x_abs * (kve * bessel_iv_ratio(v + 1., x) + kvep1))

  # We need to add a correction term for negative v.

  if output_log_space:
    log_ive = ive
    negative_v_correction = kve - 2. * x_abs
  else:
    log_ive = tf.math.log(ive)
    negative_v_correction = tf.math.log(kve) - 2. * x_abs

  coeff = 2 / np.pi * tf.math.sin(np.pi * u)
  coeff = (1. - 2. * tf.math.mod(n, 2.)) * coeff

  lse, sign = tfp_math.log_sub_exp(
      log_ive,
      negative_v_correction + tf.math.log(tf.math.abs(coeff)),
      return_sign=True)
  sign = tf.where(coeff < 0., sign, 1.)

  log_ive_negative_v = tf.where(
      coeff < 0.,
      lse,
      tfp_math.log_add_exp(
          log_ive, negative_v_correction + tf.math.log(tf.math.abs(coeff))))

  z = u + tf.math.mod(n, 2.)

  if output_log_space:
    ive = tf.where(v_less_than_zero, log_ive_negative_v, ive)

    ive = tf.where(
        tf.math.equal(x, 0.),
        tf.where(
            tf.math.equal(v, 0.), numpy_dtype(0.), numpy_dtype(-np.inf)), ive)
  else:
    ive = tf.where(
        v_less_than_zero, sign * tf.math.exp(log_ive_negative_v), ive)

    ive = tf.where(
        tf.math.equal(x, 0.),
        tf.where(tf.math.equal(v, 0.), numpy_dtype(1.), numpy_dtype(0.)), ive)

  ive = tf.where(tf.math.equal(x, 0.) & v_less_than_zero,
                 tf.where(
                     tf.math.equal(z, tf.math.floor(z)),
                     ive,
                     numpy_dtype(np.inf)), ive)

  kve = tf.where(tf.math.equal(x, 0.), numpy_dtype(np.inf), kve)
  ive = tf.where(x < 0., numpy_dtype(np.nan), ive)
  kve = tf.where(x < 0., numpy_dtype(np.nan), kve)
  return ive, kve
Ejemplo n.º 12
0
def _olver_asymptotic_uniform(v, z, output_log_space=False, name=None):
  """Use Olver's uniform asymptotic expansion for the Bessel function.

  Olver's uniform asymptotic expansion [1] is specified by

  `I_v(v, v * z) ~ f(a, v) * sum_k U_k(1 / sqrt(1 + z^2)) / v^k`
  `K_v(v, v * z) ~ f(a, v) * sum_k (-1) ** k * U_k(1 / sqrt(1 + z^2)) / v^k`
  where

  * `f(a, v) = `exp(v * a) / (sqrt(2 * pi * v) * (1 + z^2)^0.25)`
  * `U_k(z)` are polynomials that are given in [2]. We use the first
  10 polynomials.

  #### References
  [1]: Digital Library of Mathematical Functions: https://dlmf.nist.gov/10.41
  [2]: F. Olver, Tables for Bessel Functions of Moderate or Large Orders.
       National Physical Laboratory Mathematical Tables, Vol. 6.
       Department of Scientific and Industrial Research

  Args:
    v: value for which `I_{v}(z)` and `K_{v}(z) should be computed.
    z: value for which `I_{v}(z)` and `K_{v}(z) should be computed.
    output_log_space: `bool`. If `True`, output is in log-space.
      Default value: `False`.
    name: A name for the operation (optional).
      Default value: `None` (i.e., 'olver_asymptotic_uniform').
  Returns:
    ive, kve: Asymptotic approximations to the modified bessel functions of the
      first and second kind.
  """
  with tf.name_scope(name or 'olver_asymptotic_uniform'):
    v_abs = tf.math.abs(v)
    w = z / v_abs
    t = tf.math.reciprocal(_sqrt1px2(w))
    n_ufactors = len(_ASYMPTOTIC_OLVER_EXPANSION_COEFFICIENTS)

    divisor = v_abs
    ive_sum = 1.
    kve_sum = 1.

    # Note the polynomials have properties of oddness and evenness so that
    # could be taken advantage of when doing evaluation. For simplicity,
    # we naively sum using Horner's method.
    for i in range(n_ufactors):
      coeff = 0.
      for c in _ASYMPTOTIC_OLVER_EXPANSION_COEFFICIENTS[i]:
        coeff = coeff * t + c
      term = coeff / divisor
      ive_sum = ive_sum + term
      kve_sum = kve_sum + (term if i % 2 == 1 else -term)
      divisor = divisor * v_abs

    # This is modified from the original impl to be more numerically stable
    # since we are subtracting off x.
    shared_prefactor = (tf.math.reciprocal(_sqrt1px2(w) + w) + tf.math.log(w)
                        - tf.math.log1p(tf.math.reciprocal(t)))
    log_i_prefactor = 0.5 * tf.math.log(
        t / (2 * np.pi * v_abs)) + v_abs * shared_prefactor

    # Not the same here since they will have the same sign.
    log_k_prefactor = 0.5 * tf.math.log(
        np.pi * t / (2 * v_abs)) - v_abs * shared_prefactor

    log_kve = log_k_prefactor + tf.math.log(kve_sum)
    log_ive = log_i_prefactor + tf.math.log(ive_sum)

    # We need to add a correction term for negative v.
    negative_v_correction = log_kve - 2. * z
    n = tf.math.round(v)
    u = v - n
    coeff = 2 / np.pi * tf.math.sin(np.pi * u)
    coeff = (1. - 2. * tf.math.mod(n, 2.)) * coeff

    lse, sign = tfp_math.log_sub_exp(
        log_ive,
        negative_v_correction + tf.math.log(tf.math.abs(coeff)),
        return_sign=True)
    sign = tf.where(coeff < 0., sign, 1.)

    log_ive_negative_v = tf.where(
        coeff < 0.,
        lse,
        tfp_math.log_add_exp(
            log_ive, negative_v_correction + tf.math.log(tf.math.abs(coeff))))

    if output_log_space:
      log_ive = tf.where(v >= 0., log_ive, log_ive_negative_v)
      return log_ive, log_kve

    ive = tf.where(
        v >= 0.,
        tf.math.exp(log_ive),
        sign * tf.math.exp(log_ive_negative_v))
    return ive, tf.math.exp(log_kve)
Ejemplo n.º 13
0
    def _loop_build_sub_tree(self, directions, integrator, log_slice_sample,
                             init_energy, iter_, energy_diff_sum_previous,
                             leapfrogs_taken, prev_tree_state,
                             candidate_tree_state, continue_tree_previous,
                             not_divergent_previous, momentum_state_memory):
        """Base case in tree doubling."""
        with tf.name_scope('loop_build_sub_tree'):
            # Take one leapfrog step in the direction v and check divergence
            [
                next_momentum_parts, next_state_parts, next_target,
                next_target_grad_parts
            ] = integrator(prev_tree_state.momentum, prev_tree_state.state,
                           prev_tree_state.target,
                           prev_tree_state.target_grad_parts)

            next_tree_state = TreeDoublingState(
                momentum=next_momentum_parts,
                state=next_state_parts,
                target=next_target,
                target_grad_parts=next_target_grad_parts)
            # If the tree have not yet terminated previously, we count this leapfrog.
            leapfrogs_taken = tf.where(continue_tree_previous,
                                       leapfrogs_taken + 1, leapfrogs_taken)

            # Save state and momentum at odd step, check U turn at even step.
            # Note that here we also write to a Placeholder at even step to avoid
            # using tf.cond
            index = iter_ // 2
            if USE_RAGGED_TENSOR:
                write_index_ = self.write_instruction[index]
            else:
                write_index_ = tf.switch_case(index, self.write_instruction)

            write_index = tf.where(tf.equal(iter_ % 2, 0), write_index_,
                                   self.max_tree_depth)

            if USE_TENSORARRAY:
                momentum_state_memory = MomentumStateSwap(
                    momentum_swap=[
                        old.write(write_index, new) for old, new in zip(
                            momentum_state_memory.momentum_swap,
                            next_momentum_parts)
                    ],
                    state_swap=[
                        old.write(write_index, new) for old, new in zip(
                            momentum_state_memory.state_swap, next_state_parts)
                    ])
            else:
                momentum_state_memory = MomentumStateSwap(
                    momentum_swap=[
                        tf.tensor_scatter_nd_update(old, [[write_index]],
                                                    [new])
                        for old, new in zip(
                            momentum_state_memory.momentum_swap,
                            next_momentum_parts)
                    ],
                    state_swap=[
                        tf.tensor_scatter_nd_update(old, [[write_index]],
                                                    [new]) for old, new in
                        zip(momentum_state_memory.state_swap, next_state_parts)
                    ])
            batch_size = prefer_static.size(next_target)
            has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool)

            if USE_RAGGED_TENSOR:
                no_u_turns_within_tree = tf.cond(
                    tf.equal(iter_ % 2, 0),
                    lambda: has_not_u_turn_at_even_step,
                    lambda: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                        self.read_instruction, iter_ // 2, directions,
                        momentum_state_memory, next_momentum_parts,
                        next_state_parts))
            else:
                f = lambda int_iter: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                    self.read_instruction, int_iter, directions,
                    momentum_state_memory, next_momentum_parts,
                    next_state_parts)
                branch_excution = {
                    x: functools.partial(f, x)
                    for x in range(len(self.read_instruction))
                }
                no_u_turns_within_tree = tf.cond(
                    tf.equal(iter_ % 2,
                             0), lambda: has_not_u_turn_at_even_step,
                    lambda: tf.switch_case(iter_ // 2, branch_excution))

            energy = compute_hamiltonian(next_target, next_momentum_parts)
            energy = tf.where(tf.math.is_nan(energy),
                              tf.constant(-np.inf, dtype=energy.dtype), energy)
            energy_diff = energy - init_energy

            if MULTINOMIAL_SAMPLE:
                not_divergent = -energy_diff < self.max_energy_diff
                weight_sum = log_add_exp(candidate_tree_state.weight,
                                         energy_diff)
                log_accept_thresh = energy_diff - weight_sum
            else:
                not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
                # Uniform sampling on the trajectory within the subtree across valid
                # samples.
                is_valid = log_slice_sample <= energy_diff
                weight_sum = tf.where(is_valid,
                                      candidate_tree_state.weight + 1,
                                      candidate_tree_state.weight)
                log_accept_thresh = tf.where(
                    is_valid,
                    -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
                    tf.constant(-np.inf, dtype=tf.float32))
            u = tf.math.log1p(-tf.random.uniform(shape=[batch_size],
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            next_candidate_tree_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(is_sample_accepted,
                                                     prefer_static.rank(s0)),
                        s0, s1) for s0, s1 in zip(next_state_parts,
                                                  candidate_tree_state.state)
                ],
                target=tf.where(is_sample_accepted, next_target,
                                candidate_tree_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(
                            is_sample_accepted, prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            next_target_grad_parts,
                            candidate_tree_state.target_grad_parts)
                ],
                weight=weight_sum)

            continue_tree = not_divergent & continue_tree_previous
            continue_tree_next = no_u_turns_within_tree & continue_tree

            not_divergent_tokeep = tf.where(
                continue_tree_previous, not_divergent,
                tf.ones([batch_size], dtype=tf.bool))

            # min(1., exp(energy_diff)).
            exp_energy_diff = tf.clip_by_value(tf.exp(energy_diff), 0., 1.)
            energy_diff_sum = tf.where(
                continue_tree, energy_diff_sum_previous + exp_energy_diff,
                energy_diff_sum_previous)

            return (
                iter_ + 1,
                energy_diff_sum,
                leapfrogs_taken,
                next_tree_state,
                next_candidate_tree_state,
                continue_tree_next,
                not_divergent_previous & not_divergent_tokeep,
                momentum_state_memory,
            )
Ejemplo n.º 14
0
    def loop_tree_doubling(self, step_size, log_slice_sample, init_energy,
                           momentum_state_memory, iter_, initial_step_state,
                           initial_step_metastate):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            batch_size = prefer_static.size(init_energy)
            direction = tf.cast(tf.random.uniform(shape=[batch_size],
                                                  minval=0,
                                                  maxval=2,
                                                  dtype=tf.int32,
                                                  seed=self._seed_stream()),
                                dtype=tf.bool)

            left_right_index = tf.concat([
                tf.cast(direction, tf.int32)[..., tf.newaxis],
                tf.range(batch_size, dtype=tf.int32)[..., tf.newaxis]
            ],
                                         axis=1)
            tree_start_states = tf.nest.map_structure(
                # Alternatively: `lambda v: tf.where(direction, v[1], v[0])`
                lambda v: tf.gather_nd(v, left_right_index),
                initial_step_state)

            directions_expanded = [
                _expand_dims_under_batch_dim(direction,
                                             prefer_static.rank(state))
                for state in tree_start_states.state
            ]
            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(direction, ss, -ss)
                    for direction, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state,
                tree_final_states,
                final_not_divergence,
                continue_tree_final,
                energy_diff_tree_sum,
                leapfrogs_taken,
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                log_slice_sample,
                init_energy,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory)

            last_candidate_state = initial_step_metastate.candidate_state
            tree_weight = candidate_tree_state.weight
            if MULTINOMIAL_SAMPLE:
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-tf.random.uniform(shape=[batch_size],
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(choose_new_state,
                                                     prefer_static.rank(s0)),
                        s0, s1) for s0, s1 in zip(candidate_tree_state.state,
                                                  last_candidate_state.state)
                ],
                target=tf.where(choose_new_state, candidate_tree_state.target,
                                last_candidate_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(
                            choose_new_state, prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            candidate_tree_state.target_grad_parts,
                            last_candidate_state.target_grad_parts)
                ],
                weight=weight_sum)
            # Update left right information of the trajectory, and check trajectory
            # level U turn

            # Alternative approach
            # left_right_mask = tf.transpose(
            #     tf.tile(tf.one_hot(tf.cast(direction, tf.int32), 2),
            #            [1, initial_step_metastate.candidate_state[0].shape[-1], 1]),
            #     [2, 0, 1])

            # trajactory_state_left_right = tf.where(
            #     tf.equal(left_right_mask, 0.),
            #     trajactory_state_left_right,
            #     tf.tile(tree_final_states[1][0][tf.newaxis, ...], [2, 1, 1]))
            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    # Alternative approach:
                    # tf.where(tf.equal(left_right_mask, 0.),
                    #          v,
                    #          tf.tile(r[tf.newaxis],
                    #                  tf.concat([[2], tf.ones_like(tf.shape(r))], 0)))
                    tf.tensor_scatter_nd_update(v, left_right_index, r)
                    for v, r in zip(tf.nest.flatten(initial_step_state),
                                    tf.nest.flatten(tree_final_states))
                ])
            no_u_turns_trajectory = has_not_u_turn(
                [s[0] for s in new_step_state.state],
                [m[0] for m in new_step_state.momentum],
                [s[1] for s in new_step_state.state],
                [m[1] for m in new_step_state.momentum])

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                energy_diff_sum=(energy_diff_tree_sum +
                                 initial_step_metastate.energy_diff_sum),
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, new_step_state, new_step_metastate