def warmup(self, num_warmup_steps: int = 1000, compile: bool = False, **kwargs): """Warmup the sampler. Warmup is necessary to get values for the evaluator's parameters that are adapted to the geometry of the posterior distribution. While run will automatically run the warmup if it hasn't been done before, runnning this method independently gives access to the trace for the warmup phase and the values of the parameters for diagnostics. Parameters ---------- num_warmup_steps The number of warmup_steps to perform. progress_bar If True the progress of the warmup will be displayed. Otherwise it will use `lax.scan` to iterate (which is potentially faster). kwargs Parameters to pass to the evaluator's warmup. Returns ------- trace A Trace object that contains the warmup sampling chain, warmup sampling info and warmup info. """ last_state, parameters, warmup_chain = self.evaluator.warmup( self.rng_key, self.state, self.kernel_factory, self.num_chains, num_warmup_steps, compile, **kwargs, ) self.state = last_state self.parameters = parameters self.is_warmed_up = True # The evaluator must return `None` when no warmup is needed. if warmup_chain is None: return samples, sampling_info, warmup_info = self.evaluator.make_warmup_trace( chain=warmup_chain, unravel_fn=self.unravel_fn) trace = Trace( warmup_samples=samples, warmup_sampling_info=sampling_info, warmup_info=warmup_info, loglikelihood_contributions_fn=self.loglikelihood_contributions, ) return trace
def __init__(self, model_fn: FunctionType, trace: Trace, chain_id=0) -> None: """Create a generative function. We create a generative function, or stochastic program, by conditioning the values of a model's random variables. A typical application is to create a function that returns samples from the posterior predictive distribution. """ self.graph, self.namespace = mcx.core.parse(model_fn) self.model_fn = model_fn self.conditioning = trace self.call_fn, self.src = mcx.core.sample_posterior_predictive( self, trace.keys()) self.trace = trace self.chain_id = chain_id
def run( self, num_samples: int = 1000, num_warmup_steps: int = 1000, compile: bool = False, **warmup_kwargs, ) -> np.DeviceArray: """Run the posterior inference. For convenience we automatically run the warmup if it hasn't been run independently previously. Samples taken during the warmup phase are discarded by default. To keep them you can run: Parameters ---------- num_samples The number of samples to take from the posterior distribution. num_warmup_steps The number of warmup_steps to perform. compile If False the progress of the warmup and samplig will be displayed. Otherwise it will use `lax.scan` to iterate (which is potentially faster). warmup_kwargs Parameters to pass to the evaluator's warmup. Returns ------- trace A Trace object that contains the chains, some information about the inference process (e.g. divergences for evaluators in the HMC family). """ if not self.is_warmed_up: self.warmup(num_warmup_steps, compile, **warmup_kwargs) @jax.jit def update_chain(rng_key, parameters, chain_state): kernel = self.kernel_factory(*parameters) new_chain_state, info = kernel(rng_key, chain_state) return new_chain_state, info _, self.rng_key = jax.random.split(self.rng_key) rng_keys = jax.random.split(self.rng_key, num_samples) # The progress bar, displayed when compile=False, is an important # indicator for exploratory analysis, while sample_scan is optimal for # getting a large number of samples. Note that for small sample sizes # lax.scan is not dramatically faster than a for loop due to compile # time. Expect however an increasing improvement as the number of # samples increases. if compile: last_state, chain = sample_scan( update_chain, self.state, self.parameters, rng_keys, self.num_chains ) else: last_state, chain = sample_loop( update_chain, self.state, self.parameters, rng_keys, self.num_chains ) samples, sampling_info = self.evaluator.make_trace( chain=chain, unravel_fn=self.unravel_fn ) trace = Trace( samples=samples, sampling_info=sampling_info, loglikelihood_contributions_fn=self.loglikelihood_contributions, ) self.state = last_state return trace
def run( self, num_samples: int = 1000, num_warmup_steps: int = 1000, compile: bool = False, metrics: Sequence[Callable[..., Tuple[str, Callable, Callable]]] = [ divergences, online_gelman_rubin, ], **warmup_kwargs, ) -> Trace: """Run the posterior inference. For convenience we automatically run the warmup if it hasn't been run independently previously. Samples taken during the warmup phase are discarded by default. To keep them you can run: Parameters ---------- num_samples The number of samples to take from the posterior distribution. num_warmup_steps The number of warmup_steps to perform. compile If False the progress of the warmup and sampling will be displayed. Otherwise it will use `lax.scan` to iterate (which is potentially faster). metrics A list of functions to generate online metrics when sampling. Only used when `compile` is False. Each function must return two functions - an `init` function and an `update` function. warmup_kwargs Parameters to pass to the evaluator's warmup. Returns ------- trace A Trace object that contains the chains, some information about the inference process (e.g. divergences for evaluators in the HMC family). Notes ----- Passing functions to `metrics` may slow down sampling. It may be useful to have online metrics when building or diagnosing a model. """ if not self.is_warmed_up: self.warmup(num_warmup_steps, compile, **warmup_kwargs) @jax.jit def update_one_chain(rng_key, parameters, chain_state): kernel = self.kernel_factory(*parameters) new_chain_state, info = kernel(rng_key, chain_state) return new_chain_state, info _, self.rng_key = jax.random.split(self.rng_key) rng_keys = jax.random.split(self.rng_key, num_samples) # The progress bar, displayed when compile=False, is an important # indicator for exploratory analysis, while sample_scan is optimal for # getting a large number of samples. Note that for small sample sizes # lax.scan is not dramatically faster than a for loop due to compile # time. Expect however an increasing improvement as the number of # samples increases. if compile: last_state, chain = sample_scan(update_one_chain, self.state, self.parameters, rng_keys, self.num_chains) else: if metrics is None: metrics = () last_state, chain = sample_loop( update_one_chain, self.state, self.parameters, rng_keys, self.num_chains, metrics, ) samples, sampling_info = self.evaluator.make_trace( chain=chain, unravel_fn=self.unravel_fn) trace = Trace( samples=samples, sampling_info=sampling_info, loglikelihood_contributions_fn=self.loglikelihood_contributions, ) self.state = last_state return trace