Пример #1
0
def _validate_predictive_group(data: az.InferenceData, group: str):
    """Validate the predictive groups in data.

    Args:
        data (arviz.InferenceData): Inference data object.
        group (str): One of ['posterior', 'prior'].

    Raises:
        ValueError: If group is not valid.
        KeyError: If predictive is not in data, gives helpful suggestion.

    Returns:
        xarray.Dataset: Dataset corresponding to the predictive of group.
    """
    if group == "posterior":
        key = "posterior_predictive"
        predictive = data.get(key, None)
    elif group == "prior":
        key = "prior_predictive"
        predictive = data.get(key, None)
    else:
        raise ValueError(
            f"Group '{group}' is not one of ['posterior', 'prior'].")

    if predictive is None:
        raise KeyError(f"Group '{key}' not in data. Consider using method " +
                       "'Inference.{key}()' to sample the predictive.")
    return predictive
Пример #2
0
def predict(
    mi: MaudInput,
    output_dir: str,
    idata_train: az.InferenceData,
) -> az.InferenceData:
    """Call CmdStanModel.sample for out of sample predictions.

    :param mi: a MaudInput object
    :param output_dir: directory where output will be saved
    :param idata_train: InferenceData object with posterior draws
    """
    model = cmdstanpy.CmdStanModel(
        stan_file=os.path.join(HERE, STAN_PROGRAM_RELATIVE_PATH_PREDICT),
        cpp_options=mi.config.cpp_options,
        stanc_options=mi.config.stanc_options,
    )
    set_up_output_dir(output_dir, mi)
    kinetic_parameters = [
        "keq",
        "km",
        "kcat",
        "dissociation_constant",
        "transfer_constant",
        "kcat_phos",
        "ki",
    ]
    posterior = idata_train.get("posterior")
    sample_stats = idata_train.get("sample_stats")
    assert posterior is not None
    assert sample_stats is not None
    chains = sample_stats["chain"]
    draws = sample_stats["draw"]
    dims = {
        "conc": ["experiment", "mic"],
        "conc_enzyme": ["experiment", "enzyme"],
        "flux": ["experiment", "reaction"],
    }
    for chain in chains:
        for draw in draws:
            inits = {
                par: (
                    posterior[par]
                    .sel(chain=chain, draw=draw)
                    .to_series()
                    .values
                )
                for par in kinetic_parameters
                if par in posterior.keys()
            }
            sample_args: dict = {
                "data": os.path.join(output_dir, "input_data_test.json"),
                "inits": inits,
                "output_dir": output_dir,
                "iter_warmup": 0,
                "iter_sampling": 1,
                "fixed_param": True,
                "show_progress": False,
            }
            if mi.config.cmdstanpy_config_predict is not None:
                sample_args = {
                    **sample_args,
                    **mi.config.cmdstanpy_config_predict,
                }
            mcmc_draw = model.sample(**sample_args)
            idata_draw = az.from_cmdstan(
                mcmc_draw.runset.csv_files,
                coords={
                    "experiment": [
                        e.id for e in mi.measurements.experiments if e.is_test
                    ],
                    "mic": [m.id for m in mi.kinetic_model.mics],
                    "enzyme": [e.id for e in mi.kinetic_model.enzymes],
                    "reaction": [r.id for r in mi.kinetic_model.reactions],
                },
                dims=dims,
            ).assign_coords(
                coords={"chain": [chain], "draw": [draw]},
                groups="posterior_groups",
            )
            if draw == 0:
                idata_chain = idata_draw.copy()
            else:
                idata_chain = az.concat(
                    [idata_chain, idata_draw], dim="draw", reset_dim=False
                )
        if chain == 0:
            out = idata_chain.copy()
        else:
            out = az.concat([out, idata_chain], dim="chain", reset_dim=False)
    return out