Esempio n. 1
0
    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return self._slice(idx)

        try:
            return self.point(int(idx))
        except (ValueError, TypeError):  # Passed variable or variable name.
            pass

        if isinstance(idx, tuple):
            var, vslice = idx
            burn, thin = vslice.start, vslice.step
            if burn is None:
                burn = 0
            if thin is None:
                thin = 1
        else:
            var = idx
            burn, thin = 0, 1

        var = get_var_name(var)
        if var in self.varnames:
            if var in self.stat_names:
                warnings.warn(
                    "Attribute access on a trace object is ambigous. "
                    "Sampler statistic and model variable share a name. Use "
                    "trace.get_values or trace.get_sampler_stats.")
            return self.get_values(var, burn=burn, thin=thin)
        if var in self.stat_names:
            return self.get_sampler_stats(var, burn=burn, thin=thin)
        raise KeyError("Unknown variable %s" % var)
Esempio n. 2
0
    def get_values(self, varname, burn=0, thin=1, combine=True, chains=None, squeeze=True):
        """Get values from traces.

        Parameters
        ----------
        varname: str
        burn: int
        thin: int
        combine: bool
            If True, results from `chains` will be concatenated.
        chains: int or list of ints
            Chains to retrieve. If None, all chains are used. A single
            chain value can also be given.
        squeeze: bool
            Return a single array element if the resulting list of
            values only has one element. If False, the result will
            always be a list of arrays, even if `combine` is True.

        Returns
        -------
        A list of NumPy arrays or a single NumPy array (depending on
        `squeeze`).
        """
        if chains is None:
            chains = self.chains
        varname = get_var_name(varname)
        try:
            results = [self._straces[chain].get_values(varname, burn, thin) for chain in chains]
        except TypeError:  # Single chain passed.
            results = [self._straces[chains].get_values(varname, burn, thin)]
        return _squeeze_cat(results, combine, squeeze)
Esempio n. 3
0
 def __init__(self, vars, shared, blocked=True):
     """
     Parameters
     ----------
     vars: list of sampling value variables
     shared: dict of Aesara variable -> shared variable
     blocked: Boolean (default True)
     """
     self.vars = vars
     self.shared = {get_var_name(var): shared for var, shared in shared.items()}
     self.blocked = blocked
Esempio n. 4
0
 def setup_class(cls):
     super().setup_class()
     cls.model = cls.make_model()
     with cls.model:
         cls.step = cls.make_step()
         cls.trace = pm.sample(
             cls.n_samples,
             tune=cls.tune,
             step=cls.step,
             cores=cls.chains,
             return_inferencedata=False,
             compute_convergence_checks=False,
         )
     cls.samples = {}
     for var in cls.model.unobserved_RVs:
         cls.samples[get_var_name(var)] = cls.trace.get_values(var, burn=cls.burn)
Esempio n. 5
0
    def get_parent_names(self, var: TensorVariable) -> Set[VarName]:
        if var.owner is None or var.owner.inputs is None:
            return set()

        def _expand(x):
            if x.name:
                return [x]
            if isinstance(x.owner, Apply):
                return reversed(x.owner.inputs)
            return []

        parents = {
            get_var_name(x)
            for x in walk(nodes=var.owner.inputs, expand=_expand) if x.name
        }

        return parents
Esempio n. 6
0
    def __getattr__(self, name):
        # Avoid infinite recursion when called before __init__
        # variables are set up (e.g., when pickling).
        if name in self._attrs:
            raise AttributeError

        name = get_var_name(name)
        if name in self.varnames:
            if name in self.stat_names:
                warnings.warn(
                    "Attribute access on a trace object is ambigous. "
                    "Sampler statistic and model variable share a name. Use "
                    "trace.get_values or trace.get_sampler_stats.")
            return self.get_values(name)
        if name in self.stat_names:
            return self.get_sampler_stats(name)
        raise AttributeError(
            f"'{type(self).__name__}' object has no attribute '{name}'")
Esempio n. 7
0
    def get_parent_names(self, var: TensorVariable) -> Set[VarName]:
        if var.owner is None or var.owner.inputs is None:
            return set()

        def _expand(x):
            if x.name:
                return [x]
            if isinstance(x.owner, Apply):
                return reversed(x.owner.inputs)
            return []

        parents = {
            get_var_name(x)
            for x in walk(nodes=var.owner.inputs, expand=_expand)
            # Only consider nodes that are in the named model variables.
            if x.name and x.name in self._all_var_names
        }

        return parents
Esempio n. 8
0
 def flat_t(var):
     x = trace[get_var_name(var)]
     return x.reshape((x.shape[0], np.prod(x.shape[1:], dtype=int)))