예제 #1
0
def test_simple(
    train_window, min_train_window, test_window, min_test_window, stride, warm_start
):
    duration = 30
    obs_dim = 2
    covariates = torch.zeros(duration, 0)
    data = torch.randn(duration, obs_dim) + 4
    forecaster_options = {"num_steps": 2, "warm_start": warm_start}

    expect_error = warm_start and train_window is not None
    with optional(pytest.raises(ValueError), expect_error):
        windows = backtest(
            data,
            covariates,
            Model,
            train_window=train_window,
            min_train_window=min_train_window,
            test_window=test_window,
            min_test_window=min_test_window,
            stride=stride,
            forecaster_options=forecaster_options,
        )
    if not expect_error:
        assert any(window["t0"] == 0 for window in windows)
        if stride == 1:
            assert any(window["t2"] == duration for window in windows)
        for window in windows:
            assert window["train_walltime"] >= 0
            assert window["test_walltime"] >= 0
            for name in DEFAULT_METRICS:
                assert name in window
                assert 0 < window[name] < math.inf
예제 #2
0
def test_num_chains(num_chains, cpu_count, default_init_params, monkeypatch):
    monkeypatch.setattr(torch.multiprocessing, "cpu_count", lambda: cpu_count)
    data = torch.tensor([1.0])
    initial_params, _, transforms, _ = initialize_model(normal_normal_model,
                                                        model_args=(data, ),
                                                        num_chains=num_chains)
    if default_init_params:
        initial_params = None
    kernel = PriorKernel(normal_normal_model)
    available_cpu = max(1, cpu_count - 1)
    mp_context = "spawn"
    with optional(pytest.warns(UserWarning), available_cpu < num_chains):
        mcmc = MCMC(
            kernel,
            num_samples=10,
            warmup_steps=10,
            num_chains=num_chains,
            initial_params=initial_params,
            transforms=transforms,
            mp_context=mp_context,
        )
    mcmc.run(data)
    assert mcmc.num_chains == num_chains
    if mcmc.num_chains == 1 or available_cpu < num_chains:
        assert isinstance(mcmc.sampler, _UnarySampler)
    else:
        assert isinstance(mcmc.sampler, _MultiSampler)
예제 #3
0
def test_predictive(num_samples, parallel):
    model, data, true_probs = beta_bernoulli()
    init_params, potential_fn, transforms, _ = initialize_model(
        model, model_args=(data, ))
    nuts_kernel = NUTS(potential_fn=potential_fn, transforms=transforms)
    mcmc = MCMC(nuts_kernel, 100, initial_params=init_params, warmup_steps=100)
    mcmc.run(data)
    samples = mcmc.get_samples()
    with ignore_experimental_warning():
        with optional(pytest.warns(UserWarning), num_samples
                      not in (None, 100)):
            predictive_samples = predictive(model,
                                            samples,
                                            num_samples=num_samples,
                                            return_sites=["beta", "obs"],
                                            parallel=parallel)

    # check shapes
    assert predictive_samples["beta"].shape == (100, 5)
    assert predictive_samples["obs"].shape == (100, 1000, 5)

    # check sample mean
    assert_close(predictive_samples["obs"].reshape([-1, 5]).mean(0),
                 true_probs,
                 rtol=0.1)
예제 #4
0
    def _potential_energy_jit(self, z):
        names, vals = zip(*sorted(z.items()))
        if self._compiled_potential_fn:
            return self._compiled_potential_fn(*vals)

        def compiled(*zi):
            z_constrained = list(zi)
            # transform to constrained space.
            for i, name in enumerate(names):
                if name in self.transforms:
                    transform = self.transforms[name]
                    z_constrained[i] = transform.inv(z_constrained[i])
            z_constrained = dict(zip(names, z_constrained))
            trace = self._get_trace(z_constrained)
            potential_energy = -self._compute_trace_log_prob(trace)
            # adjust by the jacobian for this transformation.
            for i, name in enumerate(names):
                if name in self.transforms:
                    transform = self.transforms[name]
                    potential_energy += transform.log_abs_det_jacobian(
                        z_constrained[name], zi[i]).sum()
            return potential_energy

        with pyro.validation_enabled(False), optional(
                ignore_jit_warnings(), self._ignore_jit_warnings):
            self._compiled_potential_fn = torch.jit.trace(compiled,
                                                          vals,
                                                          check_trace=False)
        return self._compiled_potential_fn(*vals)
예제 #5
0
파일: jit.py 프로젝트: pyro-ppl/pyro
    def __call__(self, *args, **kwargs):
        key = _hashable_args_kwargs(args, kwargs)

        # if first time
        if key not in self.compiled:
            # param capture
            with poutine.block():
                with poutine.trace(param_only=True) as first_param_capture:
                    self.fn(*args, **kwargs)

            self._param_names = list(
                set(first_param_capture.trace.nodes.keys()))
            unconstrained_params = tuple(
                pyro.param(name).unconstrained() for name in self._param_names)
            params_and_args = unconstrained_params + args
            weakself = weakref.ref(self)

            def compiled(*params_and_args):
                self = weakself()
                unconstrained_params = params_and_args[:len(self._param_names)]
                args = params_and_args[len(self._param_names):]
                constrained_params = {}
                for name, unconstrained_param in zip(self._param_names,
                                                     unconstrained_params):
                    constrained_param = pyro.param(
                        name)  # assume param has been initialized
                    assert constrained_param.unconstrained(
                    ) is unconstrained_param
                    constrained_params[name] = constrained_param
                return poutine.replay(self.fn,
                                      params=constrained_params)(*args,
                                                                 **kwargs)

            if self.ignore_warnings:
                compiled = ignore_jit_warnings()(compiled)
            with pyro.validation_enabled(False):
                time_compilation = self.jit_options.pop(
                    "time_compilation", False)
                with optional(timed(), time_compilation) as t:
                    self.compiled[key] = torch.jit.trace(
                        compiled, params_and_args, **self.jit_options)
                if time_compilation:
                    self.compile_time = t.elapsed
        else:
            unconstrained_params = [
                pyro.param(name).unconstrained() for name in self._param_names
            ]
            params_and_args = unconstrained_params + list(args)

        with poutine.block(hide=self._param_names):
            with poutine.trace(param_only=True) as param_capture:
                ret = self.compiled[key](*params_and_args)

        for name in param_capture.trace.nodes.keys():
            if name not in self._param_names:
                raise NotImplementedError(
                    "pyro.ops.jit.trace assumes all params are created on "
                    "first invocation, but found new param: {}".format(name))

        return ret
예제 #6
0
def test_num_chains(num_chains, cpu_count, monkeypatch):
    monkeypatch.setattr(torch.multiprocessing, 'cpu_count', lambda: cpu_count)
    kernel = PriorKernel(normal_normal_model)
    available_cpu = max(1, cpu_count - 1)
    with optional(pytest.warns(UserWarning), available_cpu < num_chains):
        mcmc = MCMC(kernel, num_samples=10, num_chains=num_chains)
    assert mcmc.num_chains == min(num_chains, available_cpu)
    if mcmc.num_chains == 1:
        assert isinstance(mcmc.sampler, _SingleSampler)
    else:
        assert isinstance(mcmc.sampler, _ParallelSampler)
예제 #7
0
    def sample(self, params):
        z, potential_energy, z_grads = self._fetch_from_cache()
        # recompute PE when cache is cleared
        if z is None:
            z = params
            z_grads, potential_energy = potential_grad(self.potential_fn, z)
            self._cache(z, potential_energy, z_grads)
        # return early if no sample sites
        elif len(z) == 0:
            self._t += 1
            self._mean_accept_prob = 1.
            if self._t > self._warmup_steps:
                self._accept_cnt += 1
            return params
        r, r_unscaled = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r_unscaled) + potential_energy

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(
                z, r, self.potential_fn, self.mass_matrix_adapter.kinetic_grad,
                self.step_size, self.num_steps, z_grads=z_grads)
            # apply Metropolis correction.
            r_new_unscaled = self.mass_matrix_adapter.unscale(r_new)
            energy_proposal = self._kinetic_energy(r_new_unscaled) + potential_energy_new
        delta_energy = energy_proposal - energy_current
        # handle the NaN case which may be the case for a diverging trajectory
        # when using a large step size.
        delta_energy = scalar_like(delta_energy, float("inf")) if torch_isnan(delta_energy) else delta_energy
        if delta_energy > self._max_sliced_energy and self._t >= self._warmup_steps:
            self._divergences.append(self._t - self._warmup_steps)

        accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.),
                                                                     scalar_like(accept_prob, 1.)))
        accepted = False
        if rand < accept_prob:
            accepted = True
            z = z_new
            z_grads = z_grads_new
            self._cache(z, potential_energy_new, z_grads)

        self._t += 1
        if self._t > self._warmup_steps:
            n = self._t - self._warmup_steps
            if accepted:
                self._accept_cnt += 1
        else:
            n = self._t
            self._adapter.step(self._t, z, accept_prob, z_grads)

        self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n
        return z.copy()
예제 #8
0
파일: api.py 프로젝트: pyro-ppl/pyro
    def run(self, *args, **kwargs):
        """
        Run StreamingMCMC to compute required `self._statistics`.
        """
        self._args, self._kwargs = args, kwargs
        num_samples = [0] * self.num_chains

        with optional(
                pyro.validation_enabled(not self.disable_validation),
                self.disable_validation is not None,
        ):
            args = [
                arg.detach() if torch.is_tensor(arg) else arg for arg in args
            ]
            for x, chain_id in self.sampler.run(*args, **kwargs):
                if num_samples[chain_id] == 0:
                    # If transforms is not explicitly provided, infer automatically using
                    # model args, kwargs.
                    if self.transforms is None:
                        self._set_transforms(*args, **kwargs)
                    num_samples[chain_id] += 1
                    z_structure = x
                elif num_samples[chain_id] == self.num_samples + 1:
                    self._diagnostics[chain_id] = x
                else:
                    num_samples[chain_id] += 1
                    if self.num_chains > 1:
                        x_cloned = x.clone()
                        del x
                    else:
                        x_cloned = x

                    # unpack latent
                    pos = 0
                    z_acc = z_structure.copy()
                    for k in sorted(z_structure):
                        shape = z_structure[k]
                        next_pos = pos + shape.numel()
                        z_acc[k] = x_cloned[pos:next_pos].reshape(shape)
                        pos = next_pos

                    for name, z in z_acc.items():
                        if name in self.transforms:
                            z_acc[name] = self.transforms[name].inv(z)

                    self._statistics.update({
                        (chain_id, name): transformed_sample
                        for name, transformed_sample in z_acc.items()
                    })

        # terminate the sampler (shut down worker processes)
        self.sampler.terminate(True)
예제 #9
0
    def sample(self, trace):
        z = {
            name: node["value"].detach()
            for name, node in self._iter_latent_nodes(trace)
        }
        # automatically transform `z` to unconstrained space, if needed.
        for name, transform in self.transforms.items():
            z[name] = transform(z[name])

        r, _ = self._sample_r(name="r_t={}".format(self._t))

        potential_energy, z_grads = self._fetch_from_cache()
        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False),
                      self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(
                z,
                r,
                self._potential_energy,
                self.inverse_mass_matrix,
                self.step_size,
                self.num_steps,
                z_grads=z_grads)
            # apply Metropolis correction.
            energy_proposal = self._kinetic_energy(
                r_new) + potential_energy_new
            energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \
                else self._energy(z, r)
        delta_energy = energy_proposal - energy_current
        # Set accept prob to 0.0 if delta_energy is `NaN` which may be
        # the case for a diverging trajectory when using a large step size.
        if torch_isnan(delta_energy):
            accept_prob = delta_energy.new_tensor(0.0)
        else:
            accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = pyro.sample("rand_t={}".format(self._t),
                           dist.Uniform(torch.zeros(1), torch.ones(1)))
        if rand < accept_prob:
            self._accept_cnt += 1
            z = z_new

        if self._t < self._warmup_steps:
            self._adapter.step(self._t, z, accept_prob)

        self._t += 1

        # get trace with the constrained values for `z`.
        for name, transform in self.transforms.items():
            z[name] = transform.inv(z[name])
        return self._get_trace(z)
예제 #10
0
def test_welford_dense(n_samples, dim_size):
    w = WelfordCovariance(diagonal=False)
    loc = torch.zeros(dim_size)
    cov = torch.randn(dim_size, dim_size)
    cov = torch.mm(cov, cov.t())
    dist = torch.distributions.MultivariateNormal(loc=loc, covariance_matrix=cov)
    samples = dist.sample(torch.Size([n_samples]))
    for sample in samples:
        w.update(sample)

    with optional(pytest.raises(RuntimeError), n_samples == 1):
        estimates = w.get_covariance(regularize=False).cpu().numpy()
        sample_cov = np.cov(samples.cpu().numpy(), bias=False, rowvar=False)
        assert_equal(estimates, sample_cov)
예제 #11
0
def test_ubersum_collide_implemented(impl, implemented):
    # Non-tree plates cause exponential blowup,
    # so ubersum() refuses to evaluate them.
    #
    #   z {a,b}
    #     /   \
    # x {a}  y {b}
    #      \  /
    #       {}  <--- target
    a, b, c, d = 2, 3, 4, 5
    x = torch.randn(a, c)
    y = torch.randn(b, d)
    z = torch.randn(a, b, c, d)
    raises = pytest.raises(NotImplementedError, match='Expected tree-structured plate nesting')
    with optional(raises, not implemented):
        impl('ac,bd,abcd->', x, y, z, plates='ab', modulo_total=True)
예제 #12
0
def test_welford_diagonal(n_samples, dim_size):
    w = WelfordCovariance(diagonal=True)
    loc = torch.zeros(dim_size)
    cov_diagonal = torch.rand(dim_size)
    cov = torch.diag(cov_diagonal)
    dist = torch.distributions.MultivariateNormal(loc=loc, covariance_matrix=cov)
    samples = []
    for _ in range(n_samples):
        sample = dist.sample()
        samples.append(sample)
        w.update(sample)

    sample_variance = torch.stack(samples).var(dim=0, unbiased=True)
    with optional(pytest.raises(RuntimeError), n_samples == 1):
        estimates = w.get_covariance(regularize=False)
        assert_equal(estimates, sample_variance)
예제 #13
0
파일: mcmc.py 프로젝트: zyxue/pyro
 def _traces(self, *args, **kwargs):
     logger_id = kwargs.pop("logger_id", "")
     log_queue = kwargs.pop("log_queue", None)
     self.logger = logging.getLogger("pyro.infer.mcmc")
     is_multiprocessing = log_queue is not None
     progress_bar = None
     if not is_multiprocessing:
         progress_bar = initialize_progbar(self.warmup_steps,
                                           self.num_samples,
                                           disable=self.disable_progbar)
     self.logger = initialize_logger(self.logger, logger_id, progress_bar,
                                     log_queue)
     self.kernel.setup(self.warmup_steps, *args, **kwargs)
     trace = self.kernel.initial_trace
     with optional(progress_bar, not is_multiprocessing):
         for trace in self._gen_samples(self.warmup_steps, trace):
             continue
         if progress_bar:
             progress_bar.set_description("Sample")
         for trace in self._gen_samples(self.num_samples, trace):
             yield (trace, 1.0)
     self.kernel.cleanup()
예제 #14
0
파일: api.py 프로젝트: nwjnwj/pyro
    def run(self, *args, **kwargs):
        """
        Run MCMC to generate samples and populate `self._samples`.

        Example usage:

        .. code-block:: python

            def model(data):
                ...

            nuts_kernel = NUTS(model)
            mcmc = MCMC(nuts_kernel, num_samples=500)
            mcmc.run(data)
            samples = mcmc.get_samples()

        :param args: optional arguments taken by
            :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`.
        :param kwargs: optional keywords arguments taken by
            :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`.
        """
        self._args, self._kwargs = args, kwargs
        num_samples = [0] * self.num_chains
        z_flat_acc = [[] for _ in range(self.num_chains)]
        with optional(pyro.validation_enabled(not self.disable_validation),
                      self.disable_validation is not None):
            for x, chain_id in self.sampler.run(*args, **kwargs):
                if num_samples[chain_id] == 0:
                    num_samples[chain_id] += 1
                    z_structure = x
                elif num_samples[chain_id] == self.num_samples + 1:
                    self._diagnostics[chain_id] = x
                else:
                    num_samples[chain_id] += 1
                    if self.num_chains > 1:
                        x_cloned = x.clone()
                        del x
                    else:
                        x_cloned = x
                    z_flat_acc[chain_id].append(x_cloned)

        z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc])

        # unpack latent
        pos = 0
        z_acc = z_structure.copy()
        for k in sorted(z_structure):
            shape = z_structure[k]
            next_pos = pos + shape.numel()
            z_acc[k] = z_flat_acc[:, :,
                                  pos:next_pos].reshape((self.num_chains,
                                                         self.num_samples) +
                                                        shape)
            pos = next_pos
        assert pos == z_flat_acc.shape[-1]

        # If transforms is not explicitly provided, infer automatically using
        # model args, kwargs.
        if self.transforms is None:
            # Use `kernel.transforms` when available
            if getattr(self.kernel, "transforms", None) is not None:
                self.transforms = self.kernel.transforms
            # Else, get transforms from model (e.g. in multiprocessing).
            elif self.kernel.model:
                warmup_steps = 0
                self.kernel.setup(warmup_steps, *args, **kwargs)
                self.transforms = self.kernel.transforms
            # Assign default value
            else:
                self.transforms = {}

        # transform samples back to constrained space
        for name, transform in self.transforms.items():
            z_acc[name] = transform.inv(z_acc[name])
        self._samples = z_acc

        # terminate the sampler (shut down worker processes)
        self.sampler.terminate(True)
예제 #15
0
    def sample(self, params):
        z, potential_energy, z_grads = self._fetch_from_cache()
        # recompute PE when cache is cleared
        if z is None:
            z = params
            potential_energy = self.potential_fn(z)
            self._cache(z, potential_energy)
        # return early if no sample sites
        elif len(z) == 0:
            self._t += 1
            self._mean_accept_prob = 1.
            if self._t > self._warmup_steps:
                self._accept_cnt += 1
            return z
        r, r_flat = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r) + potential_energy

        # Ideally, following a symplectic integrator trajectory, the energy is constant.
        # In that case, we can sample the proposal uniformly, and there is no need to use "slice".
        # However, it is not the case for real situation: there are errors during the computation.
        # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted
        # by u).
        # The sampling process goes as follows:
        #   first sampling u from initial state (z_0, r_0) according to
        #     u ~ Uniform(0, p(z_0, r_0)),
        #   then sampling state (z, r) from the integrator trajectory according to
        #     (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}).
        #
        # For more information about slice sampling method, see [3].
        # For another version of NUTS which uses multinomial sampling instead of slice sampling,
        # see [2].

        if self.use_multinomial_sampling:
            log_slice = -energy_current
        else:
            # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can
            # sample log_slice directly using `energy`, so as to avoid potential underflow or
            # overflow issues ([2]).
            slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t),
                                         dist.Exponential(scalar_like(energy_current, 1.)))
            log_slice = -energy_current - slice_exp_term

        z_left = z_right = z
        r_left = r_right = r
        z_left_grads = z_right_grads = z_grads
        accepted = False
        r_sum = r_flat
        sum_accept_probs = 0.
        num_proposals = 0
        tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.)

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation.
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            # doubling process, stop when turning or diverging
            tree_depth = 0
            while tree_depth < self._max_tree_depth:
                direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth),
                                        dist.Bernoulli(probs=scalar_like(tree_weight, 0.5)))
                direction = int(direction.item())
                if direction == 1:  # go to the right, start from the right leaf of current tree
                    new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice,
                                                direction, tree_depth, energy_current)
                    # update leaf for the next doubling process
                    z_right = new_tree.z_right
                    r_right = new_tree.r_right
                    z_right_grads = new_tree.z_right_grads
                else:  # go the the left, start from the left leaf of current tree
                    new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice,
                                                direction, tree_depth, energy_current)
                    z_left = new_tree.z_left
                    r_left = new_tree.r_left
                    z_left_grads = new_tree.z_left_grads

                sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs
                num_proposals = num_proposals + new_tree.num_proposals

                # stop doubling
                if new_tree.diverging:
                    if self._t >= self._warmup_steps:
                        self._divergences.append(self._t - self._warmup_steps)
                    break

                if new_tree.turning:
                    break

                tree_depth += 1

                if self.use_multinomial_sampling:
                    new_tree_prob = (new_tree.weight - tree_weight).exp()
                else:
                    new_tree_prob = new_tree.weight / tree_weight
                rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth),
                                   dist.Uniform(scalar_like(new_tree_prob, 0.),
                                                scalar_like(new_tree_prob, 1.)))
                if rand < new_tree_prob:
                    accepted = True
                    z = new_tree.z_proposal
                    self._cache(z, new_tree.z_proposal_pe, new_tree.z_proposal_grads)

                r_sum = r_sum + new_tree.r_sum
                if self._is_turning(r_left, r_right, r_sum):  # stop doubling
                    break
                else:  # update tree_weight
                    if self.use_multinomial_sampling:
                        tree_weight = _logaddexp(tree_weight, new_tree.weight)
                    else:
                        tree_weight = tree_weight + new_tree.weight

        accept_prob = sum_accept_probs / num_proposals

        self._t += 1
        if self._t > self._warmup_steps:
            n = self._t - self._warmup_steps
            if accepted:
                self._accept_cnt += 1
        else:
            n = self._t
            self._adapter.step(self._t, z, accept_prob)
        self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n

        return z.copy()
예제 #16
0
    def run(self, *args, **kwargs):
        """
        Run MCMC to generate samples and populate `self._samples`.

        Example usage:

        .. code-block:: python

            def model(data):
                ...

            nuts_kernel = NUTS(model)
            mcmc = MCMC(nuts_kernel, num_samples=500)
            mcmc.run(data)
            samples = mcmc.get_samples()

        :param args: optional arguments taken by
            :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`.
        :param kwargs: optional keywords arguments taken by
            :meth:`MCMCKernel.setup <pyro.infer.mcmc.mcmc_kernel.MCMCKernel.setup>`.
        """
        self._args, self._kwargs = args, kwargs
        num_samples = [0] * self.num_chains
        z_flat_acc = [[] for _ in range(self.num_chains)]
        with optional(pyro.validation_enabled(not self.disable_validation),
                      self.disable_validation is not None):
            # XXX we clone CUDA tensor args to resolve the issue "Invalid device pointer"
            # at https://github.com/pytorch/pytorch/issues/10375
            # This also resolves "RuntimeError: Cowardly refusing to serialize non-leaf tensor which
            # requires_grad", which happens with `jit_compile` under PyTorch 1.7
            args = [
                arg.detach() if torch.is_tensor(arg) else arg for arg in args
            ]
            for x, chain_id in self.sampler.run(*args, **kwargs):
                if num_samples[chain_id] == 0:
                    num_samples[chain_id] += 1
                    z_structure = x
                elif num_samples[chain_id] == self.num_samples + 1:
                    self._diagnostics[chain_id] = x
                else:
                    num_samples[chain_id] += 1
                    if self.num_chains > 1:
                        x_cloned = x.clone()
                        del x
                    else:
                        x_cloned = x
                    z_flat_acc[chain_id].append(x_cloned)

        z_flat_acc = torch.stack([torch.stack(l) for l in z_flat_acc])

        # unpack latent
        pos = 0
        z_acc = z_structure.copy()
        for k in sorted(z_structure):
            shape = z_structure[k]
            next_pos = pos + shape.numel()
            z_acc[k] = z_flat_acc[:, :,
                                  pos:next_pos].reshape((self.num_chains,
                                                         self.num_samples) +
                                                        shape)
            pos = next_pos
        assert pos == z_flat_acc.shape[-1]

        # If transforms is not explicitly provided, infer automatically using
        # model args, kwargs.
        if self.transforms is None:
            # Use `kernel.transforms` when available
            if getattr(self.kernel, "transforms", None) is not None:
                self.transforms = self.kernel.transforms
            # Else, get transforms from model (e.g. in multiprocessing).
            elif self.kernel.model:
                warmup_steps = 0
                self.kernel.setup(warmup_steps, *args, **kwargs)
                self.transforms = self.kernel.transforms
            # Assign default value
            else:
                self.transforms = {}

        # transform samples back to constrained space
        for name, transform in self.transforms.items():
            z_acc[name] = transform.inv(z_acc[name])
        self._samples = z_acc

        # terminate the sampler (shut down worker processes)
        self.sampler.terminate(True)
예제 #17
0
    def sample(self, trace):
        z, potential_energy, z_grads = self._fetch_from_cache()
        r, _ = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r) + potential_energy

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            z_new, r_new, z_grads_new, potential_energy_new = velocity_verlet(z, r, self._potential_energy,
                                                                              self.inverse_mass_matrix,
                                                                              self.step_size,
                                                                              self.batch_size,
                                                                              self.num_steps,
                                                                              z_grads=z_grads)
            # apply Metropolis correction.
            energy_proposal = self._kinetic_energy(r_new) + potential_energy_new
        delta_energy = energy_proposal - energy_current
        # Set accept prob to 0.0 if delta_energy is `NaN` which may be
        # the case for a diverging trajectory when using a large step size.
        if torch_isnan(delta_energy):
            accept_prob = delta_energy.new_tensor(0.0)
        else:
            accept_prob = (-delta_energy).exp().clamp(max=1.)
        rand = torch.rand(self.batch_size)
        accepted = rand < accept_prob
        self._accept_cnt += accepted.sum()/self.batch_size

        # select accepted zs to get z_new
        transitioned_z = {}
        for name in z:
            assert len(z_grads[name].shape) == 2
            assert z_grads[name].shape[0] == self.batch_size
            assert len(z[name].shape) == 2
            assert z[name].shape[0] == self.batch_size
            old_val = z[name]
            old_grad = z_grads[name]
            new_val = z[name]
            new_grad = z_grads_new[name]
            val_dim = old_val.shape[1]
            accept_val = accepted.view(self.batch_size, 1).repeat(1, val_dim)
            transitioned_z[name] = torch.where(accept_val,
                                               new_val,
                                               old_val)
            transitioned_grads = torch.where(accept_val,
                                             new_grad,
                                             old_grad)

        self._cache(transitioned_z,
                    potential_energy,
                    transitioned_grads)

        if self._t < self._warmup_steps:
            self._adapter.step(self._t, transitioned_z, accept_prob)

        self._t += 1

        # get trace with the constrained values for `z`.
        z = transitioned_z.copy()
        for name, transform in self.transforms.items():
            z[name] = transform.inv(z[name])
        return self._get_trace(z)
예제 #18
0
파일: nuts.py 프로젝트: zyxue/pyro
    def sample(self, trace):
        z = {
            name: node["value"].detach()
            for name, node in self._iter_latent_nodes(trace)
        }
        potential_energy, z_grads = self._fetch_from_cache()
        # automatically transform `z` to unconstrained space, if needed.
        for name, transform in self.transforms.items():
            z[name] = transform(z[name])
        r, r_flat = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r) + potential_energy if potential_energy is not None \
            else self._energy(z, r)

        # Ideally, following a symplectic integrator trajectory, the energy is constant.
        # In that case, we can sample the proposal uniformly, and there is no need to use "slice".
        # However, it is not the case for real situation: there are errors during the computation.
        # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted
        # by u).
        # The sampling process goes as follows:
        #   first sampling u from initial state (z_0, r_0) according to
        #     u ~ Uniform(0, p(z_0, r_0)),
        #   then sampling state (z, r) from the integrator trajectory according to
        #     (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}).
        #
        # For more information about slice sampling method, see [3].
        # For another version of NUTS which uses multinomial sampling instead of slice sampling,
        # see [2].

        if self.use_multinomial_sampling:
            log_slice = -energy_current
        else:
            # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can
            # sample log_slice directly using `energy`, so as to avoid potential underflow or
            # overflow issues ([2]).
            slice_exp_term = pyro.sample(
                "slicevar_exp_t={}".format(self._t),
                dist.Exponential(energy_current.new_tensor(1.)))
            log_slice = -energy_current - slice_exp_term

        z_left = z_right = z
        r_left = r_right = r
        z_left_grads = z_right_grads = z_grads
        accepted = False
        r_sum = r_flat
        if self.use_multinomial_sampling:
            tree_weight = energy_current.new_zeros(())
        else:
            tree_weight = energy_current.new_ones(())

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation.
        with optional(pyro.validation_enabled(False),
                      self._t < self._warmup_steps):
            # doubling process, stop when turning or diverging
            for tree_depth in range(self._max_tree_depth + 1):
                direction = pyro.sample(
                    "direction_t={}_treedepth={}".format(self._t, tree_depth),
                    dist.Bernoulli(probs=torch.ones(1) * 0.5))
                direction = int(direction.item())
                if direction == 1:  # go to the right, start from the right leaf of current tree
                    new_tree = self._build_tree(z_right, r_right,
                                                z_right_grads, log_slice,
                                                direction, tree_depth,
                                                energy_current)
                    # update leaf for the next doubling process
                    z_right = new_tree.z_right
                    r_right = new_tree.r_right
                    z_right_grads = new_tree.z_right_grads
                else:  # go the the left, start from the left leaf of current tree
                    new_tree = self._build_tree(z_left, r_left, z_left_grads,
                                                log_slice, direction,
                                                tree_depth, energy_current)
                    z_left = new_tree.z_left
                    r_left = new_tree.r_left
                    z_left_grads = new_tree.z_left_grads

                if new_tree.turning or new_tree.diverging:  # stop doubling
                    break

                if self.use_multinomial_sampling:
                    new_tree_prob = (new_tree.weight - tree_weight).exp()
                else:
                    new_tree_prob = new_tree.weight / tree_weight
                rand = pyro.sample(
                    "rand_t={}_treedepth={}".format(self._t, tree_depth),
                    dist.Uniform(new_tree_prob.new_tensor(0.),
                                 new_tree_prob.new_tensor(1.)))
                if rand < new_tree_prob:
                    accepted = True
                    z = new_tree.z_proposal
                    self._cache(new_tree.z_proposal_pe,
                                new_tree.z_proposal_grads)

                r_sum = r_sum + new_tree.r_sum
                if self._is_turning(r_left, r_right, r_sum):  # stop doubling
                    break
                else:  # update tree_weight
                    if self.use_multinomial_sampling:
                        tree_weight = logsumexp(torch.stack(
                            [tree_weight, new_tree.weight]),
                                                dim=0)
                    else:
                        tree_weight = tree_weight + new_tree.weight

        if self._t < self._warmup_steps:
            accept_prob = new_tree.sum_accept_probs / new_tree.num_proposals
            self._adapter.step(self._t, z, accept_prob)

        if accepted:
            self._accept_cnt += 1

        self._t += 1
        # get trace with the constrained values for `z`.
        for name, transform in self.transforms.items():
            z[name] = transform.inv(z[name])
        return self._get_trace(z)