def backward_simulation_full(ssm_scenario: StateSpaceModel, marginal_particles: cdict, n_samps: int, random_key: jnp.ndarray) -> cdict: marg_particles_vals = marginal_particles.value times = marginal_particles.t marginal_log_weight = marginal_particles.log_weight T, n_pf, d = marg_particles_vals.shape t_keys = random.split(random_key, T) final_particle_vals = marg_particles_vals[-1, random.categorical(t_keys[-1], marginal_log_weight[-1], shape=(n_samps,))] def back_sim_body(x_tplus1_all: jnp.ndarray, ind: int): x_t_all = full_resampling(ssm_scenario, marg_particles_vals[ind], times[ind], x_tplus1_all, times[ind + 1], marginal_log_weight[ind], t_keys[ind]) return x_t_all, x_t_all _, back_sim_particles = scan(back_sim_body, final_particle_vals, jnp.arange(T - 2, -1, -1)) out_samps = marginal_particles.copy() out_samps.value = jnp.vstack([back_sim_particles[::-1], final_particle_vals[jnp.newaxis]]) out_samps.num_transition_evals = jnp.append(0, jnp.ones(T - 1) * n_pf * n_samps) del out_samps.log_weight return out_samps
def proposal(self, abc_scenario: ABCScenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: proposed_state = reject_state.copy() proposed_extra = reject_extra.copy() stepsize = reject_extra.parameters.stepsize proposed_extra.random_key, subkey1, subkey2, subkey3 = random.split( reject_extra.random_key, 4) proposed_state.value = reject_state.value + jnp.sqrt( stepsize) * random.normal(subkey1, (abc_scenario.dim, )) proposed_state.prior_potential = abc_scenario.prior_potential( proposed_state.value, subkey2) proposed_state.simulated_data = abc_scenario.likelihood_sample( proposed_state.value, subkey3) proposed_state.distance = abc_scenario.distance_function( proposed_state.simulated_data) return proposed_state, proposed_extra
def particle_filter_body(samps_previous: cdict, iter_ind: int) -> Tuple[cdict, cdict]: x_previous = samps_previous.value log_weight_previous = samps_previous.log_weight int_rand_key = int_rand_keys[iter_ind] ess_previous = samps_previous.ess resample_bool = ess_previous < (ess_threshold * n) x_res = cond(resample_bool, _resample, lambda tup: tup[0], (x_previous, log_weight_previous, int_rand_key)) log_weight_res = jnp.where(resample_bool, jnp.zeros(n), log_weight_previous) split_keys = random.split(int_rand_key, len(x_previous)) x_new, log_weight_new = particle_filter.propose_and_intermediate_weight_vectorised( ssm_scenario, x_res, samps_previous.t, y[iter_ind], t[iter_ind], split_keys) log_weight_new = log_weight_res + log_weight_new samps_new = samps_previous.copy() samps_new.value = x_new samps_new.log_weight = log_weight_new samps_new.y = y[iter_ind] samps_new.t = t[iter_ind] samps_new.ess = ess_log_weight(log_weight_new) return samps_new, samps_new
def proposal(self, abc_scenario: ABCScenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: proposed_state = reject_state.copy() proposed_extra = reject_extra.copy() proposed_extra.random_key, subkey1, subkey2, subkey3 = random.split( reject_extra.random_key, 4) proposed_state.value = self.importance_proposal(abc_scenario, subkey1) proposed_state.prior_potential = abc_scenario.prior_potential( proposed_state.value, subkey2) proposed_state.simulated_data = abc_scenario.likelihood_sample( proposed_state.value, subkey3) proposed_state.distance = abc_scenario.distance_function( proposed_state.simulated_data) proposed_state.log_weight = self.log_weight(abc_scenario, proposed_state, proposed_extra) return proposed_state, proposed_extra
def leapfrog_step(init_state: cdict, i: int): new_state = init_state.copy() p_half = init_state.momenta - stepsize / 2. * init_state.grad_potential new_state.value = init_state.value + stepsize * p_half new_state.potential, new_state.grad_potential = potential_and_grad( new_state.value, random_keys[i]) new_state.momenta = p_half - stepsize / 2. * new_state.grad_potential next_sample_chain = new_state.copy() next_sample_chain.momenta = jnp.vstack([p_half, new_state.momenta]) return new_state, next_sample_chain
def proposal(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: proposed_state = reject_state.copy() d = scenario.dim x = reject_state.value stepsize = reject_extra.parameters.stepsize reject_extra.random_key, subkey, scen_key = random.split( reject_extra.random_key, 3) proposed_state.value = x + jnp.sqrt(stepsize) * random.normal( subkey, (d, )) proposed_state.potential = scenario.potential(proposed_state.value, scen_key) return proposed_state, reject_extra
def resample_particles(particles: cdict, random_key: jnp.ndarray, resample_full: bool = True) -> cdict: out_particles = particles.copy() n = particles.value.shape[-2] if out_particles.log_weight.ndim == 1: out_particles.value = _resample( (particles.value, particles.log_weight, random_key)) out_particles.log_weight = jnp.zeros(n) elif resample_full: out_particles.value = _resample( (particles.value, particles.log_weight[-1], random_key)) out_particles.log_weight = index_update(out_particles.log_weight, -1, jnp.zeros(n)) else: latest_value = _resample( (particles.value[-1], particles.log_weight[-1], random_key)) out_particles.value = index_update(out_particles.value, -1, latest_value) out_particles.log_weight = index_update(out_particles.log_weight, -1, jnp.zeros(n)) return out_particles
def backward_simulation_rejection(ssm_scenario: StateSpaceModel, marginal_particles: cdict, n_samps: int, random_key: jnp.ndarray, maximum_rejections: int, init_bound_param: float, bound_inflation: float) -> cdict: marg_particles_vals = marginal_particles.value times = marginal_particles.t marginal_log_weight = marginal_particles.log_weight T, n_pf, d = marg_particles_vals.shape t_keys = random.split(random_key, T) final_particle_vals = marg_particles_vals[-1, random.categorical(t_keys[-1], marginal_log_weight[-1], shape=(n_samps,))] def back_sim_body(x_tplus1_all: jnp.ndarray, ind: int): x_t_all, num_transition_evals = rejection_resampling(ssm_scenario, marg_particles_vals[ind], times[ind], x_tplus1_all, times[ind + 1], marginal_log_weight[ind], t_keys[ind], maximum_rejections, init_bound_param, bound_inflation) return x_t_all, (x_t_all, num_transition_evals) _, back_sim_out = scan(back_sim_body, final_particle_vals, jnp.arange(T - 2, -1, -1), unroll=1) back_sim_particles, num_transition_evals = back_sim_out out_samps = marginal_particles.copy() out_samps.value = jnp.vstack([back_sim_particles[::-1], final_particle_vals[jnp.newaxis]]) out_samps.num_transition_evals = jnp.append(0, num_transition_evals[::-1]) del out_samps.log_weight return out_samps
def run_particle_filter_for_marginals(ssm_scenario: StateSpaceModel, particle_filter: ParticleFilter, y: jnp.ndarray, t: jnp.ndarray, random_key: jnp.ndarray, n: int = None, initial_sample: cdict = None, ess_threshold: float = 0.5) -> cdict: if y.ndim == 1: y = y[..., jnp.newaxis] if initial_sample is None: random_key, sub_key = random.split(random_key) init_y = y[0] y = y[1:] initial_sample = initiate_particles(ssm_scenario, particle_filter, n, sub_key, init_y, t[0]) t = t[1:] if n is None: n = initial_sample.value.shape[1] num_propagate_steps = len(y) int_rand_keys = random.split(random_key, num_propagate_steps) def particle_filter_body(samps_previous: cdict, iter_ind: int) -> Tuple[cdict, cdict]: x_previous = samps_previous.value log_weight_previous = samps_previous.log_weight int_rand_key = int_rand_keys[iter_ind] ess_previous = samps_previous.ess resample_bool = ess_previous < (ess_threshold * n) x_res = cond(resample_bool, _resample, lambda tup: tup[0], (x_previous, log_weight_previous, int_rand_key)) log_weight_res = jnp.where(resample_bool, jnp.zeros(n), log_weight_previous) split_keys = random.split(int_rand_key, len(x_previous)) x_new, log_weight_new = particle_filter.propose_and_intermediate_weight_vectorised( ssm_scenario, x_res, samps_previous.t, y[iter_ind], t[iter_ind], split_keys) log_weight_new = log_weight_res + log_weight_new samps_new = samps_previous.copy() samps_new.value = x_new samps_new.log_weight = log_weight_new samps_new.y = y[iter_ind] samps_new.t = t[iter_ind] samps_new.ess = ess_log_weight(log_weight_new) return samps_new, samps_new _, after_init_samps = scan(particle_filter_body, initial_sample[-1], jnp.arange(num_propagate_steps)) out_samps = initial_sample.copy() out_samps.value = jnp.append(initial_sample.value, after_init_samps.value, axis=0) out_samps.log_weight = jnp.append(initial_sample.log_weight, after_init_samps.log_weight, axis=0) out_samps.y = jnp.append(initial_sample.y, after_init_samps.y, axis=0) out_samps.t = jnp.append(initial_sample.t, after_init_samps.t) out_samps.ess = jnp.append(initial_sample.ess, after_init_samps.ess) return out_samps