Example #1
0
    def update(self, scenario: Scenario, ensemble_state: cdict,
               extra: cdict) -> Tuple[cdict, cdict]:
        extra.iter = extra.iter + 1
        n = ensemble_state.value.shape[0]

        resample_bool = self.resample_criterion(ensemble_state, extra)

        random_keys_all = random.split(extra.random_key, n + 2)
        extra.random_key = random_keys_all[-1]

        resampled_ensemble_state \
            = cond(resample_bool,
                   lambda state: self.resample(state, random_keys_all[-2]),
                   lambda state: state,
                   ensemble_state)

        advanced_state = vmap(self.forward_proposal,
                              in_axes=(None, 0, None,
                                       0))(scenario, resampled_ensemble_state,
                                           extra, random_keys_all[:n])
        advanced_state, advanced_extra = self.adapt(resampled_ensemble_state,
                                                    extra, advanced_state,
                                                    extra)

        return advanced_state, advanced_extra
Example #2
0
    def startup(self,
                abc_scenario: ABCScenario,
                n: int,
                initial_state: cdict,
                initial_extra: cdict,
                startup_correction: bool = True,
                **kwargs) -> Tuple[cdict, cdict]:
        initial_state, initial_extra = super().startup(abc_scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        if not hasattr(initial_state, 'prior_potential') and is_implemented(
                abc_scenario.prior_potential):
            initial_extra.random_key, subkey = random.split(
                initial_extra.random_key)
            initial_state.prior_potential = abc_scenario.prior_potential(
                initial_state.value, subkey)

        if not hasattr(initial_state, 'simulated_data'):
            initial_extra.random_key, subkey = random.split(
                initial_extra.random_key)
            initial_state.simulated_data = abc_scenario.likelihood_sample(
                initial_state.value, subkey)

        if not hasattr(initial_state, 'distance'):
            initial_state.distance = abc_scenario.distance_function(
                initial_state.simulated_data)
        return initial_state, initial_extra
Example #3
0
    def startup(self,
                scenario: Scenario,
                n: int,
                initial_state: Union[None, cdict],
                initial_extra: cdict,
                **kwargs) -> Tuple[cdict, cdict]:
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
            if hasattr(self, 'parameters') and hasattr(self.parameters, key):
                setattr(self.parameters, key, value)

        if not hasattr(self, 'max_iter')\
            or not (isinstance(self.max_iter, int)
                    or (isinstance(self.max_iter, jnp.ndarray) and self.max_iter.dtype == 'int32')):
            raise AttributeError(self.__repr__() + ' max_iter must be int')

        if not hasattr(initial_extra, 'iter'):
            initial_extra.iter = 0

        if hasattr(self, 'parameters'):
            if not hasattr(initial_extra, 'parameters'):
                initial_extra.parameters = cdict()
            for key, value in self.parameters.__dict__.items():
                if not hasattr(initial_extra.parameters, key) or getattr(initial_extra.parameters, key) is None:
                    setattr(initial_extra.parameters, key, value)
        return initial_state, initial_extra
Example #4
0
 def clean_chain(self, scenario: Scenario,
                 chain_ensemble_state: cdict) -> cdict:
     chain_ensemble_state.temperature = chain_ensemble_state.temperature[:,
                                                                         0]
     scenario.temperature = float(chain_ensemble_state.temperature[-1])
     chain_ensemble_state.ess = chain_ensemble_state.ess[:, 0]
     return chain_ensemble_state
Example #5
0
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        if self.parameters.ensemble_batchsize is None:
            self.parameters.ensemble_batchsize = n
            initial_extra.parameters.ensemble_batchsize = n

        if self.parameters.ensemble_batchsize == n:
            self.get_batch_inds = lambda _: jnp.repeat(
                jnp.arange(n)[None], n, axis=0)
        else:
            self.get_batch_inds = lambda rk: random.choice(
                rk, n, shape=(
                    n,
                    self.parameters.ensemble_batchsize,
                ))

        del initial_extra.parameters.stepsize

        random_keys = random.split(initial_extra.random_key, n + 1)
        initial_extra.random_key = random_keys[-1]

        initial_state.potential, initial_state.grad_potential = vmap(
            scenario.potential_and_grad)(initial_state.value, random_keys[:n])

        initial_state, initial_extra = self.adapt(initial_state, initial_extra)

        self.opt_init, self.opt_update, self.get_params = self.optimiser(
            step_size=self.parameters.stepsize,
            **initial_extra.parameters.optim_params)
        initial_extra.opt_state = self.opt_init(initial_state.value)
        return initial_state, initial_extra
Example #6
0
 def startup(self, scenario: Scenario, n: int, initial_state: cdict,
             initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
     initial_state, initial_extra = super().startup(scenario, n,
                                                    initial_state,
                                                    initial_extra, **kwargs)
     if not hasattr(initial_state, 'log_weight'):
         initial_state.log_weight = jnp.zeros(n)
     if not hasattr(initial_state, 'ess'):
         initial_state.ess = jnp.zeros(n) + n
     return initial_state, initial_extra
Example #7
0
 def startup(self, scenario: Scenario, n: int, initial_state: cdict,
             initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
     initial_state, initial_extra = super().startup(scenario, n,
                                                    initial_state,
                                                    initial_extra, **kwargs)
     initial_extra.random_key, scen_key = random.split(
         initial_extra.random_key)
     initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad(
         initial_state.value, scen_key)
     return initial_state, initial_extra
Example #8
0
 def adapt(self, previous_ensemble_state: cdict, previous_extra: cdict,
           new_ensemble_state: cdict,
           new_extra: cdict) -> Tuple[cdict, cdict]:
     n = new_ensemble_state.value.shape[0]
     next_temperature = self.next_temperature(new_ensemble_state, new_extra)
     new_ensemble_state.temperature = jnp.ones(n) * next_temperature
     new_ensemble_state.log_weight = previous_ensemble_state.log_weight \
                                     + self.log_weight(previous_ensemble_state, previous_extra,
                                                       new_ensemble_state, new_extra)
     new_ensemble_state.ess = jnp.ones(n) * jnp.exp(
         log_ess_log_weight(new_ensemble_state.log_weight))
     return new_ensemble_state, new_extra
Example #9
0
 def startup(self, scenario: Scenario, n: int, initial_state: cdict,
             initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
     initial_state, initial_extra = super().startup(scenario, n,
                                                    initial_state,
                                                    initial_extra, **kwargs)
     initial_extra.random_key, scen_key = random.split(
         initial_extra.random_key)
     initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad(
         initial_state.value, scen_key)
     if not hasattr(
             initial_state,
             'momenta') or initial_state.momenta.shape[-1] != scenario.dim:
         initial_state.momenta = jnp.zeros(scenario.dim)
     return initial_state, initial_extra
Example #10
0
    def always(self, scenario: Scenario, reject_state: cdict,
               reject_extra: cdict) -> Tuple[cdict, cdict]:
        d = scenario.dim

        stepsize = reject_extra.parameters.stepsize
        friction = reject_extra.parameters.friction

        reject_state.momenta = reject_state.momenta * -1

        # Update p - exactly according to solution of OU process
        # Accepted even if leapfrog step is rejected
        reject_extra.random_key, subkey = random.split(reject_extra.random_key)
        reject_state.momenta = reject_state.momenta * jnp.exp(- friction * stepsize) \
                               + jnp.sqrt(1 - jnp.exp(- 2 * friction * stepsize)) * random.normal(subkey, (d,))
        return reject_state, reject_extra
Example #11
0
    def forward_proposal(self, scenario: Scenario, state: cdict, extra: cdict,
                         random_key: jnp.ndarray) -> cdict:
        def mcmc_kernel(
                previous_carry: Tuple[cdict, cdict],
                _: None) -> Tuple[Tuple[cdict, cdict], Tuple[cdict, cdict]]:
            new_carry = self.mcmc_sampler.update(scenario, *previous_carry)
            return new_carry, new_carry

        extra.random_key = random_key

        start_state, start_extra = self.mcmc_sampler.startup(
            scenario, extra.parameters.mcmc_steps, state, extra)

        final_carry, chain = scan(mcmc_kernel, (start_state, start_extra),
                                  None,
                                  length=self.parameters.mcmc_steps)

        advanced_state, advanced_extra = self.clean_mcmc_chain(
            chain[0], chain[1])

        advanced_state.prior_potential = scenario.prior_potential(
            advanced_state.value, advanced_extra.random_key)
        advanced_state.likelihood_potential = (advanced_state.potential - advanced_state.prior_potential) \
                                              / scenario.temperature
        return advanced_state
Example #12
0
def backward_simulation_full(ssm_scenario: StateSpaceModel,
                             marginal_particles: cdict,
                             n_samps: int,
                             random_key: jnp.ndarray) -> cdict:
    marg_particles_vals = marginal_particles.value
    times = marginal_particles.t
    marginal_log_weight = marginal_particles.log_weight

    T, n_pf, d = marg_particles_vals.shape

    t_keys = random.split(random_key, T)
    final_particle_vals = marg_particles_vals[-1, random.categorical(t_keys[-1],
                                                                     marginal_log_weight[-1],
                                                                     shape=(n_samps,))]

    def back_sim_body(x_tplus1_all: jnp.ndarray, ind: int):
        x_t_all = full_resampling(ssm_scenario, marg_particles_vals[ind], times[ind],
                                  x_tplus1_all, times[ind + 1], marginal_log_weight[ind], t_keys[ind])
        return x_t_all, x_t_all

    _, back_sim_particles = scan(back_sim_body,
                                 final_particle_vals,
                                 jnp.arange(T - 2, -1, -1))

    out_samps = marginal_particles.copy()
    out_samps.value = jnp.vstack([back_sim_particles[::-1], final_particle_vals[jnp.newaxis]])
    out_samps.num_transition_evals = jnp.append(0, jnp.ones(T - 1) * n_pf * n_samps)
    del out_samps.log_weight
    return out_samps
Example #13
0
    def forward_proposal_non_zero_weight(self,
                                         abc_scenario: ABCScenario,
                                         state: cdict,
                                         extra: cdict,
                                         random_key: jnp.ndarray) -> cdict:

        def mcmc_kernel(previous_carry: Tuple[cdict, cdict],
                        _: None) -> Tuple[Tuple[cdict, cdict], Tuple[cdict, cdict]]:
            new_carry = self.mcmc_sampler.update(abc_scenario, *previous_carry)
            return new_carry, new_carry

        extra.random_key = random_key

        start_state, start_extra = self.mcmc_sampler.startup(abc_scenario,
                                                             extra.parameters.mcmc_steps,
                                                             state,
                                                             extra)

        final_carry, chain = scan(mcmc_kernel,
                                  (start_state, start_extra),
                                  None,
                                  length=self.parameters.mcmc_steps)

        advanced_state, advanced_extra = self.clean_mcmc_chain(chain[0], chain[1])
        return advanced_state
Example #14
0
 def clean_chain_ar(self, abc_scenario: ABCScenario, chain_state: cdict):
     threshold = jnp.quantile(chain_state.distance,
                              self.parameters.acceptance_rate)
     self.parameters.threshold = float(threshold)
     chain_state.log_weight = jnp.where(chain_state.distance < threshold,
                                        0., -jnp.inf)
     return chain_state
Example #15
0
    def particle_filter_body(samps_previous: cdict,
                             iter_ind: int) -> Tuple[cdict, cdict]:
        x_previous = samps_previous.value
        log_weight_previous = samps_previous.log_weight
        int_rand_key = int_rand_keys[iter_ind]

        ess_previous = samps_previous.ess
        resample_bool = ess_previous < (ess_threshold * n)
        x_res = cond(resample_bool, _resample, lambda tup: tup[0],
                     (x_previous, log_weight_previous, int_rand_key))
        log_weight_res = jnp.where(resample_bool, jnp.zeros(n),
                                   log_weight_previous)

        split_keys = random.split(int_rand_key, len(x_previous))

        x_new, log_weight_new = particle_filter.propose_and_intermediate_weight_vectorised(
            ssm_scenario, x_res, samps_previous.t, y[iter_ind], t[iter_ind],
            split_keys)

        log_weight_new = log_weight_res + log_weight_new

        samps_new = samps_previous.copy()
        samps_new.value = x_new
        samps_new.log_weight = log_weight_new
        samps_new.y = y[iter_ind]
        samps_new.t = t[iter_ind]
        samps_new.ess = ess_log_weight(log_weight_new)
        return samps_new, samps_new
Example #16
0
    def startup(self,
                scenario: Scenario,
                n: int,
                initial_state: cdict,
                initial_extra: cdict,
                startup_correction: bool = True,
                **kwargs) -> Tuple[cdict, cdict]:
        if initial_state is None:
            if is_implemented(scenario.prior_sample):
                initial_extra.random_key, sub_key = random.split(
                    initial_extra.random_key)
                init_vals = scenario.prior_sample(sub_key)
            else:
                init_vals = jnp.zeros(scenario.dim)
            initial_state = cdict(value=init_vals)

        self.max_iter = n - 1

        if 'correction' in kwargs.keys():
            self.correction = kwargs['correction']
            del kwargs['correction']

        self.correction = check_correction(self.correction)

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        if startup_correction:
            initial_state, initial_extra = self.correction.startup(
                scenario, self, n, initial_state, initial_extra, **kwargs)

        return initial_state, initial_extra
Example #17
0
 def proposal(self, abc_scenario: ABCScenario, reject_state: cdict,
              reject_extra: cdict) -> Tuple[cdict, cdict]:
     proposed_state = reject_state.copy()
     proposed_extra = reject_extra.copy()
     stepsize = reject_extra.parameters.stepsize
     proposed_extra.random_key, subkey1, subkey2, subkey3 = random.split(
         reject_extra.random_key, 4)
     proposed_state.value = reject_state.value + jnp.sqrt(
         stepsize) * random.normal(subkey1, (abc_scenario.dim, ))
     proposed_state.prior_potential = abc_scenario.prior_potential(
         proposed_state.value, subkey2)
     proposed_state.simulated_data = abc_scenario.likelihood_sample(
         proposed_state.value, subkey3)
     proposed_state.distance = abc_scenario.distance_function(
         proposed_state.simulated_data)
     return proposed_state, proposed_extra
Example #18
0
 def startup(self, abc_scenario: ABCScenario, sampler: MCMCSampler, n: int,
             initial_state: cdict, initial_extra: cdict,
             **kwargs) -> Tuple[cdict, cdict]:
     initial_state, initial_extra = super().startup(abc_scenario, sampler,
                                                    n, initial_state,
                                                    initial_extra, **kwargs)
     if sampler.parameters.stepsize is None:
         initial_extra.parameters.stepsize = 1.0
     if sampler.parameters.threshold is None:
         initial_extra.parameters.threshold = 50.
         initial_state.threshold = initial_extra.parameters.threshold
     initial_extra.parameters.stepsize = jnp.ones(
         abc_scenario.dim) * initial_extra.parameters.stepsize
     initial_state.stepsize = initial_extra.parameters.stepsize
     initial_extra.post_mean = initial_state.value
     initial_extra.diag_post_cov = initial_extra.parameters.stepsize * abc_scenario.dim / 2.38**2
     return initial_state, initial_extra
Example #19
0
 def startup(self, scenario: Scenario, sampler: MCMCSampler, n: int,
             initial_state: cdict, initial_extra: cdict,
             **kwargs) -> Tuple[cdict, cdict]:
     initial_state, initial_extra = super().startup(scenario, sampler, n,
                                                    initial_state,
                                                    initial_extra, **kwargs)
     initial_state.alpha = 1.
     return initial_state, initial_extra
Example #20
0
 def proposal(self, abc_scenario: ABCScenario, reject_state: cdict,
              reject_extra: cdict) -> Tuple[cdict, cdict]:
     proposed_state = reject_state.copy()
     proposed_extra = reject_extra.copy()
     proposed_extra.random_key, subkey1, subkey2, subkey3 = random.split(
         reject_extra.random_key, 4)
     proposed_state.value = self.importance_proposal(abc_scenario, subkey1)
     proposed_state.prior_potential = abc_scenario.prior_potential(
         proposed_state.value, subkey2)
     proposed_state.simulated_data = abc_scenario.likelihood_sample(
         proposed_state.value, subkey3)
     proposed_state.distance = abc_scenario.distance_function(
         proposed_state.simulated_data)
     proposed_state.log_weight = self.log_weight(abc_scenario,
                                                 proposed_state,
                                                 proposed_extra)
     return proposed_state, proposed_extra
Example #21
0
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
        if not hasattr(scenario, 'prior_sample'):
            raise TypeError(
                f'Likelihood tempering requires scenario {scenario.name} to have prior_sample implemented'
            )

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        random_keys = random.split(initial_extra.random_key, 2 * n + 1)

        initial_extra.random_key = random_keys[-1]
        initial_state.prior_potential = vmap(scenario.prior_potential)(
            initial_state.value, random_keys[:n])
        initial_state.likelihood_potential = vmap(
            scenario.likelihood_potential)(initial_state.value,
                                           random_keys[n:(2 * n)])
        initial_state.potential = initial_state.prior_potential
        initial_state.temperature = jnp.zeros(n)
        initial_state.log_weight = jnp.zeros(n)
        initial_state.ess = jnp.zeros(n) + n

        scenario.temperature = 0.

        return initial_state, initial_extra
Example #22
0
    def proposal(self, scenario: Scenario, reject_state: cdict,
                 reject_extra: cdict) -> Tuple[cdict, cdict]:
        proposed_state = reject_state.copy()

        d = scenario.dim
        x = reject_state.value

        stepsize = reject_extra.parameters.stepsize

        reject_extra.random_key, subkey, scen_key = random.split(
            reject_extra.random_key, 3)

        proposed_state.value = x + jnp.sqrt(stepsize) * random.normal(
            subkey, (d, ))
        proposed_state.potential = scenario.potential(proposed_state.value,
                                                      scen_key)

        return proposed_state, reject_extra
Example #23
0
    def adapt(self,
              previous_ensemble_state: cdict,
              previous_extra: cdict,
              new_ensemble_state: cdict,
              new_extra: cdict) -> Tuple[cdict, cdict]:
        n = new_ensemble_state.value.shape[0]
        next_threshold = self.next_threshold(new_ensemble_state, new_extra)
        new_ensemble_state.threshold = jnp.ones(n) * next_threshold
        new_extra.parameters.threshold = next_threshold
        new_ensemble_state.log_weight = self.log_weight(previous_ensemble_state, previous_extra,
                                                        new_ensemble_state, new_extra)
        new_ensemble_state.ess = jnp.ones(n) * ess_log_weight(new_ensemble_state.log_weight)
        alive_inds = previous_ensemble_state.log_weight > -jnp.inf
        new_extra.alpha_mean = (new_ensemble_state.alpha * alive_inds).sum() / alive_inds.sum()
        new_ensemble_state, new_extra = self.adapt_mcmc_params(previous_ensemble_state, previous_extra,
                                                               new_ensemble_state, new_extra)

        return new_ensemble_state, new_extra
Example #24
0
    def proposal(self, scenario: Scenario, reject_state: cdict,
                 reject_extra: cdict) -> Tuple[cdict, cdict]:
        d = scenario.dim
        random_keys = random.split(reject_extra.random_key,
                                   self.parameters.leapfrog_steps + 2)

        reject_extra.random_key = random_keys[0]

        reject_state.momenta = random.normal(random_keys[1], (d, ))

        all_leapfrog_state = utils.leapfrog(scenario.potential_and_grad,
                                            reject_state,
                                            reject_extra.parameters.stepsize,
                                            random_keys[2:])
        proposed_state = all_leapfrog_state[-1]

        # Technically we should reverse momenta now
        # but momenta target is symmetric and then immediately resampled at the next step anyway

        return proposed_state, reject_extra
Example #25
0
    def leapfrog_step(init_state: cdict, i: int):
        new_state = init_state.copy()

        p_half = init_state.momenta - stepsize / 2. * init_state.grad_potential
        new_state.value = init_state.value + stepsize * p_half

        new_state.potential, new_state.grad_potential = potential_and_grad(
            new_state.value, random_keys[i])

        new_state.momenta = p_half - stepsize / 2. * new_state.grad_potential

        next_sample_chain = new_state.copy()
        next_sample_chain.momenta = jnp.vstack([p_half, new_state.momenta])
        return new_state, next_sample_chain
Example #26
0
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:

        self.mcmc_sampler.correction = check_correction(
            self.mcmc_sampler.correction)

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        first_temp = self.next_temperature(initial_state, initial_extra)
        scenario.temperature = first_temp
        initial_state.temperature += first_temp
        initial_state.potential = initial_state.prior_potential + first_temp * initial_state.likelihood_potential
        initial_state.log_weight = -first_temp * initial_state.likelihood_potential
        initial_state.ess = jnp.repeat(
            jnp.exp(log_ess_log_weight(initial_state.log_weight)), n)

        initial_state, initial_extra = vmap(
            lambda state: self.mcmc_sampler.startup(
                scenario, n, state, initial_extra))(initial_state)
        initial_extra = initial_extra[0]
        return initial_state, initial_extra
Example #27
0
    def startup(self,
                abc_scenario: ABCScenario,
                n: int,
                initial_state: cdict = None,
                initial_extra: cdict = None,
                **kwargs) -> Tuple[cdict, cdict]:
        if initial_state is None:
            if is_implemented(abc_scenario.prior_sample):
                initial_extra.random_key, sub_key = random.split(
                    initial_extra.random_key)
                init_vals = abc_scenario.prior_sample(sub_key)
            else:
                init_vals = jnp.zeros(abc_scenario.dim)
            initial_state = cdict(value=init_vals)

        self.max_iter = n - 1

        initial_state, initial_extra = super().startup(abc_scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        initial_state.log_weight = self.log_weight(abc_scenario, initial_state,
                                                   initial_extra)
        return initial_state, initial_extra
Example #28
0
    def update(self, scenario: Scenario, ensemble_state: cdict,
               extra: cdict) -> Tuple[cdict, cdict]:
        n = ensemble_state.value.shape[0]
        extra.iter = extra.iter + 1

        random_keys = random.split(extra.random_key, n + 2)
        batch_inds = self.get_batch_inds(random_keys[-1])
        extra.random_key = random_keys[-2]

        phi_hat = self.kernelised_grad_matrix(ensemble_state.value,
                                              ensemble_state.grad_potential,
                                              extra.parameters.kernel_params,
                                              batch_inds)

        extra.opt_state = self.opt_update(extra.iter, -phi_hat,
                                          extra.opt_state)
        ensemble_state.value = self.get_params(extra.opt_state)

        ensemble_state.potential, ensemble_state.grad_potential \
            = vmap(scenario.potential_and_grad)(ensemble_state.value, random_keys[:n])

        ensemble_state, extra = self.adapt(ensemble_state, extra)

        return ensemble_state, extra
Example #29
0
    def proposal(self, scenario: Scenario, reject_state: cdict,
                 reject_extra: cdict) -> Tuple[cdict, cdict]:
        random_keys = random.split(reject_extra.random_key,
                                   self.parameters.leapfrog_steps + 1)

        reject_extra.random_key = random_keys[0]

        all_leapfrog_state = utils.leapfrog(scenario.potential_and_grad,
                                            reject_state,
                                            reject_extra.parameters.stepsize,
                                            random_keys[1:])
        proposed_state = all_leapfrog_state[-1]

        proposed_state.momenta *= -1

        return proposed_state, reject_extra
Example #30
0
    def startup(self, scenario: Scenario, n: int, initial_state: cdict,
                initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
        if initial_state is None:
            initial_extra.random_key, sub_key = random.split(
                initial_extra.random_key)
            if is_implemented(scenario.prior_sample):
                init_vals = vmap(scenario.prior_sample)(random.split(
                    sub_key, n))
            else:
                init_vals = random.normal(sub_key, shape=(n, scenario.dim))
            initial_state = cdict(value=init_vals)

        initial_state, initial_extra = super().startup(scenario, n,
                                                       initial_state,
                                                       initial_extra, **kwargs)

        return initial_state, initial_extra