def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): step_size = self.step_size if direction == 1 else -self.step_size z_new, r_new, z_grads, potential_energy = velocity_verlet( z, r, self.potential_fn, self.inverse_mass_matrix, step_size, z_grads=z_grads) r_new_flat = torch.cat( [r_new[site_name].reshape(-1) for site_name in sorted(r_new)]) energy_new = potential_energy + self._kinetic_energy(r_new) # handle the NaN case energy_new = scalar_like( energy_new, float("inf")) if torch_isnan(energy_new) else energy_new sliced_energy = energy_new + log_slice diverging = (sliced_energy > self._max_sliced_energy) delta_energy = energy_new - energy_current accept_prob = (-delta_energy).exp().clamp(max=1.0) if self.use_multinomial_sampling: tree_weight = -sliced_energy else: # As a part of the slice sampling process (see below), along the trajectory # we eliminate states which p(z, r) < u, or dE > 0. # Due to this elimination (and stop doubling conditions), # the weight of binary tree might not equal to 2^tree_depth. tree_weight = scalar_like(sliced_energy, 1. if sliced_energy <= 0 else 0.) return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new, potential_energy, z_grads, r_new_flat, tree_weight, False, diverging, accept_prob, 1)
def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() # recompute PE when cache is cleared if z is None: z = params z_grads, potential_energy = potential_grad(self.potential_fn, z) self._cache(z, potential_energy, z_grads) # return early if no sample sites elif len(z) == 0: self._t += 1 self._mean_accept_prob = 1. if self._t > self._warmup_steps: self._accept_cnt += 1 return params r, r_unscaled = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r_unscaled) + potential_energy # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, self.step_size, self.num_steps, z_grads=z_grads) # apply Metropolis correction. r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_proposal = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_proposal - energy_current # handle the NaN case which may be the case for a diverging trajectory # when using a large step size. delta_energy = scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) accept_prob = (-delta_energy).exp().clamp(max=1.) rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.), scalar_like(accept_prob, 1.))) accepted = False if rand < accept_prob: accepted = True z = z_new z_grads = z_grads_new self._cache(z, potential_energy_new, z_grads) self._t += 1 if self._t > self._warmup_steps: n = self._t - self._warmup_steps if accepted: self._accept_cnt += 1 else: n = self._t self._adapter.step(self._t, z, accept_prob, z_grads) self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n return z.copy()
def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): step_size = self.step_size if direction == 1 else -self.step_size z_new, r_new, z_grads, potential_energy = velocity_verlet( z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad, step_size, z_grads=z_grads, ) r_new_unscaled = self.mass_matrix_adapter.unscale(r_new) energy_new = potential_energy + self._kinetic_energy(r_new_unscaled) # handle the NaN case energy_new = (scalar_like(energy_new, float("inf")) if torch_isnan(energy_new) else energy_new) sliced_energy = energy_new + log_slice diverging = sliced_energy > self._max_sliced_energy delta_energy = energy_new - energy_current accept_prob = (-delta_energy).exp().clamp(max=1.0) if self.use_multinomial_sampling: tree_weight = -sliced_energy else: # As a part of the slice sampling process (see below), along the trajectory # we eliminate states which p(z, r) < u, or dE > 0. # Due to this elimination (and stop doubling conditions), # the weight of binary tree might not equal to 2^tree_depth. tree_weight = scalar_like(sliced_energy, 1.0 if sliced_energy <= 0 else 0.0) r_sum = r_new_unscaled return _TreeInfo( z_new, r_new, r_new_unscaled, z_grads, z_new, r_new, r_new_unscaled, z_grads, z_new, potential_energy, z_grads, r_sum, tree_weight, False, diverging, accept_prob, 1, )
def not_pooled(at_bats, hits): r""" Number of hits in $K$ at bats for each player has a Binomial distribution with independent probability of success, $\phi_i$. :param (torch.Tensor) at_bats: Number of at bats for each player. :param (torch.Tensor) hits: Number of hits for the given at bats. :return: Number of hits predicted by the model. """ num_players = at_bats.shape[0] with pyro.plate("num_players", num_players): phi_prior = Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1)) phi = pyro.sample("phi", phi_prior) return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)
def partially_pooled_with_logit(at_bats, hits): r""" Number of hits has a Binomial distribution with a logit link function. The logits $\alpha$ for each player is normally distributed with the mean and scale parameters sharing a common prior. :param (torch.Tensor) at_bats: Number of at bats for each player. :param (torch.Tensor) hits: Number of hits for the given at bats. :return: Number of hits predicted by the model. """ num_players = at_bats.shape[0] loc = pyro.sample("loc", Normal(scalar_like(at_bats, -1), scalar_like(at_bats, 1))) scale = pyro.sample("scale", HalfCauchy(scale=scalar_like(at_bats, 1))) with pyro.plate("num_players", num_players): alpha = pyro.sample("alpha", Normal(loc, scale)) return pyro.sample("obs", Binomial(at_bats, logits=alpha), obs=hits)
def partially_pooled(at_bats, hits): r""" Number of hits has a Binomial distribution with independent probability of success, $\phi_i$. Each $\phi_i$ follows a Beta distribution with concentration parameters $c_1$ and $c_2$, where $c_1 = m * kappa$, $c_2 = (1 - m) * kappa$, $m ~ Uniform(0, 1)$, and $kappa ~ Pareto(1, 1.5)$. :param (torch.Tensor) at_bats: Number of at bats for each player. :param (torch.Tensor) hits: Number of hits for the given at bats. :return: Number of hits predicted by the model. """ num_players = at_bats.shape[0] m = pyro.sample("m", Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1))) kappa = pyro.sample("kappa", Pareto(scalar_like(at_bats, 1), scalar_like(at_bats, 1.5))) with pyro.plate("num_players", num_players): phi_prior = Beta(m * kappa, (1 - m) * kappa) phi = pyro.sample("phi", phi_prior) return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)
def sample(self, params): z, potential_energy, z_grads = self._fetch_from_cache() # recompute PE when cache is cleared if z is None: z = params potential_energy = self.potential_fn(z) self._cache(z, potential_energy) # return early if no sample sites elif len(z) == 0: self._t += 1 self._mean_accept_prob = 1. if self._t > self._warmup_steps: self._accept_cnt += 1 return z r, r_flat = self._sample_r(name="r_t={}".format(self._t)) energy_current = self._kinetic_energy(r) + potential_energy # Ideally, following a symplectic integrator trajectory, the energy is constant. # In that case, we can sample the proposal uniformly, and there is no need to use "slice". # However, it is not the case for real situation: there are errors during the computation. # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted # by u). # The sampling process goes as follows: # first sampling u from initial state (z_0, r_0) according to # u ~ Uniform(0, p(z_0, r_0)), # then sampling state (z, r) from the integrator trajectory according to # (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}). # # For more information about slice sampling method, see [3]. # For another version of NUTS which uses multinomial sampling instead of slice sampling, # see [2]. if self.use_multinomial_sampling: log_slice = -energy_current else: # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can # sample log_slice directly using `energy`, so as to avoid potential underflow or # overflow issues ([2]). slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t), dist.Exponential(scalar_like(energy_current, 1.))) log_slice = -energy_current - slice_exp_term z_left = z_right = z r_left = r_right = r z_left_grads = z_right_grads = z_grads accepted = False r_sum = r_flat sum_accept_probs = 0. num_proposals = 0 tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.) # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation. with optional(pyro.validation_enabled(False), self._t < self._warmup_steps): # doubling process, stop when turning or diverging tree_depth = 0 while tree_depth < self._max_tree_depth: direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth), dist.Bernoulli(probs=scalar_like(tree_weight, 0.5))) direction = int(direction.item()) if direction == 1: # go to the right, start from the right leaf of current tree new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice, direction, tree_depth, energy_current) # update leaf for the next doubling process z_right = new_tree.z_right r_right = new_tree.r_right z_right_grads = new_tree.z_right_grads else: # go the the left, start from the left leaf of current tree new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice, direction, tree_depth, energy_current) z_left = new_tree.z_left r_left = new_tree.r_left z_left_grads = new_tree.z_left_grads sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs num_proposals = num_proposals + new_tree.num_proposals # stop doubling if new_tree.diverging: if self._t >= self._warmup_steps: self._divergences.append(self._t - self._warmup_steps) break if new_tree.turning: break tree_depth += 1 if self.use_multinomial_sampling: new_tree_prob = (new_tree.weight - tree_weight).exp() else: new_tree_prob = new_tree.weight / tree_weight rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth), dist.Uniform(scalar_like(new_tree_prob, 0.), scalar_like(new_tree_prob, 1.))) if rand < new_tree_prob: accepted = True z = new_tree.z_proposal self._cache(z, new_tree.z_proposal_pe, new_tree.z_proposal_grads) r_sum = r_sum + new_tree.r_sum if self._is_turning(r_left, r_right, r_sum): # stop doubling break else: # update tree_weight if self.use_multinomial_sampling: tree_weight = _logaddexp(tree_weight, new_tree.weight) else: tree_weight = tree_weight + new_tree.weight accept_prob = sum_accept_probs / num_proposals self._t += 1 if self._t > self._warmup_steps: n = self._t - self._warmup_steps if accepted: self._accept_cnt += 1 else: n = self._t self._adapter.step(self._t, z, accept_prob) self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n return z.copy()
def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current): if tree_depth == 0: return self._build_basetree(z, r, z_grads, log_slice, direction, energy_current) # build the first half of tree half_tree = self._build_tree(z, r, z_grads, log_slice, direction, tree_depth-1, energy_current) z_proposal = half_tree.z_proposal z_proposal_pe = half_tree.z_proposal_pe z_proposal_grads = half_tree.z_proposal_grads # Check conditions to stop doubling. If we meet that condition, # there is no need to build the other tree. if half_tree.turning or half_tree.diverging: return half_tree # Else, build remaining half of tree. # If we are going to the right, start from the right leaf of the first half. if direction == 1: z = half_tree.z_right r = half_tree.r_right z_grads = half_tree.z_right_grads else: # otherwise, start from the left leaf of the first half z = half_tree.z_left r = half_tree.r_left z_grads = half_tree.z_left_grads other_half_tree = self._build_tree(z, r, z_grads, log_slice, direction, tree_depth-1, energy_current) if self.use_multinomial_sampling: tree_weight = _logaddexp(half_tree.weight, other_half_tree.weight) else: tree_weight = half_tree.weight + other_half_tree.weight sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs num_proposals = half_tree.num_proposals + other_half_tree.num_proposals r_sum = half_tree.r_sum + other_half_tree.r_sum # The probability of that proposal belongs to which half of tree # is computed based on the weights of each half. if self.use_multinomial_sampling: other_half_tree_prob = (other_half_tree.weight - tree_weight).exp() else: # For the special case that the weights of each half are both 0, # we choose the proposal from the first half # (any is fine, because the probability of picking it at the end is 0!). other_half_tree_prob = (other_half_tree.weight / tree_weight if tree_weight > 0 else scalar_like(tree_weight, 0.)) is_other_half_tree = pyro.sample("is_other_half_tree", dist.Bernoulli(probs=other_half_tree_prob)) if is_other_half_tree == 1: z_proposal = other_half_tree.z_proposal z_proposal_pe = other_half_tree.z_proposal_pe z_proposal_grads = other_half_tree.z_proposal_grads # leaves of the full tree are determined by the direction if direction == 1: z_left = half_tree.z_left r_left = half_tree.r_left z_left_grads = half_tree.z_left_grads z_right = other_half_tree.z_right r_right = other_half_tree.r_right z_right_grads = other_half_tree.z_right_grads else: z_left = other_half_tree.z_left r_left = other_half_tree.r_left z_left_grads = other_half_tree.z_left_grads z_right = half_tree.z_right r_right = half_tree.r_right z_right_grads = half_tree.z_right_grads # We already check if first half tree is turning. Now, we check # if the other half tree or full tree are turning. turning = other_half_tree.turning or self._is_turning(r_left, r_right, r_sum) # The divergence is checked by the second half tree (the first half is already checked). diverging = other_half_tree.diverging return _TreeInfo(z_left, r_left, z_left_grads, z_right, r_right, z_right_grads, z_proposal, z_proposal_pe, z_proposal_grads, r_sum, tree_weight, turning, diverging, sum_accept_probs, num_proposals)