def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) if self.parameters.ensemble_batchsize is None: self.parameters.ensemble_batchsize = n initial_extra.parameters.ensemble_batchsize = n if self.parameters.ensemble_batchsize == n: self.get_batch_inds = lambda _: jnp.repeat( jnp.arange(n)[None], n, axis=0) else: self.get_batch_inds = lambda rk: random.choice( rk, n, shape=( n, self.parameters.ensemble_batchsize, )) del initial_extra.parameters.stepsize random_keys = random.split(initial_extra.random_key, n + 1) initial_extra.random_key = random_keys[-1] initial_state.potential, initial_state.grad_potential = vmap( scenario.potential_and_grad)(initial_state.value, random_keys[:n]) initial_state, initial_extra = self.adapt(initial_state, initial_extra) self.opt_init, self.opt_update, self.get_params = self.optimiser( step_size=self.parameters.stepsize, **initial_extra.parameters.optim_params) initial_extra.opt_state = self.opt_init(initial_state.value) return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: if not hasattr(scenario, 'prior_sample'): raise TypeError( f'Likelihood tempering requires scenario {scenario.name} to have prior_sample implemented' ) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) random_keys = random.split(initial_extra.random_key, 2 * n + 1) initial_extra.random_key = random_keys[-1] initial_state.prior_potential = vmap(scenario.prior_potential)( initial_state.value, random_keys[:n]) initial_state.likelihood_potential = vmap( scenario.likelihood_potential)(initial_state.value, random_keys[n:(2 * n)]) initial_state.potential = initial_state.prior_potential initial_state.temperature = jnp.zeros(n) initial_state.log_weight = jnp.zeros(n) initial_state.ess = jnp.zeros(n) + n scenario.temperature = 0. return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) initial_extra.random_key, scen_key = random.split( initial_extra.random_key) initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad( initial_state.value, scen_key) return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) initial_extra.random_key, scen_key = random.split( initial_extra.random_key) initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad( initial_state.value, scen_key) if not hasattr( initial_state, 'momenta') or initial_state.momenta.shape[-1] != scenario.dim: initial_state.momenta = jnp.zeros(scenario.dim) return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: self.mcmc_sampler.correction = check_correction( self.mcmc_sampler.correction) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) first_temp = self.next_temperature(initial_state, initial_extra) scenario.temperature = first_temp initial_state.temperature += first_temp initial_state.potential = initial_state.prior_potential + first_temp * initial_state.likelihood_potential initial_state.log_weight = -first_temp * initial_state.likelihood_potential initial_state.ess = jnp.repeat( jnp.exp(log_ess_log_weight(initial_state.log_weight)), n) initial_state, initial_extra = vmap( lambda state: self.mcmc_sampler.startup( scenario, n, state, initial_extra))(initial_state) initial_extra = initial_extra[0] return initial_state, initial_extra
def update(self, scenario: Scenario, ensemble_state: cdict, extra: cdict) -> Tuple[cdict, cdict]: n = ensemble_state.value.shape[0] extra.iter = extra.iter + 1 random_keys = random.split(extra.random_key, n + 2) batch_inds = self.get_batch_inds(random_keys[-1]) extra.random_key = random_keys[-2] phi_hat = self.kernelised_grad_matrix(ensemble_state.value, ensemble_state.grad_potential, extra.parameters.kernel_params, batch_inds) extra.opt_state = self.opt_update(extra.iter, -phi_hat, extra.opt_state) ensemble_state.value = self.get_params(extra.opt_state) ensemble_state.potential, ensemble_state.grad_potential \ = vmap(scenario.potential_and_grad)(ensemble_state.value, random_keys[:n]) ensemble_state, extra = self.adapt(ensemble_state, extra) return ensemble_state, extra