Exemplo n.º 1
0
        return x + dt * jnp.array([jnp.sin(x[1]), jnp.cos(x[0])])

    def fx(x):
        return x

    dt = 0.4
    nsteps = 100
    # Initial state vector
    x0 = jnp.array([1.5, 0.0])
    # State noise
    Qt = jnp.eye(2) * 0.001
    # Observed noise
    Rt = jnp.eye(2) * 0.05

    key = random.PRNGKey(314)
    model = ds.NLDS(lambda x: fz(x, dt), fx, Qt, Rt)
    sample_state, sample_obs = model.sample(key, x0, nsteps)

    n_particles = 3_000
    fz_vec = jax.vmap(fz, in_axes=(0, None))
    particle_filter = ds.BootstrapFiltering(lambda x: fz_vec(x, dt), fx, Qt,
                                            Rt)
    pf_mean = particle_filter.filter(key, x0, sample_obs, n_particles)

    plot_inference(sample_obs, pf_mean)
    pml.savefig("nlds2d_bootstrap.pdf")

    plot_samples(sample_state, sample_obs)
    pml.savefig("nlds2d_data.pdf")

    plt.show()
Exemplo n.º 2
0
g = 10
dt = 0.015
qc = 0.06
Q = jnp.array([[qc * dt**3 / 3, qc * dt**2 / 2], [qc * dt**2 / 2, qc * dt]])

fx_vmap = jax.vmap(fx)
fz_vec = jax.vmap(lambda x: fz(x, g=g, dt=dt))

nsteps = 200
Rt = jnp.eye(1) * 0.02
x0 = jnp.array([1.5, 0.0]).astype(float)
time = jnp.arange(0, nsteps * dt, dt)

key = random.PRNGKey(3141)
key_samples, key_pf, key_noisy = random.split(key, 3)
model = ds.NLDS(lambda x: fz(x, g=g, dt=dt), fx, Q, Rt)
sample_state, sample_obs = model.sample(key, x0, nsteps)

# *** Pertubed data ***
key_noisy, key_values = random.split(key_noisy)
sample_obs_noise = sample_obs.copy()
samples_map = random.bernoulli(key_noisy, 0.5, (nsteps, ))
replacement_values = random.uniform(key_values, (samples_map.sum(), ),
                                    minval=-2,
                                    maxval=2)
sample_obs_noise = index_update(sample_obs_noise.ravel(), samples_map,
                                replacement_values)
colors = ["tab:red" if samp else "tab:blue" for samp in samples_map]

# *** Perform filtering ****
alpha, beta, kappa = 1, 0, 2