def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict, initial_extra: cdict, startup_correction: bool = True, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(abc_scenario, n, initial_state, initial_extra, **kwargs) if not hasattr(initial_state, 'prior_potential') and is_implemented( abc_scenario.prior_potential): initial_extra.random_key, subkey = random.split( initial_extra.random_key) initial_state.prior_potential = abc_scenario.prior_potential( initial_state.value, subkey) if not hasattr(initial_state, 'simulated_data'): initial_extra.random_key, subkey = random.split( initial_extra.random_key) initial_state.simulated_data = abc_scenario.likelihood_sample( initial_state.value, subkey) if not hasattr(initial_state, 'distance'): initial_state.distance = abc_scenario.distance_function( initial_state.simulated_data) return initial_state, initial_extra
def forward_proposal_non_zero_weight(self, abc_scenario: ABCScenario, state: cdict, extra: cdict, random_key: jnp.ndarray) -> cdict: def mcmc_kernel(previous_carry: Tuple[cdict, cdict], _: None) -> Tuple[Tuple[cdict, cdict], Tuple[cdict, cdict]]: new_carry = self.mcmc_sampler.update(abc_scenario, *previous_carry) return new_carry, new_carry extra.random_key = random_key start_state, start_extra = self.mcmc_sampler.startup(abc_scenario, extra.parameters.mcmc_steps, state, extra) final_carry, chain = scan(mcmc_kernel, (start_state, start_extra), None, length=self.parameters.mcmc_steps) advanced_state, advanced_extra = self.clean_mcmc_chain(chain[0], chain[1]) return advanced_state
def forward_proposal(self, scenario: Scenario, state: cdict, extra: cdict, random_key: jnp.ndarray) -> cdict: def mcmc_kernel( previous_carry: Tuple[cdict, cdict], _: None) -> Tuple[Tuple[cdict, cdict], Tuple[cdict, cdict]]: new_carry = self.mcmc_sampler.update(scenario, *previous_carry) return new_carry, new_carry extra.random_key = random_key start_state, start_extra = self.mcmc_sampler.startup( scenario, extra.parameters.mcmc_steps, state, extra) final_carry, chain = scan(mcmc_kernel, (start_state, start_extra), None, length=self.parameters.mcmc_steps) advanced_state, advanced_extra = self.clean_mcmc_chain( chain[0], chain[1]) advanced_state.prior_potential = scenario.prior_potential( advanced_state.value, advanced_extra.random_key) advanced_state.likelihood_potential = (advanced_state.potential - advanced_state.prior_potential) \ / scenario.temperature return advanced_state
def update(self, scenario: Scenario, ensemble_state: cdict, extra: cdict) -> Tuple[cdict, cdict]: extra.iter = extra.iter + 1 n = ensemble_state.value.shape[0] resample_bool = self.resample_criterion(ensemble_state, extra) random_keys_all = random.split(extra.random_key, n + 2) extra.random_key = random_keys_all[-1] resampled_ensemble_state \ = cond(resample_bool, lambda state: self.resample(state, random_keys_all[-2]), lambda state: state, ensemble_state) advanced_state = vmap(self.forward_proposal, in_axes=(None, 0, None, 0))(scenario, resampled_ensemble_state, extra, random_keys_all[:n]) advanced_state, advanced_extra = self.adapt(resampled_ensemble_state, extra, advanced_state, extra) return advanced_state, advanced_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) if self.parameters.ensemble_batchsize is None: self.parameters.ensemble_batchsize = n initial_extra.parameters.ensemble_batchsize = n if self.parameters.ensemble_batchsize == n: self.get_batch_inds = lambda _: jnp.repeat( jnp.arange(n)[None], n, axis=0) else: self.get_batch_inds = lambda rk: random.choice( rk, n, shape=( n, self.parameters.ensemble_batchsize, )) del initial_extra.parameters.stepsize random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.potential, initial_state.grad_potential = vmap( scenario.potential_and_grad)(initial_state.value, random_keys[:n]) initial_state, initial_extra = self.adapt(initial_state, initial_extra) self.opt_init, self.opt_update, self.get_params = self.optimiser( step_size=self.parameters.stepsize, **initial_extra.parameters.optim_params) initial_extra.opt_state = self.opt_init(initial_state.value) return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: if not hasattr(scenario, 'prior_sample'): raise TypeError( f'Likelihood tempering requires scenario {scenario.name} to have prior_sample implemented' ) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) random_keys = random.split(initial_extra.random_key, 2 * n + 1) initial_extra.random_key = random_keys[-1] initial_state.prior_potential = vmap(scenario.prior_potential)( initial_state.value, random_keys[:n]) initial_state.likelihood_potential = vmap( scenario.likelihood_potential)(initial_state.value, random_keys[n:(2 * n)]) initial_state.potential = initial_state.prior_potential initial_state.temperature = jnp.zeros(n) initial_state.log_weight = jnp.zeros(n) initial_state.ess = jnp.zeros(n) + n scenario.temperature = 0. return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, startup_correction: bool = True, **kwargs) -> Tuple[cdict, cdict]: if initial_state is None: if is_implemented(scenario.prior_sample): initial_extra.random_key, sub_key = random.split( initial_extra.random_key) init_vals = scenario.prior_sample(sub_key) else: init_vals = jnp.zeros(scenario.dim) initial_state = cdict(value=init_vals) self.max_iter = n - 1 if 'correction' in kwargs.keys(): self.correction = kwargs['correction'] del kwargs['correction'] self.correction = check_correction(self.correction) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) if startup_correction: initial_state, initial_extra = self.correction.startup( scenario, self, n, initial_state, initial_extra, **kwargs) return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) initial_extra.random_key, scen_key = random.split( initial_extra.random_key) initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad( initial_state.value, scen_key) return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) initial_extra.random_key, scen_key = random.split( initial_extra.random_key) initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad( initial_state.value, scen_key) if not hasattr( initial_state, 'momenta') or initial_state.momenta.shape[-1] != scenario.dim: initial_state.momenta = jnp.zeros(scenario.dim) return initial_state, initial_extra
def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = SMCSampler.startup(self, abc_scenario, n, initial_state, initial_extra, **kwargs) n = len(initial_state.value) if not hasattr(initial_state, 'prior_potential') and is_implemented(abc_scenario.prior_potential): random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.prior_potential = vmap(abc_scenario.prior_potential)(initial_state.value, random_keys[:n]) if not hasattr(initial_state, 'simulated_data'): random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.simulated_data = vmap(abc_scenario.likelihood_sample)(initial_state.value, random_keys[:n]) if not hasattr(initial_state, 'distance'): initial_state.distance = vmap(abc_scenario.distance_function)(initial_state.simulated_data) if not hasattr(initial_state, 'threshold'): if self.threshold_schedule is None: initial_state.threshold = jnp.zeros(n) + jnp.inf else: initial_state.threshold = jnp.zeros(n) + self.threshold_schedule[0] if not hasattr(initial_state, 'ess'): initial_state.ess = jnp.zeros(n) + n return initial_state, initial_extra
def always(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: d = scenario.dim stepsize = reject_extra.parameters.stepsize friction = reject_extra.parameters.friction reject_state.momenta = reject_state.momenta * -1 # Update p - exactly according to solution of OU process # Accepted even if leapfrog step is rejected reject_extra.random_key, subkey = random.split(reject_extra.random_key) reject_state.momenta = reject_state.momenta * jnp.exp(- friction * stepsize) \ + jnp.sqrt(1 - jnp.exp(- 2 * friction * stepsize)) * random.normal(subkey, (d,)) return reject_state, reject_extra
def proposal(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: random_keys = random.split(reject_extra.random_key, self.parameters.leapfrog_steps + 1) reject_extra.random_key = random_keys[0] all_leapfrog_state = utils.leapfrog(scenario.potential_and_grad, reject_state, reject_extra.parameters.stepsize, random_keys[1:]) proposed_state = all_leapfrog_state[-1] proposed_state.momenta *= -1 return proposed_state, reject_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: if initial_state is None: initial_extra.random_key, sub_key = random.split( initial_extra.random_key) if is_implemented(scenario.prior_sample): init_vals = vmap(scenario.prior_sample)(random.split( sub_key, n)) else: init_vals = random.normal(sub_key, shape=(n, scenario.dim)) initial_state = cdict(value=init_vals) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) return initial_state, initial_extra
def proposal(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: proposed_state = reject_state.copy() d = scenario.dim x = reject_state.value stepsize = reject_extra.parameters.stepsize reject_extra.random_key, subkey, scen_key = random.split( reject_extra.random_key, 3) proposed_state.value = x + jnp.sqrt(stepsize) * random.normal( subkey, (d, )) proposed_state.potential = scenario.potential(proposed_state.value, scen_key) return proposed_state, reject_extra
def proposal(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: d = scenario.dim random_keys = random.split(reject_extra.random_key, self.parameters.leapfrog_steps + 2) reject_extra.random_key = random_keys[0] reject_state.momenta = random.normal(random_keys[1], (d, )) all_leapfrog_state = utils.leapfrog(scenario.potential_and_grad, reject_state, reject_extra.parameters.stepsize, random_keys[2:]) proposed_state = all_leapfrog_state[-1] # Technically we should reverse momenta now # but momenta target is symmetric and then immediately resampled at the next step anyway return proposed_state, reject_extra
def run(scenario: Scenario, sampler: Union[Sampler, Type[Sampler]], n: int, random_key: Union[None, jnp.ndarray], initial_state: cdict = None, initial_extra: cdict = None, **kwargs) -> Union[cdict, Tuple[cdict, jnp.ndarray]]: if isclass(sampler): sampler = sampler(**kwargs) sampler.n = n if initial_extra is None: initial_extra = cdict() if random_key is not None: initial_extra.random_key = random_key initial_state, initial_extra = sampler.startup(scenario, n, initial_state, initial_extra, **kwargs) summary = sampler.summary(scenario, initial_state, initial_extra) transport_kernel = partial(sampler.update, scenario) start = time() chain = while_loop_stacked(lambda state, extra: ~sampler.termination_criterion(state, extra), transport_kernel, (initial_state, initial_extra), sampler.max_iter) chain = initial_state[jnp.newaxis] + chain chain = sampler.clean_chain(scenario, chain) chain.value.block_until_ready() end = time() chain.time = end - start chain.summary = summary return chain
def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict = None, initial_extra: cdict = None, **kwargs) -> Tuple[cdict, cdict]: if initial_state is None: if is_implemented(abc_scenario.prior_sample): initial_extra.random_key, sub_key = random.split( initial_extra.random_key) init_vals = abc_scenario.prior_sample(sub_key) else: init_vals = jnp.zeros(abc_scenario.dim) initial_state = cdict(value=init_vals) self.max_iter = n - 1 initial_state, initial_extra = super().startup(abc_scenario, n, initial_state, initial_extra, **kwargs) initial_state.log_weight = self.log_weight(abc_scenario, initial_state, initial_extra) return initial_state, initial_extra
def update(self, scenario: Scenario, ensemble_state: cdict, extra: cdict) -> Tuple[cdict, cdict]: n = ensemble_state.value.shape[0] extra.iter = extra.iter + 1 random_keys = random.split(extra.random_key, n + 2) batch_inds = self.get_batch_inds(random_keys[-1]) extra.random_key = random_keys[-2] phi_hat = self.kernelised_grad_matrix(ensemble_state.value, ensemble_state.grad_potential, extra.parameters.kernel_params, batch_inds) extra.opt_state = self.opt_update(extra.iter, -phi_hat, extra.opt_state) ensemble_state.value = self.get_params(extra.opt_state) ensemble_state.potential, ensemble_state.grad_potential \ = vmap(scenario.potential_and_grad)(ensemble_state.value, random_keys[:n]) ensemble_state, extra = self.adapt(ensemble_state, extra) return ensemble_state, extra
def always(self, abc_scenario: ABCScenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: reject_extra.random_key, _ = random.split(reject_extra.random_key) return reject_state, reject_extra