Example #1
0
File: nuts.py Project: xidulu/pyro
    def _build_basetree(self, z, r, z_grads, log_slice, direction,
                        energy_current):
        step_size = self.step_size if direction == 1 else -self.step_size
        z_new, r_new, z_grads, potential_energy = velocity_verlet(
            z,
            r,
            self.potential_fn,
            self.inverse_mass_matrix,
            step_size,
            z_grads=z_grads)
        r_new_flat = torch.cat(
            [r_new[site_name].reshape(-1) for site_name in sorted(r_new)])
        energy_new = potential_energy + self._kinetic_energy(r_new)
        # handle the NaN case
        energy_new = scalar_like(
            energy_new,
            float("inf")) if torch_isnan(energy_new) else energy_new
        sliced_energy = energy_new + log_slice
        diverging = (sliced_energy > self._max_sliced_energy)
        delta_energy = energy_new - energy_current
        accept_prob = (-delta_energy).exp().clamp(max=1.0)

        if self.use_multinomial_sampling:
            tree_weight = -sliced_energy
        else:
            # As a part of the slice sampling process (see below), along the trajectory
            #   we eliminate states which p(z, r) < u, or dE > 0.
            # Due to this elimination (and stop doubling conditions),
            #   the weight of binary tree might not equal to 2^tree_depth.
            tree_weight = scalar_like(sliced_energy,
                                      1. if sliced_energy <= 0 else 0.)

        return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new,
                         potential_energy, z_grads, r_new_flat, tree_weight,
                         False, diverging, accept_prob, 1)
Example #2
0
    def sample(self, params):
        z, potential_energy, z_grads = self._fetch_from_cache()
        # recompute PE when cache is cleared
        if z is None:
            z = params
            z_grads, potential_energy = potential_grad(self.potential_fn, z)
            self._cache(z, potential_energy, z_grads)
        # return early if no sample sites
        elif len(z) == 0:
            self._t += 1
            self._mean_accept_prob = 1.
            if self._t > self._warmup_steps:
                self._accept_cnt += 1
            return params
        r, r_unscaled = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r_unscaled) + potential_energy

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(
                z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad,
                self.step_size, self.num_steps, z_grads=z_grads)
            # apply Metropolis correction.
            r_new_unscaled = self.mass_matrix_adapter.unscale(r_new)
            energy_proposal = self._kinetic_energy(r_new_unscaled) + potential_energy_new
        delta_energy = energy_proposal - energy_current
        # handle the NaN case which may be the case for a diverging trajectory
        # when using a large step size.
        delta_energy = scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy
        if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps:
            self._divergences.append(self._t - self._warmup_steps)

        accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.),
                                                                     scalar_like(accept_prob, 1.)))
        accepted = False
        if rand < accept_prob:
            accepted = True
            z = z_new
            z_grads = z_grads_new
            self._cache(z, potential_energy_new, z_grads)

        self._t += 1
        if self._t > self._warmup_steps:
            n = self._t - self._warmup_steps
            if accepted:
                self._accept_cnt += 1
        else:
            n = self._t
            self._adapter.step(self._t, z, accept_prob, z_grads)

        self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n
        return z.copy()
Example #3
0
    def _build_basetree(self, z, r, z_grads, log_slice, direction,
                        energy_current):
        step_size = self.step_size if direction == 1 else -self.step_size
        z_new, r_new, z_grads, potential_energy = velocity_verlet(
            z,
            r,
            self.potential_fn,
            self.mass_matrix_adapter.kinetic_grad,
            step_size,
            z_grads=z_grads,
        )
        r_new_unscaled = self.mass_matrix_adapter.unscale(r_new)
        energy_new = potential_energy + self._kinetic_energy(r_new_unscaled)
        # handle the NaN case
        energy_new = (scalar_like(energy_new, float("inf"))
                      if torch_isnan(energy_new) else energy_new)
        sliced_energy = energy_new + log_slice
        diverging = sliced_energy > self._max_sliced_energy
        delta_energy = energy_new - energy_current
        accept_prob = (-delta_energy).exp().clamp(max=1.0)

        if self.use_multinomial_sampling:
            tree_weight = -sliced_energy
        else:
            # As a part of the slice sampling process (see below), along the trajectory
            #   we eliminate states which p(z, r) < u, or dE > 0.
            # Due to this elimination (and stop doubling conditions),
            #   the weight of binary tree might not equal to 2^tree_depth.
            tree_weight = scalar_like(sliced_energy,
                                      1.0 if sliced_energy <= 0 else 0.0)

        r_sum = r_new_unscaled
        return _TreeInfo(
            z_new,
            r_new,
            r_new_unscaled,
            z_grads,
            z_new,
            r_new,
            r_new_unscaled,
            z_grads,
            z_new,
            potential_energy,
            z_grads,
            r_sum,
            tree_weight,
            False,
            diverging,
            accept_prob,
            1,
        )
Example #4
0
def not_pooled(at_bats, hits):
    r"""
    Number of hits in $K$ at bats for each player has a Binomial
    distribution with independent probability of success, $\phi_i$.

    :param (torch.Tensor) at_bats: Number of at bats for each player.
    :param (torch.Tensor) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    num_players = at_bats.shape[0]
    with pyro.plate("num_players", num_players):
        phi_prior = Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1))
        phi = pyro.sample("phi", phi_prior)
        return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)
Example #5
0
def partially_pooled_with_logit(at_bats, hits):
    r"""
    Number of hits has a Binomial distribution with a logit link function.
    The logits $\alpha$ for each player is normally distributed with the
    mean and scale parameters sharing a common prior.

    :param (torch.Tensor) at_bats: Number of at bats for each player.
    :param (torch.Tensor) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    num_players = at_bats.shape[0]
    loc = pyro.sample("loc", Normal(scalar_like(at_bats, -1), scalar_like(at_bats, 1)))
    scale = pyro.sample("scale", HalfCauchy(scale=scalar_like(at_bats, 1)))
    with pyro.plate("num_players", num_players):
        alpha = pyro.sample("alpha", Normal(loc, scale))
        return pyro.sample("obs", Binomial(at_bats, logits=alpha), obs=hits)
Example #6
0
def partially_pooled(at_bats, hits):
    r"""
    Number of hits has a Binomial distribution with independent
    probability of success, $\phi_i$. Each $\phi_i$ follows a Beta
    distribution with concentration parameters $c_1$ and $c_2$, where
    $c_1 = m * kappa$, $c_2 = (1 - m) * kappa$, $m ~ Uniform(0, 1)$,
    and $kappa ~ Pareto(1, 1.5)$.

    :param (torch.Tensor) at_bats: Number of at bats for each player.
    :param (torch.Tensor) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    num_players = at_bats.shape[0]
    m = pyro.sample("m", Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1)))
    kappa = pyro.sample("kappa", Pareto(scalar_like(at_bats, 1), scalar_like(at_bats, 1.5)))
    with pyro.plate("num_players", num_players):
        phi_prior = Beta(m * kappa, (1 - m) * kappa)
        phi = pyro.sample("phi", phi_prior)
        return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)
Example #7
0
    def sample(self, params):
        z, potential_energy, z_grads = self._fetch_from_cache()
        # recompute PE when cache is cleared
        if z is None:
            z = params
            potential_energy = self.potential_fn(z)
            self._cache(z, potential_energy)
        # return early if no sample sites
        elif len(z) == 0:
            self._t += 1
            self._mean_accept_prob = 1.
            if self._t > self._warmup_steps:
                self._accept_cnt += 1
            return z
        r, r_flat = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r) + potential_energy

        # Ideally, following a symplectic integrator trajectory, the energy is constant.
        # In that case, we can sample the proposal uniformly, and there is no need to use "slice".
        # However, it is not the case for real situation: there are errors during the computation.
        # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted
        # by u).
        # The sampling process goes as follows:
        #   first sampling u from initial state (z_0, r_0) according to
        #     u ~ Uniform(0, p(z_0, r_0)),
        #   then sampling state (z, r) from the integrator trajectory according to
        #     (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}).
        #
        # For more information about slice sampling method, see [3].
        # For another version of NUTS which uses multinomial sampling instead of slice sampling,
        # see [2].

        if self.use_multinomial_sampling:
            log_slice = -energy_current
        else:
            # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can
            # sample log_slice directly using `energy`, so as to avoid potential underflow or
            # overflow issues ([2]).
            slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t),
                                         dist.Exponential(scalar_like(energy_current, 1.)))
            log_slice = -energy_current - slice_exp_term

        z_left = z_right = z
        r_left = r_right = r
        z_left_grads = z_right_grads = z_grads
        accepted = False
        r_sum = r_flat
        sum_accept_probs = 0.
        num_proposals = 0
        tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.)

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation.
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            # doubling process, stop when turning or diverging
            tree_depth = 0
            while tree_depth < self._max_tree_depth:
                direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth),
                                        dist.Bernoulli(probs=scalar_like(tree_weight, 0.5)))
                direction = int(direction.item())
                if direction == 1:  # go to the right, start from the right leaf of current tree
                    new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice,
                                                direction, tree_depth, energy_current)
                    # update leaf for the next doubling process
                    z_right = new_tree.z_right
                    r_right = new_tree.r_right
                    z_right_grads = new_tree.z_right_grads
                else:  # go the the left, start from the left leaf of current tree
                    new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice,
                                                direction, tree_depth, energy_current)
                    z_left = new_tree.z_left
                    r_left = new_tree.r_left
                    z_left_grads = new_tree.z_left_grads

                sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs
                num_proposals = num_proposals + new_tree.num_proposals

                # stop doubling
                if new_tree.diverging:
                    if self._t >= self._warmup_steps:
                        self._divergences.append(self._t - self._warmup_steps)
                    break

                if new_tree.turning:
                    break

                tree_depth += 1

                if self.use_multinomial_sampling:
                    new_tree_prob = (new_tree.weight - tree_weight).exp()
                else:
                    new_tree_prob = new_tree.weight / tree_weight
                rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth),
                                   dist.Uniform(scalar_like(new_tree_prob, 0.),
                                                scalar_like(new_tree_prob, 1.)))
                if rand < new_tree_prob:
                    accepted = True
                    z = new_tree.z_proposal
                    self._cache(z, new_tree.z_proposal_pe, new_tree.z_proposal_grads)

                r_sum = r_sum + new_tree.r_sum
                if self._is_turning(r_left, r_right, r_sum):  # stop doubling
                    break
                else:  # update tree_weight
                    if self.use_multinomial_sampling:
                        tree_weight = _logaddexp(tree_weight, new_tree.weight)
                    else:
                        tree_weight = tree_weight + new_tree.weight

        accept_prob = sum_accept_probs / num_proposals

        self._t += 1
        if self._t > self._warmup_steps:
            n = self._t - self._warmup_steps
            if accepted:
                self._accept_cnt += 1
        else:
            n = self._t
            self._adapter.step(self._t, z, accept_prob)
        self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n

        return z.copy()
Example #8
0
    def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_current):
        if tree_depth == 0:
            return self._build_basetree(z, r, z_grads, log_slice, direction, energy_current)

        # build the first half of tree
        half_tree = self._build_tree(z, r, z_grads, log_slice,
                                     direction, tree_depth-1, energy_current)
        z_proposal = half_tree.z_proposal
        z_proposal_pe = half_tree.z_proposal_pe
        z_proposal_grads = half_tree.z_proposal_grads

        # Check conditions to stop doubling. If we meet that condition,
        #     there is no need to build the other tree.
        if half_tree.turning or half_tree.diverging:
            return half_tree

        # Else, build remaining half of tree.
        # If we are going to the right, start from the right leaf of the first half.
        if direction == 1:
            z = half_tree.z_right
            r = half_tree.r_right
            z_grads = half_tree.z_right_grads
        else:  # otherwise, start from the left leaf of the first half
            z = half_tree.z_left
            r = half_tree.r_left
            z_grads = half_tree.z_left_grads
        other_half_tree = self._build_tree(z, r, z_grads, log_slice,
                                           direction, tree_depth-1, energy_current)

        if self.use_multinomial_sampling:
            tree_weight = _logaddexp(half_tree.weight, other_half_tree.weight)
        else:
            tree_weight = half_tree.weight + other_half_tree.weight
        sum_accept_probs = half_tree.sum_accept_probs + other_half_tree.sum_accept_probs
        num_proposals = half_tree.num_proposals + other_half_tree.num_proposals
        r_sum = half_tree.r_sum + other_half_tree.r_sum

        # The probability of that proposal belongs to which half of tree
        #     is computed based on the weights of each half.
        if self.use_multinomial_sampling:
            other_half_tree_prob = (other_half_tree.weight - tree_weight).exp()
        else:
            # For the special case that the weights of each half are both 0,
            #   we choose the proposal from the first half
            #   (any is fine, because the probability of picking it at the end is 0!).
            other_half_tree_prob = (other_half_tree.weight / tree_weight if tree_weight > 0
                                    else scalar_like(tree_weight, 0.))
        is_other_half_tree = pyro.sample("is_other_half_tree",
                                         dist.Bernoulli(probs=other_half_tree_prob))

        if is_other_half_tree == 1:
            z_proposal = other_half_tree.z_proposal
            z_proposal_pe = other_half_tree.z_proposal_pe
            z_proposal_grads = other_half_tree.z_proposal_grads

        # leaves of the full tree are determined by the direction
        if direction == 1:
            z_left = half_tree.z_left
            r_left = half_tree.r_left
            z_left_grads = half_tree.z_left_grads
            z_right = other_half_tree.z_right
            r_right = other_half_tree.r_right
            z_right_grads = other_half_tree.z_right_grads
        else:
            z_left = other_half_tree.z_left
            r_left = other_half_tree.r_left
            z_left_grads = other_half_tree.z_left_grads
            z_right = half_tree.z_right
            r_right = half_tree.r_right
            z_right_grads = half_tree.z_right_grads

        # We already check if first half tree is turning. Now, we check
        #     if the other half tree or full tree are turning.
        turning = other_half_tree.turning or self._is_turning(r_left, r_right, r_sum)

        # The divergence is checked by the second half tree (the first half is already checked).
        diverging = other_half_tree.diverging

        return _TreeInfo(z_left, r_left, z_left_grads, z_right, r_right, z_right_grads, z_proposal,
                         z_proposal_pe, z_proposal_grads, r_sum, tree_weight, turning, diverging,
                         sum_accept_probs, num_proposals)