def propagate_particle_smoother_bs(ssm_scenario: StateSpaceModel, particle_filter: ParticleFilter, particles: cdict, y_new: jnp.ndarray, t_new: float, random_key: jnp.ndarray, lag: int, ess_threshold: float, maximum_rejections: int, init_bound_param: float, bound_inflation: float) -> cdict: n = particles.value.shape[1] if not hasattr(particles, 'num_transition_evals'): particles.num_transition_evals = jnp.array(0) if not hasattr(particles, 'marginal_filter'): particles.marginal_filter = cdict(value=particles.value, log_weight=particles.log_weight, y=particles.y, t=particles.t, ess=particles.ess) split_keys = random.split(random_key, 4) out_particles = particles # Propagate marginal filter particles out_particles.marginal_filter = propagate_particle_filter(ssm_scenario, particle_filter, particles.marginal_filter, y_new, t_new, split_keys[1], ess_threshold, False) out_particles.y = jnp.append(out_particles.y, y_new[jnp.newaxis], axis=0) out_particles.t = jnp.append(out_particles.t, t_new) out_particles.log_weight = jnp.zeros(n) out_particles.ess = out_particles.marginal_filter.ess[-1] len_t = len(out_particles.t) stitch_ind_min_1 = len_t - lag - 1 stitch_ind = len_t - lag def back_sim_only(marginal_filter): backward_sim = backward_simulation(ssm_scenario, marginal_filter, split_keys[2], n, maximum_rejections, init_bound_param, bound_inflation) return backward_sim.value, backward_sim.num_transition_evals.sum() def back_sim_and_stitch(marginal_filter): backward_sim = backward_simulation(ssm_scenario, marginal_filter[stitch_ind_min_1:], split_keys[2], n, maximum_rejections, init_bound_param, bound_inflation) vals, stitch_nte = fixed_lag_stitching(ssm_scenario, out_particles.value[:(stitch_ind_min_1 + 1)], out_particles.t[stitch_ind_min_1], backward_sim.value, jnp.zeros(n), out_particles.t[stitch_ind], random_key, maximum_rejections, init_bound_param, bound_inflation) return vals, stitch_nte + backward_sim.num_transition_evals.sum() if stitch_ind_min_1 >= 0: out_particles.value, num_transition_evals = back_sim_and_stitch(out_particles.marginal_filter) else: out_particles.value, num_transition_evals = back_sim_only(out_particles.marginal_filter) out_particles.num_transition_evals = jnp.append(out_particles.num_transition_evals, num_transition_evals) # out_particles.value = cond(stitch_ind_min_1 >= 0, # back_sim_and_stitch, # back_sim_only, # out_particles.marginal_filter) return out_particles
def propagate_particle_smoother_pf(ssm_scenario: StateSpaceModel, particle_filter: ParticleFilter, particles: cdict, y_new: jnp.ndarray, t_new: float, random_key: jnp.ndarray, lag: int, maximum_rejections: int, init_bound_param: float, bound_inflation: float) -> cdict: if not hasattr(particles, 'num_transition_evals'): particles.num_transition_evals = jnp.array(0) n = particles.value.shape[1] # Check particles are unweighted out_particles = cond(ess_log_weight(jnp.atleast_2d(particles.log_weight)[-1]) < (n - 1e-3), lambda p: resample_particles(p, random_key, True), lambda p: p.copy(), particles) out_particles.log_weight = jnp.zeros(n) x_previous = out_particles.value[-1] t_previous = out_particles.t[-1] split_keys = random.split(random_key, len(x_previous)) x_new, out_particles.log_weight = particle_filter.propose_and_intermediate_weight_vectorised(ssm_scenario, x_previous, t_previous, y_new, t_new, split_keys) out_particles.value = jnp.append(out_particles.value, x_new[jnp.newaxis], axis=0) out_particles.y = jnp.append(out_particles.y, y_new[jnp.newaxis], axis=0) out_particles.t = jnp.append(out_particles.t, t_new) out_particles.ess = ess_log_weight(out_particles.log_weight) len_t = len(out_particles.t) stitch_ind_min_1 = len_t - lag - 1 stitch_ind = len_t - lag # out_particles.value = cond(stitch_ind_min_1 >= 0, # lambda vals: fixed_lag_stitching(ssm_scenario, # vals[:(stitch_ind_min_1 + 1)], # out_particles.t[stitch_ind_min_1], # vals[stitch_ind_min_1:], # out_particles.log_weight, # out_particles.t[stitch_ind], # random_key, # maximum_rejections, # init_bound_param, # bound_inflation), # lambda vals: vals, # out_particles.value) num_transition_evals = 0 if stitch_ind_min_1 >= 0: out_particles.value, num_transition_evals = fixed_lag_stitching(ssm_scenario, out_particles.value[:(stitch_ind_min_1 + 1)], out_particles.t[stitch_ind_min_1], out_particles.value[stitch_ind_min_1:], out_particles.log_weight, out_particles.t[stitch_ind], random_key, maximum_rejections, init_bound_param, bound_inflation) out_particles.num_transition_evals = jnp.append(out_particles.num_transition_evals, num_transition_evals) out_particles.log_weight = jnp.where(stitch_ind_min_1 >= 0, jnp.zeros(n), out_particles.log_weight) return out_particles