timesteps = 15
    Δ = 1.0
    A = jnp.array([[1, 0, Δ, 0], [0, 1, 0, Δ], [0, 0, 1, 0], [0, 0, 0, 1]])

    C = jnp.array([[1, 0, 0, 0], [0, 1, 0, 0]])

    state_size, _ = A.shape
    observation_size, _ = C.shape

    Q = jnp.eye(state_size) * 0.001
    R = jnp.eye(observation_size) * 1.0
    # Prior parameter distribution
    mu0 = jnp.array([8, 10, 1, 0]).astype(float)
    Sigma0 = jnp.eye(state_size) * 1.0

    lds_instance = lds.KalmanFilter(A, C, Q, R, mu0, Sigma0, timesteps)
    result = sample_filter_smooth(lds_instance, key)

    l2_filter = jnp.linalg.norm(
        result["z_hist"][:, :2] - result["mu_hist"][:, :2], 2)
    l2_smooth = jnp.linalg.norm(
        result["z_hist"][:, :2] - result["mu_hist_smooth"][:, :2], 2)

    print(f"L2-filter: {l2_filter:0.4f}")
    print(f"L2-smooth: {l2_smooth:0.4f}")

    fig, axs = plt.subplots()
    axs.plot(result["x_hist"][:, 0],
             result["x_hist"][:, 1],
             marker="o",
             linewidth=0,
Beispiel #2
0
    plt.rcParams["axes.spines.right"] = False
    plt.rcParams["axes.spines.top"] = False

    dx = 1.1
    timesteps = 20
    key = random.PRNGKey(27182)

    mean_0 = jnp.array([1, 1, 1, 0])
    Sigma_0 = jnp.eye(4)
    A = jnp.array([[0.1, 1.1, dx, 0], [-1, 1, 0, dx], [0, 0, 0.1, 0],
                   [0, 0, 0, 0.1]])
    C = jnp.array([[1, 0, 0, 0], [0, 1, 0, 0]])
    Q = jnp.eye(4) * 0.001
    R = jnp.eye(2) * 4

    lds_instance = lds.KalmanFilter(A, C, Q, R, mean_0, Sigma_0, timesteps)
    state_hist, obs_hist = lds_instance.sample(key)

    res = lds_instance.filter(obs_hist)
    mean_hist, Sigma_hist, mean_cond_hist, Sigma_cond_hist = res
    mean_hist_smooth, Sigma_hist_smooth = lds_instance.smooth(
        mean_hist, Sigma_hist, mean_cond_hist, Sigma_cond_hist)

    fig, ax = plt.subplots()
    ax.plot(*state_hist[:, :2].T, linestyle="--")
    ax.scatter(*obs_hist.T, marker="+", s=60)
    ax.set_title("State space")
    pml.savefig("spiral-state.pdf")

    fig, ax = plt.subplots()
    ax.plot(*mean_hist[:, :2].T)