def summary(samples, prob=0.9, group_by_chain=True): """ Returns a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`, :func:`~pyro.ops.stats.split_gelman_rubin`. :param dict samples: dictionary of samples keyed by site name. :param float prob: the probability mass of samples within the credibility interval. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). """ if not group_by_chain: samples = {k: v.unsqueeze(0) for k, v in samples.items()} summary_dict = {} for name, value in samples.items(): value_flat = torch.reshape(value, (-1,) + value.shape[2:]) mean = value_flat.mean(dim=0) std = value_flat.std(dim=0) median = value_flat.median(dim=0)[0] hpdi = stats.hpdi(value_flat, prob=prob) n_eff = _safe(stats.effective_sample_size)(value) r_hat = stats.split_gelman_rubin(value) hpd_lower = '{:.1f}%'.format(50 * (1 - prob)) hpd_upper = '{:.1f}%'.format(50 * (1 + prob)) summary_dict[name] = OrderedDict([("mean", mean), ("std", std), ("median", median), (hpd_lower, hpdi[0]), (hpd_upper, hpdi[1]), ("n_eff", n_eff), ("r_hat", r_hat)]) return summary_dict
def summary(samples, prob=0.9, group_by_chain=True): """ Prints a summary table displaying diagnostics of ``samples`` from the posterior. The diagnostics displayed are mean, standard deviation, median, the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`, :func:`~pyro.ops.stats.split_gelman_rubin`. :param dict samples: dictionary of samples keyed by site name. :param float prob: the probability mass of samples within the credibility interval. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). """ if not group_by_chain: samples = {k: v.unsqueeze(0) for k, v in samples.items()} row_names = { k: k + '[' + ','.join(map(lambda x: str(x - 1), v.shape[2:])) + ']' for k, v in samples.items() } max_len = max(max(map(lambda x: len(x), row_names.values())), 10) name_format = '{:>' + str(max_len) + '}' header_format = name_format + ' {:>9} {:>9} {:>9} {:>9} {:>9} {:>9} {:>9}' columns = [ '', 'mean', 'std', 'median', '{:.1f}%'.format(50 * (1 - prob)), '{:.1f}%'.format(50 * (1 + prob)), 'n_eff', 'r_hat' ] print('\n') print(header_format.format(*columns)) row_format = name_format + ' {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f}' for name, value in samples.items(): value_flat = torch.reshape(value, (-1, ) + value.shape[2:]) mean = value_flat.mean(dim=0) sd = value_flat.std(dim=0) median = value_flat.median(dim=0)[0] hpd = stats.hpdi(value_flat, prob=prob) n_eff = _safe(stats.effective_sample_size)(value) r_hat = stats.split_gelman_rubin(value) shape = value_flat.shape[1:] if len(shape) == 0: print( row_format.format(name, mean, sd, median, hpd[0], hpd[1], n_eff, r_hat)) else: for idx in product(*map(range, shape)): idx_str = '[{}]'.format(','.join(map(str, idx))) print( row_format.format(name + idx_str, mean[idx], sd[idx], median[idx], hpd[0][idx], hpd[1][idx], n_eff[idx], r_hat[idx])) print('\n')
def diagnostics(self): if self._diagnostics: return self._diagnostics for site in self.sites: site_stats = OrderedDict() try: site_stats["n_eff"] = stats.effective_sample_size( self.support()[site]) except NotImplementedError: site_stats["n_eff"] = torch.tensor(float('nan')) site_stats["r_hat"] = stats.split_gelman_rubin( self.support()[site]) self._diagnostics[site] = site_stats return self._diagnostics
def diagnostics(samples, group_by_chain=True): """ Gets diagnostics statistics such as effective sample size and split Gelman-Rubin using the samples drawn from the posterior distribution. :param dict samples: dictionary of samples keyed by site name. :param bool group_by_chain: If True, each variable in `samples` will be treated as having shape `num_chains x num_samples x sample_shape`. Otherwise, the corresponding shape will be `num_samples x sample_shape` (i.e. without chain dimension). :return: dictionary of diagnostic stats for each sample site. """ diagnostics = {} for site, support in samples.items(): if not group_by_chain: support = support.unsqueeze(0) site_stats = OrderedDict() site_stats["n_eff"] = _safe(stats.effective_sample_size)(support) site_stats["r_hat"] = stats.split_gelman_rubin(support) diagnostics[site] = site_stats return diagnostics
def test_split_gelman_rubin_agree_with_gelman_rubin(): x = torch.rand(2, 10) r_hat1 = gelman_rubin(x.reshape(2, 2, 5).reshape(4, 5)) r_hat2 = split_gelman_rubin(x) assert_equal(r_hat1, r_hat2)