示例#1
0
def mcmc_stats(runs, burnin, prob, batch):
    """
        入力
        runs:   モンテカルロ標本
        burnin: バーンインの回数
        prob:   区間確率 (0 < prob < 1)
        batch:  乱数系列の分割数
        出力
        事後統計量のデータフレーム
    """
    traces = runs[burnin:, :]
    n = traces.shape[0] // batch
    k = traces.shape[1]
    alpha = 100 * (1.0 - prob)
    post_mean = np.mean(traces, axis=0)
    post_median = np.median(traces, axis=0)
    post_sd = np.std(traces, axis=0)
    mc_err = [pm.mcse(traces[:, i].reshape((n, batch), order='F')).item(0) \
              for i in range(k)]
    ci_lower = np.percentile(traces, 0.5 * alpha, axis=0)
    ci_upper = np.percentile(traces, 100 - 0.5 * alpha, axis=0)
    hpdi = pm.hpd(traces, 1.0 - prob)
    rhat = [pm.rhat(traces[:, i].reshape((n, batch), order='F')).item(0) \
            for i in range(k)]
    stats = np.vstack((post_mean, post_median, post_sd, mc_err, ci_lower,
                       ci_upper, hpdi.T, rhat)).T
    stats_string = [
        '平均', '中央値', '標準偏差', '近似誤差', '信用区間(下限)', '信用区間(上限)', 'HPDI(下限)',
        'HPDI(上限)', '$\\hat R$'
    ]
    param_string = ['$\\beta_{0:<d}$'.format(i + 1) for i in range(k - 1)]
    param_string.append('$\\sigma^2$')
    return pd.DataFrame(stats, index=param_string, columns=stats_string)
示例#2
0
 def test_Rhat(self):
     rhat = pm.rhat(self.trace[self.burn :])
     for var in rhat:
         npt.assert_allclose(rhat[var], 1, rtol=0.01)
示例#3
0
    def _run_convergence_checks(self, trace, model):
        if trace.nchains == 1:
            msg = ("Only one chain was sampled, this makes it impossible to "
                   "run some convergence checks")
            warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info', None,
                                  None, None)
            self._add_warnings([warn])
            return

        from pymc3 import rhat, ess

        valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
        varnames = []
        for rv in model.free_RVs:
            rv_name = rv.name
            if is_transformed_name(rv_name):
                rv_name2 = get_untransformed_name(rv_name)
                rv_name = rv_name2 if rv_name2 in valid_name else rv_name
            if rv_name in trace.varnames:
                varnames.append(rv_name)

        self._ess = ess = ess(trace, var_names=varnames)
        self._rhat = rhat = rhat(trace, var_names=varnames)

        warnings = []
        rhat_max = max(val.max() for val in rhat.values())
        if rhat_max > 1.4:
            msg = ("The rhat statistic is larger than 1.4 for some "
                   "parameters. The sampler did not converge.")
            warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'error', None,
                                  None, rhat)
            warnings.append(warn)
        elif rhat_max > 1.2:
            msg = ("The rhat statistic is larger than 1.2 for some "
                   "parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'warn', None,
                                  None, rhat)
            warnings.append(warn)
        elif rhat_max > 1.05:
            msg = ("The rhat statistic is larger than 1.05 for some "
                   "parameters. This indicates slight problems during "
                   "sampling.")
            warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'info', None,
                                  None, rhat)
            warnings.append(warn)

        eff_min = min(val.min() for val in ess.values())
        n_samples = len(trace) * trace.nchains
        if eff_min < 200 and n_samples >= 500:
            msg = ("The estimated number of effective samples is smaller than "
                   "200 for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'error', None,
                                  None, ess)
            warnings.append(warn)
        elif eff_min / n_samples < 0.1:
            msg = ("The number of effective samples is smaller than "
                   "10% for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'warn', None,
                                  None, ess)
            warnings.append(warn)
        elif eff_min / n_samples < 0.25:
            msg = ("The number of effective samples is smaller than "
                   "25% for some parameters.")
            warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'info', None,
                                  None, ess)
            warnings.append(warn)

        self._add_warnings(warnings)