class TestAutocorrelation(unittest.TestCase): key = random.PRNGKey(0) ind_draws_arr = random.normal(key, (100, )) ind_draws_cdict = cdict(value=ind_draws_arr) corr_draws_arr = 0.1 * jnp.cumsum(ind_draws_arr) corr_draws_cdict = cdict(value=corr_draws_arr) ind_draws_cdict_pot = cdict(value=corr_draws_arr, potential=ind_draws_arr) corr_draws_cdict_pot = cdict(value=ind_draws_arr, potential=corr_draws_arr) def test_array(self): ind_autocorr = metrics.autocorrelation(self.ind_draws_arr) self.assertEqual(ind_autocorr[0], 1.) npt.assert_array_equal(jnp.abs(ind_autocorr[1:]) < 0.3, True) corr_autocorr = metrics.autocorrelation(self.corr_draws_arr) self.assertEqual(corr_autocorr[0], 1.) npt.assert_array_equal(jnp.abs(corr_autocorr[1]) > 0.5, True) npt.assert_array_equal(jnp.abs(corr_autocorr[50]) < 0.5, True) npt.assert_array_equal(jnp.abs(corr_autocorr[-1]) < 0.3, True) def test_cdict_pot(self): # Potential ind_autocorr = metrics.autocorrelation(self.ind_draws_cdict_pot) self.assertEqual(ind_autocorr[0], 1.) npt.assert_array_equal(jnp.abs(ind_autocorr[1:]) < 0.3, True) corr_autocorr = metrics.autocorrelation(self.corr_draws_cdict_pot) self.assertEqual(corr_autocorr[0], 1.) npt.assert_array_equal(jnp.abs(corr_autocorr[1]) > 0.5, True) npt.assert_array_equal(jnp.abs(corr_autocorr[50]) < 0.5, True) npt.assert_array_equal(jnp.abs(corr_autocorr[-1]) < 0.3, True)
class testKSDStdGaussian(unittest.TestCase): key = random.PRNGKey(0) dim = 2 n_small = 10 ind_draws_arr_n_small = random.normal(key, (n_small, dim)) sample_n_small = cdict(value=ind_draws_arr_n_small, grad_potential=ind_draws_arr_n_small) n_large = 1000 ind_draws_arr_n_large = random.normal(key, (n_large, dim)) sample_n_large = cdict(value=ind_draws_arr_n_large, grad_potential=ind_draws_arr_n_large) def testksd_gaussian_kernel(self): kernel = kernels.Gaussian(bandwidth=1.) ksd_n_small_a = metrics.ksd(self.sample_n_small, kernel) ksd_n_large_a = metrics.ksd(self.sample_n_large, kernel) ksd_n_large_a_minibatch = metrics.ksd(self.sample_n_small, kernel, ensemble_batchsize=100, random_key=self.key) self.assertLess(ksd_n_large_a, ksd_n_small_a) npt.assert_almost_equal(ksd_n_large_a, ksd_n_large_a_minibatch, 1) kernel.parameters.bandwidth = 10. ksd_n_small_b = metrics.ksd(self.sample_n_small, kernel) ksd_n_large_b = metrics.ksd(self.sample_n_large, kernel) self.assertLess(ksd_n_large_b, ksd_n_small_b) def testksd_IMQ_kernel(self): kernel = kernels.IMQ(bandwidth=1., c=1., beta=-0.5) ksd_n_small_a = metrics.ksd(self.sample_n_small, kernel) ksd_n_large_a = metrics.ksd(self.sample_n_large, kernel) ksd_n_large_a_minibatch = metrics.ksd(self.sample_n_small, kernel, ensemble_batchsize=100, random_key=self.key) self.assertLess(ksd_n_large_a, ksd_n_small_a) npt.assert_almost_equal(ksd_n_large_a, ksd_n_large_a_minibatch, 1) kernel.parameters.bandwidth = 10. kernel.parameters.c = 0.1 kernel.parameters.beta = -0.1 ksd_n_small_b = metrics.ksd(self.sample_n_small, kernel) ksd_n_large_b = metrics.ksd(self.sample_n_large, kernel) self.assertLess(ksd_n_large_b, ksd_n_small_b)
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, startup_correction: bool = True, **kwargs) -> Tuple[cdict, cdict]: if initial_state is None: if is_implemented(scenario.prior_sample): initial_extra.random_key, sub_key = random.split( initial_extra.random_key) init_vals = scenario.prior_sample(sub_key) else: init_vals = jnp.zeros(scenario.dim) initial_state = cdict(value=init_vals) self.max_iter = n - 1 if 'correction' in kwargs.keys(): self.correction = kwargs['correction'] del kwargs['correction'] self.correction = check_correction(self.correction) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) if startup_correction: initial_state, initial_extra = self.correction.startup( scenario, self, n, initial_state, initial_extra, **kwargs) return initial_state, initial_extra
def startup(self, scenario: Scenario, n: int, initial_state: Union[None, cdict], initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) if hasattr(self, 'parameters') and hasattr(self.parameters, key): setattr(self.parameters, key, value) if not hasattr(self, 'max_iter')\ or not (isinstance(self.max_iter, int) or (isinstance(self.max_iter, jnp.ndarray) and self.max_iter.dtype == 'int32')): raise AttributeError(self.__repr__() + ' max_iter must be int') if not hasattr(initial_extra, 'iter'): initial_extra.iter = 0 if hasattr(self, 'parameters'): if not hasattr(initial_extra, 'parameters'): initial_extra.parameters = cdict() for key, value in self.parameters.__dict__.items(): if not hasattr(initial_extra.parameters, key) or getattr(initial_extra.parameters, key) is None: setattr(initial_extra.parameters, key, value) return initial_state, initial_extra
def simulate(self, t_all: jnp.ndarray, random_key: jnp.ndarray) -> cdict: len_t = len(t_all) random_keys = random.split(random_key, 2 * len_t) latent_keys = random_keys[:len_t] obs_keys = random_keys[len_t:] x_init = self.initial_sample(t_all[0], latent_keys[0]) def transition_body(x, i): new_x = self.transition_sample(x, t_all[i - 1], t_all[i], latent_keys[i]) return new_x, new_x _, x_all_but_zero = scan(transition_body, x_init, jnp.arange(1, len_t)) x_all = jnp.append(x_init[jnp.newaxis], x_all_but_zero, axis=0) y = vmap(self.likelihood_sample)(x_all, t_all, obs_keys) out_cdict = cdict(x=x_all, y=y, t=t_all, name=f'{self.name} simulation') return out_cdict
def resample_final(self, sample: cdict) -> cdict: unweighted_vals = sample.value[ -1, random.categorical(random.PRNGKey(1), logits=sample.log_weight[-1], shape=(self.n, ))] unweighted_sample = cdict(value=unweighted_vals) return unweighted_sample
class Testcdict(unittest.TestCase): cdict = core.cdict(test_arr=jnp.ones((10, 3)), test_float=3.) def test_init(self): npt.assert_(hasattr(self.cdict, 'test_arr')) npt.assert_array_equal(self.cdict.test_arr, jnp.ones((10, 3))) npt.assert_(hasattr(self.cdict, 'test_float')) npt.assert_equal(self.cdict.test_float, 3.) def test_copy(self): cdict2 = self.cdict.copy() npt.assert_(isinstance(cdict2, core.cdict)) npt.assert_(isinstance(cdict2.test_arr, jnp.DeviceArray)) npt.assert_array_equal(cdict2.test_arr, jnp.ones((10, 3))) npt.assert_(isinstance(cdict2.test_float, float)) npt.assert_equal(cdict2.test_float, 3.) cdict2.test_arr = jnp.zeros(5) npt.assert_array_equal(self.cdict.test_arr, jnp.ones((10, 3))) cdict2.test_float = 9. npt.assert_equal(self.cdict.test_float, 3.) def test_getitem(self): cdict_0get = self.cdict[0] npt.assert_(isinstance(cdict_0get, core.cdict)) npt.assert_(isinstance(cdict_0get.test_arr, jnp.DeviceArray)) npt.assert_array_equal(cdict_0get.test_arr, jnp.ones(3)) npt.assert_(isinstance(cdict_0get.test_float, float)) npt.assert_equal(cdict_0get.test_float, 3.) def test_additem(self): cdict_other = core.cdict(test_arr=jnp.ones((2, 3)), test_float=7., time=25.) self.cdict.time = 10. cdict_add = self.cdict + cdict_other npt.assert_(isinstance(cdict_add, core.cdict)) npt.assert_(isinstance(cdict_add.test_arr, jnp.DeviceArray)) npt.assert_array_equal(cdict_add.test_arr, jnp.ones((12, 3))) npt.assert_array_equal(cdict_add.time, 35.) npt.assert_(isinstance(cdict_add.test_float, float)) npt.assert_equal(cdict_add.test_float, 3.) npt.assert_array_equal(self.cdict.test_arr, jnp.ones((10, 3))) npt.assert_equal(self.cdict.test_float, 3.) npt.assert_equal(self.cdict.time, 10.) del self.cdict.time
def __init__(self, threshold: float = None, stepsize: float = None): super().__init__() self.parameters.threshold = threshold self.parameters.stepsize = stepsize self.tuning = cdict(parameter='threshold', target=0.1, metric='alpha', monotonicity='increasing')
def __init__(self, **kwargs): if not hasattr(self, 'tuning'): self.tuning = cdict(parameter='stepsize', target=None, metric='alpha', monotonicity='decreasing') # Initiate sampler class (set any additional parameters from init) super().__init__(**kwargs)
class testSJD(unittest.TestCase): zeros_arr = jnp.zeros(10) zeros_cdict = cdict(value=zeros_arr) seq_arr = jnp.arange(10) seq_cdict = cdict(value=seq_arr) def test_array(self): accept_rate = metrics.squared_jumping_distance(self.zeros_arr) self.assertEqual(accept_rate, 0.) accept_rate = metrics.squared_jumping_distance(self.seq_arr) self.assertEqual(accept_rate, 1.) def test_cdict(self): accept_rate = metrics.squared_jumping_distance(self.zeros_cdict) self.assertEqual(accept_rate, 0.) accept_rate = metrics.squared_jumping_distance(self.seq_cdict) self.assertEqual(accept_rate, 1.)
class testIAT(unittest.TestCase): key = random.PRNGKey(0) ind_draws_arr = random.normal(key, (100, )) ind_draws_cdict = cdict(value=ind_draws_arr) corr_draws_arr = 0.1 * jnp.cumsum(ind_draws_arr) def test_array(self): ind_iat = metrics.integrated_autocorrelation_time(self.ind_draws_arr) corr_iat = metrics.integrated_autocorrelation_time(self.corr_draws_arr) self.assertLess(ind_iat, 2.) self.assertGreater(corr_iat, 8.)
def __init__(self, name: str = None, **kwargs): if name is not None: self.name = name if not hasattr(self, 'parameters'): self.parameters = cdict() for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) else: setattr(self.parameters, key, value)
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: if initial_state is None: initial_extra.random_key, sub_key = random.split( initial_extra.random_key) if is_implemented(scenario.prior_sample): init_vals = vmap(scenario.prior_sample)(random.split( sub_key, n)) else: init_vals = random.normal(sub_key, shape=(n, scenario.dim)) initial_state = cdict(value=init_vals) initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) return initial_state, initial_extra
class TestLeapfrog(unittest.TestCase): stepsize = 0.1 leapfrog_steps = 3 start_state = cdict(value=jnp.zeros(2), potential=jnp.array(0.), grad_potential=jnp.array([1., 2.]), momenta=jnp.ones(2), auxiliary_float=5., auxiliary_0darray=jnp.array(0.), auxiliary_2darray=jnp.zeros(2)) def test_leapfrog(self): full_state = utils.leapfrog( lambda x, _: (0., x), self.start_state, self.stepsize, random.split(random.PRNGKey(0), self.leapfrog_steps)) out_state = full_state[-1] npt.assert_array_equal(out_state.value, jnp.array([0.28120947, 0.266409])) npt.assert_array_equal(out_state.grad_potential, jnp.array([0.28120947, 0.266409])) npt.assert_array_equal(out_state.momenta, jnp.array([0.9075344, 0.8597696])) def test_all_leapfrog(self): full_state = utils.leapfrog( lambda x, _: (0., x), self.start_state, self.stepsize, random.split(random.PRNGKey(0), self.leapfrog_steps)) npt.assert_array_equal(full_state.value.shape, (self.leapfrog_steps + 1, 2)) npt.assert_array_equal(full_state.momenta.shape, (2 * self.leapfrog_steps + 1, 2)) npt.assert_array_equal(full_state.value[0], jnp.zeros(2)) npt.assert_array_equal(full_state.value[-1], jnp.array([0.28120947, 0.266409])) npt.assert_array_equal(full_state.momenta[0], jnp.ones(2)) npt.assert_array_equal(full_state.momenta[-1], jnp.array([0.9075344, 0.8597696]))
def test_additem(self): cdict_other = core.cdict(test_arr=jnp.ones((2, 3)), test_float=7., time=25.) self.cdict.time = 10. cdict_add = self.cdict + cdict_other npt.assert_(isinstance(cdict_add, core.cdict)) npt.assert_(isinstance(cdict_add.test_arr, jnp.DeviceArray)) npt.assert_array_equal(cdict_add.test_arr, jnp.ones((12, 3))) npt.assert_array_equal(cdict_add.time, 35.) npt.assert_(isinstance(cdict_add.test_float, float)) npt.assert_equal(cdict_add.test_float, 3.) npt.assert_array_equal(self.cdict.test_arr, jnp.ones((10, 3))) npt.assert_equal(self.cdict.test_float, 3.) npt.assert_equal(self.cdict.time, 10.) del self.cdict.time
def run(scenario: Scenario, sampler: Union[Sampler, Type[Sampler]], n: int, random_key: Union[None, jnp.ndarray], initial_state: cdict = None, initial_extra: cdict = None, **kwargs) -> Union[cdict, Tuple[cdict, jnp.ndarray]]: if isclass(sampler): sampler = sampler(**kwargs) sampler.n = n if initial_extra is None: initial_extra = cdict() if random_key is not None: initial_extra.random_key = random_key initial_state, initial_extra = sampler.startup(scenario, n, initial_state, initial_extra, **kwargs) summary = sampler.summary(scenario, initial_state, initial_extra) transport_kernel = partial(sampler.update, scenario) start = time() chain = while_loop_stacked(lambda state, extra: ~sampler.termination_criterion(state, extra), transport_kernel, (initial_state, initial_extra), sampler.max_iter) chain = initial_state[jnp.newaxis] + chain chain = sampler.clean_chain(scenario, chain) chain.value.block_until_ready() end = time() chain.time = end - start chain.summary = summary return chain
def initiate_particles(ssm_scenario: StateSpaceModel, particle_filter: ParticleFilter, n: int, random_key: jnp.ndarray, y: jnp.ndarray = None, t: float = None) -> cdict: particle_filter.startup(ssm_scenario) sub_keys = random.split(random_key, n) init_vals, init_log_weight = particle_filter.initial_sample_and_weight_vectorised( ssm_scenario, y, t, sub_keys) if init_vals.ndim == 1: init_vals = init_vals[..., jnp.newaxis] initial_sample = cdict( value=init_vals[jnp.newaxis], log_weight=init_log_weight[jnp.newaxis], t=jnp.atleast_1d(t) if t is not None else jnp.zeros(1), y=y[jnp.newaxis] if y is not None else None, ess=jnp.atleast_1d(ess_log_weight(init_log_weight))) return initial_sample
def startup(self, abc_scenario: ABCScenario, n: int, initial_state: cdict = None, initial_extra: cdict = None, **kwargs) -> Tuple[cdict, cdict]: if initial_state is None: if is_implemented(abc_scenario.prior_sample): initial_extra.random_key, sub_key = random.split( initial_extra.random_key) init_vals = abc_scenario.prior_sample(sub_key) else: init_vals = jnp.zeros(abc_scenario.dim) initial_state = cdict(value=init_vals) self.max_iter = n - 1 initial_state, initial_extra = super().startup(abc_scenario, n, initial_state, initial_extra, **kwargs) initial_state.log_weight = self.log_weight(abc_scenario, initial_state, initial_extra) return initial_state, initial_extra
def propagate_particle_smoother_bs(ssm_scenario: StateSpaceModel, particle_filter: ParticleFilter, particles: cdict, y_new: jnp.ndarray, t_new: float, random_key: jnp.ndarray, lag: int, ess_threshold: float, maximum_rejections: int, init_bound_param: float, bound_inflation: float) -> cdict: n = particles.value.shape[1] if not hasattr(particles, 'num_transition_evals'): particles.num_transition_evals = jnp.array(0) if not hasattr(particles, 'marginal_filter'): particles.marginal_filter = cdict(value=particles.value, log_weight=particles.log_weight, y=particles.y, t=particles.t, ess=particles.ess) split_keys = random.split(random_key, 4) out_particles = particles # Propagate marginal filter particles out_particles.marginal_filter = propagate_particle_filter(ssm_scenario, particle_filter, particles.marginal_filter, y_new, t_new, split_keys[1], ess_threshold, False) out_particles.y = jnp.append(out_particles.y, y_new[jnp.newaxis], axis=0) out_particles.t = jnp.append(out_particles.t, t_new) out_particles.log_weight = jnp.zeros(n) out_particles.ess = out_particles.marginal_filter.ess[-1] len_t = len(out_particles.t) stitch_ind_min_1 = len_t - lag - 1 stitch_ind = len_t - lag def back_sim_only(marginal_filter): backward_sim = backward_simulation(ssm_scenario, marginal_filter, split_keys[2], n, maximum_rejections, init_bound_param, bound_inflation) return backward_sim.value, backward_sim.num_transition_evals.sum() def back_sim_and_stitch(marginal_filter): backward_sim = backward_simulation(ssm_scenario, marginal_filter[stitch_ind_min_1:], split_keys[2], n, maximum_rejections, init_bound_param, bound_inflation) vals, stitch_nte = fixed_lag_stitching(ssm_scenario, out_particles.value[:(stitch_ind_min_1 + 1)], out_particles.t[stitch_ind_min_1], backward_sim.value, jnp.zeros(n), out_particles.t[stitch_ind], random_key, maximum_rejections, init_bound_param, bound_inflation) return vals, stitch_nte + backward_sim.num_transition_evals.sum() if stitch_ind_min_1 >= 0: out_particles.value, num_transition_evals = back_sim_and_stitch(out_particles.marginal_filter) else: out_particles.value, num_transition_evals = back_sim_only(out_particles.marginal_filter) out_particles.num_transition_evals = jnp.append(out_particles.num_transition_evals, num_transition_evals) # out_particles.value = cond(stitch_ind_min_1 >= 0, # back_sim_and_stitch, # back_sim_only, # out_particles.marginal_filter) return out_particles
def __init__(self, **kwargs): if not hasattr(self, 'parameters'): self.parameters = cdict() self.parameters.__dict__.update(kwargs)