def expand_dims_(x): """Implementation of `expand_dims`.""" with tf.name_scope(name or 'expand_dims'): x = tf.convert_to_tensor(x, name='x') new_axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis') nx = prefer_static.rank(x) na = prefer_static.size(new_axis) is_neg_axis = new_axis < 0 k = prefer_static.reduce_sum( prefer_static.cast(is_neg_axis, new_axis.dtype)) new_axis = prefer_static.where(is_neg_axis, new_axis + nx, new_axis) new_axis = prefer_static.sort(new_axis) axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1]) idx = prefer_static.argsort(prefer_static.concat([ axis_pos, prefer_static.range(nx), axis_neg, ], axis=0), stable=True) shape = prefer_static.pad(prefer_static.shape(x), paddings=[[na - k, k]], constant_values=1) shape = prefer_static.gather(shape, idx) return tf.reshape(x, shape)
def _forward_event_shape_tensor(self, input_shape, is_inverse=False): ndims = ps.size(input_shape) indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) extra_sizes = ps.reduce_sum(self.paddings, axis=-1) update_fn = (ps.tensor_scatter_nd_sub if is_inverse else ps.tensor_scatter_nd_add) return update_fn(ps.identity(input_shape), indices, extra_sizes)
def moments_of_masked_time_series(time_series_tensor, broadcast_mask): """Compute mean and variance, accounting for a mask. Args: time_series_tensor: float `Tensor` time series of shape `concat([batch_shape, [num_timesteps]])`. broadcast_mask: bool `Tensor` of the same shape as `time_series`. Returns: mean: float `Tensor` of shape `batch_shape`. variance: float `Tensor` of shape `batch_shape`. """ num_unmasked_entries = ps.cast( ps.reduce_sum(ps.cast(~broadcast_mask, np.int32), axis=-1), time_series_tensor.dtype) # Manually compute mean and variance, excluding masked entries. mean = (tf.reduce_sum(tf.where( broadcast_mask, tf.zeros([], dtype=time_series_tensor.dtype), time_series_tensor), axis=-1) / num_unmasked_entries) variance = (tf.reduce_sum(tf.where( broadcast_mask, tf.zeros([], dtype=time_series_tensor.dtype), (time_series_tensor - mean[..., tf.newaxis])**2), axis=-1) / num_unmasked_entries) return mean, variance
def expand_dims(x, axis, name=None): """Like `tf.expand_dims` but accepts a vector of axes to expand.""" with tf.name_scope(name or 'expand_dims'): x = tf.convert_to_tensor(x, name='x') axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis') nx = prefer_static.rank(x) na = prefer_static.size(axis) is_neg_axis = axis < 0 k = prefer_static.reduce_sum( prefer_static.cast(is_neg_axis, axis.dtype)) axis = prefer_static.where(is_neg_axis, axis + nx, axis) axis = prefer_static.sort(axis) axis_neg, axis_pos = prefer_static.split(axis, [k, -1]) idx = prefer_static.argsort(prefer_static.concat([ axis_pos, prefer_static.range(nx), axis_neg, ], axis=0), stable=True) shape = prefer_static.pad(prefer_static.shape(x), paddings=[[na - k, k]], constant_values=1) shape = prefer_static.gather(shape, idx) return tf.reshape(x, shape)
def _sample_bates(total_count, low, high, n, seed=None): """Vectorized production of `Bates` samples. Args: total_count: (Batches of) counts of `Uniform`s to take means of. Should have integer dtype and already be broadcasted to the batch shape. low: (Batches of) lower bounds of the `Uniform` variables to sample. Should be the same floating dtype as `high` and broadcastable to the batch shape. high: (Batches of) upper bounds of the `Uniform` variables to sample. Should be the same floating dtype as `low` and broadcastable to the batch shape. n: `int32` number of samples to generate. seed: Random seed to pass to `Uniform` sampler. Returns: samples: Samples of (batches of) the `Bates` variable. Will have same dtype as `low` and `high`. If the batch shape is `[B1,..., Bn]`, `samples` has shape `[n, B1,..., Bn]`. """ # 1. Sample Uniform(0, 1)s, flattening the batch dimension into axis 0. uniform_sample_shape = ps.concat([[ps.reduce_sum(total_count)], [n]], axis=0) uniform_samples = samplers.uniform( uniform_sample_shape, minval=0., maxval=1., dtype=low.dtype, seed=seed) # 2. Produce segment means. segment_lengths = tf.reshape(total_count, [-1]) segment_ids = tf.repeat(tf.range(tf.size(segment_lengths)), segment_lengths) flatmeans = tf.math.segment_mean(uniform_samples, segment_ids) # 3. Reshape and transpose segment means back to the original shape. outshape = tf.concat([tf.shape(total_count), [n]], axis=0) tmeans = tf.reshape(flatmeans, outshape) axes = tf.range(tf.rank(tmeans)) means = tf.transpose(tmeans, tf.roll(axes, shift=1, axis=0)) # 4. Shift/scale from (0, 1) to (low, high). return low + (high - low) * means
def _calculate_batch_shape(self): """Computes fully defined batch shape for the new distribution.""" all_batch_shapes = [d.batch_shape.as_list() if tensorshape_util.is_fully_defined(d.batch_shape) else d.batch_shape_tensor() for d in self.distributions] original_shape = ps.stack(all_batch_shapes, axis=0) index_mask = ps.cast( ps.one_hot(self._axis, ps.shape(original_shape)[1]), dtype=tf.bool) new_concat_dim = ps.cast( ps.reduce_sum(original_shape, axis=0)[self._axis], dtype=tf.int32) return ps.where(index_mask, new_concat_dim, ps.reduce_max(original_shape, axis=0))
def reduce_fn(operands, inits, axis=None, keepdims=False): """Applies `reducer` to the given operands along the given axes. Args: operands: tuple of tensors, all having the same shape. inits: tuple of scalar tensors, with dtypes aligned to those of operands. axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of `int`. `None` is taken to mean "reduce all axes". keepdims: When `True`, we do not squeeze away the reduced dims, instead returning values with singleton dims in those axes. Returns: reduced: A tuple of the reduced operands. """ # Static shape consistency checks. args_shape = operands[0].shape for arg in operands[1:]: args_shape = tensorshape_util.merge_with(args_shape, arg.shape) ndims = tensorshape_util.rank(args_shape) if ndims is None: raise ValueError( 'Rank of at least one of `operands` must be known statically.') # Ensure the 'axis' arg is a tuple of non-negative ints. axis = np.arange(ndims) if axis is None else np.array(axis) if axis.ndim > 1: raise ValueError( '`axis` must be `None`, an `int`, or a sequence of ' '`int`, but got {}'.format(axis)) axis = np.reshape(axis, [-1]) axis = np.where(axis < 0, axis + ndims, axis) axis = tuple(int(ax) for ax in axis) axis_nhot = ps.reduce_sum(ps.one_hot(axis, depth=ndims, on_value=True, off_value=False, dtype=tf.bool), axis=0) in_shape = args_shape if not tensorshape_util.is_fully_defined(in_shape): in_shape = tf.shape(operands[0]) unsqueezed_shape = ps.where(axis_nhot, 1, in_shape) result = _variadic_reduce_custom_grad(operands, inits, axis, reducer, unsqueezed_shape) if keepdims: result = tf.nest.map_structure( lambda t: tf.reshape(t, unsqueezed_shape), result) return result
def preprocess_state(init_state): """Initial preprocessing at Stage 0.""" dimension = ps.reduce_sum([ ps.reduce_prod(ps.shape(x)[1:]) for x in init_state]) likelihood_log_prob = likelihood_log_prob_fn(*init_state) # Default to the optimal for normal distributed targets. # TODO(b/152412213): Revisit this default parameter. scale_start = ( tf.constant(2.38 ** 2, dtype=likelihood_log_prob.dtype) / tf.constant(dimension, dtype=likelihood_log_prob.dtype)) # TODO(b/152412213): Enable batch of batches style by using non-scalar # inverse_temperature inverse_temperature = tf.zeros([], dtype=likelihood_log_prob.dtype) scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(scale_start, 1.) kernel = make_kernel_fn( _make_tempered_target_log_prob_fn( prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature), init_state, scalings, seed=seed_stream()) pkr = kernel.bootstrap_results(current_state) _, kernel_target_log_prob = gather_mh_like_result(pkr) particle_info = ParticleInfo( log_accept_prob=ps.zeros_like(likelihood_log_prob), log_scalings=tf.math.log(scalings), tempered_log_prob=kernel_target_log_prob, likelihood_log_prob=likelihood_log_prob, ) return SMCResults( num_steps=tf.convert_to_tensor( max_num_steps, dtype=tf.int32, name='num_steps'), inverse_temperature=inverse_temperature, log_marginal_likelihood=tf.constant( 0., dtype=likelihood_log_prob.dtype), particle_info=particle_info )
def sample_sequential_monte_carlo( prior_log_prob_fn, likelihood_log_prob_fn, current_state, max_num_steps=25, max_stage=100, make_kernel_fn=make_rwmh_kernel_fn, tuning_fn=simple_heuristic_tuning, make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn, ess_threshold_ratio=0.5, parallel_iterations=10, seed=None, name=None): """Runs Sequential Monte Carlo to sample from the posterior distribution. This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo) to sample from a series of distributions that slowly interpolates between an initial 'prior' distribution: `exp(prior_log_prob_fn(x))` and the target 'posterior' distribution: `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`, by mutating a collection of MC samples (i.e., particles). The approach is also known as Particle Filter in some literature. The current implemenetation is largely based on Del Moral et al [1], which adapts the tempering sequence adaptively (base on the effective sample size) and the scaling of the mutation kernel (base on the sample covariance of the particles) at each stage. Args: prior_log_prob_fn: Python callable that returns the log density of the prior distribution. likelihood_log_prob_fn: Python callable which takes an argument like `current_state` (or `*current_state` if it's a list) and returns its (possibly unnormalized) log-density under the likelihood distribution. current_state: `Tensor` or Python `list` of `Tensor`s representing the current state(s) of the Markov chain(s). The first `r` dimensions index independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. max_num_steps: The maximum number of kernel transition steps in one mutation of the MC samples. Note that the actual number of steps in one mutation is tuned during sampling and likely lower than the max_num_step. max_stage: Integer number of the stage for increasing the temperature from 0 to 1. make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like object. Must take one argument representing the `TransitionKernel`'s `target_log_prob_fn`. The `target_log_prob_fn` argument represents the `TransitionKernel`'s target log distribution. Note: `sample_sequential_monte_carlo` creates a new `target_log_prob_fn` which is an interpolation between the supplied `target_log_prob_fn` and `proposal_log_prob_fn`; it is this interpolated function which is used as an argument to `make_kernel_fn`. tuning_fn: Python `callable` which takes the number of steps, the log scaling, and the log acceptance ratio from the last mutation and output the number of steps and log scaling for the next mutation. make_tempered_target_log_prob_fn: Python `callable` that takes the `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures` and creates a `target_log_prob_fn` `callable` that pass to `make_kernel_fn`. ess_threshold_ratio: Target ratio for effective sample size. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: Python integer or TFP seedstream to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'sample_sequential_monte_carlo'). Returns: n_stage: Number of the mutation stage SMC ran. final_state: `Tensor` or Python `list` of `Tensor`s representing the final state(s) of the Markov chain(s). The output are the posterior samples. final_kernel_results: `collections.namedtuple` of internal calculations used to advance the chain. #### References [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential Monte Carlo method for approximate Bayesian computation. _Statistics and Computing_, 22.5(1009-1020), 2012. """ with tf.name_scope(name or 'sample_sequential_monte_carlo'): seed_stream = SeedStream(seed, salt='smc_seed') unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] current_state = [ tf.convert_to_tensor(s, dtype_hint=tf.float32) for s in current_state ] # Initial preprocessing at Stage 0 likelihood_log_prob = likelihood_log_prob_fn(*current_state) likelihood_rank = ps.rank(likelihood_log_prob) dimension = ps.reduce_sum([ ps.reduce_prod(ps.shape(x)[likelihood_rank:]) for x in current_state ]) # We infer the particle shapes from the resulting likelihood: # [num_particles, b1, ..., bN] particle_shape = ps.shape(likelihood_log_prob) num_particles, batch_shape = particle_shape[0], particle_shape[1:] effective_sample_size_threshold = tf.cast( num_particles * ess_threshold_ratio, tf.int32) # TODO(b/152412213): Revisit this default parameter. # Default to the optimal scaling of a random walk kernel for a d-dimensional # normal distributed targets: 2.38 ** 2 / d. # For more detail see: # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of # random walk Metropolis algorithms. _The annals of applied probability_. # 1997;7(1):110-20. scale_start = (tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) / tf.constant(dimension, dtype=likelihood_log_prob.dtype)) inverse_temperature = tf.zeros(batch_shape, dtype=likelihood_log_prob.dtype) scalings = ps.ones_like(likelihood_log_prob) * ps.minimum( scale_start, 1.) kernel = make_kernel_fn(make_tempered_target_log_prob_fn( prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature), current_state, scalings, seed=seed_stream) pkr = kernel.bootstrap_results(current_state) _, kernel_target_log_prob = gather_mh_like_result(pkr) particle_info = ParticleInfo( log_accept_prob=ps.zeros_like(likelihood_log_prob), log_scalings=tf.math.log(scalings), tempered_log_prob=kernel_target_log_prob, likelihood_log_prob=likelihood_log_prob, ) current_pkr = SMCResults( num_steps=tf.convert_to_tensor(max_num_steps, dtype=tf.int32, name='num_steps'), inverse_temperature=inverse_temperature, log_marginal_likelihood=tf.zeros_like(inverse_temperature), particle_info=particle_info) def update_weights_temperature(inverse_temperature, likelihood_log_prob): """Calculate the next inverse temperature and update weights.""" likelihood_diff = likelihood_log_prob - tf.reduce_max( likelihood_log_prob, axis=0) def _body_fn(new_beta, upper_beta, lower_beta, eff_size, log_weights): """One iteration of the temperature and weight update.""" new_beta = (lower_beta + upper_beta) / 2.0 log_weights = (new_beta - inverse_temperature) * likelihood_diff log_weights_norm = tf.math.log_softmax(log_weights, axis=0) eff_size = tf.cast( tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm, axis=0)), tf.int32) upper_beta = tf.where( eff_size < effective_sample_size_threshold, new_beta, upper_beta) lower_beta = tf.where( eff_size < effective_sample_size_threshold, lower_beta, new_beta) return new_beta, upper_beta, lower_beta, eff_size, log_weights def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_): # pylint: disable=unused-argument # TODO(junpenglao): revisit threshold below to be dtype specific. threshold = 1e-6 return (tf.math.reduce_any(upper_beta - lower_beta > threshold) & tf.math.reduce_any( eff_size != effective_sample_size_threshold)) (new_beta, upper_beta, lower_beta, eff_size, log_weights) = tf.while_loop( # pylint: disable=unused-variable cond=_cond_fn, body=_body_fn, loop_vars=(tf.zeros_like(inverse_temperature), tf.fill(ps.shape(inverse_temperature), tf.constant(2, inverse_temperature.dtype)), inverse_temperature, tf.zeros_like(inverse_temperature, dtype=tf.int32), tf.zeros_like(likelihood_diff)), parallel_iterations=parallel_iterations) log_weights = tf.where(new_beta < 1., log_weights, (1. - inverse_temperature) * likelihood_diff) marginal_loglike_ = reduce_logmeanexp( (new_beta - inverse_temperature) * likelihood_log_prob, axis=0) new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.) return marginal_loglike_, new_inverse_temperature, log_weights def mutate(current_state, log_scalings, num_steps, inverse_temperature): """Mutate the state using a Transition kernel.""" with tf.name_scope('mutate_states'): scalings = tf.exp(log_scalings) kernel = make_kernel_fn(make_tempered_target_log_prob_fn( prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature), current_state, scalings, seed=seed_stream) pkr = kernel.bootstrap_results(current_state) kernel_log_accept_ratio, _ = gather_mh_like_result(pkr) def mutate_onestep(i, state, pkr, log_accept_prob_sum): next_state, next_kernel_results = kernel.one_step( state, pkr) kernel_log_accept_ratio, _ = gather_mh_like_result(pkr) log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.) log_accept_prob_sum = log_add_exp(log_accept_prob_sum, log_accept_prob) return i + 1, next_state, next_kernel_results, log_accept_prob_sum ( _, next_state, next_kernel_results, log_accept_prob_sum ) = tf.while_loop( cond=lambda i, *args: i < num_steps, body=mutate_onestep, loop_vars=( tf.zeros([], dtype=tf.int32), current_state, pkr, # we accumulate the acceptance probability in log space. tf.fill( ps.shape(kernel_log_accept_ratio), tf.constant(-np.inf, kernel_log_accept_ratio.dtype))), parallel_iterations=parallel_iterations) _, kernel_target_log_prob = gather_mh_like_result( next_kernel_results) avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log( tf.cast(num_steps + 1, log_accept_prob_sum.dtype)) return (next_state, avg_log_accept_prob_per_particle, kernel_target_log_prob) # One SMC steps. def smc_body_fn(stage, state, smc_kernel_result): """Run one stage of SMC with constant temperature.""" (new_marginal, new_inv_temperature, log_weights) = update_weights_temperature( smc_kernel_result.inverse_temperature, smc_kernel_result.particle_info.likelihood_log_prob) # TODO(b/152412213) Use a tf.scan to better collect debug info. if PRINT_DEBUG: tf.print( 'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:', smc_kernel_result.num_steps, 'accept:', tf.exp( reduce_logmeanexp( smc_kernel_result.particle_info.log_accept_prob, axis=0)), 'scaling:', tf.exp( reduce_logmeanexp( smc_kernel_result.particle_info.log_scalings, axis=0))) (resampled_state, resampled_particle_info), _ = resample_particle_and_info( (state, smc_kernel_result.particle_info), log_weights, seed=seed_stream) next_num_steps, next_log_scalings = tuning_fn( smc_kernel_result.num_steps, resampled_particle_info.log_scalings, resampled_particle_info.log_accept_prob) # Skip tuning at stage 0. next_num_steps = tf.where(stage == 0, smc_kernel_result.num_steps, next_num_steps) next_log_scalings = tf.where(stage == 0, resampled_particle_info.log_scalings, next_log_scalings) next_num_steps = tf.clip_by_value(next_num_steps, 2, max_num_steps) next_state, log_accept_prob, tempered_log_prob = mutate( resampled_state, next_log_scalings, next_num_steps, new_inv_temperature) next_pkr = SMCResults( num_steps=next_num_steps, inverse_temperature=new_inv_temperature, log_marginal_likelihood=( new_marginal + smc_kernel_result.log_marginal_likelihood), particle_info=ParticleInfo( log_accept_prob=log_accept_prob, log_scalings=next_log_scalings, tempered_log_prob=tempered_log_prob, likelihood_log_prob=likelihood_log_prob_fn(*next_state), )) return stage + 1, next_state, next_pkr (n_stage, final_state, final_kernel_results) = tf.while_loop( cond=lambda i, state, pkr: ( # pylint: disable=g-long-lambda (i < max_stage) & tf.reduce_any(pkr.inverse_temperature < 1.)), body=smc_body_fn, loop_vars=(tf.zeros([], dtype=tf.int32), current_state, current_pkr), parallel_iterations=parallel_iterations) if unwrap_state_list: final_state = final_state[0] return n_stage, final_state, final_kernel_results
def _inverse_event_shape_tensor(self, output_shape): input_size = ps.reduce_sum(self.block_sizes) return ps.concat([output_shape[:-1], input_size[tf.newaxis]], -1)
def _forward_event_shape_tensor(self, input_shape): output_size = ps.reduce_sum(self._output_block_sizes()) return ps.concat([input_shape[:-1], output_size[tf.newaxis]], -1)
def reduce_fn(operands, inits, axis=None, keepdims=False): """Applies `reducer` to the given operands along the given axes. Args: operands: tuple of tensors, all having the same shape. inits: tuple of scalar tensors, with dtypes aligned to those of operands. axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of `int`. `None` is taken to mean "reduce all axes". keepdims: When `True`, we do not squeeze away the reduced dims, instead returning values with singleton dims in those axes. Returns: reduced: A tuple of the reduced operands. """ # Static shape consistency checks. args_shape = operands[0].shape for arg in operands[1:]: args_shape = tensorshape_util.merge_with(args_shape, arg.shape) ndims = tensorshape_util.rank(args_shape) if ndims is None: raise ValueError( 'Rank of at least one of `operands` must be known statically.') # Ensure the 'axis' arg is a tuple of non-negative ints. axis = np.arange(ndims) if axis is None else np.array(axis) if axis.ndim > 1: raise ValueError( '`axis` must be `None`, an `int`, or a sequence of ' '`int`, but got {}'.format(axis)) axis = np.reshape(axis, [-1]) axis = np.where(axis < 0, axis + ndims, axis) axis = tuple(int(ax) for ax in axis) if JAX_MODE: from jax import lax # pylint: disable=g-import-not-at-top result = lax.reduce(operands, init_values=inits, dimensions=axis, computation=reducer) elif (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): result = _variadic_reduce(operands, init=inits, axis=axis, reducer=reducer) else: result = _xla_reduce(operands, inits, axis) if keepdims: axis_nhot = ps.reduce_sum(ps.one_hot(axis, depth=ndims, on_value=True, off_value=False, dtype=tf.bool), axis=0) in_shape = args_shape if not tensorshape_util.is_fully_defined(in_shape): in_shape = tf.shape(operands[0]) final_shape = ps.where(axis_nhot, 1, in_shape) result = tf.nest.map_structure( lambda t: tf.reshape(t, final_shape), result) return result