def _validate_predictive_group(data: az.InferenceData, group: str): """Validate the predictive groups in data. Args: data (arviz.InferenceData): Inference data object. group (str): One of ['posterior', 'prior']. Raises: ValueError: If group is not valid. KeyError: If predictive is not in data, gives helpful suggestion. Returns: xarray.Dataset: Dataset corresponding to the predictive of group. """ if group == "posterior": key = "posterior_predictive" predictive = data.get(key, None) elif group == "prior": key = "prior_predictive" predictive = data.get(key, None) else: raise ValueError( f"Group '{group}' is not one of ['posterior', 'prior'].") if predictive is None: raise KeyError(f"Group '{key}' not in data. Consider using method " + "'Inference.{key}()' to sample the predictive.") return predictive
def predict( mi: MaudInput, output_dir: str, idata_train: az.InferenceData, ) -> az.InferenceData: """Call CmdStanModel.sample for out of sample predictions. :param mi: a MaudInput object :param output_dir: directory where output will be saved :param idata_train: InferenceData object with posterior draws """ model = cmdstanpy.CmdStanModel( stan_file=os.path.join(HERE, STAN_PROGRAM_RELATIVE_PATH_PREDICT), cpp_options=mi.config.cpp_options, stanc_options=mi.config.stanc_options, ) set_up_output_dir(output_dir, mi) kinetic_parameters = [ "keq", "km", "kcat", "dissociation_constant", "transfer_constant", "kcat_phos", "ki", ] posterior = idata_train.get("posterior") sample_stats = idata_train.get("sample_stats") assert posterior is not None assert sample_stats is not None chains = sample_stats["chain"] draws = sample_stats["draw"] dims = { "conc": ["experiment", "mic"], "conc_enzyme": ["experiment", "enzyme"], "flux": ["experiment", "reaction"], } for chain in chains: for draw in draws: inits = { par: ( posterior[par] .sel(chain=chain, draw=draw) .to_series() .values ) for par in kinetic_parameters if par in posterior.keys() } sample_args: dict = { "data": os.path.join(output_dir, "input_data_test.json"), "inits": inits, "output_dir": output_dir, "iter_warmup": 0, "iter_sampling": 1, "fixed_param": True, "show_progress": False, } if mi.config.cmdstanpy_config_predict is not None: sample_args = { **sample_args, **mi.config.cmdstanpy_config_predict, } mcmc_draw = model.sample(**sample_args) idata_draw = az.from_cmdstan( mcmc_draw.runset.csv_files, coords={ "experiment": [ e.id for e in mi.measurements.experiments if e.is_test ], "mic": [m.id for m in mi.kinetic_model.mics], "enzyme": [e.id for e in mi.kinetic_model.enzymes], "reaction": [r.id for r in mi.kinetic_model.reactions], }, dims=dims, ).assign_coords( coords={"chain": [chain], "draw": [draw]}, groups="posterior_groups", ) if draw == 0: idata_chain = idata_draw.copy() else: idata_chain = az.concat( [idata_chain, idata_draw], dim="draw", reset_dim=False ) if chain == 0: out = idata_chain.copy() else: out = az.concat([out, idata_chain], dim="chain", reset_dim=False) return out