def _design_matrix_for_one_seasonal_effect(num_steps, duration, period, dtype): current_period = np.int32(np.arange(num_steps) / duration) % period return np.transpose([ ps.where(current_period == p, # pylint: disable=g-complex-comprehension ps.ones([], dtype=dtype), ps.zeros([], dtype=dtype)) for p in range(period)])
def _inverse(self, y): ndims = prefer_static.rank(y) indices = prefer_static.reshape(prefer_static.add(self.axis, ndims), shape=[-1, 1]) num_left, num_right = prefer_static.unstack(self.paddings, num=2, axis=-1) x = tf.slice(y, begin=prefer_static.tensor_scatter_nd_update( prefer_static.zeros(ndims, dtype=tf.int32), indices, num_left), size=prefer_static.tensor_scatter_nd_sub( prefer_static.shape(y), indices, num_left + num_right)) if not self.validate_args: return x assertions = [ assert_util.assert_equal( self._forward(x), y, message=('Argument `y` to `inverse` was not padded with ' '`constant_values`.')), ] with tf.control_dependencies(assertions): return tf.identity(x)
def _init(shape_and_dtype): """Allocate TensorArray for storing state and momentum.""" return [ # pylint: disable=g-complex-comprehension prefer_static.zeros(prefer_static.concat( [[max(self._write_instruction) + 1], s], axis=0), dtype=d) for (s, d) in shape_and_dtype ]
def init_velocity_state_memory(self, input_tensors): """Allocate TensorArray for storing state and momentum.""" shape_and_dtype = [(ps.shape(x_), x_.dtype) for x_ in input_tensors] return [ # pylint: disable=g-complex-comprehension ps.zeros( ps.concat([[max(self._write_instruction) + 1], s], axis=0), dtype=d) for (s, d) in shape_and_dtype ]
def _squeeze(x, axis): """A version of squeeze that works with dynamic axis.""" x = tf.convert_to_tensor(x, name='x') if axis is None: return tf.squeeze(x, axis=None) axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32) axis = axis + ps.zeros([1], dtype=axis.dtype) # Make axis at least 1d. keep_axis = ps.setdiff1d(ps.range(0, ps.rank(x)), axis) return tf.reshape(x, ps.gather(ps.shape(x), keep_axis))
def _forward(self, x): ndims = ps.rank(x) indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) return tf.pad( x, paddings=ps.tensor_scatter_nd_update( ps.zeros([ndims, 2], dtype=tf.int32), indices, self.paddings), mode=self.mode, constant_values=ps.cast(self.constant_values, dtype=x.dtype))
def _entropy(self, **kwargs): if not self.bijector.is_constant_jacobian: raise NotImplementedError('`entropy` is not implemented.') if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError('`entropy` is not implemented when ' '`bijector` is not injective.') distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) override_event_shape = tf.convert_to_tensor(self._override_event_shape) override_batch_shape = tf.convert_to_tensor(self._override_batch_shape) base_batch_shape_tensor = self.distribution.batch_shape_tensor() base_event_shape_tensor = self.distribution.event_shape_tensor() # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It # can be shown that: # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)]. # If is_constant_jacobian then: # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) # where c can by anything. entropy = self.distribution.entropy(**distribution_kwargs) if self._is_maybe_event_override: # H[X] = sum_i H[X_i] if X_i are mutually independent. # This means that a reduce_sum is a simple rescaling. entropy = entropy * tf.cast(tf.reduce_prod(override_event_shape), dtype=dtype_util.base_dtype( entropy.dtype)) if self._is_maybe_batch_override: new_shape = tf.concat([ prefer_static.ones_like(override_batch_shape), base_batch_shape_tensor ], 0) entropy = tf.reshape(entropy, new_shape) multiples = tf.concat([ override_batch_shape, prefer_static.ones_like(base_batch_shape_tensor) ], 0) entropy = tf.tile(entropy, multiples) dummy = prefer_static.zeros(shape=tf.concat([ self._batch_shape_tensor(override_batch_shape, base_batch_shape_tensor), self._event_shape_tensor(override_event_shape, base_event_shape_tensor) ], 0), dtype=self.dtype) event_ndims = ( tensorshape_util.rank(self.event_shape) # pylint: disable=g-long-ternary if tensorshape_util.rank(self.event_shape) is not None else tf.size( self._event_shape_tensor(override_event_shape, base_event_shape_tensor))) ildj = self.bijector.inverse_log_det_jacobian(dummy, event_ndims=event_ndims, **bijector_kwargs) entropy = entropy - tf.cast(ildj, entropy.dtype) tensorshape_util.set_shape(entropy, self.batch_shape) return entropy
def _inverse(self, y): ndims = ps.rank(y) shifted_y = ps.pad( ps.slice( y, ps.zeros(ndims, dtype=tf.int32), ps.shape(y) - ps.one_hot(ndims + self.axis, ndims, dtype=tf.int32) ), # Remove the last entry of y in the chosen dimension. paddings=ps.one_hot( ps.one_hot(ndims + self.axis, ndims, on_value=0, off_value=-1), 2, dtype=tf.int32 ) # Insert zeros at the beginning of the chosen dimension. ) return y - shifted_y
def body_fn(vecs, i): # Slice out the vector w.r.t. which we're orthogonalizing the rest. vecs_ndims = ps.rank(vecs) select_axis = (ps.range(vecs_ndims) == vecs_ndims - 1) start = ps.where(select_axis, i, ps.zeros([vecs_ndims], i.dtype)) size = ps.where(select_axis, 1, ps.shape(vecs)) u = tf.math.l2_normalize(tf.slice(vecs, start, size), axis=-2) # TODO(b/171730305): XLA can't handle this line... # u = tf.math.l2_normalize(vecs[..., i, tf.newaxis], axis=-2) # Find weights by dotting the d x 1 against the d x n. weights = tf.einsum('...dm,...dn->...n', u, vecs) # Project out vector `u` from the trailing vectors. masked_weights = tf.where(tf.range(n) > i, weights, 0.)[..., tf.newaxis, :] vecs = vecs - tf.math.multiply_no_nan(u, masked_weights) tensorshape_util.set_shape(vecs, vectors.shape) return vecs, i + 1
def _entropy(self): if not self.bijector.is_constant_jacobian: raise NotImplementedError("entropy is not implemented") if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError("entropy is not implemented when " "bijector is not injective.") # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It # can be shown that: # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)]. # If is_constant_jacobian then: # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) # where c can by anything. entropy = self.distribution.entropy() if self._is_maybe_event_override: # H[X] = sum_i H[X_i] if X_i are mutually independent. # This means that a reduce_sum is a simple rescaling. entropy *= tf.cast( tf.reduce_prod(input_tensor=self._override_event_shape), dtype=entropy.dtype.base_dtype) if self._is_maybe_batch_override: new_shape = tf.concat([ prefer_static.ones_like(self._override_batch_shape), self.distribution.batch_shape_tensor() ], 0) entropy = tf.reshape(entropy, new_shape) multiples = tf.concat([ self._override_batch_shape, prefer_static.ones_like(self.distribution.batch_shape_tensor()) ], 0) entropy = tf.tile(entropy, multiples) dummy = prefer_static.zeros(shape=tf.concat( [self.batch_shape_tensor(), self.event_shape_tensor()], 0), dtype=self.dtype) event_ndims = (self.event_shape.ndims if self.event_shape.ndims is not None else tf.size( input=self.event_shape_tensor())) ildj = self.bijector.inverse_log_det_jacobian(dummy, event_ndims=event_ndims) entropy -= tf.cast(ildj, entropy.dtype) entropy.set_shape(self.batch_shape) return entropy
def interpolate_backward_differences(backward_differences, order, step_size_ratio): """Updates backward differences when a change in the step size occurs.""" state_dtype = backward_differences.dtype interpolation_matrix_ = interpolation_matrix(state_dtype, order, step_size_ratio) interpolation_matrix_unit_step_size_ratio = interpolation_matrix( state_dtype, order, 1.) interpolated_backward_differences_orders_one_to_five = tf.matmul( interpolation_matrix_unit_step_size_ratio, tf.matmul(interpolation_matrix_, backward_differences[1:MAX_ORDER + 1])) interpolated_backward_differences = tf.concat([ tf.gather(backward_differences, [0]), interpolated_backward_differences_orders_one_to_five, ps.zeros(ps.stack([2, ps.shape(backward_differences)[1]]), dtype=state_dtype), ], 0) return interpolated_backward_differences
def _forward(self, x): x = tf.convert_to_tensor(x, name='x') batch_shape = ps.shape(x)[:-1] # Pad zeros on the top row and right column. y = fill_triangular.FillTriangular().forward(x) rank = ps.rank(y) paddings = ps.concat( [ps.zeros([rank - 2, 2], dtype=tf.int32), [[1, 0], [0, 1]]], axis=0) y = tf.pad(y, paddings) # Set diagonal to 1s. n = ps.shape(y)[-1] diag = tf.ones(ps.concat([batch_shape, [n]], axis=-1), dtype=x.dtype) y = tf.linalg.set_diag(y, diag) # Normalize each row to have Euclidean (L2) norm 1. y /= tf.norm(y, axis=-1)[..., tf.newaxis] return y
def _particle_filter_initial_weighted_particles(observations, observation_fn, initial_state_prior, initial_state_proposal, num_particles, seed=None): """Initialize a set of weighted particles including the first observation.""" # Initial particles all have the same weight, `1. / num_particles`. broadcast_batch_shape = tf.convert_to_tensor( functools.reduce( ps.broadcast_shape, tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []), dtype=tf.int32) initial_log_weights = ps.zeros( ps.concat([[num_particles], broadcast_batch_shape], axis=0), dtype=tf.float32) - ps.log(num_particles) # Propose an initial state. if initial_state_proposal is None: initial_state = initial_state_prior.sample(num_particles, seed=seed) else: initial_state = initial_state_proposal.sample(num_particles, seed=seed) initial_log_weights += (initial_state_prior.log_prob(initial_state) - initial_state_proposal.log_prob(initial_state)) # The initial proposal weights are normalized in expectation, but actually # normalizing them reduces variance in the initial marginal # likelihood. initial_log_weights = tf.nn.log_softmax(initial_log_weights, axis=0) # Return particles weighted by the initial observation. return smc_kernel.WeightedParticles( particles=initial_state, log_weights=initial_log_weights + _compute_observation_log_weights( step=0, particles=initial_state, observations=observations, observation_fn=observation_fn))
def make_convolution_transpose_fn_with_subkernels_matrix( filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, validate_args=False, name=None): """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): if tf.get_static_value(rank) != 2: raise NotImplementedError( 'Argument `rank` currently only supports `2`; ' 'saw "{}".'.format(rank)) strides = tf.get_static_value(strides) if not isinstance(strides, int): raise ValueError( 'Argument `strides` must be a statically known integer.' 'Saw: {}'.format(strides)) [ filter_shape, rank, _, padding, dilations, ] = prepare_conv_args(filter_shape, rank=rank, strides=strides, padding=padding, dilations=dilations, is_transpose=True, validate_args=validate_args) fh, fw = filter_shape dh, dw = dilations # Determine maximum filter height and filter width of sub-kernels. sub_fh = (fh - 1) // strides + 1 sub_fw = (fw - 1) // strides + 1 def loop_body(i_, event_ind): i = i_ // strides j = i_ % strides i_ind = ps.range(i * fw, fw * fh, delta=strides * fw, dtype=dtype) j_ind = ps.range(j, fw, delta=strides, dtype=dtype) nc = cartesian_add([i_ind, j_ind]) ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0]) k = ps.reshape(cartesian_add([ ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype), ps.range(ps.shape(nc)[1], dtype=dtype) ]), shape=[-1]) last_j = strides - (fw - j - 1) % strides - 1 last_i = strides - (fh - i - 1) % strides - 1 kernel_ind = ps.stack( [k, ps.ones_like(k) * last_i * strides + last_j], axis=1) event_ind = ps.tensor_scatter_nd_update(event_ind, ind[..., tf.newaxis], kernel_ind) return i_ + 1, event_ind event_ind = ps.zeros((fh * fw, 2), dtype=dtype) _, event_ind = tf.while_loop(lambda i, _: i < strides**2, loop_body, [tf.zeros([], dtype=dtype), event_ind]) tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding( fh, stride=strides, dilation=dh, padding=padding) tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding( fw, stride=strides, dilation=dw, padding=padding) pad_bottom = (tot_pad_bottom - 1) // strides + 1 pad_top = (tot_pad_top - 1) // strides + 1 pad_right = (tot_pad_right - 1) // strides + 1 pad_left = (tot_pad_left - 1) // strides + 1 padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right)) truncate_top = pad_top * strides - tot_pad_top truncate_left = pad_left * strides - tot_pad_left def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel=kernel, filter_shape=filter_shape, strides=(strides, ) * rank, padding=padding, dilations=dilations, c_out=c_out, batch_shape=batch_shape, event_shape=event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) x_pad_shape = ps.shape(x_pad)[:-3] flat_shape = ps.pad(x_pad_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_pad, shape=flat_shape) idx, s = im2row_index( (xh + tf.reduce_sum(padding_vals[0]), xw + tf.reduce_sum(padding_vals[1]), c_in), block_shape=(sub_fh, sub_fw), slice_step=(1, 1), dilations=dilations) x_ = tf.gather(flat_x, indices=idx, axis=-1) im_x = tf.reshape(x_, shape=ps.concat([x_pad_shape, s], axis=0)) # Add channels to subkernel indices idx_event = event_ind * [[c_in, 1]] idx_event_channels = (idx_event[tf.newaxis] + tf.stack( [ps.range(c_in), tf.zeros( (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :]) idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels, block_shape=[c_in], crops=[[0, 0]]), axis=0) idx_event_broadcast = tf.broadcast_to( idx_event, shape=ps.concat( [kernel_batch, ps.shape(idx_event)], axis=0)) # Add cartesian product of batch indices, since scatter_nd can only be # applied to leading dimensions. idx_batch = tf.stack(tf.meshgrid(*[ ps.range(b_, delta=1, dtype=dtype) for b_ in tf.unstack(kernel_batch) ], indexing='ij'), axis=ps.size(kernel_batch)) idx_batch = tf.cast(idx_batch, dtype=dtype) # empty tensor is float idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros( (ps.shape(idx_event)[0], 1), dtype=dtype) idx_kernel = tf.concat( [idx_batch_broadcast, idx_event_broadcast], axis=-1) kernel_mat = tf.scatter_nd( idx_kernel, updates=kernel, shape=ps.cast(ps.concat([ kernel_batch, [sub_fh * sub_fw * c_in, strides**2, c_out] ], axis=0), dtype=dtype)) kernel_mat = tf.reshape( kernel_mat, shape=ps.concat( [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]], axis=0)) kernel_mat = kernel_mat[..., tf.newaxis, :, :] out = tf.matmul(im_x, kernel_mat) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) if strides > 1: tot_size = tf.reduce_prod(broadcast_batch_shape) flat_out = tf.reshape(out, shape=ps.concat([[tot_size], ps.shape(out)[-3:]], axis=0)) out = tf.nn.depth_to_space(flat_out, block_size=strides) if padding == 'VALID': out_height = fh + strides * (xh - 1) out_width = fw + strides * (xw - 1) elif padding == 'SAME': out_height = xh * strides out_width = xw * strides out = out[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :] out = tf.reshape( out, shape=ps.concat([ broadcast_batch_shape, [out_height, out_width, c_out] ], axis=0)) return out return op
def _batched_isotropic_normal_like(state_part): return sample.Sample( normal.Normal(ps.zeros([], dtype=state_part.dtype), 1.), ps.shape(state_part)[batch_rank:])
def particle_filter( observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_criterion_fn=ess_below_threshold, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_transitions_per_observation=1, num_steps_state_history_to_pass=None, num_steps_observation_history_to_pass=None, seed=None, name=None): # pylint: disable=g-doc-args """Samples a series of particles representing filtered latent states. The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all observations *up to that time*. Because particles may be resampled, a particle at time `t` may be different from the particle with the same index at time `t + 1`. To reconstruct trajectories by tracing back through the resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} Returns: particles: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing (possibly weighted) samples from the series of filtering distributions `p(latent_states[t] | observations[:t])`. log_weights: `float` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `log_weights[t, :]` are the logarithms of normalized importance weights (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of the particles at time `t`. These may be used in conjunction with `particles` to compute expectations under the series of filtering distributions. parent_indices: `int` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `parent_indices[t, k]` gives the index of the particle at time `t - 1` that the `k`th particle at time `t` is immediately descended from. See also `tfp.experimental.mcmc.reconstruct_trajectories`. step_log_marginal_likelihoods: float `Tensor` of shape `[num_observation_steps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each observed timestep `t`. Note that (by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. ${non_markovian_specification_str} """ seed = SeedStream(seed, 'particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = prefer_static.shape( tf.nest.flatten(observations)[0])[0] num_timesteps = (1 + num_transitions_per_observation * (num_observation_steps - 1)) # If no criterion is specified, default is to resample at every step. if not resample_criterion_fn: resample_criterion_fn = lambda _: True # Dress up the prior and prior proposal as a fake `transition_fn` and # `proposal_fn` respectively. prior_fn = lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_prior, num_particles) prior_proposal_fn = ( None if initial_state_proposal is None else lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_proposal, num_particles)) # Initially the particles all have the same weight, `1. / num_particles`. broadcast_batch_shape = tf.convert_to_tensor(functools.reduce( prefer_static.broadcast_shape, tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []), dtype=tf.int32) log_uniform_weights = prefer_static.zeros( prefer_static.concat([[num_particles], broadcast_batch_shape], axis=0), dtype=tf.float32) - prefer_static.log(num_particles) # Initialize from the prior, and incorporate the first observation. initial_step_results = _filter_one_step( step=0, # `previous_particles` at the first step is a dummy quantity, used only # to convey state structure and num_particles to an optional # proposal fn. previous_particles=prior_fn(0, []).sample(), log_weights=log_uniform_weights, observation=tf.nest.map_structure(lambda x: tf.gather(x, 0), observations), transition_fn=prior_fn, observation_fn=observation_fn, proposal_fn=prior_proposal_fn, resample_criterion_fn=resample_criterion_fn, seed=seed) def _loop_body(step, previous_step_results, accumulated_step_results, state_history): """Take one step in dynamics and accumulate marginal likelihood.""" step_has_observation = ( # The second of these conditions subsumes the first, but both are # useful because the first can often be evaluated statically. prefer_static.equal(num_transitions_per_observation, 1) | prefer_static.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation current_observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) history_to_pass_into_fns = {} if num_steps_observation_history_to_pass: history_to_pass_into_fns[ 'observation_history'] = _gather_history( observations, observation_idx, num_steps_observation_history_to_pass) if num_steps_state_history_to_pass: history_to_pass_into_fns['state_history'] = state_history new_step_results = _filter_one_step( step=step, previous_particles=previous_step_results.particles, log_weights=previous_step_results.log_weights, observation=current_observation, transition_fn=functools.partial(transition_fn, **history_to_pass_into_fns), observation_fn=functools.partial(observation_fn, **history_to_pass_into_fns), proposal_fn=(None if proposal_fn is None else functools.partial( proposal_fn, **history_to_pass_into_fns)), resample_criterion_fn=resample_criterion_fn, has_observation=step_has_observation, seed=seed) return _update_loop_variables(step, new_step_results, accumulated_step_results, state_history) loop_results = tf.while_loop( cond=lambda step, *_: step < num_timesteps, body=_loop_body, loop_vars=_initialize_loop_variables( initial_step_results, num_steps_state_history_to_pass, num_timesteps)) results = tf.nest.map_structure(lambda ta: ta.stack(), loop_results.accumulated_step_results) if num_transitions_per_observation != 1: # Return a log-prob for each observed step. observed_steps = prefer_static.range( 0, num_timesteps, num_transitions_per_observation) results = results._replace(step_log_marginal_likelihood=tf.gather( results.step_log_marginal_likelihood, observed_steps)) return results
def pad_tensor_with_trailing_zeros(x, num_zeros): return tf.pad( x, ps.concat( [ps.zeros([ps.rank(x) - 1, 2], dtype=np.int32), [[0, num_zeros]]], axis=0))
def particle_filter( observations, initial_state_prior, transition_fn, observation_fn, num_particles, initial_state_proposal=None, proposal_fn=None, resample_fn=weighted_resampling.resample_systematic, resample_criterion_fn=ess_below_threshold, rejuvenation_kernel_fn=None, # TODO(davmre): not yet supported. pylint: disable=unused-argument num_transitions_per_observation=1, trace_fn=_default_trace_fn, step_indices_to_trace=None, seed=None, name=None): # pylint: disable=g-doc-args """Samples a series of particles representing filtered latent states. The particle filter samples from the sequence of "filtering" distributions `p(state[t] | observations[:t])` over latent states: at each point in time, this is the distribution conditioned on all observations *up to that time*. Because particles may be resampled, a particle at time `t` may be different from the particle with the same index at time `t + 1`. To reconstruct trajectories by tracing back through the resampling process, see `tfp.mcmc.experimental.reconstruct_trajectories`. ${particle_filter_arg_str} trace_fn: Python `callable` defining the values to be traced at each step. It takes a `ParticleFilterStepResults` tuple and returns a structure of `Tensor`s. The default function returns `(particles, log_weights, parent_indices, step_log_likelihood)`. step_indices_to_trace: optional `int` `Tensor` listing, in increasing order, the indices of steps at which to record the values traced by `trace_fn`. If `None`, the default behavior is to trace at every timestep, equivalent to specifying `step_indices_to_trace=tf.range(num_timsteps)`. seed: Python `int` seed for random ops. name: Python `str` name for ops created by this method. Default value: `None` (i.e., `'particle_filter'`). Returns: particles: a (structure of) Tensor(s) matching the latent state, each of shape `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing (possibly weighted) samples from the series of filtering distributions `p(latent_states[t] | observations[:t])`. log_weights: `float` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `log_weights[t, :]` are the logarithms of normalized importance weights (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of the particles at time `t`. These may be used in conjunction with `particles` to compute expectations under the series of filtering distributions. parent_indices: `int` `Tensor` of shape `[num_timesteps, num_particles, b1, ..., bN]`, such that `parent_indices[t, k]` gives the index of the particle at time `t - 1` that the `k`th particle at time `t` is immediately descended from. See also `tfp.experimental.mcmc.reconstruct_trajectories`. incremental_log_marginal_likelihoods: float `Tensor` of shape `[num_observation_steps, b1, ..., bN]`, giving the natural logarithm of an unbiased estimate of `p(observations[t] | observations[:t])` at each observed timestep `t`. Note that (by [Jensen's inequality]( https://en.wikipedia.org/wiki/Jensen%27s_inequality)) this is *smaller* in expectation than the true `log p(observations[t] | observations[:t])`. """ seed = SeedStream(seed, 'particle_filter') with tf.name_scope(name or 'particle_filter'): num_observation_steps = ps.size0(tf.nest.flatten(observations)[0]) num_timesteps = (1 + num_transitions_per_observation * (num_observation_steps - 1)) # If no criterion is specified, default is to resample at every step. if not resample_criterion_fn: resample_criterion_fn = lambda _: True # Canonicalize the list of steps to trace as a rank-1 tensor of (sorted) # positive integers. E.g., `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`. if step_indices_to_trace is not None: (step_indices_to_trace, traced_steps_have_rank_zero) = _canonicalize_steps_to_trace( step_indices_to_trace, num_timesteps) # Dress up the prior and prior proposal as a fake `transition_fn` and # `proposal_fn` respectively. prior_fn = lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_prior, num_particles) prior_proposal_fn = ( None if initial_state_proposal is None else lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_proposal, num_particles)) # Initially the particles all have the same weight, `1. / num_particles`. broadcast_batch_shape = tf.convert_to_tensor(functools.reduce( ps.broadcast_shape, tf.nest.flatten(initial_state_prior.batch_shape_tensor()), []), dtype=tf.int32) log_uniform_weights = ps.zeros( ps.concat([[num_particles], broadcast_batch_shape], axis=0), dtype=tf.float32) - ps.log(num_particles) # Initialize from the prior and incorporate the first observation. dummy_previous_step = ParticleFilterStepResults( particles=prior_fn(0, []).sample(), log_weights=log_uniform_weights, parent_indices=None, incremental_log_marginal_likelihood=0., accumulated_log_marginal_likelihood=0.) initial_step_results = _filter_one_step( step=0, # `previous_particles` at the first step is a dummy quantity, used only # to convey state structure and num_particles to an optional # proposal fn. previous_step_results=dummy_previous_step, observation=tf.nest.map_structure(lambda x: tf.gather(x, 0), observations), transition_fn=prior_fn, observation_fn=observation_fn, proposal_fn=prior_proposal_fn, resample_fn=resample_fn, resample_criterion_fn=resample_criterion_fn, seed=seed) def _loop_body(step, previous_step_results, accumulated_traced_results, num_steps_traced): """Take one step in dynamics and accumulate marginal likelihood.""" step_has_observation = ( # The second of these conditions subsumes the first, but both are # useful because the first can often be evaluated statically. ps.equal(num_transitions_per_observation, 1) | ps.equal(step % num_transitions_per_observation, 0)) observation_idx = step // num_transitions_per_observation current_observation = tf.nest.map_structure( lambda x, step=step: tf.gather(x, observation_idx), observations) new_step_results = _filter_one_step( step=step, previous_step_results=previous_step_results, observation=current_observation, transition_fn=transition_fn, observation_fn=observation_fn, proposal_fn=proposal_fn, resample_criterion_fn=resample_criterion_fn, resample_fn=resample_fn, has_observation=step_has_observation, seed=seed) return _update_loop_variables( step=step, current_step_results=new_step_results, accumulated_traced_results=accumulated_traced_results, trace_fn=trace_fn, step_indices_to_trace=step_indices_to_trace, num_steps_traced=num_steps_traced) loop_results = tf.while_loop( cond=lambda step, *_: step < num_timesteps, body=_loop_body, loop_vars=_initialize_loop_variables( initial_step_results=initial_step_results, num_timesteps=num_timesteps, trace_fn=trace_fn, step_indices_to_trace=step_indices_to_trace)) results = tf.nest.map_structure( lambda ta: ta.stack(), loop_results.accumulated_traced_results) if step_indices_to_trace is not None: # If we were passed a rank-0 (single scalar) step to trace, don't # return a time axis in the returned results. results = ps.cond( traced_steps_have_rank_zero, lambda: tf.nest.map_structure(lambda x: x[0, ...], results), lambda: results) return results