def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.scenario_cov = jnp.array([[1., 0.9], [0.9, 2.]]) self.scenario = toy_examples.Gaussian(covariance=self.scenario_cov) self.n = int(1e5) self.__test__ = self.sampler is not None if self.__test__: self.adapt_sample = run(self.scenario, self.sampler, self.n, PRNGKey(0), correction=RMMetropolis()) self.sampler.parameters.stepsize = self.adapt_sample.stepsize[-1] self.warmstart_sample = run(self.scenario, self.sampler, self.n, PRNGKey(0), correction=Metropolis)
def test_adaptive_diag_stepsize(self): sampler = RandomWalkABC() sampler.tuning.target = 0.1 sample = run(self.scenario, sampler, n=self.n, random_key=PRNGKey(0), correction=RMMetropolisDiagStepsize(rm_stepsize_scale=0.1)) self._test_mean(sample.value) npt.assert_almost_equal(sample.alpha.mean(), sampler.tuning.target, decimal=1)
def test_callable_stepsize(self): sample = run(self.scenario, SVGD(max_iter=self.n_iter, stepsize=lambda i: i**-0.5), n=self.n, random_key=random.PRNGKey(0)) self._test_mean(sample) self._test_cov(sample)
def test_threshold_preschedule(self): threshold_schedule = jnp.linspace(10, 0.1, 10) sampler = MetropolisedABCSMCSampler(threshold_schedule=threshold_schedule) sample = run(self.scenario, sampler, n=self.n, random_key=PRNGKey(0)) self._test_mean(sample.value[-1]) self._test_cov(sample.value[-1])
def test_post_threshold(self): acceptance_rate = 0.1 sampler = VanillaABC(acceptance_rate=acceptance_rate) sample = run(self.scenario, sampler, n=self.n, random_key=PRNGKey(0)) self._test_mean(sample.value[sample.log_weight > -jnp.inf], 10.) npt.assert_almost_equal(jnp.mean(sample.log_weight == 0), acceptance_rate, decimal=3) self.assertNotEqual(sampler.parameters.threshold, jnp.inf)
def test_adaptive(self): retain_parameter = 0.8 sampler = MetropolisedABCSMCSampler(ess_threshold_retain=retain_parameter) sample = run(self.scenario, sampler, n=self.n, random_key=PRNGKey(0)) self._test_mean(sample.value[-1]) self._test_cov(sample.value[-1])
def test_tempered_adaptive_OD(self): sample = run(self.scenario, MetropolisedSMCSampler(Overdamped(stepsize=1.0)), self.n, random_key=random.PRNGKey(0), ess_threshold=0.9) unweighted_sample = self.resample_final(sample) self._test_mean(unweighted_sample) self._test_cov(unweighted_sample)
def test_ensemble_minibatch(self): sample = run(self.scenario, SVGD(max_iter=self.n_iter, stepsize=1.0, ensemble_batchsize=100), n=1000, random_key=random.PRNGKey(0)) self._test_mean(sample) self._test_cov(sample)
def test_tempered_preschedule_OD(self): sample = run(self.scenario, MetropolisedSMCSampler(Overdamped(stepsize=1.0)), self.n, random_key=random.PRNGKey(0), temperature_schedule=self.preschedule) unweighted_sample = self.resample_final(sample) self._test_mean(unweighted_sample) self._test_cov(unweighted_sample) npt.assert_array_equal(sample.temperature, self.preschedule[1:])
def test_adaptive_diag_stepsize(self): sampler = RandomWalkABC() sampler.tuning.target = 0.1 sample = run(self.scenario, sampler, n=self.n, random_key=PRNGKey(0), correction=RMMetropolisDiagStepsize()) self._test_mean(sample.value) self._test_cov(sample.value) npt.assert_almost_equal(sample.alpha.mean(), sampler.tuning.target, decimal=1) npt.assert_array_almost_equal(sample.stepsize[-1], jnp.diag(jnp.cov(sample.value, rowvar=False)) / self.scenario.dim * 2.38 ** 2, decimal=1)
def test_fixed_kernel_params(self): class SVGD_fixed(SVGD): def adapt(self, ensemble_state: cdict, ensemble_extra: cdict) -> Tuple[cdict, cdict]: return ensemble_state, ensemble_extra sample = run(self.scenario, SVGD_fixed(max_iter=self.n_iter, stepsize=0.8), n=self.n, random_key=random.PRNGKey(0)) self._test_mean(sample[-1]) self._test_cov(sample[-1])
def test_mean_update(self): class SVGD_mean(SVGD): def adapt(self, ensemble_state: cdict, ensemble_extra: cdict) -> Tuple[cdict, cdict]: ensemble_extra.parameters.kernel_params.bandwidth = mean_bandwidth_update( ensemble_state.value) ensemble_state.kernel_params = ensemble_extra.parameters.kernel_params return ensemble_state, ensemble_extra sample = run(self.scenario, SVGD_mean(max_iter=self.n_iter, stepsize=1.0), n=self.n, random_key=random.PRNGKey(0)) self._test_mean(sample) self._test_cov(sample)
def test_pre_threshold(self): threshold = 1000. sample = run(self.scenario, VanillaABC(threshold=threshold), n=self.n, random_key=PRNGKey(0)) self._test_mean(sample.value[sample.log_weight > -jnp.inf], 10.)
def test_fixed_stepsize(self): stepsize = 0.1 sample = run(self.scenario, RandomWalkABC(stepsize=stepsize, threshold=self.threshold), n=self.n, random_key=PRNGKey(0)) self._test_mean(sample.value) self._test_cov(sample.value)