def mcmc_stats(runs, burnin, prob, batch): """ 入力 runs: モンテカルロ標本 burnin: バーンインの回数 prob: 区間確率 (0 < prob < 1) batch: 乱数系列の分割数 出力 事後統計量のデータフレーム """ traces = runs[burnin:, :] n = traces.shape[0] // batch k = traces.shape[1] alpha = 100 * (1.0 - prob) post_mean = np.mean(traces, axis=0) post_median = np.median(traces, axis=0) post_sd = np.std(traces, axis=0) mc_err = [pm.mcse(traces[:, i].reshape((n, batch), order='F')).item(0) \ for i in range(k)] ci_lower = np.percentile(traces, 0.5 * alpha, axis=0) ci_upper = np.percentile(traces, 100 - 0.5 * alpha, axis=0) hpdi = pm.hpd(traces, 1.0 - prob) rhat = [pm.rhat(traces[:, i].reshape((n, batch), order='F')).item(0) \ for i in range(k)] stats = np.vstack((post_mean, post_median, post_sd, mc_err, ci_lower, ci_upper, hpdi.T, rhat)).T stats_string = [ '平均', '中央値', '標準偏差', '近似誤差', '信用区間(下限)', '信用区間(上限)', 'HPDI(下限)', 'HPDI(上限)', '$\\hat R$' ] param_string = ['$\\beta_{0:<d}$'.format(i + 1) for i in range(k - 1)] param_string.append('$\\sigma^2$') return pd.DataFrame(stats, index=param_string, columns=stats_string)
def test_Rhat(self): rhat = pm.rhat(self.trace[self.burn :]) for var in rhat: npt.assert_allclose(rhat[var], 1, rtol=0.01)
def _run_convergence_checks(self, trace, model): if trace.nchains == 1: msg = ("Only one chain was sampled, this makes it impossible to " "run some convergence checks") warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info', None, None, None) self._add_warnings([warn]) return from pymc3 import rhat, ess valid_name = [rv.name for rv in model.free_RVs + model.deterministics] varnames = [] for rv in model.free_RVs: rv_name = rv.name if is_transformed_name(rv_name): rv_name2 = get_untransformed_name(rv_name) rv_name = rv_name2 if rv_name2 in valid_name else rv_name if rv_name in trace.varnames: varnames.append(rv_name) self._ess = ess = ess(trace, var_names=varnames) self._rhat = rhat = rhat(trace, var_names=varnames) warnings = [] rhat_max = max(val.max() for val in rhat.values()) if rhat_max > 1.4: msg = ("The rhat statistic is larger than 1.4 for some " "parameters. The sampler did not converge.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'error', None, None, rhat) warnings.append(warn) elif rhat_max > 1.2: msg = ("The rhat statistic is larger than 1.2 for some " "parameters.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'warn', None, None, rhat) warnings.append(warn) elif rhat_max > 1.05: msg = ("The rhat statistic is larger than 1.05 for some " "parameters. This indicates slight problems during " "sampling.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'info', None, None, rhat) warnings.append(warn) eff_min = min(val.min() for val in ess.values()) n_samples = len(trace) * trace.nchains if eff_min < 200 and n_samples >= 500: msg = ("The estimated number of effective samples is smaller than " "200 for some parameters.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'error', None, None, ess) warnings.append(warn) elif eff_min / n_samples < 0.1: msg = ("The number of effective samples is smaller than " "10% for some parameters.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'warn', None, None, ess) warnings.append(warn) elif eff_min / n_samples < 0.25: msg = ("The number of effective samples is smaller than " "25% for some parameters.") warn = SamplerWarning(WarningType.CONVERGENCE, msg, 'info', None, None, ess) warnings.append(warn) self._add_warnings(warnings)