Exemple #1
0
    def particle_filter_body(samps_previous: cdict,
                             iter_ind: int) -> Tuple[cdict, cdict]:
        x_previous = samps_previous.value
        log_weight_previous = samps_previous.log_weight
        int_rand_key = int_rand_keys[iter_ind]

        ess_previous = samps_previous.ess
        resample_bool = ess_previous < (ess_threshold * n)
        x_res = cond(resample_bool, _resample, lambda tup: tup[0],
                     (x_previous, log_weight_previous, int_rand_key))
        log_weight_res = jnp.where(resample_bool, jnp.zeros(n),
                                   log_weight_previous)

        split_keys = random.split(int_rand_key, len(x_previous))

        x_new, log_weight_new = particle_filter.propose_and_intermediate_weight_vectorised(
            ssm_scenario, x_res, samps_previous.t, y[iter_ind], t[iter_ind],
            split_keys)

        log_weight_new = log_weight_res + log_weight_new

        samps_new = samps_previous.copy()
        samps_new.value = x_new
        samps_new.log_weight = log_weight_new
        samps_new.y = y[iter_ind]
        samps_new.t = t[iter_ind]
        samps_new.ess = ess_log_weight(log_weight_new)
        return samps_new, samps_new
Exemple #2
0
    def adapt(self,
              previous_ensemble_state: cdict,
              previous_extra: cdict,
              new_ensemble_state: cdict,
              new_extra: cdict) -> Tuple[cdict, cdict]:
        n = new_ensemble_state.value.shape[0]
        next_threshold = self.next_threshold(new_ensemble_state, new_extra)
        new_ensemble_state.threshold = jnp.ones(n) * next_threshold
        new_extra.parameters.threshold = next_threshold
        new_ensemble_state.log_weight = self.log_weight(previous_ensemble_state, previous_extra,
                                                        new_ensemble_state, new_extra)
        new_ensemble_state.ess = jnp.ones(n) * ess_log_weight(new_ensemble_state.log_weight)
        alive_inds = previous_ensemble_state.log_weight > -jnp.inf
        new_extra.alpha_mean = (new_ensemble_state.alpha * alive_inds).sum() / alive_inds.sum()
        new_ensemble_state, new_extra = self.adapt_mcmc_params(previous_ensemble_state, previous_extra,
                                                               new_ensemble_state, new_extra)

        return new_ensemble_state, new_extra
Exemple #3
0
def initiate_particles(ssm_scenario: StateSpaceModel,
                       particle_filter: ParticleFilter,
                       n: int,
                       random_key: jnp.ndarray,
                       y: jnp.ndarray = None,
                       t: float = None) -> cdict:
    particle_filter.startup(ssm_scenario)

    sub_keys = random.split(random_key, n)

    init_vals, init_log_weight = particle_filter.initial_sample_and_weight_vectorised(
        ssm_scenario, y, t, sub_keys)

    if init_vals.ndim == 1:
        init_vals = init_vals[..., jnp.newaxis]

    initial_sample = cdict(
        value=init_vals[jnp.newaxis],
        log_weight=init_log_weight[jnp.newaxis],
        t=jnp.atleast_1d(t) if t is not None else jnp.zeros(1),
        y=y[jnp.newaxis] if y is not None else None,
        ess=jnp.atleast_1d(ess_log_weight(init_log_weight)))
    return initial_sample
Exemple #4
0
def propagate_particle_filter(ssm_scenario: StateSpaceModel,
                              particle_filter: ParticleFilter,
                              particles: cdict,
                              y_new: jnp.ndarray,
                              t_new: float,
                              random_key: jnp.ndarray,
                              ess_threshold: float = 0.5,
                              resample_full: bool = True) -> cdict:
    n = particles.value.shape[1]
    ess_previous = particles.ess[-1]
    out_particles = cond(
        ess_previous < ess_threshold * n,
        lambda p: resample_particles(p, random_key, resample_full),
        lambda p: p, particles)

    x_previous = out_particles.value[-1]
    log_weight_previous = out_particles.log_weight[-1]
    t_previous = out_particles.t[-1]

    split_keys = random.split(random_key, n)

    x_new, log_weight_new = particle_filter.propose_and_intermediate_weight_vectorised(
        ssm_scenario, x_previous, t_previous, y_new, t_new, split_keys)

    log_weight_new = log_weight_previous + log_weight_new

    out_particles.value = jnp.append(out_particles.value,
                                     x_new[jnp.newaxis],
                                     axis=0)
    out_particles.log_weight = jnp.append(out_particles.log_weight,
                                          log_weight_new[jnp.newaxis],
                                          axis=0)
    out_particles.y = jnp.append(out_particles.y, y_new[jnp.newaxis])
    out_particles.t = jnp.append(out_particles.t, t_new)
    out_particles.ess = jnp.append(out_particles.ess,
                                   ess_log_weight(log_weight_new))
    return out_particles
def propagate_particle_smoother_pf(ssm_scenario: StateSpaceModel,
                                   particle_filter: ParticleFilter,
                                   particles: cdict,
                                   y_new: jnp.ndarray,
                                   t_new: float,
                                   random_key: jnp.ndarray,
                                   lag: int,
                                   maximum_rejections: int,
                                   init_bound_param: float,
                                   bound_inflation: float) -> cdict:
    if not hasattr(particles, 'num_transition_evals'):
        particles.num_transition_evals = jnp.array(0)

    n = particles.value.shape[1]

    # Check particles are unweighted
    out_particles = cond(ess_log_weight(jnp.atleast_2d(particles.log_weight)[-1]) < (n - 1e-3),
                         lambda p: resample_particles(p, random_key, True),
                         lambda p: p.copy(),
                         particles)
    out_particles.log_weight = jnp.zeros(n)

    x_previous = out_particles.value[-1]
    t_previous = out_particles.t[-1]

    split_keys = random.split(random_key, len(x_previous))

    x_new, out_particles.log_weight = particle_filter.propose_and_intermediate_weight_vectorised(ssm_scenario,
                                                                                                 x_previous, t_previous,
                                                                                                 y_new, t_new,
                                                                                                 split_keys)

    out_particles.value = jnp.append(out_particles.value, x_new[jnp.newaxis], axis=0)
    out_particles.y = jnp.append(out_particles.y, y_new[jnp.newaxis], axis=0)
    out_particles.t = jnp.append(out_particles.t, t_new)
    out_particles.ess = ess_log_weight(out_particles.log_weight)

    len_t = len(out_particles.t)
    stitch_ind_min_1 = len_t - lag - 1
    stitch_ind = len_t - lag

    # out_particles.value = cond(stitch_ind_min_1 >= 0,
    #                            lambda vals: fixed_lag_stitching(ssm_scenario,
    #                                                             vals[:(stitch_ind_min_1 + 1)],
    #                                                             out_particles.t[stitch_ind_min_1],
    #                                                             vals[stitch_ind_min_1:],
    #                                                             out_particles.log_weight,
    #                                                             out_particles.t[stitch_ind],
    #                                                             random_key,
    #                                                             maximum_rejections,
    #                                                             init_bound_param,
    #                                                             bound_inflation),
    #                            lambda vals: vals,
    #                            out_particles.value)

    num_transition_evals = 0
    if stitch_ind_min_1 >= 0:
        out_particles.value, num_transition_evals = fixed_lag_stitching(ssm_scenario,
                                                                        out_particles.value[:(stitch_ind_min_1 + 1)],
                                                                        out_particles.t[stitch_ind_min_1],
                                                                        out_particles.value[stitch_ind_min_1:],
                                                                        out_particles.log_weight,
                                                                        out_particles.t[stitch_ind],
                                                                        random_key,
                                                                        maximum_rejections,
                                                                        init_bound_param,
                                                                        bound_inflation)
    out_particles.num_transition_evals = jnp.append(out_particles.num_transition_evals, num_transition_evals)
    out_particles.log_weight = jnp.where(stitch_ind_min_1 >= 0, jnp.zeros(n), out_particles.log_weight)
    return out_particles