Пример #1
0
    def __init__(self, *args, **kwargs):

        super().__init__(*args, **kwargs)

        self.scenario_cov = jnp.array([[1., 0.9], [0.9, 2.]])
        self.scenario = toy_examples.Gaussian(covariance=self.scenario_cov)
        self.n = int(1e5)

        self.__test__ = self.sampler is not None
        if self.__test__:
            self.adapt_sample = run(self.scenario, self.sampler, self.n, PRNGKey(0), correction=RMMetropolis())
            self.sampler.parameters.stepsize = self.adapt_sample.stepsize[-1]
            self.warmstart_sample = run(self.scenario, self.sampler, self.n, PRNGKey(0), correction=Metropolis)
Пример #2
0
 def test_adaptive_diag_stepsize(self):
     sampler = RandomWalkABC()
     sampler.tuning.target = 0.1
     sample = run(self.scenario, sampler, n=self.n,
                  random_key=PRNGKey(0), correction=RMMetropolisDiagStepsize(rm_stepsize_scale=0.1))
     self._test_mean(sample.value)
     npt.assert_almost_equal(sample.alpha.mean(), sampler.tuning.target, decimal=1)
Пример #3
0
 def test_callable_stepsize(self):
     sample = run(self.scenario,
                  SVGD(max_iter=self.n_iter, stepsize=lambda i: i**-0.5),
                  n=self.n,
                  random_key=random.PRNGKey(0))
     self._test_mean(sample)
     self._test_cov(sample)
Пример #4
0
 def test_threshold_preschedule(self):
     threshold_schedule = jnp.linspace(10, 0.1, 10)
     sampler = MetropolisedABCSMCSampler(threshold_schedule=threshold_schedule)
     sample = run(self.scenario, sampler, n=self.n,
                  random_key=PRNGKey(0))
     self._test_mean(sample.value[-1])
     self._test_cov(sample.value[-1])
Пример #5
0
 def test_post_threshold(self):
     acceptance_rate = 0.1
     sampler = VanillaABC(acceptance_rate=acceptance_rate)
     sample = run(self.scenario, sampler, n=self.n,
                  random_key=PRNGKey(0))
     self._test_mean(sample.value[sample.log_weight > -jnp.inf], 10.)
     npt.assert_almost_equal(jnp.mean(sample.log_weight == 0), acceptance_rate, decimal=3)
     self.assertNotEqual(sampler.parameters.threshold, jnp.inf)
Пример #6
0
 def test_adaptive(self):
     retain_parameter = 0.8
     sampler = MetropolisedABCSMCSampler(ess_threshold_retain=retain_parameter)
     sample = run(self.scenario, sampler,
                  n=self.n,
                  random_key=PRNGKey(0))
     self._test_mean(sample.value[-1])
     self._test_cov(sample.value[-1])
Пример #7
0
 def test_tempered_adaptive_OD(self):
     sample = run(self.scenario,
                  MetropolisedSMCSampler(Overdamped(stepsize=1.0)),
                  self.n,
                  random_key=random.PRNGKey(0),
                  ess_threshold=0.9)
     unweighted_sample = self.resample_final(sample)
     self._test_mean(unweighted_sample)
     self._test_cov(unweighted_sample)
Пример #8
0
 def test_ensemble_minibatch(self):
     sample = run(self.scenario,
                  SVGD(max_iter=self.n_iter,
                       stepsize=1.0,
                       ensemble_batchsize=100),
                  n=1000,
                  random_key=random.PRNGKey(0))
     self._test_mean(sample)
     self._test_cov(sample)
Пример #9
0
 def test_tempered_preschedule_OD(self):
     sample = run(self.scenario,
                  MetropolisedSMCSampler(Overdamped(stepsize=1.0)),
                  self.n,
                  random_key=random.PRNGKey(0),
                  temperature_schedule=self.preschedule)
     unweighted_sample = self.resample_final(sample)
     self._test_mean(unweighted_sample)
     self._test_cov(unweighted_sample)
     npt.assert_array_equal(sample.temperature, self.preschedule[1:])
Пример #10
0
 def test_adaptive_diag_stepsize(self):
     sampler = RandomWalkABC()
     sampler.tuning.target = 0.1
     sample = run(self.scenario, sampler, n=self.n,
                  random_key=PRNGKey(0), correction=RMMetropolisDiagStepsize())
     self._test_mean(sample.value)
     self._test_cov(sample.value)
     npt.assert_almost_equal(sample.alpha.mean(), sampler.tuning.target, decimal=1)
     npt.assert_array_almost_equal(sample.stepsize[-1],
                                   jnp.diag(jnp.cov(sample.value, rowvar=False)) / self.scenario.dim * 2.38 ** 2,
                                   decimal=1)
Пример #11
0
    def test_fixed_kernel_params(self):
        class SVGD_fixed(SVGD):
            def adapt(self, ensemble_state: cdict,
                      ensemble_extra: cdict) -> Tuple[cdict, cdict]:
                return ensemble_state, ensemble_extra

        sample = run(self.scenario,
                     SVGD_fixed(max_iter=self.n_iter, stepsize=0.8),
                     n=self.n,
                     random_key=random.PRNGKey(0))
        self._test_mean(sample[-1])
        self._test_cov(sample[-1])
Пример #12
0
    def test_mean_update(self):
        class SVGD_mean(SVGD):
            def adapt(self, ensemble_state: cdict,
                      ensemble_extra: cdict) -> Tuple[cdict, cdict]:
                ensemble_extra.parameters.kernel_params.bandwidth = mean_bandwidth_update(
                    ensemble_state.value)
                ensemble_state.kernel_params = ensemble_extra.parameters.kernel_params
                return ensemble_state, ensemble_extra

        sample = run(self.scenario,
                     SVGD_mean(max_iter=self.n_iter, stepsize=1.0),
                     n=self.n,
                     random_key=random.PRNGKey(0))
        self._test_mean(sample)
        self._test_cov(sample)
Пример #13
0
 def test_pre_threshold(self):
     threshold = 1000.
     sample = run(self.scenario, VanillaABC(threshold=threshold), n=self.n,
                  random_key=PRNGKey(0))
     self._test_mean(sample.value[sample.log_weight > -jnp.inf], 10.)
Пример #14
0
 def test_fixed_stepsize(self):
     stepsize = 0.1
     sample = run(self.scenario, RandomWalkABC(stepsize=stepsize, threshold=self.threshold), n=self.n,
                  random_key=PRNGKey(0))
     self._test_mean(sample.value)
     self._test_cov(sample.value)