def test_clear_data_home(): resource = REMOTE_DATASETS["test_remote"] assert not os.path.exists(resource.filename) load_arviz_data("test_remote") assert os.path.exists(resource.filename) clear_data_home(data_home=os.path.dirname(resource.filename)) assert not os.path.exists(resource.filename)
def test_multi_model(): models = { "centered": az.load_arviz_data("centered_eight"), "noncentered": az.load_arviz_data("non_centered_eight"), } multi_arviz_to_json(models, "multimodel.zip") # verify zip file is good with open("multimodel.zip", "rb") as f: z = zipfile.ZipFile(f) elements = z.namelist() for key in models: npz_name = key + ".npz" assert npz_name in elements z.extract(npz_name) check_zip(npz_name)
def test_id_conversion_args(self): stored = load_arviz_data("centered_eight") IVIES = [ "Yale", "Harvard", "MIT", "Princeton", "Cornell", "Dartmouth", "Columbia", "Brown" ] # test dictionary argument... # I reverse engineered a dictionary out of the centered_eight # data. That's what this block of code does. d = stored.posterior.to_dict() d = d["data_vars"] test_dict = {} # type: Dict[str, np.ndarray] for var_name in d: data = d[var_name]["data"] # this is a list of chains that is a list of samples... chain_arrs = [] for chain in data: # list of samples chain_arrs.append(np.array(chain)) data_arr = np.stack(chain_arrs) test_dict[var_name] = data_arr inference_data = convert_to_inference_data(test_dict, dims={"theta": ["Ivies"]}, coords={"Ivies": IVIES}) assert isinstance(inference_data, InferenceData) assert set( inference_data.posterior.coords["Ivies"].values) == set(IVIES) assert inference_data.posterior["theta"].dims == ("chain", "draw", "Ivies")
def test_arviz_to_json(): import arviz as az data = az.load_arviz_data("centered_eight") arviz_to_json(data, "centered_eight.npz") check_zip("centered_eight.npz") # check that we can write to a file descriptor as well as a filename with open("centered_eight_as_f.npz", "wb") as f: arviz_to_json(data, f) check_zip("centered_eight_as_f.npz")
def test_load_local_arviz_data(): inference_data = load_arviz_data("centered_eight") assert isinstance(inference_data, InferenceData) assert set(inference_data.observed_data.obs.coords["school"].values) == { "Hotchkiss", "Mt. Hermon", "Choate", "Deerfield", "Phillips Andover", "St. Paul's", "Lawrenceville", "Phillips Exeter", } assert inference_data.posterior["theta"].dims == ("chain", "draw", "school")
def test_dataset_conversion_idempotent(self): inference_data = load_arviz_data("centered_eight") data_set = convert_to_dataset(inference_data.posterior) assert isinstance(data_set, xr.Dataset) assert set(data_set.coords["school"].values) == { "Hotchkiss", "Mt. Hermon", "Choate", "Deerfield", "Phillips Andover", "St. Paul's", "Lawrenceville", "Phillips Exeter", } assert data_set["theta"].dims == ("chain", "draw", "school")
def test_id_conversion_idempotent(self): stored = load_arviz_data("centered_eight") inference_data = convert_to_inference_data(stored) assert isinstance(inference_data, InferenceData) assert set(inference_data.observed_data.obs.coords["school"].values) == { "Hotchkiss", "Mt. Hermon", "Choate", "Deerfield", "Phillips Andover", "St. Paul's", "Lawrenceville", "Phillips Exeter", } assert inference_data.posterior["theta"].dims == ("chain", "draw", "school")
def test_sel_method_chain_prior(): idata = load_arviz_data("centered_eight") original_groups = getattr(idata, "_groups") idata_subset = idata.sel(inplace=False, chain_prior=False, chain=[0, 1, 3]) groups = getattr(idata_subset, "_groups") assert np.all(np.isin(groups, original_groups)) for group in groups: dataset_subset = getattr(idata_subset, group) dataset = getattr(idata, group) if "chain" in dataset.dims: assert "chain" in dataset_subset.dims if "prior" not in group: assert np.all(dataset_subset.chain.values == np.array([0, 1, 3])) else: assert "chain" not in dataset_subset.dims with pytest.raises(KeyError): idata.sel(inplace=False, chain_prior=True, chain=[0, 1, 3])
def test_repr_html(self): """Test if the function _repr_html is generating html.""" idata = load_arviz_data("centered_eight") display_style = OPTIONS["display_style"] xr.set_options(display_style="html") html = idata._repr_html_() # pylint: disable=protected-access assert html is not None assert "<div" in html for group in idata._groups: # pylint: disable=protected-access assert group in html xr_data = getattr(idata, group) for item, _ in xr_data.items(): assert item in html specific_style = ".xr-wrap{width:700px!important;}" assert specific_style in html xr.set_options(display_style="text") html = idata._repr_html_() # pylint: disable=protected-access assert escape(idata.__repr__()) in html xr.set_options(display_style=display_style)
def test_map(self, use): idata = load_arviz_data("centered_eight") args = [] kwargs = {} if use is None: fun = lambda x: x + 3 elif use == "args": fun = lambda x, a: x + a args = [3] else: fun = lambda x, a: x + a kwargs = {"a": 3} groups = ("observed_data", "posterior_predictive") idata_map = idata.map(fun, groups, args=args, **kwargs) groups_map = idata_map._groups # pylint: disable=protected-access assert groups_map == idata._groups # pylint: disable=protected-access assert np.allclose(idata_map.observed_data.obs, fun(idata.observed_data.obs, *args, **kwargs)) assert np.allclose( idata_map.posterior_predictive.obs, fun(idata.posterior_predictive.obs, *args, **kwargs)) assert np.allclose(idata_map.posterior.mu, idata.posterior.mu)
def test_load_remote_arviz_data(): assert load_arviz_data("test_remote")
""" ESS Local Plot ============== _thumb: .7, .5 """ import arviz as az az.style.use("arviz-darkgrid") idata = az.load_arviz_data("centered_eight") az.plot_ess(idata, var_names=["mu"], kind="local", marker="_", ms=20, mew=2)
""" Joint Plot ========== _thumb: .5, .8 """ import matplotlib.pyplot as plt import arviz as az az.style.use("arviz-darkgrid") data = az.load_arviz_data("non_centered_eight") az.plot_joint( data, var_names=["theta"], coords={"school": ["Choate", "Phillips Andover"]}, kind="hexbin", figsize=(10, 10), ) plt.show()
""" Autocorrelation Plot ==================== _thumb: .8, .8 """ import arviz as az az.style.use('arviz-darkgrid') data = az.load_arviz_data('centered_eight') az.plot_autocorr(data, var_names=('tau', 'mu'))
""" Violinplot ========== _thumb: .2, .8 """ import arviz as az az.style.use('arviz-darkgrid') non_centered = az.load_arviz_data('non_centered_eight') az.plot_violin(non_centered, var_names=["mu", "tau"], textsize=8)
def test_waic(): """Test widely available information criterion calculation""" centered = load_arviz_data('centered_eight') waic(centered)
def test_bad_checksum(): with pytest.raises(IOError): load_arviz_data("bad_checksum")
def setup_class(cls): cls.centered = load_arviz_data('centered_eight') cls.non_centered = load_arviz_data('non_centered_eight')
""" Ridgeplot ========= _thumb: .8, .5 """ import matplotlib.pyplot as plt import arviz as az az.style.use("arviz-darkgrid") rugby_data = az.load_arviz_data("rugby") axes = az.plot_forest( rugby_data, kind="ridgeplot", var_names=["defs"], linewidth=4, combined=True, ridgeplot_overlap=1.5, colors="blue", figsize=(9, 4), ) axes[0].set_title("Relative defensive strength\nof Six Nation rugby teams") plt.show()
def test_summary_bad_fmt(): centered = load_arviz_data("centered_eight") with pytest.raises(TypeError): summary(centered, fmt="bad_fmt")
def test_summary_fmt(fmt): centered = load_arviz_data("centered_eight") summary(centered, fmt=fmt)
def setup_class(cls): cls.centered = load_arviz_data("centered_eight") cls.non_centered = load_arviz_data("non_centered_eight")
""" Quantile MCSE Errobar Plot ========================== _thumb: .6, .4 """ import arviz as az az.style.use('arviz-darkgrid') data = az.load_arviz_data('radon') az.plot_mcse(data, var_names=["sigma_a"], color="C4", errorbar=True)
def test_load_local_arviz_data(): assert load_arviz_data("centered_eight")
""" Bayesian p-value Posterior plot =============================== _thumb: .6, .5 """ import matplotlib.pyplot as plt import arviz as az az.style.use("arviz-darkgrid") data = az.load_arviz_data("regression1d") az.plot_bpv(data) plt.show()
""" Separationplot ========== _thumb: .2, .8 """ import matplotlib.pyplot as plt import arviz as az az.style.use("arviz-darkgrid") idata = az.load_arviz_data("classification10d") az.plot_separation(idata=idata, y="outcome", y_hat="outcome", figsize=(8, 1)) plt.show()
def test_missing_dataset(): with pytest.raises(ValueError): load_arviz_data("does not exist")
""" ESS Quantile Plot ================= _thumb: .4, .5 """ import matplotlib.pyplot as plt import arviz as az az.style.use("arviz-darkgrid") idata = az.load_arviz_data("radon") az.plot_ess(idata, var_names=["sigma_y"], kind="quantile", color="C4") plt.show()
def test_summary(include_circ): centered = load_arviz_data('centered_eight') summary(centered, include_circ=include_circ)
""" ELPD Plot ========= _thumb: .6, .5 """ import matplotlib.pyplot as plt import arviz as az az.style.use("arviz-darkgrid") d1 = az.load_arviz_data("centered_eight") d2 = az.load_arviz_data("non_centered_eight") az.plot_elpd({"Centered eight": d1, "Non centered eight": d2}, xlabels=True) plt.show()