return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])]) def fx(x): return x dt = 0.4 nsteps = 100 # Initial state vector x0 = jnp.array([1.5, 0.0]) # State noise Qt = jnp.eye(2) * 0.001 # Observed noise Rt = jnp.eye(2) * 0.05 key = random.PRNGKey(314) model = ds.NLDS(lambda x: fz(x, dt), fx, Qt, Rt) sample_state, sample_obs = model.sample(key, x0, nsteps) n_particles = 3_000 fz_vec = jax.vmap(fz, in_axes=(0, None)) particle_filter = ds.BootstrapFiltering(lambda x: fz_vec(x, dt), fx, Qt, Rt) pf_mean = particle_filter.filter(key, x0, sample_obs, n_particles) plot_inference(sample_obs, pf_mean) pml.savefig("nlds2d_bootstrap.pdf") plot_samples(sample_state, sample_obs) pml.savefig("nlds2d_data.pdf") plt.show()
g = 10 dt = 0.015 qc = 0.06 Q = jnp.array([[qc * dt**3 / 3, qc * dt**2 / 2], [qc * dt**2 / 2, qc * dt]]) fx_vmap = jax.vmap(fx) fz_vec = jax.vmap(lambda x: fz(x, g=g, dt=dt)) nsteps = 200 Rt = jnp.eye(1) * 0.02 x0 = jnp.array([1.5, 0.0]).astype(float) time = jnp.arange(0, nsteps * dt, dt) key = random.PRNGKey(3141) key_samples, key_pf, key_noisy = random.split(key, 3) model = ds.NLDS(lambda x: fz(x, g=g, dt=dt), fx, Q, Rt) sample_state, sample_obs = model.sample(key, x0, nsteps) # *** Pertubed data *** key_noisy, key_values = random.split(key_noisy) sample_obs_noise = sample_obs.copy() samples_map = random.bernoulli(key_noisy, 0.5, (nsteps, )) replacement_values = random.uniform(key_values, (samples_map.sum(), ), minval=-2, maxval=2) sample_obs_noise = index_update(sample_obs_noise.ravel(), samples_map, replacement_values) colors = ["tab:red" if samp else "tab:blue" for samp in samples_map] # *** Perform filtering **** alpha, beta, kappa = 1, 0, 2