Beispiel #1
0
 def posterior_to_xarray(self):
     """Convert the posterior to an xarray dataset."""
     var_names = get_default_varnames(self.trace.varnames,
                                      include_transformed=False)
     data = {}
     data_warmup = {}
     for var_name in var_names:
         if self.warmup_trace:
             data_warmup[var_name] = np.array(
                 self.warmup_trace.get_values(var_name,
                                              combine=False,
                                              squeeze=False))
         if self.posterior_trace:
             data[var_name] = np.array(
                 self.posterior_trace.get_values(var_name,
                                                 combine=False,
                                                 squeeze=False))
     return (
         dict_to_dataset(
             data,
             library=pymc,
             coords=self.coords,
             dims=self.dims,
             attrs=self.attrs,
         ),
         dict_to_dataset(
             data_warmup,
             library=pymc,
             coords=self.coords,
             dims=self.dims,
             attrs=self.attrs,
         ),
     )
Beispiel #2
0
 def log_likelihood_to_xarray(self):
     """Extract log likelihood and log_p data from PyMC trace."""
     if self.predictions or not self.log_likelihood:
         return None
     data_warmup = {}
     data = {}
     warn_msg = ("Could not compute log_likelihood, it will be omitted. "
                 "Check your model object or set log_likelihood=False")
     if self.posterior_trace:
         try:
             data = self._extract_log_likelihood(self.posterior_trace)
         except TypeError:
             warnings.warn(warn_msg)
     if self.warmup_trace:
         try:
             data_warmup = self._extract_log_likelihood(self.warmup_trace)
         except TypeError:
             warnings.warn(warn_msg)
     return (
         dict_to_dataset(
             data,
             library=pymc,
             dims=self.dims,
             coords=self.coords,
             skip_event_dims=True,
         ),
         dict_to_dataset(
             data_warmup,
             library=pymc,
             dims=self.dims,
             coords=self.coords,
             skip_event_dims=True,
         ),
     )
Beispiel #3
0
    def priors_to_xarray(self):
        """Convert prior samples (and if possible prior predictive too) to xarray."""
        if self.prior is None:
            return {"prior": None, "prior_predictive": None}
        if self.observations is not None:
            prior_predictive_vars = list(self.observations.keys())
            prior_vars = [
                key for key in self.prior.keys()
                if key not in prior_predictive_vars
            ]
        else:
            prior_vars = list(self.prior.keys())
            prior_predictive_vars = None

        priors_dict = {}
        for group, var_names in zip(("prior", "prior_predictive"),
                                    (prior_vars, prior_predictive_vars)):
            priors_dict[group] = (
                None if var_names is None else dict_to_dataset(
                    {k: np.expand_dims(self.prior[k], 0)
                     for k in var_names},
                    library=pymc,
                    coords=self.coords,
                    dims=self.dims,
                ))
        return priors_dict
Beispiel #4
0
 def observed_data_to_xarray(self):
     """Convert observed data to xarray."""
     if self.predictions:
         return None
     return dict_to_dataset(
         self.observations,
         library=pymc,
         coords=self.coords,
         dims=self.dims,
         default_dims=[],
     )
Beispiel #5
0
 def warmup_sample_stats(self):
     info = self.raw.warmup_sampling_info
     sample_stats = {
         "lp": info["potential_energy"],
         "acceptance_probability": info["acceptance_probability"],
         "diverging": info["is_divergent"],
         "energy": info["energy"],
         "step_size": info["step_size"],
         "num_integration_steps": info["num_integration_steps"],
     }
     return dict_to_dataset(data=sample_stats, library=mcx)
Beispiel #6
0
    def sample_stats_to_xarray(self):
        """Extract sample_stats from PyMC trace."""
        data = {}
        rename_key = {
            "model_logp": "lp",
            "mean_tree_accept": "acceptance_rate",
            "depth": "tree_depth",
            "tree_size": "n_steps",
        }
        data = {}
        data_warmup = {}
        for stat in self.trace.stat_names:
            name = rename_key.get(stat, stat)
            if name == "tune":
                continue
            if self.warmup_trace:
                data_warmup[name] = np.array(
                    self.warmup_trace.get_sampler_stats(stat, combine=False))
            if self.posterior_trace:
                data[name] = np.array(
                    self.posterior_trace.get_sampler_stats(stat,
                                                           combine=False))

        return (
            dict_to_dataset(
                data,
                library=pymc,
                dims=None,
                coords=self.coords,
                attrs=self.attrs,
            ),
            dict_to_dataset(
                data_warmup,
                library=pymc,
                dims=None,
                coords=self.coords,
                attrs=self.attrs,
            ),
        )
Beispiel #7
0
def test_var_names_filter(var_args):
    """Test var_names filter with partial naming or regular expressions."""
    samples = np.random.randn(10)
    data = dict_to_dataset({
        "alpha": samples,
        "beta1": samples,
        "beta2": samples,
        "p1": samples,
        "p2": samples,
        "phi": samples,
        "theta": samples,
        "theta_t": samples,
    })
    var_names, expected, filter_vars = var_args
    assert _var_names(var_names, data, filter_vars) == expected
Beispiel #8
0
    def log_likelihood(self):
        if self.raw.loglikelihoods:
            loglikelihoods = self.raw.loglikelihoods
        else:

            def compute_in(samples):
                return self.loglikelihood_contributions_fn(**samples)

            def compute(samples):
                in_axes = ({key: 0 for key in self.raw.samples},)
                return jax.vmap(compute_in, in_axes=in_axes)(samples)

            in_axes = ({key: 0 for key in self.raw.samples},)
            loglikelihoods = jax.vmap(compute, in_axes=in_axes)(self.raw.samples)
            self.raw.loglikelihoods = loglikelihoods

        return dict_to_dataset(data=loglikelihoods, library=mcx)
Beispiel #9
0
 def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
     """Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
     data = {}
     for k, ary in dct.items():
         shape = ary.shape
         if shape[0] == self.nchains and shape[1] == self.ndraws:
             data[k] = ary
         elif shape[0] == self.nchains * self.ndraws:
             data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
         else:
             data[k] = np.expand_dims(ary, 0)
             # pylint: disable=line-too-long
             _log.warning(
                 "posterior predictive variable %s's shape not compatible with number of chains and draws. "
                 "This can mean that some draws or even whole chains are not represented.",
                 k,
             )
     return dict_to_dataset(data, library=pymc, coords=self.coords, dims=self.dims)
Beispiel #10
0
    def constant_data_to_xarray(self):
        """Convert constant data to xarray."""
        # For constant data, we are concerned only with deterministics and
        # data.  The constant data vars must be either pm.Data
        # (TensorConstant/SharedVariable) or pm.Deterministic
        constant_data_vars = {}  # type: Dict[str, Var]

        def is_data(name, var) -> bool:
            assert self.model is not None
            return (var not in self.model.deterministics
                    and var not in self.model.observed_RVs
                    and var not in self.model.free_RVs
                    and var not in self.model.potentials
                    and var not in self.model.value_vars
                    and (self.observations is None
                         or name not in self.observations)
                    and isinstance(var, (Constant, SharedVariable)))

        # I don't know how to find pm.Data, except that they are named
        # variables that aren't observed or free RVs, nor are they
        # deterministics, and then we eliminate observations.
        for name, var in self.model.named_vars.items():
            if is_data(name, var):
                constant_data_vars[name] = var

        if not constant_data_vars:
            return None

        constant_data = {}
        for name, vals in constant_data_vars.items():
            if hasattr(vals, "get_value"):
                vals = vals.get_value()
            elif hasattr(vals, "data"):
                vals = vals.data
            constant_data[name] = vals

        return dict_to_dataset(
            constant_data,
            library=pymc,
            coords=self.coords,
            dims=self.dims,
            default_dims=[],
        )
Beispiel #11
0
 def translate_posterior_predictive_dict_to_xarray(self, dct,
                                                   kind) -> xr.Dataset:
     """Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
     data = {}
     warning_vars = []
     for k, ary in dct.items():
         if (ary.shape[0] == self.nchains) and (ary.shape[1]
                                                == self.ndraws):
             data[k] = ary
         else:
             data[k] = np.expand_dims(ary, 0)
             warning_vars.append(k)
     if warning_vars:
         warnings.warn(
             f"The shape of variables {', '.join(warning_vars)} in {kind} group is not compatible "
             "with number of chains and draws. The automatic dimension naming might not have worked. "
             "This can also mean that some draws or even whole chains are not represented.",
             UserWarning,
         )
     return dict_to_dataset(data,
                            library=pymc,
                            coords=self.coords,
                            dims=self.dims)
Beispiel #12
0
 def warmup_posterior(self):
     samples = self.raw.warmup_samples
     return dict_to_dataset(data=samples, library=mcx)