def convert_to_arviz_inference_data( traces, chain_stats, sample_stats_key=None): """Wrap chain outputs in an `arviz.InferenceData` container object. The `traces` and `chain_stats` arguments should correspond to a multiple-chain sampler output i.e. the returned values from a `sample_chains` call. Args: traces (Dict[str, List[array]]): Trace arrays, with one entry per function in `trace_funcs` passed to sampler method. Each entry consists of a list of arrays, one per chain, with the first axes of the arrays corresponding to the sampling (draw) index. chain_stats (Dict[str, List[array]]): Chain integration transition statistics as a dictionary with string keys describing the statistics recorded and values corresponding to a list of arrays with one array per chain and the first axis of the arrays corresponding to the sampling index. sample_stats_key (str): Optional. Key of transition in `chain_stats` to use the recorded statistics of to populate the `sampling_stats` group in the returned `InferenceData` object. Returns: arviz.InferenceData: An arviz data container with groups `posterior` and 'sample_stats', both of instances of `xarray.Dataset`. The `posterior` group corresponds to the chain variable traces provides in the `traces` argument and the `sample_stats` group corresponds to the chain transitions statistics passed in the `chain_stats` argument (if multiple transition statistics dictionaries are present the `sample_stats_key` argument should be specified to indicate which to use). """ if (sample_stats_key is not None and sample_stats_key not in chain_stats): raise ValueError( f'Specified `sample_stats_key` ({sample_stats_key}) does ' f'not match any transition in `chain_stats`.') if sample_stats_key is not None: return arviz.InferenceData( posterior=arviz.dict_to_dataset(traces, library=mici), sample_stats=arviz.dict_to_dataset( chain_stats[sample_stats_key], library=mici)) elif not isinstance(next(iter(chain_stats.values())), dict): # chain_stats dictionary value not another dictionary therefore # assume corresponds to statistics for a single transition return arviz.InferenceData( posterior=arviz.dict_to_dataset(traces, library=mici), sample_stats=arviz.dict_to_dataset(chain_stats, library=mici)) elif len(chain_stats) == 1: # single transtition statistics dictionary in chain_stats therefore # unambiguous to set sample_stats return arviz.InferenceData( posterior=arviz.dict_to_dataset(traces, library=mici), sample_stats=arviz.dict_to_dataset( chain_stats.popitem()[1], library=mici)) else: raise ValueError( '`sample_stats_key` must be specified as `chain_stats` ' 'contains multiple transtitiion statistics dictionaries.')
def sample_chains_arviz(self, n_sample, init_states, chain_var_funcs=None, n_process=1): """Sample one or more Markov chains from given initial states. Performs a specified number of chain iterations (each of which may be composed of multiple individual Markov transitions), recording the outputs of functions of the sampled chain state after each iteration. The chains may be run in parallel across multiple independent processes or sequentially. Chain data is returned in an `arviz.InferenceData` container object. Args: n_sample (int): Number of samples to draw per chain. init_states (Iterable[HamiltonianState] or Iterable[array]): Initial chain states. Each state can be either an array specifying the state position component or a `HamiltonianState` instance. If an array is passed or the `mom` attribute of the state is not set, a momentum component will be independently sampled from its conditional distribution. One chain will be run for each state in the iterable sequence. chain_var_funcs (dict[str, callable]): Dictionary of functions which compute the chain variables to be recorded at each iteration, with each function being passed the current state and returning an array corresponding to the variable(s) to be stored. By default (or if set to `None`) a single function which returns the position component of the state is used. The keys to the functions are used to index the chain variable arrays in the returned data. n_process (int or None): Number of parallel processes to run chains over. If set to one then chains will be run sequentially in otherwise a `multiprocessing.Pool` object will be used to dynamically assign the chains across multiple processes. If set to `None` then the number of processes will default to the output of `os.cpu_count()`. Returns: arvix.InferenceData: An arviz data container with groups `posterior` and 'sample_stats', both of instances of `xarray.Dataset`. The `posterior` group corresponds to the chain variable samples computed using the `chain_var_funcs` entries (with the data variable keys corresponding to the keys there). The `sample_stats` group corresponds to the statistics of the integration transition such as the acceptance probabilities and number of integrator steps. """ chains, chain_stats = self.sample_chains(n_sample, init_states, chain_var_funcs, n_process) return arviz.InferenceData( posterior=arviz.dict_to_dataset(chains, library=hmc), sample_stats=arviz.dict_to_dataset(chain_stats, library=hmc))
def concatenate_inferences( inf_list: List[az.InferenceData], coords: dict, concatenation_name: str = "feature" ) -> az.InferenceData: """Concatenates multiple single feature fits into one object. :param inf_list: List of InferenceData objects for each feature :type inf_list: List[az.InferenceData] :param coords: Coordinates containing concatenation name labels :type coords: dict :param concatenation_name: Name of feature dimension used when concatenating, defaults to "feature" :type concatenation_name: str :returns: Combined InferenceData object :rtype: az.InferenceData """ group_list = [] group_list.append([x.posterior for x in inf_list]) group_list.append([x.sample_stats for x in inf_list]) if "log_likelihood" in inf_list[0].groups(): group_list.append([x.log_likelihood for x in inf_list]) if "posterior_predictive" in inf_list[0].groups(): group_list.append([x.posterior_predictive for x in inf_list]) po_ds = xr.concat(group_list[0], concatenation_name) ss_ds = xr.concat(group_list[1], concatenation_name) group_dict = {"posterior": po_ds, "sample_stats": ss_ds} if "log_likelihood" in inf_list[0].groups(): ll_ds = xr.concat(group_list[2], concatenation_name) group_dict["log_likelihood"] = ll_ds if "posterior_predictive" in inf_list[0].groups(): pp_ds = xr.concat(group_list[3], concatenation_name) group_dict["posterior_predictive"] = pp_ds all_group_inferences = [] for group in group_dict: # Set concatenation dim coords group_ds = group_dict[group].assign_coords( {concatenation_name: coords[concatenation_name]} ) group_inf = az.InferenceData(**{group: group_ds}) # hacky all_group_inferences.append(group_inf) return az.concat(*all_group_inferences)
def convert_pyjags_samples_dict_to_arviz_inference_data( samples: tp.Dict[str, np.ndarray]) -> az.InferenceData: # pyjags returns a dictionary of numpy arrays with shape # (parameter dimension, chain length, number of chains) # but arviz expects samples with shape # (number of chains, chain length, parameter dimension) parameter_name_to_samples_map = {} for parameter_name, chains in samples.items(): parameter_dimension, chain_length, number_of_chains = chains.shape if parameter_dimension == 1: parameter_name_to_samples_map[parameter_name] = \ chains[0, :, :].transpose() else: for i in range(parameter_dimension): parameter_name_to_samples_map[f'{parameter_name}_{i+1}'] = \ chains[i, :, :].transpose() return az.InferenceData( posterior=az.data.base.dict_to_dataset(parameter_name_to_samples_map))
def merge_inferences(inf_list, log_likelihood, posterior_predictive, coords, concatenation_name='features'): group_list = [] group_list.append(dask.persist(*[x.posterior for x in inf_list])) group_list.append(dask.persist(*[x.sample_stats for x in inf_list])) if log_likelihood is not None: group_list.append(dask.persist(*[x.log_likelihood for x in inf_list])) if posterior_predictive is not None: group_list.append( dask.persist(*[x.posterior_predictive for x in inf_list])) group_list = dask.compute(*group_list) po_ds = xr.concat(group_list[0], concatenation_name) ss_ds = xr.concat(group_list[1], concatenation_name) group_dict = {"posterior": po_ds, "sample_stats": ss_ds} if log_likelihood is not None: ll_ds = xr.concat(group_list[2], concatenation_name) group_dict["log_likelihood"] = ll_ds if posterior_predictive is not None: pp_ds = xr.concat(group_list[3], concatenation_name) group_dict["posterior_predictive"] = pp_ds all_group_inferences = [] for group in group_dict: # Set concatenation dim coords group_ds = group_dict[group].assign_coords( {concatenation_name: coords[concatenation_name]}) group_inf = az.InferenceData(**{group: group_ds}) # hacky all_group_inferences.append(group_inf) return az.concat(*all_group_inferences)
def sample_chains_arviz(self, n_sample, init_states, chain_var_funcs=None, sample_stats_key=None, **kwargs): """Sample one or more Markov chains from given initial states. Performs a specified number of chain iterations (each of which may be composed of multiple individual Markov transitions), recording the outputs of functions of the sampled chain state after each iteration. The chains may be run in parallel across multiple independent processes or sequentially. Chain data is returned in an `arviz.InferenceData` container object. Args: n_sample (int): Number of samples to draw per chain. init_states (Iterable[ChainState] or Iterable[array]): Initial chain states. Each entry can be either an array specifying the state or a `ChainState` instance. One chain will be run for each state in the iterable sequence. chain_var_funcs (dict[str, callable]): Dictionary of functions which compute the chain variables to be recorded at each iteration, with each function being passed the current state and returning an array corresponding to the variable(s) to be stored. The keys to the functions are used to index the chain variable arrays in the returned data. sample_stats_key (str): Key of transition to use the recorded statistics of to populate the `sampling_stats` group in the returned `InferenceData` object. Kwargs: n_process (int or None): Number of parallel processes to run chains over. If set to one then chains will be run sequentially in otherwise a `multiprocessing.Pool` object will be used to dynamically assign the chains across multiple processes. If set to `None` then the number of processes will default to the output of `os.cpu_count()`. memmap_enabled (bool): Whether to memory-map arrays used to store chain data to files on disk to avoid excessive system memory usage for long chains and/or high memory chain states. The chain data is written to `.npy` files in the directory specified by `memmap_path` (or a temporary directory if not provided). These files persist after the termination of the function so should be manually deleted when no longer required. memmap_path (str): Path to directory to write memory-mapped chain data to. If not provided, a temporary directory will be created and the chain data written to files there. Returns: arvix.InferenceData: An arviz data container with groups `posterior` and 'sample_stats', both of instances of `xarray.Dataset`. The `posterior` group corresponds to the chain variable samples computed using the `chain_var_funcs` entries (with the data variable keys corresponding to the keys there). The `sample_stats` group corresponds to the statistics of the transition indicated by the `sample_stats_key` argument. """ if (sample_stats_key is not None and sample_stats_key not in self.transitions): raise ValueError( f'Specified `sample_stats_key` ({sample_stats_key}) does ' f'not match any transition.') chains, chain_stats = self.sample_chains(n_sample, init_states, chain_var_funcs, **kwargs) if sample_stats_key is None: return arviz.InferenceData( posterior=arviz.dict_to_dataset(chains, library=hmc)) else: return arviz.InferenceData(posterior=arviz.dict_to_dataset( chains, library=hmc), sample_stats=arviz.dict_to_dataset( chain_stats[sample_stats_key], library=hmc))
def main(args): print("Loading data...") teams, df = load_data() train = df[df["split"] == "train"] print("Starting inference...") samples = bm.GlobalNoUTurnSampler().infer( queries=[ alpha(), home(), sd_att(), sd_def(), attack(), defend(), ], observations={ s1(): torch.tensor(train["score1"].values), s2(): torch.tensor(train["score2"].values), }, num_samples=args.num_samples, num_chains=args.num_chains, num_adaptive_samples=args.num_warmup, ) samples = samples.to_xarray() fit = az.InferenceData(posterior=samples) print("Analyse posterior...") az.plot_forest( fit, backend="bokeh", ) az.plot_trace( fit, backend="bokeh", ) # Attack and defence quality = teams.copy() quality = quality.assign( attack=samples[attack()].mean(axis=(0, 1)), attacksd=samples[attack()].std(axis=(0, 1)), defend=samples[defend()].mean(axis=(0, 1)), defendsd=samples[defend()].std(axis=(0, 1)), ) quality = quality.assign( attack_low=quality["attack"] - quality["attacksd"], attack_high=quality["attack"] + quality["attacksd"], defend_low=quality["defend"] - quality["defendsd"], defend_high=quality["defend"] + quality["defendsd"], ) plot_quality(quality) # Predicted goals and table predict = df[df["split"] == "predict"] theta1 = (samples[alpha()].expand_dims("", axis=-1).values + samples[home()].expand_dims("", axis=-1).values + samples[attack()][:, :, predict["Home_id"]].values - samples[defend()][:, :, predict["Away_id"]].values) theta1 = torch.tensor(theta1.reshape(-1, theta1.shape[-1])) theta2 = (samples[alpha()].expand_dims("", axis=-1).values + samples[attack()][:, :, predict["Away_id"]].values - samples[defend()][:, :, predict["Home_id"]].values) theta2 = torch.tensor(theta2.reshape(-1, theta2.shape[-1])) score1 = np.array(dist.Poisson(torch.exp(theta1)).sample()) score2 = np.array(dist.Poisson(torch.exp(theta2)).sample()) predicted_full = predict.copy() predicted_full = predicted_full.assign( score1=score1.mean(axis=0).round(), score1error=score1.std(axis=0), score2=score2.mean(axis=0).round(), score2error=score2.std(axis=0), ) predicted_full = train.append( predicted_full.drop(columns=["score1error", "score2error"])) print(score_table(df)) print(score_table(predicted_full))