Esempio n. 1
0
def summary(samples, prob=0.9, group_by_chain=True):
    """
    Returns a summary table displaying diagnostics of ``samples`` from the
    posterior. The diagnostics displayed are mean, standard deviation, median,
    the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`,
    :func:`~pyro.ops.stats.split_gelman_rubin`.

    :param dict samples: dictionary of samples keyed by site name.
    :param float prob: the probability mass of samples within the credibility interval.
    :param bool group_by_chain: If True, each variable in `samples`
        will be treated as having shape `num_chains x num_samples x sample_shape`.
        Otherwise, the corresponding shape will be `num_samples x sample_shape`
        (i.e. without chain dimension).
    """
    if not group_by_chain:
        samples = {k: v.unsqueeze(0) for k, v in samples.items()}

    summary_dict = {}
    for name, value in samples.items():
        value_flat = torch.reshape(value, (-1,) + value.shape[2:])
        mean = value_flat.mean(dim=0)
        std = value_flat.std(dim=0)
        median = value_flat.median(dim=0)[0]
        hpdi = stats.hpdi(value_flat, prob=prob)
        n_eff = _safe(stats.effective_sample_size)(value)
        r_hat = stats.split_gelman_rubin(value)
        hpd_lower = '{:.1f}%'.format(50 * (1 - prob))
        hpd_upper = '{:.1f}%'.format(50 * (1 + prob))
        summary_dict[name] = OrderedDict([("mean", mean), ("std", std), ("median", median),
                                          (hpd_lower, hpdi[0]), (hpd_upper, hpdi[1]),
                                          ("n_eff", n_eff), ("r_hat", r_hat)])
    return summary_dict
Esempio n. 2
0
def summary(samples, prob=0.9, group_by_chain=True):
    """
    Prints a summary table displaying diagnostics of ``samples`` from the
    posterior. The diagnostics displayed are mean, standard deviation, median,
    the 90% Credibility Interval, :func:`~pyro.ops.stats.effective_sample_size`,
    :func:`~pyro.ops.stats.split_gelman_rubin`.

    :param dict samples: dictionary of samples keyed by site name.
    :param float prob: the probability mass of samples within the credibility interval.
    :param bool group_by_chain: If True, each variable in `samples`
        will be treated as having shape `num_chains x num_samples x sample_shape`.
        Otherwise, the corresponding shape will be `num_samples x sample_shape`
        (i.e. without chain dimension).
    """
    if not group_by_chain:
        samples = {k: v.unsqueeze(0) for k, v in samples.items()}

    row_names = {
        k: k + '[' + ','.join(map(lambda x: str(x - 1), v.shape[2:])) + ']'
        for k, v in samples.items()
    }
    max_len = max(max(map(lambda x: len(x), row_names.values())), 10)
    name_format = '{:>' + str(max_len) + '}'
    header_format = name_format + ' {:>9} {:>9} {:>9} {:>9} {:>9} {:>9} {:>9}'
    columns = [
        '', 'mean', 'std', 'median', '{:.1f}%'.format(50 * (1 - prob)),
        '{:.1f}%'.format(50 * (1 + prob)), 'n_eff', 'r_hat'
    ]
    print('\n')
    print(header_format.format(*columns))

    row_format = name_format + ' {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f} {:>9.2f}'
    for name, value in samples.items():
        value_flat = torch.reshape(value, (-1, ) + value.shape[2:])
        mean = value_flat.mean(dim=0)
        sd = value_flat.std(dim=0)
        median = value_flat.median(dim=0)[0]
        hpd = stats.hpdi(value_flat, prob=prob)
        n_eff = _safe(stats.effective_sample_size)(value)
        r_hat = stats.split_gelman_rubin(value)
        shape = value_flat.shape[1:]
        if len(shape) == 0:
            print(
                row_format.format(name, mean, sd, median, hpd[0], hpd[1],
                                  n_eff, r_hat))
        else:
            for idx in product(*map(range, shape)):
                idx_str = '[{}]'.format(','.join(map(str, idx)))
                print(
                    row_format.format(name + idx_str, mean[idx], sd[idx],
                                      median[idx], hpd[0][idx], hpd[1][idx],
                                      n_eff[idx], r_hat[idx]))
    print('\n')
Esempio n. 3
0
File: mcmc.py Progetto: zyxue/pyro
 def diagnostics(self):
     if self._diagnostics:
         return self._diagnostics
     for site in self.sites:
         site_stats = OrderedDict()
         try:
             site_stats["n_eff"] = stats.effective_sample_size(
                 self.support()[site])
         except NotImplementedError:
             site_stats["n_eff"] = torch.tensor(float('nan'))
         site_stats["r_hat"] = stats.split_gelman_rubin(
             self.support()[site])
         self._diagnostics[site] = site_stats
     return self._diagnostics
Esempio n. 4
0
def diagnostics(samples, group_by_chain=True):
    """
    Gets diagnostics statistics such as effective sample size and
    split Gelman-Rubin using the samples drawn from the posterior
    distribution.

    :param dict samples: dictionary of samples keyed by site name.
    :param bool group_by_chain: If True, each variable in `samples`
        will be treated as having shape `num_chains x num_samples x sample_shape`.
        Otherwise, the corresponding shape will be `num_samples x sample_shape`
        (i.e. without chain dimension).
    :return: dictionary of diagnostic stats for each sample site.
    """
    diagnostics = {}
    for site, support in samples.items():
        if not group_by_chain:
            support = support.unsqueeze(0)
        site_stats = OrderedDict()
        site_stats["n_eff"] = _safe(stats.effective_sample_size)(support)
        site_stats["r_hat"] = stats.split_gelman_rubin(support)
        diagnostics[site] = site_stats
    return diagnostics
Esempio n. 5
0
def test_split_gelman_rubin_agree_with_gelman_rubin():
    x = torch.rand(2, 10)
    r_hat1 = gelman_rubin(x.reshape(2, 2, 5).reshape(4, 5))
    r_hat2 = split_gelman_rubin(x)
    assert_equal(r_hat1, r_hat2)