def _build_sub_tree(self, directions, integrator, current_step_meta_info, nsteps, initial_state, continue_tree, not_divergence, momentum_state_memory, seed, name=None): with tf.name_scope('build_sub_tree'): batch_shape = ps.shape(current_step_meta_info.init_energy) # We never want to select the inital state if MULTINOMIAL_SAMPLE: init_weight = tf.fill( batch_shape, tf.constant( -np.inf, dtype=current_step_meta_info.init_energy.dtype)) else: init_weight = tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE) init_momentum_cumsum = [ tf.zeros_like(x) for x in initial_state.momentum ] initial_state_candidate = TreeDoublingStateCandidate( state=initial_state.state, target=initial_state.target, target_grad_parts=initial_state.target_grad_parts, energy=initial_state.target, weight=init_weight) energy_diff_sum = tf.zeros_like(current_step_meta_info.init_energy, name='energy_diff_sum') [ _, _, energy_diff_tree_sum, momentum_tree_cumsum, leapfrogs_taken, final_state, candidate_tree_state, final_continue_tree, final_not_divergence, momentum_state_memory, ] = tf.while_loop( cond=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory: ( (iter_ < nsteps) & tf.reduce_any(continue_tree)), body=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory: (self._loop_build_sub_tree( directions, integrator, current_step_meta_info, iter_, energy_diff_sum, init_momentum_cumsum, leapfrogs_taken, state, state_c, continue_tree, not_divergence, momentum_state_memory, seed)), loop_vars=( tf.zeros([], dtype=tf.int32, name='iter'), seed, energy_diff_sum, init_momentum_cumsum, tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE), initial_state, initial_state_candidate, continue_tree, not_divergence, momentum_state_memory, ), parallel_iterations=self.parallel_iterations) return ( candidate_tree_state, final_state, final_not_divergence, final_continue_tree, energy_diff_tree_sum, momentum_tree_cumsum, leapfrogs_taken, )
def sample_sequential_monte_carlo( prior_log_prob_fn, likelihood_log_prob_fn, current_state, min_num_steps=2, max_num_steps=25, max_stage=100, make_kernel_fn=make_rwmh_kernel_fn, tuning_fn=simple_heuristic_tuning, make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn, resample_fn=weighted_resampling.resample_systematic, ess_threshold_ratio=0.5, parallel_iterations=10, seed=None, name=None): """Runs Sequential Monte Carlo to sample from the posterior distribution. This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo) to sample from a series of distributions that slowly interpolates between an initial 'prior' distribution: `exp(prior_log_prob_fn(x))` and the target 'posterior' distribution: `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`, by mutating a collection of MC samples (i.e., particles). The approach is also known as Particle Filter in some literature. The current implemenetation is largely based on Del Moral et al [1], which adapts the tempering sequence adaptively (base on the effective sample size) and the scaling of the mutation kernel (base on the sample covariance of the particles) at each stage. Args: prior_log_prob_fn: Python callable that returns the log density of the prior distribution. likelihood_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the likelihood distribution. current_state: Nested structure of `Tensor`s, each of shape `concat([[num_particles, b1, ..., bN], latent_part_event_shape])`, where `b1, ..., bN` are optional batch dimensions. Each batch represents an independent SMC run. min_num_steps: The minimal number of kernel transition steps in one mutation of the MC samples. max_num_steps: The maximum number of kernel transition steps in one mutation of the MC samples. Note that the actual number of steps in one mutation is tuned during sampling and likely lower than the max_num_step. max_stage: Integer number of the stage for increasing the temperature from 0 to 1. make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like object. Must take one argument representing the `TransitionKernel`'s `target_log_prob_fn`. The `target_log_prob_fn` argument represents the `TransitionKernel`'s target log distribution. Note: `sample_sequential_monte_carlo` creates a new `target_log_prob_fn` which is an interpolation between the supplied `target_log_prob_fn` and `proposal_log_prob_fn`; it is this interpolated function which is used as an argument to `make_kernel_fn`. tuning_fn: Python `callable` which takes the number of steps, the log scaling, and the log acceptance ratio from the last mutation and output the number of steps and log scaling for the next mutation. make_tempered_target_log_prob_fn: Python `callable` that takes the `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures` and creates a `target_log_prob_fn` `callable` that pass to `make_kernel_fn`. resample_fn: Python `callable` to generate the indices of resampled particles, given their weights. Generally, one of `tfp.experimental.mcmc.resample_independent` or `tfp.experimental.mcmc.resample_systematic`, or any function with the same signature. Default value: `tfp.experimental.mcmc.resample_systematic`. ess_threshold_ratio: Target ratio for effective sample size. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: Python integer or TFP seedstream to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'sample_sequential_monte_carlo'). Returns: n_stage: Number of the mutation stage SMC ran. final_state: `Tensor` or Python `list` of `Tensor`s representing the final state(s) of the Markov chain(s). The output are the posterior samples. final_kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. #### References [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential Monte Carlo method for approximate Bayesian computation. _Statistics and Computing_, 22.5(1009-1020), 2012. """ with tf.name_scope(name or 'sample_sequential_monte_carlo'): is_seeded = seed is not None seed = samplers.sanitize_seed(seed, salt='mcmc.sample_smc') unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] current_state = [ tf.convert_to_tensor(s, dtype_hint=tf.float32) for s in current_state ] # Initial preprocessing at Stage 0 likelihood_log_prob = likelihood_log_prob_fn(*current_state) likelihood_rank = ps.rank(likelihood_log_prob) dimension = ps.reduce_sum([ ps.reduce_prod(ps.shape(x)[likelihood_rank:]) for x in current_state ]) # We infer the particle shapes from the resulting likelihood: # [num_particles, b1, ..., bN] particle_shape = ps.shape(likelihood_log_prob) num_particles, batch_shape = particle_shape[0], particle_shape[1:] effective_sample_size_threshold = tf.cast( num_particles * ess_threshold_ratio, tf.int32) # TODO(b/152412213): Revisit this default parameter. # Default to the optimal scaling of a random walk kernel for a d-dimensional # normal distributed targets: 2.38 ** 2 / d. # For more detail see: # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of # random walk Metropolis algorithms. _The annals of applied probability_. # 1997;7(1):110-20. scale_start = (tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) / tf.constant(dimension, dtype=likelihood_log_prob.dtype)) inverse_temperature = tf.zeros(batch_shape, dtype=likelihood_log_prob.dtype) scalings = ps.ones_like(likelihood_log_prob) * ps.minimum( scale_start, 1.) kernel = make_kernel_fn( make_tempered_target_log_prob_fn(prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature), current_state, scalings) pkr = kernel.bootstrap_results(current_state) _, kernel_target_log_prob = gather_mh_like_result(pkr) particle_info = ParticleInfo( log_accept_prob=ps.zeros_like(likelihood_log_prob), log_scalings=tf.math.log(scalings), tempered_log_prob=kernel_target_log_prob, likelihood_log_prob=likelihood_log_prob, ) current_pkr = SMCResults( num_steps=tf.convert_to_tensor(max_num_steps, dtype=tf.int32, name='num_steps'), inverse_temperature=inverse_temperature, log_marginal_likelihood=tf.zeros_like(inverse_temperature), particle_info=particle_info) def update_weights_temperature(inverse_temperature, likelihood_log_prob): """Calculate the next inverse temperature and update weights.""" likelihood_diff = likelihood_log_prob - tf.reduce_max( likelihood_log_prob, axis=0) def _body_fn(new_beta, upper_beta, lower_beta, eff_size, log_weights): """One iteration of the temperature and weight update.""" new_beta = (lower_beta + upper_beta) / 2.0 log_weights = (new_beta - inverse_temperature) * likelihood_diff log_weights_norm = tf.math.log_softmax(log_weights, axis=0) eff_size = tf.cast( tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm, axis=0)), tf.int32) upper_beta = tf.where( eff_size < effective_sample_size_threshold, new_beta, upper_beta) lower_beta = tf.where( eff_size < effective_sample_size_threshold, lower_beta, new_beta) return new_beta, upper_beta, lower_beta, eff_size, log_weights def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_): # pylint: disable=unused-argument # TODO(junpenglao): revisit threshold below to be dtype specific. threshold = 1e-6 return (tf.math.reduce_any(upper_beta - lower_beta > threshold) & tf.math.reduce_any( eff_size != effective_sample_size_threshold)) (new_beta, upper_beta, lower_beta, eff_size, log_weights) = tf.while_loop( # pylint: disable=unused-variable cond=_cond_fn, body=_body_fn, loop_vars=(tf.zeros_like(inverse_temperature), tf.fill(ps.shape(inverse_temperature), tf.constant(2, inverse_temperature.dtype)), inverse_temperature, tf.zeros_like(inverse_temperature, dtype=tf.int32), tf.zeros_like(likelihood_diff)), parallel_iterations=parallel_iterations) log_weights = tf.where(new_beta < 1., log_weights, (1. - inverse_temperature) * likelihood_diff) marginal_loglike_ = reduce_logmeanexp( (new_beta - inverse_temperature) * likelihood_log_prob, axis=0) new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.) return marginal_loglike_, new_inverse_temperature, log_weights def mutate(current_state, log_scalings, num_steps, inverse_temperature): """Mutate the state using a Transition kernel.""" with tf.name_scope('mutate_states'): scalings = tf.exp(log_scalings) kernel = make_kernel_fn( make_tempered_target_log_prob_fn(prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature), current_state, scalings) pkr = kernel.bootstrap_results(current_state) kernel_log_accept_ratio, _ = gather_mh_like_result(pkr) def mutate_onestep(i, seed, state, pkr, log_accept_prob_sum): iter_seed, next_seed = (samplers.split_seed(seed) if is_seeded else (None, seed)) one_step_kwargs = dict(seed=iter_seed) if is_seeded else {} next_state, next_kernel_results = kernel.one_step( state, pkr, **one_step_kwargs) kernel_log_accept_ratio, _ = gather_mh_like_result(pkr) log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.) log_accept_prob_sum = log_add_exp(log_accept_prob_sum, log_accept_prob) return [ i + 1, next_seed, next_state, next_kernel_results, log_accept_prob_sum ] ( _, _, next_state, next_kernel_results, log_accept_prob_sum ) = tf.while_loop( cond=lambda i, *args: i < num_steps, body=mutate_onestep, loop_vars=( tf.zeros([], dtype=tf.int32), seed, current_state, pkr, # we accumulate the acceptance probability in log space. tf.fill( ps.shape(kernel_log_accept_ratio), tf.constant(-np.inf, kernel_log_accept_ratio.dtype))), parallel_iterations=parallel_iterations) _, kernel_target_log_prob = gather_mh_like_result( next_kernel_results) avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log( tf.cast(num_steps + 1, log_accept_prob_sum.dtype)) return (next_state, avg_log_accept_prob_per_particle, kernel_target_log_prob) # One SMC steps. def smc_body_fn(stage, state, smc_kernel_result): """Run one stage of SMC with constant temperature.""" (new_marginal, new_inv_temperature, log_weights) = update_weights_temperature( smc_kernel_result.inverse_temperature, smc_kernel_result.particle_info.likelihood_log_prob) # TODO(b/152412213) Use a tf.scan to better collect debug info. if PRINT_DEBUG: tf.print( 'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:', smc_kernel_result.num_steps, 'accept:', tf.exp( reduce_logmeanexp( smc_kernel_result.particle_info.log_accept_prob, axis=0)), 'scaling:', tf.exp( reduce_logmeanexp( smc_kernel_result.particle_info.log_scalings, axis=0))) (resampled_state, resampled_particle_info), _, _ = weighted_resampling.resample( particles=(state, smc_kernel_result.particle_info), log_weights=log_weights, resample_fn=resample_fn, seed=seed) next_num_steps, next_log_scalings = tuning_fn( smc_kernel_result.num_steps, resampled_particle_info.log_scalings, resampled_particle_info.log_accept_prob) # Skip tuning at stage 0. next_num_steps = tf.where(stage == 0, smc_kernel_result.num_steps, next_num_steps) next_log_scalings = tf.where(stage == 0, resampled_particle_info.log_scalings, next_log_scalings) next_num_steps = tf.clip_by_value(next_num_steps, min_num_steps, max_num_steps) next_state, log_accept_prob, tempered_log_prob = mutate( resampled_state, next_log_scalings, next_num_steps, new_inv_temperature) next_pkr = SMCResults( num_steps=next_num_steps, inverse_temperature=new_inv_temperature, log_marginal_likelihood=( new_marginal + smc_kernel_result.log_marginal_likelihood), particle_info=ParticleInfo( log_accept_prob=log_accept_prob, log_scalings=next_log_scalings, tempered_log_prob=tempered_log_prob, likelihood_log_prob=likelihood_log_prob_fn(*next_state), )) return stage + 1, next_state, next_pkr (n_stage, final_state, final_kernel_results) = tf.while_loop( cond=lambda i, state, pkr: ( # pylint: disable=g-long-lambda (i < max_stage) & tf.reduce_any(pkr.inverse_temperature < 1.)), body=smc_body_fn, loop_vars=(tf.zeros([], dtype=tf.int32), current_state, current_pkr), parallel_iterations=parallel_iterations) if unwrap_state_list: final_state = final_state[0] return n_stage, final_state, final_kernel_results
def trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn=None, static_trace_allocation_size=None, parallel_iterations=10, name=None): """A simplified version of `tf.scan` that has configurable tracing. This function repeatedly calls `loop_fn(state, elem)`, where `state` is the `initial_state` during the first iteration, and the return value of `loop_fn` for every iteration thereafter. `elem` is a slice of `elements` along the first dimension, accessed in order. Additionally, it calls `trace_fn` on the return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are stacked and returned from this function, such that the first dimension of those `Tensor`s matches the size of `elems`. Args: loop_fn: A callable that takes in a `Tensor` or a nested collection of `Tensor`s with the same structure as `initial_state`, a slice of `elems` and returns the same structure as `initial_state`. initial_state: A `Tensor` or a nested collection of `Tensor`s passed to `loop_fn` in the first iteration. elems: A `Tensor` that is split along the first dimension and each element of which is passed to `loop_fn`. trace_fn: A callable that takes in the return value of `loop_fn` and returns a `Tensor` or a nested collection of `Tensor`s. trace_criterion_fn: Optional callable that takes in the return value of `loop_fn` and returns a boolean `Tensor` indicating whether to trace it. If `None`, all steps are traced. Default value: `None`. static_trace_allocation_size: Optional Python `int` size of trace to allocate statically. This should be an upper bound on the number of steps traced and is used only when the length cannot be statically inferred (for example, if a `trace_criterion_fn` is specified). It is primarily intended for contexts where static shapes are required, such as in XLA-compiled code. Default value: `None`. parallel_iterations: Passed to the internal `tf.while_loop`. name: Name scope used in this function. Default: 'trace_scan'. Returns: final_state: The final return value of `loop_fn`. trace: The same structure as the return value of `trace_fn`, but with each `Tensor` being a stack of the corresponding `Tensors` in the return value of `trace_fn` for each slice of `elems`. """ with tf.name_scope(name or 'trace_scan'), tf1.variable_scope( tf1.get_variable_scope()) as vs: if vs.caching_device is None and not tf.executing_eagerly(): vs.set_caching_device(lambda op: op.device) initial_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name='initial_state'), initial_state, expand_composites=True) elems = tf.convert_to_tensor(elems, name='elems') length = ps.size0(elems) # This is an TensorArray in part because of XLA, which had trouble with # non-statically known indices. I.e. elems[i] errored, but # elems_array.read(i) worked. elems_array = tf.TensorArray( elems.dtype, size=length, element_shape=elems.shape[1:]) elems_array = elems_array.unstack(elems) # Initialize trace arrays. if trace_criterion_fn is None: dynamic_size, initial_size = tf.is_tensor(length), length elif static_trace_allocation_size is not None: dynamic_size, initial_size = False, static_trace_allocation_size elif JAX_MODE or (not tf.executing_eagerly() and control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): dynamic_size, initial_size = False, length else: dynamic_size, initial_size = True, 0 initial_trace = trace_fn(initial_state) flat_initial_trace = tf.nest.flatten(initial_trace, expand_composites=True) trace_arrays = [] for trace_elt in flat_initial_trace: trace_arrays.append( tf.TensorArray( trace_elt.dtype, size=initial_size, dynamic_size=dynamic_size, element_shape=trace_elt.shape)) # Helper for writing a (structured) state to (structured) arrays. def trace_one_step(num_steps_traced, trace_arrays, state): return [ta.write(num_steps_traced, x) for ta, x in zip(trace_arrays, tf.nest.flatten(trace_fn(state), expand_composites=True))] def _body(i, state, num_steps_traced, trace_arrays): elem = elems_array.read(i) state = loop_fn(state, elem) trace_arrays, num_steps_traced = ps.cond( trace_criterion_fn(state) if trace_criterion_fn else True, lambda: (trace_one_step(num_steps_traced, trace_arrays, state), # pylint: disable=g-long-lambda num_steps_traced + 1), lambda: (trace_arrays, num_steps_traced)) return i + 1, state, num_steps_traced, trace_arrays _, final_state, _, trace_arrays = tf.while_loop( cond=lambda i, *_: i < length, body=_body, loop_vars=(0, initial_state, 0, trace_arrays), parallel_iterations=parallel_iterations) # unflatten stacked_trace = tf.nest.pack_sequence_as( initial_trace, [ta.stack() for ta in trace_arrays], expand_composites=True) # Restore the static length if we know it. static_length = tf.TensorShape(None if dynamic_size else initial_size) def _merge_static_length(x): tensorshape_util.set_shape(x, static_length.concatenate(x.shape[1:])) return x stacked_trace = tf.nest.map_structure( _merge_static_length, stacked_trace, expand_composites=True) return final_state, stacked_trace
def _sample_n(self, n, seed=None): power = tf.convert_to_tensor(self.power) shape = tf.concat([[n], tf.shape(power)], axis=0) seed = samplers.sanitize_seed(seed, salt='zipf') minval_u = self._hat_integral(0.5, power=power) + 1. maxval_u = self._hat_integral(dtype_util.max(tf.int64) - 0.5, power=power) def loop_body(should_continue, k, seed): """Resample the non-accepted points.""" u_seed, next_seed = samplers.split_seed(seed) # The range of U is chosen so that the resulting sample K lies in # [0, tf.int64.max). The final sample, if accepted, is K + 1. u = samplers.uniform(shape, minval=minval_u, maxval=maxval_u, dtype=power.dtype, seed=u_seed) # set_shape needed here because of b/139013403 tensorshape_util.set_shape(u, should_continue.shape) # Sample the point X from the continuous density h(x) \propto x^(-power). x = self._hat_integral_inverse(u, power=power) # Rejection-inversion requires a `hat` function, h(x) such that # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the # support. A natural hat function for us is h(x) = x^(-power). # # After sampling X from h(x), suppose it lies in the interval # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if # if lies to the left of x_K, where x_K is defined by: # \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1), # where H(x) = \int_x^inf h(x) dx. # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)). # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)). # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1). # Update the non-accepted points. # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5). k = tf.where(should_continue, tf.floor(x + 0.5), k) accept = (u <= self._hat_integral(k + .5, power=power) + tf.exp(self._log_prob(k + 1, power=power))) return [should_continue & (~accept), k, next_seed] should_continue, samples, _ = tf.while_loop( cond=lambda should_continue, *ignore: tf.reduce_any(should_continue ), body=loop_body, loop_vars=[ tf.ones(shape, dtype=tf.bool), # should_continue tf.zeros(shape, dtype=power.dtype), # k seed, # seed ], maximum_iterations=self.sample_maximum_iterations, ) samples = samples + 1. if self.validate_args and dtype_util.is_integer(self.dtype): samples = distribution_util.embed_check_integer_casting_closed( samples, target_dtype=self.dtype, assert_positive=True) samples = tf.cast(samples, self.dtype) if self.validate_args: npdt = dtype_util.as_numpy_dtype(self.dtype) v = npdt( dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan ) samples = tf.where(should_continue, v, samples) return samples
def bessel_iv_ratio(v, z, name=None): """Computes `I_{v} (z) / I_{v - 1} (z)` in a numerically stable way. Let I(v, z) be the modified bessel function of the first kind. This computes the ratio of I(v, z) / I(v - 1, z). This can be more numerically stable and faster than computing the ratio directly. This uses a continued fraction approximation attributed to Gauss for computing this quantity in the limit where z <= v, and a continued fraction approximation attributed to Perron for z > v. Args: v: value for which `I_{v}(z) / I_{v - 1}(z)` should be computed. Expect v > 0. z: value for which `I_{v}(z) / I_{v - 1}(z)` should be computed. Expect z > 0. name: A name for the operation (optional). Default value: `None` (i.e., 'bessel_iv_ratio'). Returns: I(v, z) / I(v - 1, z). #### References [1]: Walter Gautschi and Josef Slavik. On the Computation of Modified Bessel Function Ratios. http://www.jstor.com/stable/2006491 """ with tf.name_scope(name or 'bessel_iv_ratio'): dtype = dtype_util.common_dtype([v, z], tf.float32) v = tf.convert_to_tensor(v, dtype=dtype) z = tf.convert_to_tensor(z, dtype=dtype) np_finfo = np.finfo(dtype_util.as_numpy_dtype(dtype)) tolerance = tf.cast(np_finfo.resolution, dtype=dtype) safe_to_use_perron = z > v def gauss_term_fn(iteration_count, v, z): """Terms for the Gauss continued fraction.""" return tf.math.square(z) / 4. / ((v + iteration_count - 1) * (v + iteration_count)) # The Gauss continued fraction converges faster for z < v. # For z > v, set z to something much less than v. safe_z_less_v = tf.where(safe_to_use_perron, v / 1000., z) # We use forward recurrence for the Gauss continued fraction. # This is so that we can do early termination. # There are a few reasons why this doesn't overflow: # * All partial numerators / denominators are positive. # * Partial numerators approach zero as 1 / n**2, where # n is the iteration count. # * All partial numerators are less than 1. # Combined with the recurrence, this ensures no overflow. # as the number of iterations -> infinity. gauss_cf = _compute_general_continued_fraction( # Use a max of 200 steps. Almost always we will be much less # than this. 200, [v, safe_z_less_v], tolerance=tolerance, partial_numerator_fn=gauss_term_fn) # Add the zeroth term for the Gauss continued fraction. gauss_cf = tf.math.reciprocal((1. + gauss_cf) * 2. * v / z) # For the Perron CF we use the backward recurrence. This is because # generally the backward recurrence is more numerically stable # than forward recurrence, especially with negative terms. # We use a flat 50 steps. Anecdotally, for z > v, convergence is # much faster than that. # The Perron continued fraction converges much faster for z >> v. # For z < v, set z to something much greater than v. safe_z_greater_v = tf.where(~safe_to_use_perron, 1000. * v, z) def perron_term_fn(iteration_count, v, z): """Terms for the Perron continued fraction.""" return -0.5 * z * (v + iteration_count - 0.5) / ( (v + z + (iteration_count - 1.) / 2.) * (v + z + iteration_count / 2.)) total_perron_iteration_count = 50 def _backward_cf_one_step(iteration_count, cf): cf = perron_term_fn(total_perron_iteration_count - iteration_count, v, safe_z_greater_v) / (1. + cf) return [iteration_count + 1., cf] # For the Perron CF, we omit the first numerator because it # has a different form. _, perron_cf = tf.while_loop( cond=lambda i, _: i < total_perron_iteration_count - 1, body=_backward_cf_one_step, # Use 50 iterations. Empirically, the Perron continued fraction # converges much faster than this. loop_vars=[ tf.cast(0., dtype=dtype), tf.zeros_like(safe_z_greater_v) ]) first_term = -0.5 * z * (v + 0.5) / ((v + z / 2.) * (v + z + 0.5)) perron_cf = first_term / (1. + perron_cf) # Add the zeroth term for the Perron continued fraction. perron_zeroth_term = (z + 2 * v) / z perron_cf = tf.math.reciprocal(perron_zeroth_term * (1. + perron_cf)) result = tf.where(safe_to_use_perron, perron_cf, gauss_cf) def grad(dy): """Computes the derivative of the ratio elementwise with respect to z. For shorthand, let `I(v) = I(v, z)`, `R(v) = I(v, z) / I(v - 1, z)` ``` R'(v) = (I'(v)I(v - 1) - I(v)I'(v - 1)) / I(v - 1) ** 2 = 0.5 * ((I(v - 1) + I(v + 1))I(v - 1) - I(v)( I(v) + I(v - 2))) / I(v - 1) ** 2 = 0.5 * (1. + I(v + 1) / I(v - 1) - (I(v) / I(v - 1)) ** 2 - ( I(v) / I(v - 1)) * (I(v - 2) / I(v - 1))) = 0.5 * (1. + R(v + 1) * R(v) - R(v) ** 2 - R(v) / R(v - 1)) = 0.5 * (1. + R(v) * (R(v + 1) - R(v) - 1. / R(v - 1))) ``` To avoid computing R(v - 1) when v <= 1 (which is not valid), we can rewrite `I(v - 2) = 2 (v - 1) / z * I(v - 1) + I(v)`. Thus the last term becomes: ``` -1. / R(v - 1) = -I(v - 2) / I(v - 1) = -2 (v - 1) / z - R(v) ``` Args: dy: A Tensor with type `float32` or `float64`. Returns: A Tensor with same shape and dtype as `z`. """ grad_z = 0.5 * (1. + result * (bessel_iv_ratio(v + 1., z) - 2. * result - 2. * (v - 1) / z)) * dy # We don't have an easily computable gradient with respect to v at the # moment, so ignore that for now. _, grad_z = _fix_gradient_for_broadcasting(v, z, tf.ones_like(grad_z), grad_z) return None, grad_z return result, grad
def _sample_multinomial_as_iterated_binomial( num_samples, num_classes, probs, num_trials, dtype, seed): """Sample a multinomial by drawing one binomial sample per class. The batch shape is given by broadcasting num_trials with remove_last_dimension(probs). The loop over binomial samples is a `tf.while_loop`, thus supporting a dynamic number of classes. Args: num_samples: Singleton integer Tensor: number of multinomial samples to draw. num_classes: Singleton integer Tensor: number of classes. probs: Floating Tensor with last dimension `num_classes`, of normalized probabilities per class. num_trials: Tensor of number of categorical trials each multinomial consists of. num_trials[..., tf.newaxis] must broadcast with probs. dtype: dtype at which to emit samples. seed: Random seed. Returns: samples: Tensor of given dtype and shape [num_samples] + batch_shape + [num_classes]. """ with tf.name_scope('draw_sample'): # `convert_to_tensor(num_classes) here to avoid unstacking inside # `split_seed`. We can't take advantage of the Python-list code path anyway # because the index at which we will take the seed is a Tensor. seeds = samplers.split_seed( seed, n=tf.convert_to_tensor(num_classes), salt='multinomial_draw_sample') def fn(i, num_trials, consumed_prob, accum): """Sample the counts for one class using binomial.""" probs_here = tf.gather(probs, i, axis=-1) binomial_probs = tf.clip_by_value(probs_here / (1. - consumed_prob), 0, 1) seed_here = tf.gather(seeds, i, axis=0) binom = binomial.Binomial(total_count=num_trials, probs=binomial_probs) # Not passing `num_samples` to `binom.sample`, as it's is already in # `num_trials.shape`. sample = binom.sample(seed=seed_here) accum = accum.write(i, tf.cast(sample, dtype=dtype)) return i + 1, num_trials - sample, consumed_prob + probs_here, accum num_trials = tf.cast(num_trials, probs.dtype) # Pre-broadcast with probs num_trials = num_trials + tf.zeros_like(probs[..., 0]) # Pre-enlarge for different output samples num_trials = _replicate_along_left(num_trials, num_samples) i = tf.constant(0) consumed_prob = tf.zeros_like(probs[..., 0]) accum = tf.TensorArray( dtype, size=num_classes, element_shape=num_trials.shape) _, num_trials_left, _, accum = tf.while_loop( cond=lambda index, _0, _1, _2: tf.less(index, num_classes - 1), body=fn, loop_vars=(i, num_trials, consumed_prob, accum)) # Force the last iteration to put all the trials into the last bucket, # because probs[..., -1] / (1. - consumed_prob) might numerically not be 1. # Also saves one iteration around the while_loop and one run of the binomial # sampler. accum = accum.write(num_classes - 1, tf.cast(num_trials_left, dtype=dtype)) # This stop_gradient is necessary to prevent spurious zero gradients coming # from b/138796859, and a spurious gradient through num_trials_left. results = tf.stop_gradient(accum.stack()) return distribution_util.move_dimension(results, 0, -1)
def make_convolution_transpose_fn_with_subkernels(filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, validate_args=False, name=None): """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): # Enable v2 control flow to avoid None gradients through TensorArray. tf.compat.v1.enable_control_flow_v2() if tf.get_static_value(rank) != 2: raise NotImplementedError( 'Argument `rank` currently only supports `2`; ' 'saw "{}".'.format(rank)) [ filter_shape, rank, strides, padding, dilations, ] = prepare_conv_args(filter_shape, rank=rank, strides=strides, padding=padding, dilations=dilations, is_transpose=True, validate_args=validate_args) sh, sw = strides fh, fw = filter_shape dh, dw = dilations # Determine maximum filter height and filter width of sub-kernels. sub_fh = (fh - 1) // sh + 1 sub_fw = (fw - 1) // sw + 1 def loop_body(i_, kernels_ind): i = i_ // sw j = i_ % sw i_ind = ps.range(i * fw, ps.maximum(i, fh) * fw, delta=sh * fw, dtype=dtype) j_ind = ps.range(j, ps.maximum(j, fw), delta=sw, dtype=dtype) last_j = sw - (fw - j - 1) % sw - 1 last_i = sh - (fh - i - 1) % sh - 1 pos = last_i * sw + last_j nc = cartesian_add([i_ind, j_ind]) kernels_ind = kernels_ind.write( pos, ps.reverse(ps.reverse(nc, [0]), [1])) return i_ + 1, kernels_ind kernels_ind = tf.TensorArray(dtype=dtype, infer_shape=False, size=sh * sw) _, kernels_ind = tf.while_loop(lambda i, _: i < sh * sw, loop_body, [0, kernels_ind]) tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding( fh, stride=sh, dilation=dh, padding=padding) tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding( fw, stride=sw, dilation=dw, padding=padding) pad_bottom = (tot_pad_bottom - 1) // sh + 1 pad_top = (tot_pad_top - 1) // sh + 1 pad_right = (tot_pad_right - 1) // sw + 1 pad_left = (tot_pad_left - 1) // sw + 1 padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right)) truncate_top = pad_top * sh - tot_pad_top truncate_left = pad_left * sw - tot_pad_left def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, c_out, batch_shape, event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1 ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1 def loop_body(i, outputs): subkernel_ind = kernels_ind.read(i) fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2) eh = ex_h + fh_ - 1 ew = ex_w + fw_ - 1 subkernel_ind = ps.reshape(ps.reshape( subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] + ps.range(c_in), shape=[-1]) k = tf.gather(kernel, subkernel_ind, axis=-2) ind, shape = im2row_index([eh, ew, c_in], block_shape=(fh_, fw_), slice_step=(1, 1), dilations=dilations) x_i = x_pad[..., :eh, :ew, :] x_i_shape = ps.shape(x_i) flat_shape = ps.pad(x_i_shape[:-3], paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_i, flat_shape) x_ = tf.gather(flat_x, ind, axis=-1) im_x = tf.reshape( x_, ps.concat([x_i_shape[:-3], shape], axis=0)) outputs = outputs.write( i, tf.matmul( im_x, tf.reshape( k, ps.concat([ kernel_batch, [1, fh_ * fw_ * c_in, c_out] ], axis=0)))) return i + 1, outputs outputs = tf.TensorArray(dtype=input_dtype, size=sh * sw) _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body, [0, outputs]) y = outputs.concat() m = tf.reduce_prod(ps.shape(y)[:-3]) y_ = tf.reshape(y, shape=ps.concat([[m], ps.shape(y)[-3:]], axis=0)) y2 = tf.batch_to_space(y_, strides, crops=tf.zeros([2, 2], dtype=tf.int64)) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) y2 = tf.reshape( y2, ps.concat([broadcast_batch_shape, ps.shape(y2)[-3:]], axis=0)) out_height = _deconv_output_length(xh, filter_size=fh, padding=padding, output_padding=None, stride=sh, dilation=dh) out_width = _deconv_output_length(xw, filter_size=fw, padding=padding, output_padding=None, stride=sw, dilation=dw) return y2[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :] return op
def _solve( self, ode_fn, initial_time, initial_state, solution_times, jacobian_fn=None, jacobian_sparsity=None, batch_ndims=None, previous_solver_internal_state=None, ): # Static assertions del jacobian_fn, jacobian_sparsity # not used by DormandPrince if batch_ndims is not None and batch_ndims != 0: raise NotImplementedError( 'For homogeneous batching use `batch_ndims=0`.') solution_times_by_solver = isinstance(solution_times, base.ChosenBySolver) with tf.name_scope(self._name): # (2) Convert to tensors, determined dtypes. get_dtype = lambda x: x.dtype error_if_wrong_dtype = functools.partial( util.error_if_not_real_or_complex, identifier='initial_state') initial_state = tf.nest.map_structure(tf.convert_to_tensor, initial_state) tf.nest.map_structure(error_if_wrong_dtype, initial_state) state_dtypes = tf.nest.map_structure(get_dtype, initial_state) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) initial_time = tf.cast(initial_time, real_dtype) max_num_steps = self._max_num_steps max_ode_fn_evals = self._max_num_steps if max_num_steps is not None: max_num_steps = tf.convert_to_tensor(max_num_steps, dtype=tf.int32) max_ode_fn_evals = max_num_steps * self.ODE_FN_EVALS_PER_STEP step_size = tf.convert_to_tensor(self._first_step_size, dtype=real_dtype) rtol = tf.convert_to_tensor(tf.cast(self._rtol, real_dtype)) atol = tf.convert_to_tensor(tf.cast(self._atol, real_dtype)) safety = tf.convert_to_tensor(self._safety_factor, dtype=real_dtype) # Use i(d)factor notation for increasing and decreasing factors. ifactor, dfactor = self._max_step_size_factor, self._min_step_size_factor ifactor = tf.convert_to_tensor(ifactor, dtype=real_dtype) dfactor = tf.convert_to_tensor(dfactor, dtype=real_dtype) solver_internal_state = previous_solver_internal_state if solver_internal_state is None: initial_derivative = ode_fn(initial_time, initial_state) initial_derivative = tf.nest.map_structure( tf.convert_to_tensor, initial_derivative) solver_internal_state = _RungeKuttaSolverInternalState( current_state=initial_state, current_derivative=initial_derivative, last_step_start=initial_time, current_time=initial_time, step_size=step_size, interpolating_coefficients=[initial_state] * self.ORDER) num_solution_times = 0 if solution_times_by_solver: final_time = tf.cast(solution_times.final_time, real_dtype) times_array = tf.TensorArray(real_dtype, size=num_solution_times, dynamic_size=True, element_shape=tf.TensorShape([])) else: solution_times = tf.cast(solution_times, real_dtype) util.error_if_not_vector(solution_times, 'solution_times') num_solution_times = tf.size(solution_times) times_array = tf.TensorArray( real_dtype, size=num_solution_times, dynamic_size=False, element_shape=[]).unstack(solution_times) solutions_arrays = [ tf.TensorArray(dtype=component_dtype, size=num_solution_times, dynamic_size=solution_times_by_solver) for component_dtype in tf.nest.flatten(state_dtypes) ] solutions_arrays = tf.nest.pack_sequence_as( initial_state, solutions_arrays) rk_step = functools.partial(self._step, max_ode_fn_evals=max_ode_fn_evals, ode_fn=ode_fn, atol=atol, rtol=rtol, safety=safety, ifactor=ifactor, dfactor=dfactor) advance_to_solution_time = functools.partial( _advance_to_solution_time, times_array=solution_times, step_fn=rk_step, validate_args=self._validate_args) assert_ops = self._assert_ops( ode_fn=ode_fn, initial_time=initial_time, initial_state=initial_state, solution_times=solution_times, previous_solver_state=previous_solver_internal_state, rtol=rtol, atol=atol, first_step_size=step_size, safety_factor=safety, min_step_size_factor=ifactor, max_step_size_factor=dfactor, max_num_steps=max_num_steps, solution_times_by_solver=solution_times_by_solver) with tf.control_dependencies(assert_ops): ode_evals_by_now = 1 if self._validate_args else 0 ode_evals_by_now += 1 if solver_internal_state is None else 0 diagnostics = _DopriDiagnostics( num_ode_fn_evaluations=ode_evals_by_now, num_jacobian_evaluations=0, num_matrix_factorizations=0, status=0) if solution_times_by_solver: r = _dense_solutions_to_final_time( final_time=final_time, solver_state=solver_internal_state, diagnostics=diagnostics, step_fn=rk_step, ode_fn=ode_fn, times_array=times_array, solutions_arrays=solutions_arrays, validate_args=self._validate_args) solver_internal_state, diagnostics, times_array, solutions_arrays = r else: def iterate_cond(time_id, *_): return time_id < num_solution_times [_, solver_internal_state, diagnostics, solutions_arrays ] = tf.while_loop(iterate_cond, advance_to_solution_time, [ 0, solver_internal_state, diagnostics, solutions_arrays ], back_prop=False) times = times_array.stack() stack_components = lambda x: x.stack() states = tf.nest.map_structure(stack_components, solutions_arrays) return base.Results( times=times, states=states, diagnostics=diagnostics, solver_internal_state=solver_internal_state)
def _dense_solutions_to_final_time(final_time, solver_state, diagnostics, step_fn, ode_fn, times_array, solutions_arrays, validate_args=False): """Integrates `solver_state` to `final_time`. Performs integration of the `solver_state` to `final_time` while saving solutions at all intermediate time steps. This corresponds to the expected behavior of `ChosenBySolver` option. The solution at `final_time` is obtained by interpolation and is set as a final state of the solver. Args: final_time: Floating `Tensor` representing the final time of integration. solver_state: `_DopriSolverInternalState` - initial solver state. diagnostics: `_DopriDiagnostics` - info on the current `_solve` call. step_fn: Partial `Dopri._step` method that performs a single step updating the `solver_state`, `diagnostics` and `solver_state`. ode_fn: Callable(t, y) -> dy_dt. times_array: `TensorArray` where time values are recorded. solutions_arrays: `TensorArray`s where solutions are recorded. validate_args: Python `bool` indicating whether to validate inputs. Default value: False. Returns: solver_state: `_RungeKuttaSolverInternalState` holding final solver state. diagnostics: `_DopriDiagnostics` holding diagnostic values. times_array: `TensorArray` with recorded solution times. solutions_arrays: `TensorArray`s with solution values at time corresponding to times_array. """ def step_and_record(solver_state, diagnostics, solutions_arrays, times_array): y = solver_state.current_state time_id = times_array.size() solutions_arrays = _write_solution_components(y, solutions_arrays, time_id) times_array = times_array.write(time_id, solver_state.current_time) solver_state, diagnostics = step_fn(solver_state, diagnostics) return (solver_state, diagnostics, solutions_arrays, times_array) def step_cond(solver_internal_state, *_): return solver_internal_state.current_time <= final_time [solver_state, diagnostics, solutions_arrays, times_array] = tf.while_loop( step_cond, step_and_record, [solver_state, diagnostics, solutions_arrays, times_array], back_prop=False) # Interpolating the last time point, updating the state and write results. y, coefficients = _interpolate_solution_at(final_time, solver_state, validate_args) dy_dt = ode_fn(final_time, y) dy_dt = tf.nest.map_structure(tf.convert_to_tensor, dy_dt) time_id = times_array.size() times_array = times_array.write(time_id, final_time) solutions_arrays = _write_solution_components(y, solutions_arrays, time_id) solver_state = _RungeKuttaSolverInternalState( current_state=y, current_derivative=dy_dt, last_step_start=solver_state.last_step_start, current_time=final_time, step_size=solver_state.step_size, interpolating_coefficients=coefficients) return solver_state, diagnostics, times_array, solutions_arrays
def iterative_mergesort(y, permutation, name=None): """Non-recusive mergesort that counts exchanges. Args: y: a `Tensor` of shape `[n]` containing values to be sorted. permutation: `Tensor` of shape `[n]` with original ordering. name: Optional Python `str` name for ops created by this method. Default value: `None` (i.e., 'iterative_mergesort'). Returns: exchanges: `int32` scalar that counts the number of exchanges required to produce a sorted permutation permutation: and a `tf.int32` Tensor that contains the ordering of y values that are sorted. """ with tf.name_scope(name or 'iterative_mergesort'): y = tf.convert_to_tensor(y, name='y') permutation = tf.convert_to_tensor( permutation, name='permutation', dtype=tf.int32) shape = permutation.shape tensorshape_util.assert_is_compatible_with(y.shape, shape) n = ps.size(y) def outer_body(k, exchanges, permutation): # The outer body progressively merges lists as k grows by powers of 2, # tracking the total swaps required in exchanges as the new permutation is # built in place. y_ordered = tf.gather(y, permutation) def middle_body(left, exchanges, permutation): # the middle body advances through the sublists of size k, advancing # the left edge until the end of the input is reached. right = left + k end = tf.minimum(right + k, n) # See explanation here # https://www.geeksforgeeks.org/counting-inversions/. def inner_body(i, j, x, np, p): # The [left, right) and [right, end) lists are merged sorted, with # i and j tracking the advance through each range. x records the # number of order (bubble-sort equivalent) swaps that are happening # with each insertion, and np represents the size of the output # permutation that's been filled in using the p tensor. y_less = y_ordered[i] <= y_ordered[j] element = tf.where(y_less, [permutation[i]], [permutation[j]]) new_p = tf.concat([p[0:np], element, p[np + 1:n]], axis=0) tensorshape_util.set_shape(new_p, p.shape) return (tf.where(y_less, i + 1, i), tf.where(y_less, j, j + 1), tf.where(y_less, x, x + right - i), np + 1, new_p) i_j_x_np_p = (left, right, exchanges, 0, tf.zeros([n], dtype=tf.int32)) (i, j, exchanges, np, p) = tf.while_loop( cond=lambda i, j, x, np, p: tf.math.logical_and(i < right, j < end), body=inner_body, loop_vars=i_j_x_np_p) permutation = tf.concat([ permutation[0:left], p[0:np], permutation[i:right], permutation[j:end], permutation[end:n] ], axis=0) tensorshape_util.set_shape(permutation, shape) return left + 2 * k, exchanges, permutation _, exchanges, permutation = tf.while_loop( cond=lambda left, exchanges, permutation: left < n - k, body=middle_body, loop_vars=(0, exchanges, permutation)) k *= 2 return k, exchanges, permutation _, exchanges, permutation = tf.while_loop( cond=lambda k, exchanges, permutation: k < n, body=outer_body, loop_vars=(1, 0, permutation)) return exchanges, permutation
def grad_fn(*dresults, **kwargs): """Adjoint sensitivity method to compute gradients.""" dresults = tf.nest.pack_sequence_as(results, dresults) dstates = dresults.states # The signature grad_fn(*dresults, variables=None) is not valid Python 2 # so use kwargs instead. variables = kwargs.pop('variables', []) assert not kwargs # This assert should never fail. # TODO(b/138304303): Support complex types. with tf.name_scope('{}Gradients'.format(self._name)): get_dtype = lambda x: x.dtype def error_if_complex(dtype): if dtype.is_complex: raise NotImplementedError( 'The adjoint sensitivity method does ' 'not support complex dtypes.') state_dtypes = tf.nest.map_structure( get_dtype, initial_state) tf.nest.map_structure(error_if_complex, state_dtypes) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) # We add initial_time to ensure that we know where to stop. result_times = tf.concat( [[tf.cast(initial_time, real_dtype)], results.times], 0) num_result_times = tf.size(result_times) # First two components correspond to reverse and adjoint states. # the last component is adjoint state for variables. terminal_augmented_state = tuple([ rk_util.nest_constant(initial_state, 0.0), rk_util.nest_constant(initial_state, 0.0), tuple( rk_util.nest_constant(variable, 0.0) for variable in variables) ]) # The XLA compiler does not compile code which slices/indexes using # integer `Tensor`s. `TensorArray`s are used to get around this. result_time_array = tf.TensorArray( results.times.dtype, clear_after_read=False, size=num_result_times, element_shape=[]).unstack(result_times) # TensorArray shape should not include time dimension, hence shape[1:] result_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, element_shape=component.shape[1:]).unstack( component) for component in tf.nest.flatten(results.states) ] result_state_arrays = tf.nest.pack_sequence_as( results.states, result_state_arrays) dresult_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, element_shape=component.shape[1:]).unstack( component) for component in tf.nest.flatten(dstates) ] dresult_state_arrays = tf.nest.pack_sequence_as( results.states, dresult_state_arrays) def augmented_ode_fn(backward_time, augmented_state): """Dynamics function for the augmented system. Describes a differential equation that evolves the augmented state backwards in time to compute gradients using the adjoint method. Augmented state consists of 3 components `(state, adjoint_state, vars)` all evaluated at time `backward_time`: state: represents the solution of user provided `ode_fn`. The structure coincides with the `initial_state`. adjoint_state: represents the solution of adjoint sensitivity differential equation as discussed below. Has the same structure and shape as `state`. vars: represent the solution of the adjoint equation for variable gradients. Represented as a `Tuple(Tensor, ...)` with as many tensors as there are `variables`. Adjoint sensitivity equation describes the gradient of the solution with respect to the value of the solution at previous time t. Its dynamics are given by d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z) Which is computed as: d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)] d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] where in the last line we moved adj(t)_j under derivative by removing gradient from it. Adjoint equation for the gradient with respect to every `tf.Variable` theta follows: d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta) = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] Args: backward_time: Floating `Tensor` representing current time. augmented_state: `Tuple(state, adjoint_state, variable_grads)` Returns: negative_derivatives: Structure of `Tensor`s equal to backwards time derivative of the `state` componnent. adjoint_ode: Structure of `Tensor`s equal to backwards time derivative of the `adjoint_state` component. adjoint_variables_ode: Structure of `Tensor`s equal to backwards time derivative of the `vars` component. """ # The negative signs disappears after the change of variables. # The ODE solver cannot handle the case initial_time > final_time # and hence a change of variables backward_time = -time is used. time = -backward_time state, adjoint_state, _ = augmented_state with tf.GradientTape() as tape: tape.watch(variables) tape.watch(state) derivatives = ode_fn(time, state) adjoint_no_grad = tf.nest.map_structure( tf.stop_gradient, adjoint_state) negative_derivatives = rk_util.weighted_sum( [-1.0], [derivatives]) def dot_prod(tensor_a, tensor_b): return tf.reduce_sum(tensor_a * tensor_b) # See docstring for details. adjoint_dot_derivatives = tf.nest.map_structure( dot_prod, adjoint_no_grad, derivatives) adjoint_dot_derivatives = tf.squeeze( tf.add_n( tf.nest.flatten(adjoint_dot_derivatives))) adjoint_ode, adjoint_variables_ode = tape.gradient( adjoint_dot_derivatives, (state, tuple(variables)), unconnected_gradients=tf.UnconnectedGradients.ZERO) return negative_derivatives, adjoint_ode, adjoint_variables_ode def reverse_to_result_time(n, augmented_state, _): """Integrates the augmented system backwards in time.""" lower_bound_of_integration = result_time_array.read(n) upper_bound_of_integration = result_time_array.read(n - 1) _, adjoint_state, adjoint_variable_state = augmented_state initial_state = _read_solution_components( result_state_arrays, input_state_structure, n - 1) initial_adjoint = _read_solution_components( dresult_state_arrays, input_state_structure, n - 1) initial_adjoint_state = rk_util.weighted_sum( [1.0, 1.0], [adjoint_state, initial_adjoint]) initial_augmented_state = (initial_state, initial_adjoint_state, adjoint_variable_state) augmented_results = self._solve( ode_fn=augmented_ode_fn, initial_time=-lower_bound_of_integration, initial_state=initial_augmented_state, solution_times=[-upper_bound_of_integration], batch_ndims=batch_ndims) # Results added an extra time dim of size 1, squeeze it. select_result = lambda x: tf.squeeze(x, [0]) result_state = augmented_results.states result_state = tf.nest.map_structure( select_result, result_state) status = augmented_results.diagnostics.status return n - 1, result_state, status _, augmented_state, _ = tf.while_loop( lambda n, _, status: (n >= 1) & tf.equal(status, 0), reverse_to_result_time, (num_result_times - 1, terminal_augmented_state, 0), ) _, adjoint_state, adjoint_variables = augmented_state return adjoint_state, list(adjoint_variables)
def kendalls_tau(y_true, y_pred, name=None): """Computes Kendall's Tau for two ordered lists. Kendall's Tau measures the correlation between ordinal rankings. This implementation is similar to the one used in scipy.stats.kendalltau. The provided values may be of any type that is sortable, with the argsort indices indicating the true or proposed ordinal sequence. Args: y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking. y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the same N items. name: Optional Python `str` name for ops created by this method. Default value: `None` (i.e., 'kendalls_tau'). Returns: kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores ordering of ties, as a `float32` scalar Tensor. """ with tf.name_scope(name or 'kendalls_tau'): in_type = dtype_util.common_dtype([y_true, y_pred], dtype_hint=tf.float32) y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type) y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type) tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape) assertions = [ assert_util.assert_rank(y_true, 1), assert_util.assert_greater( ps.size(y_true), 1, 'Ordering requires at least 2 elements.') ] with tf.control_dependencies(assertions): lexa = lexicographical_indirect_sort(y_true, y_pred) # See A Computer Method for Calculating Kendall's Tau with Ungrouped Data # by William Night, Journal of the American Statistical Association, # Jun., 1966, Vol. 61, No. 314, Part 1 (Jun., 1966), pp. 436-439 # for notation https://www.jstor.org/stable/2282833 def jointly_tied_pairs_body(first, t, i): not_equal = tf.math.logical_or( tf.not_equal(y_true[lexa[first]], y_true[lexa[i]]), tf.not_equal(y_pred[lexa[first]], y_pred[lexa[i]])) return (tf.where(not_equal, i, first), tf.where(not_equal, t + ((i - first) * (i - first - 1)) // 2, t), i + 1) n = ps.size0(y_true) first, t, _ = tf.while_loop( cond=lambda first, t, i: i < n, body=jointly_tied_pairs_body, loop_vars=(0, 0, 1)) t += ((n - first) * (n - first - 1)) // 2 def ties_y_true_body(first, v, i): not_equal = tf.not_equal(y_true[lexa[first]], y_true[lexa[i]]) return (tf.where(not_equal, i, first), tf.where(not_equal, v + ((i - first) * (i - first - 1)) // 2, v), i + 1) first, v, _ = tf.while_loop( cond=lambda first, v, i: i < n, body=ties_y_true_body, loop_vars=(0, 0, 1)) v += ((n - first) * (n - first - 1)) // 2 # count exchanges exchanges, newperm = iterative_mergesort(y_pred, lexa) def ties_in_y_pred_body(first, u, i): not_equal = tf.not_equal(y_pred[newperm[first]], y_pred[newperm[i]]) return (tf.where(not_equal, i, first), tf.where(not_equal, u + ((i - first) * (i - first - 1)) // 2, u), i + 1) first, u, _ = tf.while_loop( cond=lambda first, u, i: i < n, body=ties_in_y_pred_body, loop_vars=(0, 0, 1)) u += ((n - first) * (n - first - 1)) // 2 n0 = (n * (n - 1)) // 2 assertions = [ assert_util.assert_less(v, tf.cast(n0, tf.int32), 'All ranks are ties for y_true.'), assert_util.assert_less(u, tf.cast(n0, tf.int32), 'All ranks are ties for y_pred.') ] with tf.control_dependencies(assertions): return (tf.cast(n0 - (u + v - t), tf.float32) - 2.0 * tf.cast(exchanges, tf.float32)) / tf.math.sqrt( tf.cast(n0 - v, tf.float32) * tf.cast(n0 - u, tf.float32))
def _sample_with_shrinkage(x_initial, target_log_prob, log_slice_heights, step_size, lower_bounds, upper_bounds, seed, name=None): """Samples from the slice by applying shrinkage for rejected points. Implements the one dimensional slice sampling algorithm of Neal (2003), with a doubling algorithm (Neal 2003 P715 Fig. 4), which doubles the size of the interval at each iteration and shrinkage (Neal 2003 P716 Fig. 5), which reduces the width of the slice when a selected point is rejected, by setting the relevant bound that that value. Randomly sampled points are checked for two criteria: that they lie within the slice and that they pass the acceptability check (Neal 2003 P717 Fig. 6), which tests that the new state could have generated the previous one. Args: x_initial: A tensor of any shape. The initial positions of the chains. This function assumes that all the dimensions of `x_initial` are batch dimensions (i.e. the event shape is `[]`). target_log_prob: Callable accepting a tensor like `x_initial` and returning a tensor containing the log density at that point of the same shape. log_slice_heights: Tensor of the same shape and dtype as the return value of `target_log_prob` when applied to `x_initial`. The log of the height of the chosen slice. step_size: A tensor of shape and dtype compatible with `x_initial`. The min interval size in the doubling algorithm. lower_bounds: Tensor of same shape and dtype as `x_initial`. Slice lower bounds for each chain. upper_bounds: Tensor of same shape and dtype as `x_initial`. Slice upper bounds for each chain. seed: Tensor seed pair. The random seed. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'find_slice_bounds'). Returns: x_proposed: A tensor of the same shape and dtype as `x_initial`. The next proposed state of the chain. """ with tf.name_scope(name or 'sample_with_shrinkage'): seed = samplers.sanitize_seed(seed) # Keeps track of whether an acceptable sample has been found for the chain. found = tf.zeros_like(x_initial, dtype=tf.bool) cond = lambda found, *ignored_args: ~tf.reduce_all(found) x_next = tf.identity(x_initial) x_initial_shape = ps.shape(x_initial) x_initial_dtype = dtype_util.base_dtype(x_initial.dtype) def _body(found, seed, left, right, x_next): """Iterates until every chain has found a suitable next state.""" proportions_seed, next_seed = samplers.split_seed(seed) proportions = samplers.uniform(x_initial_shape, dtype=x_initial_dtype, seed=proportions_seed) x_proposed = tf.where(~found, left + proportions * (right - left), x_next) accept_res = _test_acceptance(x_initial, target_log_prob=target_log_prob, decided=found, log_slice_heights=log_slice_heights, x_proposed=x_proposed, step_size=step_size, lower_bounds=left, upper_bounds=right) boundary_test = log_slice_heights < target_log_prob(x_proposed) can_accept = boundary_test & accept_res next_found = found | can_accept # Note that it might seem that we are moving the left and right end points # even if the point has been accepted (which is contrary to the stated # algorithm in Neal). However, this does not matter because the endpoints # for points that have been already accepted are not used again so it # doesn't matter what we do with them. next_left = tf.where(x_proposed < x_initial, x_proposed, left) next_right = tf.where(x_proposed >= x_initial, x_proposed, right) return (next_found, next_seed, next_left, next_right, x_proposed) return tf.while_loop(cond=cond, body=_body, loop_vars=(found, seed, lower_bounds, upper_bounds, x_next))[-1]
def _test_acceptance(x_initial, target_log_prob, decided, log_slice_heights, x_proposed, step_size, lower_bounds, upper_bounds, name=None): """Ensures the chosen point does not violate reversibility. Implements Fig 6 of Neal 2003 page 717, which checks that the path from the existing point to the new point would also have been possible in reverse. This is done by checking that the algorithm would not have been terminated before reaching the old point. Args: x_initial: A tensor of any shape and real dtype. The initial positions of the chains. This function assumes that all the dimensions of `x_initial` are batch dimensions (i.e. the event shape is `[]`). target_log_prob: Callable accepting a tensor like `x_initial` and returning a tensor containing the log density at that point of the same shape. decided: A `tf.bool` tensor of the same shape as `x_initial`. Indicates whether the acceptance has already been decided. A point is tested only if `decided` for that point is False. log_slice_heights: Tensor of the same shape and dtype as the return value of `target_log_prob` when applied to `x_initial`. The log of the height of the chosen slice. x_proposed: A tensor of the same shape and dtype as `x_initial`. The proposed points. step_size: A tensor of shape and dtype compatible with `x_initial`. The min interval size in the doubling algorithm. lower_bounds: Tensor of same shape and dtype as `x_initial`. Slice lower bounds for each chain. upper_bounds: Tensor of same shape and dtype as `x_initial`. Slice upper bounds for each chain. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'find_slice_bounds'). Returns: acceptable: A boolean tensor of same shape as `x_initial` indicating whether the proposed points are acceptable for reversibility or not. """ with tf.name_scope(name or 'test_acceptance'): d = tf.zeros_like(x_initial, dtype=tf.bool) # Keeps track of points for which the loop has "effectively terminated". # Termination is when either their interval width has shrunk to the minimum # value (step_size) or if the point has already been rejected. def cond(_, decided, *ignored_args): # pylint: disable=unused-argument # Continue until all the points have been decided. return ~tf.reduce_all(decided) acceptable = tf.ones_like(x_initial, dtype=tf.bool) def body(acceptable, decided, left, right, d): """Checks reversibility as described on P717 of Neal 2003.""" midpoint = (left + right) / 2 divided = (((x_initial < midpoint) & (x_proposed >= midpoint)) | ((x_proposed < midpoint) & (x_initial >= midpoint))) next_d = d | divided next_right = tf.where(x_proposed < midpoint, midpoint, right) next_left = tf.where(x_proposed >= midpoint, midpoint, left) left_test = (log_slice_heights >= target_log_prob(next_left)) right_test = (log_slice_heights >= target_log_prob(next_right)) unacceptable = next_d & left_test & right_test # Logic here: For points which have not already been decided, # and are unacceptable, set acceptable to False. For others, let them # be as they were. now_decided = ~decided & unacceptable next_acceptable = tf.where(now_decided, ~unacceptable, acceptable) # Decided if (a) was already decided, or # (b) the new width is less than 1.1 step_size, or # (c) was marked unacceptable. next_decided = (decided | (next_right - next_left <= 1.1 * step_size) | now_decided) return (next_acceptable, next_decided, next_left, next_right, next_d) return tf.while_loop(cond=cond, body=body, loop_vars=(acceptable, decided, lower_bounds, upper_bounds, d))[0]
def fn(): x = np.asarray(0) c = lambda x: x < 10000 b = lambda x: [x + 1] return tf.while_loop(c, b, [x], parallel_iterations=20)
def sample_annealed_importance_chain(num_steps, proposal_log_prob_fn, target_log_prob_fn, current_state, make_kernel_fn, parallel_iterations=10, seed=None, name=None): """Runs annealed importance sampling (AIS) to estimate normalizing constants. This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo) to sample from a series of distributions that slowly interpolates between an initial 'proposal' distribution: `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` and the target distribution: `exp(target_log_prob_fn(x) - target_log_normalizer)`, accumulating importance weights along the way. The product of these importance weights gives an unbiased estimate of the ratio of the normalizing constants of the initial distribution and the target distribution: `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. Note: When running in graph mode, `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three times (although this may be reduced to two times in the future). Args: num_steps: Integer number of Markov chain updates to run. More iterations means more expense, but smoother annealing between q and p, which in turn means exponentially lower variance for the normalizing constant estimator. proposal_log_prob_fn: Python callable that returns the log density of the initial distribution. target_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the target distribution. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like object. Must take one argument representing the `TransitionKernel`'s `target_log_prob_fn`. The `target_log_prob_fn` argument represents the `TransitionKernel`'s target log distribution. Note: `sample_annealed_importance_chain` creates a new `target_log_prob_fn` which is an interpolation between the supplied `target_log_prob_fn` and `proposal_log_prob_fn`; it is this interpolated function which is used as an argument to `make_kernel_fn`. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'sample_annealed_importance_chain'). Returns: next_state: `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at the final iteration. Has same shape as input `current_state`. ais_weights: Tensor with the estimated weight(s). Has shape matching `target_log_prob_fn(current_state)`. kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. #### Examples ##### Estimate the normalizing constant of a log-gamma distribution. ```python tfd = tfp.distributions # Run 100 AIS chains in parallel num_chains = 100 dims = 20 dtype = np.float32 proposal = tfd.MultivariateNormalDiag( loc=tf.zeros([dims], dtype=dtype)) target = tfd.TransformedDistribution( distribution=tfd.Sample( tfd.Gamma(concentration=dtype(2), rate=dtype(3)), sample_shape=[dims]) bijector=tfp.bijectors.Invert(tfp.bijectors.Exp())) chains_state, ais_weights, kernels_results = ( tfp.mcmc.sample_annealed_importance_chain( num_steps=1000, proposal_log_prob_fn=proposal.log_prob, target_log_prob_fn=target.log_prob, current_state=proposal.sample(num_chains), make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=tlp_fn, step_size=0.2, num_leapfrog_steps=2))) log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) - np.log(num_chains)) log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) ``` ##### Estimate marginal likelihood of a Bayesian regression model. ```python tfd = tfp.distributions def make_prior(dims, dtype): return tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)) def make_likelihood(weights, x): return tfd.MultivariateNormalDiag( loc=tf.tensordot(weights, x, axes=[[0], [-1]])) # Run 100 AIS chains in parallel num_chains = 100 dims = 10 dtype = np.float32 # Make training data. x = np.random.randn(num_chains, dims).astype(dtype) true_weights = np.random.randn(dims).astype(dtype) y = np.dot(x, true_weights) + np.random.randn(num_chains) # Setup model. prior = make_prior(dims, dtype) def target_log_prob_fn(weights): return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) proposal = tfd.MultivariateNormalDiag( loc=tf.zeros(dims, dtype)) weight_samples, ais_weights, kernel_results = ( tfp.mcmc.sample_annealed_importance_chain( num_steps=1000, proposal_log_prob_fn=proposal.log_prob, target_log_prob_fn=target_log_prob_fn current_state=tf.zeros([num_chains, dims], dtype), make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=tlp_fn, step_size=0.1, num_leapfrog_steps=2))) log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) - np.log(num_chains)) ``` """ is_seeded = seed is not None seed = samplers.sanitize_seed(seed, salt='mcmc.sample_ais_chain') with tf.name_scope(name or 'sample_annealed_importance_chain'): num_steps = tf.convert_to_tensor(value=num_steps, dtype=tf.int32, name='num_steps') if mcmc_util.is_list_like(current_state): current_state = [ tf.convert_to_tensor(s, name='current_state') for s in current_state ] else: current_state = tf.convert_to_tensor(value=current_state, name='current_state') def _make_convex_combined_log_prob_fn(iter_): def _fn(*args): p = tf.identity(proposal_log_prob_fn(*args), name='proposal_log_prob') t = tf.identity(target_log_prob_fn(*args), name='target_log_prob') dtype = dtype_util.base_dtype(p.dtype) beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype) return tf.identity(beta * t + (1. - beta) * p, name='convex_combined_log_prob') return _fn def _loop_body(iter_, seed, ais_weights, current_state, kernel_results): """Closure which implements `tf.while_loop` body.""" iter_seed, next_seed = samplers.split_seed( seed, salt='ais_chain.seeded_one_step') if is_seeded else (seed, seed) x = (current_state if mcmc_util.is_list_like(current_state) else [current_state]) proposal_log_prob = proposal_log_prob_fn(*x) target_log_prob = target_log_prob_fn(*x) ais_weights += ((target_log_prob - proposal_log_prob) / tf.cast(num_steps, ais_weights.dtype)) kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_)) # TODO(b/147676843): Should we warn if the kernel is not calibrated? one_step_kwargs = dict(seed=iter_seed) if is_seeded else {} next_state, inner_results = kernel.one_step( current_state, kernel_results.inner_results, **one_step_kwargs) kernel_results = AISResults( proposal_log_prob=proposal_log_prob, target_log_prob=target_log_prob, inner_results=inner_results, ) return [ iter_ + 1, next_seed, ais_weights, next_state, kernel_results ] def _bootstrap_results(init_state): """Creates first version of `previous_kernel_results`.""" kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_=0)) inner_results = kernel.bootstrap_results(init_state) mh_results = _find_inner_mh_results(inner_results) convex_combined_log_prob = mh_results.accepted_results.target_log_prob dtype = dtype_util.as_numpy_dtype(convex_combined_log_prob.dtype) shape = tf.shape(convex_combined_log_prob) proposal_log_prob = tf.fill(shape, dtype(np.nan), name='bootstrap_proposal_log_prob') target_log_prob = tf.fill(shape, dtype(np.nan), name='target_target_log_prob') return AISResults( proposal_log_prob=proposal_log_prob, target_log_prob=target_log_prob, inner_results=inner_results, ) previous_kernel_results = _bootstrap_results(current_state) inner_results = previous_kernel_results.inner_results mh_results = _find_inner_mh_results(inner_results) ais_weights = tf.zeros( shape=tf.broadcast_dynamic_shape( tf.shape(mh_results.proposed_results.target_log_prob), tf.shape(mh_results.accepted_results.target_log_prob)), dtype=mh_results.proposed_results.target_log_prob.dtype) [_, _, ais_weights, current_state, kernel_results] = tf.while_loop( cond=lambda iter_, *args: iter_ < num_steps, body=_loop_body, loop_vars=[ np.int32(0), # iter_ seed, ais_weights, current_state, previous_kernel_results, ], parallel_iterations=parallel_iterations) return [current_state, ais_weights, kernel_results]
def option_price_binomial(*, volatilities, strikes, expiries, spots, discount_rates=None, dividend_rates=None, is_call_options=None, is_american=None, num_steps=100, dtype=None, name=None): """Computes the BS price for a batch of European or American options. Uses the Cox-Ross-Rubinstein version of the binomial tree method to compute the price of American or European options. Supports batching of the options and allows mixing of European and American style exercises in a batch. For more information about the binomial tree method and the Cox-Ross-Rubinstein method in particular see the references below. #### Example ```python # Prices 5 options with a mix of Call/Put, American/European features # in a single batch. dtype = np.float64 spots = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=dtype) strikes = np.array([3.0, 3.0, 3.0, 3.0, 3.0], dtype=dtype) volatilities = np.array([0.1, 0.22, 0.32, 0.01, 0.4], dtype=dtype) is_call_options = np.array([True, True, False, False, False]) is_american = np.array([False, True, True, False, True]) discount_rates = np.array(0.035, dtype=dtype) dividend_rates = np.array([0.02, 0.0, 0.07, 0.01, 0.0], dtype=dtype) expiries = np.array(1.0, dtype=dtype) prices = option_price_binomial( volatilities=volatilities, strikes=strikes, expiries=expiries, spots=spots, discount_rates=discount_rates, dividend_rates=dividend_rates, is_call_options=is_call_options, is_american=is_american, dtype=dtype) # Prints [0., 0.0098847, 0.41299509, 0., 0.06046989] ``` #### References [1] Hull, John C., Options, Futures and Other Derivatives. Pearson, 2018. [2] Wikipedia contributors. Binomial Options Pricing Model. Available at: https://en.wikipedia.org/wiki/Binomial_options_pricing_model Args: volatilities: Real `Tensor` of any shape and dtype. The volatilities to expiry of the options to price. strikes: A real `Tensor` of the same dtype and compatible shape as `volatilities`. The strikes of the options to be priced. expiries: A real `Tensor` of same dtype and compatible shape as `volatilities`. The expiry of each option. The units should be such that `expiry * volatility**2` is dimensionless. spots: A real `Tensor` of any shape that broadcasts to the shape of the `volatilities`. The current spot price of the underlying. discount_rates: An optional real `Tensor` of same dtype as the `volatilities`. The risk free discount rate. If None the rate is assumed to be 0. Default value: None, equivalent to discount rates = 0.. dividend_rates: An optional real `Tensor` of same dtype as the `volatilities`. If None the rate is assumed to be 0. Default value: None, equivalent to discount rates = 1. is_call_options: A boolean `Tensor` of a shape compatible with `volatilities`. Indicates whether the option is a call (if True) or a put (if False). If not supplied, call options are assumed. Default value: None, equivalent to is_call_options = True. is_american: A boolean `Tensor` of a shape compatible with `volatilities`. Indicates whether the option exercise style is American (if True) or European (if False). If not supplied, European style exercise is assumed. Default value: None, equivalent to is_american = False. num_steps: A positive scalar int32 `Tensor`. The size of the time discretization to use. Default value: 100. dtype: Optional `tf.DType`. If supplied, the dtype to be used for conversion of any supplied non-`Tensor` arguments to `Tensor`. Default value: None which maps to the default dtype inferred by TensorFlow (float32). name: str. The name for the ops created by this function. Default value: None which is mapped to the default name `option_price`. Returns: A `Tensor` of the same shape as the inferred batch shape of the input data. The Black Scholes price of the options computed on a binomial tree. """ with tf.name_scope(name or 'crr_option_price'): strikes = tf.convert_to_tensor(strikes, dtype=dtype, name='strikes') dtype = strikes.dtype volatilities = tf.convert_to_tensor(volatilities, dtype=dtype, name='volatilities') expiries = tf.convert_to_tensor(expiries, dtype=dtype, name='expiries') spots = tf.convert_to_tensor(spots, dtype=dtype, name='spots') if discount_rates is None: discount_rates = tf.zeros_like(volatilities) else: discount_rates = tf.convert_to_tensor(discount_rates, dtype=dtype, name='discount_rates') if dividend_rates is None: dividend_rates = tf.zeros_like(volatilities) else: dividend_rates = tf.convert_to_tensor(dividend_rates, dtype=dtype, name='dividend_rates') if is_call_options is None: is_call_options = tf.ones_like(volatilities, dtype=tf.bool, name='is_call_options') else: is_call_options = tf.convert_to_tensor(is_call_options, dtype=tf.bool, name='is_call_options') if is_american is None: is_american = tf.zeros_like(volatilities, dtype=tf.bool, name='is_american') else: is_american = tf.convert_to_tensor(is_american, dtype=tf.bool, name='is_american') num_steps = tf.cast(num_steps, dtype=dtype) dt = expiries / num_steps # CRR choices for the up and down move multipliers ln_up = volatilities * tf.math.sqrt(dt) ln_dn = -ln_up # Prepares the spot grid. grid_idx = tf.range(num_steps + 1) # Stores the grid as shape [input_batch, N + 1] where N = num_steps. log_spot_grid_1 = tf.expand_dims(tf.math.log(spots) + ln_up * num_steps, axis=-1) log_spot_grid_2 = tf.expand_dims(ln_dn - ln_up, axis=-1) * grid_idx log_spot_grid = log_spot_grid_1 + log_spot_grid_2 # Adding the new dimension is to ensure that batch shape is at the front. payoff_fn = _get_payoff_fn(tf.expand_dims(strikes, axis=-1), tf.expand_dims(is_call_options, axis=-1)) value_mod_fn = _get_value_modifier( tf.expand_dims(is_american, axis=-1), payoff_fn) # Shape [batch shape, num time steps + 1] values_grid = payoff_fn(tf.math.exp(log_spot_grid)) p_up = tf.math.exp((discount_rates - dividend_rates) * dt + ln_up) - 1 p_up /= tf.math.exp(2 * ln_up) - 1 p_up = tf.expand_dims(p_up, axis=-1) p_dn = 1 - p_up discount_factors = tf.expand_dims(tf.math.exp(-discount_rates * dt), axis=-1) ln_up = tf.expand_dims(ln_up, axis=-1) def one_step_back(current_values, current_log_spot_grid): next_values = (current_values[..., 1:] * p_dn + current_values[..., :-1] * p_up) next_log_spot_grid = current_log_spot_grid[..., :-1] - ln_up next_values = value_mod_fn(next_values, tf.math.exp(next_log_spot_grid)) return discount_factors * next_values, next_log_spot_grid def should_continue(current_values, current_log_spot_grid): del current_values, current_log_spot_grid return True batch_shape = values_grid.shape[:-1] pv, _ = tf.while_loop( should_continue, one_step_back, (values_grid, log_spot_grid), maximum_iterations=tf.cast(num_steps, dtype=tf.int32), shape_invariants=(tf.TensorShape(batch_shape + [None]), tf.TensorShape(batch_shape + [None]))) return tf.squeeze(pv, axis=-1)
def _solve( time_direction_fn, start_time, end_time, coord_grid, values_grid, num_steps=None, start_step_count=0, time_step=None, one_step_fn=None, boundary_conditions=None, values_transform_fn=None, second_order_coeff_fn=None, first_order_coeff_fn=None, zeroth_order_coeff_fn=None, inner_second_order_coeff_fn=None, inner_first_order_coeff_fn=None, maximum_steps=None, swap_memory=True, name=None): """Common code for solve_backward and solve_forward.""" if (num_steps is None) == (time_step is None): raise ValueError('Exactly one of num_steps or time_step' ' should be supplied.') coord_grid = [ tf.convert_to_tensor(dim_grid, dtype=values_grid.dtype) for dim_grid in coord_grid ] n_dims = len(coord_grid) if one_step_fn is None: if n_dims == 1: one_step_fn = oscillation_damped_crank_nicolson_step() else: one_step_fn = douglas_adi_step(theta=0.5) if boundary_conditions is None: def zero_dirichlet(t, grid): del t, grid return 1, None, tf.constant(0, dtype=values_grid.dtype) boundary_conditions = [(zero_dirichlet, zero_dirichlet)] * n_dims with tf.compat.v1.name_scope( name, default_name='solve', values=[ start_time, end_time, coord_grid, values_grid, num_steps, time_step, ]): time_step_fn, est_max_steps = _get_time_steps_info(start_time, end_time, num_steps, time_step, time_direction_fn) if est_max_steps is None and maximum_steps is not None: est_max_steps = maximum_steps def loop_cond(should_stop, time, x_grid, f_grid, steps_performed): del time, x_grid, f_grid, steps_performed return tf.logical_not(should_stop) def loop_body(should_stop, time, x_grid, f_grid, steps_performed): """Propagates the grid in time.""" del should_stop next_should_stop, t_next = time_step_fn(time) next_xs, next_fs = one_step_fn( time=time, next_time=t_next, coord_grid=x_grid, value_grid=f_grid, boundary_conditions=boundary_conditions, second_order_coeff_fn=second_order_coeff_fn, first_order_coeff_fn=first_order_coeff_fn, zeroth_order_coeff_fn=zeroth_order_coeff_fn, inner_second_order_coeff_fn=inner_second_order_coeff_fn, inner_first_order_coeff_fn=inner_first_order_coeff_fn, num_steps_performed=steps_performed) if values_transform_fn is not None: next_xs, next_fs = values_transform_fn(t_next, next_xs, next_fs) return next_should_stop, t_next, next_xs, next_fs, steps_performed + 1 # If the start time is already equal to end time, no stepping is needed. # solve_backward, solve_forward already took care of the case when end_time # is on the "wrong side" of start_time. should_already_stop = (start_time == end_time) initial_args = (should_already_stop, start_time, coord_grid, values_grid, start_step_count) (_, final_time, final_coords, final_values, steps_performed) = tf.while_loop( loop_cond, loop_body, initial_args, swap_memory=swap_memory, maximum_iterations=est_max_steps) return final_values, final_coords, final_time, steps_performed
def make_convolution_transpose_fn_with_subkernels_matrix( filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, validate_args=False, name=None): """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): if tf.get_static_value(rank) != 2: raise NotImplementedError( 'Argument `rank` currently only supports `2`; ' 'saw "{}".'.format(rank)) strides = tf.get_static_value(strides) if not isinstance(strides, int): raise ValueError( 'Argument `strides` must be a statically known integer.' 'Saw: {}'.format(strides)) [ filter_shape, rank, _, padding, dilations, ] = prepare_conv_args(filter_shape, rank=rank, strides=strides, padding=padding, dilations=dilations, is_transpose=True, validate_args=validate_args) fh, fw = filter_shape dh, dw = dilations # Determine maximum filter height and filter width of sub-kernels. sub_fh = (fh - 1) // strides + 1 sub_fw = (fw - 1) // strides + 1 def loop_body(i_, event_ind): i = i_ // strides j = i_ % strides i_ind = ps.range(i * fw, ps.maximum(i, fh) * fw, delta=strides * fw, dtype=dtype) j_ind = ps.range(j, ps.maximum(j, fw), delta=strides, dtype=dtype) nc = cartesian_add([i_ind, j_ind]) ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0]) k = ps.reshape(cartesian_add([ ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype), ps.range(ps.shape(nc)[1], dtype=dtype) ]), shape=[-1]) last_j = strides - (fw - j - 1) % strides - 1 last_i = strides - (fh - i - 1) % strides - 1 kernel_ind = ps.stack( [k, ps.ones_like(k) * last_i * strides + last_j], axis=1) event_ind = ps.tensor_scatter_nd_update(event_ind, ind[..., tf.newaxis], kernel_ind) return i_ + 1, event_ind event_ind = ps.zeros((fh * fw, 2), dtype=dtype) _, event_ind = tf.while_loop(lambda i, _: i < strides**2, loop_body, [tf.zeros([], dtype=dtype), event_ind]) tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding( fh, stride=strides, dilation=dh, padding=padding) tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding( fw, stride=strides, dilation=dw, padding=padding) pad_bottom = (tot_pad_bottom - 1) // strides + 1 pad_top = (tot_pad_top - 1) // strides + 1 pad_right = (tot_pad_right - 1) // strides + 1 pad_left = (tot_pad_left - 1) // strides + 1 padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right)) truncate_top = pad_top * strides - tot_pad_top truncate_left = pad_left * strides - tot_pad_left def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel=kernel, filter_shape=filter_shape, strides=(strides, ) * rank, padding=padding, dilations=dilations, c_out=c_out, batch_shape=batch_shape, event_shape=event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) x_pad_shape = ps.shape(x_pad)[:-3] flat_shape = ps.pad(x_pad_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_pad, shape=flat_shape) idx, s = im2row_index( (xh + tf.reduce_sum(padding_vals[0]), xw + tf.reduce_sum(padding_vals[1]), c_in), block_shape=(sub_fh, sub_fw), slice_step=(1, 1), dilations=dilations) x_ = tf.gather(flat_x, indices=idx, axis=-1) im_x = tf.reshape(x_, shape=ps.concat([x_pad_shape, s], axis=0)) # Add channels to subkernel indices idx_event = event_ind * [[c_in, 1]] idx_event_channels = (idx_event[tf.newaxis] + tf.stack( [ps.range(c_in), tf.zeros( (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :]) idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels, block_shape=[c_in], crops=[[0, 0]]), axis=0) idx_event_broadcast = tf.broadcast_to( idx_event, shape=ps.concat( [kernel_batch, ps.shape(idx_event)], axis=0)) # Add cartesian product of batch indices, since scatter_nd can only be # applied to leading dimensions. idx_batch = tf.stack(tf.meshgrid(*[ ps.range(b_, delta=1, dtype=dtype) for b_ in tf.unstack(kernel_batch) ], indexing='ij'), axis=ps.size(kernel_batch)) idx_batch = tf.cast(idx_batch, dtype=dtype) # empty tensor is float idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros( (ps.shape(idx_event)[0], 1), dtype=dtype) idx_kernel = tf.concat( [idx_batch_broadcast, idx_event_broadcast], axis=-1) kernel_mat = tf.scatter_nd( idx_kernel, updates=kernel, shape=ps.cast(ps.concat([ kernel_batch, [sub_fh * sub_fw * c_in, strides**2, c_out] ], axis=0), dtype=dtype)) kernel_mat = tf.reshape( kernel_mat, shape=ps.concat( [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]], axis=0)) kernel_mat = kernel_mat[..., tf.newaxis, :, :] out = tf.matmul(im_x, kernel_mat) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) if strides > 1: tot_size = tf.reduce_prod(broadcast_batch_shape) flat_out = tf.reshape(out, shape=ps.concat([[tot_size], ps.shape(out)[-3:]], axis=0)) out = tf.nn.depth_to_space(flat_out, block_size=strides) out_height = _deconv_output_length(xh, filter_size=fh, padding=padding, output_padding=None, stride=strides, dilation=dh) out_width = _deconv_output_length(xw, filter_size=fw, padding=padding, output_padding=None, stride=strides, dilation=dw) out = out[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :] out = tf.reshape( out, shape=ps.concat([ broadcast_batch_shape, [out_height, out_width, c_out] ], axis=0)) return out return op
def sinkhorn_iterations(x, y, a, b, power=2.0, epsilon=1e-3, epsilon_0=1e-1, epsilon_decay=0.95, threshold=1e-2, inner_num_iter=5, max_iterations=2000): """Runs the Sinkhorn's algorithm from (x, a) to (y, b). Args: x: Tensor<float>[batch, n]: the input point clouds. y: Tensor<float>[batch, m]: the target point clouds. a: Tensor<float>[batch, n]: the weight of each input point. The sum of all elements of b must match that of a to converge. b: Tensor<float>[batch, m]: the weight of each target point. The sum of all elements of b must match that of a to converge. power: (float) the power of the distance for the cost function. epsilon: (float) the level of entropic regularization wanted. epsilon_0: (float) the initial level of entropic regularization. epsilon_decay: (float) a multiplicative factor applied at each iteration until reaching the epsilon value. threshold: (float) the relative threshold on the Sinkhorn error to stop the Sinkhorn iterations. inner_num_iter: (int32) the Sinkhorn error is not recomputed at each iteration but every inner_num_iter instead to avoid computational overhead. max_iterations: (int32) the maximum number of Sinkhorn iterations. Returns: A 5-tuple containing: the values of the conjugate variables f and g, the final value of the entropic parameter epsilon, the cost matrix and the number of iterations. """ max_outer_iterations = max_iterations // inner_num_iter loga = tf.math.log(a) logb = tf.math.log(b) cost, d_cost = cost_fn(x, y, power) def body_fn(f, g, eps, num_iter): for _ in range(inner_num_iter): g = eps * logb + softmin(cost, f, g, eps, axis=1) + g f = eps * loga + softmin(cost, f, g, eps, axis=2) + f eps = tf.math.maximum(eps * epsilon_decay, epsilon) return [f, g, eps, num_iter + inner_num_iter] def cond_fn(f, g, eps, num_iter): return tf.math.reduce_all([ tf.math.less(num_iter, max_iterations), tf.math.reduce_any([ tf.math.greater(eps, epsilon), tf.math.greater(error(cost, f, g, eps, b), threshold) ]) ]) f, g, eps, iterations = tf.while_loop( cond_fn, body_fn, [ tf.zeros_like(loga), tf.zeros_like(logb), tf.cast(epsilon_0, dtype=x.dtype), tf.constant(0, dtype=tf.int32) ], parallel_iterations=1, maximum_iterations=max_outer_iterations + 1) return f, g, eps, cost, d_cost, iterations
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, c_out, batch_shape, event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1 ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1 def loop_body(i, outputs): subkernel_ind = kernels_ind.read(i) fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2) eh = ex_h + fh_ - 1 ew = ex_w + fw_ - 1 subkernel_ind = ps.reshape(ps.reshape( subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] + ps.range(c_in), shape=[-1]) k = tf.gather(kernel, subkernel_ind, axis=-2) ind, shape = im2row_index([eh, ew, c_in], block_shape=(fh_, fw_), slice_step=(1, 1), dilations=dilations) x_i = x_pad[..., :eh, :ew, :] x_i_shape = ps.shape(x_i) flat_shape = ps.pad(x_i_shape[:-3], paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_i, flat_shape) x_ = tf.gather(flat_x, ind, axis=-1) im_x = tf.reshape( x_, ps.concat([x_i_shape[:-3], shape], axis=0)) outputs = outputs.write( i, tf.matmul( im_x, tf.reshape( k, ps.concat([ kernel_batch, [1, fh_ * fw_ * c_in, c_out] ], axis=0)))) return i + 1, outputs outputs = tf.TensorArray(dtype=input_dtype, size=sh * sw) _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body, [0, outputs]) y = outputs.concat() m = tf.reduce_prod(ps.shape(y)[:-3]) y_ = tf.reshape(y, shape=ps.concat([[m], ps.shape(y)[-3:]], axis=0)) y2 = tf.batch_to_space(y_, strides, crops=tf.zeros([2, 2], dtype=tf.int64)) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) y2 = tf.reshape( y2, ps.concat([broadcast_batch_shape, ps.shape(y2)[-3:]], axis=0)) out_height = _deconv_output_length(xh, filter_size=fh, padding=padding, output_padding=None, stride=sh, dilation=dh) out_width = _deconv_output_length(xw, filter_size=fw, padding=padding, output_padding=None, stride=sw, dilation=dw) return y2[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :]
def rejection_sample_with_gradient(concentration): """Performs rejection sampling for standardized von Mises. A nested function is required because @tf.custom_gradient does not handle non-tensor inputs such as dtype. Instead, they are captured by the outer scope. Arguments: concentration: The concentration parameter of the distribution. Returns: Differentiable samples of standardized von Mises. """ r = 1. + tf.sqrt(1. + 4. * concentration**2) rho = (r - tf.sqrt(2. * r)) / (2. * concentration) s_exact = (1. + rho**2) / (2. * rho) # For low concentration, s becomes numerically unstable. # To fix that, we use an approximation. Here is the derivation. # First-order Taylor expansion at conc = 0 gives # sqrt(1 + 4 concentration^2) ~= 1 + (2 concentration)^2 / 2. # Therefore, r ~= 2 + 2 concentration. By plugging this into rho, we have # rho ~= conc + 1 / conc - sqrt(1 + 1 / concentration^2). # Let's expand the last term at concentration=0 up to the linear term: # sqrt(1 + 1 / concentration^2) ~= 1 / concentration + concentration / 2 # Thus, rho ~= concentration / 2. Finally, # s = 1 / (2 rho) + rho / 2 ~= 1 / concentration + concentration / 4. # Since concentration is small, we drop the second term and simply use # s ~= 1 / concentration. s_approximate = 1. / concentration # To compute the cutoff, we compute s_exact using mpmath with 30 decimal # digits precision and compare that to the s_exact and s_approximate # computed with dtype. Then, the cutoff is the largest concentration for # which abs(s_exact - s_exact_mpmath) > abs(s_approximate - s_exact_mpmath). s_concentration_cutoff_dict = { tf.float16: 1.8e-1, tf.float32: 2e-2, tf.float64: 1.2e-4, } s_concentration_cutoff = s_concentration_cutoff_dict[dtype] s = tf.where(concentration > s_concentration_cutoff, s_exact, s_approximate) def loop_body(done, u, w, seed): """Resample the non-accepted points.""" # We resample u each time completely. Only its sign is used outside the # loop, which is random. u_seed, v_seed, next_seed = samplers.split_seed(seed, n=3) u = samplers.uniform(shape, minval=-1., maxval=1., dtype=dtype, seed=u_seed) z = tf.cos(np.pi * u) # Update the non-accepted points. w = tf.where(done, w, (1. + s * z) / (s + z)) y = concentration * (s - w) v = samplers.uniform(shape, minval=0., maxval=1., dtype=dtype, seed=v_seed) accept = (y * (2. - y) >= v) | (tf.math.log(y / v) + 1. >= y) return done | accept, u, w, next_seed _, u, w, _ = tf.while_loop( cond=lambda done, *_: ~tf.reduce_all(done), body=loop_body, loop_vars=( tf.zeros(shape, dtype=tf.bool, name='done'), tf.zeros(shape, dtype=dtype, name='u'), tf.zeros(shape, dtype=dtype, name='w'), seed, ), # The expected number of iterations depends on concentration. # It monotonically increases from one iteration for concentration = 0 to # sqrt(2 pi / e) ~= 1.52 iterations for concentration = +inf [1]. # We use a limit of 100 iterations to avoid infinite loops # for very large / nan concentration. maximum_iterations=100, ) x = tf.sign(u) * tf.math.acos(w) def grad(dy): """The gradient of the von Mises samples w.r.t. concentration.""" broadcast_concentration = tf.broadcast_to(concentration, prefer_static.shape(x)) _, dcdf_dconcentration = value_and_gradient( lambda conc: von_mises_cdf(x, conc), broadcast_concentration) inv_prob = tf.exp(-broadcast_concentration * (tf.cos(x) - 1.)) * ( (2. * np.pi) * tf.math.bessel_i0e(broadcast_concentration)) # Compute the implicit reparameterization gradient [2], # dz/dconc = -(dF(z; conc) / dconc) / p(z; conc) ret = dy * (-inv_prob * dcdf_dconcentration) # Sum over the sample dimensions. Assume that they are always the first # ones. num_sample_dimensions = (tf.rank(broadcast_concentration) - tf.rank(concentration)) return tf.reduce_sum(ret, axis=tf.range(num_sample_dimensions)) return x, grad
def _sample_n(self, n, seed=None): seed = seed_stream.SeedStream(seed, salt='vom_mises_fisher') # The sampling strategy relies on the fact that vMF variates are symmetric # about the mean direction. Accordingly, if we have a sampling strategy for # the away-from-mean angle, then we can uniformly sample the remaining # dimensions on the S^{dim-2} sphere for , and rotate these samples from a # (1, 0, 0, ..., 0)-mode distribution into the target orientation. # # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a # von-Mises distributed `x` value in [-1, 1], then uniformly select what # amounts to a "up" or "down" additional degree of freedom after unit # normalizing, followed by a final rotation to the desired mean direction # from a basis of (1, 0). # # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the # unit sphere over which the distribution is uniform, in particular the # circle where x = \hat{x} intersects the unit sphere. We pick a point on # that circle, then rotate to the desired mean direction from a basis of # (1, 0, 0). event_dim = (tf.compat.dimension_value(self.event_shape[0]) or self._event_shape_tensor()[0]) sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0) dim = tf.cast(event_dim - 1, self.dtype) if event_dim == 3: samples_dim0 = self._sample_3d(n, seed=seed) else: # Wood'94 provides a rejection algorithm to sample the x coordinate. # Wood'94 definition of b: # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim # https://stats.stackexchange.com/questions/156729 suggests: b = dim / (2 * self.concentration + tf.sqrt(4 * self.concentration**2 + dim**2)) # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE # https://github.com/nicola-decao/s-vae-tf/ x = (1 - b) / (1 + b) c = self.concentration * x + dim * tf.math.log1p(-x**2) beta = beta_lib.Beta(dim / 2, dim / 2) def cond_fn(w, should_continue): del w return tf.reduce_any(input_tensor=should_continue) def body_fn(w, should_continue): z = beta.sample(sample_shape=sample_batch_shape, seed=seed()) w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w) w = tf.debugging.check_numerics(w, 'w') should_continue = tf.logical_and( should_continue, self.concentration * w + dim * tf.math.log1p(-x * w) - c < tf.math.log( tf.random.uniform(sample_batch_shape, seed=seed(), dtype=self.dtype))) return w, should_continue w = tf.zeros(sample_batch_shape, dtype=self.dtype) should_continue = tf.ones(sample_batch_shape, dtype=tf.bool) samples_dim0 = tf.while_loop(cond=cond_fn, body=body_fn, loop_vars=(w, should_continue))[0] samples_dim0 = samples_dim0[..., tf.newaxis] if not self._allow_nan_stats: # Verify samples are w/in -1, 1, with useful error output tensors (top # value rather than all values). with tf.control_dependencies([ assert_util.assert_less_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(1.01), data=[tf.nn.top_k(tf.reshape(samples_dim0, [-1]))[0]]), assert_util.assert_greater_equal( samples_dim0, dtype_util.as_numpy_dtype(self.dtype)(-1.01), data=[ -tf.nn.top_k(tf.reshape(-samples_dim0, [-1]))[0] ]) ]): samples_dim0 = tf.identity(samples_dim0) samples_otherdims_shape = tf.concat( [sample_batch_shape, [event_dim - 1]], axis=0) unit_otherdims = tf.nn.l2_normalize(tf.random.normal( samples_otherdims_shape, seed=seed(), dtype=self.dtype), axis=-1) samples = tf.concat( [ samples_dim0, # we must avoid sqrt(1 - (>1)**2) tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims ], axis=-1) samples = tf.nn.l2_normalize(samples, axis=-1) if not self._allow_nan_stats: samples = tf.debugging.check_numerics(samples, 'samples') # Runtime assert that samples are unit length. if not self._allow_nan_stats: worst, idx = tf.nn.top_k( tf.reshape(tf.abs(1 - tf.linalg.norm(tensor=samples, axis=-1)), [-1])) with tf.control_dependencies([ assert_util.assert_near( dtype_util.as_numpy_dtype(self.dtype)(0), worst, data=[ worst, idx, tf.gather(tf.reshape(samples, [-1, event_dim]), idx) ], atol=1e-4, summarize=100) ]): samples = tf.identity(samples) # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0). # Now, we move the mode to `self.mean_direction` using a rotation matrix. if not self._allow_nan_stats: # Assert that the basis vector rotates to the mean direction, as expected. basis = tf.cast( tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0), self.dtype) with tf.control_dependencies([ assert_util.assert_less( tf.linalg.norm(tensor=self._rotate(basis) - self.mean_direction, axis=-1), dtype_util.as_numpy_dtype(self.dtype)(1e-5)) ]): return self._rotate(samples) return self._rotate(samples)
def _build_discount_curve(bond_cashflows, bond_cashflow_times, present_values, pv_settle_times, initial_discount_rates, discount_tolerance, maximum_iterations): """Estimates the discount curve. The procedure is recursive and as follows: 1. Assume some initial set of discount rates/discount factors. Set this as the current yield curve. 2. From the current yield curve, interpolate to get the discount rates for each time at which bond_cashflows occur. 3. Using these discounts and the known bond prices, compute the discount rate to expiry of each bond by inverting the bond pricing formula as follows. We know that the bond price satisfies (`P` is the present value, `r_i` is the discount rate to time `t_i`, `c_i` is the cashflow occurring at time `t_i`.): ```None P e^{-r_0 t_0} = c_1 e^{-r_1 t_1} + ... + c_n e^{-r_n t_n} (A) ``` Assuming we have estimated r_0, r_1, r_2, ..., r_{n-1}, we can invert the above equation to calculate r_n. We write this in a suggestive form suitable for the implementation below. ```None -c_n z_n = -P z_0 + c_1 z_1 + c_2 z_2 + ... + c_{n-1} z_{n-1} (B) ``` where ```None z_i = e^{-r_i t_i} (C) ``` The RHS of Eq. (B) looks like the PV of cashflows `[-P, c_1, c_2, ... c_{n-1}]` paid out at times `[t_0, t_1, ..., t_{n-1}]`. Concatenate these "synthetic" cashflow times for each bond: `Ts = [t1_0, t1_1, ... t1_{n1-1}] + [t2_0, t2_1, ... t2_{n2-1}] ...` Also concatenate the synthetic bond cashflows as: `Cs = [-P1, c1_1, ..., c1_{n1-1}] + [-P2, c2_1, ..., c2_{n2-1}] ...` Then compute `Rs = InterpolateRates[Ts], Zs = exp(-Rs * Ts)` Let `Zns = [z_n1, z_n2, ... ], Cns = [c1_n, c2_n, ...]` be the discount factors to expiry and the final cashflow of each bond. We can derive `Zns = - SegmentSum(Cs * Zs) / Cns`. From that, we get Rns = -log(Zns) / Tns. Using this as the next guess for the discount rates and we repeat the procedure from Step (1) until convergence. Args: bond_cashflows: List of `Tensor`s. Each `Tensor` must be of rank 1 and of the same real dtype. They may be of different sizes. Each `Tensor` represents the bond cashflows defining a particular bond. The elements of the list are the bonds to be used to build the curve. bond_cashflow_times: List of `Tensor`s. The list must be of the same length as the `bond_cashflows` and each `Tensor` in the list must be of the same length as the `Tensor` at the same index in the `bond_cashflows` list. Each `Tensor` must be of rank 1 and of the same dtype as the `Tensor`s in `bond_cashflows` and contain strictly positive and increasing values. The times of the bond cashflows for the bonds must in an ascending order. present_values: List containing scalar `Tensor`s of the same dtype as elements of `bond_cashflows`. The length of the list must be the same as the length of `bond_cashflows`. The market price (i.e the all-in or dirty price) of the bond cashflows supplied in the `bond_cashflows`. pv_settle_times: List containing scalar `Tensor`s of the same dtype as elements of `bond_cashflows`. The length of the list must be the same as the length of `bond_cashflows`. The settlement times for the present values is the time from now when the bond is traded to the time that the purchase price is actually delivered. initial_discount_rates: Rank 1 `Tensor` of same shape and dtype as `pv_settle_times`. The initial guess for the discount rates to bond expiry times. discount_tolerance: Positive scalar `Tensor` of same dtype as `initial_discount_factors`. The absolute tolerance for terminating the iterations used to fit the rate curve. The iterations are stopped when the estimated discounts at the expiry times of the bond cashflows change by a amount smaller than `discount_tolerance` in an iteration. maximum_iterations: Positive scalar `tf.int32` `Tensor`. The maximum number of iterations permitted. Returns: curve_builder_result: An instance of `CurveBuilderResult` containing the following attributes. times: Rank 1 real `Tensor`. Times for the computed discount rates. discount_rates: Rank 1 `Tensor` of the same dtype as `times`. The inferred discount rates. discount_factor: Rank 1 `Tensor` of the same dtype as `times`. The inferred discount factors. initial_discount_rates: Rank 1 `Tensor` of the same dtype as `times`. The initial guess for the discount rates. converged: Scalar boolean `Tensor`. Whether the procedure converged. The procedure is said to have converged when the maximum absolute difference in the discount factors from one iteration to the next falls below the `discount_tolerance`. failed: Scalar boolean `Tensor`. Whether the procedure failed. Procedure may fail either because a NaN value was encountered for the discount rates or the discount factors. iterations: Scalar `tf.int32` `Tensor`. Number of iterations performed. """ calc_bond_cashflows = [] # Cs calc_times = [] # Ts expiry_times = [] # Tns expiry_bond_cashflows = [] # Cns calc_groups = [] num_bonds = len(bond_cashflows) for i in range(num_bonds): calc_bond_cashflows.extend([[-present_values[i]], bond_cashflows[i][:-1]]) calc_times.extend([[pv_settle_times[i]], bond_cashflow_times[i][:-1]]) expiry_times.append(bond_cashflow_times[i][-1]) expiry_bond_cashflows.append(bond_cashflows[i][-1]) calc_groups.append(tf.fill(tf.shape(bond_cashflows[i]), i)) calc_bond_cashflows = tf.concat(calc_bond_cashflows, axis=0) calc_times = tf.concat(calc_times, axis=0) expiry_times = tf.stack(expiry_times, axis=0) expiry_bond_cashflows = tf.stack(expiry_bond_cashflows, axis=0) calc_groups = tf.concat(calc_groups, axis=0) def one_step(converged, failed, iteration, expiry_discounts): """One step of the iteration.""" expiry_rates = -tf.math.log(expiry_discounts) / expiry_times failed = tf.math.reduce_any( tf.math.is_nan(expiry_rates) | tf.math.is_nan(expiry_discounts)) calc_rates = monotone_convex.interpolate_yields(calc_times, expiry_times, yields=expiry_rates) calc_discounts = tf.math.exp(-calc_rates * calc_times) next_expiry_discounts = -tf.math.segment_sum( calc_bond_cashflows * calc_discounts, calc_groups) / expiry_bond_cashflows discount_diff = tf.math.abs(next_expiry_discounts - expiry_discounts) converged = (~tf.math.reduce_any(tf.math.is_nan(discount_diff)) & (tf.math.reduce_max(discount_diff) < discount_tolerance)) return converged, failed, iteration + 1, next_expiry_discounts def cond(converged, failed, iteration, expiry_discounts): del expiry_discounts, iteration # Note we do not need to check iteration count here because that # termination mode is imposed by the maximum_iterations parameter in the # while loop. return ~tf.math.logical_or(converged, failed) initial_discount_factors = tf.math.exp(-initial_discount_rates * expiry_times) initial_vals = (False, False, 0, initial_discount_factors) loop_result = tf.while_loop(cond, one_step, initial_vals, maximum_iterations=maximum_iterations) discount_factors = loop_result[-1] discount_rates = -tf.math.log(discount_factors) / expiry_times results = CurveBuilderResult(times=expiry_times, discount_rates=discount_rates, discount_factors=discount_factors, initial_discount_rates=initial_discount_rates, converged=loop_result[0], failed=loop_result[1], iterations=loop_result[2]) return results
def _compute_general_continued_fraction(max_iterations, numerator_denominator_args_list, tolerance=None, partial_numerator_fn=None, partial_denominator_fn=None, dtype=tf.float32, name=None): """Compute a general continued fraction. Given at least one of `partial_numerator_fn` and `partial_denominator_fn`, compute the continued fraction associated with it via the forward recurrence. Let `a_i = partial_numerator_fn` and `b_i = partial_denominator_fn`. Then, this evaluates the infinite continued fraction: ```result = a_1 / (b_1 + a_2 / (b_2 + a_3 / (b_3 .....)```. If `partial_numerator_fn` or `partial_denominator_fn` are not given, then `a_i` (respectively `b_i`) are assumed to be 1. However one must be given. NOTE: Use this with caution. Forward recursion doesn't have numerical stability guarantees, compared to backward recursion. Args: max_iterations: Integer `Tensor` specifying the maximum number of terms to use. numerator_denominator_args_list: Arguments to pass in to `partial_numerator_fn` and `partial_denominator_fn`. tolerance: Float `Tensor` specifying the maximum acceptable tolerance between convergents. If unset, convergence is dictated by the number of iterations. Default value: `None`. partial_numerator_fn: Python callable that takes in as its first argument the current iteration count (an integer >= 1), and a list of *args, and returns a `Tensor`. These are used as partial numerators for the continued fraction. Default value: `None`. partial_denominator_fn: Python callable that takes in as its first argument the current iteration count (an integer >= 1), and a list of *args, and returns a `Tensor`. These are used as partial denominators for the continued fraction. Default value: `None`. dtype: The default dtype of the continued fraction. Default: `float32`. name: A name for the operation (optional). Default value: `None` (i.e., 'continued_fraction'). Returns: Continued fraction computed to `max_iterations` iterations and/or up to absolute error `tolerance`. #### References [1]: Walter Gautschi and Josef Slavik. On the Computation of Modified Bessel Function Ratios. http://www.jstor.com/stable/2006491 """ with tf.name_scope(name or 'continued_fraction'): dtype = dtype_util.common_dtype(numerator_denominator_args_list, dtype) if (partial_numerator_fn is None) and (partial_denominator_fn is None): raise ValueError('Expect one of `partial_numerator_fn` and ' '`partial_denominator_fn` to be set.') def _continued_fraction_one_step(unused_should_stop, numerator, previous_numerator, denominator, previous_denominator, iteration_count): partial_denominator = 1. if partial_denominator_fn: partial_denominator = partial_denominator_fn( iteration_count, *numerator_denominator_args_list) new_numerator = partial_denominator * numerator new_denominator = partial_denominator * denominator partial_numerator = 1. if partial_numerator_fn: partial_numerator = partial_numerator_fn( iteration_count, *numerator_denominator_args_list) new_numerator = new_numerator + partial_numerator * previous_numerator new_denominator = (new_denominator + partial_numerator * previous_denominator) should_stop_next = iteration_count > max_iterations if tolerance is not None: # We can use a more efficient computation when the partial numerators # are 1. if partial_numerator_fn is None: # We now want to compute to relative error between the fraction at # this iteration, vs. the previous iteration. # Let h_i be the numerator and k_i the denominator, and a_i be the # i-th term. # h_i / k_i - h_{i-1} / k_{i-1} = # (h_i * k_{i - 1} - h_{i - 1} * k_i) / (k_i * k_{i - 1}) = # ((a_i h_{i - 1} + h_{i - 2}) * k_{i - 1} - # (a_i k_{i - 1} + k_{i - 2}) * h_{i - 1}) / (k_i * k_{i - 1}) = # -(h_{i - 1} * k_{i - 2} - h_{i - 2} * k_{i - 1}) / (k_i * k_{i - 1}) # This suggests we should prove something about the numerator # inductively, and indeed # (h_i * k_{i - 1} - h_{i - 1} * k_i) = (-1)**i delta = tf.math.reciprocal(new_denominator * denominator) # We actually need to compute the difference of fractions. else: delta = new_numerator / new_denominator - numerator / denominator converged = tf.math.abs(delta) <= tolerance should_stop_next = tf.reduce_all(converged) | should_stop_next return (should_stop_next, new_numerator, numerator, new_denominator, denominator, iteration_count + 1.) # This is to infer the correct shape of tensors if partial_denominator_fn: term = partial_denominator_fn(1., *numerator_denominator_args_list) else: term = partial_numerator_fn(1., *numerator_denominator_args_list) zeroth_numerator = tf.ones_like(term, dtype=dtype) zeroth_denominator = tf.zeros_like(term, dtype=dtype) first_numerator = tf.zeros_like(term, dtype=dtype) first_denominator = tf.ones_like(term, dtype=dtype) results = tf.while_loop(cond=lambda stop, *_: ~stop, body=_continued_fraction_one_step, loop_vars=(False, first_numerator, zeroth_numerator, first_denominator, zeroth_denominator, tf.cast(1., dtype=dtype))) return results[1] / results[3]
def minimize(objective_function, initial_simplex=None, initial_vertex=None, step_sizes=None, objective_at_initial_simplex=None, objective_at_initial_vertex=None, batch_evaluate_objective=False, func_tolerance=1e-8, position_tolerance=1e-8, parallel_iterations=1, max_iterations=None, reflection=None, expansion=None, contraction=None, shrinkage=None, name=None): """Minimum of the objective function using the Nelder Mead simplex algorithm. Performs an unconstrained minimization of a (possibly non-smooth) function using the Nelder Mead simplex method. Nelder Mead method does not support univariate functions. Hence the dimensions of the domain must be 2 or greater. For details of the algorithm, see [Press, Teukolsky, Vetterling and Flannery(2007)][1]. Points in the domain of the objective function may be represented as a `Tensor` of general shape but with rank at least 1. The algorithm proceeds by modifying a full rank simplex in the domain. The initial simplex may either be specified by the user or can be constructed using a single vertex supplied by the user. In the latter case, if `v0` is the supplied vertex, the simplex is the convex hull of the set: ```None S = {v0} + {v0 + step_i * e_i} ``` Here `e_i` is a vector which is `1` along the `i`-th axis and zero elsewhere and `step_i` is a characteristic length scale along the `i`-th axis. If the step size is not supplied by the user, a unit step size is used in every axis. Alternately, a single step size may be specified which is used for every axis. The most flexible option is to supply a bespoke step size for every axis. ### Usage: The following example demonstrates the usage of the Nelder Mead minimzation on a two dimensional problem with the minimum located at a non-differentiable point. ```python # The objective function def sqrt_quadratic(x): return tf.sqrt(tf.reduce_sum(x ** 2, axis=-1)) start = tf.constant([6.0, -21.0]) # Starting point for the search. optim_results = tfp.optimizer.nelder_mead_minimize( sqrt_quadratic, initial_vertex=start, func_tolerance=1e-8, batch_evaluate_objective=True) # Check that the search converged assert(optim_results.converged) # Check that the argmin is close to the actual value. np.testing.assert_allclose(optim_results.position, np.array([0.0, 0.0]), atol=1e-7) # Print out the total number of function evaluations it took. print("Function evaluations: %d" % optim_results.num_objective_evaluations) ``` ### References: [1]: William Press, Saul Teukolsky, William Vetterling and Brian Flannery. Numerical Recipes in C++, third edition. pp. 502-507. (2007). http://numerical.recipes/cpppages/chap0sel.pdf [2]: Jeffrey Lagarias, James Reeds, Margaret Wright and Paul Wright. Convergence properties of the Nelder-Mead simplex method in low dimensions, Siam J. Optim., Vol 9, No. 1, pp. 112-147. (1998). http://www.math.kent.edu/~reichel/courses/Opt/reading.material.2/nelder.mead.pdf [3]: Fuchang Gao and Lixing Han. Implementing the Nelder-Mead simplex algorithm with adaptive parameters. Computational Optimization and Applications, Vol 51, Issue 1, pp 259-277. (2012). https://pdfs.semanticscholar.org/15b4/c4aa7437df4d032c6ee6ce98d6030dd627be.pdf Args: objective_function: A Python callable that accepts a point as a real `Tensor` and returns a `Tensor` of real dtype containing the value of the function at that point. The function to be minimized. If `batch_evaluate_objective` is `True`, the callable may be evaluated on a `Tensor` of shape `[n+1] + s ` where `n` is the dimension of the problem and `s` is the shape of a single point in the domain (so `n` is the size of a `Tensor` representing a single point). In this case, the expected return value is a `Tensor` of shape `[n+1]`. Note that this method does not support univariate functions so the problem dimension `n` must be strictly greater than 1. initial_simplex: (Optional) `Tensor` of real dtype. The initial simplex to start the search. If supplied, should be a `Tensor` of shape `[n+1] + s` where `n` is the dimension of the problem and `s` is the shape of a single point in the domain. Each row (i.e. the `Tensor` with a given value of the first index) is interpreted as a vertex of a simplex and hence the rows must be affinely independent. If not supplied, an axes aligned simplex is constructed using the `initial_vertex` and `step_sizes`. Only one and at least one of `initial_simplex` and `initial_vertex` must be supplied. initial_vertex: (Optional) `Tensor` of real dtype and any shape that can be consumed by the `objective_function`. A single point in the domain that will be used to construct an axes aligned initial simplex. step_sizes: (Optional) `Tensor` of real dtype and shape broadcasting compatible with `initial_vertex`. Supplies the simplex scale along each axes. Only used if `initial_simplex` is not supplied. See description above for details on how step sizes and initial vertex are used to construct the initial simplex. objective_at_initial_simplex: (Optional) Rank `1` `Tensor` of real dtype of a rank `1` `Tensor`. The value of the objective function at the initial simplex. May be supplied only if `initial_simplex` is supplied. If not supplied, it will be computed. objective_at_initial_vertex: (Optional) Scalar `Tensor` of real dtype. The value of the objective function at the initial vertex. May be supplied only if the `initial_vertex` is also supplied. batch_evaluate_objective: (Optional) Python `bool`. If True, the objective function will be evaluated on all the vertices of the simplex packed into a single tensor. If False, the objective will be mapped across each vertex separately. Evaluating the objective function in a batch allows use of vectorization and should be preferred if the objective function allows it. func_tolerance: (Optional) Scalar `Tensor` of real dtype. The algorithm stops if the absolute difference between the largest and the smallest function value on the vertices of the simplex is below this number. position_tolerance: (Optional) Scalar `Tensor` of real dtype. The algorithm stops if the largest absolute difference between the coordinates of the vertices is below this threshold. parallel_iterations: (Optional) Positive integer. The number of iterations allowed to run in parallel. max_iterations: (Optional) Scalar positive `Tensor` of dtype `int32`. The maximum number of iterations allowed. If `None` then no limit is applied. reflection: (Optional) Positive Scalar `Tensor` of same dtype as `initial_vertex`. This parameter controls the scaling of the reflected vertex. See, [Press et al(2007)][1] for details. If not specified, uses the dimension dependent prescription of [Gao and Han(2012)][3]. expansion: (Optional) Positive Scalar `Tensor` of same dtype as `initial_vertex`. Should be greater than `1` and `reflection`. This parameter controls the expanded scaling of a reflected vertex. See, [Press et al(2007)][1] for details. If not specified, uses the dimension dependent prescription of [Gao and Han(2012)][3]. contraction: (Optional) Positive scalar `Tensor` of same dtype as `initial_vertex`. Must be between `0` and `1`. This parameter controls the contraction of the reflected vertex when the objective function at the reflected point fails to show sufficient decrease. See, [Press et al(2007)][1] for more details. If not specified, uses the dimension dependent prescription of [Gao and Han(2012][3]. shrinkage: (Optional) Positive scalar `Tensor` of same dtype as `initial_vertex`. Must be between `0` and `1`. This parameter is the scale by which the simplex is shrunk around the best point when the other steps fail to produce improvements. See, [Press et al(2007)][1] for more details. If not specified, uses the dimension dependent prescription of [Gao and Han(2012][3]. name: (Optional) Python str. The name prefixed to the ops created by this function. If not supplied, the default name 'minimize' is used. Returns: optimizer_results: A namedtuple containing the following items: converged: Scalar boolean tensor indicating whether the minimum was found within tolerance. num_objective_evaluations: The total number of objective evaluations performed. position: A `Tensor` containing the last argument value found during the search. If the search converged, then this value is the argmin of the objective function. objective_value: A tensor containing the value of the objective function at the `position`. If the search converged, then this is the (local) minimum of the objective function. final_simplex: The last simplex constructed before stopping. final_objective_values: The objective function evaluated at the vertices of the final simplex. initial_simplex: The starting simplex. initial_objective_values: The objective function evaluated at the vertices of the initial simplex. num_iterations: The number of iterations of the main algorithm body. Raises: ValueError: If any of the following conditions hold 1. If none or more than one of `initial_simplex` and `initial_vertex` are supplied. 2. If `initial_simplex` and `step_sizes` are both specified. """ with tf.name_scope(name or 'minimize'): (dim, _, simplex, objective_at_simplex, num_evaluations) = _prepare_args(objective_function, initial_simplex, initial_vertex, step_sizes, objective_at_initial_simplex, objective_at_initial_vertex, batch_evaluate_objective) domain_dtype = simplex.dtype (reflection, expansion, contraction, shrinkage) = _resolve_parameters(dim, reflection, expansion, contraction, shrinkage, domain_dtype) closure_kwargs = dict( objective_function=objective_function, dim=dim, func_tolerance=func_tolerance, position_tolerance=position_tolerance, batch_evaluate_objective=batch_evaluate_objective, reflection=reflection, expansion=expansion, contraction=contraction, shrinkage=shrinkage) def _loop_body(_, iterations, simplex, objective_at_simplex, num_evaluations): (converged, next_simplex, next_objective, evaluations) = nelder_mead_one_step(simplex, objective_at_simplex, **closure_kwargs) return (converged, iterations + 1, next_simplex, next_objective, num_evaluations + evaluations) initial_args = (False, 0, simplex, objective_at_simplex, num_evaluations) # Loop until either we have converged or if the max iterations are supplied # then until we have converged or exhausted the available iteration budget. def _is_converged(converged, num_iterations, *ignored_args): # pylint:disable=unused-argument # It is important to ensure that not_converged is a tensor. If # converged is not a tensor but a Python bool, then the overloaded # op '~' acts as bitwise complement so ~True = -2 and ~False = -1. # In that case, the loop will never terminate. not_converged = tf.logical_not(converged) return (not_converged if max_iterations is None else (not_converged & (num_iterations < max_iterations))) (converged, num_iterations, final_simplex, final_objective_values, final_evaluations) = tf.while_loop( cond=_is_converged, body=_loop_body, loop_vars=initial_args, parallel_iterations=parallel_iterations) order = tf.argsort(final_objective_values, direction='ASCENDING', stable=True) best_index = order[0] # The explicit cast to Tensor below is done to avoid returning a mixture # of Python types and Tensors which cause problems with session.run. # In the eager mode, converged may remain a Python bool. Trying to evaluate # the whole tuple in one evaluate call will raise an exception because # of the presence of non-tensors. This is very annoying so we explicitly # cast those arguments to Tensors. return NelderMeadOptimizerResults( converged=tf.convert_to_tensor(converged), num_objective_evaluations=final_evaluations, position=final_simplex[best_index], objective_value=final_objective_values[best_index], final_simplex=final_simplex, final_objective_values=final_objective_values, num_iterations=tf.convert_to_tensor(num_iterations), initial_simplex=simplex, initial_objective_values=objective_at_simplex)
def _sample_paths(self, times, num_samples, random_type, skip, seed, normal_draws=None, times_grid=None, validate_args=False): """Returns a sample of paths from the process.""" # Note: all the notations below are the same as in [1]. num_requested_times = tf.shape(times)[0] params = [self._mean_reversion, self._volatility] if self._corr_matrix is not None: params = params + [self._corr_matrix] times, keep_mask = _prepare_grid( times, times_grid, *params) # Add zeros as a starting location dt = times[1:] - times[:-1] if dt.shape.is_fully_defined(): steps_num = dt.shape.as_list()[-1] else: steps_num = tf.shape(dt)[-1] # TODO(b/148133811): Re-enable Sobol test when TF 2.2 is released. if random_type == random.RandomType.SOBOL: raise ValueError('Sobol sequence for Euler sampling is temporarily ' 'unsupported when `time_step` or `times` have a ' 'non-constant value') if normal_draws is None: # In order to use low-discrepancy random_type we need to generate the # sequence of independent random normals upfront. We also precompute # random numbers for stateless random type in order to ensure independent # samples for multiple function calls whith different seeds. if random_type in (random.RandomType.SOBOL, random.RandomType.HALTON, random.RandomType.HALTON_RANDOMIZED, random.RandomType.STATELESS, random.RandomType.STATELESS_ANTITHETIC): normal_draws = utils.generate_mc_normal_draws( num_normal_draws=self._dim, num_time_steps=steps_num, num_sample_paths=num_samples, random_type=random_type, seed=seed, dtype=self._dtype, skip=skip) else: normal_draws = None else: if validate_args: draws_times = tf.shape(normal_draws)[0] asserts = tf.assert_equal( draws_times, tf.shape(times)[0] - 1, # We have added `0` to `times` message='`tf.shape(normal_draws)[1]` should be equal to the ' 'number of all `times` plus the number of all jumps of ' 'the piecewise constant parameters.') with tf.compat.v1.control_dependencies([asserts]): normal_draws = tf.identity(normal_draws) # The below is OK because we support exact discretization with piecewise # constant mr and vol. mean_reversion = self._mean_reversion(times) volatility = self._volatility(times) if self._corr_matrix is not None: corr_matrix = _get_parameters( times + tf.math.reduce_min(dt) / 2, self._corr_matrix)[0] corr_matrix_root = tf.linalg.cholesky(corr_matrix) else: corr_matrix_root = None exp_x_t = self._conditional_mean_x(times, mean_reversion, volatility) var_x_t = self._conditional_variance_x(times, mean_reversion, volatility) if self._dim == 1: mean_reversion = tf.expand_dims(mean_reversion, axis=0) cond_fn = lambda i, *args: i < tf.size(dt) def body_fn(i, written_count, current_x, rate_paths): """Simulate hull-white process to the next time point.""" if normal_draws is None: normals = random.mv_normal_sample( (num_samples,), mean=tf.zeros((self._dim,), dtype=mean_reversion.dtype), random_type=random_type, seed=seed) else: normals = normal_draws[i] if corr_matrix_root is not None: normals = tf.linalg.matvec(corr_matrix_root[i], normals) vol_x_t = tf.math.sqrt(tf.nn.relu(tf.transpose(var_x_t)[i])) # If numerically `vol_x_t == 0`, the gradient of `vol_x_t` becomes `NaN`. # To prevent this, we explicitly set `vol_x_t` to zero tensor at zero # values so that the gradient is set to zero at this values. vol_x_t = tf.where(vol_x_t > 0.0, vol_x_t, 0.0) next_x = (tf.math.exp(-tf.transpose(mean_reversion)[i + 1] * dt[i]) * current_x + tf.transpose(exp_x_t)[i] + vol_x_t * normals) f_0_t = self._instant_forward_rate_fn(times[i + 1]) # Update `rate_paths` rate_paths = utils.maybe_update_along_axis( tensor=rate_paths, do_update=keep_mask[i + 1], ind=written_count, axis=1, new_tensor=tf.expand_dims(next_x, axis=1) + f_0_t) written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32) return (i + 1, written_count, next_x, rate_paths) rate_paths = tf.zeros((num_samples, num_requested_times, self._dim), dtype=self._dtype) # Include initial state, if necessary f0_t = self._instant_forward_rate_fn(times[0]) rate_paths = utils.maybe_update_along_axis( tensor=rate_paths, do_update=keep_mask[0], ind=0, axis=1, new_tensor=f0_t) written_count = tf.cast(keep_mask[0], dtype=tf.int32) initial_x = tf.zeros((num_samples, self._dim), dtype=self._dtype) # TODO(b/157232803): Use tf.cumsum instead? _, _, _, rate_paths = tf.while_loop( cond_fn, body_fn, (0, written_count, initial_x, rate_paths)) return rate_paths
def __call__(self, momentum_parts, state_parts, target=None, target_grad_parts=None, name=None): """Applies `num_steps` of the leapfrog integrator. Args: momentum_parts: Python `list` of `Tensor`s representing momentume for each state part. state_parts: Python `list` of `Tensor`s which collectively representing the state. target: Batch of scalar `Tensor` representing the target (i.e., unnormalized log prob) evaluated at `state_parts`. target_grad_parts: Python `list` of `Tensor`s representing the gradient of `target` with respect to each of `state_parts`. name: Python `str` used to group ops created by this function. Returns: next_momentum_parts: Python `list` of `Tensor`s representing new momentum. next_state_parts: Python `list` of `Tensor`s which collectively representing the new state. next_target: Batch of scalar `Tensor` representing the target (i.e., unnormalized log prob) evaluated at `next_state_parts`. next_target_grad_parts: Python `list` of `Tensor`s representing the gradient of `next_target` with respect to each of `next_state_parts`. """ with tf.name_scope(name or 'leapfrog_integrate'): [ momentum_parts, state_parts, target, target_grad_parts, ] = process_args(self.target_fn, momentum_parts, state_parts, target, target_grad_parts) # See Algorithm 1 of "Faster Hamiltonian Monte Carlo by Learning Leapfrog # Scale", https://arxiv.org/abs/1810.04449. half_next_momentum_parts = [ v + tf.cast(0.5 * eps, v.dtype) * tf.cast(g, v.dtype) for v, eps, g in zip(momentum_parts, self.step_sizes, target_grad_parts) ] [ _, next_half_next_momentum_parts, next_state_parts, next_target, next_target_grad_parts, ] = tf.while_loop( cond=lambda i, *_: i < self.num_steps, body=lambda i, *args: [i + 1] + list( _one_step( # pylint: disable=no-value-for-parameter,g-long-lambda self.target_fn, self.step_sizes, *args)), loop_vars=[ tf.zeros_like(self.num_steps, name='iter'), half_next_momentum_parts, state_parts, target, target_grad_parts, ]) next_momentum_parts = [ v - tf.cast(0.5 * eps, v.dtype) * tf.cast(g, v.dtype) # pylint: disable=g-complex-comprehension for v, eps, g in zip(next_half_next_momentum_parts, self.step_sizes, next_target_grad_parts) ] return ( next_momentum_parts, next_state_parts, next_target, next_target_grad_parts, )
def solve(self, ode_fn, initial_time, initial_state, solution_times, jacobian_fn=None, jacobian_sparsity=None, batch_ndims=None, previous_solver_internal_state=None): """See `tfp.math.ode.Solver.solve`.""" # The `solve` function is comprised of the following sequential stages: # (1) Make static assertions. # (2) Initialize variables. # (3) Make non-static assertions. # (4) Solve up to final time. # (5) Return `Results` object. # # The stages can be found in the code by searching for (n) where n=1..5. # # By static vs. non-static assertions (see stages 1 and 3), we mean # assertions that can be made before the graph is run vs. those that can # only be made at run time. The latter are constructed as a list of # tf.Assert operations by the function `assert_ops` (see below). # # If `solution_times` is specified as a `Tensor`, stage 4 consists of three # nested loops, which can be conceptually understood as follows: # ``` # current_time, current_state = initial_time, initial_state # order, step_size = 1, first_step_size # for solution_time in solution_times: # while current_time < solution_time: # while True: # next_time = current_time + step_size # next_state, error = ( # solve_nonlinear_equation_to_get_approximate_state_at_next_time( # current_time, current_state, next_time, order)) # if error < tolerance: # current_time, current_state = next_time, next_state # order, step_size = ( # maybe_update_order_and_step_size(order, step_size)) # break # else: # step_size = decrease_step_size(step_size) # ``` # The outermost loop advances the solver to the next `solution_time` (see # `advance_to_solution_time`). The middle loop advances the solver by a # small timestep (see `step`). The innermost loop determines the size of # that timestep (see `maybe_step`). # # If `solution_times` is specified as # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped # and `solution_time` in the middle loop is replaced by `final_time`. def assert_ops(): """Creates a list of assert operations.""" if not self._validate_args: return [] assert_ops = [] if ((not initial_state_missing) and (previous_solver_internal_state is not None)): assert_initial_state_matches_previous_solver_internal_state = ( tf.assert_near( tf.norm( original_initial_state - previous_solver_internal_state.backward_differences[0], np.inf), 0., message='`previous_solver_internal_state` does not match ' '`initial_state`.')) assert_ops.append( assert_initial_state_matches_previous_solver_internal_state) if solution_times_chosen_by_solver: assert_ops.append( util.assert_positive(final_time - initial_time, 'final_time - initial_time')) else: assert_ops += [ util.assert_increasing(solution_times, 'solution_times'), util.assert_nonnegative(solution_times[0] - initial_time, 'solution_times[0] - initial_time'), ] if max_num_steps is not None: assert_ops.append(util.assert_positive(max_num_steps, 'max_num_steps')) if max_num_newton_iters is not None: assert_ops.append( util.assert_positive(max_num_newton_iters, 'max_num_newton_iters')) assert_ops += [ util.assert_positive(rtol, 'rtol'), util.assert_positive(atol, 'atol'), util.assert_positive(first_step_size, 'first_step_size'), util.assert_positive(safety_factor, 'safety_factor'), util.assert_positive(min_step_size_factor, 'min_step_size_factor'), util.assert_positive(max_step_size_factor, 'max_step_size_factor'), tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER), [ '`max_order` must be between 1 and {}.'.format(bdf_util.MAX_ORDER) ]), util.assert_positive(newton_tol_factor, 'newton_tol_factor'), util.assert_positive(newton_step_size_factor, 'newton_step_size_factor'), ] return assert_ops def advance_to_solution_time(n, diagnostics, iterand, solver_internal_state, states_array, times_array): """Takes multiple steps to advance time to `solution_times[n]`.""" def step_cond(next_time, diagnostics, iterand, *_): return (iterand.time < next_time) & (tf.equal(diagnostics.status, 0)) solution_times_n = solution_times_array.read(n) [ _, diagnostics, iterand, solver_internal_state, states_array, times_array ] = tf.while_loop(step_cond, step, [ solution_times_n, diagnostics, iterand, solver_internal_state, states_array, times_array ]) states_array = states_array.write( n, solver_internal_state.backward_differences[0]) times_array = times_array.write(n, solution_times_n) return (n + 1, diagnostics, iterand, solver_internal_state, states_array, times_array) def step(next_time, diagnostics, iterand, solver_internal_state, states_array, times_array): """Takes a single step.""" distance_to_next_time = next_time - iterand.time overstepped = iterand.new_step_size > distance_to_next_time iterand = iterand._replace( new_step_size=tf1.where(overstepped, distance_to_next_time, iterand.new_step_size), should_update_step_size=overstepped | iterand.should_update_step_size) if not self._evaluate_jacobian_lazily: diagnostics = diagnostics._replace( num_jacobian_evaluations=diagnostics.num_jacobian_evaluations + 1) iterand = iterand._replace( jacobian=jacobian_fn_mat( iterand.time, solver_internal_state.backward_differences[0]), jacobian_is_up_to_date=True) def maybe_step_cond(accepted, diagnostics, *_): return tf.logical_not(accepted) & tf.equal(diagnostics.status, 0) _, diagnostics, iterand, solver_internal_state = tf.while_loop( maybe_step_cond, maybe_step, [False, diagnostics, iterand, solver_internal_state]) if solution_times_chosen_by_solver: states_array = states_array.write( states_array.size(), solver_internal_state.backward_differences[0]) times_array = times_array.write(times_array.size(), iterand.time) return (next_time, diagnostics, iterand, solver_internal_state, states_array, times_array) def maybe_step(accepted, diagnostics, iterand, solver_internal_state): """Takes a single step only if the outcome has a low enough error.""" [ num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status ] = diagnostics [ jacobian, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper ] = iterand backward_differences, order, state_shape, step_size = solver_internal_state if max_num_steps is not None: status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0) backward_differences = tf1.where( should_update_step_size, bdf_util.interpolate_backward_differences(backward_differences, order, new_step_size / step_size), backward_differences) step_size = tf1.where(should_update_step_size, new_step_size, step_size) should_update_factorization = should_update_step_size num_steps_same_size = tf1.where(should_update_step_size, 0, num_steps_same_size) def update_factorization(): return bdf_util.newton_qr(jacobian, newton_coefficients_array.read(order), step_size) if self._evaluate_jacobian_lazily: def update_jacobian_and_factorization(): new_jacobian = jacobian_fn_mat(time, backward_differences[0]) new_unitary, new_upper = update_factorization() return [ new_jacobian, True, num_jacobian_evaluations + 1, new_unitary, new_upper ] def maybe_update_factorization(): new_unitary, new_upper = tf.cond( should_update_factorization, update_factorization, lambda: [unitary, upper]) return [ jacobian, jacobian_is_up_to_date, num_jacobian_evaluations, new_unitary, new_upper ] [ jacobian, jacobian_is_up_to_date, num_jacobian_evaluations, unitary, upper ] = tf.cond(should_update_jacobian, update_jacobian_and_factorization, maybe_update_factorization) else: unitary, upper = update_factorization() num_matrix_factorizations += 1 tol = atol + rtol * tf.abs(backward_differences[0]) newton_tol = newton_tol_factor * tf.norm(tol) [ newton_converged, next_backward_difference, next_state, newton_num_iters ] = bdf_util.newton(backward_differences, max_num_newton_iters, newton_coefficients_array.read(order), ode_fn_vec, order, step_size, time, newton_tol, unitary, upper) num_steps += 1 num_ode_fn_evaluations += newton_num_iters # If Newton's method failed and the Jacobian was up to date, decrease the # step size. newton_failed = tf.logical_not(newton_converged) should_update_step_size = newton_failed & jacobian_is_up_to_date new_step_size = step_size * tf1.where(should_update_step_size, newton_step_size_factor, 1.) # If Newton's method failed and the Jacobian was NOT up to date, update # the Jacobian. should_update_jacobian = newton_failed & tf.logical_not( jacobian_is_up_to_date) error_ratio = tf1.where( newton_converged, bdf_util.error_ratio(next_backward_difference, error_coefficients_array.read(order), tol), np.nan) accepted = error_ratio < 1. converged_and_rejected = newton_converged & tf.logical_not(accepted) # If Newton's method converged but the solution was NOT accepted, decrease # the step size. new_step_size = tf1.where( converged_and_rejected, util.next_step_size(step_size, order, error_ratio, safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = should_update_step_size | converged_and_rejected # If Newton's method converged and the solution was accepted, update the # matrix of backward differences. time = tf1.where(accepted, time + step_size, time) backward_differences = tf1.where( accepted, bdf_util.update_backward_differences(backward_differences, next_backward_difference, next_state, order), backward_differences) jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not(accepted) num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1, num_steps_same_size) # Order and step size are only updated if we have taken strictly more than # order + 1 steps of the same size. This is to prevent the order from # being throttled. should_update_order_and_step_size = accepted & ( num_steps_same_size > order + 1) backward_differences_array = tf.TensorArray( backward_differences.dtype, size=bdf_util.MAX_ORDER + 3, clear_after_read=False, element_shape=next_backward_difference.get_shape()).unstack( backward_differences) new_order = order new_error_ratio = error_ratio for offset in [-1, +1]: proposed_order = tf.clip_by_value(order + offset, 1, max_order) proposed_error_ratio = bdf_util.error_ratio( backward_differences_array.read(proposed_order + 1), error_coefficients_array.read(proposed_order), tol) proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio new_order = tf1.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_order, new_order) new_error_ratio = tf1.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_error_ratio, new_error_ratio) order = new_order error_ratio = new_error_ratio new_step_size = tf1.where( should_update_order_and_step_size, util.next_step_size(step_size, order, error_ratio, safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = ( should_update_step_size | should_update_order_and_step_size) diagnostics = _BDFDiagnostics(num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status) iterand = _BDFIterand(jacobian, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper) solver_internal_state = _BDFSolverInternalState(backward_differences, order, state_shape, step_size) return accepted, diagnostics, iterand, solver_internal_state # (1) Make static assertions. # TODO(parsiad): Support specifying Jacobian sparsity patterns. if jacobian_sparsity is not None: raise NotImplementedError('The BDF solver does not support specifying ' 'Jacobian sparsity patterns.') if batch_ndims is not None and batch_ndims != 0: raise NotImplementedError('The BDF solver does not support batching.') solution_times_chosen_by_solver = ( isinstance(solution_times, base.ChosenBySolver)) initial_state_missing = initial_state is None if initial_state_missing and previous_solver_internal_state is None: raise ValueError( 'At least one of `initial_state` or `previous_solver_internal_state` ' 'must be specified') with tf.name_scope(self._name): # (2) Initialize variables. original_initial_state = initial_state if previous_solver_internal_state is None: initial_state = tf.convert_to_tensor(initial_state) original_state_shape = tf.shape(initial_state) else: initial_state = previous_solver_internal_state.backward_differences[0] original_state_shape = previous_solver_internal_state.state_shape state_dtype = initial_state.dtype util.error_if_not_real_or_complex(initial_state, 'initial_state') # TODO(parsiad): Support complex automatic Jacobians. if jacobian_fn is None and state_dtype.is_complex: raise NotImplementedError('The BDF solver does not support automatic ' 'Jacobian computations for complex dtypes.') num_odes = tf.size(initial_state) original_state_tensor_shape = initial_state.get_shape() initial_state = tf.reshape(initial_state, [-1]) ode_fn_vec = util.get_ode_fn_vec(ode_fn, original_state_shape) # `real_dtype` is the floating point `dtype` associated with # `initial_state.dtype` (recall that the latter can be complex). real_dtype = tf.abs(initial_state).dtype initial_time = tf.ensure_shape( tf.convert_to_tensor(initial_time, dtype=real_dtype), []) num_solution_times = 0 if solution_times_chosen_by_solver: final_time = solution_times.final_time final_time = tf.ensure_shape( tf.convert_to_tensor(final_time, dtype=real_dtype), []) else: solution_times = tf.convert_to_tensor(solution_times, dtype=real_dtype) num_solution_times = tf.size(solution_times) solution_times_array = tf.TensorArray( solution_times.dtype, size=num_solution_times, element_shape=[]).unstack(solution_times) util.error_if_not_vector(solution_times, 'solution_times') jacobian_fn_mat = util.get_jacobian_fn_mat( jacobian_fn, ode_fn_vec, original_state_shape, use_pfor=self._use_pfor_to_compute_jacobian) rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype) atol = tf.convert_to_tensor(self._atol, dtype=real_dtype) safety_factor = tf.ensure_shape( tf.convert_to_tensor(self._safety_factor, dtype=real_dtype), []) min_step_size_factor = tf.ensure_shape( tf.convert_to_tensor(self._min_step_size_factor, dtype=real_dtype), []) max_step_size_factor = tf.ensure_shape( tf.convert_to_tensor(self._max_step_size_factor, dtype=real_dtype), []) max_num_steps = self._max_num_steps if max_num_steps is not None: max_num_steps = tf.convert_to_tensor(max_num_steps, dtype=tf.int32) max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32) max_num_newton_iters = self._max_num_newton_iters if max_num_newton_iters is not None: max_num_newton_iters = tf.convert_to_tensor( max_num_newton_iters, dtype=tf.int32) newton_tol_factor = tf.ensure_shape( tf.convert_to_tensor(self._newton_tol_factor, dtype=real_dtype), []) newton_step_size_factor = tf.ensure_shape( tf.convert_to_tensor(self._newton_step_size_factor, dtype=real_dtype), []) bdf_coefficients = tf.cast( tf.concat( [[0.], tf.convert_to_tensor(self._bdf_coefficients, dtype=real_dtype)], 0), state_dtype) util.error_if_not_vector(bdf_coefficients, 'bdf_coefficients') newton_coefficients = 1. / ( (1. - bdf_coefficients) * bdf_util.RECIPROCAL_SUMS) newton_coefficients_array = tf.TensorArray( newton_coefficients.dtype, size=bdf_util.MAX_ORDER + 1, clear_after_read=False, element_shape=[]).unstack(newton_coefficients) error_coefficients = bdf_coefficients * bdf_util.RECIPROCAL_SUMS + 1. / ( bdf_util.ORDERS + 1) error_coefficients_array = tf.TensorArray( error_coefficients.dtype, size=bdf_util.MAX_ORDER + 1, clear_after_read=False, element_shape=[]).unstack(error_coefficients) first_step_size = self._first_step_size if first_step_size is None: first_step_size = bdf_util.first_step_size( atol, error_coefficients_array.read(1), initial_state, initial_time, ode_fn_vec, rtol, safety_factor) elif previous_solver_internal_state is not None: tf.logging.warn('`first_step_size` is ignored since' '`previous_solver_internal_state` was specified.') first_step_size = tf.convert_to_tensor(first_step_size, dtype=real_dtype) if self._validate_args: if max_num_steps is not None: max_num_steps = tf.ensure_shape(max_num_steps, []) max_order = tf.ensure_shape(max_order, []) if max_num_newton_iters is not None: max_num_newton_iters = tf.ensure_shape(max_num_newton_iters, []) bdf_coefficients = tf.ensure_shape(bdf_coefficients, [6]) first_step_size = tf.ensure_shape(first_step_size, []) solver_internal_state = previous_solver_internal_state if solver_internal_state is None: first_order_backward_difference = ode_fn_vec( initial_time, initial_state) * tf.cast(first_step_size, state_dtype) backward_differences = tf.concat([ tf.reshape(initial_state, [1, -1]), first_order_backward_difference[tf.newaxis, :], tf.zeros( tf.stack([bdf_util.MAX_ORDER + 1, num_odes]), dtype=state_dtype), ], 0) solver_internal_state = _BDFSolverInternalState( backward_differences=backward_differences, order=1, state_shape=original_state_shape, step_size=first_step_size) states_array = tf.TensorArray( state_dtype, size=num_solution_times, dynamic_size=solution_times_chosen_by_solver, element_shape=initial_state.get_shape()) times_array = tf.TensorArray( real_dtype, size=num_solution_times, dynamic_size=solution_times_chosen_by_solver, element_shape=tf.TensorShape([])) diagnostics = _BDFDiagnostics( num_jacobian_evaluations=0, num_matrix_factorizations=0, num_ode_fn_evaluations=0, status=0) iterand = _BDFIterand( jacobian=tf.zeros([num_odes, num_odes], dtype=state_dtype), jacobian_is_up_to_date=False, new_step_size=solver_internal_state.step_size, num_steps=0, num_steps_same_size=0, should_update_jacobian=True, should_update_step_size=False, time=initial_time, unitary=tf.zeros([num_odes, num_odes], dtype=state_dtype), upper=tf.zeros([num_odes, num_odes], dtype=state_dtype)) # (3) Make non-static assertions. with tf.control_dependencies(assert_ops()): # (4) Solve up to final time. if solution_times_chosen_by_solver: def step_cond(next_time, diagnostics, iterand, *_): return (iterand.time < next_time) & ( tf.equal(diagnostics.status, 0)) [ _, diagnostics, iterand, solver_internal_state, states_array, times_array ] = tf.while_loop(step_cond, step, [ final_time, diagnostics, iterand, solver_internal_state, states_array, times_array ]) else: def advance_to_solution_time_cond(n, diagnostics, *_): return (n < num_solution_times) & (tf.equal(diagnostics.status, 0)) [ _, diagnostics, iterand, solver_internal_state, states_array, times_array ] = tf.while_loop(advance_to_solution_time_cond, advance_to_solution_time, [ 0, diagnostics, iterand, solver_internal_state, states_array, times_array ]) # (6) Return `Results` object. states = tf.reshape(states_array.stack(), tf.concat([[-1], original_state_shape], 0)) times = times_array.stack() if not solution_times_chosen_by_solver: times.set_shape(solution_times.get_shape()) states.set_shape(solution_times.get_shape().concatenate( original_state_tensor_shape)) return base.Results( times=times, states=states, diagnostics=diagnostics, solver_internal_state=solver_internal_state)
def one_step(self, current_state, previous_kernel_results, seed=None): seed = samplers.sanitize_seed(seed) # Retain for diagnostics. start_trajectory_seed, loop_seed = samplers.split_seed(seed) with tf.name_scope(self.name + '.one_step'): unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] current_target_log_prob = previous_kernel_results.target_log_prob [init_momentum, init_energy, log_slice_sample ] = self._start_trajectory_batched(current_state, current_target_log_prob, seed=start_trajectory_seed) def _copy(v): return v * ps.ones(ps.pad( [2], paddings=[[0, ps.rank(v)]], constant_values=1), dtype=v.dtype) initial_state = TreeDoublingState( momentum=init_momentum, state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob ) initial_step_state = tf.nest.map_structure(_copy, initial_state) if MULTINOMIAL_SAMPLE: init_weight = tf.zeros_like(init_energy) # log(exp(H0 - H0)) else: init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE) candidate_state = TreeDoublingStateCandidate( state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results. grads_target_log_prob, energy=init_energy, weight=init_weight) initial_step_metastate = TreeDoublingMetaState( candidate_state=candidate_state, is_accepted=tf.zeros_like(init_energy, dtype=tf.bool), momentum_sum=init_momentum, energy_diff_sum=tf.zeros_like(init_energy), leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE), continue_tree=tf.ones_like(init_energy, dtype=tf.bool), not_divergence=tf.ones_like(init_energy, dtype=tf.bool)) # Convert the write/read instruction into TensorArray so that it is # compatible with XLA. write_instruction = tf.TensorArray( TREE_COUNT_DTYPE, size=len(self._write_instruction), clear_after_read=False).unstack(self._write_instruction) read_instruction = tf.TensorArray(tf.int32, size=len(self._read_instruction), clear_after_read=False).unstack( self._read_instruction) current_step_meta_info = OneStepMetaInfo( log_slice_sample=log_slice_sample, init_energy=init_energy, write_instruction=write_instruction, read_instruction=read_instruction) _, _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda (iter_ < self.max_tree_depth) & tf.reduce_any( metastate.continue_tree)), body=lambda iter_, seed, state, metastate: self. _loop_tree_doubling( # pylint: disable=g-long-lambda previous_kernel_results.step_size, previous_kernel_results. momentum_state_memory, current_step_meta_info, iter_, state, metastate, seed), loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'), loop_seed, initial_step_state, initial_step_metastate), parallel_iterations=self.parallel_iterations, ) kernel_results = NUTSKernelResults( target_log_prob=new_step_metastate.candidate_state.target, grads_target_log_prob=( new_step_metastate.candidate_state.target_grad_parts), momentum_state_memory=previous_kernel_results. momentum_state_memory, step_size=previous_kernel_results.step_size, log_accept_ratio=tf.math.log( new_step_metastate.energy_diff_sum / tf.cast(new_step_metastate.leapfrog_count, dtype=new_step_metastate.energy_diff_sum.dtype)), leapfrogs_taken=(new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps), is_accepted=new_step_metastate.is_accepted, reach_max_depth=new_step_metastate.continue_tree, has_divergence=~new_step_metastate.not_divergence, energy=new_step_metastate.candidate_state.energy, seed=seed, ) result_state = new_step_metastate.candidate_state.state if unwrap_state_list: result_state = result_state[0] return result_state, kernel_results