def particle_filter_body(samps_previous: cdict, iter_ind: int) -> Tuple[cdict, cdict]: x_previous = samps_previous.value log_weight_previous = samps_previous.log_weight int_rand_key = int_rand_keys[iter_ind] ess_previous = samps_previous.ess resample_bool = ess_previous < (ess_threshold * n) x_res = cond(resample_bool, _resample, lambda tup: tup[0], (x_previous, log_weight_previous, int_rand_key)) log_weight_res = jnp.where(resample_bool, jnp.zeros(n), log_weight_previous) split_keys = random.split(int_rand_key, len(x_previous)) x_new, log_weight_new = particle_filter.propose_and_intermediate_weight_vectorised( ssm_scenario, x_res, samps_previous.t, y[iter_ind], t[iter_ind], split_keys) log_weight_new = log_weight_res + log_weight_new samps_new = samps_previous.copy() samps_new.value = x_new samps_new.log_weight = log_weight_new samps_new.y = y[iter_ind] samps_new.t = t[iter_ind] samps_new.ess = ess_log_weight(log_weight_new) return samps_new, samps_new
def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict, new_ensemble_state: cdict, new_extra: cdict) -> Tuple[cdict, cdict]: n = new_ensemble_state.value.shape[0] next_threshold = self.next_threshold(new_ensemble_state, new_extra) new_ensemble_state.threshold = jnp.ones(n) * next_threshold new_extra.parameters.threshold = next_threshold new_ensemble_state.log_weight = self.log_weight(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) new_ensemble_state.ess = jnp.ones(n) * ess_log_weight(new_ensemble_state.log_weight) alive_inds = previous_ensemble_state.log_weight > -jnp.inf new_extra.alpha_mean = (new_ensemble_state.alpha * alive_inds).sum() / alive_inds.sum() new_ensemble_state, new_extra = self.adapt_mcmc_params(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) return new_ensemble_state, new_extra
def initiate_particles(ssm_scenario: StateSpaceModel, particle_filter: ParticleFilter, n: int, random_key: jnp.ndarray, y: jnp.ndarray = None, t: float = None) -> cdict: particle_filter.startup(ssm_scenario) sub_keys = random.split(random_key, n) init_vals, init_log_weight = particle_filter.initial_sample_and_weight_vectorised( ssm_scenario, y, t, sub_keys) if init_vals.ndim == 1: init_vals = init_vals[..., jnp.newaxis] initial_sample = cdict( value=init_vals[jnp.newaxis], log_weight=init_log_weight[jnp.newaxis], t=jnp.atleast_1d(t) if t is not None else jnp.zeros(1), y=y[jnp.newaxis] if y is not None else None, ess=jnp.atleast_1d(ess_log_weight(init_log_weight))) return initial_sample
def propagate_particle_filter(ssm_scenario: StateSpaceModel, particle_filter: ParticleFilter, particles: cdict, y_new: jnp.ndarray, t_new: float, random_key: jnp.ndarray, ess_threshold: float = 0.5, resample_full: bool = True) -> cdict: n = particles.value.shape[1] ess_previous = particles.ess[-1] out_particles = cond( ess_previous < ess_threshold * n, lambda p: resample_particles(p, random_key, resample_full), lambda p: p, particles) x_previous = out_particles.value[-1] log_weight_previous = out_particles.log_weight[-1] t_previous = out_particles.t[-1] split_keys = random.split(random_key, n) x_new, log_weight_new = particle_filter.propose_and_intermediate_weight_vectorised( ssm_scenario, x_previous, t_previous, y_new, t_new, split_keys) log_weight_new = log_weight_previous + log_weight_new out_particles.value = jnp.append(out_particles.value, x_new[jnp.newaxis], axis=0) out_particles.log_weight = jnp.append(out_particles.log_weight, log_weight_new[jnp.newaxis], axis=0) out_particles.y = jnp.append(out_particles.y, y_new[jnp.newaxis]) out_particles.t = jnp.append(out_particles.t, t_new) out_particles.ess = jnp.append(out_particles.ess, ess_log_weight(log_weight_new)) return out_particles
def propagate_particle_smoother_pf(ssm_scenario: StateSpaceModel, particle_filter: ParticleFilter, particles: cdict, y_new: jnp.ndarray, t_new: float, random_key: jnp.ndarray, lag: int, maximum_rejections: int, init_bound_param: float, bound_inflation: float) -> cdict: if not hasattr(particles, 'num_transition_evals'): particles.num_transition_evals = jnp.array(0) n = particles.value.shape[1] # Check particles are unweighted out_particles = cond(ess_log_weight(jnp.atleast_2d(particles.log_weight)[-1]) < (n - 1e-3), lambda p: resample_particles(p, random_key, True), lambda p: p.copy(), particles) out_particles.log_weight = jnp.zeros(n) x_previous = out_particles.value[-1] t_previous = out_particles.t[-1] split_keys = random.split(random_key, len(x_previous)) x_new, out_particles.log_weight = particle_filter.propose_and_intermediate_weight_vectorised(ssm_scenario, x_previous, t_previous, y_new, t_new, split_keys) out_particles.value = jnp.append(out_particles.value, x_new[jnp.newaxis], axis=0) out_particles.y = jnp.append(out_particles.y, y_new[jnp.newaxis], axis=0) out_particles.t = jnp.append(out_particles.t, t_new) out_particles.ess = ess_log_weight(out_particles.log_weight) len_t = len(out_particles.t) stitch_ind_min_1 = len_t - lag - 1 stitch_ind = len_t - lag # out_particles.value = cond(stitch_ind_min_1 >= 0, # lambda vals: fixed_lag_stitching(ssm_scenario, # vals[:(stitch_ind_min_1 + 1)], # out_particles.t[stitch_ind_min_1], # vals[stitch_ind_min_1:], # out_particles.log_weight, # out_particles.t[stitch_ind], # random_key, # maximum_rejections, # init_bound_param, # bound_inflation), # lambda vals: vals, # out_particles.value) num_transition_evals = 0 if stitch_ind_min_1 >= 0: out_particles.value, num_transition_evals = fixed_lag_stitching(ssm_scenario, out_particles.value[:(stitch_ind_min_1 + 1)], out_particles.t[stitch_ind_min_1], out_particles.value[stitch_ind_min_1:], out_particles.log_weight, out_particles.t[stitch_ind], random_key, maximum_rejections, init_bound_param, bound_inflation) out_particles.num_transition_evals = jnp.append(out_particles.num_transition_evals, num_transition_evals) out_particles.log_weight = jnp.where(stitch_ind_min_1 >= 0, jnp.zeros(n), out_particles.log_weight) return out_particles