Esempio n. 1
0
def filter_routine(initial_state: MVNormalParameters,
                   observations: jnp.ndarray,
                   transition_function: Callable,
                   transition_covariance: jnp.ndarray,
                   observation_function: Callable,
                   observation_covariance: jnp.ndarray,
                   linearization_points: jnp.ndarray = None):
    """ Computes the predict-update routine of the Extended Kalman Filter equations
    using temporal parallelization and returns a series of filtered_states TODO:reference

    Parameters
    ----------
    initial_state: MVNormalParameters
        prior belief on the initial state distribution
    observations: (n, K) array
        array of n observations of dimension K
    transition_function: callable
        transition function of the state space model
    transition_covariance: (D, D) array
        transition covariance for each time step
    observation_function: callable
        observation function of the state space model
    observation_covariance: (K, K) array
        observation error covariances for each time step
    linearization_points: (n, D) array, optional
        points at which to compute the jacobians.

    Returns
    -------
    filtered_states: MVNormalParameters
        list of filtered states

    """
    n_observations = observations.shape[0]
    x_dim = initial_state.mean.shape[0]
    dtype = initial_state.mean.dtype
    if linearization_points is None:
        linearization_points = jnp.zeros((n_observations, x_dim), dtype=dtype)

    @vmap
    def make_params(obs, i, x_k_1, x_k):
        return make_associative_filtering_params(observation_function,
                                                 observation_covariance,
                                                 transition_function,
                                                 transition_covariance, obs, i,
                                                 initial_state.mean,
                                                 initial_state.cov, x_k_1, x_k)

    x_k_1_s = jnp.concatenate(
        (initial_state.mean.reshape(1, -1), linearization_points[:-1]), 0)
    As, bs, Cs, etas, Js = make_params(observations,
                                       jnp.arange(n_observations), x_k_1_s,
                                       linearization_points)
    _, filtered_means, filtered_covariances, _, _ = lax.associative_scan(
        filtering_operator, (As, bs, Cs, etas, Js))

    return vmap(MVNormalParameters)(filtered_means, filtered_covariances)
Esempio n. 2
0
def smoother_routine(transition_function: Callable,
                     transition_covariance: jnp.ndarray,
                     filtered_states: MVNormalParameters,
                     linearization_states: MVNormalParameters = None):
    """ Computes the predict-update routine of the Extended Kalman Filter equations
    using temporal parallelization and returns a series of filtered_states TODO:reference

    Parameters
    ----------
    transition_function: callable
        transition function of the state space model
    transition_covariance: (D, D) array
        transition covariance for each time step
        observation error covariances for each time step
    filtered_states: MVNormalParameters
        states resulting from (iterated) EKF
    linearization_states: MVNormalParameters, optional
        states at which to compute the cubature linearized functions

    Returns
    -------
    filtered_states: MVNormalParameters
        list of filtered states

    """
    n_observations = filtered_states.mean.shape[0]

    @vmap
    def make_params(i, filtered_state, linearization_state):
        if linearization_state is None:
            linearization_state = filtered_state
        return make_associative_smoothing_params(transition_function,
                                                 transition_covariance, i,
                                                 n_observations,
                                                 filtered_state,
                                                 linearization_state)

    gs, Es, Ls = make_params(jnp.arange(n_observations), filtered_states,
                             linearization_states)

    smoothed_means, _, smoothed_covariances = lax.associative_scan(
        smoothing_operator, (gs, Es, Ls), reverse=True)

    return vmap(MVNormalParameters)(smoothed_means, smoothed_covariances)
Esempio n. 3
0
def filter_routine(initial_state: MVNormalParameters,
                   observations: jnp.ndarray,
                   transition_function: Callable,
                   transition_covariance: jnp.ndarray,
                   observation_function: Callable,
                   observation_covariance: jnp.ndarray,
                   linearization_states: MVNormalParameters = None,
                   propagate_first: bool = True):
    """ Computes the predict-update routine of the Cubature Kalman Filter equations
    using temporal parallelization and returns a series of filtered_states TODO:reference

    Parameters
    ----------
    initial_state: MVNormalParameters
        prior belief on the initial state distribution
    observations: (n, K) array
        array of n observations of dimension K
    transition_function: callable
        transition function of the state space model
    transition_covariance: (D, D) array
        transition covariance for each time step
    observation_function: callable
        observation function of the state space model
    observation_covariance: (K, K) array
        observation error covariances for each time step
    linearization_states: MVNormalParameters, optional
        in the case of Sigma-Point .
    propagate_first: bool, optional
        Is the first step a transition or an update? i.e. False if the initial time step has
        an associated observation. Default is True.

    Returns
    -------
    filtered_states: MVNormalParameters
        list of filtered states

    """
    n_observations = observations.shape[0]
    x_dim = initial_state.mean.shape[0]
    dtype = initial_state.mean.dtype

    if linearization_states is not None:
        if propagate_first:
            x_k_1_s = jax.tree_map(lambda z: z[:-1], linearization_states)
            x_k_s = jax.tree_map(lambda z: z[1:], linearization_states)
        else:
            x_k_1_s = jax.tree_map(
                lambda z: jnp.concatenate([z[None, 0], z[:-1]], 0),
                linearization_states)
            x_k_s = linearization_states
    else:

        m_k_s = jnp.zeros((n_observations, x_dim), dtype=dtype)
        P_k_s = jnp.repeat(jnp.eye(x_dim)[None, ...], n_observations, axis=0)
        x_k_1_s = x_k_s = MVNormalParameters(m_k_s, P_k_s)

    @vmap
    def make_params(obs, i, prev_linearization_state, linearisation_state):
        return make_associative_filtering_params(
            observation_function, observation_covariance, transition_function,
            transition_covariance, obs, i, initial_state,
            prev_linearization_state, linearisation_state, propagate_first)

    As, bs, Cs, etas, Js = make_params(observations,
                                       jnp.arange(n_observations), x_k_1_s,
                                       x_k_s)
    _, filtered_means, filtered_covariances, _, _ = lax.associative_scan(
        filtering_operator, (As, bs, Cs, etas, Js))

    filtered_states = MVNormalParameters(filtered_means, filtered_covariances)
    if propagate_first:
        filtered_states = jax.tree_map(
            lambda x, y: jnp.concatenate([x[None, ...], y], 0), initial_state,
            filtered_states)
    return filtered_states
Esempio n. 4
0
 def f():
     xs = jnp.arange(4.)
     return lax.associative_scan(err, xs)
def filter_routine(initial_state: MVNormalParameters,
                   observations: jnp.ndarray,
                   transition_function: Callable,
                   transition_covariance: jnp.ndarray,
                   observation_function: Callable,
                   observation_covariance: jnp.ndarray,
                   linearization_states: MVNormalParameters = None):
    """ Computes the predict-update routine of the Cubature Kalman Filter equations
    using temporal parallelization and returns a series of filtered_states TODO:reference

    Parameters
    ----------
    initial_state: MVNormalParameters
        prior belief on the initial state distribution
    observations: (n, K) array
        array of n observations of dimension K
    transition_function: callable
        transition function of the state space model
    transition_covariance: (D, D) array
        transition covariance for each time step
    observation_function: callable
        observation function of the state space model
    observation_covariance: (K, K) array
        observation error covariances for each time step
    linearization_states: MVNormalParameters, optional
        in the case of Sigma-Point .

    Returns
    -------
    filtered_states: MVNormalParameters
        list of filtered states

    """
    n_observations = observations.shape[0]
    x_dim = initial_state.mean.shape[0]
    dtype = initial_state.mean.dtype

    if linearization_states is None:
        linearization_mean = jnp.zeros((n_observations, x_dim), dtype=dtype)
        linearization_cov = make_matrices_parameters(
            jnp.eye(x_dim, dtype=dtype), n_observations)
        linearization_states = MVNormalParameters(linearization_mean,
                                                  linearization_cov)

    @vmap
    def make_params(obs, i, prev_linearization_state, linearisation_state):
        return make_associative_filtering_params(
            observation_function, observation_covariance, transition_function,
            transition_covariance, obs, i, initial_state,
            prev_linearization_state, linearisation_state)

    x_k_1_s = jnp.concatenate(
        (initial_state.mean.reshape(1, -1), linearization_states.mean[:-1]), 0)
    P_k_1_s = jnp.concatenate((initial_state.cov.reshape(
        1, x_dim, x_dim), linearization_states.cov[:-1]), 0)
    prev_linearization_states = MVNormalParameters(x_k_1_s, P_k_1_s)

    As, bs, Cs, etas, Js = make_params(observations,
                                       jnp.arange(n_observations),
                                       prev_linearization_states,
                                       linearization_states)
    _, filtered_means, filtered_covariances, _, _ = lax.associative_scan(
        filtering_operator, (As, bs, Cs, etas, Js))

    return vmap(MVNormalParameters)(filtered_means, filtered_covariances)