Example #1
0
File: sample.py Project: rlouf/mcx
    def warmup(self,
               num_warmup_steps: int = 1000,
               compile: bool = False,
               **kwargs):
        """Warmup the sampler.

        Warmup is necessary to get values for the evaluator's parameters that are
        adapted to the geometry of the posterior distribution. While run will
        automatically run the warmup if it hasn't been done before, runnning
        this method independently gives access to the trace for the warmup
        phase and the values of the parameters for diagnostics.

        Parameters
        ----------
        num_warmup_steps
            The number of warmup_steps to perform.
        progress_bar
            If True the progress of the warmup will be displayed. Otherwise it
            will use `lax.scan` to iterate (which is potentially faster).
        kwargs
            Parameters to pass to the evaluator's warmup.

        Returns
        -------
        trace
            A Trace object that contains the warmup sampling chain, warmup sampling info
            and warmup info.

        """
        last_state, parameters, warmup_chain = self.evaluator.warmup(
            self.rng_key,
            self.state,
            self.kernel_factory,
            self.num_chains,
            num_warmup_steps,
            compile,
            **kwargs,
        )
        self.state = last_state
        self.parameters = parameters
        self.is_warmed_up = True

        # The evaluator must return `None` when no warmup is needed.
        if warmup_chain is None:
            return

        samples, sampling_info, warmup_info = self.evaluator.make_warmup_trace(
            chain=warmup_chain, unravel_fn=self.unravel_fn)

        trace = Trace(
            warmup_samples=samples,
            warmup_sampling_info=sampling_info,
            warmup_info=warmup_info,
            loglikelihood_contributions_fn=self.loglikelihood_contributions,
        )

        return trace
Example #2
0
    def __init__(self,
                 model_fn: FunctionType,
                 trace: Trace,
                 chain_id=0) -> None:
        """Create a generative function.

        We create a generative function, or stochastic program, by conditioning
        the values of a model's random variables. A typical application is to
        create a function that returns samples from the posterior predictive
        distribution.

        """
        self.graph, self.namespace = mcx.core.parse(model_fn)
        self.model_fn = model_fn
        self.conditioning = trace

        self.call_fn, self.src = mcx.core.sample_posterior_predictive(
            self, trace.keys())
        self.trace = trace
        self.chain_id = chain_id
Example #3
0
    def run(
        self,
        num_samples: int = 1000,
        num_warmup_steps: int = 1000,
        compile: bool = False,
        **warmup_kwargs,
    ) -> np.DeviceArray:
        """Run the posterior inference.

        For convenience we automatically run the warmup if it hasn't been run
        independently previously. Samples taken during the warmup phase are
        discarded by default. To keep them you can run:

        Parameters
        ----------
        num_samples
            The number of samples to take from the posterior distribution.
        num_warmup_steps
            The number of warmup_steps to perform.
        compile
            If False the progress of the warmup and samplig will be displayed.
            Otherwise it will use `lax.scan` to iterate (which is potentially
            faster).
        warmup_kwargs
            Parameters to pass to the evaluator's warmup.

        Returns
        -------
        trace
            A Trace object that contains the chains, some information about
            the inference process (e.g. divergences for evaluators in the
            HMC family).

        """
        if not self.is_warmed_up:
            self.warmup(num_warmup_steps, compile, **warmup_kwargs)

        @jax.jit
        def update_chain(rng_key, parameters, chain_state):
            kernel = self.kernel_factory(*parameters)
            new_chain_state, info = kernel(rng_key, chain_state)
            return new_chain_state, info

        _, self.rng_key = jax.random.split(self.rng_key)
        rng_keys = jax.random.split(self.rng_key, num_samples)

        # The progress bar, displayed when compile=False, is an important
        # indicator for exploratory analysis, while sample_scan is optimal for
        # getting a large number of samples.  Note that for small sample sizes
        # lax.scan is not dramatically faster than a for loop due to compile
        # time. Expect however an increasing improvement as the number of
        # samples increases.
        if compile:
            last_state, chain = sample_scan(
                update_chain, self.state, self.parameters, rng_keys, self.num_chains
            )
        else:
            last_state, chain = sample_loop(
                update_chain, self.state, self.parameters, rng_keys, self.num_chains
            )

        samples, sampling_info = self.evaluator.make_trace(
            chain=chain, unravel_fn=self.unravel_fn
        )
        trace = Trace(
            samples=samples,
            sampling_info=sampling_info,
            loglikelihood_contributions_fn=self.loglikelihood_contributions,
        )

        self.state = last_state

        return trace
Example #4
0
File: sample.py Project: rlouf/mcx
    def run(
        self,
        num_samples: int = 1000,
        num_warmup_steps: int = 1000,
        compile: bool = False,
        metrics: Sequence[Callable[..., Tuple[str, Callable, Callable]]] = [
            divergences,
            online_gelman_rubin,
        ],
        **warmup_kwargs,
    ) -> Trace:
        """Run the posterior inference.

        For convenience we automatically run the warmup if it hasn't been run
        independently previously. Samples taken during the warmup phase are
        discarded by default. To keep them you can run:

        Parameters
        ----------
        num_samples
            The number of samples to take from the posterior distribution.
        num_warmup_steps
            The number of warmup_steps to perform.
        compile
            If False the progress of the warmup and sampling will be displayed.
            Otherwise it will use `lax.scan` to iterate (which is potentially
            faster).
        metrics
            A list of functions to generate online metrics when sampling. Only
            used when `compile` is False. Each function must return two functions -
            an `init` function and an `update` function.
        warmup_kwargs
            Parameters to pass to the evaluator's warmup.

        Returns
        -------
        trace
            A Trace object that contains the chains, some information about
            the inference process (e.g. divergences for evaluators in the
            HMC family).

        Notes
        -----
        Passing functions to `metrics` may slow down sampling. It may be useful to have
        online metrics when building or diagnosing a model.
        """
        if not self.is_warmed_up:
            self.warmup(num_warmup_steps, compile, **warmup_kwargs)

        @jax.jit
        def update_one_chain(rng_key, parameters, chain_state):
            kernel = self.kernel_factory(*parameters)
            new_chain_state, info = kernel(rng_key, chain_state)
            return new_chain_state, info

        _, self.rng_key = jax.random.split(self.rng_key)
        rng_keys = jax.random.split(self.rng_key, num_samples)

        # The progress bar, displayed when compile=False, is an important
        # indicator for exploratory analysis, while sample_scan is optimal for
        # getting a large number of samples.  Note that for small sample sizes
        # lax.scan is not dramatically faster than a for loop due to compile
        # time. Expect however an increasing improvement as the number of
        # samples increases.
        if compile:
            last_state, chain = sample_scan(update_one_chain, self.state,
                                            self.parameters, rng_keys,
                                            self.num_chains)
        else:
            if metrics is None:
                metrics = ()
            last_state, chain = sample_loop(
                update_one_chain,
                self.state,
                self.parameters,
                rng_keys,
                self.num_chains,
                metrics,
            )

        samples, sampling_info = self.evaluator.make_trace(
            chain=chain, unravel_fn=self.unravel_fn)
        trace = Trace(
            samples=samples,
            sampling_info=sampling_info,
            loglikelihood_contributions_fn=self.loglikelihood_contributions,
        )

        self.state = last_state

        return trace