def __init__(self, straces): self._straces = {} for strace in straces: if strace.chain in self._straces: raise ValueError("Chains are not unique.") self._straces[strace.chain] = strace self._report = SamplerReport() for strace in straces: if hasattr(strace, "_warnings"): self._report._add_warnings(strace._warnings, strace.chain)
class MultiTrace: """Main interface for accessing values from MCMC results. The core method to select values is `get_values`. The method to select sampler statistics is `get_sampler_stats`. Both kinds of values can also be accessed by indexing the MultiTrace object. Indexing can behave in four ways: 1. Indexing with a variable or variable name (str) returns all values for that variable, combining values for all chains. >>> trace[varname] Slicing after the variable name can be used to burn and thin the samples. >>> trace[varname, 1000:] For convenience during interactive use, values can also be accessed using the variable as an attribute. >>> trace.varname 2. Indexing with an integer returns a dictionary with values for each variable at the given index (corresponding to a single sampling iteration). 3. Slicing with a range returns a new trace with the number of draws corresponding to the range. 4. Indexing with the name of a sampler statistic that is not also the name of a variable returns those values from all chains. If there is more than one sampler that provides that statistic, the values are concatenated along a new axis. For any methods that require a single trace (e.g., taking the length of the MultiTrace instance, which returns the number of draws), the trace with the highest chain number is always used. Attributes ---------- nchains: int Number of chains in the `MultiTrace`. chains: `List[int]` List of chain indices report: str Report on the sampling process. varnames: `List[str]` List of variable names in the trace(s) """ def __init__(self, straces): self._straces = {} for strace in straces: if strace.chain in self._straces: raise ValueError("Chains are not unique.") self._straces[strace.chain] = strace self._report = SamplerReport() for strace in straces: if hasattr(strace, "_warnings"): self._report._add_warnings(strace._warnings, strace.chain) def __repr__(self): template = "<{}: {} chains, {} iterations, {} variables>" return template.format(self.__class__.__name__, self.nchains, len(self), len(self.varnames)) @property def nchains(self): return len(self._straces) @property def chains(self): return list(sorted(self._straces.keys())) @property def report(self): return self._report def __iter__(self): raise NotImplementedError def __getitem__(self, idx): if isinstance(idx, slice): return self._slice(idx) try: return self.point(int(idx)) except (ValueError, TypeError): # Passed variable or variable name. pass if isinstance(idx, tuple): var, vslice = idx burn, thin = vslice.start, vslice.step if burn is None: burn = 0 if thin is None: thin = 1 else: var = idx burn, thin = 0, 1 var = get_var_name(var) if var in self.varnames: if var in self.stat_names: warnings.warn( "Attribute access on a trace object is ambigous. " "Sampler statistic and model variable share a name. Use " "trace.get_values or trace.get_sampler_stats.") return self.get_values(var, burn=burn, thin=thin) if var in self.stat_names: return self.get_sampler_stats(var, burn=burn, thin=thin) raise KeyError("Unknown variable %s" % var) _attrs = { "_straces", "varnames", "chains", "stat_names", "supports_sampler_stats", "_report" } def __getattr__(self, name): # Avoid infinite recursion when called before __init__ # variables are set up (e.g., when pickling). if name in self._attrs: raise AttributeError name = get_var_name(name) if name in self.varnames: if name in self.stat_names: warnings.warn( "Attribute access on a trace object is ambigous. " "Sampler statistic and model variable share a name. Use " "trace.get_values or trace.get_sampler_stats.") return self.get_values(name) if name in self.stat_names: return self.get_sampler_stats(name) raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'") def __len__(self): chain = self.chains[-1] return len(self._straces[chain]) @property def varnames(self): chain = self.chains[-1] return self._straces[chain].varnames @property def stat_names(self): if not self._straces: return set() sampler_vars = [s.sampler_vars for s in self._straces.values()] if not all(svars == sampler_vars[0] for svars in sampler_vars): raise ValueError( "Inividual chains contain different sampler stats") names = set() for trace in self._straces.values(): if trace.sampler_vars is None: continue for vars in trace.sampler_vars: names.update(vars.keys()) return names def add_values(self, vals, overwrite=False) -> None: """Add variables to traces. Parameters ---------- vals: dict (str: array-like) The keys should be the names of the new variables. The values are expected to be array-like objects. For traces with more than one chain the length of each value should match the number of total samples already in the trace `(chains * iterations)`, otherwise a warning is raised. overwrite: bool If `False` (default) a ValueError is raised if the variable already exists. Change to `True` to overwrite the values of variables Returns ------- None. """ for k, v in vals.items(): new_var = 1 if k in self.varnames: if overwrite: self.varnames.remove(k) new_var = 0 else: raise ValueError(f"Variable name {k} already exists.") self.varnames.append(k) chains = self._straces l_samples = len(self) * len(self.chains) l_v = len(v) if l_v != l_samples: warnings.warn("The length of the values you are trying to " "add ({}) does not match the number ({}) of " "total samples in the trace " "(chains * iterations)".format(l_v, l_samples)) v = np.squeeze(v.reshape(len(chains), len(self), -1)) for idx, chain in enumerate(chains.values()): if new_var: dummy = at.as_tensor_variable([], k) chain.vars.append(dummy) chain.samples[k] = v[idx] def remove_values(self, name): """remove variables from traces. Parameters ---------- name: str Name of the variable to remove. Raises KeyError if the variable is not present """ varnames = self.varnames if name not in varnames: raise KeyError(f"Unknown variable {name}") self.varnames.remove(name) chains = self._straces for chain in chains.values(): for va in chain.vars: if va.name == name: chain.vars.remove(va) del chain.samples[name] def get_values(self, varname, burn=0, thin=1, combine=True, chains=None, squeeze=True): """Get values from traces. Parameters ---------- varname: str burn: int thin: int combine: bool If True, results from `chains` will be concatenated. chains: int or list of ints Chains to retrieve. If None, all chains are used. A single chain value can also be given. squeeze: bool Return a single array element if the resulting list of values only has one element. If False, the result will always be a list of arrays, even if `combine` is True. Returns ------- A list of NumPy arrays or a single NumPy array (depending on `squeeze`). """ if chains is None: chains = self.chains varname = get_var_name(varname) try: results = [ self._straces[chain].get_values(varname, burn, thin) for chain in chains ] except TypeError: # Single chain passed. results = [self._straces[chains].get_values(varname, burn, thin)] return _squeeze_cat(results, combine, squeeze) def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True, chains=None, squeeze=True): """Get sampler statistics from the trace. Parameters ---------- stat_name: str sampler_idx: int or None burn: int thin: int Returns ------- If the `sampler_idx` is specified, return the statistic with the given name in a numpy array. If it is not specified and there is more than one sampler that provides this statistic, return a numpy array of shape (m, n), where `m` is the number of such samplers, and `n` is the number of samples. """ if stat_name not in self.stat_names: raise KeyError("Unknown sampler statistic %s" % stat_name) if chains is None: chains = self.chains try: chains = iter(chains) except TypeError: chains = [chains] results = [ self._straces[chain].get_sampler_stats(stat_name, None, burn, thin) for chain in chains ] return _squeeze_cat(results, combine, squeeze) def _slice(self, slice): """Return a new MultiTrace object sliced according to `slice`.""" new_traces = [trace._slice(slice) for trace in self._straces.values()] trace = MultiTrace(new_traces) idxs = slice.indices(len(self)) trace._report = self._report._slice(*idxs) return trace def point(self, idx, chain=None): """Return a dictionary of point values at `idx`. Parameters ---------- idx: int chain: int If a chain is not given, the highest chain number is used. """ if chain is None: chain = self.chains[-1] return self._straces[chain].point(idx) def points(self, chains=None): """Return an iterator over all or some of the sample points Parameters ---------- chains: list of int or N The chains whose points should be inlcuded in the iterator. If chains is not given, include points from all chains. """ if chains is None: chains = self.chains return itl.chain.from_iterable(self._straces[chain] for chain in chains)