Exemple #1
0
def test_clear_data_home():
    resource = REMOTE_DATASETS["test_remote"]
    assert not os.path.exists(resource.filename)
    load_arviz_data("test_remote")
    assert os.path.exists(resource.filename)
    clear_data_home(data_home=os.path.dirname(resource.filename))
    assert not os.path.exists(resource.filename)
def test_multi_model():
    models = {
        "centered": az.load_arviz_data("centered_eight"),
        "noncentered": az.load_arviz_data("non_centered_eight"),
    }
    multi_arviz_to_json(models, "multimodel.zip")
    # verify zip file is good
    with open("multimodel.zip", "rb") as f:
        z = zipfile.ZipFile(f)
        elements = z.namelist()
        for key in models:
            npz_name = key + ".npz"
            assert npz_name in elements
            z.extract(npz_name)
            check_zip(npz_name)
Exemple #3
0
    def test_id_conversion_args(self):
        stored = load_arviz_data("centered_eight")
        IVIES = [
            "Yale", "Harvard", "MIT", "Princeton", "Cornell", "Dartmouth",
            "Columbia", "Brown"
        ]
        # test dictionary argument...
        # I reverse engineered a dictionary out of the centered_eight
        # data. That's what this block of code does.
        d = stored.posterior.to_dict()
        d = d["data_vars"]
        test_dict = {}  # type: Dict[str, np.ndarray]
        for var_name in d:
            data = d[var_name]["data"]
            # this is a list of chains that is a list of samples...
            chain_arrs = []
            for chain in data:  # list of samples
                chain_arrs.append(np.array(chain))
            data_arr = np.stack(chain_arrs)
            test_dict[var_name] = data_arr

        inference_data = convert_to_inference_data(test_dict,
                                                   dims={"theta": ["Ivies"]},
                                                   coords={"Ivies": IVIES})

        assert isinstance(inference_data, InferenceData)
        assert set(
            inference_data.posterior.coords["Ivies"].values) == set(IVIES)
        assert inference_data.posterior["theta"].dims == ("chain", "draw",
                                                          "Ivies")
def test_arviz_to_json():
    import arviz as az

    data = az.load_arviz_data("centered_eight")
    arviz_to_json(data, "centered_eight.npz")
    check_zip("centered_eight.npz")

    # check that we can write to a file descriptor as well as a filename
    with open("centered_eight_as_f.npz", "wb") as f:
        arviz_to_json(data, f)
    check_zip("centered_eight_as_f.npz")
Exemple #5
0
def test_load_local_arviz_data():
    inference_data = load_arviz_data("centered_eight")
    assert isinstance(inference_data, InferenceData)
    assert set(inference_data.observed_data.obs.coords["school"].values) == {
        "Hotchkiss",
        "Mt. Hermon",
        "Choate",
        "Deerfield",
        "Phillips Andover",
        "St. Paul's",
        "Lawrenceville",
        "Phillips Exeter",
    }
    assert inference_data.posterior["theta"].dims == ("chain", "draw", "school")
Exemple #6
0
 def test_dataset_conversion_idempotent(self):
     inference_data = load_arviz_data("centered_eight")
     data_set = convert_to_dataset(inference_data.posterior)
     assert isinstance(data_set, xr.Dataset)
     assert set(data_set.coords["school"].values) == {
         "Hotchkiss",
         "Mt. Hermon",
         "Choate",
         "Deerfield",
         "Phillips Andover",
         "St. Paul's",
         "Lawrenceville",
         "Phillips Exeter",
     }
     assert data_set["theta"].dims == ("chain", "draw", "school")
Exemple #7
0
 def test_id_conversion_idempotent(self):
     stored = load_arviz_data("centered_eight")
     inference_data = convert_to_inference_data(stored)
     assert isinstance(inference_data, InferenceData)
     assert set(inference_data.observed_data.obs.coords["school"].values) == {
         "Hotchkiss",
         "Mt. Hermon",
         "Choate",
         "Deerfield",
         "Phillips Andover",
         "St. Paul's",
         "Lawrenceville",
         "Phillips Exeter",
     }
     assert inference_data.posterior["theta"].dims == ("chain", "draw", "school")
Exemple #8
0
def test_sel_method_chain_prior():
    idata = load_arviz_data("centered_eight")
    original_groups = getattr(idata, "_groups")
    idata_subset = idata.sel(inplace=False, chain_prior=False, chain=[0, 1, 3])
    groups = getattr(idata_subset, "_groups")
    assert np.all(np.isin(groups, original_groups))
    for group in groups:
        dataset_subset = getattr(idata_subset, group)
        dataset = getattr(idata, group)
        if "chain" in dataset.dims:
            assert "chain" in dataset_subset.dims
            if "prior" not in group:
                assert np.all(dataset_subset.chain.values == np.array([0, 1, 3]))
        else:
            assert "chain" not in dataset_subset.dims
    with pytest.raises(KeyError):
        idata.sel(inplace=False, chain_prior=True, chain=[0, 1, 3])
Exemple #9
0
    def test_repr_html(self):
        """Test if the function _repr_html is generating html."""
        idata = load_arviz_data("centered_eight")
        display_style = OPTIONS["display_style"]
        xr.set_options(display_style="html")
        html = idata._repr_html_()  # pylint: disable=protected-access

        assert html is not None
        assert "<div" in html
        for group in idata._groups:  # pylint: disable=protected-access
            assert group in html
            xr_data = getattr(idata, group)
            for item, _ in xr_data.items():
                assert item in html
        specific_style = ".xr-wrap{width:700px!important;}"
        assert specific_style in html

        xr.set_options(display_style="text")
        html = idata._repr_html_()  # pylint: disable=protected-access
        assert escape(idata.__repr__()) in html
        xr.set_options(display_style=display_style)
Exemple #10
0
 def test_map(self, use):
     idata = load_arviz_data("centered_eight")
     args = []
     kwargs = {}
     if use is None:
         fun = lambda x: x + 3
     elif use == "args":
         fun = lambda x, a: x + a
         args = [3]
     else:
         fun = lambda x, a: x + a
         kwargs = {"a": 3}
     groups = ("observed_data", "posterior_predictive")
     idata_map = idata.map(fun, groups, args=args, **kwargs)
     groups_map = idata_map._groups  # pylint: disable=protected-access
     assert groups_map == idata._groups  # pylint: disable=protected-access
     assert np.allclose(idata_map.observed_data.obs,
                        fun(idata.observed_data.obs, *args, **kwargs))
     assert np.allclose(
         idata_map.posterior_predictive.obs,
         fun(idata.posterior_predictive.obs, *args, **kwargs))
     assert np.allclose(idata_map.posterior.mu, idata.posterior.mu)
Exemple #11
0
def test_load_remote_arviz_data():
    assert load_arviz_data("test_remote")
Exemple #12
0
"""
ESS Local Plot
==============

_thumb: .7, .5
"""
import arviz as az

az.style.use("arviz-darkgrid")

idata = az.load_arviz_data("centered_eight")

az.plot_ess(idata, var_names=["mu"], kind="local", marker="_", ms=20, mew=2)
Exemple #13
0
"""
Joint Plot
==========

_thumb: .5, .8
"""
import matplotlib.pyplot as plt
import arviz as az

az.style.use("arviz-darkgrid")

data = az.load_arviz_data("non_centered_eight")

az.plot_joint(
    data,
    var_names=["theta"],
    coords={"school": ["Choate", "Phillips Andover"]},
    kind="hexbin",
    figsize=(10, 10),
)
plt.show()
Exemple #14
0
"""
Autocorrelation Plot
====================

_thumb: .8, .8
"""
import arviz as az

az.style.use('arviz-darkgrid')

data = az.load_arviz_data('centered_eight')
az.plot_autocorr(data, var_names=('tau', 'mu'))
Exemple #15
0
"""
Violinplot
==========

_thumb: .2, .8
"""
import arviz as az

az.style.use('arviz-darkgrid')

non_centered = az.load_arviz_data('non_centered_eight')
az.plot_violin(non_centered, var_names=["mu", "tau"], textsize=8)
Exemple #16
0
def test_waic():
    """Test widely available information criterion calculation"""
    centered = load_arviz_data('centered_eight')
    waic(centered)
Exemple #17
0
def test_bad_checksum():
    with pytest.raises(IOError):
        load_arviz_data("bad_checksum")
Exemple #18
0
 def setup_class(cls):
     cls.centered = load_arviz_data('centered_eight')
     cls.non_centered = load_arviz_data('non_centered_eight')
Exemple #19
0
"""
Ridgeplot
=========

_thumb: .8, .5
"""
import matplotlib.pyplot as plt
import arviz as az

az.style.use("arviz-darkgrid")

rugby_data = az.load_arviz_data("rugby")
axes = az.plot_forest(
    rugby_data,
    kind="ridgeplot",
    var_names=["defs"],
    linewidth=4,
    combined=True,
    ridgeplot_overlap=1.5,
    colors="blue",
    figsize=(9, 4),
)
axes[0].set_title("Relative defensive strength\nof Six Nation rugby teams")

plt.show()
Exemple #20
0
def test_summary_bad_fmt():
    centered = load_arviz_data("centered_eight")
    with pytest.raises(TypeError):
        summary(centered, fmt="bad_fmt")
Exemple #21
0
def test_summary_fmt(fmt):
    centered = load_arviz_data("centered_eight")
    summary(centered, fmt=fmt)
Exemple #22
0
 def setup_class(cls):
     cls.centered = load_arviz_data("centered_eight")
     cls.non_centered = load_arviz_data("non_centered_eight")
Exemple #23
0
"""
Quantile MCSE Errobar Plot
==========================

_thumb: .6, .4
"""
import arviz as az

az.style.use('arviz-darkgrid')

data = az.load_arviz_data('radon')
az.plot_mcse(data, var_names=["sigma_a"], color="C4", errorbar=True)
Exemple #24
0
def test_load_local_arviz_data():
    assert load_arviz_data("centered_eight")
Exemple #25
0
"""
Bayesian p-value Posterior plot
===============================

_thumb: .6, .5
"""
import matplotlib.pyplot as plt
import arviz as az

az.style.use("arviz-darkgrid")

data = az.load_arviz_data("regression1d")
az.plot_bpv(data)

plt.show()
Exemple #26
0
"""
Separationplot
==========

_thumb: .2, .8
"""
import matplotlib.pyplot as plt

import arviz as az

az.style.use("arviz-darkgrid")

idata = az.load_arviz_data("classification10d")

az.plot_separation(idata=idata, y="outcome", y_hat="outcome", figsize=(8, 1))

plt.show()
Exemple #27
0
def test_missing_dataset():
    with pytest.raises(ValueError):
        load_arviz_data("does not exist")
Exemple #28
0
"""
ESS Quantile Plot
=================

_thumb: .4, .5
"""
import matplotlib.pyplot as plt
import arviz as az

az.style.use("arviz-darkgrid")

idata = az.load_arviz_data("radon")

az.plot_ess(idata, var_names=["sigma_y"], kind="quantile", color="C4")

plt.show()
Exemple #29
0
def test_summary(include_circ):
    centered = load_arviz_data('centered_eight')
    summary(centered, include_circ=include_circ)
Exemple #30
0
"""
ELPD Plot
=========

_thumb: .6, .5
"""
import matplotlib.pyplot as plt
import arviz as az

az.style.use("arviz-darkgrid")

d1 = az.load_arviz_data("centered_eight")
d2 = az.load_arviz_data("non_centered_eight")

az.plot_elpd({"Centered eight": d1, "Non centered eight": d2}, xlabels=True)

plt.show()