def test_time_reversibility(example): model, args = example q_forward, p_forward = velocity_verlet(args.q_i, args.p_i, model.potential_fn, args.step_size, args.num_steps) p_reverse = {key: -val for key, val in p_forward.items()} q_f, p_f = velocity_verlet(q_forward, p_reverse, model.potential_fn, args.step_size, args.num_steps) assert_equal(q_f, args.q_i, 1e-5)
def test_trajectory(example): model, args = example q_f, p_f = velocity_verlet(args.q_i, args.p_i, model.potential_fn, args.step_size, args.num_steps) logger.info("initial q: {}".format(args.q_i)) logger.info("final q: {}".format(q_f)) assert_equal(q_f, args.q_f, args.prec) assert_equal(p_f, args.p_f, args.prec)
def test_energy_conservation(example): model, args = example q_f, p_f = velocity_verlet(args.q_i, args.p_i, model.potential_fn, args.step_size, args.num_steps) energy_initial = model.energy(args.q_i, args.p_i) energy_final = model.energy(q_f, p_f) logger.info("initial energy: {}".format(energy_initial.item())) logger.info("final energy: {}".format(energy_final.item())) assert_equal(energy_final, energy_initial)
def test_time_reversibility(example): model, args = example q_forward, p_forward, _, _ = velocity_verlet( args.q_i, args.p_i, model.potential_fn, model.kinetic_grad, args.step_size, args.num_steps, ) p_reverse = {key: -val for key, val in p_forward.items()} q_f, p_f, _, _ = velocity_verlet( q_forward, p_reverse, model.potential_fn, model.kinetic_grad, args.step_size, args.num_steps, ) assert_equal(q_f, args.q_i, 1e-5)
def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() # recompute PE when cache is cleared if z is None: z = params z_grads, potential_energy = potential_grad(self.potential_fn, z) self._cache(z, potential_energy, z_grads) # return early if no sample sites elif len(z) == 0: self._t += 1 self._mean_accept_prob = 1. if self._t > self._warmup_steps: self._accept_cnt += 1 return params r, r_unscaled = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r_unscaled) + potential_energy # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, self.step_size, self.num_steps, z_grads=z_grads) # apply Metropolis correction. r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_proposal = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_proposal - energy_current # handle the NaN case which may be the case for a diverging trajectory # when using a large step size. delta_energy = scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) accept_prob = (-delta_energy).exp().clamp(max=1.) rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.), scalar_like(accept_prob, 1.))) accepted = False if rand < accept_prob: accepted = True z = z_new z_grads = z_grads_new self._cache(z, potential_energy_new, z_grads) self._t += 1 if self._t > self._warmup_steps: n = self._t - self._warmup_steps if accepted: self._accept_cnt += 1 else: n = self._t self._adapter.step(self._t, z, accept_prob, z_grads) self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n return z.copy()
def _find_reasonable_step_size(self, z): step_size = self.step_size # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. potential_energy = self.potential_fn(z) r, r_unscaled = self._sample_r(name="r_presample_0") energy_current = self._kinetic_energy(r_unscaled) + potential_energy # This is required so as to avoid issues with autograd when model # contains transforms with cache_size > 0 (https://github.com/pyro-ppl/pyro/issues/2292) z = {k: v.clone() for k, v in z.items()} z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size) r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_new = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_new - energy_current # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN` which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample site). direction = 1 if self._direction_threshold < -delta_energy else -1 # define scale for step_size: 2 for increasing, 1/2 for decreasing step_size_scale = 2 ** direction direction_new = direction # keep scale step_size until accept_prob crosses its target # TODO: make thresholds for too small step_size or too large step_size t = 0 while direction_new == direction: t += 1 step_size = step_size_scale * step_size r, r_unscaled = self._sample_r(name="r_presample_{}".format(t)) energy_current = self._kinetic_energy(r_unscaled) + potential_energy z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size) r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_new = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_new - energy_current direction_new = 1 if self._direction_threshold < -delta_energy else -1 return step_size
def sample(self, trace): z = { name: node["value"].detach() for name, node in self._iter_latent_nodes(trace) } # automatically transform `z` to unconstrained space, if needed. for name, transform in self.transforms.items(): z[name] = transform(z[name]) r, _ = self._sample_r(name="r_t={}".format(self._t)) potential_energy, z_grads = self._fetch_from_cache() # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self._potential_energy, self.inverse_mass_matrix, self.step_size, self.num_steps, z_grads=z_grads) # apply Metropolis correction. energy_proposal = self._kinetic_energy( r_new) + potential_energy_new energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \ else self._energy(z, r) delta_energy = energy_proposal - energy_current # Set accept prob to 0.0 if delta_energy is `NaN` which may be # the case for a diverging trajectory when using a large step size. if torch_isnan(delta_energy): accept_prob = delta_energy.new_tensor(0.0) else: accept_prob = (-delta_energy).exp().clamp(max=1.) rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(torch.zeros(1), torch.ones(1))) if rand < accept_prob: self._accept_cnt += 1 z = z_new if self._t < self._warmup_steps: self._adapter.step(self._t, z, accept_prob) self._t += 1 # get trace with the constrained values for `z`. for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z)
def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): step_size = self.step_size if direction == 1 else -self.step_size z_new, r_new, z_grads, potential_energy = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size, z_grads=z_grads, ) r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_new = potential_energy + self._kinetic_energy(r_new_unscaled) # handle the NaN case energy_new = (scalar_like(energy_new, float("inf")) if torch_isnan(energy_new) else energy_new) sliced_energy = energy_new + log_slice diverging = sliced_energy > self._max_sliced_energy delta_energy = energy_new - energy_current accept_prob = (-delta_energy).exp().clamp(max=1.0) if self.use_multinomial_sampling: tree_weight = -sliced_energy else: # As a part of the slice sampling process (see below), along the trajectory # we eliminate states which p(z, r) < u, or dE > 0. # Due to this elimination (and stop doubling conditions), # the weight of binary tree might not equal to 2^tree_depth. tree_weight = scalar_like(sliced_energy, 1.0 if sliced_energy <= 0 else 0.0) r_sum = r_new_unscaled return _TreeInfo( z_new, r_new, r_new_unscaled, z_grads, z_new, r_new, r_new_unscaled, z_grads, z_new, potential_energy, z_grads, r_sum, tree_weight, False, diverging, accept_prob, 1, )
def _find_reasonable_step_size(self, z): step_size = self.step_size # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. potential_energy = self.potential_fn(z) r, _ = self._sample_r(name="r_presample_0") energy_current = self._kinetic_energy(r) + potential_energy z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.inverse_mass_matrix, step_size) energy_new = self._kinetic_energy(r_new) + potential_energy_new delta_energy = energy_new - energy_current # direction=1 means keep increasing step_size, otherwise decreasing step_size. # Note that the direction is -1 if delta_energy is `NaN` which may be the # case for a diverging trajectory (e.g. in the case of evaluating log prob # of a value simulated using a large step size for a constrained sample site). direction = 1 if self._direction_threshold < -delta_energy else -1 # define scale for step_size: 2 for increasing, 1/2 for decreasing step_size_scale = 2**direction direction_new = direction # keep scale step_size until accept_prob crosses its target # TODO: make thresholds for too small step_size or too large step_size t = 0 while direction_new == direction: t += 1 step_size = step_size_scale * step_size r, _ = self._sample_r(name="r_presample_{}".format(t)) energy_current = self._kinetic_energy(r) + potential_energy z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.inverse_mass_matrix, step_size) energy_new = self._kinetic_energy(r_new) + potential_energy_new delta_energy = energy_new - energy_current direction_new = 1 if self._direction_threshold < -delta_energy else -1 return step_size
def sample(self, trace): z = { name: node["value"].detach() for name, node in trace.iter_stochastic_nodes() } # automatically transform `z` to unconstrained space, if needed. for name, transform in self.transforms.items(): z[name] = transform(z[name]) r = { name: pyro.sample("r_{}_t={}".format(name, self._t), self._r_dist[name]) for name in self._r_dist } # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation dist_arg_check = False if self._adapt_phase else pyro.distributions.is_validation_enabled( ) with dist.validation_enabled(dist_arg_check): z_new, r_new = velocity_verlet(z, r, self._potential_energy, self.step_size, self.num_steps) # apply Metropolis correction. energy_proposal = self._energy(z_new, r_new) energy_current = self._energy(z, r) delta_energy = energy_proposal - energy_current rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(torch.zeros(1), torch.ones(1))) if rand < (-delta_energy).exp(): self._accept_cnt += 1 z = z_new if self._adapt_phase: # Set accept prob to 0.0 if delta_energy is `NaN` which may be # the case for a diverging trajectory when using a large step size. if torch_isnan(delta_energy): accept_prob = delta_energy.new_tensor(0.0) else: accept_prob = (-delta_energy).exp().clamp(max=1).item() self._adapt_step_size(accept_prob) self._t += 1 # get trace with the constrained values for `z`. for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z)
def sample(self, trace): z = {name: node["value"].detach() for name, node in trace.iter_stochastic_nodes()} # automatically transform `z` to unconstrained space, if needed. for name, transform in self.transforms.items(): z[name] = transform(z[name]) r = {name: pyro.sample("r_{}_t={}".format(name, self._t), self._r_dist[name]) for name in self._r_dist} # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation dist_arg_check = False if self._adapt_phase else pyro.distributions.is_validation_enabled() with dist.validation_enabled(dist_arg_check): z_new, r_new = velocity_verlet(z, r, self._potential_energy, self.step_size, self.num_steps) # apply Metropolis correction. energy_proposal = self._energy(z_new, r_new) energy_current = self._energy(z, r) delta_energy = energy_proposal - energy_current rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(torch.zeros(1), torch.ones(1))) if rand < (-delta_energy).exp(): self._accept_cnt += 1 z = z_new if self._adapt_phase: # Set accept prob to 0.0 if delta_energy is `NaN` which may be # the case for a diverging trajectory when using a large step size. if torch_isnan(delta_energy): accept_prob = delta_energy.new_tensor(0.0) else: accept_prob = (-delta_energy).exp().clamp(max=1).item() self._adapt_step_size(accept_prob) self._t += 1 # get trace with the constrained values for `z`. for name, transform in self.transforms.items(): z[name] = transform.inv(z[name]) return self._get_trace(z)
def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): step_size = self.step_size if direction == 1 else -self.step_size z_new, r_new, z_grads, potential_energy = velocity_verlet( z, r, self.potential_fn, self.inverse_mass_matrix, step_size, z_grads=z_grads) r_new_flat = torch.cat([r_new[site_name].reshape(-1) for site_name in sorted(r_new)]) energy_new = potential_energy + self._kinetic_energy(r_new) # handle the NaN case energy_new = scalar_like(energy_new, float("inf")) if torch_isnan(energy_new) else energy_new sliced_energy = energy_new + log_slice diverging = (sliced_energy > self._max_sliced_energy) delta_energy = energy_new - energy_current accept_prob = (-delta_energy).exp().clamp(max=1.0) if self.use_multinomial_sampling: tree_weight = -sliced_energy else: # As a part of the slice sampling process (see below), along the trajectory # we eliminate states which p(z, r) < u, or dE > 0. # Due to this elimination (and stop doubling conditions), # the weight of binary tree might not equal to 2^tree_depth. tree_weight = scalar_like(sliced_energy, 1. if sliced_energy <= 0 else 0.) return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new, potential_energy, z_grads, r_new_flat, tree_weight, False, diverging, accept_prob, 1)