Example #1
0
def _find_valid_initial_params(model, model_args, model_kwargs, transforms, potential_fn,
                               prototype_params, max_tries_initial_params=100, num_chains=1,
                               init_strategy=init_to_uniform, trace=None):
    params = prototype_params

    # For empty models, exit early
    if not params:
        return params

    params_per_chain = defaultdict(list)
    num_found = 0
    model = InitMessenger(init_strategy)(model)
    for attempt in range(num_chains * max_tries_initial_params):
        if trace is None:
            trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)
        samples = {name: trace.nodes[name]["value"].detach() for name in params}
        params = {k: transforms[k](v) for k, v in samples.items()}
        pe_grad, pe = potential_grad(potential_fn, params)

        if torch.isfinite(pe) and all(map(torch.all, map(torch.isfinite, pe_grad.values()))):
            for k, v in params.items():
                params_per_chain[k].append(v)
            num_found += 1
            if num_found == num_chains:
                if num_chains == 1:
                    return {k: v[0] for k, v in params_per_chain.items()}
                else:
                    return {k: torch.stack(v) for k, v in params_per_chain.items()}
        trace = None
    raise ValueError("Model specification seems incorrect - cannot find valid initial params.")
Example #2
0
    def sample(self, params):
        z, potential_energy, z_grads = self._fetch_from_cache()
        # recompute PE when cache is cleared
        if z is None:
            z = params
            z_grads, potential_energy = potential_grad(self.potential_fn, z)
            self._cache(z, potential_energy, z_grads)
        # return early if no sample sites
        elif len(z) == 0:
            self._t += 1
            self._mean_accept_prob = 1.
            if self._t > self._warmup_steps:
                self._accept_cnt += 1
            return params
        r, r_unscaled = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r_unscaled) + potential_energy

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(
                z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad,
                self.step_size, self.num_steps, z_grads=z_grads)
            # apply Metropolis correction.
            r_new_unscaled = self.mass_matrix_adapter.unscale(r_new)
            energy_proposal = self._kinetic_energy(r_new_unscaled) + potential_energy_new
        delta_energy = energy_proposal - energy_current
        # handle the NaN case which may be the case for a diverging trajectory
        # when using a large step size.
        delta_energy = scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy
        if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps:
            self._divergences.append(self._t - self._warmup_steps)

        accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.),
                                                                     scalar_like(accept_prob, 1.)))
        accepted = False
        if rand < accept_prob:
            accepted = True
            z = z_new
            z_grads = z_grads_new
            self._cache(z, potential_energy_new, z_grads)

        self._t += 1
        if self._t > self._warmup_steps:
            n = self._t - self._warmup_steps
            if accepted:
                self._accept_cnt += 1
        else:
            n = self._t
            self._adapter.step(self._t, z, accept_prob, z_grads)

        self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n
        return z.copy()
Example #3
0
 def setup(self, warmup_steps, *args, **kwargs):
     self._warmup_steps = warmup_steps
     if self.model is not None:
         self._initialize_model_properties(args, kwargs)
     if self.initial_params:
         z = {k: v.detach() for k, v in self.initial_params.items()}
         z_grads, potential_energy = potential_grad(self.potential_fn, z)
     else:
         z_grads, potential_energy = {}, self.potential_fn(self.initial_params)
     self._cache(self.initial_params, potential_energy, z_grads)
     if self.initial_params:
         self._initialize_adapter()
Example #4
0
def _get_init_params(model,
                     model_args,
                     model_kwargs,
                     transforms,
                     potential_fn,
                     prototype_params,
                     max_tries_initial_params=100,
                     num_chains=1,
                     strategy="uniform"):
    params = prototype_params
    params_per_chain = defaultdict(list)
    n = 0

    # For empty models, exit early
    if not params:
        return params

    for i in range(max_tries_initial_params):
        while n < num_chains:
            if strategy == "uniform":
                params = {
                    k: dist.Uniform(v.new_full(v.shape, -2),
                                    v.new_full(v.shape, 2)).sample()
                    for k, v in params.items()
                }
            elif strategy == "prior":
                trace = poutine.trace(model).get_trace(*model_args,
                                                       **model_kwargs)
                samples = {
                    name: trace.nodes[name]["value"].detach()
                    for name in params
                }
                params = {k: transforms[k](v) for k, v in samples.items()}
            pe_grad, pe = potential_grad(potential_fn, params)

            if torch.isfinite(pe) and all(
                    map(torch.all, map(torch.isfinite, pe_grad.values()))):
                for k, v in params.items():
                    params_per_chain[k].append(v)
                n += 1
        if num_chains == 1:
            return {k: v[0] for k, v in params_per_chain.items()}
        else:
            return {k: torch.stack(v) for k, v in params_per_chain.items()}
    raise ValueError(
        "Model specification seems incorrect - cannot find valid initial params."
    )
Example #5
0
    def sample(self, params):
        z, potential_energy, z_grads = self._fetch_from_cache()
        # recompute PE when cache is cleared
        if z is None:
            z = params
            z_grads, potential_energy = potential_grad(self.potential_fn, z)
            self._cache(z, potential_energy, z_grads)
        # return early if no sample sites
        elif len(z) == 0:
            self._t += 1
            self._mean_accept_prob = 1.
            if self._t > self._warmup_steps:
                self._accept_cnt += 1
            return z
        r, r_unscaled = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r_unscaled) + potential_energy

        # Ideally, following a symplectic integrator trajectory, the energy is constant.
        # In that case, we can sample the proposal uniformly, and there is no need to use "slice".
        # However, it is not the case for real situation: there are errors during the computation.
        # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted
        # by u).
        # The sampling process goes as follows:
        #   first sampling u from initial state (z_0, r_0) according to
        #     u ~ Uniform(0, p(z_0, r_0)),
        #   then sampling state (z, r) from the integrator trajectory according to
        #     (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}).
        #
        # For more information about slice sampling method, see [3].
        # For another version of NUTS which uses multinomial sampling instead of slice sampling,
        # see [2].

        if self.use_multinomial_sampling:
            log_slice = -energy_current
        else:
            # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can
            # sample log_slice directly using `energy`, so as to avoid potential underflow or
            # overflow issues ([2]).
            slice_exp_term = pyro.sample(
                "slicevar_exp_t={}".format(self._t),
                dist.Exponential(scalar_like(energy_current, 1.)))
            log_slice = -energy_current - slice_exp_term

        z_left = z_right = z
        r_left = r_right = r
        r_left_unscaled = r_right_unscaled = r_unscaled
        z_left_grads = z_right_grads = z_grads
        accepted = False
        r_sum = r_unscaled
        sum_accept_probs = 0.
        num_proposals = 0
        tree_weight = scalar_like(energy_current,
                                  0. if self.use_multinomial_sampling else 1.)

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation.
        with optional(pyro.validation_enabled(False),
                      self._t < self._warmup_steps):
            # doubling process, stop when turning or diverging
            tree_depth = 0
            while tree_depth < self._max_tree_depth:
                direction = pyro.sample(
                    "direction_t={}_treedepth={}".format(self._t, tree_depth),
                    dist.Bernoulli(probs=scalar_like(tree_weight, 0.5)))
                direction = int(direction.item())
                if direction == 1:  # go to the right, start from the right leaf of current tree
                    new_tree = self._build_tree(z_right, r_right,
                                                z_right_grads, log_slice,
                                                direction, tree_depth,
                                                energy_current)
                    # update leaf for the next doubling process
                    z_right = new_tree.z_right
                    r_right = new_tree.r_right
                    r_right_unscaled = new_tree.r_right_unscaled
                    z_right_grads = new_tree.z_right_grads
                else:  # go the the left, start from the left leaf of current tree
                    new_tree = self._build_tree(z_left, r_left, z_left_grads,
                                                log_slice, direction,
                                                tree_depth, energy_current)
                    z_left = new_tree.z_left
                    r_left = new_tree.r_left
                    r_left_unscaled = new_tree.r_left_unscaled
                    z_left_grads = new_tree.z_left_grads

                sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs
                num_proposals = num_proposals + new_tree.num_proposals

                # stop doubling
                if new_tree.diverging:
                    if self._t >= self._warmup_steps:
                        self._divergences.append(self._t - self._warmup_steps)
                    break

                if new_tree.turning:
                    break

                tree_depth += 1

                if self.use_multinomial_sampling:
                    new_tree_prob = (new_tree.weight - tree_weight).exp()
                else:
                    new_tree_prob = new_tree.weight / tree_weight
                rand = pyro.sample(
                    "rand_t={}_treedepth={}".format(self._t, tree_depth),
                    dist.Uniform(scalar_like(new_tree_prob, 0.),
                                 scalar_like(new_tree_prob, 1.)))
                if rand < new_tree_prob:
                    accepted = True
                    z = new_tree.z_proposal
                    z_grads = new_tree.z_proposal_grads
                    self._cache(z, new_tree.z_proposal_pe, z_grads)

                r_sum = {
                    site_names: r_sum[site_names] + new_tree.r_sum[site_names]
                    for site_names in r_unscaled
                }
                if self._is_turning(r_left_unscaled, r_right_unscaled,
                                    r_sum):  # stop doubling
                    break
                else:  # update tree_weight
                    if self.use_multinomial_sampling:
                        tree_weight = _logaddexp(tree_weight, new_tree.weight)
                    else:
                        tree_weight = tree_weight + new_tree.weight

        accept_prob = sum_accept_probs / num_proposals

        self._t += 1
        if self._t > self._warmup_steps:
            n = self._t - self._warmup_steps
            if accepted:
                self._accept_cnt += 1
        else:
            n = self._t
            self._adapter.step(self._t, z, accept_prob, z_grads)
        self._mean_accept_prob += (accept_prob.item() -
                                   self._mean_accept_prob) / n

        return z.copy()