Exemple #1
0
def test_concat_group(copy, inplace, sequence):
    idata1 = from_dict(
        posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)}
    )
    if copy and inplace:
        original_idata1_posterior_id = id(idata1.posterior)
    idata2 = from_dict(prior={"C": np.random.randn(2, 10, 2), "D": np.random.randn(2, 10, 5, 2)})
    idata3 = from_dict(observed_data={"E": np.random.randn(100), "F": np.random.randn(2, 100)})
    # basic case
    assert concat(idata1, idata2, copy=True, inplace=False) is not None
    if sequence:
        new_idata = concat((idata1, idata2, idata3), copy=copy, inplace=inplace)
    else:
        new_idata = concat(idata1, idata2, idata3, copy=copy, inplace=inplace)
    if inplace:
        assert new_idata is None
        new_idata = idata1
    assert new_idata is not None
    test_dict = {"posterior": ["A", "B"], "prior": ["C", "D"], "observed_data": ["E", "F"]}
    fails = check_multiple_attrs(test_dict, new_idata)
    assert not fails
    if copy:
        if inplace:
            assert id(new_idata.posterior) == original_idata1_posterior_id
        else:
            assert id(new_idata.posterior) != id(idata1.posterior)
        assert id(new_idata.prior) != id(idata2.prior)
        assert id(new_idata.observed_data) != id(idata3.observed_data)
    else:
        assert id(new_idata.posterior) == id(idata1.posterior)
        assert id(new_idata.prior) == id(idata2.prior)
        assert id(new_idata.observed_data) == id(idata3.observed_data)
Exemple #2
0
def _trace_to_arviz(
    trace=None,
    sample_stats=None,
    observed_data=None,
    prior_predictive=None,
    posterior_predictive=None,
    inplace=True,
):

    if trace is not None and isinstance(trace, dict):
        trace = {k: v.numpy() for k, v in trace.items()}
    if sample_stats is not None and isinstance(sample_stats, dict):
        sample_stats = {k: v.numpy().T for k, v in sample_stats.items()}
    if prior_predictive is not None and isinstance(prior_predictive, dict):
        prior_predictive = {
            k: v[np.newaxis]
            for k, v in prior_predictive.items()
        }
    if posterior_predictive is not None and isinstance(posterior_predictive,
                                                       dict):
        if isinstance(trace, az.InferenceData) and inplace == True:
            return trace + az.from_dict(
                posterior_predictive=posterior_predictive)
        else:
            trace = None

    return az.from_dict(
        posterior=trace,
        sample_stats=sample_stats,
        prior_predictive=prior_predictive,
        posterior_predictive=posterior_predictive,
        observed_data=observed_data,
    )
Exemple #3
0
def test_addition():
    idata1 = from_dict(
        posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)}
    )
    idata2 = from_dict(prior={"C": np.random.randn(2, 10, 2), "D": np.random.randn(2, 10, 5, 2)})
    new_idata = idata1 + idata2
    assert new_idata is not None
    test_dict = {"posterior": ["A", "B"], "prior": ["C", "D"]}
    fails = check_multiple_attrs(test_dict, new_idata)
    assert not fails
Exemple #4
0
def trace_to_arviz(
    trace=None,
    sample_stats=None,
    observed_data=None,
    prior_predictive=None,
    posterior_predictive=None,
    inplace=True,
):
    """
    Tensorflow to Arviz trace convertor.

    Creates an ArviZ's InferenceData object with inference, prediction and/or sampling data
    generated by PyMC4

    Parameters
    ----------
    trace : dict or InferenceData
    sample_stats : dict
    observed_data : dict
    prior_predictive : dict
    posterior_predictive : dict
    inplace : bool

    Returns
    -------
    ArviZ's InferenceData object
    """
    if trace is not None and isinstance(trace, dict):
        trace = {
            k: np.swapaxes(v.numpy(), 1, 0)
            for k, v in trace.items() if "/" in k
        }
    if sample_stats is not None and isinstance(sample_stats, dict):
        sample_stats = {k: v.numpy().T for k, v in sample_stats.items()}
    if prior_predictive is not None and isinstance(prior_predictive, dict):
        prior_predictive = {
            k: v[np.newaxis]
            for k, v in prior_predictive.items()
        }
    if posterior_predictive is not None and isinstance(posterior_predictive,
                                                       dict):
        if isinstance(trace, az.InferenceData) and inplace == True:
            return trace + az.from_dict(
                posterior_predictive=posterior_predictive)
        else:
            trace = None

    return az.from_dict(
        posterior=trace,
        sample_stats=sample_stats,
        prior_predictive=prior_predictive,
        posterior_predictive=posterior_predictive,
        observed_data=observed_data,
    )
Exemple #5
0
    def test_inference_data_edge_cases(self):
        # create data
        log_likelihood = {
            "y": np.random.randn(4, 100),
            "log_likelihood": np.random.randn(4, 100, 8),
        }

        # log_likelihood to posterior
        assert from_dict(posterior=log_likelihood) is not None

        # dims == None
        assert from_dict(observed_data=log_likelihood, dims=None) is not None
Exemple #6
0
    def test_inference_data_edge_cases(self):
        # create data
        log_likelihood = {
            "y": np.random.randn(4, 100),
            "log_likelihood": np.random.randn(4, 100, 8),
        }

        # log_likelihood to posterior
        with pytest.warns(UserWarning, match="log_likelihood.+in posterior"):
            assert from_dict(posterior=log_likelihood) is not None

        # dims == None
        assert from_dict(observed_data=log_likelihood, dims=None) is not None
Exemple #7
0
def test_inference_concat_keeps_all_fields():
    """From failures observed in issue #907"""
    idata1 = from_dict(posterior={"A": [1, 2, 3, 4]}, sample_stats={"B": [2, 3, 4, 5]})
    idata2 = from_dict(prior={"C": [1, 2, 3, 4]}, observed_data={"D": [2, 3, 4, 5]})

    idata_c1 = concat(idata1, idata2)
    idata_c2 = concat(idata2, idata1)

    test_dict = {"posterior": ["A"], "sample_stats": ["B"], "prior": ["C"], "observed_data": ["D"]}

    fails_c1 = check_multiple_attrs(test_dict, idata_c1)
    assert not fails_c1
    fails_c2 = check_multiple_attrs(test_dict, idata_c2)
    assert not fails_c2
Exemple #8
0
def test_sel_method(inplace):
    data = np.random.normal(size=(4, 500, 8))
    idata = from_dict(
        posterior={"a": data[..., 0], "b": data},
        sample_stats={"a": data[..., 0], "b": data},
        observed_data={"b": data[0, 0, :]},
        posterior_predictive={"a": data[..., 0], "b": data},
    )
    original_groups = getattr(idata, "_groups")
    ndraws = idata.posterior.draw.values.size
    kwargs = {"draw": slice(200, None), "chain": slice(None, None, 2), "b_dim_0": [1, 2, 7]}
    if inplace:
        idata.sel(inplace=inplace, **kwargs)
    else:
        idata2 = idata.sel(inplace=inplace, **kwargs)
        assert idata2 is not idata
        idata = idata2
    groups = getattr(idata, "_groups")
    assert np.all(np.isin(groups, original_groups))
    for group in groups:
        dataset = getattr(idata, group)
        assert "b_dim_0" in dataset.dims
        assert np.all(dataset.b_dim_0.values == np.array(kwargs["b_dim_0"]))
        if group != "observed_data":
            assert np.all(np.isin(["chain", "draw"], dataset.dims))
            assert np.all(dataset.chain.values == np.arange(0, 4, 2))
            assert np.all(dataset.draw.values == np.arange(200, ndraws))
Exemple #9
0
    def to_inference_object(self) -> az.InferenceData:
        """Convert fitted Stan model into ``arviz`` InferenceData object.

        :returns: ``arviz`` InferenceData object with selected values
        :rtype: az.InferenceData
        """
        if self.fit is None:
            raise ValueError("Model has not been fit!")

        # if already Inference, just return
        if isinstance(self.fit, az.InferenceData):
            return self.fit

        if not self.specified:
            raise ValueError("Model has not been specified!")

        inference = single_feature_fit_to_inference(
            fit=self.fit,
            params=self.params,
            coords=self.coords,
            dims=self.dims,
            posterior_predictive=self.posterior_predictive,
            log_likelihood=self.log_likelihood,
            **self.specifications)

        if self.include_observed_data:
            obs = az.from_dict(observed_data={"observed": self.dat["y"]},
                               coords={"tbl_sample": self.sample_names},
                               dims={"observed": ["tbl_sample"]})
            inference = az.concat(inference, obs)
        return inference
Exemple #10
0
def _run_corner(pandas=False,
                arviz=False,
                N=10000,
                seed=1234,
                ndim=3,
                factor=None,
                **kwargs):
    np.random.seed(seed)
    data1 = np.random.randn(ndim * 4 * N // 5).reshape([4 * N // 5, ndim])
    data2 = 5 * np.random.rand(ndim)[None, :] + np.random.randn(
        ndim * N // 5).reshape([N // 5, ndim])
    data = np.vstack([data1, data2])
    if factor is not None:
        data[:, 0] *= factor
        data[:, 1] /= factor
    if pandas:
        # data = pd.DataFrame.from_items()
        data = pd.DataFrame.from_dict(
            OrderedDict(zip(map("d{0}".format, range(ndim)), data.T)))
    elif arviz:
        data = az.from_dict(
            posterior={"x": data[None]},
            sample_stats={"diverging": data[None, :, 0] < 0.0},
        )
        kwargs["truths"] = {"x": np.random.randn(ndim)}

    fig = corner.corner(data, **kwargs)
    return fig
Exemple #11
0
def fit_svi(model,
            n_draws=1000,
            autoguide=AutoLaplaceApproximation,
            loss=Trace_ELBO(),
            optim=optim.Adam(step_size=.00001),
            num_warmup=2000,
            use_gpu=False,
            num_chains=1,
            progress_bar=False,
            sampler=None,
            **kwargs):
    select_device(use_gpu, num_chains)
    guide = autoguide(model)
    svi = SVI(model=model, guide=guide, loss=loss, optim=optim, **kwargs)
    # Experimental interface:
    svi_result = svi.run(jax.random.PRNGKey(0),
                         num_steps=num_warmup,
                         stable_update=True,
                         progress_bar=progress_bar)
    # Old:
    post = guide.sample_posterior(jax.random.PRNGKey(1),
                                  params=svi_result.params,
                                  sample_shape=(1, n_draws))
    # New:
    #predictive = Predictive(guide,  params=svi_result.params, num_samples=n_draws)
    #post = predictive(jax.random.PRNGKey(1), **kwargs)

    # Old interface:
    # init_state = svi.init(jax.random.PRNGKey(0))
    # state, loss = lax.scan(lambda x, i: svi.update(x), init_state, jnp.zeros(n_draws))#, length=num_warmup)
    # svi_params = svi.get_params(state)
    # post = guide.sample_posterior(jax.random.PRNGKey(1), svi_params, (1, n_draws))

    trace = az.from_dict(post)
    return trace, post
Exemple #12
0
 def get_inference_data(self,
                        data,
                        eight_schools_params,
                        save_warmup=False):
     return from_dict(
         posterior=data.obj,
         posterior_predictive=data.obj,
         sample_stats=data.obj,
         prior=data.obj,
         prior_predictive=data.obj,
         sample_stats_prior=data.obj,
         warmup_posterior=data.obj,
         warmup_posterior_predictive=data.obj,
         predictions=data.obj,
         observed_data=eight_schools_params,
         coords={
             "school": np.arange(8),
         },
         pred_coords={
             "school_pred": np.arange(8),
         },
         dims={
             "theta": ["school"],
             "eta": ["school"]
         },
         pred_dims={
             "theta": ["school_pred"],
             "eta": ["school_pred"]
         },
         save_warmup=save_warmup,
     )
Exemple #13
0
 def test_to_dict(self, models):
     idata = models.model_1
     test_data = from_dict(**idata.to_dict())
     assert test_data
     for group in idata._groups_all:  # pylint: disable=protected-access
         xr_data = getattr(idata, group)
         test_xr_data = getattr(test_data, group)
         assert xr_data.equals(test_xr_data)
Exemple #14
0
 def test_from_pymc_trace_inference_data(self):
     """Check if the error is raised successfully after passing InferenceData as trace"""
     idata = from_dict(posterior={
         "A": np.random.randn(2, 10, 2),
         "B": np.random.randn(2, 10, 5, 2)
     })
     assert isinstance(idata, InferenceData)
     with pytest.raises(ValueError):
         from_pymc3(trace=idata, model=pm.Model())
    def get_data(self) -> az.InferenceData:
        """Get inference data.

        Returns:
            arviz.InferenceData: Inference data.
        """
        dims, coords = self._get_dims()
        observed_data = {"nu": self.nu}
        constant_data = {
            "nu_err":
            np.zeros_like(self.nu) if self.nu_err is None else self.nu_err
        }
        data = az.from_dict(
            posterior=self.samples,
            prior_predictive=self.prior_predictive_samples,
            posterior_predictive=self.predictive_samples,
            sample_stats=self.sample_stats,
            observed_data=observed_data,
            constant_data=constant_data,
            dims=dims,
            coords=coords,
        )

        # Add sample metadata info
        if self.sample_stats is not None:
            data.sample_stats.attrs.update(self.sample_metadata)

        if self.sample_metadata.get("method", None) == "nested":
            # The weights are just the logP in sampler stats
            with warnings.catch_warnings():
                # Catch user warnings
                warnings.filterwarnings("ignore", category=UserWarning)
                data.add_groups(
                    {"weighted_posterior": self.weighted_samples},
                    coords=coords,
                    dims=dims,
                )

        circ_var_names = self.get_circ_var_names()
        # Add unit, symbol and circular attributes to groups
        for group in data.groups():
            for key in data[group].keys():
                sub_key = key
                if key.endswith("_pred"):
                    sub_key = key[:-5]
                else:
                    sub_key = key

                unit = self.model.units.get(sub_key, u.Unit())
                sym = self.model.symbols.get(sub_key, "")
                circ = 1 if sub_key in circ_var_names else 0

                data[group][key].attrs["unit"] = unit.to_string()
                data[group][key].attrs["symbol"] = sym
                data[group][key].attrs["is_circular"] = circ

        return data
Exemple #16
0
def test_concat_bad():
    with pytest.raises(TypeError):
        concat("hello", "hello")
    idata = from_dict(posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)})
    idata2 = from_dict(posterior={"A": np.random.randn(2, 10, 2)})
    idata3 = from_dict(prior={"A": np.random.randn(2, 10, 2)})
    with pytest.raises(TypeError):
        concat(idata, np.array([1, 2, 3, 4, 5]))
    with pytest.raises(TypeError):
        concat(idata, idata, dim=None)
    with pytest.raises(TypeError):
        concat(idata, idata2, dim="chain")
    with pytest.raises(TypeError):
        concat(idata2, idata, dim="chain")
    with pytest.raises(TypeError):
        concat(idata, idata3, dim="chain")
    with pytest.raises(TypeError):
        concat(idata3, idata, dim="chain")
Exemple #17
0
 def get_trace_stats(self,
                     trace,
                     statnames=[
                         'log_likelihood', 'tree_size', 'diverging',
                         'energy', 'mean_tree_accept'
                     ]):
     return az.from_dict(
         sample_stats={k: v.numpy().T
                       for k, v in zip(statnames, trace)})
Exemple #18
0
def test_concat_dim(dim, copy, inplace, sequence, reset_dim):
    idata1 = from_dict(
        posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)},
        observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)},
    )
    if inplace:
        original_idata1_id = id(idata1)
    idata2 = from_dict(
        posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)},
        observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)},
    )
    idata3 = from_dict(
        posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)},
        observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)},
    )
    # basic case
    assert (
        concat(idata1, idata2, dim=dim, copy=copy, inplace=False, reset_dim=reset_dim) is not None
    )
    if sequence:
        new_idata = concat(
            (idata1, idata2, idata3), copy=copy, dim=dim, inplace=inplace, reset_dim=reset_dim
        )
    else:
        new_idata = concat(
            idata1, idata2, idata3, dim=dim, copy=copy, inplace=inplace, reset_dim=reset_dim
        )
    if inplace:
        assert new_idata is None
        new_idata = idata1
    assert new_idata is not None
    test_dict = {"posterior": ["A", "B"], "observed_data": ["C", "D"]}
    fails = check_multiple_attrs(test_dict, new_idata)
    assert not fails
    if inplace:
        assert id(new_idata) == original_idata1_id
    else:
        assert id(new_idata) != id(idata1)
    assert getattr(new_idata.posterior, dim).size == 6 if dim == "chain" else 30
    if reset_dim:
        assert np.all(
            getattr(new_idata.posterior, dim).values
            == (np.arange(6) if dim == "chain" else np.arange(30))
        )
Exemple #19
0
def sample_numpyro_nuts(
    log_posterior_fun,
    flat_init_params,
    parameter_summary,
    constrain_fun_dict={},
    target_accept=0.8,
    draws=1000,
    tune=1000,
    chains=4,
    progress_bar=True,
    random_seed=10,
    chain_method="parallel",
    thinning=1,
):
    # Strongly inspired by:
    # https://github.com/pymc-devs/pymc3/blob/master/pymc3/sampling_jax.py#L116
    def _sample(current_state, seed):

        step_size = jnp.ones_like(flat_init_params)

        nuts_kernel = NUTS(
            potential_fn=lambda x: -log_posterior_fun(x),
            target_accept_prob=target_accept,
            adapt_step_size=True,
            adapt_mass_matrix=True,
            dense_mass=False,
        )

        numpyro = MCMC(
            nuts_kernel,
            num_warmup=tune,
            num_samples=draws,
            num_chains=chains,
            postprocess_fn=None,
            progress_bar=progress_bar,
            chain_method=chain_method,
            thinning=thinning,
        )

        numpyro.run(seed, init_params=current_state)
        samples = numpyro.get_samples(group_by_chain=True)
        return samples

    seed = jax.random.PRNGKey(random_seed)
    samples = _sample(flat_init_params, seed)

    # Reshape this into a dict
    def reshape_single_chain(theta):
        fun_to_map = lambda x: apply_constraints(
            reconstruct(x, parameter_summary, jnp.reshape), constrain_fun_dict
        )[0]
        return vmap(fun_to_map)(theta)

    samples = vmap(reshape_single_chain)(samples)

    return az.from_dict(posterior=samples)
Exemple #20
0
    def test_potentials_warning(self):
        warning_msg = "The effect of Potentials on other parameters is ignored during"
        with pm.Model() as m:
            a = pm.Normal("a", 0, 1)
            p = pm.Potential("p", a + 1)
            obs = pm.Normal("obs", a, 1, observed=5)

        trace = az.from_dict({"a": np.random.rand(10)})
        with pytest.warns(UserWarning, match=warning_msg):
            pm.sample_posterior_predictive_w(samples=5, traces=[trace, trace], models=[m, m])
Exemple #21
0
    def sample(self, n: int = 500) -> az.InferenceData:
        """Generate samples from posterior distribution."""
        samples = self.approx.sample(n)
        q_samples = self.order.split_samples(samples, n)
        q_samples = dict(**q_samples,
                         **self.deterministics_callback(q_samples))

        # Add a new axis so as n_chains=1 for InferenceData: handles shape issues
        trace = {k: v.numpy()[np.newaxis] for k, v in q_samples.items()}
        trace = az.from_dict(trace, observed_data=self.state.observed_values)
        return trace
Exemple #22
0
def test_concat_bad():
    with pytest.raises(TypeError):
        concat("hello", "hello")
    idata = from_dict(posterior={
        "A": np.random.randn(2, 10, 2),
        "B": np.random.randn(2, 10, 5, 2)
    })
    with pytest.raises(TypeError):
        concat(idata, np.array([1, 2, 3, 4, 5]))
    with pytest.raises(NotImplementedError):
        concat(idata, idata)
Exemple #23
0
 def get_inference_data(self, data, eight_schools_params):
     return from_dict(
         posterior=data.obj,
         posterior_predictive=data.obj,
         sample_stats=data.obj,
         prior=data.obj,
         prior_predictive=data.obj,
         sample_stats_prior=data.obj,
         observed_data=eight_schools_params,
         coords={"school": np.arange(eight_schools_params["J"])},
         dims={"theta": ["school"], "eta": ["school"]},
     )
def plot_pair_evolution(params, mcmc_kernel):

    files = []
    for file in os.listdir("./results"):
        if file.startswith("output_it"):
            files.append(file)
    files = sorted(files, key=lambda x: int(x[9:-4]))
    arvzs, cs = [], []

    for i, f in enumerate(files):
        with open(f"./results/{f}", "rb") as obj:
            i += 1
            samples, stats = pickle.load(obj)
            if mcmc_kernel == "hmc":
                stats_names = [
                    "logprob", "diverging", "acceptance", "step_size"
                ]
            elif mcmc_kernel == "nuts":
                stats_names = [
                    "logprob",
                    "tree_size",
                    "diverging",
                    "energy",
                    "acceptance",
                    "mean_tree_accept",
                ]
            sample_stats = {k: v for k, v in zip(stats_names, stats)}
            var_names = [p.name for p in params]
            posterior = {k: v for k, v in zip(var_names, samples)}
            arvzs.append(
                az.from_dict(posterior=posterior, sample_stats=sample_stats))
            cs.append(i / len(files))

    ax = az.plot_pair(
        arvzs[0],
        kind="scatter",
        marginals=True,
        marginal_kwargs={"color": cm.hot_r(cs[0])},
        scatter_kwargs={"c": cm.hot_r(cs[0])},
    )
    for arvz, c in zip(arvzs[0:], cs[0:]):
        az.plot_pair(
            arvz,
            kind="scatter",
            marginals=True,
            marginal_kwargs={"color": cm.hot_r(c)},
            scatter_kwargs={"c": cm.hot_r(c)},
            ax=ax,
        )

    fig = ax.ravel()[0].figure
    fig.savefig("./results/pair_plot_evo.png")
Exemple #25
0
def convert_to_arviz(chains,model,burnin,remove_stuck=False,iparas_time=None,phy_space=False):
    rands,chain,lnprob,lnprob_i = chains
    para_names = model.parameter_names
    if remove_stuck:
        old_number_chains = lnprob.shape[0] 
        chain_mask = np.where(lnprob[:,burnin:].min(axis=1)>-1e10)[0]
        chain = chain[chain_mask,burnin:]
        lnprob_i = lnprob_i[burnin:,chain_mask]
        lnprob = lnprob[chain_mask,burnin:]
        print(old_number_chains - len(chain_mask),' chains are stuck')
    else:
        chain = chain[:,burnin:]
        lnprob_i = lnprob_i[burnin:,:]
        lnprob = lnprob[:,burnin:]

    if iparas_time is not None:
        chain_s = chain.shape
        iparas_name = list(model.calc_implicit_parameters(iparas_time).keys())
        chain_woth_ip = np.zeros((chain_s[0],chain_s[1],chain_s[2]+len(iparas_name)))
        for n_chain in range(chain_s[0]):
            for n_sample in range(chain_s[1]):
                model.set_parameters_fit_array(chain[n_chain,n_sample],mode='bayes')
                chain_woth_ip[n_chain,n_sample,chain_s[2]:] = [model.calc_implicit_parameters(iparas_time)[name] for name in iparas_name]
                sample_stats = {'log_likelihood':np.transpose(lnprob_i,axes=(1,0,2)),'loglike_values':lnprob}

        chain_dict = {name: chain[:,:,i] for i,name in enumerate(model.parameter_names) }
        chain_dict_phy = model.transform_fit_to_physical(chain_dict,mode='bayes')
        chain = np.transpose(np.array([chain_dict_phy[name] for name in model.parameter_names]),axes=(1,2,0))
        chain_woth_ip[:,:,:chain_s[2]] = chain
        
        return az.from_dict(posterior={'a':chain_woth_ip},sample_stats=sample_stats,dims={'a':['ac']},coords ={'ac':list(para_names)+iparas_name}),list(para_names)+iparas_name
    else:
        if phy_space:
            chain_dict = {name: chain[:,:,i] for i,name in enumerate(model.parameter_names) }
            chain_dict_phy = model.transform_fit_to_physical(chain_dict,mode='bayes')
            chain = np.transpose(np.array([chain_dict_phy[name] for name in model.parameter_names]),axes=(1,2,0))
        sample_stats = {'log_likelihood':np.transpose(lnprob_i,axes=(1,0,2)),'loglike_values':lnprob}
        return az.from_dict(posterior={'a':chain},sample_stats=sample_stats,dims={'a':['ac']},coords ={'ac':para_names}),list(para_names)
Exemple #26
0
def simulate_fixed(fixed_value, n_chains, n_samples, model):
    samples = np.zeros((n_chains, n_samples, model.n_params))
    model.t_init[-1] = fixed_value
    for i in range(n_chains):
        samples[i] = model.sample_slice_gibbs5(n_samples)
    
    samples = samples[:, :, :-1]
    az_dict = to_arviz_dict(samples, {"$\\beta$":np.arange(4), 
                                      "$\\theta$":np.arange(4, 7)}, burnin=1000)
    
    az_data = az.from_dict(az_dict)
    summary = az.summary(az_data)
    summary['sampling_effeciency'] = summary['ess_mean'] / np.product(samples.shape[:2])
    return samples, summary
Exemple #27
0
 def test_to_dict_warmup(self):
     idata = create_data_random(groups=[
         "posterior",
         "sample_stats",
         "observed_data",
         "warmup_posterior",
         "warmup_posterior_predictive",
     ])
     test_data = from_dict(**idata.to_dict(), save_warmup=True)
     assert test_data
     for group in idata._groups_all:  # pylint: disable=protected-access
         xr_data = getattr(idata, group)
         test_xr_data = getattr(test_data, group)
         assert xr_data.equals(test_xr_data)
def mcmc_diagnostic_plots(posterior, sample_stats, it):

    az_trace = az.from_dict(posterior=posterior, sample_stats=sample_stats)
    """
    # 2 parameters or more for these pair plots
    if len(az_trace.posterior.data_vars) > 1:
        ax = az.plot_pair(az_trace, kind="hexbin", gridsize=30, marginals=True)
        fig = ax.ravel()[0].figure
        plt.ylim((5000, 30000))
        plt.xlim((1e-10, 1e-7))
        fig.savefig(f"./results/pair_plot_it{it}.png")
        plt.clf()

        ax = az.plot_pair(
            az_trace,
            kind=["scatter", "kde"],
            kde_kwargs={"fill_last": False},
            point_estimate="mean",
            marginals=True,
        )
        fig = ax.ravel()[0].figure
        fig.savefig(f"./results/point_estimate_plot_it{it}.png")
        plt.clf()
    """

    ax = az.plot_trace(az_trace, divergences=False)
    fig = ax.ravel()[0].figure
    fig.savefig(f"./results/trace_plot_it{it}.png")
    plt.clf()

    ax = az.plot_posterior(az_trace)
    fig = ax.ravel()[0].figure
    fig.savefig(f"./results/posterior_plot_it{it}.png")
    plt.clf()

    lag = np.minimum(len(list(posterior.values())[0]), 100)
    ax = az.plot_autocorr(az_trace, max_lag=lag)
    fig = ax.ravel()[0].figure
    fig.savefig(f"./results/autocorr_plot_it{it}.png")
    plt.clf()

    ax = az.plot_ess(az_trace, kind="evolution")
    fig = ax.ravel()[0].figure
    fig.savefig(f"./results/ess_evolution_plot_it{it}.png")
    plt.clf()
    plt.close()
Exemple #29
0
    def test_del(self, use):
        # create inference data object
        data = np.random.normal(size=(4, 500, 8))
        idata = from_dict(
            posterior={
                "a": data[..., 0],
                "b": data
            },
            sample_stats={
                "a": data[..., 0],
                "b": data
            },
            observed_data={"b": data[0, 0, :]},
            posterior_predictive={
                "a": data[..., 0],
                "b": data
            },
        )

        # assert inference data object has all attributes
        test_dict = {
            "posterior": ("a", "b"),
            "sample_stats": ("a", "b"),
            "observed_data": ["b"],
            "posterior_predictive": ("a", "b"),
        }
        fails = check_multiple_attrs(test_dict, idata)
        assert not fails
        # assert _groups attribute contains all groups
        groups = getattr(idata, "_groups")
        assert all([group in groups for group in test_dict])

        # Use del method
        if use == "del":
            del idata.sample_stats
        else:
            delattr(idata, "sample_stats")

        # assert attribute has been removed
        test_dict.pop("sample_stats")
        fails = check_multiple_attrs(test_dict, idata)
        assert not fails
        assert not hasattr(idata, "sample_stats")
        # assert _groups attribute has been updated
        assert "sample_stats" not in getattr(idata, "_groups")
Exemple #30
0
    def sample(self, n: int = 500, include_log_likelihood: bool = False) -> az.InferenceData:
        """Generate samples from posterior distribution."""
        samples = self.approx.sample(n)
        q_samples = self.order.split_samples(samples, n)
        q_samples = dict(**q_samples, **self.deterministics_callback(q_samples))

        # Add a new axis so as n_chains=1 for InferenceData: handles shape issues
        trace = {k: v.numpy()[np.newaxis] for k, v in q_samples.items()}
        log_likelihood_dict = dict()
        if include_log_likelihood:
            log_likelihood_dict = calculate_log_likelihood(self.model, trace, self.state)

        trace = az.from_dict(
            trace,
            observed_data=self.state.observed_values,
            log_likelihood=log_likelihood_dict if include_log_likelihood else None,
        )
        return trace