Exemplo n.º 1
0
def propagate_particle_smoother_bs(ssm_scenario: StateSpaceModel,
                                   particle_filter: ParticleFilter,
                                   particles: cdict,
                                   y_new: jnp.ndarray,
                                   t_new: float,
                                   random_key: jnp.ndarray,
                                   lag: int,
                                   ess_threshold: float,
                                   maximum_rejections: int,
                                   init_bound_param: float,
                                   bound_inflation: float) -> cdict:
    n = particles.value.shape[1]

    if not hasattr(particles, 'num_transition_evals'):
        particles.num_transition_evals = jnp.array(0)

    if not hasattr(particles, 'marginal_filter'):
        particles.marginal_filter = cdict(value=particles.value,
                                          log_weight=particles.log_weight,
                                          y=particles.y,
                                          t=particles.t,
                                          ess=particles.ess)

    split_keys = random.split(random_key, 4)

    out_particles = particles

    # Propagate marginal filter particles
    out_particles.marginal_filter = propagate_particle_filter(ssm_scenario, particle_filter, particles.marginal_filter,
                                                              y_new, t_new, split_keys[1], ess_threshold, False)
    out_particles.y = jnp.append(out_particles.y, y_new[jnp.newaxis], axis=0)
    out_particles.t = jnp.append(out_particles.t, t_new)
    out_particles.log_weight = jnp.zeros(n)
    out_particles.ess = out_particles.marginal_filter.ess[-1]

    len_t = len(out_particles.t)
    stitch_ind_min_1 = len_t - lag - 1
    stitch_ind = len_t - lag

    def back_sim_only(marginal_filter):
        backward_sim = backward_simulation(ssm_scenario,
                                           marginal_filter,
                                           split_keys[2],
                                           n,
                                           maximum_rejections,
                                           init_bound_param,
                                           bound_inflation)
        return backward_sim.value, backward_sim.num_transition_evals.sum()

    def back_sim_and_stitch(marginal_filter):
        backward_sim = backward_simulation(ssm_scenario,
                                           marginal_filter[stitch_ind_min_1:],
                                           split_keys[2],
                                           n,
                                           maximum_rejections,
                                           init_bound_param,
                                           bound_inflation)

        vals, stitch_nte = fixed_lag_stitching(ssm_scenario,
                                               out_particles.value[:(stitch_ind_min_1 + 1)],
                                               out_particles.t[stitch_ind_min_1],
                                               backward_sim.value,
                                               jnp.zeros(n),
                                               out_particles.t[stitch_ind],
                                               random_key,
                                               maximum_rejections,
                                               init_bound_param,
                                               bound_inflation)
        return vals, stitch_nte + backward_sim.num_transition_evals.sum()

    if stitch_ind_min_1 >= 0:
        out_particles.value, num_transition_evals = back_sim_and_stitch(out_particles.marginal_filter)
    else:
        out_particles.value, num_transition_evals = back_sim_only(out_particles.marginal_filter)

    out_particles.num_transition_evals = jnp.append(out_particles.num_transition_evals, num_transition_evals)

    # out_particles.value = cond(stitch_ind_min_1 >= 0,
    #                            back_sim_and_stitch,
    #                            back_sim_only,
    #                            out_particles.marginal_filter)
    return out_particles
Exemplo n.º 2
0
def propagate_particle_smoother_pf(ssm_scenario: StateSpaceModel,
                                   particle_filter: ParticleFilter,
                                   particles: cdict,
                                   y_new: jnp.ndarray,
                                   t_new: float,
                                   random_key: jnp.ndarray,
                                   lag: int,
                                   maximum_rejections: int,
                                   init_bound_param: float,
                                   bound_inflation: float) -> cdict:
    if not hasattr(particles, 'num_transition_evals'):
        particles.num_transition_evals = jnp.array(0)

    n = particles.value.shape[1]

    # Check particles are unweighted
    out_particles = cond(ess_log_weight(jnp.atleast_2d(particles.log_weight)[-1]) < (n - 1e-3),
                         lambda p: resample_particles(p, random_key, True),
                         lambda p: p.copy(),
                         particles)
    out_particles.log_weight = jnp.zeros(n)

    x_previous = out_particles.value[-1]
    t_previous = out_particles.t[-1]

    split_keys = random.split(random_key, len(x_previous))

    x_new, out_particles.log_weight = particle_filter.propose_and_intermediate_weight_vectorised(ssm_scenario,
                                                                                                 x_previous, t_previous,
                                                                                                 y_new, t_new,
                                                                                                 split_keys)

    out_particles.value = jnp.append(out_particles.value, x_new[jnp.newaxis], axis=0)
    out_particles.y = jnp.append(out_particles.y, y_new[jnp.newaxis], axis=0)
    out_particles.t = jnp.append(out_particles.t, t_new)
    out_particles.ess = ess_log_weight(out_particles.log_weight)

    len_t = len(out_particles.t)
    stitch_ind_min_1 = len_t - lag - 1
    stitch_ind = len_t - lag

    # out_particles.value = cond(stitch_ind_min_1 >= 0,
    #                            lambda vals: fixed_lag_stitching(ssm_scenario,
    #                                                             vals[:(stitch_ind_min_1 + 1)],
    #                                                             out_particles.t[stitch_ind_min_1],
    #                                                             vals[stitch_ind_min_1:],
    #                                                             out_particles.log_weight,
    #                                                             out_particles.t[stitch_ind],
    #                                                             random_key,
    #                                                             maximum_rejections,
    #                                                             init_bound_param,
    #                                                             bound_inflation),
    #                            lambda vals: vals,
    #                            out_particles.value)

    num_transition_evals = 0
    if stitch_ind_min_1 >= 0:
        out_particles.value, num_transition_evals = fixed_lag_stitching(ssm_scenario,
                                                                        out_particles.value[:(stitch_ind_min_1 + 1)],
                                                                        out_particles.t[stitch_ind_min_1],
                                                                        out_particles.value[stitch_ind_min_1:],
                                                                        out_particles.log_weight,
                                                                        out_particles.t[stitch_ind],
                                                                        random_key,
                                                                        maximum_rejections,
                                                                        init_bound_param,
                                                                        bound_inflation)
    out_particles.num_transition_evals = jnp.append(out_particles.num_transition_evals, num_transition_evals)
    out_particles.log_weight = jnp.where(stitch_ind_min_1 >= 0, jnp.zeros(n), out_particles.log_weight)
    return out_particles