Exemple #1
0
    def always(self, scenario: Scenario, reject_state: cdict,
               reject_extra: cdict) -> Tuple[cdict, cdict]:
        d = scenario.dim

        stepsize = reject_extra.parameters.stepsize
        friction = reject_extra.parameters.friction

        reject_state.momenta = reject_state.momenta * -1

        # Update p - exactly according to solution of OU process
        # Accepted even if leapfrog step is rejected
        reject_extra.random_key, subkey = random.split(reject_extra.random_key)
        reject_state.momenta = reject_state.momenta * jnp.exp(- friction * stepsize) \
                               + jnp.sqrt(1 - jnp.exp(- 2 * friction * stepsize)) * random.normal(subkey, (d,))
        return reject_state, reject_extra
Exemple #2
0
 def startup(self, scenario: Scenario, n: int, initial_state: cdict,
             initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]:
     initial_state, initial_extra = super().startup(scenario, n,
                                                    initial_state,
                                                    initial_extra, **kwargs)
     initial_extra.random_key, scen_key = random.split(
         initial_extra.random_key)
     initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad(
         initial_state.value, scen_key)
     if not hasattr(
             initial_state,
             'momenta') or initial_state.momenta.shape[-1] != scenario.dim:
         initial_state.momenta = jnp.zeros(scenario.dim)
     return initial_state, initial_extra
Exemple #3
0
    def proposal(self, scenario: Scenario, reject_state: cdict,
                 reject_extra: cdict) -> Tuple[cdict, cdict]:
        d = scenario.dim
        random_keys = random.split(reject_extra.random_key,
                                   self.parameters.leapfrog_steps + 2)

        reject_extra.random_key = random_keys[0]

        reject_state.momenta = random.normal(random_keys[1], (d, ))

        all_leapfrog_state = utils.leapfrog(scenario.potential_and_grad,
                                            reject_state,
                                            reject_extra.parameters.stepsize,
                                            random_keys[2:])
        proposed_state = all_leapfrog_state[-1]

        # Technically we should reverse momenta now
        # but momenta target is symmetric and then immediately resampled at the next step anyway

        return proposed_state, reject_extra