Esempio n. 1
0
    def startup(self,
                abc_scenario: ABCScenario,
                n: int,
                initial_state: cdict,
                initial_extra: cdict,
                startup_correction: bool = True,
                **kwargs) -> Tuple[cdict, cdict]:
        initial_state, initial_extra = super().startup(abc_scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        if not hasattr(initial_state, 'prior_potential') and is_implemented(
                abc_scenario.prior_potential):
            initial_extra.random_key, subkey = random.split(
                initial_extra.random_key)
            initial_state.prior_potential = abc_scenario.prior_potential(
                initial_state.value, subkey)

        if not hasattr(initial_state, 'simulated_data'):
            initial_extra.random_key, subkey = random.split(
                initial_extra.random_key)
            initial_state.simulated_data = abc_scenario.likelihood_sample(
                initial_state.value, subkey)

        if not hasattr(initial_state, 'distance'):
            initial_state.distance = abc_scenario.distance_function(
                initial_state.simulated_data)
        return initial_state, initial_extra
Esempio n. 2
0
    def forward_proposal_non_zero_weight(self,
                                         abc_scenario: ABCScenario,
                                         state: cdict,
                                         extra: cdict,
                                         random_key: jnp.ndarray) -> cdict:

        def mcmc_kernel(previous_carry: Tuple[cdict, cdict],
                        _: None) -> Tuple[Tuple[cdict, cdict], Tuple[cdict, cdict]]:
            new_carry = self.mcmc_sampler.update(abc_scenario, *previous_carry)
            return new_carry, new_carry

        extra.random_key = random_key

        start_state, start_extra = self.mcmc_sampler.startup(abc_scenario,
                                                             extra.parameters.mcmc_steps,
                                                             state,
                                                             extra)

        final_carry, chain = scan(mcmc_kernel,
                                  (start_state, start_extra),
                                  None,
                                  length=self.parameters.mcmc_steps)

        advanced_state, advanced_extra = self.clean_mcmc_chain(chain[0], chain[1])
        return advanced_state
Esempio n. 3
0
    def forward_proposal(self, scenario: Scenario, state: cdict, extra: cdict,
                         random_key: jnp.ndarray) -> cdict:
        def mcmc_kernel(
                previous_carry: Tuple[cdict, cdict],
                _: None) -> Tuple[Tuple[cdict, cdict], Tuple[cdict, cdict]]:
            new_carry = self.mcmc_sampler.update(scenario, *previous_carry)
            return new_carry, new_carry

        extra.random_key = random_key

        start_state, start_extra = self.mcmc_sampler.startup(
            scenario, extra.parameters.mcmc_steps, state, extra)

        final_carry, chain = scan(mcmc_kernel, (start_state, start_extra),
                                  None,
                                  length=self.parameters.mcmc_steps)

        advanced_state, advanced_extra = self.clean_mcmc_chain(
            chain[0], chain[1])

        advanced_state.prior_potential = scenario.prior_potential(
            advanced_state.value, advanced_extra.random_key)
        advanced_state.likelihood_potential = (advanced_state.potential - advanced_state.prior_potential) \
                                              / scenario.temperature
        return advanced_state
Esempio n. 4
0
    def update(self, scenario: Scenario, ensemble_state: cdict,
               extra: cdict) -> Tuple[cdict, cdict]:
        extra.iter = extra.iter + 1
        n = ensemble_state.value.shape[0]

        resample_bool = self.resample_criterion(ensemble_state, extra)

        random_keys_all = random.split(extra.random_key, n + 2)
        extra.random_key = random_keys_all[-1]

        resampled_ensemble_state \
            = cond(resample_bool,
                   lambda state: self.resample(state, random_keys_all[-2]),
                   lambda state: state,
                   ensemble_state)

        advanced_state = vmap(self.forward_proposal,
                              in_axes=(None, 0, None,
                                       0))(scenario, resampled_ensemble_state,
                                           extra, random_keys_all[:n])
        advanced_state, advanced_extra = self.adapt(resampled_ensemble_state,
                                                    extra, advanced_state,
                                                    extra)

        return advanced_state, advanced_extra
Esempio n. 5
0
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        if self.parameters.ensemble_batchsize is None:
            self.parameters.ensemble_batchsize = n
            initial_extra.parameters.ensemble_batchsize = n

        if self.parameters.ensemble_batchsize == n:
            self.get_batch_inds = lambda _: jnp.repeat(
                jnp.arange(n)[None], n, axis=0)
        else:
            self.get_batch_inds = lambda rk: random.choice(
                rk, n, shape=(
                    n,
                    self.parameters.ensemble_batchsize,
                ))

        del initial_extra.parameters.stepsize

        random_keys = random.split(initial_extra.random_key, n + 1)
        initial_extra.random_key = random_keys[-1]

        initial_state.potential, initial_state.grad_potential = vmap(
            scenario.potential_and_grad)(initial_state.value, random_keys[:n])

        initial_state, initial_extra = self.adapt(initial_state, initial_extra)

        self.opt_init, self.opt_update, self.get_params = self.optimiser(
            step_size=self.parameters.stepsize,
            **initial_extra.parameters.optim_params)
        initial_extra.opt_state = self.opt_init(initial_state.value)
        return initial_state, initial_extra
Esempio n. 6
0
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
        if not hasattr(scenario, 'prior_sample'):
            raise TypeError(
                f'Likelihood tempering requires scenario {scenario.name} to have prior_sample implemented'
            )

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        random_keys = random.split(initial_extra.random_key, 2 * n + 1)

        initial_extra.random_key = random_keys[-1]
        initial_state.prior_potential = vmap(scenario.prior_potential)(
            initial_state.value, random_keys[:n])
        initial_state.likelihood_potential = vmap(
            scenario.likelihood_potential)(initial_state.value,
                                           random_keys[n:(2 * n)])
        initial_state.potential = initial_state.prior_potential
        initial_state.temperature = jnp.zeros(n)
        initial_state.log_weight = jnp.zeros(n)
        initial_state.ess = jnp.zeros(n) + n

        scenario.temperature = 0.

        return initial_state, initial_extra
Esempio n. 7
0
    def startup(self,
                scenario: Scenario,
                n: int,
                initial_state: cdict,
                initial_extra: cdict,
                startup_correction: bool = True,
                **kwargs) -> Tuple[cdict, cdict]:
        if initial_state is None:
            if is_implemented(scenario.prior_sample):
                initial_extra.random_key, sub_key = random.split(
                    initial_extra.random_key)
                init_vals = scenario.prior_sample(sub_key)
            else:
                init_vals = jnp.zeros(scenario.dim)
            initial_state = cdict(value=init_vals)

        self.max_iter = n - 1

        if 'correction' in kwargs.keys():
            self.correction = kwargs['correction']
            del kwargs['correction']

        self.correction = check_correction(self.correction)

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        if startup_correction:
            initial_state, initial_extra = self.correction.startup(
                scenario, self, n, initial_state, initial_extra, **kwargs)

        return initial_state, initial_extra
Esempio n. 8
0
 def startup(self, scenario: Scenario, n: int, initial_state: cdict,
             initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
     initial_state, initial_extra = super().startup(scenario, n,
                                                    initial_state,
                                                    initial_extra, **kwargs)
     initial_extra.random_key, scen_key = random.split(
         initial_extra.random_key)
     initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad(
         initial_state.value, scen_key)
     return initial_state, initial_extra
Esempio n. 9
0
 def startup(self, scenario: Scenario, n: int, initial_state: cdict,
             initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
     initial_state, initial_extra = super().startup(scenario, n,
                                                    initial_state,
                                                    initial_extra, **kwargs)
     initial_extra.random_key, scen_key = random.split(
         initial_extra.random_key)
     initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad(
         initial_state.value, scen_key)
     if not hasattr(
             initial_state,
             'momenta') or initial_state.momenta.shape[-1] != scenario.dim:
         initial_state.momenta = jnp.zeros(scenario.dim)
     return initial_state, initial_extra
Esempio n. 10
0
    def startup(self,
                abc_scenario: ABCScenario,
                n: int,
                initial_state: cdict,
                initial_extra: cdict,
                **kwargs) -> Tuple[cdict, cdict]:

        initial_state, initial_extra = SMCSampler.startup(self, abc_scenario, n,
                                                          initial_state, initial_extra, **kwargs)

        n = len(initial_state.value)
        if not hasattr(initial_state, 'prior_potential') and is_implemented(abc_scenario.prior_potential):
            random_keys = random.split(initial_extra.random_key, n + 1)
            initial_extra.random_key = random_keys[-1]
            initial_state.prior_potential = vmap(abc_scenario.prior_potential)(initial_state.value,
                                                                               random_keys[:n])

        if not hasattr(initial_state, 'simulated_data'):
            random_keys = random.split(initial_extra.random_key, n + 1)
            initial_extra.random_key = random_keys[-1]
            initial_state.simulated_data = vmap(abc_scenario.likelihood_sample)(initial_state.value,
                                                                                random_keys[:n])

        if not hasattr(initial_state, 'distance'):
            initial_state.distance = vmap(abc_scenario.distance_function)(initial_state.simulated_data)

        if not hasattr(initial_state, 'threshold'):
            if self.threshold_schedule is None:
                initial_state.threshold = jnp.zeros(n) + jnp.inf
            else:
                initial_state.threshold = jnp.zeros(n) + self.threshold_schedule[0]

        if not hasattr(initial_state, 'ess'):
            initial_state.ess = jnp.zeros(n) + n

        return initial_state, initial_extra
Esempio n. 11
0
    def always(self, scenario: Scenario, reject_state: cdict,
               reject_extra: cdict) -> Tuple[cdict, cdict]:
        d = scenario.dim

        stepsize = reject_extra.parameters.stepsize
        friction = reject_extra.parameters.friction

        reject_state.momenta = reject_state.momenta * -1

        # Update p - exactly according to solution of OU process
        # Accepted even if leapfrog step is rejected
        reject_extra.random_key, subkey = random.split(reject_extra.random_key)
        reject_state.momenta = reject_state.momenta * jnp.exp(- friction * stepsize) \
                               + jnp.sqrt(1 - jnp.exp(- 2 * friction * stepsize)) * random.normal(subkey, (d,))
        return reject_state, reject_extra
Esempio n. 12
0
    def proposal(self, scenario: Scenario, reject_state: cdict,
                 reject_extra: cdict) -> Tuple[cdict, cdict]:
        random_keys = random.split(reject_extra.random_key,
                                   self.parameters.leapfrog_steps + 1)

        reject_extra.random_key = random_keys[0]

        all_leapfrog_state = utils.leapfrog(scenario.potential_and_grad,
                                            reject_state,
                                            reject_extra.parameters.stepsize,
                                            random_keys[1:])
        proposed_state = all_leapfrog_state[-1]

        proposed_state.momenta *= -1

        return proposed_state, reject_extra
Esempio n. 13
0
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
        if initial_state is None:
            initial_extra.random_key, sub_key = random.split(
                initial_extra.random_key)
            if is_implemented(scenario.prior_sample):
                init_vals = vmap(scenario.prior_sample)(random.split(
                    sub_key, n))
            else:
                init_vals = random.normal(sub_key, shape=(n, scenario.dim))
            initial_state = cdict(value=init_vals)

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        return initial_state, initial_extra
Esempio n. 14
0
    def proposal(self, scenario: Scenario, reject_state: cdict,
                 reject_extra: cdict) -> Tuple[cdict, cdict]:
        proposed_state = reject_state.copy()

        d = scenario.dim
        x = reject_state.value

        stepsize = reject_extra.parameters.stepsize

        reject_extra.random_key, subkey, scen_key = random.split(
            reject_extra.random_key, 3)

        proposed_state.value = x + jnp.sqrt(stepsize) * random.normal(
            subkey, (d, ))
        proposed_state.potential = scenario.potential(proposed_state.value,
                                                      scen_key)

        return proposed_state, reject_extra
Esempio n. 15
0
    def proposal(self, scenario: Scenario, reject_state: cdict,
                 reject_extra: cdict) -> Tuple[cdict, cdict]:
        d = scenario.dim
        random_keys = random.split(reject_extra.random_key,
                                   self.parameters.leapfrog_steps + 2)

        reject_extra.random_key = random_keys[0]

        reject_state.momenta = random.normal(random_keys[1], (d, ))

        all_leapfrog_state = utils.leapfrog(scenario.potential_and_grad,
                                            reject_state,
                                            reject_extra.parameters.stepsize,
                                            random_keys[2:])
        proposed_state = all_leapfrog_state[-1]

        # Technically we should reverse momenta now
        # but momenta target is symmetric and then immediately resampled at the next step anyway

        return proposed_state, reject_extra
Esempio n. 16
0
def run(scenario: Scenario,
        sampler: Union[Sampler, Type[Sampler]],
        n: int,
        random_key: Union[None, jnp.ndarray],
        initial_state: cdict = None,
        initial_extra: cdict = None,
        **kwargs) -> Union[cdict, Tuple[cdict, jnp.ndarray]]:

    if isclass(sampler):
        sampler = sampler(**kwargs)

    sampler.n = n

    if initial_extra is None:
        initial_extra = cdict()
    if random_key is not None:
        initial_extra.random_key = random_key

    initial_state, initial_extra = sampler.startup(scenario, n, initial_state, initial_extra,
                                                   **kwargs)

    summary = sampler.summary(scenario, initial_state, initial_extra)

    transport_kernel = partial(sampler.update, scenario)

    start = time()
    chain = while_loop_stacked(lambda state, extra: ~sampler.termination_criterion(state, extra),
                               transport_kernel,
                               (initial_state, initial_extra),
                               sampler.max_iter)
    chain = initial_state[jnp.newaxis] + chain
    chain = sampler.clean_chain(scenario, chain)
    chain.value.block_until_ready()
    end = time()
    chain.time = end - start

    chain.summary = summary

    return chain
Esempio n. 17
0
    def startup(self,
                abc_scenario: ABCScenario,
                n: int,
                initial_state: cdict = None,
                initial_extra: cdict = None,
                **kwargs) -> Tuple[cdict, cdict]:
        if initial_state is None:
            if is_implemented(abc_scenario.prior_sample):
                initial_extra.random_key, sub_key = random.split(
                    initial_extra.random_key)
                init_vals = abc_scenario.prior_sample(sub_key)
            else:
                init_vals = jnp.zeros(abc_scenario.dim)
            initial_state = cdict(value=init_vals)

        self.max_iter = n - 1

        initial_state, initial_extra = super().startup(abc_scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        initial_state.log_weight = self.log_weight(abc_scenario, initial_state,
                                                   initial_extra)
        return initial_state, initial_extra
Esempio n. 18
0
    def update(self, scenario: Scenario, ensemble_state: cdict,
               extra: cdict) -> Tuple[cdict, cdict]:
        n = ensemble_state.value.shape[0]
        extra.iter = extra.iter + 1

        random_keys = random.split(extra.random_key, n + 2)
        batch_inds = self.get_batch_inds(random_keys[-1])
        extra.random_key = random_keys[-2]

        phi_hat = self.kernelised_grad_matrix(ensemble_state.value,
                                              ensemble_state.grad_potential,
                                              extra.parameters.kernel_params,
                                              batch_inds)

        extra.opt_state = self.opt_update(extra.iter, -phi_hat,
                                          extra.opt_state)
        ensemble_state.value = self.get_params(extra.opt_state)

        ensemble_state.potential, ensemble_state.grad_potential \
            = vmap(scenario.potential_and_grad)(ensemble_state.value, random_keys[:n])

        ensemble_state, extra = self.adapt(ensemble_state, extra)

        return ensemble_state, extra
Esempio n. 19
0
 def always(self, abc_scenario: ABCScenario, reject_state: cdict,
            reject_extra: cdict) -> Tuple[cdict, cdict]:
     reject_extra.random_key, _ = random.split(reject_extra.random_key)
     return reject_state, reject_extra