def plot_priors(self, var_names=None): if not self.built: raise ValueError("Cannot plot priors until model is built!") with pm.Model(): # get priors for fixed fx, separately for each level of each predictor dists = [] for fixed_term in self.fixed_terms.values(): if var_names is not None and fixed_term.name not in var_names: continue for i, level in enumerate(fixed_term.levels): params = { k: np.atleast_1d(v)[i % v.size] if isinstance( v, np.ndarray) else v for k, v in fixed_term.prior.args.items() } dists += [ getattr(pm, fixed_term.prior.name)(level, **params) ] # get priors for random effect sigmas for random_term in self.random_terms.values(): if var_names is not None and random_term.name not in var_names: continue prior = random_term.prior.args["sigma"].name params = random_term.prior.args["sigma"].args dists += [ getattr(pm, prior)(random_term.name + "_sigma", **params) ] # add priors on Y params if applicable y_priors = [(k, v) for k, v in self.y.prior.args.items() if isinstance(v, Prior)] if y_priors: for y_prior in y_priors: pm_attr = getattr(pm, y_prior[1].name) y_prior_ = pm_attr("_".join([self.y.name, y_prior[0]]), **y_prior[1].args) dists.extend([y_prior_]) # make the plot! priors_to_plot = {} for i, dist in enumerate(dists): dist_ = dist.distribution if isinstance( dist, pm.model.FreeRV) else dist priors_to_plot[dist.name] = dist_.random(size=1000).flatten() # Probably we should replace this for something else axes = plot_posterior(priors_to_plot, credible_interval=None, point_estimate=None) return axes
def plot_priors( self, draws=5000, var_names=None, random_seed=None, figsize=None, textsize=None, hdi_prob=None, round_to=2, point_estimate="mean", kind="kde", bins=None, omit_offsets=True, omit_group_specific=True, ax=None, ): """ Samples from the prior distribution and plots its marginals. Parameters ---------- draws : int Number of draws to sample from the prior predictive distribution. Defaults to 5000. var_names : str or list A list of names of variables for which to compute the posterior predictive distribution. Defaults to both observed and unobserved RVs. random_seed : int Seed for the random number generator. figsize: tuple Figure size. If ``None`` it will be defined automatically. textsize: float Text size scaling factor for labels, titles and lines. If ``None`` it will be autoscaled based on ``figsize``. hdi_prob: float Plots highest density interval for chosen percentage of density. Use ``'hide'`` to hide the highest density interval. Defaults to 0.94. round_to: int Controls formatting of floats. Defaults to 2 or the integer part, whichever is bigger. point_estimate: str Plot point estimate per variable. Values should be ``'mean'``, ``'median'``, ``'mode'`` or ``None``. Defaults to ``'auto'`` i.e. it falls back to default set in ArviZ's rcParams. kind: str Type of plot to display (``'kde'`` or ``'hist'``) For discrete variables this argument is ignored and a histogram is always used. bins: integer or sequence or 'auto' Controls the number of bins, accepts the same keywords ``matplotlib.hist()`` does. Only works if ``kind == hist``. If ``None`` (default) it will use ``auto`` for continuous variables and ``range(xmin, xmax + 1)`` for discrete variables. omit_offsets: bool Whether to omit offset terms in the plot. Defaults to ``True``. omit_group_specific: bool Whether to omit group specific effects in the plot. Defaults to ``True``. ax: numpy array-like of matplotlib axes or bokeh figures A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create its own array of plot areas (and return it). **kwargs Passed as-is to ``plt.hist()`` or ``plt.plot()`` function depending on the value of ``kind``. Returns ------- axes: matplotlib axes or bokeh figures """ if not self.built: raise ValueError( "Cannot plot priors until model is built!! " "Call .build() to build the model or .fit() to build and sample from the posterior." ) unobserved_rvs_names = [] flat_rvs = [] for unobserved in self.backend.model.unobserved_RVs: if "Flat" in unobserved.__str__(): flat_rvs.append(unobserved.name) else: unobserved_rvs_names.append(unobserved.name) if var_names is None: var_names = pm.util.get_default_varnames(unobserved_rvs_names, include_transformed=False) else: flat_rvs = [fv for fv in flat_rvs if fv in var_names] var_names = [vn for vn in var_names if vn not in flat_rvs] if flat_rvs: _log.info( "Variables %s have flat priors, and hence they are not plotted", ", ".join(flat_rvs)) if omit_offsets: var_names = [ name for name in var_names if not name.endswith("_offset") ] if omit_group_specific: omitted = list(self.group_specific_terms) var_names = [vn for vn in var_names if vn not in omitted] axes = None if var_names: pps = self.prior_predictive(draws=draws, var_names=var_names, random_seed=random_seed) axes = plot_posterior( pps, group="prior", figsize=figsize, textsize=textsize, hdi_prob=hdi_prob, round_to=round_to, point_estimate=point_estimate, kind=kind, bins=bins, ax=ax, ) return axes