def diagnostics(self): if self._diagnostics: return self._diagnostics for site in self.sites: site_stats = OrderedDict() try: site_stats["n_eff"] = stats.effective_sample_size( self.support()[site]) except NotImplementedError: site_stats["n_eff"] = torch.tensor(float('nan')) site_stats["r_hat"] = stats.split_gelman_rubin( self.support()[site]) self._diagnostics[site] = site_stats return self._diagnostics
def test_effective_sample_size(): x = torch.arange(1000.0).reshape(100, 10) with xfail_if_not_implemented(): # test against arviz assert_equal(effective_sample_size(x).item(), 52.64, prec=0.01)