def filter_routine(initial_state: MVNormalParameters, observations: jnp.ndarray, transition_function: Callable, transition_covariance: jnp.ndarray, observation_function: Callable, observation_covariance: jnp.ndarray, linearization_points: jnp.ndarray = None): """ Computes the predict-update routine of the Extended Kalman Filter equations using temporal parallelization and returns a series of filtered_states TODO:reference Parameters ---------- initial_state: MVNormalParameters prior belief on the initial state distribution observations: (n, K) array array of n observations of dimension K transition_function: callable transition function of the state space model transition_covariance: (D, D) array transition covariance for each time step observation_function: callable observation function of the state space model observation_covariance: (K, K) array observation error covariances for each time step linearization_points: (n, D) array, optional points at which to compute the jacobians. Returns ------- filtered_states: MVNormalParameters list of filtered states """ n_observations = observations.shape[0] x_dim = initial_state.mean.shape[0] dtype = initial_state.mean.dtype if linearization_points is None: linearization_points = jnp.zeros((n_observations, x_dim), dtype=dtype) @vmap def make_params(obs, i, x_k_1, x_k): return make_associative_filtering_params(observation_function, observation_covariance, transition_function, transition_covariance, obs, i, initial_state.mean, initial_state.cov, x_k_1, x_k) x_k_1_s = jnp.concatenate( (initial_state.mean.reshape(1, -1), linearization_points[:-1]), 0) As, bs, Cs, etas, Js = make_params(observations, jnp.arange(n_observations), x_k_1_s, linearization_points) _, filtered_means, filtered_covariances, _, _ = lax.associative_scan( filtering_operator, (As, bs, Cs, etas, Js)) return vmap(MVNormalParameters)(filtered_means, filtered_covariances)
def smoother_routine(transition_function: Callable, transition_covariance: jnp.ndarray, filtered_states: MVNormalParameters, linearization_states: MVNormalParameters = None): """ Computes the predict-update routine of the Extended Kalman Filter equations using temporal parallelization and returns a series of filtered_states TODO:reference Parameters ---------- transition_function: callable transition function of the state space model transition_covariance: (D, D) array transition covariance for each time step observation error covariances for each time step filtered_states: MVNormalParameters states resulting from (iterated) EKF linearization_states: MVNormalParameters, optional states at which to compute the cubature linearized functions Returns ------- filtered_states: MVNormalParameters list of filtered states """ n_observations = filtered_states.mean.shape[0] @vmap def make_params(i, filtered_state, linearization_state): if linearization_state is None: linearization_state = filtered_state return make_associative_smoothing_params(transition_function, transition_covariance, i, n_observations, filtered_state, linearization_state) gs, Es, Ls = make_params(jnp.arange(n_observations), filtered_states, linearization_states) smoothed_means, _, smoothed_covariances = lax.associative_scan( smoothing_operator, (gs, Es, Ls), reverse=True) return vmap(MVNormalParameters)(smoothed_means, smoothed_covariances)
def filter_routine(initial_state: MVNormalParameters, observations: jnp.ndarray, transition_function: Callable, transition_covariance: jnp.ndarray, observation_function: Callable, observation_covariance: jnp.ndarray, linearization_states: MVNormalParameters = None, propagate_first: bool = True): """ Computes the predict-update routine of the Cubature Kalman Filter equations using temporal parallelization and returns a series of filtered_states TODO:reference Parameters ---------- initial_state: MVNormalParameters prior belief on the initial state distribution observations: (n, K) array array of n observations of dimension K transition_function: callable transition function of the state space model transition_covariance: (D, D) array transition covariance for each time step observation_function: callable observation function of the state space model observation_covariance: (K, K) array observation error covariances for each time step linearization_states: MVNormalParameters, optional in the case of Sigma-Point . propagate_first: bool, optional Is the first step a transition or an update? i.e. False if the initial time step has an associated observation. Default is True. Returns ------- filtered_states: MVNormalParameters list of filtered states """ n_observations = observations.shape[0] x_dim = initial_state.mean.shape[0] dtype = initial_state.mean.dtype if linearization_states is not None: if propagate_first: x_k_1_s = jax.tree_map(lambda z: z[:-1], linearization_states) x_k_s = jax.tree_map(lambda z: z[1:], linearization_states) else: x_k_1_s = jax.tree_map( lambda z: jnp.concatenate([z[None, 0], z[:-1]], 0), linearization_states) x_k_s = linearization_states else: m_k_s = jnp.zeros((n_observations, x_dim), dtype=dtype) P_k_s = jnp.repeat(jnp.eye(x_dim)[None, ...], n_observations, axis=0) x_k_1_s = x_k_s = MVNormalParameters(m_k_s, P_k_s) @vmap def make_params(obs, i, prev_linearization_state, linearisation_state): return make_associative_filtering_params( observation_function, observation_covariance, transition_function, transition_covariance, obs, i, initial_state, prev_linearization_state, linearisation_state, propagate_first) As, bs, Cs, etas, Js = make_params(observations, jnp.arange(n_observations), x_k_1_s, x_k_s) _, filtered_means, filtered_covariances, _, _ = lax.associative_scan( filtering_operator, (As, bs, Cs, etas, Js)) filtered_states = MVNormalParameters(filtered_means, filtered_covariances) if propagate_first: filtered_states = jax.tree_map( lambda x, y: jnp.concatenate([x[None, ...], y], 0), initial_state, filtered_states) return filtered_states
def f(): xs = jnp.arange(4.) return lax.associative_scan(err, xs)
def filter_routine(initial_state: MVNormalParameters, observations: jnp.ndarray, transition_function: Callable, transition_covariance: jnp.ndarray, observation_function: Callable, observation_covariance: jnp.ndarray, linearization_states: MVNormalParameters = None): """ Computes the predict-update routine of the Cubature Kalman Filter equations using temporal parallelization and returns a series of filtered_states TODO:reference Parameters ---------- initial_state: MVNormalParameters prior belief on the initial state distribution observations: (n, K) array array of n observations of dimension K transition_function: callable transition function of the state space model transition_covariance: (D, D) array transition covariance for each time step observation_function: callable observation function of the state space model observation_covariance: (K, K) array observation error covariances for each time step linearization_states: MVNormalParameters, optional in the case of Sigma-Point . Returns ------- filtered_states: MVNormalParameters list of filtered states """ n_observations = observations.shape[0] x_dim = initial_state.mean.shape[0] dtype = initial_state.mean.dtype if linearization_states is None: linearization_mean = jnp.zeros((n_observations, x_dim), dtype=dtype) linearization_cov = make_matrices_parameters( jnp.eye(x_dim, dtype=dtype), n_observations) linearization_states = MVNormalParameters(linearization_mean, linearization_cov) @vmap def make_params(obs, i, prev_linearization_state, linearisation_state): return make_associative_filtering_params( observation_function, observation_covariance, transition_function, transition_covariance, obs, i, initial_state, prev_linearization_state, linearisation_state) x_k_1_s = jnp.concatenate( (initial_state.mean.reshape(1, -1), linearization_states.mean[:-1]), 0) P_k_1_s = jnp.concatenate((initial_state.cov.reshape( 1, x_dim, x_dim), linearization_states.cov[:-1]), 0) prev_linearization_states = MVNormalParameters(x_k_1_s, P_k_1_s) As, bs, Cs, etas, Js = make_params(observations, jnp.arange(n_observations), prev_linearization_states, linearization_states) _, filtered_means, filtered_covariances, _, _ = lax.associative_scan( filtering_operator, (As, bs, Cs, etas, Js)) return vmap(MVNormalParameters)(filtered_means, filtered_covariances)