Ejemplo n.º 1
0
        return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])])

    def fx(x):
        return x

    dt = 0.4
    nsteps = 100
    # Initial state vector
    x0 = jnp.array([1.5, 0.0])
    # State noise
    Qt = jnp.eye(2) * 0.001
    # Observed noise
    Rt = jnp.eye(2) * 0.05

    key = random.PRNGKey(314)
    model = ds.NLDS(lambda x: fz(x, dt), fx, Qt, Rt)
    sample_state, sample_obs = model.sample(key, x0, nsteps)

    n_particles = 3_000
    fz_vec = jax.vmap(fz, in_axes=(0, None))
    particle_filter = ds.BootstrapFiltering(lambda x: fz_vec(x, dt), fx, Qt,
                                            Rt)
    pf_mean = particle_filter.filter(key, x0, sample_obs, n_particles)

    plot_inference(sample_obs, pf_mean)
    pml.savefig("nlds2d_bootstrap.pdf")

    plot_samples(sample_state, sample_obs)
    pml.savefig("nlds2d_data.pdf")

    plt.show()
Ejemplo n.º 2
0
key_noisy, key_values = random.split(key_noisy)
sample_obs_noise = sample_obs.copy()
samples_map = random.bernoulli(key_noisy, 0.5, (nsteps, ))
replacement_values = random.uniform(key_values, (samples_map.sum(), ),
                                    minval=-2,
                                    maxval=2)
sample_obs_noise = index_update(sample_obs_noise.ravel(), samples_map,
                                replacement_values)
colors = ["tab:red" if samp else "tab:blue" for samp in samples_map]

# *** Perform filtering ****
alpha, beta, kappa = 1, 0, 2
Vinit = jnp.eye(2)
ekf = ds.ExtendedKalmanFilter.from_base(model)
ukf = ds.UnscentedKalmanFilter.from_base(model, alpha, beta, kappa)
particle_filter = ds.BootstrapFiltering(fz_vec, fx_vmap, Q, Rt)

print("Filtering data...")
ekf_mean_hist, ekf_Sigma_hist = ekf.filter(x0, sample_obs)
ukf_mean_hist, ukf_Sigma_hist = ukf.filter(x0, sample_obs)
pf_mean_hist = particle_filter.filter(key_pf,
                                      x0,
                                      sample_obs,
                                      nsamples=4_000,
                                      Vinit=Vinit)

print("Filtering outlier data...")
ekf_perturbed_mean_hist, ekf_Sigma_hist = ekf.filter(x0, sample_obs_noise)
ukf_perturbed_mean_hist, ukf_Sigma_hist = ukf.filter(x0, sample_obs_noise)
pf_perturbed_mean_hist = particle_filter.filter(key_pf,
                                                x0,