def update(self, scenario: Scenario, ensemble_state: cdict, extra: cdict) -> Tuple[cdict, cdict]: extra.iter = extra.iter + 1 n = ensemble_state.value.shape[0] resample_bool = self.resample_criterion(ensemble_state, extra) random_keys_all = random.split(extra.random_key, n + 2) extra.random_key = random_keys_all[-1] resampled_ensemble_state \ = cond(resample_bool, lambda state: self.resample(state, random_keys_all[-2]), lambda state: state, ensemble_state) advanced_state = vmap(self.forward_proposal, in_axes=(None, 0, None, 0))(scenario, resampled_ensemble_state, extra, random_keys_all[:n]) advanced_state, advanced_extra = self.adapt(resampled_ensemble_state, extra, advanced_state, extra) return advanced_state, advanced_extra
def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict, initial_extra: cdict, startup_correction: bool = True, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(abc_scenario, n, initial_state, initial_extra, **kwargs) if not hasattr(initial_state, 'prior_potential') and is_implemented( abc_scenario.prior_potential): initial_extra.random_key, subkey = random.split( initial_extra.random_key) initial_state.prior_potential = abc_scenario.prior_potential( initial_state.value, subkey) if not hasattr(initial_state, 'simulated_data'): initial_extra.random_key, subkey = random.split( initial_extra.random_key) initial_state.simulated_data = abc_scenario.likelihood_sample( initial_state.value, subkey) if not hasattr(initial_state, 'distance'): initial_state.distance = abc_scenario.distance_function( initial_state.simulated_data) return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: Union[None, cdict], initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) if hasattr(self, 'parameters') and hasattr(self.parameters, key): setattr(self.parameters, key, value) if not hasattr(self, 'max_iter')\ or not (isinstance(self.max_iter, int) or (isinstance(self.max_iter, jnp.ndarray) and self.max_iter.dtype == 'int32')): raise AttributeError(self.__repr__() + ' max_iter must be int') if not hasattr(initial_extra, 'iter'): initial_extra.iter = 0 if hasattr(self, 'parameters'): if not hasattr(initial_extra, 'parameters'): initial_extra.parameters = cdict() for key, value in self.parameters.__dict__.items(): if not hasattr(initial_extra.parameters, key) or getattr(initial_extra.parameters, key) is None: setattr(initial_extra.parameters, key, value) return initial_state, initial_extra
def clean_chain(self, scenario: Scenario, chain_ensemble_state: cdict) -> cdict: chain_ensemble_state.temperature = chain_ensemble_state.temperature[:, 0] scenario.temperature = float(chain_ensemble_state.temperature[-1]) chain_ensemble_state.ess = chain_ensemble_state.ess[:, 0] return chain_ensemble_state
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) if self.parameters.ensemble_batchsize is None: self.parameters.ensemble_batchsize = n initial_extra.parameters.ensemble_batchsize = n if self.parameters.ensemble_batchsize == n: self.get_batch_inds = lambda _: jnp.repeat( jnp.arange(n)[None], n, axis=0) else: self.get_batch_inds = lambda rk: random.choice( rk, n, shape=( n, self.parameters.ensemble_batchsize, )) del initial_extra.parameters.stepsize random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.potential, initial_state.grad_potential = vmap( scenario.potential_and_grad)(initial_state.value, random_keys[:n]) initial_state, initial_extra = self.adapt(initial_state, initial_extra) self.opt_init, self.opt_update, self.get_params = self.optimiser( step_size=self.parameters.stepsize, **initial_extra.parameters.optim_params) initial_extra.opt_state = self.opt_init(initial_state.value) return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) if not hasattr(initial_state, 'log_weight'): initial_state.log_weight = jnp.zeros(n) if not hasattr(initial_state, 'ess'): initial_state.ess = jnp.zeros(n) + n return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) initial_extra.random_key, scen_key = random.split( initial_extra.random_key) initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad( initial_state.value, scen_key) return initial_state, initial_extra
def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict, new_ensemble_state: cdict, new_extra: cdict) -> Tuple[cdict, cdict]: n = new_ensemble_state.value.shape[0] next_temperature = self.next_temperature(new_ensemble_state, new_extra) new_ensemble_state.temperature = jnp.ones(n) * next_temperature new_ensemble_state.log_weight = previous_ensemble_state.log_weight \ + self.log_weight(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) new_ensemble_state.ess = jnp.ones(n) * jnp.exp( log_ess_log_weight(new_ensemble_state.log_weight)) return new_ensemble_state, new_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) initial_extra.random_key, scen_key = random.split( initial_extra.random_key) initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad( initial_state.value, scen_key) if not hasattr( initial_state, 'momenta') or initial_state.momenta.shape[-1] != scenario.dim: initial_state.momenta = jnp.zeros(scenario.dim) return initial_state, initial_extra
def always(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: d = scenario.dim stepsize = reject_extra.parameters.stepsize friction = reject_extra.parameters.friction reject_state.momenta = reject_state.momenta * -1 # Update p - exactly according to solution of OU process # Accepted even if leapfrog step is rejected reject_extra.random_key, subkey = random.split(reject_extra.random_key) reject_state.momenta = reject_state.momenta * jnp.exp(- friction * stepsize) \ + jnp.sqrt(1 - jnp.exp(- 2 * friction * stepsize)) * random.normal(subkey, (d,)) return reject_state, reject_extra
def forward_proposal(self, scenario: Scenario, state: cdict, extra: cdict, random_key: jnp.ndarray) -> cdict: def mcmc_kernel( previous_carry: Tuple[cdict, cdict], _: None) -> Tuple[Tuple[cdict, cdict], Tuple[cdict, cdict]]: new_carry = self.mcmc_sampler.update(scenario, *previous_carry) return new_carry, new_carry extra.random_key = random_key start_state, start_extra = self.mcmc_sampler.startup( scenario, extra.parameters.mcmc_steps, state, extra) final_carry, chain = scan(mcmc_kernel, (start_state, start_extra), None, length=self.parameters.mcmc_steps) advanced_state, advanced_extra = self.clean_mcmc_chain( chain[0], chain[1]) advanced_state.prior_potential = scenario.prior_potential( advanced_state.value, advanced_extra.random_key) advanced_state.likelihood_potential = (advanced_state.potential - advanced_state.prior_potential) \ / scenario.temperature return advanced_state
def backward_simulation_full(ssm_scenario: StateSpaceModel, marginal_particles: cdict, n_samps: int, random_key: jnp.ndarray) -> cdict: marg_particles_vals = marginal_particles.value times = marginal_particles.t marginal_log_weight = marginal_particles.log_weight T, n_pf, d = marg_particles_vals.shape t_keys = random.split(random_key, T) final_particle_vals = marg_particles_vals[-1, random.categorical(t_keys[-1], marginal_log_weight[-1], shape=(n_samps,))] def back_sim_body(x_tplus1_all: jnp.ndarray, ind: int): x_t_all = full_resampling(ssm_scenario, marg_particles_vals[ind], times[ind], x_tplus1_all, times[ind + 1], marginal_log_weight[ind], t_keys[ind]) return x_t_all, x_t_all _, back_sim_particles = scan(back_sim_body, final_particle_vals, jnp.arange(T - 2, -1, -1)) out_samps = marginal_particles.copy() out_samps.value = jnp.vstack([back_sim_particles[::-1], final_particle_vals[jnp.newaxis]]) out_samps.num_transition_evals = jnp.append(0, jnp.ones(T - 1) * n_pf * n_samps) del out_samps.log_weight return out_samps
def forward_proposal_non_zero_weight(self, abc_scenario: ABCScenario, state: cdict, extra: cdict, random_key: jnp.ndarray) -> cdict: def mcmc_kernel(previous_carry: Tuple[cdict, cdict], _: None) -> Tuple[Tuple[cdict, cdict], Tuple[cdict, cdict]]: new_carry = self.mcmc_sampler.update(abc_scenario, *previous_carry) return new_carry, new_carry extra.random_key = random_key start_state, start_extra = self.mcmc_sampler.startup(abc_scenario, extra.parameters.mcmc_steps, state, extra) final_carry, chain = scan(mcmc_kernel, (start_state, start_extra), None, length=self.parameters.mcmc_steps) advanced_state, advanced_extra = self.clean_mcmc_chain(chain[0], chain[1]) return advanced_state
def clean_chain_ar(self, abc_scenario: ABCScenario, chain_state: cdict): threshold = jnp.quantile(chain_state.distance, self.parameters.acceptance_rate) self.parameters.threshold = float(threshold) chain_state.log_weight = jnp.where(chain_state.distance < threshold, 0., -jnp.inf) return chain_state
def particle_filter_body(samps_previous: cdict, iter_ind: int) -> Tuple[cdict, cdict]: x_previous = samps_previous.value log_weight_previous = samps_previous.log_weight int_rand_key = int_rand_keys[iter_ind] ess_previous = samps_previous.ess resample_bool = ess_previous < (ess_threshold * n) x_res = cond(resample_bool, _resample, lambda tup: tup[0], (x_previous, log_weight_previous, int_rand_key)) log_weight_res = jnp.where(resample_bool, jnp.zeros(n), log_weight_previous) split_keys = random.split(int_rand_key, len(x_previous)) x_new, log_weight_new = particle_filter.propose_and_intermediate_weight_vectorised( ssm_scenario, x_res, samps_previous.t, y[iter_ind], t[iter_ind], split_keys) log_weight_new = log_weight_res + log_weight_new samps_new = samps_previous.copy() samps_new.value = x_new samps_new.log_weight = log_weight_new samps_new.y = y[iter_ind] samps_new.t = t[iter_ind] samps_new.ess = ess_log_weight(log_weight_new) return samps_new, samps_new
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, startup_correction: bool = True, **kwargs) -> Tuple[cdict, cdict]: if initial_state is None: if is_implemented(scenario.prior_sample): initial_extra.random_key, sub_key = random.split( initial_extra.random_key) init_vals = scenario.prior_sample(sub_key) else: init_vals = jnp.zeros(scenario.dim) initial_state = cdict(value=init_vals) self.max_iter = n - 1 if 'correction' in kwargs.keys(): self.correction = kwargs['correction'] del kwargs['correction'] self.correction = check_correction(self.correction) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) if startup_correction: initial_state, initial_extra = self.correction.startup( scenario, self, n, initial_state, initial_extra, **kwargs) return initial_state, initial_extra
def proposal(self, abc_scenario: ABCScenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: proposed_state = reject_state.copy() proposed_extra = reject_extra.copy() stepsize = reject_extra.parameters.stepsize proposed_extra.random_key, subkey1, subkey2, subkey3 = random.split( reject_extra.random_key, 4) proposed_state.value = reject_state.value + jnp.sqrt( stepsize) * random.normal(subkey1, (abc_scenario.dim, )) proposed_state.prior_potential = abc_scenario.prior_potential( proposed_state.value, subkey2) proposed_state.simulated_data = abc_scenario.likelihood_sample( proposed_state.value, subkey3) proposed_state.distance = abc_scenario.distance_function( proposed_state.simulated_data) return proposed_state, proposed_extra
def startup(self, abc_scenario: ABCScenario, sampler: MCMCSampler, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(abc_scenario, sampler, n, initial_state, initial_extra, **kwargs) if sampler.parameters.stepsize is None: initial_extra.parameters.stepsize = 1.0 if sampler.parameters.threshold is None: initial_extra.parameters.threshold = 50. initial_state.threshold = initial_extra.parameters.threshold initial_extra.parameters.stepsize = jnp.ones( abc_scenario.dim) * initial_extra.parameters.stepsize initial_state.stepsize = initial_extra.parameters.stepsize initial_extra.post_mean = initial_state.value initial_extra.diag_post_cov = initial_extra.parameters.stepsize * abc_scenario.dim / 2.38**2 return initial_state, initial_extra
def startup(self, scenario: Scenario, sampler: MCMCSampler, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, sampler, n, initial_state, initial_extra, **kwargs) initial_state.alpha = 1. return initial_state, initial_extra
def proposal(self, abc_scenario: ABCScenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: proposed_state = reject_state.copy() proposed_extra = reject_extra.copy() proposed_extra.random_key, subkey1, subkey2, subkey3 = random.split( reject_extra.random_key, 4) proposed_state.value = self.importance_proposal(abc_scenario, subkey1) proposed_state.prior_potential = abc_scenario.prior_potential( proposed_state.value, subkey2) proposed_state.simulated_data = abc_scenario.likelihood_sample( proposed_state.value, subkey3) proposed_state.distance = abc_scenario.distance_function( proposed_state.simulated_data) proposed_state.log_weight = self.log_weight(abc_scenario, proposed_state, proposed_extra) return proposed_state, proposed_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: if not hasattr(scenario, 'prior_sample'): raise TypeError( f'Likelihood tempering requires scenario {scenario.name} to have prior_sample implemented' ) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) random_keys = random.split(initial_extra.random_key, 2 * n + 1) initial_extra.random_key = random_keys[-1] initial_state.prior_potential = vmap(scenario.prior_potential)( initial_state.value, random_keys[:n]) initial_state.likelihood_potential = vmap( scenario.likelihood_potential)(initial_state.value, random_keys[n:(2 * n)]) initial_state.potential = initial_state.prior_potential initial_state.temperature = jnp.zeros(n) initial_state.log_weight = jnp.zeros(n) initial_state.ess = jnp.zeros(n) + n scenario.temperature = 0. return initial_state, initial_extra
def proposal(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: proposed_state = reject_state.copy() d = scenario.dim x = reject_state.value stepsize = reject_extra.parameters.stepsize reject_extra.random_key, subkey, scen_key = random.split( reject_extra.random_key, 3) proposed_state.value = x + jnp.sqrt(stepsize) * random.normal( subkey, (d, )) proposed_state.potential = scenario.potential(proposed_state.value, scen_key) return proposed_state, reject_extra
def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict, new_ensemble_state: cdict, new_extra: cdict) -> Tuple[cdict, cdict]: n = new_ensemble_state.value.shape[0] next_threshold = self.next_threshold(new_ensemble_state, new_extra) new_ensemble_state.threshold = jnp.ones(n) * next_threshold new_extra.parameters.threshold = next_threshold new_ensemble_state.log_weight = self.log_weight(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) new_ensemble_state.ess = jnp.ones(n) * ess_log_weight(new_ensemble_state.log_weight) alive_inds = previous_ensemble_state.log_weight > -jnp.inf new_extra.alpha_mean = (new_ensemble_state.alpha * alive_inds).sum() / alive_inds.sum() new_ensemble_state, new_extra = self.adapt_mcmc_params(previous_ensemble_state, previous_extra, new_ensemble_state, new_extra) return new_ensemble_state, new_extra
def proposal(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: d = scenario.dim random_keys = random.split(reject_extra.random_key, self.parameters.leapfrog_steps + 2) reject_extra.random_key = random_keys[0] reject_state.momenta = random.normal(random_keys[1], (d, )) all_leapfrog_state = utils.leapfrog(scenario.potential_and_grad, reject_state, reject_extra.parameters.stepsize, random_keys[2:]) proposed_state = all_leapfrog_state[-1] # Technically we should reverse momenta now # but momenta target is symmetric and then immediately resampled at the next step anyway return proposed_state, reject_extra
def leapfrog_step(init_state: cdict, i: int): new_state = init_state.copy() p_half = init_state.momenta - stepsize / 2. * init_state.grad_potential new_state.value = init_state.value + stepsize * p_half new_state.potential, new_state.grad_potential = potential_and_grad( new_state.value, random_keys[i]) new_state.momenta = p_half - stepsize / 2. * new_state.grad_potential next_sample_chain = new_state.copy() next_sample_chain.momenta = jnp.vstack([p_half, new_state.momenta]) return new_state, next_sample_chain
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: self.mcmc_sampler.correction = check_correction( self.mcmc_sampler.correction) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) first_temp = self.next_temperature(initial_state, initial_extra) scenario.temperature = first_temp initial_state.temperature += first_temp initial_state.potential = initial_state.prior_potential + first_temp * initial_state.likelihood_potential initial_state.log_weight = -first_temp * initial_state.likelihood_potential initial_state.ess = jnp.repeat( jnp.exp(log_ess_log_weight(initial_state.log_weight)), n) initial_state, initial_extra = vmap( lambda state: self.mcmc_sampler.startup( scenario, n, state, initial_extra))(initial_state) initial_extra = initial_extra[0] return initial_state, initial_extra
def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict = None, initial_extra: cdict = None, **kwargs) -> Tuple[cdict, cdict]: if initial_state is None: if is_implemented(abc_scenario.prior_sample): initial_extra.random_key, sub_key = random.split( initial_extra.random_key) init_vals = abc_scenario.prior_sample(sub_key) else: init_vals = jnp.zeros(abc_scenario.dim) initial_state = cdict(value=init_vals) self.max_iter = n - 1 initial_state, initial_extra = super().startup(abc_scenario, n, initial_state, initial_extra, **kwargs) initial_state.log_weight = self.log_weight(abc_scenario, initial_state, initial_extra) return initial_state, initial_extra
def update(self, scenario: Scenario, ensemble_state: cdict, extra: cdict) -> Tuple[cdict, cdict]: n = ensemble_state.value.shape[0] extra.iter = extra.iter + 1 random_keys = random.split(extra.random_key, n + 2) batch_inds = self.get_batch_inds(random_keys[-1]) extra.random_key = random_keys[-2] phi_hat = self.kernelised_grad_matrix(ensemble_state.value, ensemble_state.grad_potential, extra.parameters.kernel_params, batch_inds) extra.opt_state = self.opt_update(extra.iter, -phi_hat, extra.opt_state) ensemble_state.value = self.get_params(extra.opt_state) ensemble_state.potential, ensemble_state.grad_potential \ = vmap(scenario.potential_and_grad)(ensemble_state.value, random_keys[:n]) ensemble_state, extra = self.adapt(ensemble_state, extra) return ensemble_state, extra
def proposal(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: random_keys = random.split(reject_extra.random_key, self.parameters.leapfrog_steps + 1) reject_extra.random_key = random_keys[0] all_leapfrog_state = utils.leapfrog(scenario.potential_and_grad, reject_state, reject_extra.parameters.stepsize, random_keys[1:]) proposed_state = all_leapfrog_state[-1] proposed_state.momenta *= -1 return proposed_state, reject_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: if initial_state is None: initial_extra.random_key, sub_key = random.split( initial_extra.random_key) if is_implemented(scenario.prior_sample): init_vals = vmap(scenario.prior_sample)(random.split( sub_key, n)) else: init_vals = random.normal(sub_key, shape=(n, scenario.dim)) initial_state = cdict(value=init_vals) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) return initial_state, initial_extra