timesteps = 15 Δ = 1.0 A = jnp.array([[1, 0, Δ, 0], [0, 1, 0, Δ], [0, 0, 1, 0], [0, 0, 0, 1]]) C = jnp.array([[1, 0, 0, 0], [0, 1, 0, 0]]) state_size, _ = A.shape observation_size, _ = C.shape Q = jnp.eye(state_size) * 0.001 R = jnp.eye(observation_size) * 1.0 # Prior parameter distribution mu0 = jnp.array([8, 10, 1, 0]).astype(float) Sigma0 = jnp.eye(state_size) * 1.0 lds_instance = lds.KalmanFilter(A, C, Q, R, mu0, Sigma0, timesteps) result = sample_filter_smooth(lds_instance, key) l2_filter = jnp.linalg.norm( result["z_hist"][:, :2] - result["mu_hist"][:, :2], 2) l2_smooth = jnp.linalg.norm( result["z_hist"][:, :2] - result["mu_hist_smooth"][:, :2], 2) print(f"L2-filter: {l2_filter:0.4f}") print(f"L2-smooth: {l2_smooth:0.4f}") fig, axs = plt.subplots() axs.plot(result["x_hist"][:, 0], result["x_hist"][:, 1], marker="o", linewidth=0,
plt.rcParams["axes.spines.right"] = False plt.rcParams["axes.spines.top"] = False dx = 1.1 timesteps = 20 key = random.PRNGKey(27182) mean_0 = jnp.array([1, 1, 1, 0]) Sigma_0 = jnp.eye(4) A = jnp.array([[0.1, 1.1, dx, 0], [-1, 1, 0, dx], [0, 0, 0.1, 0], [0, 0, 0, 0.1]]) C = jnp.array([[1, 0, 0, 0], [0, 1, 0, 0]]) Q = jnp.eye(4) * 0.001 R = jnp.eye(2) * 4 lds_instance = lds.KalmanFilter(A, C, Q, R, mean_0, Sigma_0, timesteps) state_hist, obs_hist = lds_instance.sample(key) res = lds_instance.filter(obs_hist) mean_hist, Sigma_hist, mean_cond_hist, Sigma_cond_hist = res mean_hist_smooth, Sigma_hist_smooth = lds_instance.smooth( mean_hist, Sigma_hist, mean_cond_hist, Sigma_cond_hist) fig, ax = plt.subplots() ax.plot(*state_hist[:, :2].T, linestyle="--") ax.scatter(*obs_hist.T, marker="+", s=60) ax.set_title("State space") pml.savefig("spiral-state.pdf") fig, ax = plt.subplots() ax.plot(*mean_hist[:, :2].T)