Пример #1
0
def convert_str_to_rv_dict(
    model, start: StartDict
) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]]:
    """Helper function for converting a user-provided start dict with str keys of (transformed) variable names
    to a dict mapping the RV tensors to untransformed initvals.
    TODO: Deprecate this functionality and only accept TensorVariables as keys
    """
    initvals = {}
    for key, initval in start.items():
        if isinstance(key, str):
            if is_transformed_name(key):
                rv = model[get_untransformed_name(key)]
                initvals[rv] = model.rvs_to_values[rv].tag.transform.backward(
                    initval, *rv.owner.inputs)
            else:
                initvals[model[key]] = initval
        else:
            initvals[key] = initval
    return initvals
Пример #2
0
    def _run_convergence_checks(self, idata: arviz.InferenceData, model):
        if not hasattr(idata, "posterior"):
            msg = "No posterior samples. Unable to run convergence checks"
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None,
                                  None, None)
            self._add_warnings([warn])
            return

        if idata.posterior.sizes["chain"] == 1:
            msg = ("Only one chain was sampled, this makes it impossible to "
                   "run some convergence checks")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
            self._add_warnings([warn])
            return

        elif idata.posterior.sizes["chain"] < 4:
            msg = (
                "We recommend running at least 4 chains for robust computation of "
                "convergence diagnostics")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
            self._add_warnings([warn])
            return

        valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
        varnames = []
        for rv in model.free_RVs:
            rv_name = rv.name
            if is_transformed_name(rv_name):
                rv_name2 = get_untransformed_name(rv_name)
                rv_name = rv_name2 if rv_name2 in valid_name else rv_name
            if rv_name in idata.posterior:
                varnames.append(rv_name)

        self._ess = ess = arviz.ess(idata, var_names=varnames)
        self._rhat = rhat = arviz.rhat(idata, var_names=varnames)

        warnings = []
        rhat_max = max(val.max() for val in rhat.values())
        if rhat_max > 1.01:
            msg = ("The rhat statistic is larger than 1.01 for some "
                   "parameters. This indicates problems during sampling. "
                   "See https://arxiv.org/abs/1903.08008 for details")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "info",
                                  extra=rhat)
            warnings.append(warn)

        eff_min = min(val.min() for val in ess.values())
        eff_per_chain = eff_min / idata.posterior.sizes["chain"]
        if eff_per_chain < 100:
            msg = (
                "The effective sample size per chain is smaller than 100 for some parameters. "
                " A higher number is needed for reliable rhat and ess computation. "
                "See https://arxiv.org/abs/1903.08008 for details")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "error",
                                  extra=ess)
            warnings.append(warn)

        self._add_warnings(warnings)
Пример #3
0
    def _run_convergence_checks(self, idata: arviz.InferenceData, model):
        if not hasattr(idata, "posterior"):
            msg = "No posterior samples. Unable to run convergence checks"
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None,
                                  None, None)
            self._add_warnings([warn])
            return

        if idata.posterior.sizes["chain"] == 1:
            msg = ("Only one chain was sampled, this makes it impossible to "
                   "run some convergence checks")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
            self._add_warnings([warn])
            return

        valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
        varnames = []
        for rv in model.free_RVs:
            rv_name = rv.name
            if is_transformed_name(rv_name):
                rv_name2 = get_untransformed_name(rv_name)
                rv_name = rv_name2 if rv_name2 in valid_name else rv_name
            if rv_name in idata.posterior:
                varnames.append(rv_name)

        self._ess = ess = arviz.ess(idata, var_names=varnames)
        self._rhat = rhat = arviz.rhat(idata, var_names=varnames)

        warnings = []
        rhat_max = max(val.max() for val in rhat.values())
        if rhat_max > 1.4:
            msg = ("The rhat statistic is larger than 1.4 for some "
                   "parameters. The sampler did not converge.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "error",
                                  extra=rhat)
            warnings.append(warn)
        elif rhat_max > 1.2:
            msg = "The rhat statistic is larger than 1.2 for some " "parameters."
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "warn",
                                  extra=rhat)
            warnings.append(warn)
        elif rhat_max > 1.05:
            msg = ("The rhat statistic is larger than 1.05 for some "
                   "parameters. This indicates slight problems during "
                   "sampling.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "info",
                                  extra=rhat)
            warnings.append(warn)

        eff_min = min(val.min() for val in ess.values())
        sizes = idata.posterior.sizes
        n_samples = sizes["chain"] * sizes["draw"]
        if eff_min < 200 and n_samples >= 500:
            msg = ("The estimated number of effective samples is smaller than "
                   "200 for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "error",
                                  extra=ess)
            warnings.append(warn)
        elif eff_min / n_samples < 0.1:
            msg = "The number of effective samples is smaller than " "10% for some parameters."
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "warn",
                                  extra=ess)
            warnings.append(warn)
        elif eff_min / n_samples < 0.25:
            msg = "The number of effective samples is smaller than " "25% for some parameters."
            warn = SamplerWarning(WarningType.CONVERGENCE,
                                  msg,
                                  "info",
                                  extra=ess)
            warnings.append(warn)

        self._add_warnings(warnings)