예제 #1
0
def backward_simulation_full(ssm_scenario: StateSpaceModel,
                             marginal_particles: cdict,
                             n_samps: int,
                             random_key: jnp.ndarray) -> cdict:
    marg_particles_vals = marginal_particles.value
    times = marginal_particles.t
    marginal_log_weight = marginal_particles.log_weight

    T, n_pf, d = marg_particles_vals.shape

    t_keys = random.split(random_key, T)
    final_particle_vals = marg_particles_vals[-1, random.categorical(t_keys[-1],
                                                                     marginal_log_weight[-1],
                                                                     shape=(n_samps,))]

    def back_sim_body(x_tplus1_all: jnp.ndarray, ind: int):
        x_t_all = full_resampling(ssm_scenario, marg_particles_vals[ind], times[ind],
                                  x_tplus1_all, times[ind + 1], marginal_log_weight[ind], t_keys[ind])
        return x_t_all, x_t_all

    _, back_sim_particles = scan(back_sim_body,
                                 final_particle_vals,
                                 jnp.arange(T - 2, -1, -1))

    out_samps = marginal_particles.copy()
    out_samps.value = jnp.vstack([back_sim_particles[::-1], final_particle_vals[jnp.newaxis]])
    out_samps.num_transition_evals = jnp.append(0, jnp.ones(T - 1) * n_pf * n_samps)
    del out_samps.log_weight
    return out_samps
예제 #2
0
 def proposal(self, abc_scenario: ABCScenario, reject_state: cdict,
              reject_extra: cdict) -> Tuple[cdict, cdict]:
     proposed_state = reject_state.copy()
     proposed_extra = reject_extra.copy()
     stepsize = reject_extra.parameters.stepsize
     proposed_extra.random_key, subkey1, subkey2, subkey3 = random.split(
         reject_extra.random_key, 4)
     proposed_state.value = reject_state.value + jnp.sqrt(
         stepsize) * random.normal(subkey1, (abc_scenario.dim, ))
     proposed_state.prior_potential = abc_scenario.prior_potential(
         proposed_state.value, subkey2)
     proposed_state.simulated_data = abc_scenario.likelihood_sample(
         proposed_state.value, subkey3)
     proposed_state.distance = abc_scenario.distance_function(
         proposed_state.simulated_data)
     return proposed_state, proposed_extra
예제 #3
0
    def particle_filter_body(samps_previous: cdict,
                             iter_ind: int) -> Tuple[cdict, cdict]:
        x_previous = samps_previous.value
        log_weight_previous = samps_previous.log_weight
        int_rand_key = int_rand_keys[iter_ind]

        ess_previous = samps_previous.ess
        resample_bool = ess_previous < (ess_threshold * n)
        x_res = cond(resample_bool, _resample, lambda tup: tup[0],
                     (x_previous, log_weight_previous, int_rand_key))
        log_weight_res = jnp.where(resample_bool, jnp.zeros(n),
                                   log_weight_previous)

        split_keys = random.split(int_rand_key, len(x_previous))

        x_new, log_weight_new = particle_filter.propose_and_intermediate_weight_vectorised(
            ssm_scenario, x_res, samps_previous.t, y[iter_ind], t[iter_ind],
            split_keys)

        log_weight_new = log_weight_res + log_weight_new

        samps_new = samps_previous.copy()
        samps_new.value = x_new
        samps_new.log_weight = log_weight_new
        samps_new.y = y[iter_ind]
        samps_new.t = t[iter_ind]
        samps_new.ess = ess_log_weight(log_weight_new)
        return samps_new, samps_new
예제 #4
0
 def proposal(self, abc_scenario: ABCScenario, reject_state: cdict,
              reject_extra: cdict) -> Tuple[cdict, cdict]:
     proposed_state = reject_state.copy()
     proposed_extra = reject_extra.copy()
     proposed_extra.random_key, subkey1, subkey2, subkey3 = random.split(
         reject_extra.random_key, 4)
     proposed_state.value = self.importance_proposal(abc_scenario, subkey1)
     proposed_state.prior_potential = abc_scenario.prior_potential(
         proposed_state.value, subkey2)
     proposed_state.simulated_data = abc_scenario.likelihood_sample(
         proposed_state.value, subkey3)
     proposed_state.distance = abc_scenario.distance_function(
         proposed_state.simulated_data)
     proposed_state.log_weight = self.log_weight(abc_scenario,
                                                 proposed_state,
                                                 proposed_extra)
     return proposed_state, proposed_extra
예제 #5
0
    def leapfrog_step(init_state: cdict, i: int):
        new_state = init_state.copy()

        p_half = init_state.momenta - stepsize / 2. * init_state.grad_potential
        new_state.value = init_state.value + stepsize * p_half

        new_state.potential, new_state.grad_potential = potential_and_grad(
            new_state.value, random_keys[i])

        new_state.momenta = p_half - stepsize / 2. * new_state.grad_potential

        next_sample_chain = new_state.copy()
        next_sample_chain.momenta = jnp.vstack([p_half, new_state.momenta])
        return new_state, next_sample_chain
예제 #6
0
    def proposal(self, scenario: Scenario, reject_state: cdict,
                 reject_extra: cdict) -> Tuple[cdict, cdict]:
        proposed_state = reject_state.copy()

        d = scenario.dim
        x = reject_state.value

        stepsize = reject_extra.parameters.stepsize

        reject_extra.random_key, subkey, scen_key = random.split(
            reject_extra.random_key, 3)

        proposed_state.value = x + jnp.sqrt(stepsize) * random.normal(
            subkey, (d, ))
        proposed_state.potential = scenario.potential(proposed_state.value,
                                                      scen_key)

        return proposed_state, reject_extra
예제 #7
0
def resample_particles(particles: cdict,
                       random_key: jnp.ndarray,
                       resample_full: bool = True) -> cdict:
    out_particles = particles.copy()
    n = particles.value.shape[-2]
    if out_particles.log_weight.ndim == 1:
        out_particles.value = _resample(
            (particles.value, particles.log_weight, random_key))
        out_particles.log_weight = jnp.zeros(n)
    elif resample_full:
        out_particles.value = _resample(
            (particles.value, particles.log_weight[-1], random_key))
        out_particles.log_weight = index_update(out_particles.log_weight, -1,
                                                jnp.zeros(n))
    else:
        latest_value = _resample(
            (particles.value[-1], particles.log_weight[-1], random_key))
        out_particles.value = index_update(out_particles.value, -1,
                                           latest_value)
        out_particles.log_weight = index_update(out_particles.log_weight, -1,
                                                jnp.zeros(n))
    return out_particles
예제 #8
0
def backward_simulation_rejection(ssm_scenario: StateSpaceModel,
                                  marginal_particles: cdict,
                                  n_samps: int,
                                  random_key: jnp.ndarray,
                                  maximum_rejections: int,
                                  init_bound_param: float,
                                  bound_inflation: float) -> cdict:
    marg_particles_vals = marginal_particles.value
    times = marginal_particles.t
    marginal_log_weight = marginal_particles.log_weight

    T, n_pf, d = marg_particles_vals.shape

    t_keys = random.split(random_key, T)
    final_particle_vals = marg_particles_vals[-1, random.categorical(t_keys[-1],
                                                                     marginal_log_weight[-1],
                                                                     shape=(n_samps,))]

    def back_sim_body(x_tplus1_all: jnp.ndarray, ind: int):
        x_t_all, num_transition_evals = rejection_resampling(ssm_scenario,
                                                             marg_particles_vals[ind], times[ind],
                                                             x_tplus1_all, times[ind + 1],
                                                             marginal_log_weight[ind], t_keys[ind],
                                                             maximum_rejections, init_bound_param, bound_inflation)
        return x_t_all, (x_t_all, num_transition_evals)

    _, back_sim_out = scan(back_sim_body,
                           final_particle_vals,
                           jnp.arange(T - 2, -1, -1), unroll=1)

    back_sim_particles, num_transition_evals = back_sim_out

    out_samps = marginal_particles.copy()
    out_samps.value = jnp.vstack([back_sim_particles[::-1], final_particle_vals[jnp.newaxis]])
    out_samps.num_transition_evals = jnp.append(0, num_transition_evals[::-1])
    del out_samps.log_weight
    return out_samps
예제 #9
0
def run_particle_filter_for_marginals(ssm_scenario: StateSpaceModel,
                                      particle_filter: ParticleFilter,
                                      y: jnp.ndarray,
                                      t: jnp.ndarray,
                                      random_key: jnp.ndarray,
                                      n: int = None,
                                      initial_sample: cdict = None,
                                      ess_threshold: float = 0.5) -> cdict:
    if y.ndim == 1:
        y = y[..., jnp.newaxis]

    if initial_sample is None:
        random_key, sub_key = random.split(random_key)
        init_y = y[0]
        y = y[1:]

        initial_sample = initiate_particles(ssm_scenario, particle_filter, n,
                                            sub_key, init_y, t[0])
        t = t[1:]

    if n is None:
        n = initial_sample.value.shape[1]

    num_propagate_steps = len(y)
    int_rand_keys = random.split(random_key, num_propagate_steps)

    def particle_filter_body(samps_previous: cdict,
                             iter_ind: int) -> Tuple[cdict, cdict]:
        x_previous = samps_previous.value
        log_weight_previous = samps_previous.log_weight
        int_rand_key = int_rand_keys[iter_ind]

        ess_previous = samps_previous.ess
        resample_bool = ess_previous < (ess_threshold * n)
        x_res = cond(resample_bool, _resample, lambda tup: tup[0],
                     (x_previous, log_weight_previous, int_rand_key))
        log_weight_res = jnp.where(resample_bool, jnp.zeros(n),
                                   log_weight_previous)

        split_keys = random.split(int_rand_key, len(x_previous))

        x_new, log_weight_new = particle_filter.propose_and_intermediate_weight_vectorised(
            ssm_scenario, x_res, samps_previous.t, y[iter_ind], t[iter_ind],
            split_keys)

        log_weight_new = log_weight_res + log_weight_new

        samps_new = samps_previous.copy()
        samps_new.value = x_new
        samps_new.log_weight = log_weight_new
        samps_new.y = y[iter_ind]
        samps_new.t = t[iter_ind]
        samps_new.ess = ess_log_weight(log_weight_new)
        return samps_new, samps_new

    _, after_init_samps = scan(particle_filter_body, initial_sample[-1],
                               jnp.arange(num_propagate_steps))

    out_samps = initial_sample.copy()
    out_samps.value = jnp.append(initial_sample.value,
                                 after_init_samps.value,
                                 axis=0)
    out_samps.log_weight = jnp.append(initial_sample.log_weight,
                                      after_init_samps.log_weight,
                                      axis=0)
    out_samps.y = jnp.append(initial_sample.y, after_init_samps.y, axis=0)
    out_samps.t = jnp.append(initial_sample.t, after_init_samps.t)
    out_samps.ess = jnp.append(initial_sample.ess, after_init_samps.ess)

    return out_samps