def always(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: d = scenario.dim stepsize = reject_extra.parameters.stepsize friction = reject_extra.parameters.friction reject_state.momenta = reject_state.momenta * -1 # Update p - exactly according to solution of OU process # Accepted even if leapfrog step is rejected reject_extra.random_key, subkey = random.split(reject_extra.random_key) reject_state.momenta = reject_state.momenta * jnp.exp(- friction * stepsize) \ + jnp.sqrt(1 - jnp.exp(- 2 * friction * stepsize)) * random.normal(subkey, (d,)) return reject_state, reject_extra
def startup(self, scenario: Scenario, n: int, initial_state: cdict, initial_extra: cdict, **kwargs) -> Tuple[cdict, cdict]: initial_state, initial_extra = super().startup(scenario, n, initial_state, initial_extra, **kwargs) initial_extra.random_key, scen_key = random.split( initial_extra.random_key) initial_state.potential, initial_state.grad_potential = scenario.potential_and_grad( initial_state.value, scen_key) if not hasattr( initial_state, 'momenta') or initial_state.momenta.shape[-1] != scenario.dim: initial_state.momenta = jnp.zeros(scenario.dim) return initial_state, initial_extra
def proposal(self, scenario: Scenario, reject_state: cdict, reject_extra: cdict) -> Tuple[cdict, cdict]: d = scenario.dim random_keys = random.split(reject_extra.random_key, self.parameters.leapfrog_steps + 2) reject_extra.random_key = random_keys[0] reject_state.momenta = random.normal(random_keys[1], (d, )) all_leapfrog_state = utils.leapfrog(scenario.potential_and_grad, reject_state, reject_extra.parameters.stepsize, random_keys[2:]) proposed_state = all_leapfrog_state[-1] # Technically we should reverse momenta now # but momenta target is symmetric and then immediately resampled at the next step anyway return proposed_state, reject_extra