def _validate_predictive_group(data: az.InferenceData, group: str):
    """Validate the predictive groups in data.

    Args:
        data (arviz.InferenceData): Inference data object.
        group (str): One of ['posterior', 'prior'].

    Raises:
        ValueError: If group is not valid.
        KeyError: If predictive is not in data, gives helpful suggestion.

    Returns:
        xarray.Dataset: Dataset corresponding to the predictive of group.
    """
    if group == "posterior":
        key = "posterior_predictive"
        predictive = data.get(key, None)
    elif group == "prior":
        key = "prior_predictive"
        predictive = data.get(key, None)
    else:
        raise ValueError(
            f"Group '{group}' is not one of ['posterior', 'prior'].")

    if predictive is None:
        raise KeyError(f"Group '{key}' not in data. Consider using method " +
                       "'Inference.{key}()' to sample the predictive.")
    return predictive
示例#2
0
 def test_empty_inference_data_object(self):
     inference_data = InferenceData()
     here = os.path.dirname(os.path.abspath(__file__))
     data_directory = os.path.join(here, "..", "saved_models")
     filepath = os.path.join(data_directory, "empty_test_file.nc")
     assert not os.path.exists(filepath)
     inference_data.to_netcdf(filepath)
     assert os.path.exists(filepath)
     os.remove(filepath)
     assert not os.path.exists(filepath)
示例#3
0
 def test_group_names(self, args_res):
     args, result = args_res
     ds = dict_to_dataset({"a": np.random.normal(size=(3, 10))})
     idata = InferenceData(
         posterior=(ds, ds),
         sample_stats=(ds, ds),
         observed_data=ds,
         posterior_predictive=ds,
     )
     group_names = idata._group_names(*args)  # pylint: disable=protected-access
     assert np.all([name in result for name in group_names])
示例#4
0
    def test_io_method(self, data, eight_schools_params, store):
        # create InferenceData and check it has been properly created
        inference_data = self.get_inference_data(  # pylint: disable=W0612
            data, eight_schools_params
        )
        test_dict = {
            "posterior": ["eta", "theta", "mu", "tau"],
            "posterior_predictive": ["eta", "theta", "mu", "tau"],
            "sample_stats": ["eta", "theta", "mu", "tau"],
            "prior": ["eta", "theta", "mu", "tau"],
            "prior_predictive": ["eta", "theta", "mu", "tau"],
            "sample_stats_prior": ["eta", "theta", "mu", "tau"],
            "observed_data": ["J", "y", "sigma"],
        }
        fails = check_multiple_attrs(test_dict, inference_data)
        assert not fails

        # check filename does not exist and use to_zarr method
        here = os.path.dirname(os.path.abspath(__file__))
        data_directory = os.path.join(here, "..", "saved_models")
        filepath = os.path.join(data_directory, "zarr")
        assert not os.path.exists(filepath)

        # InferenceData method
        if store == 0:
            # Tempdir
            store = inference_data.to_zarr(store=None)
            assert isinstance(store, MutableMapping)
        elif store == 1:
            inference_data.to_zarr(store=filepath)
            # assert file has been saved correctly
            assert os.path.exists(filepath)
            assert os.path.getsize(filepath) > 0
        elif store == 2:
            store = zarr.storage.DirectoryStore(filepath)
            inference_data.to_zarr(store=store)
            # assert file has been saved correctly
            assert os.path.exists(filepath)
            assert os.path.getsize(filepath) > 0

        if isinstance(store, MutableMapping):
            inference_data2 = InferenceData.from_zarr(store)
        else:
            inference_data2 = InferenceData.from_zarr(filepath)

        # Everything in dict still available in inference_data2 ?
        fails = check_multiple_attrs(test_dict, inference_data2)
        assert not fails

        # Remove created folder structure
        if os.path.exists(filepath):
            shutil.rmtree(filepath)
        assert not os.path.exists(filepath)
示例#5
0
def test_inference_data_other_groups():
    datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
    dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={"b": ["c"]})
    with pytest.warns(UserWarning, match="not.+in.+InferenceData scheme"):
        idata = InferenceData(other_group=dataset)
    fails = check_multiple_attrs({"other_group": ["a", "b"]}, idata)
    assert not fails
示例#6
0
    def to_inference_data(self):
        """Convert all available data to an InferenceData object.

        Note that if groups can not be created (e.g., there is no `trace`, so
        the `posterior` and `sample_stats` can not be extracted), then the InferenceData
        will not have those groups.
        """
        id_dict = {
            "posterior":
            self.posterior_to_xarray(),
            "sample_stats":
            self.sample_stats_to_xarray(),
            "log_likelihood":
            self.log_likelihood_to_xarray(),
            "posterior_predictive":
            self.posterior_predictive_to_xarray(),
            "predictions":
            self.predictions_to_xarray(),
            **self.priors_to_xarray(),
            "observed_data":
            self.observed_data_to_xarray(),
        }
        if self.predictions:
            id_dict[
                "predictions_constant_data"] = self.constant_data_to_xarray()
        else:
            id_dict["constant_data"] = self.constant_data_to_xarray()
        return InferenceData(save_warmup=self.save_warmup, **id_dict)
示例#7
0
def load_data(filename):
    """
    Load netcdf file back into an arviz.InferenceData

    Parameters
    ----------
    filename : str
        name or path of the file to load trace
    """
    return InferenceData(filename)
示例#8
0
def _save_sample_stats(
    sample_settings,
    sample_stats,
    chains,
    trace,
    return_inferencedata,
    _t_sampling,
    idata_kwargs,
    model,
):
    sample_settings_dict = sample_settings[0]
    sample_settings_dict["_t_sampling"] = _t_sampling
    sample_stats_dict = sample_stats[0]

    if chains > 1:
        # Collect the stat values from each chain in a single list
        for stat in sample_stats[0].keys():
            value_list = []
            for chain_sample_stats in sample_stats:
                value_list.append(chain_sample_stats[stat])
            sample_stats_dict[stat] = value_list

    if not return_inferencedata:
        for stat, value in sample_stats_dict.items():
            setattr(trace.report, stat, value)
        for stat, value in sample_settings_dict.items():
            setattr(trace.report, stat, value)
        idata = None
    else:
        for stat, value in sample_stats_dict.items():
            if chains > 1:
                # Different chains might have more iteration steps, leading to a
                # non-square `sample_stats` dataset, we cast as `object` to avoid
                # numpy ragged array deprecation warning
                sample_stats_dict[stat] = np.array(value, dtype=object)
            else:
                sample_stats_dict[stat] = np.array(value)

        sample_stats = dict_to_dataset(
            sample_stats_dict,
            attrs=sample_settings_dict,
            library=pymc,
        )

        ikwargs = dict(model=model)
        if idata_kwargs is not None:
            ikwargs.update(idata_kwargs)
        idata = to_inference_data(trace, **ikwargs)
        idata = InferenceData(**idata, sample_stats=sample_stats)

    return sample_stats, idata
示例#9
0
 def test_io_method(self, data, eight_schools_params):
     inference_data = self.get_inference_data(  # pylint: disable=W0612
         data, eight_schools_params)
     assert hasattr(inference_data, "posterior")
     here = os.path.dirname(os.path.abspath(__file__))
     data_directory = os.path.join(here, "saved_models")
     filepath = os.path.join(data_directory, "io_method_testfile.nc")
     assert not os.path.exists(filepath)
     # InferenceData method
     inference_data.to_netcdf(filepath)
     assert os.path.exists(filepath)
     assert os.path.getsize(filepath) > 0
     inference_data2 = InferenceData.from_netcdf(filepath)
     assert hasattr(inference_data2, "posterior")
     os.remove(filepath)
     assert not os.path.exists(filepath)
示例#10
0
    def test_io_method(self, data, eight_schools_params, groups_arg):
        # create InferenceData and check it has been properly created
        inference_data = self.get_inference_data(  # pylint: disable=W0612
            data, eight_schools_params
        )
        test_dict = {
            "posterior": ["eta", "theta", "mu", "tau"],
            "posterior_predictive": ["eta", "theta", "mu", "tau"],
            "sample_stats": ["eta", "theta", "mu", "tau"],
            "prior": ["eta", "theta", "mu", "tau"],
            "prior_predictive": ["eta", "theta", "mu", "tau"],
            "sample_stats_prior": ["eta", "theta", "mu", "tau"],
            "observed_data": ["J", "y", "sigma"],
        }
        fails = check_multiple_attrs(test_dict, inference_data)
        assert not fails

        # check filename does not exist and use to_netcdf method
        here = os.path.dirname(os.path.abspath(__file__))
        data_directory = os.path.join(here, "saved_models")
        filepath = os.path.join(data_directory, "io_method_testfile.nc")
        assert not os.path.exists(filepath)
        # InferenceData method
        inference_data.to_netcdf(
            filepath, groups=("posterior", "observed_data") if groups_arg else None
        )

        # assert file has been saved correctly
        assert os.path.exists(filepath)
        assert os.path.getsize(filepath) > 0
        inference_data2 = InferenceData.from_netcdf(filepath)
        if groups_arg:  # if groups arg, update test dict to contain only saved groups
            test_dict = {
                "posterior": ["eta", "theta", "mu", "tau"],
                "observed_data": ["J", "y", "sigma"],
            }
            assert not hasattr(inference_data2, "sample_stats")
        fails = check_multiple_attrs(test_dict, inference_data2)
        assert not fails

        os.remove(filepath)
        assert not os.path.exists(filepath)
示例#11
0
def load_arviz_data(dataset):
    """Load built-in arviz dataset into memory

    Will print out available datasets in case of error.

    Parameters
    ----------
    dataset : str
        Name of dataset to load

    Returns
    -------
    InferenceData
    """
    top = os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    data_path = os.path.join(top, 'doc', 'data')
    datasets_available = {
        'centered_eight': {
            'description': '''
                Centered eight schools model.  Four chains, 500 draws each, fit with
                NUTS in PyMC3.  Features named coordinates for each of the eight schools.
            ''',
            'path': os.path.join(data_path, 'centered_eight.nc')
        },
        'non_centered_eight': {
            'description': '''
                Non-centered eight schools model.  Four chains, 500 draws each, fit with
                NUTS in PyMC3.  Features named coordinates for each of the eight schools.
            ''',
            'path': os.path.join(data_path, 'non_centered_eight.nc')
        }
    }
    if dataset in datasets_available:
        return InferenceData(datasets_available[dataset]['path'])
    else:
        msg = ['\'dataset\' must be one of the following options:']
        for key, value in sorted(datasets_available.items()):
            msg.append('{key}: {description}'.format(
                key=key, description=value['description']))

        raise ValueError('\n'.join(msg))
示例#12
0
 def to_netcdf(self):
     has_group = False
     mode = 'w'  # overwrite first, then append
     for group, func_name in self.converters.items():
         if hasattr(self, func_name):
             data = getattr(self, func_name)()
             try:
                 data.to_netcdf(self.filename, mode=mode, group=group)
             except PermissionError as err:
                 msg = 'File "{}" is in use - is another object using it?'.format(
                     self.filename)
                 raise PermissionError(msg) from err
             has_group = True
             mode = 'a'
     if not has_group:
         msg = (
             '{} has no functions creating groups! Must implement one of '
             'the following functions:\n{}'.format(
                 self.__class__.__name__,
                 '\n'.join(self.converters.values())))
         raise RuntimeError(msg)
     return InferenceData(self.filename)
示例#13
0
def test_bad_inference_data():
    with pytest.raises(ValueError):
        InferenceData(posterior=[1, 2, 3])
示例#14
0
def sample_smc(
    draws=2000,
    kernel=IMH,
    *,
    start=None,
    model=None,
    random_seed=None,
    chains=None,
    cores=None,
    compute_convergence_checks=True,
    return_inferencedata=True,
    idata_kwargs=None,
    progressbar=True,
    **kernel_kwargs,
):
    r"""
    Sequential Monte Carlo based sampling.

    Parameters
    ----------
    draws: int
        The number of samples to draw from the posterior (i.e. last stage). And also the number of
        independent chains. Defaults to 2000.
    kernel: SMC Kernel used. Defaults to pm.smc.IMH (Independent Metropolis Hastings)
    start: dict, or array of dict
        Starting point in parameter space. It should be a list of dict with length `chains`.
        When None (default) the starting point is sampled from the prior distribution.
    model: Model (optional if in ``with`` context)).
    random_seed: int
        random seed
    chains : int
        The number of chains to sample. Running independent chains is important for some
        convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
        is larger.
    cores : int
        The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
        system.
    compute_convergence_checks : bool
        Whether to compute sampler statistics like Gelman-Rubin and ``effective_n``.
        Defaults to ``True``.
    return_inferencedata : bool, default=True
        Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
        Defaults to ``True``.
    idata_kwargs : dict, optional
        Keyword arguments for :func:`pymc.to_inference_data`
    progressbar : bool, optional default=True
        Whether or not to display a progress bar in the command line.
    **kernel_kwargs: keyword arguments passed to the SMC kernel.
        The default IMH kernel takes the following keywords:
            threshold: float
                Determines the change of beta from stage to stage, i.e. indirectly the number of stages,
                the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
                It should be between 0 and 1.
            n_steps: int
                The number of steps of each Markov Chain. If ``tune_steps == True`` ``n_steps`` will be used
                for the first stage and for the others it will be determined automatically based on the
                acceptance rate and `p_acc_rate`, the max number of steps is ``n_steps``.
            tune_steps: bool
                Whether to compute the number of steps automatically or not. Defaults to True
            p_acc_rate: float
                Used to compute ``n_steps`` when ``tune_steps == True``. The higher the value of
                ``p_acc_rate`` the higher the number of steps computed automatically. Defaults to 0.85.
                It should be between 0 and 1.
        Keyword arguments for other kernels should be checked in the respective docstrings

    Notes
    -----
    SMC works by moving through successive stages. At each stage the inverse temperature
    :math:`\beta` is increased a little bit (starting from 0 up to 1). When :math:`\beta` = 0
    we have the prior distribution and when :math:`\beta` =1 we have the posterior distribution.
    So in more general terms we are always computing samples from a tempered posterior that we can
    write as:

    .. math::

        p(\theta \mid y)_{\beta} = p(y \mid \theta)^{\beta} p(\theta)

    A summary of the algorithm is:

     1. Initialize :math:`\beta` at zero and stage at zero.
     2. Generate N samples :math:`S_{\beta}` from the prior (because when :math `\beta = 0` the
        tempered posterior is the prior).
     3. Increase :math:`\beta` in order to make the effective sample size equals some predefined
        value (we use :math:`Nt`, where :math:`t` is 0.5 by default).
     4. Compute a set of N importance weights W. The weights are computed as the ratio of the
        likelihoods of a sample at stage i+1 and stage i.
     5. Obtain :math:`S_{w}` by re-sampling according to W.
     6. Use W to compute the mean and covariance for the proposal distribution, a MVNormal.
     7. For stages other than 0 use the acceptance rate from the previous stage to estimate
        `n_steps`.
     8. Run N independent Metropolis-Hastings (IMH) chains (each one of length `n_steps`),
        starting each one from a different sample in :math:`S_{w}`. Samples are IMH as the proposal
        mean is the of the previous posterior stage and not the current point in parameter space.
     9. Repeat from step 3 until :math:`\beta \ge 1`.
     10. The final result is a collection of N samples from the posterior.


    References
    ----------
    .. [Minson2013] Minson, S. E. and Simons, M. and Beck, J. L., (2013),
        Bayesian inversion for finite fault earthquake source models I- Theory and algorithm.
        Geophysical Journal International, 2013, 194(3), pp.1701-1726,
        `link <https://gji.oxfordjournals.org/content/194/3/1701.full>`__

    .. [Ching2007] Ching, J. and Chen, Y. (2007).
        Transitional Markov Chain Monte Carlo Method for Bayesian Model Updating, Model Class
        Selection, and Model Averaging. J. Eng. Mech., 10.1061/(ASCE)0733-9399(2007)133:7(816),
        816-832. `link <http://ascelibrary.org/doi/abs/10.1061/%28ASCE%290733-9399
        %282007%29133:7%28816%29>`__
    """

    if isinstance(kernel, str) and kernel.lower() in ("abc", "metropolis"):
        warnings.warn(
            f'The kernel string argument "{kernel}" in sample_smc has been deprecated. '
            f"It is no longer needed to distinguish between `abc` and `metropolis`",
            FutureWarning,
            stacklevel=2,
        )
        kernel = IMH

    if kernel_kwargs.pop("save_sim_data", None) is not None:
        warnings.warn(
            "save_sim_data has been deprecated. Use pm.sample_posterior_predictive "
            "to obtain the same type of samples.",
            FutureWarning,
            stacklevel=2,
        )

    if kernel_kwargs.pop("save_log_pseudolikelihood", None) is not None:
        warnings.warn(
            "save_log_pseudolikelihood has been deprecated. This information is "
            "now saved as log_likelihood in models with Simulator distributions.",
            FutureWarning,
            stacklevel=2,
        )

    parallel = kernel_kwargs.pop("parallel", None)
    if parallel is not None:
        warnings.warn(
            "The argument parallel is deprecated, use the argument cores instead.",
            FutureWarning,
            stacklevel=2,
        )
        if parallel is False:
            cores = 1

    if cores is None:
        cores = _cpu_count()

    if chains is None:
        chains = max(2, cores)
    else:
        cores = min(chains, cores)

    if random_seed == -1:
        raise FutureWarning(
            f"random_seed should be a non-negative integer or None, got: {random_seed}"
            "This will raise a ValueError in the Future")
        random_seed = None
    if isinstance(random_seed, int) or random_seed is None:
        rng = np.random.default_rng(seed=random_seed)
        random_seed = list(rng.integers(2**30, size=chains))
    elif isinstance(random_seed, Iterable):
        if len(random_seed) != chains:
            raise ValueError(
                f"Length of seeds ({len(seeds)}) must match number of chains {chains}"
            )
    else:
        raise TypeError(
            "Invalid value for `random_seed`. Must be tuple, list, int or None"
        )

    model = modelcontext(model)

    _log = logging.getLogger("pymc")
    _log.info("Initializing SMC sampler...")
    _log.info(f"Sampling {chains} chain{'s' if chains > 1 else ''} "
              f"in {cores} job{'s' if cores > 1 else ''}")

    params = (
        draws,
        kernel,
        start,
        model,
    )

    t1 = time.time()
    if cores > 1:
        pbar = progress_bar((), total=100, display=progressbar)
        pbar.update(0)
        pbars = [pbar] + [None] * (chains - 1)

        pool = mp.Pool(cores)

        # "manually" (de)serialize params before/after multiprocessing
        params = tuple(cloudpickle.dumps(p) for p in params)
        kernel_kwargs = {
            key: cloudpickle.dumps(value)
            for key, value in kernel_kwargs.items()
        }
        results = _starmap_with_kwargs(
            pool,
            _sample_smc_int,
            [(*params, random_seed[chain], chain, pbars[chain])
             for chain in range(chains)],
            repeat(kernel_kwargs),
        )
        results = tuple(cloudpickle.loads(r) for r in results)
        pool.close()
        pool.join()

    else:
        results = []
        pbar = progress_bar((), total=100 * chains, display=progressbar)
        pbar.update(0)
        for chain in range(chains):
            pbar.offset = 100 * chain
            pbar.base_comment = f"Chain: {chain+1}/{chains}"
            results.append(
                _sample_smc_int(*params, random_seed[chain], chain, pbar,
                                **kernel_kwargs))

    (
        traces,
        sample_stats,
        sample_settings,
    ) = zip(*results)

    trace = MultiTrace(traces)
    idata = None

    # Save sample_stats
    _t_sampling = time.time() - t1
    sample_settings_dict = sample_settings[0]
    sample_settings_dict["_t_sampling"] = _t_sampling

    sample_stats_dict = sample_stats[0]
    if chains > 1:
        # Collect the stat values from each chain in a single list
        for stat in sample_stats[0].keys():
            value_list = []
            for chain_sample_stats in sample_stats:
                value_list.append(chain_sample_stats[stat])
            sample_stats_dict[stat] = value_list

    if not return_inferencedata:
        for stat, value in sample_stats_dict.items():
            setattr(trace.report, stat, value)
        for stat, value in sample_settings_dict.items():
            setattr(trace.report, stat, value)
    else:
        for stat, value in sample_stats_dict.items():
            if chains > 1:
                # Different chains might have more iteration steps, leading to a
                # non-square `sample_stats` dataset, we cast as `object` to avoid
                # numpy ragged array deprecation warning
                sample_stats_dict[stat] = np.array(value, dtype=object)
            else:
                sample_stats_dict[stat] = np.array(value)

        sample_stats = dict_to_dataset(
            sample_stats_dict,
            attrs=sample_settings_dict,
            library=pymc,
        )

        ikwargs = dict(model=model)
        if idata_kwargs is not None:
            ikwargs.update(idata_kwargs)
        idata = to_inference_data(trace, **ikwargs)
        idata = InferenceData(**idata, sample_stats=sample_stats)

    if compute_convergence_checks:
        if draws < 100:
            warnings.warn(
                "The number of samples is too small to check convergence reliably.",
                stacklevel=2,
            )
        else:
            if idata is None:
                idata = to_inference_data(trace, log_likelihood=False)
            trace.report._run_convergence_checks(idata, model)
    trace.report._log_summary()

    return idata if return_inferencedata else trace
示例#15
0
def _226(obj: az.InferenceData) -> MonteCarloTensorFormat:
    ff = MonteCarloTensorFormat()
    obj.to_netcdf(str(ff))
    return ff
示例#16
0
def predict(
    mi: MaudInput,
    output_dir: str,
    idata_train: az.InferenceData,
) -> az.InferenceData:
    """Call CmdStanModel.sample for out of sample predictions.

    :param mi: a MaudInput object
    :param output_dir: directory where output will be saved
    :param idata_train: InferenceData object with posterior draws
    """
    model = cmdstanpy.CmdStanModel(
        stan_file=os.path.join(HERE, STAN_PROGRAM_RELATIVE_PATH_PREDICT),
        cpp_options=mi.config.cpp_options,
        stanc_options=mi.config.stanc_options,
    )
    set_up_output_dir(output_dir, mi)
    kinetic_parameters = [
        "keq",
        "km",
        "kcat",
        "dissociation_constant",
        "transfer_constant",
        "kcat_phos",
        "ki",
    ]
    posterior = idata_train.get("posterior")
    sample_stats = idata_train.get("sample_stats")
    assert posterior is not None
    assert sample_stats is not None
    chains = sample_stats["chain"]
    draws = sample_stats["draw"]
    dims = {
        "conc": ["experiment", "mic"],
        "conc_enzyme": ["experiment", "enzyme"],
        "flux": ["experiment", "reaction"],
    }
    for chain in chains:
        for draw in draws:
            inits = {
                par: (
                    posterior[par]
                    .sel(chain=chain, draw=draw)
                    .to_series()
                    .values
                )
                for par in kinetic_parameters
                if par in posterior.keys()
            }
            sample_args: dict = {
                "data": os.path.join(output_dir, "input_data_test.json"),
                "inits": inits,
                "output_dir": output_dir,
                "iter_warmup": 0,
                "iter_sampling": 1,
                "fixed_param": True,
                "show_progress": False,
            }
            if mi.config.cmdstanpy_config_predict is not None:
                sample_args = {
                    **sample_args,
                    **mi.config.cmdstanpy_config_predict,
                }
            mcmc_draw = model.sample(**sample_args)
            idata_draw = az.from_cmdstan(
                mcmc_draw.runset.csv_files,
                coords={
                    "experiment": [
                        e.id for e in mi.measurements.experiments if e.is_test
                    ],
                    "mic": [m.id for m in mi.kinetic_model.mics],
                    "enzyme": [e.id for e in mi.kinetic_model.enzymes],
                    "reaction": [r.id for r in mi.kinetic_model.reactions],
                },
                dims=dims,
            ).assign_coords(
                coords={"chain": [chain], "draw": [draw]},
                groups="posterior_groups",
            )
            if draw == 0:
                idata_chain = idata_draw.copy()
            else:
                idata_chain = az.concat(
                    [idata_chain, idata_draw], dim="draw", reset_dim=False
                )
        if chain == 0:
            out = idata_chain.copy()
        else:
            out = az.concat([out, idata_chain], dim="chain", reset_dim=False)
    return out
示例#17
0
def plot_posterior_predictive_checks(inference_object: az.InferenceData):
    """Plot posterior predictive checks of fitted model.

    :param inference_object: Inference object containing posterior predictive
        and observed data groups
    :type inference_object: az.InferenceData

    :returns: matplotlib axes figure
    """
    if "posterior_predictive" not in inference_object.groups():
        raise ValueError(
            "Must include posterior predictive values to perform PPC!")
    if "observed_data" not in inference_object.groups():
        raise ValueError("Must include observed data to perform PPC!")

    obs = inference_object.observed_data.transpose("tbl_sample", "feature")
    ppc = inference_object.posterior_predictive.transpose(
        "chain", "draw", "tbl_sample", "feature")
    ppc_median = ppc.median(["chain", "draw"])
    ppc_lower = ppc.quantile(0.025, ["chain", "draw"])
    ppc_upper = ppc.quantile(0.975, ["chain", "draw"])
    # ppc_in_ci = (
    #     (obs["observed"] <= ppc_upper["y_predict"])
    #     & (obs["observed"] >= ppc_lower["y_predict"])
    # )
    # pct_in_ci = ppc_in_ci.data.sum() / len(ppc_in_ci.data.ravel()) * 100

    obs_data = obs["observed"].data.reshape(-1)
    sort_indices = obs_data.argsort()
    obs_data = obs_data[sort_indices]
    ppc_median_data = ppc_median["y_predict"].data.reshape(-1)[sort_indices]
    ppc_lower_data = ppc_lower["y_predict"].data.reshape(-1)[sort_indices]
    ppc_upper_data = ppc_upper["y_predict"].data.reshape(-1)[sort_indices]

    fig, ax = plt.subplots(1, 1)
    x = np.arange(len(obs_data))
    ax.plot(x, obs_data, zorder=3, color="black")
    y_min, y_max = ax.get_ylim()
    ax.scatter(x=x, y=ppc_median_data, zorder=1, color="gray")
    for i, (lower, upper) in enumerate(zip(ppc_lower_data, ppc_upper_data)):
        ax.plot(  # credible interval
            [i, i],
            [lower, upper],
            zorder=0,
            color="lightgray",
        )
    ax.set_ylim([y_min, y_max])

    obs_legend_entry = Line2D([0], [0], color="black", linewidth=2)
    ci_legend_entry = Line2D([0], [0], color="lightgray", linewidth=2)
    ppc_median_legend_entry = Line2D([0], [0],
                                     color="gray",
                                     marker="o",
                                     linewidth=0)
    ax.legend(
        handles=[obs_legend_entry, ci_legend_entry, ppc_median_legend_entry],
        labels=["Observed", "95% Credible Interval", "Median"],
        bbox_to_anchor=[0.5, -0.2],
        loc="center",
        ncol=3)

    ax.set_ylabel("Count")
    ax.set_xlabel("Table Entry")
    plt.tight_layout()

    return ax