def ess_below_threshold(unnormalized_log_weights, threshold=0.5): """Determines if the effective sample size is much less than num_particles.""" with tf.name_scope('ess_below_threshold'): num_particles = ps.size0(unnormalized_log_weights) log_weights = tf.math.log_softmax(unnormalized_log_weights, axis=0) log_ess = -tf.math.reduce_logsumexp(2 * log_weights, axis=0) return log_ess < (ps.log(num_particles) + ps.log(threshold))
def _filter_one_step(step, observation, previous_particles, log_weights, transition_fn, observation_fn, proposal_fn, resample_criterion_fn, seed=None): """Advances the particle filter by a single time step.""" with tf.name_scope('filter_one_step'): seed = SeedStream(seed, 'filter_one_step') num_particles = prefer_static.shape(log_weights)[-1] proposed_particles, proposal_log_weights = _propose_with_log_weights( step=step - 1, particles=previous_particles, transition_fn=transition_fn, proposal_fn=proposal_fn, seed=seed) observation_log_weights = _compute_observation_log_weights( step, proposed_particles, observation, observation_fn) unnormalized_log_weights = (log_weights + proposal_log_weights + observation_log_weights) step_log_marginal_likelihood = tf.math.reduce_logsumexp( unnormalized_log_weights, axis=-1) log_weights = (unnormalized_log_weights - step_log_marginal_likelihood[..., tf.newaxis]) # Adaptive resampling: resample particles iff the specified criterion. do_resample = tf.convert_to_tensor( resample_criterion_fn(unnormalized_log_weights))[ ..., tf.newaxis] # Broadcast over particles. # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to use # the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a dealbreaker. resampled_particles, resample_indices = _resample(proposed_particles, log_weights, seed=seed) dummy_indices = tf.broadcast_to(prefer_static.range(num_particles), prefer_static.shape(resample_indices)) uniform_weights = (prefer_static.zeros_like(log_weights) - prefer_static.log(num_particles)) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: prefer_static.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (proposed_particles, dummy_indices, log_weights)) return ParticleFilterStepResults( particles=resampled_particles, log_weights=log_weights, parent_indices=resample_indices, step_log_marginal_likelihood=step_log_marginal_likelihood)
def _forward_log_det_jacobian(self, x): # This code is similar to tf.math.log_softmax but different because we have # an implicit zero column to handle. I.e., instead of: # reduce_sum(logits - reduce_sum(exp(logits), dim)) # we must do: # log_normalization = 1 + reduce_sum(exp(logits)) # -log_normalization + reduce_sum(logits - log_normalization) np1 = prefer_static.cast(1 + prefer_static.shape(x)[-1], dtype=x.dtype) return (0.5 * prefer_static.log(np1) + tf.reduce_sum(x, axis=-1) - np1 * tf.math.softplus(tf.reduce_logsumexp(x, axis=-1)))
def _log_soosum_exp_impl(logx, axis, keepdims, compute_mean): """Implementation for `*soosum*` functions.""" with tf.name_scope('log_soosum_exp_impl'): logx = tf.convert_to_tensor(logx, name='logx') log_loosum_x, log_sum_x, n = _log_loosum_exp_impl(logx, axis, keepdims, compute_mean=False) # The swap-one-out-sum ('soosum') is n different sums, each of which # replaces the i-th item with the i-th-left-out average (or the user # specified value), i.e., # soo_sum_x[i] = [exp(logx) - exp(logx[i])] + exp(mean(logx[!=i])) # = exp(log_loosum_x[i]) + exp(loo_log_swap_in[i]) loo_log_swap_in = ( (tf.reduce_sum(logx, axis=axis, keepdims=True) - logx) / (n - 1.)) log_soosum_x = log_add_exp(log_loosum_x, loo_log_swap_in) if not compute_mean: return log_soosum_x, log_sum_x log_n = prefer_static.log(n) return log_soosum_x - log_n, log_sum_x - log_n
def _inverse_log_det_jacobian(self, y): # Let B be the forward map defined by the bijector. Consider the map # F : R^n -> R^n where the image of B in R^{n+1} is restricted to the first # n coordinates. # # Claim: det{ dF(X)/dX } = prod(Y) where Y = B(X). # Proof: WLOG, in vector notation: # X = log(Y[:-1]) - log(Y[-1]) # where, # Y[-1] = 1 - sum(Y[:-1]). # We have: # det{dF} = 1 / det{ dX/dF(X} } (1) # = 1 / det{ diag(1 / Y[:-1]) + 1 / Y[-1] } # = 1 / det{ inv{ diag(Y[:-1]) - Y[:-1]' Y[:-1] } } # = det{ diag(Y[:-1]) - Y[:-1]' Y[:-1] } # = (1 + Y[:-1]' inv{diag(Y[:-1])} Y[:-1]) det{diag(Y[:-1])} (2) # = Y[-1] prod(Y[:-1]) # = prod(Y) # # Let P be the image of R^n under F. Define the lift G, from P to R^{n+1}, # which appends the last coordinate, Y[-1] := 1 - \sum_k Y_k. G is linear, # so its Jacobian is constant. # # The differential of G, DG, is eye(n) with a row of -1s appended to the # bottom. To compute the Jacobian sqrt{det{(DG)^T(DG)}}, one can see that # (DG)^T(DG) = A + eye(n), where A is the n x n matrix of 1s. This has # eigenvalues (n + 1, 1,...,1), so the determinant is (n + 1). Hence, the # Jacobian of G is sqrt{n + 1} everywhere. # # Putting it all together, the forward bijective map B can be written as # B(X) = G(F(X)) and has Jacobian sqrt{n + 1} * prod(F(X)). # # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector # docstring "Tip". # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma np1 = ps.cast(ps.shape(y)[-1], dtype=y.dtype) return -(0.5 * ps.log(np1) + tf.reduce_sum(tf.math.log(y), axis=-1))
def _particle_filter_initial_weighted_particles(observations, observation_fn, initial_state_prior, initial_state_proposal, num_particles, seed=None): """Initialize a set of weighted particles including the first observation.""" # Initial particles all have the same weight, `1. / num_particles`. broadcast_batch_shape = tf.convert_to_tensor( functools.reduce( ps.broadcast_shape, tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []), dtype=tf.int32) initial_log_weights = ps.zeros( ps.concat([[num_particles], broadcast_batch_shape], axis=0), dtype=tf.float32) - ps.log(num_particles) # Propose an initial state. if initial_state_proposal is None: initial_state = initial_state_prior.sample(num_particles, seed=seed) else: initial_state = initial_state_proposal.sample(num_particles, seed=seed) initial_log_weights += (initial_state_prior.log_prob(initial_state) - initial_state_proposal.log_prob(initial_state)) # The initial proposal weights are normalized in expectation, but actually # normalizing them reduces variance in the initial marginal # likelihood. initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0) # Return particles weighted by the initial observation. return smc_kernel.WeightedParticles( particles=initial_state, log_weights=initial_log_weights + _compute_observation_log_weights( step=0, particles=initial_state, observations=observations, observation_fn=observation_fn))
def auto_correlation(x, axis=-1, max_lags=None, center=True, normalize=True, name='auto_correlation'): """Auto correlation along one axis. Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) ``` RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, W[n] := (X[n] - MU) / S, MU := E{ X[0] }, S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. ``` This function takes the viewpoint that `x` is (along one axis) a finite sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an estimate of `RXX[m]` as follows: After extending `x` from length `L` to `inf` by zero padding, the auto correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as ``` rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), w[n] := (x[n] - mu) / s, mu := L**-1 sum_n x[n], s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) ``` The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users often set `max_lags` small enough so that the entire output is meaningful. Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation contains a slight bias, which goes to zero as `len(x) - m --> infinity`. Args: x: `float32` or `complex64` `Tensor`. axis: Python `int`. The axis number along which to compute correlation. Other dimensions index different batch members. max_lags: Positive `int` tensor. The maximum value of `m` to consider (in equation above). If `max_lags >= x.shape[axis]`, we effectively re-set `max_lags` to `x.shape[axis] - 1`. center: Python `bool`. If `False`, do not subtract the mean estimate `mu` from `x[n]` when forming `w[n]`. normalize: Python `bool`. If `False`, do not divide by the variance estimate `s**2` when forming `w[n]`. name: `String` name to prepend to created ops. Returns: `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for `i != axis`, and `rxx.shape[axis] = max_lags + 1`. Raises: TypeError: If `x` is not a supported type. """ # Implementation details: # Extend length N / 2 1-D array x to length N by zero padding onto the end. # Then, set # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. # It is not hard to see that # F[x]_k Conj(F[x]_k) = F[R]_k, where # R_m := sum_n x_n Conj(x_{(n - m) mod N}). # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT # based version of estimating RXX. # Note that this is a special case of the Wiener-Khinchin Theorem. with tf.name_scope(name): x = tf.convert_to_tensor(x, name='x') # Rotate dimensions of x in order to put axis at the rightmost dim. # FFT op requires this. rank = ps.rank(x) if axis < 0: axis = rank + axis shift = rank - 1 - axis # Suppose x.shape[axis] = T, so there are T 'time' steps. # ==> x_rotated.shape = B + [T], # where B is x_rotated's batch shape. x_rotated = distribution_util.rotate_transpose(x, shift) if center: x_rotated = x_rotated - tf.reduce_mean( x_rotated, axis=-1, keepdims=True) # x_len = N / 2 from above explanation. The length of x along axis. # Get a value for x_len that works in all cases. x_len = ps.shape(x_rotated)[-1] # TODO(langmore) Investigate whether this zero padding helps or hurts. At # the moment is necessary so that all FFT implementations work. # Zero pad to the next power of 2 greater than 2 * x_len, which equals # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). x_len_float64 = ps.cast(x_len, np.float64) target_length = ps.pow(np.float64(2.), ps.ceil(ps.log(x_len_float64 * 2) / np.log(2.))) pad_length = ps.cast(target_length - x_len_float64, np.int32) # We should have: # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] # = B + [T + pad_length] x_rotated_pad = distribution_util.pad(x_rotated, axis=-1, back=True, count=pad_length) dtype = x.dtype if not dtype_util.is_complex(dtype): if not dtype_util.is_floating(dtype): raise TypeError( 'Argument x must have either float or complex dtype' ' found: {}'.format(dtype)) x_rotated_pad = tf.complex( x_rotated_pad, dtype_util.as_numpy_dtype(dtype_util.real_dtype(dtype))(0.)) # Autocorrelation is IFFT of power-spectral density (up to some scaling). fft_x_rotated_pad = tf.signal.fft(x_rotated_pad) spectral_density = fft_x_rotated_pad * tf.math.conj(fft_x_rotated_pad) # shifted_product is R[m] from above detailed explanation. # It is the inner product sum_n X[n] * Conj(X[n - m]). shifted_product = tf.signal.ifft(spectral_density) # Cast back to real-valued if x was real to begin with. shifted_product = tf.cast(shifted_product, dtype) # Figure out if we can deduce the final static shape, and set max_lags. # Use x_rotated as a reference, because it has the time dimension in the far # right, and was created before we performed all sorts of crazy shape # manipulations. know_static_shape = True if not tensorshape_util.is_fully_defined(x_rotated.shape): know_static_shape = False if max_lags is None: max_lags = x_len - 1 else: max_lags = tf.convert_to_tensor(max_lags, name='max_lags') max_lags_ = tf.get_static_value(max_lags) if max_lags_ is None or not know_static_shape: know_static_shape = False max_lags = tf.minimum(x_len - 1, max_lags) else: max_lags = min(x_len - 1, max_lags_) # Chop off the padding. # We allow users to provide a huge max_lags, but cut it off here. # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] shifted_product_chopped = shifted_product[..., :max_lags + 1] # If possible, set shape. if know_static_shape: chopped_shape = tensorshape_util.as_list(x_rotated.shape) chopped_shape[-1] = min(x_len, max_lags + 1) tensorshape_util.set_shape(shifted_product_chopped, chopped_shape) # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The # other terms were zeros arising only due to zero padding. # `denominator = (N / 2 - m)` (defined below) is the proper term to # divide by to make this an unbiased estimate of the expectation # E[X[n] Conj(X[n - m])]. x_len = ps.cast(x_len, dtype_util.real_dtype(dtype)) max_lags = ps.cast(max_lags, dtype_util.real_dtype(dtype)) denominator = x_len - ps.range(0., max_lags + 1.) denominator = ps.cast(denominator, dtype) shifted_product_rotated = shifted_product_chopped / denominator if normalize: shifted_product_rotated /= shifted_product_rotated[..., :1] # Transpose dimensions back to those of x. return distribution_util.rotate_transpose(shifted_product_rotated, -shift)
def _filter_one_step(step, observation, previous_particles, log_weights, transition_fn, observation_fn, proposal_fn, resample_criterion_fn, has_observation=True, seed=None): """Advances the particle filter by a single time step.""" with tf.name_scope('filter_one_step'): seed = SeedStream(seed, 'filter_one_step') num_particles = prefer_static.shape(log_weights)[0] proposed_particles, proposal_log_weights = _propose_with_log_weights( step=step - 1, particles=previous_particles, transition_fn=transition_fn, proposal_fn=proposal_fn, seed=seed) log_weights = tf.nn.log_softmax(proposal_log_weights + log_weights, axis=-1) # If this step has an observation, compute its weights and marginal # likelihood (and otherwise, leave weights unchanged). observation_log_weights = prefer_static.cond( has_observation, lambda: prefer_static.broadcast_to( # pylint: disable=g-long-lambda _compute_observation_log_weights(step, proposed_particles, observation, observation_fn), prefer_static.shape(log_weights)), lambda: tf.zeros_like(log_weights)) unnormalized_log_weights = log_weights + observation_log_weights step_log_marginal_likelihood = tf.math.reduce_logsumexp( unnormalized_log_weights, axis=0) log_weights = (unnormalized_log_weights - step_log_marginal_likelihood) # Adaptive resampling: resample particles iff the specified criterion. do_resample = resample_criterion_fn(unnormalized_log_weights) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to use # the resampled values for each batch element according to # `do_resample`. If there were no batching, we might prefer to use # `tf.cond` to avoid the resampling computation on steps where it's not # needed---but we're ultimately interested in adaptive resampling # for statistical (not computational) purposes, so this isn't a dealbreaker. resampled_particles, resample_indices = _resample(proposed_particles, log_weights, resample_independent, seed=seed) uniform_weights = (prefer_static.zeros_like(log_weights) - prefer_static.log(num_particles)) (resampled_particles, resample_indices, log_weights) = tf.nest.map_structure( lambda r, p: prefer_static.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), (proposed_particles, _dummy_indices_like(resample_indices), log_weights)) return ParticleFilterStepResults( particles=resampled_particles, log_weights=log_weights, parent_indices=resample_indices, step_log_marginal_likelihood=step_log_marginal_likelihood)
def particle_filter( observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_criterion_fn=ess_below_threshold, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_transitions_per_observation=1, num_steps_state_history_to_pass=None, num_steps_observation_history_to_pass=None, seed=None, name=None): # pylint: disable=g-doc-args """Samples a series of particles representing filtered latent states. The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all observations *up to that time*. Because particles may be resampled, a particle at time `t` may be different from the particle with the same index at time `t + 1`. To reconstruct trajectories by tracing back through the resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} Returns: particles: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing (possibly weighted) samples from the series of filtering distributions `p(latent_states[t] | observations[:t])`. log_weights: `float` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `log_weights[t, :]` are the logarithms of normalized importance weights (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of the particles at time `t`. These may be used in conjunction with `particles` to compute expectations under the series of filtering distributions. parent_indices: `int` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `parent_indices[t, k]` gives the index of the particle at time `t - 1` that the `k`th particle at time `t` is immediately descended from. See also `tfp.experimental.mcmc.reconstruct_trajectories`. step_log_marginal_likelihoods: float `Tensor` of shape `[num_observation_steps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each observed timestep `t`. Note that (by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. ${non_markovian_specification_str} """ seed = SeedStream(seed, 'particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = prefer_static.shape( tf.nest.flatten(observations)[0])[0] num_timesteps = (1 + num_transitions_per_observation * (num_observation_steps - 1)) # If no criterion is specified, default is to resample at every step. if not resample_criterion_fn: resample_criterion_fn = lambda _: True # Dress up the prior and prior proposal as a fake `transition_fn` and # `proposal_fn` respectively. prior_fn = lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_prior, num_particles) prior_proposal_fn = ( None if initial_state_proposal is None else lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_proposal, num_particles)) # Initially the particles all have the same weight, `1. / num_particles`. broadcast_batch_shape = tf.convert_to_tensor(functools.reduce( prefer_static.broadcast_shape, tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []), dtype=tf.int32) log_uniform_weights = prefer_static.zeros( prefer_static.concat([[num_particles], broadcast_batch_shape], axis=0), dtype=tf.float32) - prefer_static.log(num_particles) # Initialize from the prior, and incorporate the first observation. initial_step_results = _filter_one_step( step=0, # `previous_particles` at the first step is a dummy quantity, used only # to convey state structure and num_particles to an optional # proposal fn. previous_particles=prior_fn(0, []).sample(), log_weights=log_uniform_weights, observation=tf.nest.map_structure(lambda x: tf.gather(x, 0), observations), transition_fn=prior_fn, observation_fn=observation_fn, proposal_fn=prior_proposal_fn, resample_criterion_fn=resample_criterion_fn, seed=seed) def _loop_body(step, previous_step_results, accumulated_step_results, state_history): """Take one step in dynamics and accumulate marginal likelihood.""" step_has_observation = ( # The second of these conditions subsumes the first, but both are # useful because the first can often be evaluated statically. prefer_static.equal(num_transitions_per_observation, 1) | prefer_static.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation current_observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) history_to_pass_into_fns = {} if num_steps_observation_history_to_pass: history_to_pass_into_fns[ 'observation_history'] = _gather_history( observations, observation_idx, num_steps_observation_history_to_pass) if num_steps_state_history_to_pass: history_to_pass_into_fns['state_history'] = state_history new_step_results = _filter_one_step( step=step, previous_particles=previous_step_results.particles, log_weights=previous_step_results.log_weights, observation=current_observation, transition_fn=functools.partial(transition_fn, **history_to_pass_into_fns), observation_fn=functools.partial(observation_fn, **history_to_pass_into_fns), proposal_fn=(None if proposal_fn is None else functools.partial( proposal_fn, **history_to_pass_into_fns)), resample_criterion_fn=resample_criterion_fn, has_observation=step_has_observation, seed=seed) return _update_loop_variables(step, new_step_results, accumulated_step_results, state_history) loop_results = tf.while_loop( cond=lambda step, *_: step < num_timesteps, body=_loop_body, loop_vars=_initialize_loop_variables( initial_step_results, num_steps_state_history_to_pass, num_timesteps)) results = tf.nest.map_structure(lambda ta: ta.stack(), loop_results.accumulated_step_results) if num_transitions_per_observation != 1: # Return a log-prob for each observed step. observed_steps = prefer_static.range( 0, num_timesteps, num_transitions_per_observation) results = results._replace(step_log_marginal_likelihood=tf.gather( results.step_log_marginal_likelihood, observed_steps)) return results
def _log_loosum_exp_impl(logx, axis, keepdims, compute_mean): """Implementation for `*loosum*` functions.""" with tf.name_scope('log_loosum_exp_impl'): logx = tf.convert_to_tensor(logx, name='logx') dtype = dtype_util.as_numpy_dtype(logx.dtype) if axis is not None: x = np.array(axis) axis = (tf.convert_to_tensor( axis, name='axis', dtype_hint=tf.int32) if x.dtype is np.object else x.astype(np.int32)) log_sum_x = tf.reduce_logsumexp(logx, axis=axis, keepdims=True) # Later we'll want to compute the mean from a sum so we calculate the number # of reduced elements, n. n = prefer_static.size(logx) // prefer_static.size(log_sum_x) n = prefer_static.cast(n, dtype) # log_loosum_x[i] = # = logsumexp(logx[j] : j != i) # = log( exp(logsumexp(logx)) - exp(logx[i]) ) # = log( exp(logsumexp(logx - logx[i])) exp(logx[i]) - exp(logx[i])) # = logx[i] + log(exp(logsumexp(logx - logx[i])) - 1) # = logx[i] + log(exp(logsumexp(logx) - logx[i]) - 1) # = logx[i] + softplus_inverse(logsumexp(logx) - logx[i]) d = log_sum_x - logx # We use `d != 0` rather than `d > 0.` because `d < 0.` should never happen; # if it does we want to complain loudly (which `softplus_inverse` will). d_ok = tf.not_equal(d, 0.) safe_d = tf.where(d_ok, d, 1.) d_ok_result = logx + softplus_inverse(safe_d) neg_inf = tf.constant(-np.inf, dtype=dtype) # When not(d_ok) and is_positive_and_largest then we manually compute the # log_loosum_x. (We can efficiently do this for any one point but not all, # hence we still need the above calculation.) This is good because when # this condition is met, we cannot use the above calculation; its -inf. # We now compute the log-leave-out-max-sum, replicate it to every # point and make sure to select it only when we need to. max_logx = tf.reduce_max(logx, axis=axis, keepdims=True) is_positive_and_largest = (logx > 0.) & tf.equal(logx, max_logx) log_lomsum_x = tf.reduce_logsumexp(tf.where(is_positive_and_largest, neg_inf, logx), axis=axis, keepdims=True) d_not_ok_result = tf.where(is_positive_and_largest, log_lomsum_x, neg_inf) log_loosum_x = tf.where(d_ok, d_ok_result, d_not_ok_result) # We now squeeze log_sum_x so as if we used `keepdims=False`. # TODO(b/136176077): These mental gymnastics could all be replaced with # `tf.squeeze(log_sum_x, axis)` if tf.squeeze supported Tensor valued `axis` # arguments. if not keepdims: if axis is None: keepdims = np.array([], dtype=np.int32) else: rank = prefer_static.rank(logx) keepdims = prefer_static.setdiff1d( prefer_static.range(rank), prefer_static.non_negative_axis(axis, rank)) squeeze_shape = tf.gather(prefer_static.shape(logx), indices=keepdims) log_sum_x = tf.reshape(log_sum_x, shape=squeeze_shape) if prefer_static.is_numpy(keepdims): tensorshape_util.set_shape(log_sum_x, np.array(logx.shape)[keepdims]) # Set static shapes just in case we lost them. tensorshape_util.set_shape(n, []) tensorshape_util.set_shape(log_loosum_x, logx.shape) if not compute_mean: return log_loosum_x, log_sum_x, n log_nm1 = prefer_static.log(max(1., n - 1.)) log_n = prefer_static.log(n) return log_loosum_x - log_nm1, log_sum_x - log_n, n
def particle_filter( observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=ess_below_threshold, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_transitions_per_observation=1, trace_fn=_default_trace_fn, step_indices_to_trace=None, seed=None, name=None): # pylint: disable=g-doc-args """Samples a series of particles representing filtered latent states. The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all observations *up to that time*. Because particles may be resampled, a particle at time `t` may be different from the particle with the same index at time `t + 1`. To reconstruct trajectories by tracing back through the resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} trace_fn: Python `callable` defining the values to be traced at each step. It takes a `ParticleFilterStepResults` tuple and returns a structure of `Tensor`s. The default function returns `(particles, log_weights, parent_indices, step_log_likelihood)`. step_indices_to_trace: optional `int` `Tensor` listing, in increasing order, the indices of steps at which to record the values traced by `trace_fn`. If `None`, the default behavior is to trace at every timestep, equivalent to specifying `step_indices_to_trace=tf.range(num_timsteps)`. seed: Python `int` seed for random ops. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'particle_filter'`). Returns: particles: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing (possibly weighted) samples from the series of filtering distributions `p(latent_states[t] | observations[:t])`. log_weights: `float` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `log_weights[t, :]` are the logarithms of normalized importance weights (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of the particles at time `t`. These may be used in conjunction with `particles` to compute expectations under the series of filtering distributions. parent_indices: `int` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `parent_indices[t, k]` gives the index of the particle at time `t - 1` that the `k`th particle at time `t` is immediately descended from. See also `tfp.experimental.mcmc.reconstruct_trajectories`. incremental_log_marginal_likelihoods: float `Tensor` of shape `[num_observation_steps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each observed timestep `t`. Note that (by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. """ seed = SeedStream(seed, 'particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) num_timesteps = (1 + num_transitions_per_observation * (num_observation_steps - 1)) # If no criterion is specified, default is to resample at every step. if not resample_criterion_fn: resample_criterion_fn = lambda _: True # Canonicalize the list of steps to trace as a rank-1 tensor of (sorted) # positive integers. E.g., `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`. if step_indices_to_trace is not None: (step_indices_to_trace, traced_steps_have_rank_zero) = _canonicalize_steps_to_trace( step_indices_to_trace, num_timesteps) # Dress up the prior and prior proposal as a fake `transition_fn` and # `proposal_fn` respectively. prior_fn = lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_prior, num_particles) prior_proposal_fn = ( None if initial_state_proposal is None else lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_proposal, num_particles)) # Initially the particles all have the same weight, `1. / num_particles`. broadcast_batch_shape = tf.convert_to_tensor(functools.reduce( ps.broadcast_shape, tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []), dtype=tf.int32) log_uniform_weights = ps.zeros( ps.concat([[num_particles], broadcast_batch_shape], axis=0), dtype=tf.float32) - ps.log(num_particles) # Initialize from the prior and incorporate the first observation. dummy_previous_step = ParticleFilterStepResults( particles=prior_fn(0, []).sample(), log_weights=log_uniform_weights, parent_indices=None, incremental_log_marginal_likelihood=0., accumulated_log_marginal_likelihood=0.) initial_step_results = _filter_one_step( step=0, # `previous_particles` at the first step is a dummy quantity, used only # to convey state structure and num_particles to an optional # proposal fn. previous_step_results=dummy_previous_step, observation=tf.nest.map_structure(lambda x: tf.gather(x, 0), observations), transition_fn=prior_fn, observation_fn=observation_fn, proposal_fn=prior_proposal_fn, resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, seed=seed) def _loop_body(step, previous_step_results, accumulated_traced_results, num_steps_traced): """Take one step in dynamics and accumulate marginal likelihood.""" step_has_observation = ( # The second of these conditions subsumes the first, but both are # useful because the first can often be evaluated statically. ps.equal(num_transitions_per_observation, 1) | ps.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation current_observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) new_step_results = _filter_one_step( step=step, previous_step_results=previous_step_results, observation=current_observation, transition_fn=transition_fn, observation_fn=observation_fn, proposal_fn=proposal_fn, resample_criterion_fn=resample_criterion_fn, resample_fn=resample_fn, has_observation=step_has_observation, seed=seed) return _update_loop_variables( step=step, current_step_results=new_step_results, accumulated_traced_results=accumulated_traced_results, trace_fn=trace_fn, step_indices_to_trace=step_indices_to_trace, num_steps_traced=num_steps_traced) loop_results = tf.while_loop( cond=lambda step, *_: step < num_timesteps, body=_loop_body, loop_vars=_initialize_loop_variables( initial_step_results=initial_step_results, num_timesteps=num_timesteps, trace_fn=trace_fn, step_indices_to_trace=step_indices_to_trace)) results = tf.nest.map_structure( lambda ta: ta.stack(), loop_results.accumulated_traced_results) if step_indices_to_trace is not None: # If we were passed a rank-0 (single scalar) step to trace, don't # return a time axis in the returned results. results = ps.cond( traced_steps_have_rank_zero, lambda: tf.nest.map_structure(lambda x: x[0, ...], results), lambda: results) return results