Ejemplo n.º 1
0
class TestAutocorrelation(unittest.TestCase):
    key = random.PRNGKey(0)
    ind_draws_arr = random.normal(key, (100, ))
    ind_draws_cdict = cdict(value=ind_draws_arr)

    corr_draws_arr = 0.1 * jnp.cumsum(ind_draws_arr)
    corr_draws_cdict = cdict(value=corr_draws_arr)

    ind_draws_cdict_pot = cdict(value=corr_draws_arr, potential=ind_draws_arr)
    corr_draws_cdict_pot = cdict(value=ind_draws_arr, potential=corr_draws_arr)

    def test_array(self):
        ind_autocorr = metrics.autocorrelation(self.ind_draws_arr)
        self.assertEqual(ind_autocorr[0], 1.)
        npt.assert_array_equal(jnp.abs(ind_autocorr[1:]) < 0.3, True)

        corr_autocorr = metrics.autocorrelation(self.corr_draws_arr)
        self.assertEqual(corr_autocorr[0], 1.)
        npt.assert_array_equal(jnp.abs(corr_autocorr[1]) > 0.5, True)
        npt.assert_array_equal(jnp.abs(corr_autocorr[50]) < 0.5, True)
        npt.assert_array_equal(jnp.abs(corr_autocorr[-1]) < 0.3, True)

    def test_cdict_pot(self):
        # Potential
        ind_autocorr = metrics.autocorrelation(self.ind_draws_cdict_pot)
        self.assertEqual(ind_autocorr[0], 1.)
        npt.assert_array_equal(jnp.abs(ind_autocorr[1:]) < 0.3, True)

        corr_autocorr = metrics.autocorrelation(self.corr_draws_cdict_pot)
        self.assertEqual(corr_autocorr[0], 1.)
        npt.assert_array_equal(jnp.abs(corr_autocorr[1]) > 0.5, True)
        npt.assert_array_equal(jnp.abs(corr_autocorr[50]) < 0.5, True)
        npt.assert_array_equal(jnp.abs(corr_autocorr[-1]) < 0.3, True)
Ejemplo n.º 2
0
class testKSDStdGaussian(unittest.TestCase):
    key = random.PRNGKey(0)
    dim = 2

    n_small = 10
    ind_draws_arr_n_small = random.normal(key, (n_small, dim))
    sample_n_small = cdict(value=ind_draws_arr_n_small,
                           grad_potential=ind_draws_arr_n_small)

    n_large = 1000
    ind_draws_arr_n_large = random.normal(key, (n_large, dim))
    sample_n_large = cdict(value=ind_draws_arr_n_large,
                           grad_potential=ind_draws_arr_n_large)

    def testksd_gaussian_kernel(self):
        kernel = kernels.Gaussian(bandwidth=1.)
        ksd_n_small_a = metrics.ksd(self.sample_n_small, kernel)
        ksd_n_large_a = metrics.ksd(self.sample_n_large, kernel)
        ksd_n_large_a_minibatch = metrics.ksd(self.sample_n_small,
                                              kernel,
                                              ensemble_batchsize=100,
                                              random_key=self.key)

        self.assertLess(ksd_n_large_a, ksd_n_small_a)
        npt.assert_almost_equal(ksd_n_large_a, ksd_n_large_a_minibatch, 1)

        kernel.parameters.bandwidth = 10.
        ksd_n_small_b = metrics.ksd(self.sample_n_small, kernel)
        ksd_n_large_b = metrics.ksd(self.sample_n_large, kernel)

        self.assertLess(ksd_n_large_b, ksd_n_small_b)

    def testksd_IMQ_kernel(self):
        kernel = kernels.IMQ(bandwidth=1., c=1., beta=-0.5)
        ksd_n_small_a = metrics.ksd(self.sample_n_small, kernel)
        ksd_n_large_a = metrics.ksd(self.sample_n_large, kernel)
        ksd_n_large_a_minibatch = metrics.ksd(self.sample_n_small,
                                              kernel,
                                              ensemble_batchsize=100,
                                              random_key=self.key)

        self.assertLess(ksd_n_large_a, ksd_n_small_a)
        npt.assert_almost_equal(ksd_n_large_a, ksd_n_large_a_minibatch, 1)

        kernel.parameters.bandwidth = 10.
        kernel.parameters.c = 0.1
        kernel.parameters.beta = -0.1

        ksd_n_small_b = metrics.ksd(self.sample_n_small, kernel)
        ksd_n_large_b = metrics.ksd(self.sample_n_large, kernel)

        self.assertLess(ksd_n_large_b, ksd_n_small_b)
Ejemplo n.º 3
0
    def startup(self,
                scenario: Scenario,
                n: int,
                initial_state: cdict,
                initial_extra: cdict,
                startup_correction: bool = True,
                **kwargs) -> Tuple[cdict, cdict]:
        if initial_state is None:
            if is_implemented(scenario.prior_sample):
                initial_extra.random_key, sub_key = random.split(
                    initial_extra.random_key)
                init_vals = scenario.prior_sample(sub_key)
            else:
                init_vals = jnp.zeros(scenario.dim)
            initial_state = cdict(value=init_vals)

        self.max_iter = n - 1

        if 'correction' in kwargs.keys():
            self.correction = kwargs['correction']
            del kwargs['correction']

        self.correction = check_correction(self.correction)

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        if startup_correction:
            initial_state, initial_extra = self.correction.startup(
                scenario, self, n, initial_state, initial_extra, **kwargs)

        return initial_state, initial_extra
Ejemplo n.º 4
0
    def startup(self,
                scenario: Scenario,
                n: int,
                initial_state: Union[None, cdict],
                initial_extra: cdict,
                **kwargs) -> Tuple[cdict, cdict]:
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
            if hasattr(self, 'parameters') and hasattr(self.parameters, key):
                setattr(self.parameters, key, value)

        if not hasattr(self, 'max_iter')\
            or not (isinstance(self.max_iter, int)
                    or (isinstance(self.max_iter, jnp.ndarray) and self.max_iter.dtype == 'int32')):
            raise AttributeError(self.__repr__() + ' max_iter must be int')

        if not hasattr(initial_extra, 'iter'):
            initial_extra.iter = 0

        if hasattr(self, 'parameters'):
            if not hasattr(initial_extra, 'parameters'):
                initial_extra.parameters = cdict()
            for key, value in self.parameters.__dict__.items():
                if not hasattr(initial_extra.parameters, key) or getattr(initial_extra.parameters, key) is None:
                    setattr(initial_extra.parameters, key, value)
        return initial_state, initial_extra
Ejemplo n.º 5
0
    def simulate(self, t_all: jnp.ndarray, random_key: jnp.ndarray) -> cdict:

        len_t = len(t_all)

        random_keys = random.split(random_key, 2 * len_t)
        latent_keys = random_keys[:len_t]
        obs_keys = random_keys[len_t:]

        x_init = self.initial_sample(t_all[0], latent_keys[0])

        def transition_body(x, i):
            new_x = self.transition_sample(x, t_all[i - 1], t_all[i],
                                           latent_keys[i])
            return new_x, new_x

        _, x_all_but_zero = scan(transition_body, x_init, jnp.arange(1, len_t))

        x_all = jnp.append(x_init[jnp.newaxis], x_all_but_zero, axis=0)

        y = vmap(self.likelihood_sample)(x_all, t_all, obs_keys)

        out_cdict = cdict(x=x_all,
                          y=y,
                          t=t_all,
                          name=f'{self.name} simulation')
        return out_cdict
Ejemplo n.º 6
0
 def resample_final(self, sample: cdict) -> cdict:
     unweighted_vals = sample.value[
         -1,
         random.categorical(random.PRNGKey(1),
                            logits=sample.log_weight[-1],
                            shape=(self.n, ))]
     unweighted_sample = cdict(value=unweighted_vals)
     return unweighted_sample
Ejemplo n.º 7
0
class Testcdict(unittest.TestCase):
    cdict = core.cdict(test_arr=jnp.ones((10, 3)), test_float=3.)

    def test_init(self):
        npt.assert_(hasattr(self.cdict, 'test_arr'))
        npt.assert_array_equal(self.cdict.test_arr, jnp.ones((10, 3)))

        npt.assert_(hasattr(self.cdict, 'test_float'))
        npt.assert_equal(self.cdict.test_float, 3.)

    def test_copy(self):
        cdict2 = self.cdict.copy()
        npt.assert_(isinstance(cdict2, core.cdict))

        npt.assert_(isinstance(cdict2.test_arr, jnp.DeviceArray))
        npt.assert_array_equal(cdict2.test_arr, jnp.ones((10, 3)))

        npt.assert_(isinstance(cdict2.test_float, float))
        npt.assert_equal(cdict2.test_float, 3.)

        cdict2.test_arr = jnp.zeros(5)
        npt.assert_array_equal(self.cdict.test_arr, jnp.ones((10, 3)))

        cdict2.test_float = 9.
        npt.assert_equal(self.cdict.test_float, 3.)

    def test_getitem(self):
        cdict_0get = self.cdict[0]
        npt.assert_(isinstance(cdict_0get, core.cdict))

        npt.assert_(isinstance(cdict_0get.test_arr, jnp.DeviceArray))
        npt.assert_array_equal(cdict_0get.test_arr, jnp.ones(3))

        npt.assert_(isinstance(cdict_0get.test_float, float))
        npt.assert_equal(cdict_0get.test_float, 3.)

    def test_additem(self):
        cdict_other = core.cdict(test_arr=jnp.ones((2, 3)),
                                 test_float=7.,
                                 time=25.)

        self.cdict.time = 10.

        cdict_add = self.cdict + cdict_other

        npt.assert_(isinstance(cdict_add, core.cdict))

        npt.assert_(isinstance(cdict_add.test_arr, jnp.DeviceArray))
        npt.assert_array_equal(cdict_add.test_arr, jnp.ones((12, 3)))
        npt.assert_array_equal(cdict_add.time, 35.)

        npt.assert_(isinstance(cdict_add.test_float, float))
        npt.assert_equal(cdict_add.test_float, 3.)

        npt.assert_array_equal(self.cdict.test_arr, jnp.ones((10, 3)))
        npt.assert_equal(self.cdict.test_float, 3.)
        npt.assert_equal(self.cdict.time, 10.)
        del self.cdict.time
Ejemplo n.º 8
0
    def __init__(self, threshold: float = None, stepsize: float = None):
        super().__init__()
        self.parameters.threshold = threshold
        self.parameters.stepsize = stepsize

        self.tuning = cdict(parameter='threshold',
                            target=0.1,
                            metric='alpha',
                            monotonicity='increasing')
Ejemplo n.º 9
0
    def __init__(self, **kwargs):

        if not hasattr(self, 'tuning'):
            self.tuning = cdict(parameter='stepsize',
                                target=None,
                                metric='alpha',
                                monotonicity='decreasing')

        # Initiate sampler class (set any additional parameters from init)
        super().__init__(**kwargs)
Ejemplo n.º 10
0
class testSJD(unittest.TestCase):
    zeros_arr = jnp.zeros(10)
    zeros_cdict = cdict(value=zeros_arr)
    seq_arr = jnp.arange(10)
    seq_cdict = cdict(value=seq_arr)

    def test_array(self):
        accept_rate = metrics.squared_jumping_distance(self.zeros_arr)
        self.assertEqual(accept_rate, 0.)

        accept_rate = metrics.squared_jumping_distance(self.seq_arr)
        self.assertEqual(accept_rate, 1.)

    def test_cdict(self):
        accept_rate = metrics.squared_jumping_distance(self.zeros_cdict)
        self.assertEqual(accept_rate, 0.)

        accept_rate = metrics.squared_jumping_distance(self.seq_cdict)
        self.assertEqual(accept_rate, 1.)
Ejemplo n.º 11
0
class testIAT(unittest.TestCase):
    key = random.PRNGKey(0)
    ind_draws_arr = random.normal(key, (100, ))
    ind_draws_cdict = cdict(value=ind_draws_arr)

    corr_draws_arr = 0.1 * jnp.cumsum(ind_draws_arr)

    def test_array(self):
        ind_iat = metrics.integrated_autocorrelation_time(self.ind_draws_arr)
        corr_iat = metrics.integrated_autocorrelation_time(self.corr_draws_arr)
        self.assertLess(ind_iat, 2.)
        self.assertGreater(corr_iat, 8.)
Ejemplo n.º 12
0
    def __init__(self,
                 name: str = None,
                 **kwargs):

        if name is not None:
            self.name = name

        if not hasattr(self, 'parameters'):
            self.parameters = cdict()

        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
            else:
                setattr(self.parameters, key, value)
Ejemplo n.º 13
0
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
        if initial_state is None:
            initial_extra.random_key, sub_key = random.split(
                initial_extra.random_key)
            if is_implemented(scenario.prior_sample):
                init_vals = vmap(scenario.prior_sample)(random.split(
                    sub_key, n))
            else:
                init_vals = random.normal(sub_key, shape=(n, scenario.dim))
            initial_state = cdict(value=init_vals)

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        return initial_state, initial_extra
Ejemplo n.º 14
0
class TestLeapfrog(unittest.TestCase):
    stepsize = 0.1
    leapfrog_steps = 3

    start_state = cdict(value=jnp.zeros(2),
                        potential=jnp.array(0.),
                        grad_potential=jnp.array([1., 2.]),
                        momenta=jnp.ones(2),
                        auxiliary_float=5.,
                        auxiliary_0darray=jnp.array(0.),
                        auxiliary_2darray=jnp.zeros(2))

    def test_leapfrog(self):
        full_state = utils.leapfrog(
            lambda x, _: (0., x), self.start_state, self.stepsize,
            random.split(random.PRNGKey(0), self.leapfrog_steps))
        out_state = full_state[-1]

        npt.assert_array_equal(out_state.value,
                               jnp.array([0.28120947, 0.266409]))
        npt.assert_array_equal(out_state.grad_potential,
                               jnp.array([0.28120947, 0.266409]))
        npt.assert_array_equal(out_state.momenta,
                               jnp.array([0.9075344, 0.8597696]))

    def test_all_leapfrog(self):
        full_state = utils.leapfrog(
            lambda x, _: (0., x), self.start_state, self.stepsize,
            random.split(random.PRNGKey(0), self.leapfrog_steps))

        npt.assert_array_equal(full_state.value.shape,
                               (self.leapfrog_steps + 1, 2))
        npt.assert_array_equal(full_state.momenta.shape,
                               (2 * self.leapfrog_steps + 1, 2))

        npt.assert_array_equal(full_state.value[0], jnp.zeros(2))
        npt.assert_array_equal(full_state.value[-1],
                               jnp.array([0.28120947, 0.266409]))

        npt.assert_array_equal(full_state.momenta[0], jnp.ones(2))
        npt.assert_array_equal(full_state.momenta[-1],
                               jnp.array([0.9075344, 0.8597696]))
Ejemplo n.º 15
0
    def test_additem(self):
        cdict_other = core.cdict(test_arr=jnp.ones((2, 3)),
                                 test_float=7.,
                                 time=25.)

        self.cdict.time = 10.

        cdict_add = self.cdict + cdict_other

        npt.assert_(isinstance(cdict_add, core.cdict))

        npt.assert_(isinstance(cdict_add.test_arr, jnp.DeviceArray))
        npt.assert_array_equal(cdict_add.test_arr, jnp.ones((12, 3)))
        npt.assert_array_equal(cdict_add.time, 35.)

        npt.assert_(isinstance(cdict_add.test_float, float))
        npt.assert_equal(cdict_add.test_float, 3.)

        npt.assert_array_equal(self.cdict.test_arr, jnp.ones((10, 3)))
        npt.assert_equal(self.cdict.test_float, 3.)
        npt.assert_equal(self.cdict.time, 10.)
        del self.cdict.time
Ejemplo n.º 16
0
def run(scenario: Scenario,
        sampler: Union[Sampler, Type[Sampler]],
        n: int,
        random_key: Union[None, jnp.ndarray],
        initial_state: cdict = None,
        initial_extra: cdict = None,
        **kwargs) -> Union[cdict, Tuple[cdict, jnp.ndarray]]:

    if isclass(sampler):
        sampler = sampler(**kwargs)

    sampler.n = n

    if initial_extra is None:
        initial_extra = cdict()
    if random_key is not None:
        initial_extra.random_key = random_key

    initial_state, initial_extra = sampler.startup(scenario, n, initial_state, initial_extra,
                                                   **kwargs)

    summary = sampler.summary(scenario, initial_state, initial_extra)

    transport_kernel = partial(sampler.update, scenario)

    start = time()
    chain = while_loop_stacked(lambda state, extra: ~sampler.termination_criterion(state, extra),
                               transport_kernel,
                               (initial_state, initial_extra),
                               sampler.max_iter)
    chain = initial_state[jnp.newaxis] + chain
    chain = sampler.clean_chain(scenario, chain)
    chain.value.block_until_ready()
    end = time()
    chain.time = end - start

    chain.summary = summary

    return chain
Ejemplo n.º 17
0
def initiate_particles(ssm_scenario: StateSpaceModel,
                       particle_filter: ParticleFilter,
                       n: int,
                       random_key: jnp.ndarray,
                       y: jnp.ndarray = None,
                       t: float = None) -> cdict:
    particle_filter.startup(ssm_scenario)

    sub_keys = random.split(random_key, n)

    init_vals, init_log_weight = particle_filter.initial_sample_and_weight_vectorised(
        ssm_scenario, y, t, sub_keys)

    if init_vals.ndim == 1:
        init_vals = init_vals[..., jnp.newaxis]

    initial_sample = cdict(
        value=init_vals[jnp.newaxis],
        log_weight=init_log_weight[jnp.newaxis],
        t=jnp.atleast_1d(t) if t is not None else jnp.zeros(1),
        y=y[jnp.newaxis] if y is not None else None,
        ess=jnp.atleast_1d(ess_log_weight(init_log_weight)))
    return initial_sample
Ejemplo n.º 18
0
    def startup(self,
                abc_scenario: ABCScenario,
                n: int,
                initial_state: cdict = None,
                initial_extra: cdict = None,
                **kwargs) -> Tuple[cdict, cdict]:
        if initial_state is None:
            if is_implemented(abc_scenario.prior_sample):
                initial_extra.random_key, sub_key = random.split(
                    initial_extra.random_key)
                init_vals = abc_scenario.prior_sample(sub_key)
            else:
                init_vals = jnp.zeros(abc_scenario.dim)
            initial_state = cdict(value=init_vals)

        self.max_iter = n - 1

        initial_state, initial_extra = super().startup(abc_scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        initial_state.log_weight = self.log_weight(abc_scenario, initial_state,
                                                   initial_extra)
        return initial_state, initial_extra
Ejemplo n.º 19
0
def propagate_particle_smoother_bs(ssm_scenario: StateSpaceModel,
                                   particle_filter: ParticleFilter,
                                   particles: cdict,
                                   y_new: jnp.ndarray,
                                   t_new: float,
                                   random_key: jnp.ndarray,
                                   lag: int,
                                   ess_threshold: float,
                                   maximum_rejections: int,
                                   init_bound_param: float,
                                   bound_inflation: float) -> cdict:
    n = particles.value.shape[1]

    if not hasattr(particles, 'num_transition_evals'):
        particles.num_transition_evals = jnp.array(0)

    if not hasattr(particles, 'marginal_filter'):
        particles.marginal_filter = cdict(value=particles.value,
                                          log_weight=particles.log_weight,
                                          y=particles.y,
                                          t=particles.t,
                                          ess=particles.ess)

    split_keys = random.split(random_key, 4)

    out_particles = particles

    # Propagate marginal filter particles
    out_particles.marginal_filter = propagate_particle_filter(ssm_scenario, particle_filter, particles.marginal_filter,
                                                              y_new, t_new, split_keys[1], ess_threshold, False)
    out_particles.y = jnp.append(out_particles.y, y_new[jnp.newaxis], axis=0)
    out_particles.t = jnp.append(out_particles.t, t_new)
    out_particles.log_weight = jnp.zeros(n)
    out_particles.ess = out_particles.marginal_filter.ess[-1]

    len_t = len(out_particles.t)
    stitch_ind_min_1 = len_t - lag - 1
    stitch_ind = len_t - lag

    def back_sim_only(marginal_filter):
        backward_sim = backward_simulation(ssm_scenario,
                                           marginal_filter,
                                           split_keys[2],
                                           n,
                                           maximum_rejections,
                                           init_bound_param,
                                           bound_inflation)
        return backward_sim.value, backward_sim.num_transition_evals.sum()

    def back_sim_and_stitch(marginal_filter):
        backward_sim = backward_simulation(ssm_scenario,
                                           marginal_filter[stitch_ind_min_1:],
                                           split_keys[2],
                                           n,
                                           maximum_rejections,
                                           init_bound_param,
                                           bound_inflation)

        vals, stitch_nte = fixed_lag_stitching(ssm_scenario,
                                               out_particles.value[:(stitch_ind_min_1 + 1)],
                                               out_particles.t[stitch_ind_min_1],
                                               backward_sim.value,
                                               jnp.zeros(n),
                                               out_particles.t[stitch_ind],
                                               random_key,
                                               maximum_rejections,
                                               init_bound_param,
                                               bound_inflation)
        return vals, stitch_nte + backward_sim.num_transition_evals.sum()

    if stitch_ind_min_1 >= 0:
        out_particles.value, num_transition_evals = back_sim_and_stitch(out_particles.marginal_filter)
    else:
        out_particles.value, num_transition_evals = back_sim_only(out_particles.marginal_filter)

    out_particles.num_transition_evals = jnp.append(out_particles.num_transition_evals, num_transition_evals)

    # out_particles.value = cond(stitch_ind_min_1 >= 0,
    #                            back_sim_and_stitch,
    #                            back_sim_only,
    #                            out_particles.marginal_filter)
    return out_particles
Ejemplo n.º 20
0
 def __init__(self, **kwargs):
     if not hasattr(self, 'parameters'):
         self.parameters = cdict()
     self.parameters.__dict__.update(kwargs)