def test_concat_group(copy, inplace, sequence): idata1 = from_dict( posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)} ) if copy and inplace: original_idata1_posterior_id = id(idata1.posterior) idata2 = from_dict(prior={"C": np.random.randn(2, 10, 2), "D": np.random.randn(2, 10, 5, 2)}) idata3 = from_dict(observed_data={"E": np.random.randn(100), "F": np.random.randn(2, 100)}) # basic case assert concat(idata1, idata2, copy=True, inplace=False) is not None if sequence: new_idata = concat((idata1, idata2, idata3), copy=copy, inplace=inplace) else: new_idata = concat(idata1, idata2, idata3, copy=copy, inplace=inplace) if inplace: assert new_idata is None new_idata = idata1 assert new_idata is not None test_dict = {"posterior": ["A", "B"], "prior": ["C", "D"], "observed_data": ["E", "F"]} fails = check_multiple_attrs(test_dict, new_idata) assert not fails if copy: if inplace: assert id(new_idata.posterior) == original_idata1_posterior_id else: assert id(new_idata.posterior) != id(idata1.posterior) assert id(new_idata.prior) != id(idata2.prior) assert id(new_idata.observed_data) != id(idata3.observed_data) else: assert id(new_idata.posterior) == id(idata1.posterior) assert id(new_idata.prior) == id(idata2.prior) assert id(new_idata.observed_data) == id(idata3.observed_data)
def _trace_to_arviz( trace=None, sample_stats=None, observed_data=None, prior_predictive=None, posterior_predictive=None, inplace=True, ): if trace is not None and isinstance(trace, dict): trace = {k: v.numpy() for k, v in trace.items()} if sample_stats is not None and isinstance(sample_stats, dict): sample_stats = {k: v.numpy().T for k, v in sample_stats.items()} if prior_predictive is not None and isinstance(prior_predictive, dict): prior_predictive = { k: v[np.newaxis] for k, v in prior_predictive.items() } if posterior_predictive is not None and isinstance(posterior_predictive, dict): if isinstance(trace, az.InferenceData) and inplace == True: return trace + az.from_dict( posterior_predictive=posterior_predictive) else: trace = None return az.from_dict( posterior=trace, sample_stats=sample_stats, prior_predictive=prior_predictive, posterior_predictive=posterior_predictive, observed_data=observed_data, )
def test_addition(): idata1 = from_dict( posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)} ) idata2 = from_dict(prior={"C": np.random.randn(2, 10, 2), "D": np.random.randn(2, 10, 5, 2)}) new_idata = idata1 + idata2 assert new_idata is not None test_dict = {"posterior": ["A", "B"], "prior": ["C", "D"]} fails = check_multiple_attrs(test_dict, new_idata) assert not fails
def trace_to_arviz( trace=None, sample_stats=None, observed_data=None, prior_predictive=None, posterior_predictive=None, inplace=True, ): """ Tensorflow to Arviz trace convertor. Creates an ArviZ's InferenceData object with inference, prediction and/or sampling data generated by PyMC4 Parameters ---------- trace : dict or InferenceData sample_stats : dict observed_data : dict prior_predictive : dict posterior_predictive : dict inplace : bool Returns ------- ArviZ's InferenceData object """ if trace is not None and isinstance(trace, dict): trace = { k: np.swapaxes(v.numpy(), 1, 0) for k, v in trace.items() if "/" in k } if sample_stats is not None and isinstance(sample_stats, dict): sample_stats = {k: v.numpy().T for k, v in sample_stats.items()} if prior_predictive is not None and isinstance(prior_predictive, dict): prior_predictive = { k: v[np.newaxis] for k, v in prior_predictive.items() } if posterior_predictive is not None and isinstance(posterior_predictive, dict): if isinstance(trace, az.InferenceData) and inplace == True: return trace + az.from_dict( posterior_predictive=posterior_predictive) else: trace = None return az.from_dict( posterior=trace, sample_stats=sample_stats, prior_predictive=prior_predictive, posterior_predictive=posterior_predictive, observed_data=observed_data, )
def test_inference_data_edge_cases(self): # create data log_likelihood = { "y": np.random.randn(4, 100), "log_likelihood": np.random.randn(4, 100, 8), } # log_likelihood to posterior assert from_dict(posterior=log_likelihood) is not None # dims == None assert from_dict(observed_data=log_likelihood, dims=None) is not None
def test_inference_data_edge_cases(self): # create data log_likelihood = { "y": np.random.randn(4, 100), "log_likelihood": np.random.randn(4, 100, 8), } # log_likelihood to posterior with pytest.warns(UserWarning, match="log_likelihood.+in posterior"): assert from_dict(posterior=log_likelihood) is not None # dims == None assert from_dict(observed_data=log_likelihood, dims=None) is not None
def test_inference_concat_keeps_all_fields(): """From failures observed in issue #907""" idata1 = from_dict(posterior={"A": [1, 2, 3, 4]}, sample_stats={"B": [2, 3, 4, 5]}) idata2 = from_dict(prior={"C": [1, 2, 3, 4]}, observed_data={"D": [2, 3, 4, 5]}) idata_c1 = concat(idata1, idata2) idata_c2 = concat(idata2, idata1) test_dict = {"posterior": ["A"], "sample_stats": ["B"], "prior": ["C"], "observed_data": ["D"]} fails_c1 = check_multiple_attrs(test_dict, idata_c1) assert not fails_c1 fails_c2 = check_multiple_attrs(test_dict, idata_c2) assert not fails_c2
def test_sel_method(inplace): data = np.random.normal(size=(4, 500, 8)) idata = from_dict( posterior={"a": data[..., 0], "b": data}, sample_stats={"a": data[..., 0], "b": data}, observed_data={"b": data[0, 0, :]}, posterior_predictive={"a": data[..., 0], "b": data}, ) original_groups = getattr(idata, "_groups") ndraws = idata.posterior.draw.values.size kwargs = {"draw": slice(200, None), "chain": slice(None, None, 2), "b_dim_0": [1, 2, 7]} if inplace: idata.sel(inplace=inplace, **kwargs) else: idata2 = idata.sel(inplace=inplace, **kwargs) assert idata2 is not idata idata = idata2 groups = getattr(idata, "_groups") assert np.all(np.isin(groups, original_groups)) for group in groups: dataset = getattr(idata, group) assert "b_dim_0" in dataset.dims assert np.all(dataset.b_dim_0.values == np.array(kwargs["b_dim_0"])) if group != "observed_data": assert np.all(np.isin(["chain", "draw"], dataset.dims)) assert np.all(dataset.chain.values == np.arange(0, 4, 2)) assert np.all(dataset.draw.values == np.arange(200, ndraws))
def to_inference_object(self) -> az.InferenceData: """Convert fitted Stan model into ``arviz`` InferenceData object. :returns: ``arviz`` InferenceData object with selected values :rtype: az.InferenceData """ if self.fit is None: raise ValueError("Model has not been fit!") # if already Inference, just return if isinstance(self.fit, az.InferenceData): return self.fit if not self.specified: raise ValueError("Model has not been specified!") inference = single_feature_fit_to_inference( fit=self.fit, params=self.params, coords=self.coords, dims=self.dims, posterior_predictive=self.posterior_predictive, log_likelihood=self.log_likelihood, **self.specifications) if self.include_observed_data: obs = az.from_dict(observed_data={"observed": self.dat["y"]}, coords={"tbl_sample": self.sample_names}, dims={"observed": ["tbl_sample"]}) inference = az.concat(inference, obs) return inference
def _run_corner(pandas=False, arviz=False, N=10000, seed=1234, ndim=3, factor=None, **kwargs): np.random.seed(seed) data1 = np.random.randn(ndim * 4 * N // 5).reshape([4 * N // 5, ndim]) data2 = 5 * np.random.rand(ndim)[None, :] + np.random.randn( ndim * N // 5).reshape([N // 5, ndim]) data = np.vstack([data1, data2]) if factor is not None: data[:, 0] *= factor data[:, 1] /= factor if pandas: # data = pd.DataFrame.from_items() data = pd.DataFrame.from_dict( OrderedDict(zip(map("d{0}".format, range(ndim)), data.T))) elif arviz: data = az.from_dict( posterior={"x": data[None]}, sample_stats={"diverging": data[None, :, 0] < 0.0}, ) kwargs["truths"] = {"x": np.random.randn(ndim)} fig = corner.corner(data, **kwargs) return fig
def fit_svi(model, n_draws=1000, autoguide=AutoLaplaceApproximation, loss=Trace_ELBO(), optim=optim.Adam(step_size=.00001), num_warmup=2000, use_gpu=False, num_chains=1, progress_bar=False, sampler=None, **kwargs): select_device(use_gpu, num_chains) guide = autoguide(model) svi = SVI(model=model, guide=guide, loss=loss, optim=optim, **kwargs) # Experimental interface: svi_result = svi.run(jax.random.PRNGKey(0), num_steps=num_warmup, stable_update=True, progress_bar=progress_bar) # Old: post = guide.sample_posterior(jax.random.PRNGKey(1), params=svi_result.params, sample_shape=(1, n_draws)) # New: #predictive = Predictive(guide, params=svi_result.params, num_samples=n_draws) #post = predictive(jax.random.PRNGKey(1), **kwargs) # Old interface: # init_state = svi.init(jax.random.PRNGKey(0)) # state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(n_draws))#, length=num_warmup) # svi_params = svi.get_params(state) # post = guide.sample_posterior(jax.random.PRNGKey(1), svi_params, (1, n_draws)) trace = az.from_dict(post) return trace, post
def get_inference_data(self, data, eight_schools_params, save_warmup=False): return from_dict( posterior=data.obj, posterior_predictive=data.obj, sample_stats=data.obj, prior=data.obj, prior_predictive=data.obj, sample_stats_prior=data.obj, warmup_posterior=data.obj, warmup_posterior_predictive=data.obj, predictions=data.obj, observed_data=eight_schools_params, coords={ "school": np.arange(8), }, pred_coords={ "school_pred": np.arange(8), }, dims={ "theta": ["school"], "eta": ["school"] }, pred_dims={ "theta": ["school_pred"], "eta": ["school_pred"] }, save_warmup=save_warmup, )
def test_to_dict(self, models): idata = models.model_1 test_data = from_dict(**idata.to_dict()) assert test_data for group in idata._groups_all: # pylint: disable=protected-access xr_data = getattr(idata, group) test_xr_data = getattr(test_data, group) assert xr_data.equals(test_xr_data)
def test_from_pymc_trace_inference_data(self): """Check if the error is raised successfully after passing InferenceData as trace""" idata = from_dict(posterior={ "A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2) }) assert isinstance(idata, InferenceData) with pytest.raises(ValueError): from_pymc3(trace=idata, model=pm.Model())
def get_data(self) -> az.InferenceData: """Get inference data. Returns: arviz.InferenceData: Inference data. """ dims, coords = self._get_dims() observed_data = {"nu": self.nu} constant_data = { "nu_err": np.zeros_like(self.nu) if self.nu_err is None else self.nu_err } data = az.from_dict( posterior=self.samples, prior_predictive=self.prior_predictive_samples, posterior_predictive=self.predictive_samples, sample_stats=self.sample_stats, observed_data=observed_data, constant_data=constant_data, dims=dims, coords=coords, ) # Add sample metadata info if self.sample_stats is not None: data.sample_stats.attrs.update(self.sample_metadata) if self.sample_metadata.get("method", None) == "nested": # The weights are just the logP in sampler stats with warnings.catch_warnings(): # Catch user warnings warnings.filterwarnings("ignore", category=UserWarning) data.add_groups( {"weighted_posterior": self.weighted_samples}, coords=coords, dims=dims, ) circ_var_names = self.get_circ_var_names() # Add unit, symbol and circular attributes to groups for group in data.groups(): for key in data[group].keys(): sub_key = key if key.endswith("_pred"): sub_key = key[:-5] else: sub_key = key unit = self.model.units.get(sub_key, u.Unit()) sym = self.model.symbols.get(sub_key, "") circ = 1 if sub_key in circ_var_names else 0 data[group][key].attrs["unit"] = unit.to_string() data[group][key].attrs["symbol"] = sym data[group][key].attrs["is_circular"] = circ return data
def test_concat_bad(): with pytest.raises(TypeError): concat("hello", "hello") idata = from_dict(posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)}) idata2 = from_dict(posterior={"A": np.random.randn(2, 10, 2)}) idata3 = from_dict(prior={"A": np.random.randn(2, 10, 2)}) with pytest.raises(TypeError): concat(idata, np.array([1, 2, 3, 4, 5])) with pytest.raises(TypeError): concat(idata, idata, dim=None) with pytest.raises(TypeError): concat(idata, idata2, dim="chain") with pytest.raises(TypeError): concat(idata2, idata, dim="chain") with pytest.raises(TypeError): concat(idata, idata3, dim="chain") with pytest.raises(TypeError): concat(idata3, idata, dim="chain")
def get_trace_stats(self, trace, statnames=[ 'log_likelihood', 'tree_size', 'diverging', 'energy', 'mean_tree_accept' ]): return az.from_dict( sample_stats={k: v.numpy().T for k, v in zip(statnames, trace)})
def test_concat_dim(dim, copy, inplace, sequence, reset_dim): idata1 = from_dict( posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)}, observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)}, ) if inplace: original_idata1_id = id(idata1) idata2 = from_dict( posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)}, observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)}, ) idata3 = from_dict( posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)}, observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)}, ) # basic case assert ( concat(idata1, idata2, dim=dim, copy=copy, inplace=False, reset_dim=reset_dim) is not None ) if sequence: new_idata = concat( (idata1, idata2, idata3), copy=copy, dim=dim, inplace=inplace, reset_dim=reset_dim ) else: new_idata = concat( idata1, idata2, idata3, dim=dim, copy=copy, inplace=inplace, reset_dim=reset_dim ) if inplace: assert new_idata is None new_idata = idata1 assert new_idata is not None test_dict = {"posterior": ["A", "B"], "observed_data": ["C", "D"]} fails = check_multiple_attrs(test_dict, new_idata) assert not fails if inplace: assert id(new_idata) == original_idata1_id else: assert id(new_idata) != id(idata1) assert getattr(new_idata.posterior, dim).size == 6 if dim == "chain" else 30 if reset_dim: assert np.all( getattr(new_idata.posterior, dim).values == (np.arange(6) if dim == "chain" else np.arange(30)) )
def sample_numpyro_nuts( log_posterior_fun, flat_init_params, parameter_summary, constrain_fun_dict={}, target_accept=0.8, draws=1000, tune=1000, chains=4, progress_bar=True, random_seed=10, chain_method="parallel", thinning=1, ): # Strongly inspired by: # https://github.com/pymc-devs/pymc3/blob/master/pymc3/sampling_jax.py#L116 def _sample(current_state, seed): step_size = jnp.ones_like(flat_init_params) nuts_kernel = NUTS( potential_fn=lambda x: -log_posterior_fun(x), target_accept_prob=target_accept, adapt_step_size=True, adapt_mass_matrix=True, dense_mass=False, ) numpyro = MCMC( nuts_kernel, num_warmup=tune, num_samples=draws, num_chains=chains, postprocess_fn=None, progress_bar=progress_bar, chain_method=chain_method, thinning=thinning, ) numpyro.run(seed, init_params=current_state) samples = numpyro.get_samples(group_by_chain=True) return samples seed = jax.random.PRNGKey(random_seed) samples = _sample(flat_init_params, seed) # Reshape this into a dict def reshape_single_chain(theta): fun_to_map = lambda x: apply_constraints( reconstruct(x, parameter_summary, jnp.reshape), constrain_fun_dict )[0] return vmap(fun_to_map)(theta) samples = vmap(reshape_single_chain)(samples) return az.from_dict(posterior=samples)
def test_potentials_warning(self): warning_msg = "The effect of Potentials on other parameters is ignored during" with pm.Model() as m: a = pm.Normal("a", 0, 1) p = pm.Potential("p", a + 1) obs = pm.Normal("obs", a, 1, observed=5) trace = az.from_dict({"a": np.random.rand(10)}) with pytest.warns(UserWarning, match=warning_msg): pm.sample_posterior_predictive_w(samples=5, traces=[trace, trace], models=[m, m])
def sample(self, n: int = 500) -> az.InferenceData: """Generate samples from posterior distribution.""" samples = self.approx.sample(n) q_samples = self.order.split_samples(samples, n) q_samples = dict(**q_samples, **self.deterministics_callback(q_samples)) # Add a new axis so as n_chains=1 for InferenceData: handles shape issues trace = {k: v.numpy()[np.newaxis] for k, v in q_samples.items()} trace = az.from_dict(trace, observed_data=self.state.observed_values) return trace
def test_concat_bad(): with pytest.raises(TypeError): concat("hello", "hello") idata = from_dict(posterior={ "A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2) }) with pytest.raises(TypeError): concat(idata, np.array([1, 2, 3, 4, 5])) with pytest.raises(NotImplementedError): concat(idata, idata)
def get_inference_data(self, data, eight_schools_params): return from_dict( posterior=data.obj, posterior_predictive=data.obj, sample_stats=data.obj, prior=data.obj, prior_predictive=data.obj, sample_stats_prior=data.obj, observed_data=eight_schools_params, coords={"school": np.arange(eight_schools_params["J"])}, dims={"theta": ["school"], "eta": ["school"]}, )
def plot_pair_evolution(params, mcmc_kernel): files = [] for file in os.listdir("./results"): if file.startswith("output_it"): files.append(file) files = sorted(files, key=lambda x: int(x[9:-4])) arvzs, cs = [], [] for i, f in enumerate(files): with open(f"./results/{f}", "rb") as obj: i += 1 samples, stats = pickle.load(obj) if mcmc_kernel == "hmc": stats_names = [ "logprob", "diverging", "acceptance", "step_size" ] elif mcmc_kernel == "nuts": stats_names = [ "logprob", "tree_size", "diverging", "energy", "acceptance", "mean_tree_accept", ] sample_stats = {k: v for k, v in zip(stats_names, stats)} var_names = [p.name for p in params] posterior = {k: v for k, v in zip(var_names, samples)} arvzs.append( az.from_dict(posterior=posterior, sample_stats=sample_stats)) cs.append(i / len(files)) ax = az.plot_pair( arvzs[0], kind="scatter", marginals=True, marginal_kwargs={"color": cm.hot_r(cs[0])}, scatter_kwargs={"c": cm.hot_r(cs[0])}, ) for arvz, c in zip(arvzs[0:], cs[0:]): az.plot_pair( arvz, kind="scatter", marginals=True, marginal_kwargs={"color": cm.hot_r(c)}, scatter_kwargs={"c": cm.hot_r(c)}, ax=ax, ) fig = ax.ravel()[0].figure fig.savefig("./results/pair_plot_evo.png")
def convert_to_arviz(chains,model,burnin,remove_stuck=False,iparas_time=None,phy_space=False): rands,chain,lnprob,lnprob_i = chains para_names = model.parameter_names if remove_stuck: old_number_chains = lnprob.shape[0] chain_mask = np.where(lnprob[:,burnin:].min(axis=1)>-1e10)[0] chain = chain[chain_mask,burnin:] lnprob_i = lnprob_i[burnin:,chain_mask] lnprob = lnprob[chain_mask,burnin:] print(old_number_chains - len(chain_mask),' chains are stuck') else: chain = chain[:,burnin:] lnprob_i = lnprob_i[burnin:,:] lnprob = lnprob[:,burnin:] if iparas_time is not None: chain_s = chain.shape iparas_name = list(model.calc_implicit_parameters(iparas_time).keys()) chain_woth_ip = np.zeros((chain_s[0],chain_s[1],chain_s[2]+len(iparas_name))) for n_chain in range(chain_s[0]): for n_sample in range(chain_s[1]): model.set_parameters_fit_array(chain[n_chain,n_sample],mode='bayes') chain_woth_ip[n_chain,n_sample,chain_s[2]:] = [model.calc_implicit_parameters(iparas_time)[name] for name in iparas_name] sample_stats = {'log_likelihood':np.transpose(lnprob_i,axes=(1,0,2)),'loglike_values':lnprob} chain_dict = {name: chain[:,:,i] for i,name in enumerate(model.parameter_names) } chain_dict_phy = model.transform_fit_to_physical(chain_dict,mode='bayes') chain = np.transpose(np.array([chain_dict_phy[name] for name in model.parameter_names]),axes=(1,2,0)) chain_woth_ip[:,:,:chain_s[2]] = chain return az.from_dict(posterior={'a':chain_woth_ip},sample_stats=sample_stats,dims={'a':['ac']},coords ={'ac':list(para_names)+iparas_name}),list(para_names)+iparas_name else: if phy_space: chain_dict = {name: chain[:,:,i] for i,name in enumerate(model.parameter_names) } chain_dict_phy = model.transform_fit_to_physical(chain_dict,mode='bayes') chain = np.transpose(np.array([chain_dict_phy[name] for name in model.parameter_names]),axes=(1,2,0)) sample_stats = {'log_likelihood':np.transpose(lnprob_i,axes=(1,0,2)),'loglike_values':lnprob} return az.from_dict(posterior={'a':chain},sample_stats=sample_stats,dims={'a':['ac']},coords ={'ac':para_names}),list(para_names)
def simulate_fixed(fixed_value, n_chains, n_samples, model): samples = np.zeros((n_chains, n_samples, model.n_params)) model.t_init[-1] = fixed_value for i in range(n_chains): samples[i] = model.sample_slice_gibbs5(n_samples) samples = samples[:, :, :-1] az_dict = to_arviz_dict(samples, {"$\\beta$":np.arange(4), "$\\theta$":np.arange(4, 7)}, burnin=1000) az_data = az.from_dict(az_dict) summary = az.summary(az_data) summary['sampling_effeciency'] = summary['ess_mean'] / np.product(samples.shape[:2]) return samples, summary
def test_to_dict_warmup(self): idata = create_data_random(groups=[ "posterior", "sample_stats", "observed_data", "warmup_posterior", "warmup_posterior_predictive", ]) test_data = from_dict(**idata.to_dict(), save_warmup=True) assert test_data for group in idata._groups_all: # pylint: disable=protected-access xr_data = getattr(idata, group) test_xr_data = getattr(test_data, group) assert xr_data.equals(test_xr_data)
def mcmc_diagnostic_plots(posterior, sample_stats, it): az_trace = az.from_dict(posterior=posterior, sample_stats=sample_stats) """ # 2 parameters or more for these pair plots if len(az_trace.posterior.data_vars) > 1: ax = az.plot_pair(az_trace, kind="hexbin", gridsize=30, marginals=True) fig = ax.ravel()[0].figure plt.ylim((5000, 30000)) plt.xlim((1e-10, 1e-7)) fig.savefig(f"./results/pair_plot_it{it}.png") plt.clf() ax = az.plot_pair( az_trace, kind=["scatter", "kde"], kde_kwargs={"fill_last": False}, point_estimate="mean", marginals=True, ) fig = ax.ravel()[0].figure fig.savefig(f"./results/point_estimate_plot_it{it}.png") plt.clf() """ ax = az.plot_trace(az_trace, divergences=False) fig = ax.ravel()[0].figure fig.savefig(f"./results/trace_plot_it{it}.png") plt.clf() ax = az.plot_posterior(az_trace) fig = ax.ravel()[0].figure fig.savefig(f"./results/posterior_plot_it{it}.png") plt.clf() lag = np.minimum(len(list(posterior.values())[0]), 100) ax = az.plot_autocorr(az_trace, max_lag=lag) fig = ax.ravel()[0].figure fig.savefig(f"./results/autocorr_plot_it{it}.png") plt.clf() ax = az.plot_ess(az_trace, kind="evolution") fig = ax.ravel()[0].figure fig.savefig(f"./results/ess_evolution_plot_it{it}.png") plt.clf() plt.close()
def test_del(self, use): # create inference data object data = np.random.normal(size=(4, 500, 8)) idata = from_dict( posterior={ "a": data[..., 0], "b": data }, sample_stats={ "a": data[..., 0], "b": data }, observed_data={"b": data[0, 0, :]}, posterior_predictive={ "a": data[..., 0], "b": data }, ) # assert inference data object has all attributes test_dict = { "posterior": ("a", "b"), "sample_stats": ("a", "b"), "observed_data": ["b"], "posterior_predictive": ("a", "b"), } fails = check_multiple_attrs(test_dict, idata) assert not fails # assert _groups attribute contains all groups groups = getattr(idata, "_groups") assert all([group in groups for group in test_dict]) # Use del method if use == "del": del idata.sample_stats else: delattr(idata, "sample_stats") # assert attribute has been removed test_dict.pop("sample_stats") fails = check_multiple_attrs(test_dict, idata) assert not fails assert not hasattr(idata, "sample_stats") # assert _groups attribute has been updated assert "sample_stats" not in getattr(idata, "_groups")
def sample(self, n: int = 500, include_log_likelihood: bool = False) -> az.InferenceData: """Generate samples from posterior distribution.""" samples = self.approx.sample(n) q_samples = self.order.split_samples(samples, n) q_samples = dict(**q_samples, **self.deterministics_callback(q_samples)) # Add a new axis so as n_chains=1 for InferenceData: handles shape issues trace = {k: v.numpy()[np.newaxis] for k, v in q_samples.items()} log_likelihood_dict = dict() if include_log_likelihood: log_likelihood_dict = calculate_log_likelihood(self.model, trace, self.state) trace = az.from_dict( trace, observed_data=self.state.observed_values, log_likelihood=log_likelihood_dict if include_log_likelihood else None, ) return trace