示例#1
0
文件: mcmc.py 项目: zyxue/pyro
 def diagnostics(self):
     if self._diagnostics:
         return self._diagnostics
     for site in self.sites:
         site_stats = OrderedDict()
         try:
             site_stats["n_eff"] = stats.effective_sample_size(
                 self.support()[site])
         except NotImplementedError:
             site_stats["n_eff"] = torch.tensor(float('nan'))
         site_stats["r_hat"] = stats.split_gelman_rubin(
             self.support()[site])
         self._diagnostics[site] = site_stats
     return self._diagnostics
示例#2
0
def test_effective_sample_size():
    x = torch.arange(1000.0).reshape(100, 10)

    with xfail_if_not_implemented():
        # test against arviz
        assert_equal(effective_sample_size(x).item(), 52.64, prec=0.01)