def _log_bessel_kve_jvp(primals, tangents): """Computes JVP for bessel_kve (supports JAX custom derivative).""" v, z = primals _, dz = tangents dtype = dtype_util.common_dtype([v, z], tf.float32) numpy_dtype = dtype_util.as_numpy_dtype(dtype) # TODO(https://github.com/google/jax/issues/3768): eliminate broadcast_to? bc_shp = ps.broadcast_shape(ps.shape(v), ps.shape(dz)) dz = tf.broadcast_to(dz, bc_shp) log_kve = _log_bessel_kve_custom_gradient(v, z) pz = tfp_math.log_add_exp( _log_bessel_kve_custom_gradient(v - 1., z), _log_bessel_kve_custom_gradient(v + 1., z)) - numpy_dtype( np.log(2.)) - log_kve pz = -tf.math.expm1(pz) # `bessel_kve` does not have gradients with respect to `v`, and thus # this `JVP` rule matches TF. # Ideally, it would be nice to throw an exception when taking gradients of # in JAX mode, but this is not possible at the moment with `custom_jvp`. # See https://github.com/google/jax/issues/5913 for details. # TODO(https://github.com/google/jax/issues/5913): Define vjp for v. return log_kve, pz * dz
def mutate_onestep(i, state, pkr, log_accept_prob_sum): next_state, next_kernel_results = kernel.one_step(state, pkr) kernel_log_accept_ratio, _ = gather_mh_like_result(pkr) log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.) log_accept_prob_sum = log_add_exp( log_accept_prob_sum, log_accept_prob) return i + 1, next_state, next_kernel_results, log_accept_prob_sum
def simple_heuristic_tuning(num_steps, log_scalings, log_accept_prob, optimal_accept=0.234, target_accept_prob=0.99, name=None): """Tune the number of steps and scaling of one mutation. # TODO(b/152412213): Better explanation of the heuristic used here. This is a simple heuristic for tuning the number of steps of the next mutation, as well as the scaling of a transition kernel (e.g., step size in HMC, scale of a Normal proposal in RWMH) using the acceptance probability from the previous mutation stage in SMC. Args: num_steps: The initial number of steps for the next mutation, to be tune. log_scalings: The log of the scale of the proposal kernel log_accept_prob: The log of the acceptance ratio from the last mutation. optimal_accept: Optimal acceptance ratio for a Transitional Kernel. Default value is 0.234 (Optimal for Random Walk Metropolis kernel). target_accept_prob: Target acceptance probability at the end of one mutation step. Default value: 0.99 name: Python `str` name prefixed to Ops created by this function. Default value: `None`. Returns: num_steps: The number of steps for the next mutation. new_log_scalings: The log of the scale of the proposal kernel for the next mutation. """ with tf.name_scope(name or 'simple_heuristic_tuning'): optimal_accept = tf.constant(optimal_accept, dtype=log_accept_prob.dtype) target_accept_prob = tf.constant( target_accept_prob, dtype=log_accept_prob.dtype) log_half_constant = tf.constant(np.log(.5), dtype=log_scalings.dtype) avg_log_scalings = reduce_logmeanexp(log_scalings) avg_log_accept_prob = reduce_logmeanexp(log_accept_prob) avg_log_scaling_target = avg_log_scalings + ( tf.exp(avg_log_accept_prob) - optimal_accept) new_log_scalings = log_half_constant + log_add_exp( avg_log_scaling_target, log_scalings + (tf.exp(log_accept_prob) - optimal_accept) ) num_replica = ps.size0(log_accept_prob) num_proposed = tf.cast( num_replica * num_steps, dtype=avg_log_accept_prob.dtype) log_avg_accept = tf.math.maximum(-tf.math.log(num_proposed), avg_log_accept_prob) num_steps = tf.cast( tf.math.log1p(-target_accept_prob) / log1mexp(log_avg_accept), dtype=num_steps.dtype) return num_steps, new_log_scalings
def bessel_recurrence(index, kve, kvep1): if output_log_space: next_kvep1 = tfp_math.log_add_exp( kvep1 + tf.math.log(u + index) + numpy_dtype(np.log(2.)) - tf.math.log(x_abs), kve) else: next_kvep1 = 2 * (u + index) * kvep1 / x_abs + kve kve = tf.where(index > n, kve, kvep1) kvep1 = tf.where(index > n, kvep1, next_kvep1) return index + 1., kve, kvep1
def _hyp2f1_z_near_one(a, b, c, z): """"Compute 2F1(a, b, c, z) when z is near 1.""" with tf.name_scope('hyp2f1_z_near_one'): dtype = dtype_util.common_dtype([a, b, c, z], tf.float32) a = tf.convert_to_tensor(a, dtype=dtype) b = tf.convert_to_tensor(b, dtype=dtype) c = tf.convert_to_tensor(c, dtype=dtype) z = tf.convert_to_tensor(z, dtype=dtype) # When z > 0.5, We can transform z to 1 - z and make use of a hypergeometric # identity. d = c - a - b # TODO(b/171982819): When tfp.math.log_gamma_difference and tfp.math.lbeta # support negative parameters, use them here for greater accuracy. log_first_coefficient = (tf.math.lgamma(c) + tf.math.lgamma(d) - tf.math.lgamma(c - a) - tf.math.lgamma(c - b)) sign_first_coefficient = ( _gamma_negative(c) ^ _gamma_negative(d) ^ _gamma_negative(c - a) ^ _gamma_negative(c - b)) sign_first_coefficient = -2. * tf.cast(sign_first_coefficient, dtype) + 1. log_second_coefficient = ( tf.math.xlog1py(d, -z) + tf.math.lgamma(c) + tf.math.lgamma(-d) - tf.math.lgamma(a) - tf.math.lgamma(b)) sign_second_coefficient = ( _gamma_negative(c) ^ _gamma_negative(a) ^ _gamma_negative(b) ^ _gamma_negative(-d)) sign_second_coefficient = -2. * tf.cast(sign_second_coefficient, dtype) + 1. first_term = _hyp2f1_internal(a, b, 1 - d, 1 - z) second_term = _hyp2f1_internal(c - a, c - b, d + 1., 1 - z) log_first_term = log_first_coefficient + tf.math.log( tf.math.abs(first_term)) log_second_term = log_second_coefficient + tf.math.log( tf.math.abs(second_term)) sign_first_term = sign_first_coefficient * tf.math.sign(first_term) sign_second_term = sign_second_coefficient * tf.math.sign(second_term) log_diff, sign_log_diff = tfp_math.log_sub_exp( log_first_term, log_second_term, return_sign=True) sign = tf.where( tf.math.equal(sign_first_term, sign_second_term), sign_first_term, sign_first_term * sign_log_diff) log_result = tf.where( tf.math.equal(sign_first_term, sign_second_term), tfp_math.log_add_exp(log_first_term, log_second_term), log_diff) return tf.math.exp(log_result) * sign
def mutate_onestep(i, seed, state, pkr, log_accept_prob_sum): iter_seed, next_seed = (samplers.split_seed(seed) if is_seeded else (None, seed)) one_step_kwargs = dict(seed=iter_seed) if is_seeded else {} next_state, next_kernel_results = kernel.one_step( state, pkr, **one_step_kwargs) kernel_log_accept_ratio, _ = gather_mh_like_result(pkr) log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.) log_accept_prob_sum = log_add_exp(log_accept_prob_sum, log_accept_prob) return [ i + 1, next_seed, next_state, next_kernel_results, log_accept_prob_sum ]
def _log_soosum_exp_impl(logx, axis, keepdims, compute_mean): """Implementation for `*soosum*` functions.""" with tf.name_scope('log_soosum_exp_impl'): logx = tf.convert_to_tensor(logx, name='logx') log_loosum_x, log_sum_x, n = _log_loosum_exp_impl(logx, axis, keepdims, compute_mean=False) # The swap-one-out-sum ('soosum') is n different sums, each of which # replaces the i-th item with the i-th-left-out average (or the user # specified value), i.e., # soo_sum_x[i] = [exp(logx) - exp(logx[i])] + exp(mean(logx[!=i])) # = exp(log_loosum_x[i]) + exp(loo_log_swap_in[i]) loo_log_swap_in = ( (tf.reduce_sum(logx, axis=axis, keepdims=True) - logx) / (n - 1.)) log_soosum_x = log_add_exp(log_loosum_x, loo_log_swap_in) if not compute_mean: return log_soosum_x, log_sum_x log_n = prefer_static.log(n) return log_soosum_x - log_n, log_sum_x - log_n
def _log_bessel_kve_bwd(aux, g): """Reverse mode impl for bessel_kve.""" v, z = aux dtype = dtype_util.common_dtype([v, z], tf.float32) numpy_dtype = dtype_util.as_numpy_dtype(dtype) log_kve = _log_bessel_kve_custom_gradient(v, z) grad_z = tfp_math.log_add_exp( _log_bessel_kve_custom_gradient(v - 1., z), _log_bessel_kve_custom_gradient(v + 1., z)) - numpy_dtype( np.log(2.)) - log_kve grad_z = g * -tf.math.expm1(grad_z) _, grad_z = _fix_gradient_for_broadcasting( v, z, tf.ones_like(grad_z), grad_z) # No gradient for v at the moment. This is a complicated expression # The gradient with respect to the parameter doesn't have an easy closed # form. More work will need to be done to ensure good numerics for the # gradient. # TODO(b/169357627): Implement gradients of modified bessel functions with # respect to parameters. return None, grad_z
def _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory): """Base case in tree doubling.""" with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) momentum_cumsum = [ p0 + p1 for p0, p1 in zip(momentum_cumsum_previous, next_momentum_parts) ] # If the tree have not yet terminated previously, we count this leapfrog. leapfrogs_taken = tf.where(continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken) write_instruction = current_step_meta_info.write_instruction read_instruction = current_step_meta_info.read_instruction init_energy = current_step_meta_info.init_energy if GENERALIZED_UTURN: state_to_write = momentum_cumsum_previous state_to_check = momentum_cumsum else: state_to_write = next_state_parts state_to_check = next_state_parts batch_shape = prefer_static.shape(next_target) has_not_u_turn_init = prefer_static.ones(batch_shape, dtype=tf.bool) read_index = read_instruction.gather([iter_])[0] no_u_turns_within_tree = has_not_u_turn_at_all_index( # pylint: disable=g-long-lambda read_index, directions, momentum_state_memory, next_momentum_parts, state_to_check, has_not_u_turn_init, log_prob_rank=prefer_static.rank(next_target)) # Get index to write state into memory swap write_index = write_instruction.gather([iter_]) momentum_state_memory = MomentumStateSwap( momentum_swap=[ tf.tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ tf.tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(momentum_state_memory.state_swap, state_to_write) ]) energy = compute_hamiltonian(next_target, next_momentum_parts) current_energy = tf.where(tf.math.is_nan(energy), tf.constant(-np.inf, dtype=energy.dtype), energy) energy_diff = current_energy - init_energy if MULTINOMIAL_SAMPLE: not_divergent = -energy_diff < self.max_energy_diff weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff) log_accept_thresh = energy_diff - weight_sum else: log_slice_sample = current_step_meta_info.log_slice_sample not_divergent = log_slice_sample - energy_diff < self.max_energy_diff # Uniform sampling on the trajectory within the subtree across valid # samples. is_valid = log_slice_sample <= energy_diff weight_sum = tf.where(is_valid, candidate_tree_state.weight + 1, candidate_tree_state.weight) log_accept_thresh = tf.where( is_valid, -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)), tf.constant(-np.inf, dtype=tf.float32)) u = tf.math.log1p(-tf.random.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=tf.where( _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(next_target)), next_target, candidate_tree_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( next_target_grad_parts, candidate_tree_state.target_grad_parts) ], energy=tf.where( _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(next_target)), current_energy, candidate_tree_state.energy), weight=weight_sum) continue_tree = not_divergent & continue_tree_previous continue_tree_next = no_u_turns_within_tree & continue_tree not_divergent_tokeep = tf.where( continue_tree_previous, not_divergent, prefer_static.ones(batch_shape, dtype=tf.bool)) # min(1., exp(energy_diff)). exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.)) energy_diff_sum = tf.where( continue_tree, energy_diff_sum_previous + exp_energy_diff, energy_diff_sum_previous) return ( iter_ + 1, energy_diff_sum, momentum_cumsum, leapfrogs_taken, next_tree_state, next_candidate_tree_state, continue_tree_next, not_divergent_previous & not_divergent_tokeep, momentum_state_memory, )
def loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_shape = prefer_static.shape( current_step_meta_info.init_energy) direction = tf.cast(tf.random.uniform(shape=batch_shape, minval=0, maxval=2, dtype=tf.int32, seed=self._seed_stream()), dtype=tf.bool) tree_start_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[1], v[0]), initial_step_state) directions_expanded = [ _rightmost_expand_to_rank(direction, prefer_static.rank(state)) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(d, ss, -ss) for d, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, momentum_subtree_cumsum, leapfrogs_taken ] = self._build_sub_tree( directions_expanded, integrator, current_step_meta_info, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state energy_diff_sum = (energy_diff_tree_sum + initial_step_metastate.energy_diff_sum) if MULTINOMIAL_SAMPLE: tree_weight = tf.where( continue_tree_final, candidate_tree_state.weight, tf.constant(-np.inf, dtype=candidate_tree_state.weight.dtype)) weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: tree_weight = tf.where(continue_tree_final, candidate_tree_state.weight, tf.zeros([], dtype=TREE_COUNT_DTYPE)) weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], energy=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.energy, last_candidate_state.energy), weight=weight_sum) for new_candidate_state_temp, old_candidate_state_temp in zip( new_candidate_state.state, last_candidate_state.state): tensorshape_util.set_shape(new_candidate_state_temp, old_candidate_state_temp.shape) for new_candidate_grad_temp, old_candidate_grad_temp in zip( new_candidate_state.target_grad_parts, last_candidate_state.target_grad_parts): tensorshape_util.set_shape(new_candidate_grad_temp, old_candidate_grad_temp.shape) # Update left right information of the trajectory, and check trajectory # level U turn tree_otherend_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[0], v[1]), initial_step_state) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ tf.stack( [ # pylint: disable=g-complex-comprehension tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), r, l), tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), l, r), ], axis=0) for l, r in zip(tf.nest.flatten(tree_final_states), tf.nest.flatten(tree_otherend_states)) ]) momentum_tree_cumsum = [] for p0, p1 in zip(initial_step_metastate.momentum_sum, momentum_subtree_cumsum): momentum_part_temp = p0 + p1 tensorshape_util.set_shape(momentum_part_temp, p0.shape) momentum_tree_cumsum.append(momentum_part_temp) for new_state_temp, old_state_temp in zip( tf.nest.flatten(new_step_state), tf.nest.flatten(initial_step_state)): tensorshape_util.set_shape(new_state_temp, old_state_temp.shape) if GENERALIZED_UTURN: state_diff = momentum_tree_cumsum else: state_diff = [s[1] - s[0] for s in new_step_state.state] no_u_turns_trajectory = has_not_u_turn( state_diff, [m[0] for m in new_step_state.momentum], [m[1] for m in new_step_state.momentum], log_prob_rank=prefer_static.rank_from_shape(batch_shape)) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, momentum_sum=momentum_tree_cumsum, energy_diff_sum=energy_diff_sum, continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, new_step_state, new_step_metastate
def _temme_expansion(v, x, output_log_space=False): """Compute modified bessel functions using Temme's method.""" # The implementation of this is based on [1]. # [1] N. Temme, On the Numerical Evaluation of the Modified Bessel Function # of the Third Kind. Journal of Computational Physics 19, 1975. dtype = dtype_util.common_dtype([v, x], tf.float32) numpy_dtype = dtype_util.as_numpy_dtype(dtype) v_less_than_zero = v < 0. v = tf.math.abs(v) n = tf.math.round(v) # Use this to compute Kv(u, x) and Kv(u + 1., x) u = v - n x_abs = tf.math.abs(x) small_x = tf.where(x_abs <= 2., x_abs, numpy_dtype(0.1)) large_x = tf.where(x_abs > 2., x_abs, numpy_dtype(1000.)) temme_kue, temme_kuep1 = _temme_series( u, small_x, output_log_space=output_log_space) cf_kue, cf_kuep1 = _continued_fraction_kv( u, large_x, output_log_space=output_log_space) kue = tf.where(x_abs <= 2., temme_kue, cf_kue) kuep1 = tf.where(x_abs <= 2., temme_kuep1, cf_kuep1) # Now use the forward recurrence for modified bessel functions # to compute Kv(v, x). That is, # K_{v + 1}(z) - (2v / z) K_v(z) - K_{v - 1}(z) = 0. # This is known to be forward numerically stable. # Note: This recurrence is also satisfied by K_v(z) * exp(z) def bessel_recurrence(index, kve, kvep1): if output_log_space: next_kvep1 = tfp_math.log_add_exp( kvep1 + tf.math.log(u + index) + numpy_dtype(np.log(2.)) - tf.math.log(x_abs), kve) else: next_kvep1 = 2 * (u + index) * kvep1 / x_abs + kve kve = tf.where(index > n, kve, kvep1) kvep1 = tf.where(index > n, kvep1, next_kvep1) return index + 1., kve, kvep1 _, kve, kvep1 = tf.while_loop( cond=lambda i, *_: tf.reduce_any(i <= n), body=bessel_recurrence, loop_vars=(tf.cast(1., dtype=dtype), kue, kuep1)) # Finally, it is known that the Wronskian # det(I_v * K'_v - K_v * I'_v) = - 1. / x. We can # use this to evaluate I_v by taking advantage of identities of Bessel # derivatives. if output_log_space: ive = -tf.math.log(x_abs) - tfp_math.log_add_exp( kve + tf.math.log(bessel_iv_ratio(v + 1., x)), kvep1) else: ive = tf.math.reciprocal( x_abs * (kve * bessel_iv_ratio(v + 1., x) + kvep1)) # We need to add a correction term for negative v. if output_log_space: log_ive = ive negative_v_correction = kve - 2. * x_abs else: log_ive = tf.math.log(ive) negative_v_correction = tf.math.log(kve) - 2. * x_abs coeff = 2 / np.pi * tf.math.sin(np.pi * u) coeff = (1. - 2. * tf.math.mod(n, 2.)) * coeff lse, sign = tfp_math.log_sub_exp( log_ive, negative_v_correction + tf.math.log(tf.math.abs(coeff)), return_sign=True) sign = tf.where(coeff < 0., sign, 1.) log_ive_negative_v = tf.where( coeff < 0., lse, tfp_math.log_add_exp( log_ive, negative_v_correction + tf.math.log(tf.math.abs(coeff)))) z = u + tf.math.mod(n, 2.) if output_log_space: ive = tf.where(v_less_than_zero, log_ive_negative_v, ive) ive = tf.where( tf.math.equal(x, 0.), tf.where( tf.math.equal(v, 0.), numpy_dtype(0.), numpy_dtype(-np.inf)), ive) else: ive = tf.where( v_less_than_zero, sign * tf.math.exp(log_ive_negative_v), ive) ive = tf.where( tf.math.equal(x, 0.), tf.where(tf.math.equal(v, 0.), numpy_dtype(1.), numpy_dtype(0.)), ive) ive = tf.where(tf.math.equal(x, 0.) & v_less_than_zero, tf.where( tf.math.equal(z, tf.math.floor(z)), ive, numpy_dtype(np.inf)), ive) kve = tf.where(tf.math.equal(x, 0.), numpy_dtype(np.inf), kve) ive = tf.where(x < 0., numpy_dtype(np.nan), ive) kve = tf.where(x < 0., numpy_dtype(np.nan), kve) return ive, kve
def _olver_asymptotic_uniform(v, z, output_log_space=False, name=None): """Use Olver's uniform asymptotic expansion for the Bessel function. Olver's uniform asymptotic expansion [1] is specified by `I_v(v, v * z) ~ f(a, v) * sum_k U_k(1 / sqrt(1 + z^2)) / v^k` `K_v(v, v * z) ~ f(a, v) * sum_k (-1) ** k * U_k(1 / sqrt(1 + z^2)) / v^k` where * `f(a, v) = `exp(v * a) / (sqrt(2 * pi * v) * (1 + z^2)^0.25)` * `U_k(z)` are polynomials that are given in [2]. We use the first 10 polynomials. #### References [1]: Digital Library of Mathematical Functions: https://dlmf.nist.gov/10.41 [2]: F. Olver, Tables for Bessel Functions of Moderate or Large Orders. National Physical Laboratory Mathematical Tables, Vol. 6. Department of Scientific and Industrial Research Args: v: value for which `I_{v}(z)` and `K_{v}(z) should be computed. z: value for which `I_{v}(z)` and `K_{v}(z) should be computed. output_log_space: `bool`. If `True`, output is in log-space. Default value: `False`. name: A name for the operation (optional). Default value: `None` (i.e., 'olver_asymptotic_uniform'). Returns: ive, kve: Asymptotic approximations to the modified bessel functions of the first and second kind. """ with tf.name_scope(name or 'olver_asymptotic_uniform'): v_abs = tf.math.abs(v) w = z / v_abs t = tf.math.reciprocal(_sqrt1px2(w)) n_ufactors = len(_ASYMPTOTIC_OLVER_EXPANSION_COEFFICIENTS) divisor = v_abs ive_sum = 1. kve_sum = 1. # Note the polynomials have properties of oddness and evenness so that # could be taken advantage of when doing evaluation. For simplicity, # we naively sum using Horner's method. for i in range(n_ufactors): coeff = 0. for c in _ASYMPTOTIC_OLVER_EXPANSION_COEFFICIENTS[i]: coeff = coeff * t + c term = coeff / divisor ive_sum = ive_sum + term kve_sum = kve_sum + (term if i % 2 == 1 else -term) divisor = divisor * v_abs # This is modified from the original impl to be more numerically stable # since we are subtracting off x. shared_prefactor = (tf.math.reciprocal(_sqrt1px2(w) + w) + tf.math.log(w) - tf.math.log1p(tf.math.reciprocal(t))) log_i_prefactor = 0.5 * tf.math.log( t / (2 * np.pi * v_abs)) + v_abs * shared_prefactor # Not the same here since they will have the same sign. log_k_prefactor = 0.5 * tf.math.log( np.pi * t / (2 * v_abs)) - v_abs * shared_prefactor log_kve = log_k_prefactor + tf.math.log(kve_sum) log_ive = log_i_prefactor + tf.math.log(ive_sum) # We need to add a correction term for negative v. negative_v_correction = log_kve - 2. * z n = tf.math.round(v) u = v - n coeff = 2 / np.pi * tf.math.sin(np.pi * u) coeff = (1. - 2. * tf.math.mod(n, 2.)) * coeff lse, sign = tfp_math.log_sub_exp( log_ive, negative_v_correction + tf.math.log(tf.math.abs(coeff)), return_sign=True) sign = tf.where(coeff < 0., sign, 1.) log_ive_negative_v = tf.where( coeff < 0., lse, tfp_math.log_add_exp( log_ive, negative_v_correction + tf.math.log(tf.math.abs(coeff)))) if output_log_space: log_ive = tf.where(v >= 0., log_ive, log_ive_negative_v) return log_ive, log_kve ive = tf.where( v >= 0., tf.math.exp(log_ive), sign * tf.math.exp(log_ive_negative_v)) return ive, tf.math.exp(log_kve)
def _loop_build_sub_tree(self, directions, integrator, log_slice_sample, init_energy, iter_, energy_diff_sum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory): """Base case in tree doubling.""" with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) # If the tree have not yet terminated previously, we count this leapfrog. leapfrogs_taken = tf.where(continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken) # Save state and momentum at odd step, check U turn at even step. # Note that here we also write to a Placeholder at even step to avoid # using tf.cond index = iter_ // 2 if USE_RAGGED_TENSOR: write_index_ = self.write_instruction[index] else: write_index_ = tf.switch_case(index, self.write_instruction) write_index = tf.where(tf.equal(iter_ % 2, 0), write_index_, self.max_tree_depth) if USE_TENSORARRAY: momentum_state_memory = MomentumStateSwap( momentum_swap=[ old.write(write_index, new) for old, new in zip( momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ old.write(write_index, new) for old, new in zip( momentum_state_memory.state_swap, next_state_parts) ]) else: momentum_state_memory = MomentumStateSwap( momentum_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip( momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip(momentum_state_memory.state_swap, next_state_parts) ]) batch_size = prefer_static.size(next_target) has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool) if USE_RAGGED_TENSOR: no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, iter_ // 2, directions, momentum_state_memory, next_momentum_parts, next_state_parts)) else: f = lambda int_iter: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, int_iter, directions, momentum_state_memory, next_momentum_parts, next_state_parts) branch_excution = { x: functools.partial(f, x) for x in range(len(self.read_instruction)) } no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: tf.switch_case(iter_ // 2, branch_excution)) energy = compute_hamiltonian(next_target, next_momentum_parts) energy = tf.where(tf.math.is_nan(energy), tf.constant(-np.inf, dtype=energy.dtype), energy) energy_diff = energy - init_energy if MULTINOMIAL_SAMPLE: not_divergent = -energy_diff < self.max_energy_diff weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff) log_accept_thresh = energy_diff - weight_sum else: not_divergent = log_slice_sample - energy_diff < self.max_energy_diff # Uniform sampling on the trajectory within the subtree across valid # samples. is_valid = log_slice_sample <= energy_diff weight_sum = tf.where(is_valid, candidate_tree_state.weight + 1, candidate_tree_state.weight) log_accept_thresh = tf.where( is_valid, -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)), tf.constant(-np.inf, dtype=tf.float32)) u = tf.math.log1p(-tf.random.uniform(shape=[batch_size], dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim(is_sample_accepted, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=tf.where(is_sample_accepted, next_target, candidate_tree_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( is_sample_accepted, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( next_target_grad_parts, candidate_tree_state.target_grad_parts) ], weight=weight_sum) continue_tree = not_divergent & continue_tree_previous continue_tree_next = no_u_turns_within_tree & continue_tree not_divergent_tokeep = tf.where( continue_tree_previous, not_divergent, tf.ones([batch_size], dtype=tf.bool)) # min(1., exp(energy_diff)). exp_energy_diff = tf.clip_by_value(tf.exp(energy_diff), 0., 1.) energy_diff_sum = tf.where( continue_tree, energy_diff_sum_previous + exp_energy_diff, energy_diff_sum_previous) return ( iter_ + 1, energy_diff_sum, leapfrogs_taken, next_tree_state, next_candidate_tree_state, continue_tree_next, not_divergent_previous & not_divergent_tokeep, momentum_state_memory, )
def loop_tree_doubling(self, step_size, log_slice_sample, init_energy, momentum_state_memory, iter_, initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_size = prefer_static.size(init_energy) direction = tf.cast(tf.random.uniform(shape=[batch_size], minval=0, maxval=2, dtype=tf.int32, seed=self._seed_stream()), dtype=tf.bool) left_right_index = tf.concat([ tf.cast(direction, tf.int32)[..., tf.newaxis], tf.range(batch_size, dtype=tf.int32)[..., tf.newaxis] ], axis=1) tree_start_states = tf.nest.map_structure( # Alternatively: `lambda v: tf.where(direction, v[1], v[0])` lambda v: tf.gather_nd(v, left_right_index), initial_step_state) directions_expanded = [ _expand_dims_under_batch_dim(direction, prefer_static.rank(state)) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(direction, ss, -ss) for direction, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, leapfrogs_taken, ] = self._build_sub_tree( directions_expanded, integrator, log_slice_sample, init_energy, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state tree_weight = candidate_tree_state.weight if MULTINOMIAL_SAMPLE: weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform(shape=[batch_size], dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim(choose_new_state, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=tf.where(choose_new_state, candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( choose_new_state, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], weight=weight_sum) # Update left right information of the trajectory, and check trajectory # level U turn # Alternative approach # left_right_mask = tf.transpose( # tf.tile(tf.one_hot(tf.cast(direction, tf.int32), 2), # [1, initial_step_metastate.candidate_state[0].shape[-1], 1]), # [2, 0, 1]) # trajactory_state_left_right = tf.where( # tf.equal(left_right_mask, 0.), # trajactory_state_left_right, # tf.tile(tree_final_states[1][0][tf.newaxis, ...], [2, 1, 1])) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ # Alternative approach: # tf.where(tf.equal(left_right_mask, 0.), # v, # tf.tile(r[tf.newaxis], # tf.concat([[2], tf.ones_like(tf.shape(r))], 0))) tf.tensor_scatter_nd_update(v, left_right_index, r) for v, r in zip(tf.nest.flatten(initial_step_state), tf.nest.flatten(tree_final_states)) ]) no_u_turns_trajectory = has_not_u_turn( [s[0] for s in new_step_state.state], [m[0] for m in new_step_state.momentum], [s[1] for s in new_step_state.state], [m[1] for m in new_step_state.momentum]) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, energy_diff_sum=(energy_diff_tree_sum + initial_step_metastate.energy_diff_sum), continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, new_step_state, new_step_metastate