def _find_valid_initial_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params, max_tries_initial_params=100, num_chains=1, init_strategy=init_to_uniform, trace=None): params = prototype_params # For empty models, exit early if not params: return params params_per_chain = defaultdict(list) num_found = 0 model = InitMessenger(init_strategy)(model) for attempt in range(num_chains * max_tries_initial_params): if trace is None: trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) samples = {name: trace.nodes[name]["value"].detach() for name in params} params = {k: transforms[k](v) for k, v in samples.items()} pe_grad, pe = potential_grad(potential_fn, params) if torch.isfinite(pe) and all(map(torch.all, map(torch.isfinite, pe_grad.values()))): for k, v in params.items(): params_per_chain[k].append(v) num_found += 1 if num_found == num_chains: if num_chains == 1: return {k: v[0] for k, v in params_per_chain.items()} else: return {k: torch.stack(v) for k, v in params_per_chain.items()} trace = None raise ValueError("Model specification seems incorrect - cannot find valid initial params.")
def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() # recompute PE when cache is cleared if z is None: z = params z_grads, potential_energy = potential_grad(self.potential_fn, z) self._cache(z, potential_energy, z_grads) # return early if no sample sites elif len(z) == 0: self._t += 1 self._mean_accept_prob = 1. if self._t > self._warmup_steps: self._accept_cnt += 1 return params r, r_unscaled = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r_unscaled) + potential_energy # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, self.step_size, self.num_steps, z_grads=z_grads) # apply Metropolis correction. r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_proposal = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_proposal - energy_current # handle the NaN case which may be the case for a diverging trajectory # when using a large step size. delta_energy = scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) accept_prob = (-delta_energy).exp().clamp(max=1.) rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.), scalar_like(accept_prob, 1.))) accepted = False if rand < accept_prob: accepted = True z = z_new z_grads = z_grads_new self._cache(z, potential_energy_new, z_grads) self._t += 1 if self._t > self._warmup_steps: n = self._t - self._warmup_steps if accepted: self._accept_cnt += 1 else: n = self._t self._adapter.step(self._t, z, accept_prob, z_grads) self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n return z.copy()
def setup(self, warmup_steps, *args, **kwargs): self._warmup_steps = warmup_steps if self.model is not None: self._initialize_model_properties(args, kwargs) if self.initial_params: z = {k: v.detach() for k, v in self.initial_params.items()} z_grads, potential_energy = potential_grad(self.potential_fn, z) else: z_grads, potential_energy = {}, self.potential_fn(self.initial_params) self._cache(self.initial_params, potential_energy, z_grads) if self.initial_params: self._initialize_adapter()
def _get_init_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params, max_tries_initial_params=100, num_chains=1, strategy="uniform"): params = prototype_params params_per_chain = defaultdict(list) n = 0 # For empty models, exit early if not params: return params for i in range(max_tries_initial_params): while n < num_chains: if strategy == "uniform": params = { k: dist.Uniform(v.new_full(v.shape, -2), v.new_full(v.shape, 2)).sample() for k, v in params.items() } elif strategy == "prior": trace = poutine.trace(model).get_trace(*model_args, **model_kwargs) samples = { name: trace.nodes[name]["value"].detach() for name in params } params = {k: transforms[k](v) for k, v in samples.items()} pe_grad, pe = potential_grad(potential_fn, params) if torch.isfinite(pe) and all( map(torch.all, map(torch.isfinite, pe_grad.values()))): for k, v in params.items(): params_per_chain[k].append(v) n += 1 if num_chains == 1: return {k: v[0] for k, v in params_per_chain.items()} else: return {k: torch.stack(v) for k, v in params_per_chain.items()} raise ValueError( "Model specification seems incorrect - cannot find valid initial params." )
def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() # recompute PE when cache is cleared if z is None: z = params z_grads, potential_energy = potential_grad(self.potential_fn, z) self._cache(z, potential_energy, z_grads) # return early if no sample sites elif len(z) == 0: self._t += 1 self._mean_accept_prob = 1. if self._t > self._warmup_steps: self._accept_cnt += 1 return z r, r_unscaled = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r_unscaled) + potential_energy # Ideally, following a symplectic integrator trajectory, the energy is constant. # In that case, we can sample the proposal uniformly, and there is no need to use "slice". # However, it is not the case for real situation: there are errors during the computation. # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted # by u). # The sampling process goes as follows: # first sampling u from initial state (z_0, r_0) according to # u ~ Uniform(0, p(z_0, r_0)), # then sampling state (z, r) from the integrator trajectory according to # (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}). # # For more information about slice sampling method, see [3]. # For another version of NUTS which uses multinomial sampling instead of slice sampling, # see [2]. if self.use_multinomial_sampling: log_slice = -energy_current else: # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can # sample log_slice directly using `energy`, so as to avoid potential underflow or # overflow issues ([2]). slice_exp_term = pyro.sample( "slicevar_exp_t={}".format(self._t), dist.Exponential(scalar_like(energy_current, 1.))) log_slice = -energy_current - slice_exp_term z_left = z_right = z r_left = r_right = r r_left_unscaled = r_right_unscaled = r_unscaled z_left_grads = z_right_grads = z_grads accepted = False r_sum = r_unscaled sum_accept_probs = 0. num_proposals = 0 tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.) # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation. with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): # doubling process, stop when turning or diverging tree_depth = 0 while tree_depth < self._max_tree_depth: direction = pyro.sample( "direction_t={}_treedepth={}".format(self._t, tree_depth), dist.Bernoulli(probs=scalar_like(tree_weight, 0.5))) direction = int(direction.item()) if direction == 1: # go to the right, start from the right leaf of current tree new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice, direction, tree_depth, energy_current) # update leaf for the next doubling process z_right = new_tree.z_right r_right = new_tree.r_right r_right_unscaled = new_tree.r_right_unscaled z_right_grads = new_tree.z_right_grads else: # go the the left, start from the left leaf of current tree new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice, direction, tree_depth, energy_current) z_left = new_tree.z_left r_left = new_tree.r_left r_left_unscaled = new_tree.r_left_unscaled z_left_grads = new_tree.z_left_grads sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs num_proposals = num_proposals + new_tree.num_proposals # stop doubling if new_tree.diverging: if self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) break if new_tree.turning: break tree_depth += 1 if self.use_multinomial_sampling: new_tree_prob = (new_tree.weight - tree_weight).exp() else: new_tree_prob = new_tree.weight / tree_weight rand = pyro.sample( "rand_t={}_treedepth={}".format(self._t, tree_depth), dist.Uniform(scalar_like(new_tree_prob, 0.), scalar_like(new_tree_prob, 1.))) if rand < new_tree_prob: accepted = True z = new_tree.z_proposal z_grads = new_tree.z_proposal_grads self._cache(z, new_tree.z_proposal_pe, z_grads) r_sum = { site_names: r_sum[site_names] + new_tree.r_sum[site_names] for site_names in r_unscaled } if self._is_turning(r_left_unscaled, r_right_unscaled, r_sum): # stop doubling break else: # update tree_weight if self.use_multinomial_sampling: tree_weight = _logaddexp(tree_weight, new_tree.weight) else: tree_weight = tree_weight + new_tree.weight accept_prob = sum_accept_probs / num_proposals self._t += 1 if self._t > self._warmup_steps: n = self._t - self._warmup_steps if accepted: self._accept_cnt += 1 else: n = self._t self._adapter.step(self._t, z, accept_prob, z_grads) self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n return z.copy()