def dict_to_dataset( data, library=None, coords=None, dims=None, attrs=None, default_dims=None, skip_event_dims=None, index_origin=None, ): """Temporal workaround for dict_to_dataset. Once ArviZ>0.11.2 release is available, only two changes are needed for everything to work. 1) this should be deleted, 2) dict_to_dataset should be imported as is from arviz, no underscore, also remove unnecessary imports """ if default_dims is None: return _dict_to_dataset(data, library=library, coords=coords, dims=dims, skip_event_dims=skip_event_dims) else: out_data = {} for name, vals in data.items(): vals = np.atleast_1d(vals) val_dims = dims.get(name) val_dims, coords = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords) coords = { key: xr.IndexVariable((key, ), data=coords[key]) for key in val_dims } out_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords) return xr.Dataset(data_vars=out_data, attrs=make_attrs(library=library))
def sample_numpyro_nuts( draws: int = 1000, tune: int = 1000, chains: int = 4, target_accept: float = 0.8, random_seed: int = None, initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, model: Optional[Model] = None, var_names=None, progress_bar: bool = True, keep_untransformed: bool = False, chain_method: str = "parallel", idata_kwargs: Optional[Dict] = None, nuts_kwargs: Optional[Dict] = None, ): """ Draw samples from the posterior using the NUTS method from the ``numpyro`` library. Parameters ---------- draws : int, default 1000 The number of samples to draw. The number of tuned samples are discarded by default. tune : int, default 1000 Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. chains : int, default 4 The number of chains to sample. target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. random_seed : int, default 10 Random seed used by the sampling steps. model : Model, optional Model to sample from. The model needs to have free random variables. When inside a ``with`` model context, it defaults to that model, otherwise the model must be passed explicitly. var_names : iterable of str, optional Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior progress_bar : bool, default True Whether or not to display a progress bar in the command line. The bar shows the percentage of completion, the sampling speed in samples per second (SPS), and the estimated remaining time until completion ("expected time of arrival"; ETA). keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. chain_method : str, default "parallel" Specify how samples should be drawn. The choices include "sequential", "parallel", and "vectorized". idata_kwargs : dict, optional Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood`` key to indicate that the pointwise log likelihood should not be included in the returned object. Returns ------- InferenceData ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and pointwise log likeihood values (unless skipped with ``idata_kwargs``). """ import numpyro from numpyro.infer import MCMC, NUTS model = modelcontext(model) if var_names is None: var_names = model.unobserved_value_vars vars_to_sample = list( get_default_varnames(var_names, include_transformed=keep_untransformed)) coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals for cname, cvals in model.coords.items() if cvals is not None } if hasattr(model, "RV_dims"): dims = { var_name: [dim for dim in dims if dim is not None] for var_name, dims in model.RV_dims.items() } else: dims = {} if random_seed is None: random_seed = model.rng_seeder.randint(2**30, dtype=np.int64) tic1 = datetime.now() print("Compiling...", file=sys.stdout) init_params = _get_batched_jittered_initial_points( model=model, chains=chains, initvals=initvals, random_seed=random_seed, ) logp_fn = get_jaxified_logp(model, negative_logp=False) if nuts_kwargs is None: nuts_kwargs = {} nuts_kernel = NUTS( potential_fn=logp_fn, target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, **nuts_kwargs, ) pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, chain_method=chain_method, progress_bar=progress_bar, ) tic2 = datetime.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) print("Sampling...", file=sys.stdout) map_seed = jax.random.PRNGKey(random_seed) if chains > 1: map_seed = jax.random.split(map_seed, chains) pmap_numpyro.run( map_seed, init_params=init_params, extra_fields=( "num_steps", "potential_energy", "energy", "adapt_state.step_size", "accept_prob", "diverging", ), ) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) tic3 = datetime.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v]) result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result tic4 = datetime.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) if idata_kwargs is None: idata_kwargs = {} else: idata_kwargs = idata_kwargs.copy() if idata_kwargs.pop("log_likelihood", True): log_likelihood = _get_log_likelihood(model, raw_mcmc_samples) else: log_likelihood = None attrs = { "sampling_time": (tic3 - tic2).total_seconds(), } posterior = mcmc_samples az_trace = az.from_dict( posterior=posterior, log_likelihood=log_likelihood, observed_data=find_observations(model), sample_stats=_sample_stats_to_xarray(pmap_numpyro), coords=coords, dims=dims, attrs=make_attrs(attrs, library=numpyro), **idata_kwargs, ) return az_trace
def sample_blackjax_nuts( draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=10, initvals=None, model=None, var_names=None, keep_untransformed=False, chain_method="parallel", idata_kwargs=None, ): """ Draw samples from the posterior using the NUTS method from the ``blackjax`` library. Parameters ---------- draws : int, default 1000 The number of samples to draw. The number of tuned samples are discarded by default. tune : int, default 1000 Number of iterations to tune. Samplers adjust the step sizes, scalings or similar during tuning. Tuning samples will be drawn in addition to the number specified in the ``draws`` argument. chains : int, default 4 The number of chains to sample. target_accept : float in [0, 1]. The step size is tuned such that we approximate this acceptance rate. Higher values like 0.9 or 0.95 often work better for problematic posteriors. random_seed : int, default 10 Random seed used by the sampling steps. model : Model, optional Model to sample from. The model needs to have free random variables. When inside a ``with`` model context, it defaults to that model, otherwise the model must be passed explicitly. var_names : iterable of str, optional Names of variables for which to compute the posterior samples. Defaults to all variables in the posterior keep_untransformed : bool, default False Include untransformed variables in the posterior samples. Defaults to False. chain_method : str, default "parallel" Specify how samples should be drawn. The choices include "parallel", and "vectorized". idata_kwargs : dict, optional Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as value for the ``log_likelihood`` key to indicate that the pointwise log likelihood should not be included in the returned object. Returns ------- InferenceData ArviZ ``InferenceData`` object that contains the posterior samples, together with their respective sample stats and pointwise log likeihood values (unless skipped with ``idata_kwargs``). """ import blackjax model = modelcontext(model) if var_names is None: var_names = model.unobserved_value_vars vars_to_sample = list( get_default_varnames(var_names, include_transformed=keep_untransformed)) coords = { cname: np.array(cvals) if isinstance(cvals, tuple) else cvals for cname, cvals in model.coords.items() if cvals is not None } if hasattr(model, "RV_dims"): dims = { var_name: [dim for dim in dims if dim is not None] for var_name, dims in model.RV_dims.items() } else: dims = {} tic1 = datetime.now() print("Compiling...", file=sys.stdout) init_params = _get_batched_jittered_initial_points( model=model, chains=chains, initvals=initvals, random_seed=random_seed, ) if chains == 1: init_params = [np.stack(init_params)] init_params = [ np.stack(init_state) for init_state in zip(*init_params) ] logprob_fn = get_jaxified_logp(model) seed = jax.random.PRNGKey(random_seed) keys = jax.random.split(seed, chains) get_posterior_samples = partial( _blackjax_inference_loop, logprob_fn=logprob_fn, tune=tune, draws=draws, target_accept=target_accept, ) tic2 = datetime.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) print("Sampling...", file=sys.stdout) # Adapted from numpyro if chain_method == "parallel": map_fn = jax.pmap elif chain_method == "vectorized": map_fn = jax.vmap else: raise ValueError( "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"') states, _ = map_fn(get_posterior_samples)(keys, init_params) raw_mcmc_samples = states.position tic3 = datetime.now() print("Sampling time = ", tic3 - tic2, file=sys.stdout) print("Transforming variables...", file=sys.stdout) mcmc_samples = {} for v in vars_to_sample: jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[v]) result = jax.vmap(jax.vmap(jax_fn))(*raw_mcmc_samples)[0] mcmc_samples[v.name] = result tic4 = datetime.now() print("Transformation time = ", tic4 - tic3, file=sys.stdout) if idata_kwargs is None: idata_kwargs = {} else: idata_kwargs = idata_kwargs.copy() if idata_kwargs.pop("log_likelihood", True): log_likelihood = _get_log_likelihood(model, raw_mcmc_samples) else: log_likelihood = None attrs = { "sampling_time": (tic3 - tic2).total_seconds(), } posterior = mcmc_samples az_trace = az.from_dict( posterior=posterior, log_likelihood=log_likelihood, observed_data=find_observations(model), coords=coords, dims=dims, attrs=make_attrs(attrs, library=blackjax), **idata_kwargs, ) return az_trace