예제 #1
0
    def __init__(self, straces):
        self._straces = {}
        for strace in straces:
            if strace.chain in self._straces:
                raise ValueError("Chains are not unique.")
            self._straces[strace.chain] = strace

        self._report = SamplerReport()
        for strace in straces:
            if hasattr(strace, "_warnings"):
                self._report._add_warnings(strace._warnings, strace.chain)
예제 #2
0
class MultiTrace:
    """Main interface for accessing values from MCMC results.

    The core method to select values is `get_values`. The method
    to select sampler statistics is `get_sampler_stats`. Both kinds of
    values can also be accessed by indexing the MultiTrace object.
    Indexing can behave in four ways:

    1. Indexing with a variable or variable name (str) returns all
       values for that variable, combining values for all chains.

       >>> trace[varname]

       Slicing after the variable name can be used to burn and thin
       the samples.

       >>> trace[varname, 1000:]

       For convenience during interactive use, values can also be
       accessed using the variable as an attribute.

       >>> trace.varname

    2. Indexing with an integer returns a dictionary with values for
       each variable at the given index (corresponding to a single
       sampling iteration).

    3. Slicing with a range returns a new trace with the number of draws
       corresponding to the range.

    4. Indexing with the name of a sampler statistic that is not also
       the name of a variable returns those values from all chains.
       If there is more than one sampler that provides that statistic,
       the values are concatenated along a new axis.

    For any methods that require a single trace (e.g., taking the length
    of the MultiTrace instance, which returns the number of draws), the
    trace with the highest chain number is always used.

    Attributes
    ----------
    nchains: int
        Number of chains in the `MultiTrace`.
    chains: `List[int]`
        List of chain indices
    report: str
        Report on the sampling process.
    varnames: `List[str]`
        List of variable names in the trace(s)
    """
    def __init__(self, straces):
        self._straces = {}
        for strace in straces:
            if strace.chain in self._straces:
                raise ValueError("Chains are not unique.")
            self._straces[strace.chain] = strace

        self._report = SamplerReport()
        for strace in straces:
            if hasattr(strace, "_warnings"):
                self._report._add_warnings(strace._warnings, strace.chain)

    def __repr__(self):
        template = "<{}: {} chains, {} iterations, {} variables>"
        return template.format(self.__class__.__name__, self.nchains,
                               len(self), len(self.varnames))

    @property
    def nchains(self):
        return len(self._straces)

    @property
    def chains(self):
        return list(sorted(self._straces.keys()))

    @property
    def report(self):
        return self._report

    def __iter__(self):
        raise NotImplementedError

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return self._slice(idx)

        try:
            return self.point(int(idx))
        except (ValueError, TypeError):  # Passed variable or variable name.
            pass

        if isinstance(idx, tuple):
            var, vslice = idx
            burn, thin = vslice.start, vslice.step
            if burn is None:
                burn = 0
            if thin is None:
                thin = 1
        else:
            var = idx
            burn, thin = 0, 1

        var = get_var_name(var)
        if var in self.varnames:
            if var in self.stat_names:
                warnings.warn(
                    "Attribute access on a trace object is ambigous. "
                    "Sampler statistic and model variable share a name. Use "
                    "trace.get_values or trace.get_sampler_stats.")
            return self.get_values(var, burn=burn, thin=thin)
        if var in self.stat_names:
            return self.get_sampler_stats(var, burn=burn, thin=thin)
        raise KeyError("Unknown variable %s" % var)

    _attrs = {
        "_straces", "varnames", "chains", "stat_names",
        "supports_sampler_stats", "_report"
    }

    def __getattr__(self, name):
        # Avoid infinite recursion when called before __init__
        # variables are set up (e.g., when pickling).
        if name in self._attrs:
            raise AttributeError

        name = get_var_name(name)
        if name in self.varnames:
            if name in self.stat_names:
                warnings.warn(
                    "Attribute access on a trace object is ambigous. "
                    "Sampler statistic and model variable share a name. Use "
                    "trace.get_values or trace.get_sampler_stats.")
            return self.get_values(name)
        if name in self.stat_names:
            return self.get_sampler_stats(name)
        raise AttributeError(
            f"'{type(self).__name__}' object has no attribute '{name}'")

    def __len__(self):
        chain = self.chains[-1]
        return len(self._straces[chain])

    @property
    def varnames(self):
        chain = self.chains[-1]
        return self._straces[chain].varnames

    @property
    def stat_names(self):
        if not self._straces:
            return set()
        sampler_vars = [s.sampler_vars for s in self._straces.values()]
        if not all(svars == sampler_vars[0] for svars in sampler_vars):
            raise ValueError(
                "Inividual chains contain different sampler stats")
        names = set()
        for trace in self._straces.values():
            if trace.sampler_vars is None:
                continue
            for vars in trace.sampler_vars:
                names.update(vars.keys())
        return names

    def add_values(self, vals, overwrite=False) -> None:
        """Add variables to traces.

        Parameters
        ----------
        vals: dict (str: array-like)
             The keys should be the names of the new variables. The values are expected to be
             array-like objects. For traces with more than one chain the length of each value
             should match the number of total samples already in the trace `(chains * iterations)`,
             otherwise a warning is raised.
        overwrite: bool
            If `False` (default) a ValueError is raised if the variable already exists.
            Change to `True` to overwrite the values of variables

        Returns
        -------
            None.
        """
        for k, v in vals.items():
            new_var = 1
            if k in self.varnames:
                if overwrite:
                    self.varnames.remove(k)
                    new_var = 0
                else:
                    raise ValueError(f"Variable name {k} already exists.")

            self.varnames.append(k)

            chains = self._straces
            l_samples = len(self) * len(self.chains)
            l_v = len(v)
            if l_v != l_samples:
                warnings.warn("The length of the values you are trying to "
                              "add ({}) does not match the number ({}) of "
                              "total samples in the trace "
                              "(chains * iterations)".format(l_v, l_samples))

            v = np.squeeze(v.reshape(len(chains), len(self), -1))

            for idx, chain in enumerate(chains.values()):
                if new_var:
                    dummy = at.as_tensor_variable([], k)
                    chain.vars.append(dummy)
                chain.samples[k] = v[idx]

    def remove_values(self, name):
        """remove variables from traces.

        Parameters
        ----------
        name: str
            Name of the variable to remove. Raises KeyError if the variable is not present
        """
        varnames = self.varnames
        if name not in varnames:
            raise KeyError(f"Unknown variable {name}")
        self.varnames.remove(name)
        chains = self._straces
        for chain in chains.values():
            for va in chain.vars:
                if va.name == name:
                    chain.vars.remove(va)
                    del chain.samples[name]

    def get_values(self,
                   varname,
                   burn=0,
                   thin=1,
                   combine=True,
                   chains=None,
                   squeeze=True):
        """Get values from traces.

        Parameters
        ----------
        varname: str
        burn: int
        thin: int
        combine: bool
            If True, results from `chains` will be concatenated.
        chains: int or list of ints
            Chains to retrieve. If None, all chains are used. A single
            chain value can also be given.
        squeeze: bool
            Return a single array element if the resulting list of
            values only has one element. If False, the result will
            always be a list of arrays, even if `combine` is True.

        Returns
        -------
        A list of NumPy arrays or a single NumPy array (depending on
        `squeeze`).
        """
        if chains is None:
            chains = self.chains
        varname = get_var_name(varname)
        try:
            results = [
                self._straces[chain].get_values(varname, burn, thin)
                for chain in chains
            ]
        except TypeError:  # Single chain passed.
            results = [self._straces[chains].get_values(varname, burn, thin)]
        return _squeeze_cat(results, combine, squeeze)

    def get_sampler_stats(self,
                          stat_name,
                          burn=0,
                          thin=1,
                          combine=True,
                          chains=None,
                          squeeze=True):
        """Get sampler statistics from the trace.

        Parameters
        ----------
        stat_name: str
        sampler_idx: int or None
        burn: int
        thin: int

        Returns
        -------
        If the `sampler_idx` is specified, return the statistic with
        the given name in a numpy array. If it is not specified and there
        is more than one sampler that provides this statistic, return
        a numpy array of shape (m, n), where `m` is the number of
        such samplers, and `n` is the number of samples.
        """
        if stat_name not in self.stat_names:
            raise KeyError("Unknown sampler statistic %s" % stat_name)

        if chains is None:
            chains = self.chains
        try:
            chains = iter(chains)
        except TypeError:
            chains = [chains]

        results = [
            self._straces[chain].get_sampler_stats(stat_name, None, burn, thin)
            for chain in chains
        ]
        return _squeeze_cat(results, combine, squeeze)

    def _slice(self, slice):
        """Return a new MultiTrace object sliced according to `slice`."""
        new_traces = [trace._slice(slice) for trace in self._straces.values()]
        trace = MultiTrace(new_traces)
        idxs = slice.indices(len(self))
        trace._report = self._report._slice(*idxs)
        return trace

    def point(self, idx, chain=None):
        """Return a dictionary of point values at `idx`.

        Parameters
        ----------
        idx: int
        chain: int
            If a chain is not given, the highest chain number is used.
        """
        if chain is None:
            chain = self.chains[-1]
        return self._straces[chain].point(idx)

    def points(self, chains=None):
        """Return an iterator over all or some of the sample points

        Parameters
        ----------
        chains: list of int or N
            The chains whose points should be inlcuded in the iterator.  If
            chains is not given, include points from all chains.
        """
        if chains is None:
            chains = self.chains

        return itl.chain.from_iterable(self._straces[chain]
                                       for chain in chains)